use async_lock::Barrier;
use async_task::Runnable;
use async_task::Task;
use concurrent_queue::ConcurrentQueue;
use futures::future;
use futures::future::Either;
use futures::future::select;
use slab::Slab;
use std::collections::HashSet;
use std::fmt;
use std::future::Future;
use std::panic::RefUnwindSafe;
use std::panic::UnwindSafe;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use std::thread;
use crate::runtime::BlockId;
use crate::runtime::FlowgraphMessage;
use crate::runtime::block::Block;
use crate::runtime::channel::mpsc::Sender;
use crate::runtime::channel::oneshot;
use crate::runtime::config;
use crate::runtime::scheduler::Scheduler;
#[derive(Clone, Debug)]
pub struct FlowScheduler {
inner: Arc<FlowSchedulerInner>,
}
struct FlowSchedulerInner {
executor: Arc<FlowExecutor>,
workers: Vec<(thread::JoinHandle<()>, oneshot::Sender<()>)>,
pinned_blocks: Vec<Vec<BlockId>>,
}
impl fmt::Debug for FlowSchedulerInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FlowSchedulerInner").finish()
}
}
impl Drop for FlowSchedulerInner {
fn drop(&mut self) {
for i in self.workers.drain(..) {
if i.1.send(()).is_err() {
warn!("Worker task already terminated.");
}
if std::thread::current().id() != i.0.thread().id() && i.0.join().is_err() {
warn!("Worker thread already terminated.");
}
}
}
}
impl FlowScheduler {
pub fn new() -> FlowScheduler {
FlowScheduler::with_pinned_blocks(Vec::new())
}
pub fn with_pinned_blocks(pinned_blocks: Vec<Vec<BlockId>>) -> FlowScheduler {
let core_ids = core_affinity::get_core_ids().unwrap();
let executor = Arc::new(FlowExecutor::new(core_ids.len()));
let mut workers = Vec::new();
debug!("flowsched: core ids {}", core_ids.len());
let barrier = Arc::new(Barrier::new(core_ids.len() + 1));
for (worker_index, id) in core_ids.into_iter().enumerate() {
let b = barrier.clone();
let e = executor.clone();
let (sender, receiver) = oneshot::channel::<()>();
let handle = thread::Builder::new()
.stack_size(config::config().stack_size)
.name(format!("flow-{}", id.id))
.spawn(move || {
debug!("starting executor thread on core id {}", id.id);
core_affinity::set_for_current(id);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
crate::runtime::block_on(e.run_on(worker_index, async {
b.wait().await;
receiver.await
}))
}));
if result.is_err() {
eprintln!("flow worker panicked {result:?}");
std::process::exit(1);
}
})
.expect("cannot spawn executor thread");
workers.push((handle, sender));
}
crate::runtime::block_on(barrier.wait());
FlowScheduler {
inner: Arc::new(FlowSchedulerInner {
executor,
workers,
pinned_blocks,
}),
}
}
fn map_block(block: usize, n_blocks: usize, n_cores: usize) -> usize {
let n = n_blocks / n_cores;
let r = n_blocks % n_cores;
for x in 1..n_cores {
if block < ((x) * n) + std::cmp::min(x, r) {
return x - 1;
}
}
n_cores - 1
}
}
impl Scheduler for FlowScheduler {
fn run_domain(
&self,
blocks: Vec<Box<dyn Block>>,
main_channel: &Sender<FlowgraphMessage>,
) -> Vec<Task<(BlockId, Box<dyn Block>)>> {
let n_blocks = blocks.len();
let n_cores = self.inner.workers.len();
let mut spawned: HashSet<BlockId> = HashSet::new();
let mut blocks_by_id = Vec::with_capacity(n_blocks);
for block in blocks {
let id = block.id();
blocks_by_id.push((id, block));
}
let mut tasks = Vec::with_capacity(n_blocks);
for (executor, block_ids) in self.inner.pinned_blocks.iter().enumerate() {
if executor >= n_cores {
warn!(
"flowsched mapping has executor index {} but only {} executors are available",
executor, n_cores
);
continue;
}
for block_id in block_ids {
let Some(pos) = blocks_by_id.iter().position(|(id, _)| id == block_id) else {
warn!(
"flowsched mapping references unknown block id {:?}",
block_id
);
continue;
};
if !spawned.insert(*block_id) {
warn!(
"flowsched mapping references block id {:?} more than once",
block_id
);
continue;
}
let (_, block) = blocks_by_id.swap_remove(pos);
tasks.push(spawn_block_on_executor(
&self.inner.executor,
block,
main_channel.clone(),
executor,
));
}
}
for (id, block) in blocks_by_id.into_iter() {
if spawned.contains(&id) {
continue;
}
let executor = FlowScheduler::map_block(id.0, n_blocks, n_cores);
tasks.push(spawn_block_on_executor(
&self.inner.executor,
block,
main_channel.clone(),
executor,
));
}
tasks
}
fn spawn<T: Send + 'static>(
&self,
future: impl Future<Output = T> + Send + 'static,
) -> Task<T> {
self.inner.executor.spawn(future)
}
}
impl Default for FlowScheduler {
fn default() -> Self {
Self::new()
}
}
fn spawn_block_on_executor(
executor: &FlowExecutor,
block: Box<dyn Block>,
main_channel: Sender<FlowgraphMessage>,
queue_index: usize,
) -> Task<(BlockId, Box<dyn Block>)> {
debug_assert!(
!block.is_blocking(),
"blocking blocks must be placed in local domains before scheduling"
);
executor.spawn_executor(
async move {
let mut block = block;
let id = block.id();
block.run(main_channel).await;
(id, block)
},
queue_index,
)
}
pub struct FlowExecutor {
state: once_cell::sync::OnceCell<Arc<State>>,
worker_count: usize,
}
const LOCAL_QUEUE_CAPACITY: usize = 512;
impl UnwindSafe for FlowExecutor {}
impl RefUnwindSafe for FlowExecutor {}
impl FlowExecutor {
pub const fn new(worker_count: usize) -> FlowExecutor {
FlowExecutor {
state: once_cell::sync::OnceCell::new(),
worker_count,
}
}
pub fn spawn<T: Send + 'static>(
&self,
future: impl Future<Output = T> + Send + 'static,
) -> Task<T> {
let mut active = self.state().active.lock().unwrap();
let entry = active.vacant_entry();
let key = entry.key();
let state = self.state().clone();
let future = async move {
let _guard = CallOnDrop(move || drop(state.active.lock().unwrap().try_remove(key)));
future.await
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, self.schedule()) };
entry.insert(runnable.waker());
runnable.schedule();
task
}
pub fn spawn_executor<T: Send + 'static>(
&self,
future: impl Future<Output = T> + Send + 'static,
executor: usize,
) -> Task<T> {
let mut active = self.state().active.lock().unwrap();
let entry = active.vacant_entry();
let key = entry.key();
let state = self.state().clone();
let future = async move {
let _guard = CallOnDrop(move || drop(state.active.lock().unwrap().try_remove(key)));
future.await
};
let local = self
.state()
.local_queues
.get(executor)
.cloned()
.expect("executor queue not initialized");
let (runnable, task) =
unsafe { async_task::spawn_unchecked(future, self.schedule_executor(local, executor)) };
entry.insert(runnable.waker());
runnable.schedule();
task
}
pub async fn run_on<T>(&self, worker_index: usize, future: impl Future<Output = T>) -> T {
let mut runner = Runner::new(self.state(), worker_index);
let run_forever = async {
loop {
for _ in 0..200 {
let runnable = runner.runnable().await;
runnable.run();
}
yield_now().await;
}
};
futures::pin_mut!(future);
futures::pin_mut!(run_forever);
match select(future, run_forever).await {
Either::Left((v, _other)) => v,
Either::Right((v, _other)) => v,
}
}
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
let state = self.state().clone();
move |runnable| {
state.queue.push(runnable).unwrap();
state.notify();
}
}
fn schedule_executor(
&self,
local: Arc<ConcurrentQueue<Runnable>>,
executor: usize,
) -> impl Fn(Runnable) + Send + Sync + 'static {
let state = self.state().clone();
move |runnable| {
if let Err(err) = local.push(runnable) {
state.queue.push(err.into_inner()).unwrap();
state.notify();
return;
}
let _ = state.wake_worker(executor);
}
}
fn state(&self) -> &Arc<State> {
self.state
.get_or_init(|| Arc::new(State::new(self.worker_count)))
}
}
impl Drop for FlowExecutor {
#[allow(clippy::significant_drop_in_scrutinee)]
fn drop(&mut self) {
debug!("dropping flow executor");
if let Some(state) = self.state.get() {
let active = state.active.lock().unwrap();
for (_, w) in active.iter() {
w.wake_by_ref();
}
drop(active);
while state.queue.pop().is_ok() {}
for q in state.local_queues.iter() {
while q.pop().is_ok() {}
}
}
}
}
struct State {
queue: ConcurrentQueue<Runnable>,
local_queues: Vec<Arc<ConcurrentQueue<Runnable>>>,
worker_signals: Vec<Arc<WorkerSignal>>,
next_wake: AtomicUsize,
active: Mutex<Slab<Waker>>,
}
impl State {
fn new(worker_count: usize) -> State {
let local_queues: Vec<_> = (0..worker_count)
.map(|_| Arc::new(ConcurrentQueue::bounded(LOCAL_QUEUE_CAPACITY)))
.collect();
let worker_signals: Vec<_> = (0..worker_count)
.map(|_| Arc::new(WorkerSignal::default()))
.collect();
State {
queue: ConcurrentQueue::unbounded(),
local_queues,
worker_signals,
next_wake: AtomicUsize::new(0),
active: Mutex::new(Slab::new()),
}
}
#[inline]
fn notify(&self) {
let n = self.worker_signals.len();
if n == 0 {
return;
}
let start = self.next_wake.fetch_add(1, Ordering::Relaxed) % n;
for off in 0..n {
let idx = (start + off) % n;
if self.wake_worker(idx) {
break;
}
}
}
#[inline]
fn wake_worker(&self, queue_index: usize) -> bool {
if queue_index >= self.worker_signals.len() {
return false;
}
let signal = &self.worker_signals[queue_index];
if signal.sleeping.swap(false, Ordering::AcqRel) {
signal.waker.wake();
true
} else {
false
}
}
}
#[derive(Debug, Default)]
struct WorkerSignal {
sleeping: AtomicBool,
waker: futures::task::AtomicWaker,
}
struct Ticker<'a> {
signal: &'a WorkerSignal,
}
impl Ticker<'_> {
fn new(signal: &WorkerSignal) -> Ticker<'_> {
Ticker { signal }
}
async fn runnable_with(&mut self, mut search: impl FnMut() -> Option<Runnable>) -> Runnable {
future::poll_fn(|cx| {
loop {
if let Some(r) = search() {
self.signal.sleeping.store(false, Ordering::Release);
return Poll::Ready(r);
}
self.signal.sleeping.store(true, Ordering::Release);
self.signal.waker.register(cx.waker());
if !self.signal.sleeping.load(Ordering::Acquire) {
continue;
}
if let Some(r) = search() {
self.signal.sleeping.store(false, Ordering::Release);
return Poll::Ready(r);
}
return Poll::Pending;
}
})
.await
}
}
impl Drop for Ticker<'_> {
fn drop(&mut self) {
self.signal.sleeping.store(false, Ordering::Release);
}
}
struct Runner<'a> {
state: &'a State,
ticker: Ticker<'a>,
local: Arc<ConcurrentQueue<Runnable>>,
}
impl Runner<'_> {
fn new(state: &State, worker_index: usize) -> Runner<'_> {
let local = state
.local_queues
.get(worker_index)
.cloned()
.expect("worker local queue not initialized");
let signal = state
.worker_signals
.get(worker_index)
.expect("worker signal not initialized");
Runner {
state,
ticker: Ticker::new(signal),
local,
}
}
async fn runnable(&mut self) -> Runnable {
self.ticker
.runnable_with(|| {
if let Ok(r) = self.local.pop() {
return Some(r);
}
if let Ok(r) = self.state.queue.pop() {
return Some(r);
}
None
})
.await
}
}
impl Drop for Runner<'_> {
fn drop(&mut self) {
}
}
struct CallOnDrop<F: Fn()>(F);
impl<F: Fn()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}
fn yield_now() -> YieldNow {
YieldNow(false)
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct YieldNow(bool);
impl Future for YieldNow {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if !self.0 {
self.0 = true;
cx.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(())
}
}
}