use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
pub(crate) struct WakerInner {
pub(crate) woken: AtomicBool,
notify: Option<Arc<dyn Fn() + Send + Sync>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TaskId(usize);
type BoxLocalFuture = Pin<Box<dyn Future<Output = ()>>>;
struct LocalTask {
future: BoxLocalFuture,
inner: Arc<WakerInner>,
completed: bool,
}
pub struct LocalExecutor {
tasks: Vec<Option<LocalTask>>,
free: VecDeque<usize>,
spawn_queue: VecDeque<usize>,
notify: Option<Arc<dyn Fn() + Send + Sync>>,
}
impl LocalExecutor {
#[must_use]
pub fn new() -> Self {
Self {
tasks: Vec::new(),
free: VecDeque::new(),
spawn_queue: VecDeque::new(),
notify: None,
}
}
pub fn set_notify(&mut self, notify: Arc<dyn Fn() + Send + Sync>) {
self.notify = Some(notify);
}
pub fn spawn(&mut self, future: impl Future<Output = ()> + 'static) -> TaskId {
let task = LocalTask {
future: Box::pin(future),
inner: Arc::new(WakerInner {
woken: AtomicBool::new(true), notify: self.notify.clone(),
}),
completed: false,
};
let id = if let Some(idx) = self.free.pop_front() {
self.tasks[idx] = Some(task);
idx
} else {
let idx = self.tasks.len();
self.tasks.push(Some(task));
idx
};
self.spawn_queue.push_back(id);
TaskId(id)
}
pub fn poll_all(&mut self) -> usize {
let mut progress = 0;
while let Some(id) = self.spawn_queue.pop_front() {
if self.poll_task(id) {
progress += 1;
}
if self.tasks[id].as_ref().is_some_and(|t| t.completed) {
self.tasks[id] = None;
self.free.push_back(id);
}
}
for id in 0..self.tasks.len() {
let Some(task) = &self.tasks[id] else {
continue;
};
if task.completed || !task.inner.woken.load(Ordering::Acquire) {
continue;
}
if self.poll_task(id) {
progress += 1;
}
if self.tasks[id].as_ref().is_some_and(|t| t.completed) {
self.tasks[id] = None;
self.free.push_back(id);
}
}
progress
}
fn poll_task(&mut self, id: usize) -> bool {
let Some(task) = &mut self.tasks[id] else {
return false;
};
if task.completed {
return false;
}
task.inner.woken.store(false, Ordering::Release);
let waker = create_waker(task.inner.clone());
let mut cx = Context::from_waker(&waker);
match task.future.as_mut().poll(&mut cx) {
Poll::Ready(()) => {
task.completed = true;
}
Poll::Pending => {
}
}
true
}
#[must_use]
pub fn is_completed(&self, id: TaskId) -> bool {
match self.tasks.get(id.0) {
Some(Some(task)) => task.completed,
Some(None) => true, None => {
panic!("TaskId({}) out of range (max: {})", id.0, self.tasks.len());
}
}
}
#[must_use]
pub fn active_count(&self) -> usize {
self.tasks
.iter()
.filter(|slot| slot.as_ref().is_some_and(|t| !t.completed))
.count()
}
#[must_use]
pub fn is_idle(&self) -> bool {
self.active_count() == 0 && self.spawn_queue.is_empty()
}
#[must_use]
pub fn has_woken(&self) -> bool {
if !self.spawn_queue.is_empty() {
return true;
}
self.tasks.iter().any(|slot| {
slot.as_ref()
.is_some_and(|t| !t.completed && t.inner.woken.load(Ordering::Acquire))
})
}
}
impl Default for LocalExecutor {
fn default() -> Self {
Self::new()
}
}
fn create_waker(inner: Arc<WakerInner>) -> Waker {
let raw = Arc::into_raw(inner) as *const ();
let raw_waker = RawWaker::new(raw, &VTABLE);
unsafe { Waker::from_raw(raw_waker) }
}
const VTABLE: RawWakerVTable =
RawWakerVTable::new(waker_clone, waker_wake, waker_wake_by_ref, waker_drop);
unsafe fn waker_clone(ptr: *const ()) -> RawWaker {
let arc = unsafe { Arc::from_raw(ptr as *const WakerInner) };
let cloned = arc.clone();
std::mem::forget(arc); RawWaker::new(Arc::into_raw(cloned) as *const (), &VTABLE)
}
unsafe fn waker_wake(ptr: *const ()) {
let arc = unsafe { Arc::from_raw(ptr as *const WakerInner) };
arc.woken.store(true, Ordering::Release);
if let Some(notify) = &arc.notify {
notify();
}
}
unsafe fn waker_wake_by_ref(ptr: *const ()) {
let arc = unsafe { Arc::from_raw(ptr as *const WakerInner) };
arc.woken.store(true, Ordering::Release);
if let Some(notify) = &arc.notify {
notify();
}
std::mem::forget(arc); }
unsafe fn waker_drop(ptr: *const ()) {
let _arc = unsafe { Arc::from_raw(ptr as *const WakerInner) };
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::{Cell, RefCell};
use std::rc::Rc;
use std::task::Waker;
#[test]
fn spawn_and_complete_immediate() {
let mut exec = LocalExecutor::new();
let done = Rc::new(Cell::new(false));
let d = done.clone();
let id = exec.spawn(async move {
d.set(true);
});
assert_eq!(exec.active_count(), 1);
exec.poll_all();
assert!(done.get());
assert!(exec.is_completed(id));
assert!(exec.is_idle());
}
#[test]
fn spawn_pending_then_wake() {
let mut exec = LocalExecutor::new();
let counter = Rc::new(Cell::new(0u32));
let waker_holder: Rc<RefCell<Option<Waker>>> = Rc::new(RefCell::new(None));
let c = counter.clone();
let wh = waker_holder.clone();
exec.spawn(async move {
std::future::poll_fn(|cx| {
let count = c.get();
if count == 0 {
*wh.borrow_mut() = Some(cx.waker().clone());
c.set(1);
Poll::Pending
} else {
c.set(2);
Poll::Ready(())
}
})
.await;
});
exec.poll_all();
assert_eq!(counter.get(), 1);
assert!(!exec.is_idle());
waker_holder.borrow().as_ref().unwrap().wake_by_ref();
exec.poll_all();
assert_eq!(counter.get(), 2);
assert!(exec.is_idle());
}
#[test]
fn wake_from_another_thread() {
let mut exec = LocalExecutor::new();
let waker_holder: Arc<std::sync::Mutex<Option<Waker>>> =
Arc::new(std::sync::Mutex::new(None));
let wh = waker_holder.clone();
exec.spawn(async move {
std::future::poll_fn(|cx| {
let mut guard = wh.lock().unwrap();
if guard.is_none() {
*guard = Some(cx.waker().clone());
Poll::Pending
} else {
Poll::Ready(())
}
})
.await;
});
exec.poll_all();
assert!(!exec.is_idle());
let wh = waker_holder.clone();
let handle = std::thread::spawn(move || {
let guard = wh.lock().unwrap();
guard.as_ref().unwrap().wake_by_ref();
});
handle.join().unwrap();
exec.poll_all();
assert!(exec.is_idle());
}
#[test]
fn multiple_tasks() {
let mut exec = LocalExecutor::new();
let log = Rc::new(RefCell::new(Vec::new()));
for i in 0..5 {
let l = log.clone();
exec.spawn(async move {
l.borrow_mut().push(i);
});
}
exec.poll_all();
assert_eq!(*log.borrow(), vec![0, 1, 2, 3, 4]);
assert!(exec.is_idle());
}
#[test]
fn task_id_reuse() {
let mut exec = LocalExecutor::new();
let id1 = exec.spawn(async {});
exec.poll_all();
assert!(exec.is_completed(id1));
let id2 = exec.spawn(async {});
assert_eq!(id1.0, id2.0); exec.poll_all();
}
#[test]
fn has_woken() {
let mut exec = LocalExecutor::new();
assert!(!exec.has_woken());
let wh: Rc<RefCell<Option<Waker>>> = Rc::new(RefCell::new(None));
let wh2 = wh.clone();
exec.spawn(async move {
std::future::poll_fn(|cx| {
*wh2.borrow_mut() = Some(cx.waker().clone());
Poll::<()>::Pending
})
.await;
});
assert!(exec.has_woken()); exec.poll_all();
assert!(!exec.has_woken());
wh.borrow().as_ref().unwrap().wake_by_ref();
assert!(exec.has_woken());
}
#[test]
fn waker_clone_and_drop() {
let inner = Arc::new(WakerInner {
woken: AtomicBool::new(false),
notify: None,
});
let waker = create_waker(inner.clone());
let waker2 = waker.clone();
drop(waker);
waker2.wake_by_ref();
assert!(inner.woken.load(Ordering::Acquire));
drop(waker2);
}
#[test]
fn completed_task_query() {
let _exec = LocalExecutor::new();
}
#[test]
fn immediate_cleanup_on_complete() {
let mut exec = LocalExecutor::new();
let id = exec.spawn(async {});
exec.poll_all();
assert!(exec.is_completed(id));
assert!(exec.is_idle());
let id2 = exec.spawn(async {});
assert_eq!(id.0, id2.0);
}
}