executor_core/
tokio.rs

1//! Integration with the Tokio async runtime.
2//!
3//! This module provides implementations of the [`Executor`] and [`LocalExecutor`] traits
4//! for the Tokio runtime, along with task wrappers that provide panic safety.
5
6#[cfg(feature = "std")]
7extern crate std;
8
9use crate::{Executor, LocalExecutor, Task};
10use alloc::boxed::Box;
11use core::{
12    future::Future,
13    pin::Pin,
14    task::{Context, Poll},
15};
16
17/// The default Tokio-based executor implementation.
18///
19/// This executor can spawn both Send and non-Send futures using Tokio's
20/// `spawn` and `spawn_local` functions respectively.
21///
22#[derive(Clone, Copy, Debug)]
23pub struct TokioExecutor;
24
25pub use tokio::{runtime::Runtime, task::JoinHandle, task::LocalSet};
26
27impl TokioExecutor {
28    /// Create a new [`TokioExecutor`].
29    pub fn new() -> Self {
30        Self
31    }
32}
33
34impl Default for TokioExecutor {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40/// Task wrapper for Tokio's `JoinHandle` that implements the [`Task`] trait.
41///
42/// This provides panic safety and proper error handling for tasks spawned
43/// with Tokio's `spawn` function.
44pub struct TokioTask<T> {
45    handle: tokio::task::JoinHandle<T>,
46}
47
48impl<T> core::fmt::Debug for TokioTask<T> {
49    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
50        f.debug_struct("TokioTask").finish_non_exhaustive()
51    }
52}
53
54impl<T: Send + 'static> Future for TokioTask<T> {
55    type Output = T;
56
57    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
58        match Pin::new(&mut self.handle).poll(cx) {
59            Poll::Ready(Ok(result)) => Poll::Ready(result),
60            Poll::Ready(Err(err)) => {
61                if err.is_panic() {
62                    std::panic::resume_unwind(err.into_panic());
63                } else {
64                    // Task was cancelled
65                    std::panic::panic_any("Task was cancelled")
66                }
67            }
68            Poll::Pending => Poll::Pending,
69        }
70    }
71}
72
73impl<T: Send + 'static> Task<T> for TokioTask<T> {
74    fn poll_result(
75        mut self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77    ) -> Poll<Result<T, crate::Error>> {
78        match Pin::new(&mut self.handle).poll(cx) {
79            Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
80            Poll::Ready(Err(err)) => {
81                let error: crate::Error = if err.is_panic() {
82                    err.into_panic()
83                } else {
84                    Box::new("Task was cancelled")
85                };
86                Poll::Ready(Err(error))
87            }
88            Poll::Pending => Poll::Pending,
89        }
90    }
91
92    fn poll_cancel(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
93        let this = unsafe { self.get_unchecked_mut() };
94        this.handle.abort();
95        Poll::Ready(())
96    }
97}
98
99impl Executor for TokioExecutor {
100    type Task<T: Send + 'static> = TokioTask<T>;
101
102    fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
103    where
104        Fut: Future<Output: Send> + Send + 'static,
105    {
106        let handle = tokio::task::spawn(fut);
107        TokioTask { handle }
108    }
109}
110
111/// Task wrapper for Tokio's local `JoinHandle` (non-Send futures).
112///
113/// This provides panic safety and proper error handling for tasks spawned
114/// with Tokio's `spawn_local` function.
115pub struct TokioLocalTask<T> {
116    handle: tokio::task::JoinHandle<T>,
117}
118
119impl<T> core::fmt::Debug for TokioLocalTask<T> {
120    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121        f.debug_struct("TokioLocalTask").finish_non_exhaustive()
122    }
123}
124
125impl<T: 'static> Future for TokioLocalTask<T> {
126    type Output = T;
127
128    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129        match Pin::new(&mut self.handle).poll(cx) {
130            Poll::Ready(Ok(result)) => Poll::Ready(result),
131            Poll::Ready(Err(err)) => {
132                if err.is_panic() {
133                    std::panic::resume_unwind(err.into_panic());
134                } else {
135                    // Task was cancelled
136                    std::panic::panic_any("Task was cancelled")
137                }
138            }
139            Poll::Pending => Poll::Pending,
140        }
141    }
142}
143
144impl<T: 'static> Task<T> for TokioLocalTask<T> {
145    fn poll_result(
146        mut self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148    ) -> Poll<Result<T, crate::Error>> {
149        match Pin::new(&mut self.handle).poll(cx) {
150            Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
151            Poll::Ready(Err(err)) => {
152                let error: crate::Error = if err.is_panic() {
153                    err.into_panic()
154                } else {
155                    Box::new("Task was cancelled")
156                };
157                Poll::Ready(Err(error))
158            }
159            Poll::Pending => Poll::Pending,
160        }
161    }
162
163    fn poll_cancel(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
164        let this = unsafe { self.get_unchecked_mut() };
165        this.handle.abort();
166        Poll::Ready(())
167    }
168}
169
170impl LocalExecutor for TokioExecutor {
171    type Task<T: 'static> = TokioLocalTask<T>;
172
173    fn spawn_local<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
174    where
175        Fut: Future + 'static,
176    {
177        let handle = tokio::task::spawn_local(fut);
178        TokioLocalTask { handle }
179    }
180}
181
182impl Executor for tokio::runtime::Runtime {
183    type Task<T: Send + 'static> = TokioTask<T>;
184
185    fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
186    where
187        Fut: Future<Output: Send> + Send + 'static,
188    {
189        let handle = self.spawn(fut);
190        TokioTask { handle }
191    }
192}
193
194impl LocalExecutor for tokio::task::LocalSet {
195    type Task<T: 'static> = TokioLocalTask<T>;
196
197    fn spawn_local<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
198    where
199        Fut: Future + 'static,
200    {
201        let handle = self.spawn_local(fut);
202        TokioLocalTask { handle }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::{Executor, LocalExecutor, Task};
210    use alloc::task::Wake;
211    use alloc::{format, sync::Arc};
212    use core::future::Future;
213    use core::{
214        pin::Pin,
215        task::{Context, Poll, Waker},
216    };
217    use tokio::time::{Duration, sleep};
218
219    struct TestWaker;
220    impl Wake for TestWaker {
221        fn wake(self: Arc<Self>) {}
222    }
223
224    fn create_waker() -> Waker {
225        Arc::new(TestWaker).into()
226    }
227
228    #[tokio::test]
229    async fn test_default_executor_spawn() {
230        let executor = TokioExecutor::new();
231        let task: TokioTask<i32> = Executor::spawn(&executor, async { 42 });
232        let result = task.await;
233        assert_eq!(result, 42);
234    }
235
236    #[tokio::test]
237    async fn test_default_executor_spawn_async_operation() {
238        let executor = TokioExecutor::new();
239        let task: TokioTask<&str> = Executor::spawn(&executor, async {
240            sleep(Duration::from_millis(10)).await;
241            "completed"
242        });
243        let result = task.await;
244        assert_eq!(result, "completed");
245    }
246
247    #[tokio::test]
248    async fn test_tokio_task_future_impl() {
249        let executor = TokioExecutor::new();
250        let mut task: TokioTask<i32> = Executor::spawn(&executor, async { 100 });
251
252        let waker = create_waker();
253        let mut cx = Context::from_waker(&waker);
254
255        match Pin::new(&mut task).poll(&mut cx) {
256            Poll::Ready(result) => assert_eq!(result, 100),
257            Poll::Pending => {
258                let result = task.await;
259                assert_eq!(result, 100);
260            }
261        }
262    }
263
264    #[tokio::test]
265    async fn test_tokio_task_poll_result() {
266        let executor = TokioExecutor::new();
267        let mut task: TokioTask<&str> = Executor::spawn(&executor, async { "success" });
268
269        let waker = create_waker();
270        let mut cx = Context::from_waker(&waker);
271
272        match Pin::new(&mut task).poll_result(&mut cx) {
273            Poll::Ready(Ok(result)) => assert_eq!(result, "success"),
274            Poll::Ready(Err(_)) => panic!("Task should not fail"),
275            Poll::Pending => {
276                let result = task.result().await;
277                assert!(result.is_ok());
278                assert_eq!(result.unwrap(), "success");
279            }
280        }
281    }
282
283    #[tokio::test]
284    async fn test_tokio_task_cancel() {
285        let executor = TokioExecutor::new();
286        let mut task: TokioTask<&str> = Executor::spawn(&executor, async {
287            sleep(Duration::from_secs(10)).await;
288            "should be cancelled"
289        });
290
291        let waker = create_waker();
292        let mut cx = Context::from_waker(&waker);
293
294        let cancel_result = Pin::new(&mut task).poll_cancel(&mut cx);
295        assert_eq!(cancel_result, Poll::Ready(()));
296    }
297
298    #[tokio::test]
299    async fn test_tokio_task_panic_handling() {
300        let executor = TokioExecutor::new();
301        let task: TokioTask<()> = Executor::spawn(&executor, async {
302            panic!("test panic");
303        });
304
305        let result = task.result().await;
306        assert!(result.is_err());
307    }
308
309    #[tokio::test]
310    async fn test_default_executor_default() {
311        let executor1 = TokioExecutor::new();
312        let executor2 = TokioExecutor::new();
313
314        let task1: TokioTask<i32> = Executor::spawn(&executor1, async { 1 });
315        let task2: TokioTask<i32> = Executor::spawn(&executor2, async { 2 });
316
317        assert_eq!(task1.await, 1);
318        assert_eq!(task2.await, 2);
319    }
320
321    #[test]
322    fn test_runtime_executor_impl() {
323        let rt = tokio::runtime::Runtime::new().unwrap();
324        let task: TokioTask<&str> = Executor::spawn(&rt, async { "runtime task" });
325        let result = rt.block_on(task);
326        assert_eq!(result, "runtime task");
327    }
328
329    #[tokio::test]
330    async fn test_local_set_executor() {
331        let local_set = tokio::task::LocalSet::new();
332
333        local_set
334            .run_until(async {
335                let task: TokioLocalTask<&str> =
336                    LocalExecutor::spawn_local(&local_set, async { "local task" });
337                let result = task.await;
338                assert_eq!(result, "local task");
339            })
340            .await;
341    }
342
343    #[tokio::test]
344    async fn test_tokio_local_task_future_impl() {
345        let local_set = tokio::task::LocalSet::new();
346
347        local_set
348            .run_until(async {
349                let mut task: TokioLocalTask<i32> =
350                    LocalExecutor::spawn_local(&local_set, async { 200 });
351
352                let waker = create_waker();
353                let mut cx = Context::from_waker(&waker);
354
355                match Pin::new(&mut task).poll(&mut cx) {
356                    Poll::Ready(result) => assert_eq!(result, 200),
357                    Poll::Pending => {
358                        let result = task.await;
359                        assert_eq!(result, 200);
360                    }
361                }
362            })
363            .await;
364    }
365
366    #[tokio::test]
367    async fn test_tokio_local_task_poll_result() {
368        let local_set = tokio::task::LocalSet::new();
369
370        local_set
371            .run_until(async {
372                let mut task: TokioLocalTask<&str> =
373                    LocalExecutor::spawn_local(&local_set, async { "local success" });
374
375                let waker = create_waker();
376                let mut cx = Context::from_waker(&waker);
377
378                match Pin::new(&mut task).poll_result(&mut cx) {
379                    Poll::Ready(Ok(result)) => assert_eq!(result, "local success"),
380                    Poll::Ready(Err(_)) => panic!("Local task should not fail"),
381                    Poll::Pending => {
382                        let result = task.result().await;
383                        assert!(result.is_ok());
384                        assert_eq!(result.unwrap(), "local success");
385                    }
386                }
387            })
388            .await;
389    }
390
391    #[tokio::test]
392    async fn test_tokio_local_task_cancel() {
393        let local_set = tokio::task::LocalSet::new();
394
395        local_set
396            .run_until(async {
397                let mut task: TokioLocalTask<&str> =
398                    LocalExecutor::spawn_local(&local_set, async {
399                        sleep(Duration::from_secs(10)).await;
400                        "should be cancelled"
401                    });
402
403                let waker = create_waker();
404                let mut cx = Context::from_waker(&waker);
405
406                let cancel_result = Pin::new(&mut task).poll_cancel(&mut cx);
407                assert_eq!(cancel_result, Poll::Ready(()));
408            })
409            .await;
410    }
411
412    #[tokio::test]
413    async fn test_tokio_local_task_panic_handling() {
414        let local_set = tokio::task::LocalSet::new();
415
416        local_set
417            .run_until(async {
418                let task: TokioLocalTask<()> = LocalExecutor::spawn_local(&local_set, async {
419                    panic!("local panic");
420                });
421
422                let result = task.result().await;
423                assert!(result.is_err());
424            })
425            .await;
426    }
427
428    #[test]
429    fn test_tokio_task_debug() {
430        let rt = tokio::runtime::Runtime::new().unwrap();
431        let task: TokioTask<i32> = Executor::spawn(&rt, async { 42 });
432        let debug_str = format!("{:?}", task);
433        assert!(debug_str.contains("TokioTask"));
434    }
435
436    #[test]
437    fn test_tokio_local_task_debug() {
438        let local_set = tokio::task::LocalSet::new();
439        let rt = tokio::runtime::Runtime::new().unwrap();
440
441        rt.block_on(local_set.run_until(async {
442            let task: TokioLocalTask<i32> = LocalExecutor::spawn_local(&local_set, async { 42 });
443            let debug_str = format!("{:?}", task);
444            assert!(debug_str.contains("TokioLocalTask"));
445        }));
446    }
447
448    #[test]
449    fn test_default_executor_debug() {
450        let executor = TokioExecutor::new();
451        let debug_str = format!("{:?}", executor);
452        assert!(debug_str.contains("TokioExecutor"));
453    }
454
455    #[tokio::test]
456    async fn test_task_result_future() {
457        let executor = TokioExecutor::new();
458        let task: TokioTask<i32> = Executor::spawn(&executor, async { 123 });
459
460        let result = task.result().await;
461        assert!(result.is_ok());
462        assert_eq!(result.unwrap(), 123);
463    }
464
465    #[tokio::test]
466    async fn test_task_cancel_future() {
467        let executor = TokioExecutor::new();
468        let task: TokioTask<&str> = Executor::spawn(&executor, async {
469            sleep(Duration::from_secs(10)).await;
470            "cancelled"
471        });
472
473        task.cancel().await;
474    }
475
476    #[tokio::test]
477    async fn test_multiple_tasks_concurrency() {
478        let executor = TokioExecutor::new();
479
480        let task1: TokioTask<i32> = Executor::spawn(&executor, async {
481            sleep(Duration::from_millis(50)).await;
482            1
483        });
484
485        let task2: TokioTask<i32> = Executor::spawn(&executor, async {
486            sleep(Duration::from_millis(25)).await;
487            2
488        });
489
490        let task3: TokioTask<i32> = Executor::spawn(&executor, async { 3 });
491
492        let (r1, r2, r3) = tokio::join!(task1, task2, task3);
493        assert_eq!(r1, 1);
494        assert_eq!(r2, 2);
495        assert_eq!(r3, 3);
496    }
497}