1use crate::{Executor, ExecutorBlocking, InnerJoinHandle, JoinHandle};
2use std::future::Future;
3use std::sync::Arc;
4use tokio::runtime::Runtime;
5
6#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq)]
8pub struct TokioExecutor;
9
10impl Executor for TokioExecutor {
11 fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
12 where
13 F: Future + Send + 'static,
14 F::Output: Send + 'static,
15 {
16 let handle = tokio::task::spawn(future);
17 let inner = InnerJoinHandle::TokioHandle(handle);
18 JoinHandle { inner }
19 }
20}
21
22impl ExecutorBlocking for TokioExecutor {
23 fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
24 where
25 F: FnOnce() -> R + Send + 'static,
26 R: Send + 'static,
27 {
28 let handle = tokio::task::spawn_blocking(f);
29 let inner = InnerJoinHandle::TokioHandle(handle);
30 JoinHandle { inner }
31 }
32}
33
34#[derive(Clone, Debug)]
36pub struct TokioRuntimeExecutor {
37 runtime: Arc<Runtime>,
38}
39
40impl TokioRuntimeExecutor {
41 pub fn with_single_thread() -> std::io::Result<Self> {
43 let runtime = tokio::runtime::Builder::new_current_thread()
44 .enable_all()
45 .build()?;
46 Ok(Self::with_runtime(runtime))
47 }
48
49 pub fn with_multi_thread() -> std::io::Result<Self> {
51 let runtime = tokio::runtime::Builder::new_multi_thread()
52 .enable_all()
53 .build()?;
54 Ok(Self::with_runtime(runtime))
55 }
56
57 pub fn with_runtime(runtime: Runtime) -> Self {
59 let runtime = Arc::new(runtime);
60 Self { runtime }
61 }
62}
63
64impl Executor for TokioRuntimeExecutor {
65 fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
66 where
67 F: Future + Send + 'static,
68 F::Output: Send + 'static,
69 {
70 let handle = self.runtime.spawn(future);
71 let inner = InnerJoinHandle::TokioHandle(handle);
72 JoinHandle { inner }
73 }
74}
75
76impl ExecutorBlocking for TokioRuntimeExecutor {
77 fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
78 where
79 F: FnOnce() -> R + Send + 'static,
80 R: Send + 'static,
81 {
82 let handle = self.runtime.spawn_blocking(f);
83 let inner = InnerJoinHandle::TokioHandle(handle);
84 JoinHandle { inner }
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::TokioExecutor;
91 use crate::{Executor, ExecutorBlocking};
92 use futures::channel::mpsc::{Receiver, UnboundedReceiver};
93
94 #[tokio::test]
95 async fn default_abortable_task() {
96 let executor = TokioExecutor;
97
98 async fn task(tx: futures::channel::oneshot::Sender<()>) {
99 futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
100 let _ = tx.send(());
101 unreachable!();
102 }
103
104 let (tx, rx) = futures::channel::oneshot::channel::<()>();
105
106 let handle = executor.spawn_abortable(task(tx));
107
108 drop(handle);
109 let result = rx.await;
110 assert!(result.is_err());
111 }
112
113 #[tokio::test]
114 async fn task_coroutine() {
115 use futures::stream::StreamExt;
116 let executor = TokioExecutor;
117
118 enum Message {
119 Send(String, futures::channel::oneshot::Sender<String>),
120 }
121
122 let mut task = executor.spawn_coroutine(|mut rx: Receiver<Message>| async move {
123 while let Some(msg) = rx.next().await {
124 match msg {
125 Message::Send(msg, sender) => {
126 sender.send(msg).unwrap();
127 }
128 }
129 }
130 });
131
132 let (tx, rx) = futures::channel::oneshot::channel::<String>();
133 let msg = Message::Send("Hello".into(), tx);
134
135 task.send(msg).await.unwrap();
136 let resp = rx.await.unwrap();
137 assert_eq!(resp, "Hello");
138 }
139
140 #[tokio::test]
141 async fn task_coroutine_with_context() {
142 use futures::stream::StreamExt;
143 let executor = TokioExecutor;
144
145 #[derive(Default)]
146 struct State {
147 message: String,
148 }
149
150 enum Message {
151 Set(String),
152 Get(futures::channel::oneshot::Sender<String>),
153 }
154
155 let mut task = executor.spawn_coroutine_with_context(
156 State::default(),
157 |mut state, mut rx: Receiver<Message>| async move {
158 while let Some(msg) = rx.next().await {
159 match msg {
160 Message::Set(msg) => {
161 state.message = msg;
162 }
163 Message::Get(resp) => {
164 resp.send(state.message.clone()).unwrap();
165 }
166 }
167 }
168 },
169 );
170
171 let msg = Message::Set("Hello".into());
172
173 task.send(msg).await.unwrap();
174 let (tx, rx) = futures::channel::oneshot::channel::<String>();
175 let msg = Message::Get(tx);
176 task.send(msg).await.unwrap();
177 let resp = rx.await.unwrap();
178 assert_eq!(resp, "Hello");
179 }
180
181 #[tokio::test]
182 async fn task_unbounded_coroutine() {
183 use futures::stream::StreamExt;
184 let executor = TokioExecutor;
185
186 enum Message {
187 Send(String, futures::channel::oneshot::Sender<String>),
188 }
189
190 let mut task =
191 executor.spawn_unbounded_coroutine(|mut rx: UnboundedReceiver<Message>| async move {
192 while let Some(msg) = rx.next().await {
193 match msg {
194 Message::Send(msg, sender) => {
195 sender.send(msg).unwrap();
196 }
197 }
198 }
199 });
200
201 let (tx, rx) = futures::channel::oneshot::channel::<String>();
202 let msg = Message::Send("Hello".into(), tx);
203
204 task.send(msg).unwrap();
205 let resp = rx.await.unwrap();
206 assert_eq!(resp, "Hello");
207 }
208
209 #[tokio::test]
210 async fn task_unbounded_coroutine_with_context() {
211 use futures::stream::StreamExt;
212 let executor = TokioExecutor;
213
214 #[derive(Default)]
215 struct State {
216 message: String,
217 }
218
219 enum Message {
220 Set(String),
221 Get(futures::channel::oneshot::Sender<String>),
222 }
223
224 let mut task = executor.spawn_unbounded_coroutine_with_context(
225 State::default(),
226 |mut state, mut rx: UnboundedReceiver<Message>| async move {
227 while let Some(msg) = rx.next().await {
228 match msg {
229 Message::Set(msg) => {
230 state.message = msg;
231 }
232 Message::Get(resp) => {
233 resp.send(state.message.clone()).unwrap();
234 }
235 }
236 }
237 },
238 );
239
240 let msg = Message::Set("Hello".into());
241
242 task.send(msg).unwrap();
243 let (tx, rx) = futures::channel::oneshot::channel::<String>();
244 let msg = Message::Get(tx);
245 task.send(msg).unwrap();
246 let resp = rx.await.unwrap();
247 assert_eq!(resp, "Hello");
248 }
249
250 #[tokio::test]
251 async fn blocking_task() {
252 let executor = TokioExecutor;
253
254 let task = executor.spawn_blocking(|| {
255 std::thread::sleep(std::time::Duration::from_millis(100));
256 "Hello"
257 });
258 let resp = task.await.unwrap();
259 assert_eq!(resp, "Hello");
260 }
261}