use crate::boxed::Box;
use crate::cell::{Cell, RefCell};
use crate::collections::VecDeque;
use crate::future::Future;
use crate::io;
use crate::pin::Pin;
use crate::task::{Context, Poll, Waker};
use crate::time::Duration;
use alloc::sync::Arc;
use core::sync::atomic::{AtomicBool, Ordering};
use super::{timers::next_deadline, waker::task_waker};
#[cfg(target_os = "linux")]
use super::linux_epoll::Driver;
#[cfg(windows)]
use super::windows::Driver;
#[cfg(any(
target_os = "macos",
target_os = "freebsd",
target_os = "openbsd",
target_os = "netbsd"
))]
use super::bsd::Driver;
struct Task {
future: Pin<Box<dyn Future<Output = ()>>>,
woken: Arc<AtomicBool>,
}
thread_local! {
static TASKS: RefCell<VecDeque<Task>> = RefCell::new(VecDeque::new());
static DRIVER: RefCell<Option<Driver>> = RefCell::new(None);
static TASK_DEPTH: Cell<usize> = Cell::new(0);
}
fn tasks_mut<R>(f: impl FnOnce(&mut VecDeque<Task>) -> R) -> R {
TASKS.with(|q| f(&mut *q.borrow_mut()))
}
pub(crate) fn with_driver<R>(f: impl FnOnce(&mut Driver) -> R) -> io::Result<R> {
DRIVER.with(|cell| -> io::Result<R> {
let mut borrow = cell.borrow_mut();
if borrow.is_none() {
*borrow = Some(Driver::new()?);
}
match borrow.as_mut() {
Some(d) => Ok(f(d)),
None => Err(io::Error::other("driver init failed")),
}
})
}
struct JoinState<T> {
result: Option<T>,
waker: Option<Waker>,
}
struct JoinFuture<T> {
inner: Pin<Box<dyn Future<Output = T>>>,
state: Arc<RefCell<JoinState<T>>>,
}
impl<T: 'static> Future for JoinFuture<T> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = Pin::into_inner(self);
match this.inner.as_mut().poll(cx) {
Poll::Ready(val) => {
let mut state = this.state.borrow_mut();
state.result = Some(val);
if let Some(w) = state.waker.take() {
w.wake();
}
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
pub struct JoinHandle<T> {
state: Arc<RefCell<JoinState<T>>>,
}
impl<T> Future for JoinHandle<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
let mut state = self.state.borrow_mut();
if let Some(val) = state.result.take() {
Poll::Ready(val)
} else {
state.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
pub fn spawn<F, T>(future: F) -> JoinHandle<T>
where
F: Future<Output = T> + 'static,
T: 'static,
{
let state = Arc::new(RefCell::new(JoinState {
result: None,
waker: None,
}));
let handle = JoinHandle {
state: Arc::clone(&state),
};
let wrapper = JoinFuture {
inner: Box::pin(future),
state,
};
let woken = Arc::new(AtomicBool::new(true));
tasks_mut(|q| {
q.push_back(Task {
future: Box::pin(wrapper),
woken,
})
});
handle
}
pub(crate) fn run<F>(future: F)
where
F: Future<Output = ()> + 'static,
{
let _ = spawn(future);
}
pub enum Either<A, B> {
Left(A),
Right(B),
}
pub struct Select<FA, FB> {
a: Pin<Box<FA>>,
b: Pin<Box<FB>>,
}
impl<FA: Future, FB: Future> Future for Select<FA, FB> {
type Output = Either<FA::Output, FB::Output>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Poll::Ready(a) = self.a.as_mut().poll(cx) {
return Poll::Ready(Either::Left(a));
}
if let Poll::Ready(b) = self.b.as_mut().poll(cx) {
return Poll::Ready(Either::Right(b));
}
Poll::Pending
}
}
pub fn select<FA: Future, FB: Future>(a: FA, b: FB) -> Select<FA, FB> {
Select {
a: Box::pin(a),
b: Box::pin(b),
}
}
#[derive(Debug, Clone, Copy)]
pub enum PollStatus {
Done,
Ready,
Idle {
next_deadline: Option<Duration>,
},
}
pub fn poll_step() -> io::Result<PollStatus> {
let had_events = with_driver(|d| d.poll_nonblocking())??;
let n = tasks_mut(|q| q.len());
let mut made_progress = false;
let mut remaining = 0usize;
for _ in 0..n {
let task = tasks_mut(|q| q.pop_front());
let Some(mut task) = task else { break };
if task.woken.swap(false, Ordering::AcqRel) {
let waker = task_waker(Arc::clone(&task.woken));
let mut cx = Context::from_waker(&waker);
match task.future.as_mut().poll(&mut cx) {
Poll::Ready(()) => {
made_progress = true;
}
Poll::Pending => {
tasks_mut(|q| q.push_back(task));
remaining += 1;
}
}
} else {
tasks_mut(|q| q.push_back(task));
remaining += 1;
}
}
TASK_DEPTH.with(|d| d.set(remaining));
Ok(if remaining == 0 && !had_events {
PollStatus::Done
} else if made_progress || had_events {
PollStatus::Ready
} else {
PollStatus::Idle {
next_deadline: next_deadline(),
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::rc::Rc;
use core::cell::Cell;
fn drive_until_done() {
loop {
match poll_step().unwrap() {
PollStatus::Done => break,
PollStatus::Ready | PollStatus::Idle { .. } => continue,
}
}
}
#[test]
fn join_handle_resolves_with_task_output() {
let result: Rc<Cell<u32>> = Rc::new(Cell::new(0));
let result2 = Rc::clone(&result);
let handle = spawn(async { 42u32 });
let _ = spawn(async move {
result2.set(handle.await);
});
drive_until_done();
assert_eq!(result.get(), 42);
}
#[test]
fn join_handle_waits_for_producer() {
let result: Rc<Cell<u32>> = Rc::new(Cell::new(0));
let result2 = Rc::clone(&result);
let handle = spawn(async { 99u32 });
let _ = spawn(async move {
result2.set(handle.await);
});
drive_until_done();
assert_eq!(result.get(), 99);
}
#[test]
fn drop_join_handle_task_still_runs() {
let ran: Rc<Cell<bool>> = Rc::new(Cell::new(false));
let ran2 = Rc::clone(&ran);
drop(spawn(async move {
ran2.set(true);
}));
drive_until_done();
assert!(ran.get(), "task should run even after handle is dropped");
}
#[test]
fn nested_spawn_resolves() {
let result: Rc<Cell<bool>> = Rc::new(Cell::new(false));
let result2 = Rc::clone(&result);
let _ = spawn(async move {
let handle = spawn(async { true });
result2.set(handle.await);
});
drive_until_done();
assert!(result.get());
}
#[test]
fn multiple_tasks_all_run() {
let counter: Rc<Cell<u32>> = Rc::new(Cell::new(0));
for _ in 0..5 {
let c = Rc::clone(&counter);
let _ = spawn(async move {
c.set(c.get() + 1);
});
}
drive_until_done();
assert_eq!(counter.get(), 5);
}
#[test]
fn select_left_wins_when_immediately_ready() {
let winner: Rc<Cell<u8>> = Rc::new(Cell::new(0));
let w2 = Rc::clone(&winner);
let _ = spawn(async move {
let r = select(async { 1u8 }, core::future::pending::<u8>()).await;
match r {
Either::Left(v) => w2.set(v),
Either::Right(_) => w2.set(99),
}
});
drive_until_done();
assert_eq!(winner.get(), 1);
}
#[test]
fn select_right_wins_when_left_never_resolves() {
let winner: Rc<Cell<u8>> = Rc::new(Cell::new(0));
let w2 = Rc::clone(&winner);
let _ = spawn(async move {
let r = select(core::future::pending::<u8>(), async { 2u8 }).await;
match r {
Either::Left(_) => w2.set(99),
Either::Right(v) => w2.set(v),
}
});
drive_until_done();
assert_eq!(winner.get(), 2);
}
#[test]
fn select_left_wins_when_both_immediately_ready() {
let winner: Rc<Cell<u8>> = Rc::new(Cell::new(0));
let w2 = Rc::clone(&winner);
let _ = spawn(async move {
let r = select(async { 10u8 }, async { 20u8 }).await;
match r {
Either::Left(v) | Either::Right(v) => w2.set(v),
}
});
drive_until_done();
assert_eq!(
winner.get(),
10,
"left is polled first so it wins when both ready"
);
}
#[test]
fn select_with_join_handle_arm() {
let result: Rc<Cell<u32>> = Rc::new(Cell::new(0));
let r2 = Rc::clone(&result);
let _ = spawn(async move {
let fast = spawn(async { 7u32 });
let r = select(fast, core::future::pending::<u32>()).await;
match r {
Either::Left(v) => r2.set(v),
Either::Right(v) => r2.set(v),
}
});
drive_until_done();
assert_eq!(result.get(), 7);
}
}