1#[cfg(feature = "std")]
7extern crate std;
8
9use crate::{Executor, LocalExecutor, Task};
10use alloc::boxed::Box;
11use core::{
12 future::Future,
13 pin::Pin,
14 task::{Context, Poll},
15};
16
17#[derive(Clone, Copy, Debug)]
23pub struct TokioExecutor;
24
25pub use tokio::{runtime::Runtime, task::JoinHandle, task::LocalSet};
26
27impl TokioExecutor {
28 pub fn new() -> Self {
30 Self
31 }
32}
33
34impl Default for TokioExecutor {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40pub struct TokioTask<T> {
45 handle: tokio::task::JoinHandle<T>,
46}
47
48impl<T> core::fmt::Debug for TokioTask<T> {
49 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
50 f.debug_struct("TokioTask").finish_non_exhaustive()
51 }
52}
53
54impl<T: Send + 'static> Future for TokioTask<T> {
55 type Output = T;
56
57 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
58 match Pin::new(&mut self.handle).poll(cx) {
59 Poll::Ready(Ok(result)) => Poll::Ready(result),
60 Poll::Ready(Err(err)) => {
61 if err.is_panic() {
62 std::panic::resume_unwind(err.into_panic());
63 } else {
64 std::panic::panic_any("Task was cancelled")
66 }
67 }
68 Poll::Pending => Poll::Pending,
69 }
70 }
71}
72
73impl<T: Send + 'static> Task<T> for TokioTask<T> {
74 fn poll_result(
75 mut self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 ) -> Poll<Result<T, crate::Error>> {
78 match Pin::new(&mut self.handle).poll(cx) {
79 Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
80 Poll::Ready(Err(err)) => {
81 let error: crate::Error = if err.is_panic() {
82 err.into_panic()
83 } else {
84 Box::new("Task was cancelled")
85 };
86 Poll::Ready(Err(error))
87 }
88 Poll::Pending => Poll::Pending,
89 }
90 }
91
92 fn poll_cancel(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
93 let this = unsafe { self.get_unchecked_mut() };
94 this.handle.abort();
95 Poll::Ready(())
96 }
97}
98
99impl Executor for TokioExecutor {
100 type Task<T: Send + 'static> = TokioTask<T>;
101
102 fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
103 where
104 Fut: Future<Output: Send> + Send + 'static,
105 {
106 let handle = tokio::task::spawn(fut);
107 TokioTask { handle }
108 }
109}
110
111pub struct TokioLocalTask<T> {
116 handle: tokio::task::JoinHandle<T>,
117}
118
119impl<T> core::fmt::Debug for TokioLocalTask<T> {
120 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121 f.debug_struct("TokioLocalTask").finish_non_exhaustive()
122 }
123}
124
125impl<T: 'static> Future for TokioLocalTask<T> {
126 type Output = T;
127
128 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129 match Pin::new(&mut self.handle).poll(cx) {
130 Poll::Ready(Ok(result)) => Poll::Ready(result),
131 Poll::Ready(Err(err)) => {
132 if err.is_panic() {
133 std::panic::resume_unwind(err.into_panic());
134 } else {
135 std::panic::panic_any("Task was cancelled")
137 }
138 }
139 Poll::Pending => Poll::Pending,
140 }
141 }
142}
143
144impl<T: 'static> Task<T> for TokioLocalTask<T> {
145 fn poll_result(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 ) -> Poll<Result<T, crate::Error>> {
149 match Pin::new(&mut self.handle).poll(cx) {
150 Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
151 Poll::Ready(Err(err)) => {
152 let error: crate::Error = if err.is_panic() {
153 err.into_panic()
154 } else {
155 Box::new("Task was cancelled")
156 };
157 Poll::Ready(Err(error))
158 }
159 Poll::Pending => Poll::Pending,
160 }
161 }
162
163 fn poll_cancel(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
164 let this = unsafe { self.get_unchecked_mut() };
165 this.handle.abort();
166 Poll::Ready(())
167 }
168}
169
170impl LocalExecutor for TokioExecutor {
171 type Task<T: 'static> = TokioLocalTask<T>;
172
173 fn spawn_local<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
174 where
175 Fut: Future + 'static,
176 {
177 let handle = tokio::task::spawn_local(fut);
178 TokioLocalTask { handle }
179 }
180}
181
182impl Executor for tokio::runtime::Runtime {
183 type Task<T: Send + 'static> = TokioTask<T>;
184
185 fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
186 where
187 Fut: Future<Output: Send> + Send + 'static,
188 {
189 let handle = self.spawn(fut);
190 TokioTask { handle }
191 }
192}
193
194impl LocalExecutor for tokio::task::LocalSet {
195 type Task<T: 'static> = TokioLocalTask<T>;
196
197 fn spawn_local<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
198 where
199 Fut: Future + 'static,
200 {
201 let handle = self.spawn_local(fut);
202 TokioLocalTask { handle }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::{Executor, LocalExecutor, Task};
210 use alloc::task::Wake;
211 use alloc::{format, sync::Arc};
212 use core::future::Future;
213 use core::{
214 pin::Pin,
215 task::{Context, Poll, Waker},
216 };
217 use tokio::time::{Duration, sleep};
218
219 struct TestWaker;
220 impl Wake for TestWaker {
221 fn wake(self: Arc<Self>) {}
222 }
223
224 fn create_waker() -> Waker {
225 Arc::new(TestWaker).into()
226 }
227
228 #[tokio::test]
229 async fn test_default_executor_spawn() {
230 let executor = TokioExecutor::new();
231 let task: TokioTask<i32> = Executor::spawn(&executor, async { 42 });
232 let result = task.await;
233 assert_eq!(result, 42);
234 }
235
236 #[tokio::test]
237 async fn test_default_executor_spawn_async_operation() {
238 let executor = TokioExecutor::new();
239 let task: TokioTask<&str> = Executor::spawn(&executor, async {
240 sleep(Duration::from_millis(10)).await;
241 "completed"
242 });
243 let result = task.await;
244 assert_eq!(result, "completed");
245 }
246
247 #[tokio::test]
248 async fn test_tokio_task_future_impl() {
249 let executor = TokioExecutor::new();
250 let mut task: TokioTask<i32> = Executor::spawn(&executor, async { 100 });
251
252 let waker = create_waker();
253 let mut cx = Context::from_waker(&waker);
254
255 match Pin::new(&mut task).poll(&mut cx) {
256 Poll::Ready(result) => assert_eq!(result, 100),
257 Poll::Pending => {
258 let result = task.await;
259 assert_eq!(result, 100);
260 }
261 }
262 }
263
264 #[tokio::test]
265 async fn test_tokio_task_poll_result() {
266 let executor = TokioExecutor::new();
267 let mut task: TokioTask<&str> = Executor::spawn(&executor, async { "success" });
268
269 let waker = create_waker();
270 let mut cx = Context::from_waker(&waker);
271
272 match Pin::new(&mut task).poll_result(&mut cx) {
273 Poll::Ready(Ok(result)) => assert_eq!(result, "success"),
274 Poll::Ready(Err(_)) => panic!("Task should not fail"),
275 Poll::Pending => {
276 let result = task.result().await;
277 assert!(result.is_ok());
278 assert_eq!(result.unwrap(), "success");
279 }
280 }
281 }
282
283 #[tokio::test]
284 async fn test_tokio_task_cancel() {
285 let executor = TokioExecutor::new();
286 let mut task: TokioTask<&str> = Executor::spawn(&executor, async {
287 sleep(Duration::from_secs(10)).await;
288 "should be cancelled"
289 });
290
291 let waker = create_waker();
292 let mut cx = Context::from_waker(&waker);
293
294 let cancel_result = Pin::new(&mut task).poll_cancel(&mut cx);
295 assert_eq!(cancel_result, Poll::Ready(()));
296 }
297
298 #[tokio::test]
299 async fn test_tokio_task_panic_handling() {
300 let executor = TokioExecutor::new();
301 let task: TokioTask<()> = Executor::spawn(&executor, async {
302 panic!("test panic");
303 });
304
305 let result = task.result().await;
306 assert!(result.is_err());
307 }
308
309 #[tokio::test]
310 async fn test_default_executor_default() {
311 let executor1 = TokioExecutor::new();
312 let executor2 = TokioExecutor::new();
313
314 let task1: TokioTask<i32> = Executor::spawn(&executor1, async { 1 });
315 let task2: TokioTask<i32> = Executor::spawn(&executor2, async { 2 });
316
317 assert_eq!(task1.await, 1);
318 assert_eq!(task2.await, 2);
319 }
320
321 #[test]
322 fn test_runtime_executor_impl() {
323 let rt = tokio::runtime::Runtime::new().unwrap();
324 let task: TokioTask<&str> = Executor::spawn(&rt, async { "runtime task" });
325 let result = rt.block_on(task);
326 assert_eq!(result, "runtime task");
327 }
328
329 #[tokio::test]
330 async fn test_local_set_executor() {
331 let local_set = tokio::task::LocalSet::new();
332
333 local_set
334 .run_until(async {
335 let task: TokioLocalTask<&str> =
336 LocalExecutor::spawn_local(&local_set, async { "local task" });
337 let result = task.await;
338 assert_eq!(result, "local task");
339 })
340 .await;
341 }
342
343 #[tokio::test]
344 async fn test_tokio_local_task_future_impl() {
345 let local_set = tokio::task::LocalSet::new();
346
347 local_set
348 .run_until(async {
349 let mut task: TokioLocalTask<i32> =
350 LocalExecutor::spawn_local(&local_set, async { 200 });
351
352 let waker = create_waker();
353 let mut cx = Context::from_waker(&waker);
354
355 match Pin::new(&mut task).poll(&mut cx) {
356 Poll::Ready(result) => assert_eq!(result, 200),
357 Poll::Pending => {
358 let result = task.await;
359 assert_eq!(result, 200);
360 }
361 }
362 })
363 .await;
364 }
365
366 #[tokio::test]
367 async fn test_tokio_local_task_poll_result() {
368 let local_set = tokio::task::LocalSet::new();
369
370 local_set
371 .run_until(async {
372 let mut task: TokioLocalTask<&str> =
373 LocalExecutor::spawn_local(&local_set, async { "local success" });
374
375 let waker = create_waker();
376 let mut cx = Context::from_waker(&waker);
377
378 match Pin::new(&mut task).poll_result(&mut cx) {
379 Poll::Ready(Ok(result)) => assert_eq!(result, "local success"),
380 Poll::Ready(Err(_)) => panic!("Local task should not fail"),
381 Poll::Pending => {
382 let result = task.result().await;
383 assert!(result.is_ok());
384 assert_eq!(result.unwrap(), "local success");
385 }
386 }
387 })
388 .await;
389 }
390
391 #[tokio::test]
392 async fn test_tokio_local_task_cancel() {
393 let local_set = tokio::task::LocalSet::new();
394
395 local_set
396 .run_until(async {
397 let mut task: TokioLocalTask<&str> =
398 LocalExecutor::spawn_local(&local_set, async {
399 sleep(Duration::from_secs(10)).await;
400 "should be cancelled"
401 });
402
403 let waker = create_waker();
404 let mut cx = Context::from_waker(&waker);
405
406 let cancel_result = Pin::new(&mut task).poll_cancel(&mut cx);
407 assert_eq!(cancel_result, Poll::Ready(()));
408 })
409 .await;
410 }
411
412 #[tokio::test]
413 async fn test_tokio_local_task_panic_handling() {
414 let local_set = tokio::task::LocalSet::new();
415
416 local_set
417 .run_until(async {
418 let task: TokioLocalTask<()> = LocalExecutor::spawn_local(&local_set, async {
419 panic!("local panic");
420 });
421
422 let result = task.result().await;
423 assert!(result.is_err());
424 })
425 .await;
426 }
427
428 #[test]
429 fn test_tokio_task_debug() {
430 let rt = tokio::runtime::Runtime::new().unwrap();
431 let task: TokioTask<i32> = Executor::spawn(&rt, async { 42 });
432 let debug_str = format!("{:?}", task);
433 assert!(debug_str.contains("TokioTask"));
434 }
435
436 #[test]
437 fn test_tokio_local_task_debug() {
438 let local_set = tokio::task::LocalSet::new();
439 let rt = tokio::runtime::Runtime::new().unwrap();
440
441 rt.block_on(local_set.run_until(async {
442 let task: TokioLocalTask<i32> = LocalExecutor::spawn_local(&local_set, async { 42 });
443 let debug_str = format!("{:?}", task);
444 assert!(debug_str.contains("TokioLocalTask"));
445 }));
446 }
447
448 #[test]
449 fn test_default_executor_debug() {
450 let executor = TokioExecutor::new();
451 let debug_str = format!("{:?}", executor);
452 assert!(debug_str.contains("TokioExecutor"));
453 }
454
455 #[tokio::test]
456 async fn test_task_result_future() {
457 let executor = TokioExecutor::new();
458 let task: TokioTask<i32> = Executor::spawn(&executor, async { 123 });
459
460 let result = task.result().await;
461 assert!(result.is_ok());
462 assert_eq!(result.unwrap(), 123);
463 }
464
465 #[tokio::test]
466 async fn test_task_cancel_future() {
467 let executor = TokioExecutor::new();
468 let task: TokioTask<&str> = Executor::spawn(&executor, async {
469 sleep(Duration::from_secs(10)).await;
470 "cancelled"
471 });
472
473 task.cancel().await;
474 }
475
476 #[tokio::test]
477 async fn test_multiple_tasks_concurrency() {
478 let executor = TokioExecutor::new();
479
480 let task1: TokioTask<i32> = Executor::spawn(&executor, async {
481 sleep(Duration::from_millis(50)).await;
482 1
483 });
484
485 let task2: TokioTask<i32> = Executor::spawn(&executor, async {
486 sleep(Duration::from_millis(25)).await;
487 2
488 });
489
490 let task3: TokioTask<i32> = Executor::spawn(&executor, async { 3 });
491
492 let (r1, r2, r3) = tokio::join!(task1, task2, task3);
493 assert_eq!(r1, 1);
494 assert_eq!(r2, 2);
495 assert_eq!(r3, 3);
496 }
497}