1pub use async_task::Runnable;
2use futures_lite::FutureExt;
3use std::{
4 error::Error,
5 fmt,
6 future::Future,
7 sync::{mpsc::RecvTimeoutError, OnceLock},
8 task::Poll,
9 time::{Duration, Instant},
10};
11
12pub fn block_on<T>(future: impl Future<Output = T>) -> T {
13 futures_lite::future::block_on(future)
14}
15
16static DISPATCHER: OnceLock<Box<dyn Dispatcher>> = OnceLock::new();
17
18pub trait Dispatcher: 'static + Send + Sync {
19 fn dispatch(&self, runnable: Runnable);
20 fn dispatch_after(&self, duration: Duration, runnable: Runnable);
21}
22
23pub fn set_dispatcher(dispatcher: impl Dispatcher) {
24 DISPATCHER.set(Box::new(dispatcher)).ok();
25}
26
27fn get_dispatcher() -> &'static dyn Dispatcher {
28 DISPATCHER
29 .get()
30 .expect("The dispatcher requires a call to set_dispatcher()")
31 .as_ref()
32}
33
34#[derive(Debug)]
35pub struct JoinHandle<T> {
36 task: Option<async_task::Task<T>>,
37}
38
39pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
40where
41 F: Future + 'static + Send,
42 F::Output: 'static + Send,
43{
44 let dispatcher = get_dispatcher();
45 let (runnable, task) = async_task::spawn(future, |runnable| dispatcher.dispatch(runnable));
46 runnable.schedule();
47 JoinHandle { task: Some(task) }
48}
49
50impl<T> Future for JoinHandle<T> {
51 type Output = T;
52
53 fn poll(
54 mut self: std::pin::Pin<&mut Self>,
55 cx: &mut std::task::Context<'_>,
56 ) -> Poll<Self::Output> {
57 std::pin::Pin::new(
58 self.task
59 .as_mut()
60 .expect("poll should not be called after drop"),
61 )
62 .poll(cx)
63 }
64}
65
66impl<T> Drop for JoinHandle<T> {
67 fn drop(&mut self) {
68 self.task
69 .take()
70 .expect("This is the only place the option is mutated")
71 .detach();
72 }
73}
74
75pub struct Sleep {
76 task: async_task::Task<()>,
77}
78
79pub fn sleep(time: Duration) -> Sleep {
80 let dispatcher = get_dispatcher();
81 let (runnable, task) = async_task::spawn(async {}, move |runnable| {
82 dispatcher.dispatch_after(time, runnable)
83 });
84 runnable.schedule();
85
86 Sleep { task }
87}
88
89impl Sleep {
90 pub fn reset(&mut self, deadline: Instant) {
91 let duration = deadline.saturating_duration_since(Instant::now());
92 self.task = sleep(duration).task
93 }
94}
95
96impl Future for Sleep {
97 type Output = ();
98
99 fn poll(
100 mut self: std::pin::Pin<&mut Self>,
101 cx: &mut std::task::Context<'_>,
102 ) -> Poll<Self::Output> {
103 std::pin::Pin::new(&mut self.task).poll(cx)
104 }
105}
106
107#[derive(Clone, Copy, Debug, Eq, PartialEq)]
108pub struct TimeoutError;
109
110impl Error for TimeoutError {}
111
112pub fn timeout<T>(
113 duration: Duration,
114 future: T,
115) -> impl Future<Output = Result<T::Output, TimeoutError>>
116where
117 T: Future,
118{
119 let future = async move { Ok(future.await) };
120 let timeout = async move {
121 sleep(duration).await;
122 Err(TimeoutError)
123 };
124 future.or(timeout)
125}
126
127impl fmt::Display for TimeoutError {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 "future has timed out".fmt(f)
130 }
131}
132
133pub fn thread_dispatcher() -> impl Dispatcher {
134 struct SimpleDispatcher {
135 tx: std::sync::mpsc::Sender<(Runnable, Option<Instant>)>,
136 _thread: std::thread::JoinHandle<()>,
137 }
138
139 impl Dispatcher for SimpleDispatcher {
140 fn dispatch(&self, runnable: Runnable) {
141 self.tx.send((runnable, None)).ok();
142 }
143
144 fn dispatch_after(&self, duration: Duration, runnable: Runnable) {
145 self.tx
146 .send((runnable, Some(Instant::now() + duration)))
147 .ok();
148 }
149 }
150
151 let (tx, rx) = std::sync::mpsc::channel::<(Runnable, Option<Instant>)>();
152 let _thread = std::thread::spawn(move || {
153 let mut timers = Vec::<(Runnable, Instant)>::new();
154 let mut recv_timeout = Duration::MAX;
155 loop {
156 match rx.recv_timeout(recv_timeout) {
157 Ok((runnable, time)) => {
158 if let Some(time) = time {
159 let now = Instant::now();
160 if time > now {
161 let ix = match timers.binary_search_by_key(&time, |t| t.1) {
162 Ok(i) | Err(i) => i,
163 };
164 timers.insert(ix, (runnable, time));
165 recv_timeout = timers.first().unwrap().1 - now;
166 continue;
167 }
168 }
169 runnable.run();
170 }
171 Err(RecvTimeoutError::Timeout) => {
172 let now = Instant::now();
173 while let Some((_, time)) = timers.first() {
174 if *time > now {
175 recv_timeout = *time - now;
176 break;
177 }
178 timers.remove(0).0.run();
179 }
180 }
181 Err(RecvTimeoutError::Disconnected) => break,
182 }
183 }
184 });
185
186 SimpleDispatcher { tx, _thread }
187}
188
189#[cfg(feature = "macros")]
190pub use async_dispatcher_macros::test;
191
192#[cfg(feature = "macros")]
193pub use async_dispatcher_macros::main;