Skip to main content

async_rt/rt/
tokio.rs

1use crate::{Executor, ExecutorBlocking, InnerJoinHandle, JoinHandle};
2use std::future::Future;
3use std::sync::Arc;
4use tokio::runtime::Runtime;
5
6/// Tokio executor
7#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq)]
8pub struct TokioExecutor;
9
10impl Executor for TokioExecutor {
11    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
12    where
13        F: Future + Send + 'static,
14        F::Output: Send + 'static,
15    {
16        let handle = tokio::task::spawn(future);
17        let inner = InnerJoinHandle::TokioHandle(handle);
18        JoinHandle { inner }
19    }
20}
21
22impl ExecutorBlocking for TokioExecutor {
23    fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
24    where
25        F: FnOnce() -> R + Send + 'static,
26        R: Send + 'static,
27    {
28        let handle = tokio::task::spawn_blocking(f);
29        let inner = InnerJoinHandle::TokioHandle(handle);
30        JoinHandle { inner }
31    }
32}
33
34/// Tokio executor with an explicit [`Runtime`]
35#[derive(Clone, Debug)]
36pub struct TokioRuntimeExecutor {
37    runtime: Arc<Runtime>,
38}
39
40impl TokioRuntimeExecutor {
41    /// Creates a tokio runtime with the current thread scheduler selected.
42    pub fn with_single_thread() -> std::io::Result<Self> {
43        let runtime = tokio::runtime::Builder::new_current_thread()
44            .enable_all()
45            .build()?;
46        Ok(Self::with_runtime(runtime))
47    }
48
49    /// Creates a tokio runtime with multi-thread scheduler selected.
50    pub fn with_multi_thread() -> std::io::Result<Self> {
51        let runtime = tokio::runtime::Builder::new_multi_thread()
52            .enable_all()
53            .build()?;
54        Ok(Self::with_runtime(runtime))
55    }
56
57    /// Create an executor with the supplied [`Runtime`].
58    pub fn with_runtime(runtime: Runtime) -> Self {
59        let runtime = Arc::new(runtime);
60        Self { runtime }
61    }
62}
63
64impl Executor for TokioRuntimeExecutor {
65    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
66    where
67        F: Future + Send + 'static,
68        F::Output: Send + 'static,
69    {
70        let handle = self.runtime.spawn(future);
71        let inner = InnerJoinHandle::TokioHandle(handle);
72        JoinHandle { inner }
73    }
74}
75
76impl ExecutorBlocking for TokioRuntimeExecutor {
77    fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
78    where
79        F: FnOnce() -> R + Send + 'static,
80        R: Send + 'static,
81    {
82        let handle = self.runtime.spawn_blocking(f);
83        let inner = InnerJoinHandle::TokioHandle(handle);
84        JoinHandle { inner }
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::TokioExecutor;
91    use crate::{Executor, ExecutorBlocking};
92    use futures::channel::mpsc::{Receiver, UnboundedReceiver};
93
94    #[tokio::test]
95    async fn default_abortable_task() {
96        let executor = TokioExecutor;
97
98        async fn task(tx: futures::channel::oneshot::Sender<()>) {
99            futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
100            let _ = tx.send(());
101            unreachable!();
102        }
103
104        let (tx, rx) = futures::channel::oneshot::channel::<()>();
105
106        let handle = executor.spawn_abortable(task(tx));
107
108        drop(handle);
109        let result = rx.await;
110        assert!(result.is_err());
111    }
112
113    #[tokio::test]
114    async fn task_coroutine() {
115        use futures::stream::StreamExt;
116        let executor = TokioExecutor;
117
118        enum Message {
119            Send(String, futures::channel::oneshot::Sender<String>),
120        }
121
122        let mut task = executor.spawn_coroutine(|mut rx: Receiver<Message>| async move {
123            while let Some(msg) = rx.next().await {
124                match msg {
125                    Message::Send(msg, sender) => {
126                        sender.send(msg).unwrap();
127                    }
128                }
129            }
130        });
131
132        let (tx, rx) = futures::channel::oneshot::channel::<String>();
133        let msg = Message::Send("Hello".into(), tx);
134
135        task.send(msg).await.unwrap();
136        let resp = rx.await.unwrap();
137        assert_eq!(resp, "Hello");
138    }
139
140    #[tokio::test]
141    async fn task_coroutine_with_context() {
142        use futures::stream::StreamExt;
143        let executor = TokioExecutor;
144
145        #[derive(Default)]
146        struct State {
147            message: String,
148        }
149
150        enum Message {
151            Set(String),
152            Get(futures::channel::oneshot::Sender<String>),
153        }
154
155        let mut task = executor.spawn_coroutine_with_context(
156            State::default(),
157            |mut state, mut rx: Receiver<Message>| async move {
158                while let Some(msg) = rx.next().await {
159                    match msg {
160                        Message::Set(msg) => {
161                            state.message = msg;
162                        }
163                        Message::Get(resp) => {
164                            resp.send(state.message.clone()).unwrap();
165                        }
166                    }
167                }
168            },
169        );
170
171        let msg = Message::Set("Hello".into());
172
173        task.send(msg).await.unwrap();
174        let (tx, rx) = futures::channel::oneshot::channel::<String>();
175        let msg = Message::Get(tx);
176        task.send(msg).await.unwrap();
177        let resp = rx.await.unwrap();
178        assert_eq!(resp, "Hello");
179    }
180
181    #[tokio::test]
182    async fn task_unbounded_coroutine() {
183        use futures::stream::StreamExt;
184        let executor = TokioExecutor;
185
186        enum Message {
187            Send(String, futures::channel::oneshot::Sender<String>),
188        }
189
190        let mut task =
191            executor.spawn_unbounded_coroutine(|mut rx: UnboundedReceiver<Message>| async move {
192                while let Some(msg) = rx.next().await {
193                    match msg {
194                        Message::Send(msg, sender) => {
195                            sender.send(msg).unwrap();
196                        }
197                    }
198                }
199            });
200
201        let (tx, rx) = futures::channel::oneshot::channel::<String>();
202        let msg = Message::Send("Hello".into(), tx);
203
204        task.send(msg).unwrap();
205        let resp = rx.await.unwrap();
206        assert_eq!(resp, "Hello");
207    }
208
209    #[tokio::test]
210    async fn task_unbounded_coroutine_with_context() {
211        use futures::stream::StreamExt;
212        let executor = TokioExecutor;
213
214        #[derive(Default)]
215        struct State {
216            message: String,
217        }
218
219        enum Message {
220            Set(String),
221            Get(futures::channel::oneshot::Sender<String>),
222        }
223
224        let mut task = executor.spawn_unbounded_coroutine_with_context(
225            State::default(),
226            |mut state, mut rx: UnboundedReceiver<Message>| async move {
227                while let Some(msg) = rx.next().await {
228                    match msg {
229                        Message::Set(msg) => {
230                            state.message = msg;
231                        }
232                        Message::Get(resp) => {
233                            resp.send(state.message.clone()).unwrap();
234                        }
235                    }
236                }
237            },
238        );
239
240        let msg = Message::Set("Hello".into());
241
242        task.send(msg).unwrap();
243        let (tx, rx) = futures::channel::oneshot::channel::<String>();
244        let msg = Message::Get(tx);
245        task.send(msg).unwrap();
246        let resp = rx.await.unwrap();
247        assert_eq!(resp, "Hello");
248    }
249
250    #[tokio::test]
251    async fn blocking_task() {
252        let executor = TokioExecutor;
253
254        let task = executor.spawn_blocking(|| {
255            std::thread::sleep(std::time::Duration::from_millis(100));
256            "Hello"
257        });
258        let resp = task.await.unwrap();
259        assert_eq!(resp, "Hello");
260    }
261}