1use crate::{Executor, InnerJoinHandle, JoinHandle};
2use std::future::Future;
3use std::sync::Arc;
4use tokio::runtime::Runtime;
5
6#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq)]
8pub struct TokioExecutor;
9
10impl Executor for TokioExecutor {
11 fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
12 where
13 F: Future + Send + 'static,
14 F::Output: Send + 'static,
15 {
16 let handle = tokio::task::spawn(future);
17 let inner = InnerJoinHandle::TokioHandle(handle);
18 JoinHandle { inner }
19 }
20}
21
22#[derive(Clone, Debug)]
24pub struct TokioRuntimeExecutor {
25 runtime: Arc<Runtime>,
26}
27
28impl TokioRuntimeExecutor {
29 pub fn with_single_thread() -> std::io::Result<Self> {
31 let runtime = tokio::runtime::Builder::new_current_thread()
32 .enable_all()
33 .build()?;
34 Ok(Self::with_runtime(runtime))
35 }
36
37 pub fn with_multi_thread() -> std::io::Result<Self> {
39 let runtime = tokio::runtime::Builder::new_multi_thread()
40 .enable_all()
41 .build()?;
42 Ok(Self::with_runtime(runtime))
43 }
44
45 pub fn with_runtime(runtime: Runtime) -> Self {
47 let runtime = Arc::new(runtime);
48 Self { runtime }
49 }
50}
51
52impl Executor for TokioRuntimeExecutor {
53 fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
54 where
55 F: Future + Send + 'static,
56 F::Output: Send + 'static,
57 {
58 let handle = self.runtime.spawn(future);
59 let inner = InnerJoinHandle::TokioHandle(handle);
60 JoinHandle { inner }
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::TokioExecutor;
67 use crate::Executor;
68 use futures::channel::mpsc::{Receiver, UnboundedReceiver};
69
70 #[tokio::test]
71 async fn default_abortable_task() {
72 let executor = TokioExecutor;
73
74 async fn task(tx: futures::channel::oneshot::Sender<()>) {
75 futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
76 let _ = tx.send(());
77 unreachable!();
78 }
79
80 let (tx, rx) = futures::channel::oneshot::channel::<()>();
81
82 let handle = executor.spawn_abortable(task(tx));
83
84 drop(handle);
85 let result = rx.await;
86 assert!(result.is_err());
87 }
88
89 #[tokio::test]
90 async fn task_coroutine() {
91 use futures::stream::StreamExt;
92 let executor = TokioExecutor;
93
94 enum Message {
95 Send(String, futures::channel::oneshot::Sender<String>),
96 }
97
98 let mut task = executor.spawn_coroutine(|mut rx: Receiver<Message>| async move {
99 while let Some(msg) = rx.next().await {
100 match msg {
101 Message::Send(msg, sender) => {
102 sender.send(msg).unwrap();
103 }
104 }
105 }
106 });
107
108 let (tx, rx) = futures::channel::oneshot::channel::<String>();
109 let msg = Message::Send("Hello".into(), tx);
110
111 task.send(msg).await.unwrap();
112 let resp = rx.await.unwrap();
113 assert_eq!(resp, "Hello");
114 }
115
116 #[tokio::test]
117 async fn task_coroutine_with_context() {
118 use futures::stream::StreamExt;
119 let executor = TokioExecutor;
120
121 #[derive(Default)]
122 struct State {
123 message: String,
124 }
125
126 enum Message {
127 Set(String),
128 Get(futures::channel::oneshot::Sender<String>),
129 }
130
131 let mut task = executor.spawn_coroutine_with_context(
132 State::default(),
133 |mut state, mut rx: Receiver<Message>| async move {
134 while let Some(msg) = rx.next().await {
135 match msg {
136 Message::Set(msg) => {
137 state.message = msg;
138 }
139 Message::Get(resp) => {
140 resp.send(state.message.clone()).unwrap();
141 }
142 }
143 }
144 },
145 );
146
147 let msg = Message::Set("Hello".into());
148
149 task.send(msg).await.unwrap();
150 let (tx, rx) = futures::channel::oneshot::channel::<String>();
151 let msg = Message::Get(tx);
152 task.send(msg).await.unwrap();
153 let resp = rx.await.unwrap();
154 assert_eq!(resp, "Hello");
155 }
156
157 #[tokio::test]
158 async fn task_unbounded_coroutine() {
159 use futures::stream::StreamExt;
160 let executor = TokioExecutor;
161
162 enum Message {
163 Send(String, futures::channel::oneshot::Sender<String>),
164 }
165
166 let mut task =
167 executor.spawn_unbounded_coroutine(|mut rx: UnboundedReceiver<Message>| async move {
168 while let Some(msg) = rx.next().await {
169 match msg {
170 Message::Send(msg, sender) => {
171 sender.send(msg).unwrap();
172 }
173 }
174 }
175 });
176
177 let (tx, rx) = futures::channel::oneshot::channel::<String>();
178 let msg = Message::Send("Hello".into(), tx);
179
180 task.send(msg).unwrap();
181 let resp = rx.await.unwrap();
182 assert_eq!(resp, "Hello");
183 }
184
185 #[tokio::test]
186 async fn task_unbounded_coroutine_with_context() {
187 use futures::stream::StreamExt;
188 let executor = TokioExecutor;
189
190 #[derive(Default)]
191 struct State {
192 message: String,
193 }
194
195 enum Message {
196 Set(String),
197 Get(futures::channel::oneshot::Sender<String>),
198 }
199
200 let mut task = executor.spawn_unbounded_coroutine_with_context(
201 State::default(),
202 |mut state, mut rx: UnboundedReceiver<Message>| async move {
203 while let Some(msg) = rx.next().await {
204 match msg {
205 Message::Set(msg) => {
206 state.message = msg;
207 }
208 Message::Get(resp) => {
209 resp.send(state.message.clone()).unwrap();
210 }
211 }
212 }
213 },
214 );
215
216 let msg = Message::Set("Hello".into());
217
218 task.send(msg).unwrap();
219 let (tx, rx) = futures::channel::oneshot::channel::<String>();
220 let msg = Message::Get(tx);
221 task.send(msg).unwrap();
222 let resp = rx.await.unwrap();
223 assert_eq!(resp, "Hello");
224 }
225}