1#[cfg(feature = "std")]
7extern crate std;
8
9use crate::{Executor, LocalExecutor, Task};
10use alloc::boxed::Box;
11use core::{
12 future::Future,
13 pin::Pin,
14 task::{Context, Poll},
15};
16
17#[derive(Clone, Copy, Debug)]
23pub struct DefaultExecutor;
24
25pub use tokio::{runtime::Runtime, task::JoinHandle, task::LocalSet};
26
27impl DefaultExecutor {
28 pub fn new() -> Self {
30 Self
31 }
32}
33
34impl Default for DefaultExecutor {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40pub struct TokioTask<T> {
45 handle: tokio::task::JoinHandle<T>,
46}
47
48impl<T> core::fmt::Debug for TokioTask<T> {
49 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
50 f.debug_struct("TokioTask").finish_non_exhaustive()
51 }
52}
53
54impl<T: Send + 'static> Future for TokioTask<T> {
55 type Output = T;
56
57 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
58 match Pin::new(&mut self.handle).poll(cx) {
59 Poll::Ready(Ok(result)) => Poll::Ready(result),
60 Poll::Ready(Err(err)) => {
61 if err.is_panic() {
62 std::panic::resume_unwind(err.into_panic());
63 } else {
64 std::panic::panic_any("Task was cancelled")
66 }
67 }
68 Poll::Pending => Poll::Pending,
69 }
70 }
71}
72
73impl<T: Send + 'static> Task<T> for TokioTask<T> {
74 fn poll_result(
75 mut self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 ) -> Poll<Result<T, crate::Error>> {
78 match Pin::new(&mut self.handle).poll(cx) {
79 Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
80 Poll::Ready(Err(err)) => {
81 let error: crate::Error = if err.is_panic() {
82 err.into_panic()
83 } else {
84 Box::new("Task was cancelled")
85 };
86 Poll::Ready(Err(error))
87 }
88 Poll::Pending => Poll::Pending,
89 }
90 }
91
92 fn poll_cancel(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
93 let this = unsafe { self.get_unchecked_mut() };
94 this.handle.abort();
95 Poll::Ready(())
96 }
97}
98
99impl Executor for DefaultExecutor {
100 type Task<T: Send + 'static> = TokioTask<T>;
101
102 fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
103 where
104 Fut: Future<Output: Send> + Send + 'static,
105 {
106 let handle = tokio::task::spawn(fut);
107 TokioTask { handle }
108 }
109}
110
111pub struct TokioLocalTask<T> {
116 handle: tokio::task::JoinHandle<T>,
117}
118
119impl<T> core::fmt::Debug for TokioLocalTask<T> {
120 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121 f.debug_struct("TokioLocalTask").finish_non_exhaustive()
122 }
123}
124
125impl<T: 'static> Future for TokioLocalTask<T> {
126 type Output = T;
127
128 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129 match Pin::new(&mut self.handle).poll(cx) {
130 Poll::Ready(Ok(result)) => Poll::Ready(result),
131 Poll::Ready(Err(err)) => {
132 if err.is_panic() {
133 std::panic::resume_unwind(err.into_panic());
134 } else {
135 std::panic::panic_any("Task was cancelled")
137 }
138 }
139 Poll::Pending => Poll::Pending,
140 }
141 }
142}
143
144impl<T: 'static> Task<T> for TokioLocalTask<T> {
145 fn poll_result(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 ) -> Poll<Result<T, crate::Error>> {
149 match Pin::new(&mut self.handle).poll(cx) {
150 Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
151 Poll::Ready(Err(err)) => {
152 let error: crate::Error = if err.is_panic() {
153 err.into_panic()
154 } else {
155 Box::new("Task was cancelled")
156 };
157 Poll::Ready(Err(error))
158 }
159 Poll::Pending => Poll::Pending,
160 }
161 }
162
163 fn poll_cancel(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
164 let this = unsafe { self.get_unchecked_mut() };
165 this.handle.abort();
166 Poll::Ready(())
167 }
168}
169
170impl LocalExecutor for DefaultExecutor {
171 type Task<T: 'static> = TokioLocalTask<T>;
172
173 fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
174 where
175 Fut: Future + 'static,
176 {
177 let handle = tokio::task::spawn_local(fut);
178 TokioLocalTask { handle }
179 }
180}
181
182impl Executor for tokio::runtime::Runtime {
183 type Task<T: Send + 'static> = TokioTask<T>;
184
185 fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
186 where
187 Fut: Future<Output: Send> + Send + 'static,
188 {
189 let handle = self.spawn(fut);
190 TokioTask { handle }
191 }
192}
193
194impl LocalExecutor for tokio::task::LocalSet {
195 type Task<T: 'static> = TokioLocalTask<T>;
196
197 fn spawn<Fut>(&self, fut: Fut) -> Self::Task<Fut::Output>
198 where
199 Fut: Future + 'static,
200 {
201 let handle = self.spawn_local(fut);
202 TokioLocalTask { handle }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::{Executor, LocalExecutor, Task};
210 use alloc::task::Wake;
211 use alloc::{format, sync::Arc};
212 use core::future::Future;
213 use core::{
214 pin::Pin,
215 task::{Context, Poll, Waker},
216 };
217 use tokio::time::{Duration, sleep};
218
219 struct TestWaker;
220 impl Wake for TestWaker {
221 fn wake(self: Arc<Self>) {}
222 }
223
224 fn create_waker() -> Waker {
225 Arc::new(TestWaker).into()
226 }
227
228 #[tokio::test]
229 async fn test_default_executor_spawn() {
230 let executor = DefaultExecutor::new();
231 let task: TokioTask<i32> = Executor::spawn(&executor, async { 42 });
232 let result = task.await;
233 assert_eq!(result, 42);
234 }
235
236 #[tokio::test]
237 async fn test_default_executor_spawn_async_operation() {
238 let executor = DefaultExecutor::new();
239 let task: TokioTask<&str> = Executor::spawn(&executor, async {
240 sleep(Duration::from_millis(10)).await;
241 "completed"
242 });
243 let result = task.await;
244 assert_eq!(result, "completed");
245 }
246
247 #[tokio::test]
248 async fn test_tokio_task_future_impl() {
249 let executor = DefaultExecutor::new();
250 let mut task: TokioTask<i32> = Executor::spawn(&executor, async { 100 });
251
252 let waker = create_waker();
253 let mut cx = Context::from_waker(&waker);
254
255 match Pin::new(&mut task).poll(&mut cx) {
256 Poll::Ready(result) => assert_eq!(result, 100),
257 Poll::Pending => {
258 let result = task.await;
259 assert_eq!(result, 100);
260 }
261 }
262 }
263
264 #[tokio::test]
265 async fn test_tokio_task_poll_result() {
266 let executor = DefaultExecutor::new();
267 let mut task: TokioTask<&str> = Executor::spawn(&executor, async { "success" });
268
269 let waker = create_waker();
270 let mut cx = Context::from_waker(&waker);
271
272 match Pin::new(&mut task).poll_result(&mut cx) {
273 Poll::Ready(Ok(result)) => assert_eq!(result, "success"),
274 Poll::Ready(Err(_)) => panic!("Task should not fail"),
275 Poll::Pending => {
276 let result = task.result().await;
277 assert!(result.is_ok());
278 assert_eq!(result.unwrap(), "success");
279 }
280 }
281 }
282
283 #[tokio::test]
284 async fn test_tokio_task_cancel() {
285 let executor = DefaultExecutor::new();
286 let mut task: TokioTask<&str> = Executor::spawn(&executor, async {
287 sleep(Duration::from_secs(10)).await;
288 "should be cancelled"
289 });
290
291 let waker = create_waker();
292 let mut cx = Context::from_waker(&waker);
293
294 let cancel_result = Pin::new(&mut task).poll_cancel(&mut cx);
295 assert_eq!(cancel_result, Poll::Ready(()));
296 }
297
298 #[tokio::test]
299 async fn test_tokio_task_panic_handling() {
300 let executor = DefaultExecutor::new();
301 let task: TokioTask<()> = Executor::spawn(&executor, async {
302 panic!("test panic");
303 });
304
305 let result = task.result().await;
306 assert!(result.is_err());
307 }
308
309 #[tokio::test]
310 async fn test_default_executor_default() {
311 let executor1 = DefaultExecutor::new();
312 let executor2 = DefaultExecutor::new();
313
314 let task1: TokioTask<i32> = Executor::spawn(&executor1, async { 1 });
315 let task2: TokioTask<i32> = Executor::spawn(&executor2, async { 2 });
316
317 assert_eq!(task1.await, 1);
318 assert_eq!(task2.await, 2);
319 }
320
321 #[test]
322 fn test_runtime_executor_impl() {
323 let rt = tokio::runtime::Runtime::new().unwrap();
324 let task: TokioTask<&str> = Executor::spawn(&rt, async { "runtime task" });
325 let result = rt.block_on(task);
326 assert_eq!(result, "runtime task");
327 }
328
329 #[tokio::test]
330 async fn test_local_set_executor() {
331 let local_set = tokio::task::LocalSet::new();
332
333 local_set
334 .run_until(async {
335 let task: TokioLocalTask<&str> =
336 LocalExecutor::spawn(&local_set, async { "local task" });
337 let result = task.await;
338 assert_eq!(result, "local task");
339 })
340 .await;
341 }
342
343 #[tokio::test]
344 async fn test_tokio_local_task_future_impl() {
345 let local_set = tokio::task::LocalSet::new();
346
347 local_set
348 .run_until(async {
349 let mut task: TokioLocalTask<i32> = LocalExecutor::spawn(&local_set, async { 200 });
350
351 let waker = create_waker();
352 let mut cx = Context::from_waker(&waker);
353
354 match Pin::new(&mut task).poll(&mut cx) {
355 Poll::Ready(result) => assert_eq!(result, 200),
356 Poll::Pending => {
357 let result = task.await;
358 assert_eq!(result, 200);
359 }
360 }
361 })
362 .await;
363 }
364
365 #[tokio::test]
366 async fn test_tokio_local_task_poll_result() {
367 let local_set = tokio::task::LocalSet::new();
368
369 local_set
370 .run_until(async {
371 let mut task: TokioLocalTask<&str> =
372 LocalExecutor::spawn(&local_set, async { "local success" });
373
374 let waker = create_waker();
375 let mut cx = Context::from_waker(&waker);
376
377 match Pin::new(&mut task).poll_result(&mut cx) {
378 Poll::Ready(Ok(result)) => assert_eq!(result, "local success"),
379 Poll::Ready(Err(_)) => panic!("Local task should not fail"),
380 Poll::Pending => {
381 let result = task.result().await;
382 assert!(result.is_ok());
383 assert_eq!(result.unwrap(), "local success");
384 }
385 }
386 })
387 .await;
388 }
389
390 #[tokio::test]
391 async fn test_tokio_local_task_cancel() {
392 let local_set = tokio::task::LocalSet::new();
393
394 local_set
395 .run_until(async {
396 let mut task: TokioLocalTask<&str> = LocalExecutor::spawn(&local_set, async {
397 sleep(Duration::from_secs(10)).await;
398 "should be cancelled"
399 });
400
401 let waker = create_waker();
402 let mut cx = Context::from_waker(&waker);
403
404 let cancel_result = Pin::new(&mut task).poll_cancel(&mut cx);
405 assert_eq!(cancel_result, Poll::Ready(()));
406 })
407 .await;
408 }
409
410 #[tokio::test]
411 async fn test_tokio_local_task_panic_handling() {
412 let local_set = tokio::task::LocalSet::new();
413
414 local_set
415 .run_until(async {
416 let task: TokioLocalTask<()> = LocalExecutor::spawn(&local_set, async {
417 panic!("local panic");
418 });
419
420 let result = task.result().await;
421 assert!(result.is_err());
422 })
423 .await;
424 }
425
426 #[test]
427 fn test_tokio_task_debug() {
428 let rt = tokio::runtime::Runtime::new().unwrap();
429 let task: TokioTask<i32> = Executor::spawn(&rt, async { 42 });
430 let debug_str = format!("{:?}", task);
431 assert!(debug_str.contains("TokioTask"));
432 }
433
434 #[test]
435 fn test_tokio_local_task_debug() {
436 let local_set = tokio::task::LocalSet::new();
437 let rt = tokio::runtime::Runtime::new().unwrap();
438
439 rt.block_on(local_set.run_until(async {
440 let task: TokioLocalTask<i32> = LocalExecutor::spawn(&local_set, async { 42 });
441 let debug_str = format!("{:?}", task);
442 assert!(debug_str.contains("TokioLocalTask"));
443 }));
444 }
445
446 #[test]
447 fn test_default_executor_debug() {
448 let executor = DefaultExecutor::new();
449 let debug_str = format!("{:?}", executor);
450 assert!(debug_str.contains("DefaultExecutor"));
451 }
452
453 #[tokio::test]
454 async fn test_task_result_future() {
455 let executor = DefaultExecutor::new();
456 let task: TokioTask<i32> = Executor::spawn(&executor, async { 123 });
457
458 let result = task.result().await;
459 assert!(result.is_ok());
460 assert_eq!(result.unwrap(), 123);
461 }
462
463 #[tokio::test]
464 async fn test_task_cancel_future() {
465 let executor = DefaultExecutor::new();
466 let task: TokioTask<&str> = Executor::spawn(&executor, async {
467 sleep(Duration::from_secs(10)).await;
468 "cancelled"
469 });
470
471 task.cancel().await;
472 }
473
474 #[tokio::test]
475 async fn test_multiple_tasks_concurrency() {
476 let executor = DefaultExecutor::new();
477
478 let task1: TokioTask<i32> = Executor::spawn(&executor, async {
479 sleep(Duration::from_millis(50)).await;
480 1
481 });
482
483 let task2: TokioTask<i32> = Executor::spawn(&executor, async {
484 sleep(Duration::from_millis(25)).await;
485 2
486 });
487
488 let task3: TokioTask<i32> = Executor::spawn(&executor, async { 3 });
489
490 let (r1, r2, r3) = tokio::join!(task1, task2, task3);
491 assert_eq!(r1, 1);
492 assert_eq!(r2, 2);
493 assert_eq!(r3, 3);
494 }
495}