async_rt/rt/
tokio.rs

1use crate::{Executor, InnerJoinHandle, JoinHandle};
2use std::future::Future;
3use std::sync::Arc;
4use tokio::runtime::Runtime;
5
6/// Tokio executor
7#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq)]
8pub struct TokioExecutor;
9
10impl Executor for TokioExecutor {
11    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
12    where
13        F: Future + Send + 'static,
14        F::Output: Send + 'static,
15    {
16        let handle = tokio::task::spawn(future);
17        let inner = InnerJoinHandle::TokioHandle(handle);
18        JoinHandle { inner }
19    }
20}
21
22/// Tokio executor with an explicit [`Runtime`]
23#[derive(Clone, Debug)]
24pub struct TokioRuntimeExecutor {
25    runtime: Arc<Runtime>,
26}
27
28impl TokioRuntimeExecutor {
29    /// Creates a tokio runtime with the current thread scheduler selected.
30    pub fn with_single_thread() -> std::io::Result<Self> {
31        let runtime = tokio::runtime::Builder::new_current_thread()
32            .enable_all()
33            .build()?;
34        Ok(Self::with_runtime(runtime))
35    }
36
37    /// Creates a tokio runtime with multi-thread scheduler selected.
38    pub fn with_multi_thread() -> std::io::Result<Self> {
39        let runtime = tokio::runtime::Builder::new_multi_thread()
40            .enable_all()
41            .build()?;
42        Ok(Self::with_runtime(runtime))
43    }
44
45    /// Create an executor with the supplied [`Runtime`].
46    pub fn with_runtime(runtime: Runtime) -> Self {
47        let runtime = Arc::new(runtime);
48        Self { runtime }
49    }
50}
51
52impl Executor for TokioRuntimeExecutor {
53    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
54    where
55        F: Future + Send + 'static,
56        F::Output: Send + 'static,
57    {
58        let handle = self.runtime.spawn(future);
59        let inner = InnerJoinHandle::TokioHandle(handle);
60        JoinHandle { inner }
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::TokioExecutor;
67    use crate::Executor;
68    use futures::channel::mpsc::{Receiver, UnboundedReceiver};
69
70    #[tokio::test]
71    async fn default_abortable_task() {
72        let executor = TokioExecutor;
73
74        async fn task(tx: futures::channel::oneshot::Sender<()>) {
75            futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
76            let _ = tx.send(());
77            unreachable!();
78        }
79
80        let (tx, rx) = futures::channel::oneshot::channel::<()>();
81
82        let handle = executor.spawn_abortable(task(tx));
83
84        drop(handle);
85        let result = rx.await;
86        assert!(result.is_err());
87    }
88
89    #[tokio::test]
90    async fn task_coroutine() {
91        use futures::stream::StreamExt;
92        let executor = TokioExecutor;
93
94        enum Message {
95            Send(String, futures::channel::oneshot::Sender<String>),
96        }
97
98        let mut task = executor.spawn_coroutine(|mut rx: Receiver<Message>| async move {
99            while let Some(msg) = rx.next().await {
100                match msg {
101                    Message::Send(msg, sender) => {
102                        sender.send(msg).unwrap();
103                    }
104                }
105            }
106        });
107
108        let (tx, rx) = futures::channel::oneshot::channel::<String>();
109        let msg = Message::Send("Hello".into(), tx);
110
111        task.send(msg).await.unwrap();
112        let resp = rx.await.unwrap();
113        assert_eq!(resp, "Hello");
114    }
115
116    #[tokio::test]
117    async fn task_coroutine_with_context() {
118        use futures::stream::StreamExt;
119        let executor = TokioExecutor;
120
121        #[derive(Default)]
122        struct State {
123            message: String,
124        }
125
126        enum Message {
127            Set(String),
128            Get(futures::channel::oneshot::Sender<String>),
129        }
130
131        let mut task = executor.spawn_coroutine_with_context(
132            State::default(),
133            |mut state, mut rx: Receiver<Message>| async move {
134                while let Some(msg) = rx.next().await {
135                    match msg {
136                        Message::Set(msg) => {
137                            state.message = msg;
138                        }
139                        Message::Get(resp) => {
140                            resp.send(state.message.clone()).unwrap();
141                        }
142                    }
143                }
144            },
145        );
146
147        let msg = Message::Set("Hello".into());
148
149        task.send(msg).await.unwrap();
150        let (tx, rx) = futures::channel::oneshot::channel::<String>();
151        let msg = Message::Get(tx);
152        task.send(msg).await.unwrap();
153        let resp = rx.await.unwrap();
154        assert_eq!(resp, "Hello");
155    }
156
157    #[tokio::test]
158    async fn task_unbounded_coroutine() {
159        use futures::stream::StreamExt;
160        let executor = TokioExecutor;
161
162        enum Message {
163            Send(String, futures::channel::oneshot::Sender<String>),
164        }
165
166        let mut task =
167            executor.spawn_unbounded_coroutine(|mut rx: UnboundedReceiver<Message>| async move {
168                while let Some(msg) = rx.next().await {
169                    match msg {
170                        Message::Send(msg, sender) => {
171                            sender.send(msg).unwrap();
172                        }
173                    }
174                }
175            });
176
177        let (tx, rx) = futures::channel::oneshot::channel::<String>();
178        let msg = Message::Send("Hello".into(), tx);
179
180        task.send(msg).unwrap();
181        let resp = rx.await.unwrap();
182        assert_eq!(resp, "Hello");
183    }
184
185    #[tokio::test]
186    async fn task_unbounded_coroutine_with_context() {
187        use futures::stream::StreamExt;
188        let executor = TokioExecutor;
189
190        #[derive(Default)]
191        struct State {
192            message: String,
193        }
194
195        enum Message {
196            Set(String),
197            Get(futures::channel::oneshot::Sender<String>),
198        }
199
200        let mut task = executor.spawn_unbounded_coroutine_with_context(
201            State::default(),
202            |mut state, mut rx: UnboundedReceiver<Message>| async move {
203                while let Some(msg) = rx.next().await {
204                    match msg {
205                        Message::Set(msg) => {
206                            state.message = msg;
207                        }
208                        Message::Get(resp) => {
209                            resp.send(state.message.clone()).unwrap();
210                        }
211                    }
212                }
213            },
214        );
215
216        let msg = Message::Set("Hello".into());
217
218        task.send(msg).unwrap();
219        let (tx, rx) = futures::channel::oneshot::channel::<String>();
220        let msg = Message::Get(tx);
221        task.send(msg).unwrap();
222        let resp = rx.await.unwrap();
223        assert_eq!(resp, "Hello");
224    }
225}