async_dispatcher/
lib.rs

1pub use async_task::Runnable;
2use futures_lite::FutureExt;
3use std::{
4    error::Error,
5    fmt,
6    future::Future,
7    sync::{mpsc::RecvTimeoutError, OnceLock},
8    task::Poll,
9    time::{Duration, Instant},
10};
11
12pub fn block_on<T>(future: impl Future<Output = T>) -> T {
13    futures_lite::future::block_on(future)
14}
15
16static DISPATCHER: OnceLock<Box<dyn Dispatcher>> = OnceLock::new();
17
18pub trait Dispatcher: 'static + Send + Sync {
19    fn dispatch(&self, runnable: Runnable);
20    fn dispatch_after(&self, duration: Duration, runnable: Runnable);
21}
22
23pub fn set_dispatcher(dispatcher: impl Dispatcher) {
24    DISPATCHER.set(Box::new(dispatcher)).ok();
25}
26
27fn get_dispatcher() -> &'static dyn Dispatcher {
28    DISPATCHER
29        .get()
30        .expect("The dispatcher requires a call to set_dispatcher()")
31        .as_ref()
32}
33
34#[derive(Debug)]
35pub struct JoinHandle<T> {
36    task: Option<async_task::Task<T>>,
37}
38
39pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
40where
41    F: Future + 'static + Send,
42    F::Output: 'static + Send,
43{
44    let dispatcher = get_dispatcher();
45    let (runnable, task) = async_task::spawn(future, |runnable| dispatcher.dispatch(runnable));
46    runnable.schedule();
47    JoinHandle { task: Some(task) }
48}
49
50impl<T> Future for JoinHandle<T> {
51    type Output = T;
52
53    fn poll(
54        mut self: std::pin::Pin<&mut Self>,
55        cx: &mut std::task::Context<'_>,
56    ) -> Poll<Self::Output> {
57        std::pin::Pin::new(
58            self.task
59                .as_mut()
60                .expect("poll should not be called after drop"),
61        )
62        .poll(cx)
63    }
64}
65
66impl<T> Drop for JoinHandle<T> {
67    fn drop(&mut self) {
68        self.task
69            .take()
70            .expect("This is the only place the option is mutated")
71            .detach();
72    }
73}
74
75pub struct Sleep {
76    task: async_task::Task<()>,
77}
78
79pub fn sleep(time: Duration) -> Sleep {
80    let dispatcher = get_dispatcher();
81    let (runnable, task) = async_task::spawn(async {}, move |runnable| {
82        dispatcher.dispatch_after(time, runnable)
83    });
84    runnable.schedule();
85
86    Sleep { task }
87}
88
89impl Sleep {
90    pub fn reset(&mut self, deadline: Instant) {
91        let duration = deadline.saturating_duration_since(Instant::now());
92        self.task = sleep(duration).task
93    }
94}
95
96impl Future for Sleep {
97    type Output = ();
98
99    fn poll(
100        mut self: std::pin::Pin<&mut Self>,
101        cx: &mut std::task::Context<'_>,
102    ) -> Poll<Self::Output> {
103        std::pin::Pin::new(&mut self.task).poll(cx)
104    }
105}
106
107#[derive(Clone, Copy, Debug, Eq, PartialEq)]
108pub struct TimeoutError;
109
110impl Error for TimeoutError {}
111
112pub fn timeout<T>(
113    duration: Duration,
114    future: T,
115) -> impl Future<Output = Result<T::Output, TimeoutError>>
116where
117    T: Future,
118{
119    let future = async move { Ok(future.await) };
120    let timeout = async move {
121        sleep(duration).await;
122        Err(TimeoutError)
123    };
124    future.or(timeout)
125}
126
127impl fmt::Display for TimeoutError {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        "future has timed out".fmt(f)
130    }
131}
132
133pub fn thread_dispatcher() -> impl Dispatcher {
134    struct SimpleDispatcher {
135        tx: std::sync::mpsc::Sender<(Runnable, Option<Instant>)>,
136        _thread: std::thread::JoinHandle<()>,
137    }
138
139    impl Dispatcher for SimpleDispatcher {
140        fn dispatch(&self, runnable: Runnable) {
141            self.tx.send((runnable, None)).ok();
142        }
143
144        fn dispatch_after(&self, duration: Duration, runnable: Runnable) {
145            self.tx
146                .send((runnable, Some(Instant::now() + duration)))
147                .ok();
148        }
149    }
150
151    let (tx, rx) = std::sync::mpsc::channel::<(Runnable, Option<Instant>)>();
152    let _thread = std::thread::spawn(move || {
153        let mut timers = Vec::<(Runnable, Instant)>::new();
154        let mut recv_timeout = Duration::MAX;
155        loop {
156            match rx.recv_timeout(recv_timeout) {
157                Ok((runnable, time)) => {
158                    if let Some(time) = time {
159                        let now = Instant::now();
160                        if time > now {
161                            let ix = match timers.binary_search_by_key(&time, |t| t.1) {
162                                Ok(i) | Err(i) => i,
163                            };
164                            timers.insert(ix, (runnable, time));
165                            recv_timeout = timers.first().unwrap().1 - now;
166                            continue;
167                        }
168                    }
169                    runnable.run();
170                }
171                Err(RecvTimeoutError::Timeout) => {
172                    let now = Instant::now();
173                    while let Some((_, time)) = timers.first() {
174                        if *time > now {
175                            recv_timeout = *time - now;
176                            break;
177                        }
178                        timers.remove(0).0.run();
179                    }
180                }
181                Err(RecvTimeoutError::Disconnected) => break,
182            }
183        }
184    });
185
186    SimpleDispatcher { tx, _thread }
187}
188
189#[cfg(feature = "macros")]
190pub use async_dispatcher_macros::test;
191
192#[cfg(feature = "macros")]
193pub use async_dispatcher_macros::main;