executor_core/
tokio.rs

1//! Integration with the Tokio async runtime.
2//!
3//! This module provides implementations of the [`Executor`] and [`LocalExecutor`] traits
4//! for the Tokio runtime, along with task wrappers that provide panic safety.
5
6#[cfg(feature = "std")]
7extern crate std;
8
9use crate::{Executor, LocalExecutor, Task};
10use alloc::boxed::Box;
11use core::{
12    future::Future,
13    pin::Pin,
14    task::{Context, Poll},
15};
16
17/// The default Tokio-based executor implementation.
18///
19/// This executor can spawn both Send and non-Send futures using Tokio's
20/// `spawn` and `spawn_local` functions respectively.
21///
22#[derive(Clone, Copy, Debug)]
23pub struct TokioExecutor;
24
25pub use tokio::{runtime::Runtime, task::JoinHandle, task::LocalSet};
26
27impl TokioExecutor {
28    /// Create a new [`TokioExecutor`].
29    pub fn new() -> Self {
30        Self
31    }
32}
33
34impl Default for TokioExecutor {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40/// Task wrapper for Tokio's `JoinHandle` that implements the [`Task`] trait.
41///
42/// This provides panic safety and proper error handling for tasks spawned
43/// with Tokio's `spawn` function.
44pub struct TokioTask<T> {
45    handle: tokio::task::JoinHandle<T>,
46}
47
48impl<T> core::fmt::Debug for TokioTask<T> {
49    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
50        f.debug_struct("TokioTask").finish_non_exhaustive()
51    }
52}
53
54impl<T: Send + 'static> Future for TokioTask<T> {
55    type Output = T;
56
57    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
58        match Pin::new(&mut self.handle).poll(cx) {
59            Poll::Ready(Ok(result)) => Poll::Ready(result),
60            Poll::Ready(Err(err)) => {
61                if err.is_panic() {
62                    std::panic::resume_unwind(err.into_panic());
63                } else {
64                    // Task was cancelled
65                    std::panic::panic_any("Task was cancelled")
66                }
67            }
68            Poll::Pending => Poll::Pending,
69        }
70    }
71}
72
73impl<T: Send + 'static> Task<T> for TokioTask<T> {
74    fn poll_result(
75        mut self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77    ) -> Poll<Result<T, crate::Error>> {
78        match Pin::new(&mut self.handle).poll(cx) {
79            Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
80            Poll::Ready(Err(err)) => {
81                let error: crate::Error = if err.is_panic() {
82                    err.into_panic()
83                } else {
84                    Box::new("Task was cancelled")
85                };
86                Poll::Ready(Err(error))
87            }
88            Poll::Pending => Poll::Pending,
89        }
90    }
91
92    fn poll_cancel(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
93        let this = unsafe { self.get_unchecked_mut() };
94        this.handle.abort();
95        Poll::Ready(())
96    }
97}
98
99impl Executor for TokioExecutor {
100    type Task<T: Send + 'static> = TokioTask<T>;
101
102    fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
103    where
104        Fut: Future<Output: Send> + Send + 'static,
105    {
106        let handle = tokio::task::spawn(fut);
107        TokioTask { handle }
108    }
109}
110
111/// Task wrapper for Tokio's local `JoinHandle` (non-Send futures).
112///
113/// This provides panic safety and proper error handling for tasks spawned
114/// with Tokio's `spawn_local` function.
115pub struct TokioLocalTask<T> {
116    handle: tokio::task::JoinHandle<T>,
117}
118
119impl<T> core::fmt::Debug for TokioLocalTask<T> {
120    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121        f.debug_struct("TokioLocalTask").finish_non_exhaustive()
122    }
123}
124
125impl<T: 'static> Future for TokioLocalTask<T> {
126    type Output = T;
127
128    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129        match Pin::new(&mut self.handle).poll(cx) {
130            Poll::Ready(Ok(result)) => Poll::Ready(result),
131            Poll::Ready(Err(err)) => {
132                if err.is_panic() {
133                    std::panic::resume_unwind(err.into_panic());
134                } else {
135                    // Task was cancelled
136                    std::panic::panic_any("Task was cancelled")
137                }
138            }
139            Poll::Pending => Poll::Pending,
140        }
141    }
142}
143
144impl<T: 'static> Task<T> for TokioLocalTask<T> {
145    fn poll_result(
146        mut self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148    ) -> Poll<Result<T, crate::Error>> {
149        match Pin::new(&mut self.handle).poll(cx) {
150            Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
151            Poll::Ready(Err(err)) => {
152                let error: crate::Error = if err.is_panic() {
153                    err.into_panic()
154                } else {
155                    Box::new("Task was cancelled")
156                };
157                Poll::Ready(Err(error))
158            }
159            Poll::Pending => Poll::Pending,
160        }
161    }
162
163    fn poll_cancel(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
164        let this = unsafe { self.get_unchecked_mut() };
165        this.handle.abort();
166        Poll::Ready(())
167    }
168}
169
170impl LocalExecutor for TokioExecutor {
171    type Task<T: 'static> = TokioLocalTask<T>;
172
173    fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
174    where
175        Fut: Future + 'static,
176    {
177        let handle = tokio::task::spawn_local(fut);
178        TokioLocalTask { handle }
179    }
180}
181
182impl Executor for tokio::runtime::Runtime {
183    type Task<T: Send + 'static> = TokioTask<T>;
184
185    fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
186    where
187        Fut: Future<Output: Send> + Send + 'static,
188    {
189        let handle = self.spawn(fut);
190        TokioTask { handle }
191    }
192}
193
194impl LocalExecutor for tokio::task::LocalSet {
195    type Task<T: 'static> = TokioLocalTask<T>;
196
197    fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
198    where
199        Fut: Future + 'static,
200    {
201        let handle = self.spawn_local(fut);
202        TokioLocalTask { handle }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::{Executor, LocalExecutor, Task};
210    use alloc::task::Wake;
211    use alloc::{format, sync::Arc};
212    use core::future::Future;
213    use core::{
214        pin::Pin,
215        task::{Context, Poll, Waker},
216    };
217    use tokio::time::{Duration, sleep};
218
219    struct TestWaker;
220    impl Wake for TestWaker {
221        fn wake(self: Arc<Self>) {}
222    }
223
224    fn create_waker() -> Waker {
225        Arc::new(TestWaker).into()
226    }
227
228    #[tokio::test]
229    async fn test_default_executor_spawn() {
230        let executor = TokioExecutor::new();
231        let task: TokioTask<i32> = Executor::spawn(&executor, async { 42 });
232        let result = task.await;
233        assert_eq!(result, 42);
234    }
235
236    #[tokio::test]
237    async fn test_default_executor_spawn_async_operation() {
238        let executor = TokioExecutor::new();
239        let task: TokioTask<&str> = Executor::spawn(&executor, async {
240            sleep(Duration::from_millis(10)).await;
241            "completed"
242        });
243        let result = task.await;
244        assert_eq!(result, "completed");
245    }
246
247    #[tokio::test]
248    async fn test_tokio_task_future_impl() {
249        let executor = TokioExecutor::new();
250        let mut task: TokioTask<i32> = Executor::spawn(&executor, async { 100 });
251
252        let waker = create_waker();
253        let mut cx = Context::from_waker(&waker);
254
255        match Pin::new(&mut task).poll(&mut cx) {
256            Poll::Ready(result) => assert_eq!(result, 100),
257            Poll::Pending => {
258                let result = task.await;
259                assert_eq!(result, 100);
260            }
261        }
262    }
263
264    #[tokio::test]
265    async fn test_tokio_task_poll_result() {
266        let executor = TokioExecutor::new();
267        let mut task: TokioTask<&str> = Executor::spawn(&executor, async { "success" });
268
269        let waker = create_waker();
270        let mut cx = Context::from_waker(&waker);
271
272        match Pin::new(&mut task).poll_result(&mut cx) {
273            Poll::Ready(Ok(result)) => assert_eq!(result, "success"),
274            Poll::Ready(Err(_)) => panic!("Task should not fail"),
275            Poll::Pending => {
276                let result = task.result().await;
277                assert!(result.is_ok());
278                assert_eq!(result.unwrap(), "success");
279            }
280        }
281    }
282
283    #[tokio::test]
284    async fn test_tokio_task_cancel() {
285        let executor = TokioExecutor::new();
286        let mut task: TokioTask<&str> = Executor::spawn(&executor, async {
287            sleep(Duration::from_secs(10)).await;
288            "should be cancelled"
289        });
290
291        let waker = create_waker();
292        let mut cx = Context::from_waker(&waker);
293
294        let cancel_result = Pin::new(&mut task).poll_cancel(&mut cx);
295        assert_eq!(cancel_result, Poll::Ready(()));
296    }
297
298    #[tokio::test]
299    async fn test_tokio_task_panic_handling() {
300        let executor = TokioExecutor::new();
301        let task: TokioTask<()> = Executor::spawn(&executor, async {
302            panic!("test panic");
303        });
304
305        let result = task.result().await;
306        assert!(result.is_err());
307    }
308
309    #[tokio::test]
310    async fn test_default_executor_default() {
311        let executor1 = TokioExecutor::new();
312        let executor2 = TokioExecutor::new();
313
314        let task1: TokioTask<i32> = Executor::spawn(&executor1, async { 1 });
315        let task2: TokioTask<i32> = Executor::spawn(&executor2, async { 2 });
316
317        assert_eq!(task1.await, 1);
318        assert_eq!(task2.await, 2);
319    }
320
321    #[test]
322    fn test_runtime_executor_impl() {
323        let rt = tokio::runtime::Runtime::new().unwrap();
324        let task: TokioTask<&str> = Executor::spawn(&rt, async { "runtime task" });
325        let result = rt.block_on(task);
326        assert_eq!(result, "runtime task");
327    }
328
329    #[tokio::test]
330    async fn test_local_set_executor() {
331        let local_set = tokio::task::LocalSet::new();
332
333        local_set
334            .run_until(async {
335                let task: TokioLocalTask<&str> =
336                    LocalExecutor::spawn(&local_set, async { "local task" });
337                let result = task.await;
338                assert_eq!(result, "local task");
339            })
340            .await;
341    }
342
343    #[tokio::test]
344    async fn test_tokio_local_task_future_impl() {
345        let local_set = tokio::task::LocalSet::new();
346
347        local_set
348            .run_until(async {
349                let mut task: TokioLocalTask<i32> = LocalExecutor::spawn(&local_set, async { 200 });
350
351                let waker = create_waker();
352                let mut cx = Context::from_waker(&waker);
353
354                match Pin::new(&mut task).poll(&mut cx) {
355                    Poll::Ready(result) => assert_eq!(result, 200),
356                    Poll::Pending => {
357                        let result = task.await;
358                        assert_eq!(result, 200);
359                    }
360                }
361            })
362            .await;
363    }
364
365    #[tokio::test]
366    async fn test_tokio_local_task_poll_result() {
367        let local_set = tokio::task::LocalSet::new();
368
369        local_set
370            .run_until(async {
371                let mut task: TokioLocalTask<&str> =
372                    LocalExecutor::spawn(&local_set, async { "local success" });
373
374                let waker = create_waker();
375                let mut cx = Context::from_waker(&waker);
376
377                match Pin::new(&mut task).poll_result(&mut cx) {
378                    Poll::Ready(Ok(result)) => assert_eq!(result, "local success"),
379                    Poll::Ready(Err(_)) => panic!("Local task should not fail"),
380                    Poll::Pending => {
381                        let result = task.result().await;
382                        assert!(result.is_ok());
383                        assert_eq!(result.unwrap(), "local success");
384                    }
385                }
386            })
387            .await;
388    }
389
390    #[tokio::test]
391    async fn test_tokio_local_task_cancel() {
392        let local_set = tokio::task::LocalSet::new();
393
394        local_set
395            .run_until(async {
396                let mut task: TokioLocalTask<&str> = LocalExecutor::spawn(&local_set, async {
397                    sleep(Duration::from_secs(10)).await;
398                    "should be cancelled"
399                });
400
401                let waker = create_waker();
402                let mut cx = Context::from_waker(&waker);
403
404                let cancel_result = Pin::new(&mut task).poll_cancel(&mut cx);
405                assert_eq!(cancel_result, Poll::Ready(()));
406            })
407            .await;
408    }
409
410    #[tokio::test]
411    async fn test_tokio_local_task_panic_handling() {
412        let local_set = tokio::task::LocalSet::new();
413
414        local_set
415            .run_until(async {
416                let task: TokioLocalTask<()> = LocalExecutor::spawn(&local_set, async {
417                    panic!("local panic");
418                });
419
420                let result = task.result().await;
421                assert!(result.is_err());
422            })
423            .await;
424    }
425
426    #[test]
427    fn test_tokio_task_debug() {
428        let rt = tokio::runtime::Runtime::new().unwrap();
429        let task: TokioTask<i32> = Executor::spawn(&rt, async { 42 });
430        let debug_str = format!("{:?}", task);
431        assert!(debug_str.contains("TokioTask"));
432    }
433
434    #[test]
435    fn test_tokio_local_task_debug() {
436        let local_set = tokio::task::LocalSet::new();
437        let rt = tokio::runtime::Runtime::new().unwrap();
438
439        rt.block_on(local_set.run_until(async {
440            let task: TokioLocalTask<i32> = LocalExecutor::spawn(&local_set, async { 42 });
441            let debug_str = format!("{:?}", task);
442            assert!(debug_str.contains("TokioLocalTask"));
443        }));
444    }
445
446    #[test]
447    fn test_default_executor_debug() {
448        let executor = TokioExecutor::new();
449        let debug_str = format!("{:?}", executor);
450        assert!(debug_str.contains("TokioExecutor"));
451    }
452
453    #[tokio::test]
454    async fn test_task_result_future() {
455        let executor = TokioExecutor::new();
456        let task: TokioTask<i32> = Executor::spawn(&executor, async { 123 });
457
458        let result = task.result().await;
459        assert!(result.is_ok());
460        assert_eq!(result.unwrap(), 123);
461    }
462
463    #[tokio::test]
464    async fn test_task_cancel_future() {
465        let executor = TokioExecutor::new();
466        let task: TokioTask<&str> = Executor::spawn(&executor, async {
467            sleep(Duration::from_secs(10)).await;
468            "cancelled"
469        });
470
471        task.cancel().await;
472    }
473
474    #[tokio::test]
475    async fn test_multiple_tasks_concurrency() {
476        let executor = TokioExecutor::new();
477
478        let task1: TokioTask<i32> = Executor::spawn(&executor, async {
479            sleep(Duration::from_millis(50)).await;
480            1
481        });
482
483        let task2: TokioTask<i32> = Executor::spawn(&executor, async {
484            sleep(Duration::from_millis(25)).await;
485            2
486        });
487
488        let task3: TokioTask<i32> = Executor::spawn(&executor, async { 3 });
489
490        let (r1, r2, r3) = tokio::join!(task1, task2, task3);
491        assert_eq!(r1, 1);
492        assert_eq!(r2, 2);
493        assert_eq!(r3, 3);
494    }
495}