executor_core/
tokio.rs

1//! Integration with the Tokio async runtime.
2//!
3//! This module provides implementations of the [`Executor`] and [`LocalExecutor`] traits
4//! for the Tokio runtime, along with task wrappers that provide panic safety.
5
6#[cfg(feature = "std")]
7extern crate std;
8
9use crate::{Executor, LocalExecutor, Task};
10use alloc::boxed::Box;
11use core::{
12    future::Future,
13    pin::Pin,
14    task::{Context, Poll},
15};
16
17pub use tokio::{runtime::Runtime, task::JoinHandle, task::LocalSet};
18
19/// Task wrapper for Tokio's `JoinHandle` that implements the [`Task`] trait.
20///
21/// This provides panic safety and proper error handling for tasks spawned
22/// with Tokio's `spawn` function.
23pub struct TokioTask<T> {
24    handle: tokio::task::JoinHandle<T>,
25}
26
27impl<T> core::fmt::Debug for TokioTask<T> {
28    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
29        f.debug_struct("TokioTask").finish_non_exhaustive()
30    }
31}
32
33impl<T: Send + 'static> Future for TokioTask<T> {
34    type Output = T;
35
36    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
37        match Pin::new(&mut self.handle).poll(cx) {
38            Poll::Ready(Ok(result)) => Poll::Ready(result),
39            Poll::Ready(Err(err)) => {
40                if err.is_panic() {
41                    std::panic::resume_unwind(err.into_panic());
42                } else {
43                    // Task was cancelled
44                    std::panic::panic_any("Task was cancelled")
45                }
46            }
47            Poll::Pending => Poll::Pending,
48        }
49    }
50}
51
52impl<T: Send + 'static> Task<T> for TokioTask<T> {
53    fn poll_result(
54        mut self: Pin<&mut Self>,
55        cx: &mut Context<'_>,
56    ) -> Poll<Result<T, crate::Error>> {
57        match Pin::new(&mut self.handle).poll(cx) {
58            Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
59            Poll::Ready(Err(err)) => {
60                let error: crate::Error = if err.is_panic() {
61                    err.into_panic()
62                } else {
63                    Box::new("Task was cancelled")
64                };
65                Poll::Ready(Err(error))
66            }
67            Poll::Pending => Poll::Pending,
68        }
69    }
70}
71
72impl<T> Drop for TokioTask<T> {
73    fn drop(&mut self) {
74        self.handle.abort();
75    }
76}
77
78/// Task wrapper for Tokio's local `JoinHandle` (non-Send futures).
79///
80/// This provides panic safety and proper error handling for tasks spawned
81/// with Tokio's `spawn_local` function.
82pub struct TokioLocalTask<T> {
83    handle: tokio::task::JoinHandle<T>,
84}
85
86impl<T> core::fmt::Debug for TokioLocalTask<T> {
87    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
88        f.debug_struct("TokioLocalTask").finish_non_exhaustive()
89    }
90}
91
92impl<T: 'static> Future for TokioLocalTask<T> {
93    type Output = T;
94
95    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
96        match Pin::new(&mut self.handle).poll(cx) {
97            Poll::Ready(Ok(result)) => Poll::Ready(result),
98            Poll::Ready(Err(err)) => {
99                if err.is_panic() {
100                    std::panic::resume_unwind(err.into_panic());
101                } else {
102                    // Task was cancelled
103                    std::panic::panic_any("Task was cancelled")
104                }
105            }
106            Poll::Pending => Poll::Pending,
107        }
108    }
109}
110
111impl<T: 'static> Task<T> for TokioLocalTask<T> {
112    fn poll_result(
113        mut self: Pin<&mut Self>,
114        cx: &mut Context<'_>,
115    ) -> Poll<Result<T, crate::Error>> {
116        match Pin::new(&mut self.handle).poll(cx) {
117            Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
118            Poll::Ready(Err(err)) => {
119                let error: crate::Error = if err.is_panic() {
120                    err.into_panic()
121                } else {
122                    Box::new("Task was cancelled")
123                };
124                Poll::Ready(Err(error))
125            }
126            Poll::Pending => Poll::Pending,
127        }
128    }
129}
130
131impl Executor for tokio::runtime::Runtime {
132    type Task<T: Send + 'static> = TokioTask<T>;
133
134    fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
135    where
136        Fut: Future<Output: Send> + Send + 'static,
137    {
138        let handle = self.spawn(fut);
139        TokioTask { handle }
140    }
141}
142
143impl LocalExecutor for tokio::task::LocalSet {
144    type Task<T: 'static> = TokioLocalTask<T>;
145
146    fn spawn_local<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
147    where
148        Fut: Future + 'static,
149    {
150        let handle = self.spawn_local(fut);
151        TokioLocalTask { handle }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::{Executor, LocalExecutor, Task};
159    use alloc::task::Wake;
160    use alloc::{format, sync::Arc};
161    use core::future::Future;
162    use core::{
163        pin::Pin,
164        task::{Context, Poll, Waker},
165    };
166    use tokio::time::{Duration, sleep};
167
168    struct TestWaker;
169    impl Wake for TestWaker {
170        fn wake(self: Arc<Self>) {}
171    }
172
173    fn create_waker() -> Waker {
174        Arc::new(TestWaker).into()
175    }
176
177    #[test]
178    fn test_default_executor_spawn() {
179        let executor = Runtime::new().expect("Failed to create Tokio runtime");
180        let task: TokioTask<i32> = Executor::spawn(&executor, async { 42 });
181        let result = executor.block_on(task);
182        assert_eq!(result, 42);
183    }
184
185    #[test]
186    fn test_default_executor_spawn_async_operation() {
187        let executor = Runtime::new().expect("Failed to create Tokio runtime");
188        let task: TokioTask<&str> = Executor::spawn(&executor, async {
189            sleep(Duration::from_millis(10)).await;
190            "completed"
191        });
192        let result = executor.block_on(task);
193        assert_eq!(result, "completed");
194    }
195
196    #[test]
197    fn test_tokio_task_future_impl() {
198        let executor = Runtime::new().expect("Failed to create Tokio runtime");
199        let mut task: TokioTask<i32> = Executor::spawn(&executor, async { 100 });
200
201        let waker = create_waker();
202        let mut cx = Context::from_waker(&waker);
203
204        match Pin::new(&mut task).poll(&mut cx) {
205            Poll::Ready(result) => assert_eq!(result, 100),
206            Poll::Pending => {
207                let result = executor.block_on(task);
208                assert_eq!(result, 100);
209            }
210        }
211    }
212
213    #[test]
214    fn test_tokio_task_poll_result() {
215        let executor = Runtime::new().expect("Failed to create Tokio runtime");
216        let mut task: TokioTask<&str> = Executor::spawn(&executor, async { "success" });
217
218        let waker = create_waker();
219        let mut cx = Context::from_waker(&waker);
220
221        match Pin::new(&mut task).poll_result(&mut cx) {
222            Poll::Ready(Ok(result)) => assert_eq!(result, "success"),
223            Poll::Ready(Err(_)) => panic!("Task should not fail"),
224            Poll::Pending => {
225                let result = executor.block_on(task.result());
226                assert!(result.is_ok());
227                assert_eq!(result.unwrap(), "success");
228            }
229        }
230    }
231
232    #[test]
233    fn test_tokio_task_panic_handling() {
234        let executor = Runtime::new().expect("Failed to create Tokio runtime");
235        let task: TokioTask<()> = Executor::spawn(&executor, async {
236            panic!("test panic");
237        });
238
239        let result = executor.block_on(task.result());
240        assert!(result.is_err());
241    }
242
243    #[test]
244    fn test_default_executor_default() {
245        let executor1 = Runtime::new().expect("Failed to create Tokio runtime");
246        let executor2 = Runtime::new().expect("Failed to create Tokio runtime");
247
248        let task1: TokioTask<i32> = Executor::spawn(&executor1, async { 1 });
249        let task2: TokioTask<i32> = Executor::spawn(&executor2, async { 2 });
250
251        assert_eq!(executor1.block_on(task1), 1);
252        assert_eq!(executor2.block_on(task2), 2);
253    }
254
255    #[test]
256    fn test_runtime_executor_impl() {
257        let rt = tokio::runtime::Runtime::new().unwrap();
258        let task: TokioTask<&str> = Executor::spawn(&rt, async { "runtime task" });
259        let result = rt.block_on(task);
260        assert_eq!(result, "runtime task");
261    }
262
263    #[tokio::test]
264    async fn test_local_set_executor() {
265        let local_set = tokio::task::LocalSet::new();
266
267        local_set
268            .run_until(async {
269                let task: TokioLocalTask<&str> =
270                    LocalExecutor::spawn_local(&local_set, async { "local task" });
271                let result = task.await;
272                assert_eq!(result, "local task");
273            })
274            .await;
275    }
276
277    #[tokio::test]
278    async fn test_tokio_local_task_future_impl() {
279        let local_set = tokio::task::LocalSet::new();
280
281        local_set
282            .run_until(async {
283                let mut task: TokioLocalTask<i32> =
284                    LocalExecutor::spawn_local(&local_set, async { 200 });
285
286                let waker = create_waker();
287                let mut cx = Context::from_waker(&waker);
288
289                match Pin::new(&mut task).poll(&mut cx) {
290                    Poll::Ready(result) => assert_eq!(result, 200),
291                    Poll::Pending => {
292                        let result = task.await;
293                        assert_eq!(result, 200);
294                    }
295                }
296            })
297            .await;
298    }
299
300    #[tokio::test]
301    async fn test_tokio_local_task_poll_result() {
302        let local_set = tokio::task::LocalSet::new();
303
304        local_set
305            .run_until(async {
306                let mut task: TokioLocalTask<&str> =
307                    LocalExecutor::spawn_local(&local_set, async { "local success" });
308
309                let waker = create_waker();
310                let mut cx = Context::from_waker(&waker);
311
312                match Pin::new(&mut task).poll_result(&mut cx) {
313                    Poll::Ready(Ok(result)) => assert_eq!(result, "local success"),
314                    Poll::Ready(Err(_)) => panic!("Local task should not fail"),
315                    Poll::Pending => {
316                        let result = task.result().await;
317                        assert!(result.is_ok());
318                        assert_eq!(result.unwrap(), "local success");
319                    }
320                }
321            })
322            .await;
323    }
324
325    #[tokio::test]
326    async fn test_tokio_local_task_panic_handling() {
327        let local_set = tokio::task::LocalSet::new();
328
329        local_set
330            .run_until(async {
331                let task: TokioLocalTask<()> = LocalExecutor::spawn_local(&local_set, async {
332                    panic!("local panic");
333                });
334
335                let result = task.result().await;
336                assert!(result.is_err());
337            })
338            .await;
339    }
340
341    #[test]
342    fn test_tokio_task_debug() {
343        let rt = tokio::runtime::Runtime::new().unwrap();
344        let task: TokioTask<i32> = Executor::spawn(&rt, async { 42 });
345        let debug_str = format!("{:?}", task);
346        assert!(debug_str.contains("TokioTask"));
347    }
348
349    #[test]
350    fn test_tokio_local_task_debug() {
351        let local_set = tokio::task::LocalSet::new();
352        let rt = tokio::runtime::Runtime::new().unwrap();
353
354        rt.block_on(local_set.run_until(async {
355            let task: TokioLocalTask<i32> = LocalExecutor::spawn_local(&local_set, async { 42 });
356            let debug_str = format!("{:?}", task);
357            assert!(debug_str.contains("TokioLocalTask"));
358        }));
359    }
360
361    #[test]
362    fn test_default_executor_debug() {
363        let executor = Runtime::new().expect("Failed to create Tokio runtime");
364        let debug_str = format!("{:?}", executor);
365        assert!(!debug_str.is_empty());
366    }
367
368    #[test]
369    fn test_task_result_future() {
370        let executor = Runtime::new().expect("Failed to create Tokio runtime");
371        let task: TokioTask<i32> = Executor::spawn(&executor, async { 123 });
372
373        let result = executor.block_on(task.result());
374        assert!(result.is_ok());
375        assert_eq!(result.unwrap(), 123);
376    }
377
378    #[test]
379    fn test_multiple_tasks_concurrency() {
380        let executor = Runtime::new().expect("Failed to create Tokio runtime");
381
382        let task1: TokioTask<i32> = Executor::spawn(&executor, async {
383            sleep(Duration::from_millis(50)).await;
384            1
385        });
386
387        let task2: TokioTask<i32> = Executor::spawn(&executor, async {
388            sleep(Duration::from_millis(25)).await;
389            2
390        });
391
392        let task3: TokioTask<i32> = Executor::spawn(&executor, async { 3 });
393
394        let (r1, r2, r3) = executor.block_on(async { tokio::join!(task1, task2, task3) });
395        assert_eq!(r1, 1);
396        assert_eq!(r2, 2);
397        assert_eq!(r3, 3);
398    }
399}