1#![allow(private_interfaces)]
17
18pub mod raw_task;
19
20use std::future::Future;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
23use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
24
25use crate::scheduler::{RawTask, SchedulerHandle};
26
27pub use crate::scheduler::TaskId;
30
31pub use crate::scheduler::gen_task_id;
34
35#[derive(Clone, Copy, PartialEq, Eq)]
38enum TaskState {
39 Running = 0,
41 Waiting = 1,
43 Completed = 2,
45 Cancelled = 3,
47 Panicked = 4,
49}
50
51impl TaskState {
52 fn from_u8(value: u8) -> Option<Self> {
55 match value {
56 0 => Some(Self::Running),
57 1 => Some(Self::Waiting),
58 2 => Some(Self::Completed),
59 3 => Some(Self::Cancelled),
60 4 => Some(Self::Panicked),
61 _ => None,
62 }
63 }
64
65 fn is_finished(self) -> bool {
68 matches!(self, Self::Completed | Self::Cancelled | Self::Panicked)
69 }
70}
71
72#[allow(dead_code)]
75struct TaskInner<T> {
76 id: TaskId,
78 state: AtomicU8,
80 ref_count: AtomicUsize,
82 scheduler: SchedulerHandle,
84 raw_task: AtomicUsize,
86 output: lock::OptionalCell<T>,
88 waiter: futures::task::AtomicWaker,
90}
91
92mod lock {
95 use std::mem::MaybeUninit;
96 use std::sync::Mutex;
97 use std::sync::atomic::{AtomicU8, Ordering};
98
99 pub(super) struct OptionalCell<T> {
100 inner: Mutex<MaybeUninit<T>>,
101 initialized: AtomicU8,
102 }
103
104 impl<T> OptionalCell<T> {
105 #[allow(dead_code)]
106 pub(super) fn new() -> Self {
107 Self {
108 inner: Mutex::new(MaybeUninit::uninit()),
109 initialized: AtomicU8::new(0),
110 }
111 }
112
113 #[allow(dead_code)]
114 pub(super) fn set(&self, value: T) {
115 let mut inner = self.inner.lock().unwrap();
116 *inner = MaybeUninit::new(value);
117 self.initialized.store(1, Ordering::Release);
118 }
119
120 #[allow(dead_code)]
121 pub(super) unsafe fn get(&self) -> Option<T> {
122 if self.initialized.load(Ordering::Acquire) == 1 {
123 let inner = self.inner.lock().unwrap();
124 Some(inner.assume_init_read())
126 } else {
127 None
128 }
129 }
130 }
131
132 unsafe impl<T: Send> Send for OptionalCell<T> {}
135 unsafe impl<T: Send> Sync for OptionalCell<T> {}
136
137 impl<T> Drop for OptionalCell<T> {
138 fn drop(&mut self) {
139 if self.initialized.load(Ordering::Acquire) == 1 {
140 let mut inner = self.inner.lock().unwrap();
141 unsafe {
143 std::ptr::drop_in_place(inner.as_mut_ptr());
144 }
145 }
146 }
147 }
148}
149
150#[allow(dead_code)]
156pub struct Task<T> {
157 inner: Arc<TaskInner<T>>,
158}
159
160impl<T> Task<T> {
161 #[allow(dead_code)]
164 fn new<F>(_future: F, id: TaskId, scheduler: SchedulerHandle) -> (Self, RawTask)
165 where
166 F: Future<Output = T> + Send + 'static,
167 T: Send + 'static,
168 {
169 let inner = Arc::new(TaskInner {
170 id,
171 state: AtomicU8::new(TaskState::Running as u8),
172 ref_count: AtomicUsize::new(2), scheduler,
174 raw_task: AtomicUsize::new(0),
175 output: lock::OptionalCell::new(),
176 waiter: futures::task::AtomicWaker::new(),
177 });
178
179 let raw_task = Arc::into_raw(inner.clone()) as RawTask;
180 inner.raw_task.store(raw_task as usize, Ordering::Release);
181
182 let task = Task { inner };
183 (task, raw_task)
184 }
185
186 #[must_use]
189 pub fn id(&self) -> TaskId {
190 self.inner.id
191 }
192
193 #[allow(dead_code)]
196 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<T> {
197 Poll::Pending
202 }
203}
204
205use std::pin::Pin;
206
207impl<T> Drop for Task<T> {
208 fn drop(&mut self) {
209 self.inner.raw_task.store(0, Ordering::Release);
212 }
213}
214
215#[allow(dead_code)]
221fn task_waker(inner: &Arc<TaskInner<()>>) -> Waker {
222 let cloned = inner.clone();
225 let data = Arc::into_raw(cloned) as *const ();
226
227 unsafe { Waker::from_raw(RawWaker::new(data, &RAW_WAKER_VTABLE)) }
228}
229
230#[allow(dead_code)]
236static RAW_WAKER_VTABLE: RawWakerVTable =
237 RawWakerVTable::new(raw_waker_clone, raw_waker_wake, raw_waker_wake_by_ref, raw_waker_drop);
238
239#[allow(dead_code)]
240unsafe fn raw_waker_clone(data: *const ()) -> RawWaker {
241 let inner = &*(data as *const TaskInner<()>);
244 inner.ref_count.fetch_add(1, Ordering::Relaxed);
245
246 RawWaker::new(data, &RAW_WAKER_VTABLE)
247}
248
249#[allow(dead_code)]
250unsafe fn raw_waker_wake(data: *const ()) {
251 raw_waker_wake_by_ref(data);
252 raw_waker_drop(data);
253}
254
255#[allow(dead_code)]
256unsafe fn raw_waker_wake_by_ref(data: *const ()) {
257 let inner = &*(data as *const TaskInner<()>);
258
259 if inner
262 .state
263 .compare_exchange(
264 TaskState::Waiting as u8,
265 TaskState::Running as u8,
266 Ordering::Release,
267 Ordering::Relaxed,
268 )
269 .is_err()
270 {
271 return; }
273
274 let raw_task = inner.raw_task.load(Ordering::Acquire) as RawTask;
277 if raw_task as usize != 0 {
278 let _ = inner.scheduler.submit(raw_task);
279 }
280}
281
282#[allow(dead_code)]
283unsafe fn raw_waker_drop(data: *const ()) {
284 let inner = &*(data as *const TaskInner<()>);
285
286 if inner.ref_count.fetch_sub(1, Ordering::Release) == 1 {
289 }
294}
295
296pub struct JoinHandle<T> {
302 inner: Option<Arc<TaskInner<T>>>,
303 raw_core: Option<raw_task::TaskRef>,
304}
305
306impl<T> JoinHandle<T> {
307 #[must_use]
310 pub fn id(&self) -> TaskId {
311 if let Some(refs) = &self.raw_core
312 && let Some(core) = refs.core()
313 {
314 return core.id();
315 }
316 self.inner.as_ref().map_or(0, |i| i.id)
317 }
318
319 #[must_use]
322 pub fn is_finished(&self) -> bool {
323 if let Some(refs) = &self.raw_core
324 && let Some(core) = refs.core()
325 {
326 return core.is_completed();
327 }
328 self.inner
329 .as_ref()
330 .and_then(|i| TaskState::from_u8(i.state.load(Ordering::Acquire)))
331 .is_some_and(TaskState::is_finished)
332 }
333
334 pub async fn wait(self) -> Result<T, JoinError> {
337 if let Some(refs) = &self.raw_core
338 && let Some(core) = refs.core()
339 {
340 std::future::poll_fn(|cx| {
341 if core.is_completed() {
342 Poll::Ready(())
343 } else {
344 cx.waker().wake_by_ref();
345 Poll::Pending
346 }
347 })
348 .await;
349 return unsafe { raw_task::read_output::<T>(core) }.ok_or(JoinError::TaskCancelled);
350 }
351 if let Some(inner) = self.inner {
352 return WaitForTask::new(inner).await;
353 }
354 Err(JoinError::TaskCancelled)
355 }
356}
357
358struct WaitForTask<T> {
361 inner: Option<Arc<TaskInner<T>>>,
362}
363
364impl<T> WaitForTask<T> {
365 fn new(inner: Arc<TaskInner<T>>) -> Self {
366 Self { inner: Some(inner) }
367 }
368}
369
370impl<T> Future for WaitForTask<T> {
371 type Output = Result<T, JoinError>;
372
373 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
374 let inner = self.inner.as_ref().unwrap();
375
376 inner.waiter.register(cx.waker());
378
379 let state = TaskState::from_u8(inner.state.load(Ordering::Acquire));
382
383 match state {
384 Some(TaskState::Completed) => {
385 let output = unsafe { inner.output.get() };
388 if let Some(result) = output {
389 self.inner = None;
390 Poll::Ready(Ok(result))
391 } else {
392 Poll::Ready(Err(JoinError::TaskCancelled))
395 }
396 },
397 Some(TaskState::Cancelled) => {
398 self.inner = None;
399 Poll::Ready(Err(JoinError::TaskCancelled))
400 },
401 Some(TaskState::Panicked) => {
402 self.inner = None;
403 Poll::Ready(Err(JoinError::TaskPanic))
404 },
405 Some(TaskState::Running | TaskState::Waiting) => {
406 Poll::Pending
409 },
410 None => Poll::Ready(Err(JoinError::TaskCancelled)),
411 }
412 }
413}
414
415impl<T> Drop for WaitForTask<T> {
416 fn drop(&mut self) {
417 self.inner = None;
420 }
421}
422
423#[derive(Debug, Clone, PartialEq, Eq)]
426pub enum JoinError {
427 TaskCancelled,
429 TaskPanic,
431}
432
433impl std::fmt::Display for JoinError {
434 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435 match self {
436 Self::TaskCancelled => write!(f, "Task was cancelled"),
437 Self::TaskPanic => write!(f, "Task panicked"),
438 }
439 }
440}
441
442impl std::error::Error for JoinError {}
443
444pub fn spawn<F, T>(future: F) -> JoinHandle<T>
473where
474 F: Future<Output = T> + Send + 'static,
475 T: Send + 'static,
476{
477 if let Some(handle) = crate::runtime::Handle::try_current() {
480 let (raw_task, task_ref) = raw_task::allocate_task(future, handle.scheduler().clone());
481
482 let id = task_ref.core().map_or(0, raw_task::TaskCore::id);
483 let _ = handle.scheduler().submit(raw_task);
484
485 return JoinHandle {
486 inner: Some(Arc::new(TaskInner {
487 id,
488 state: AtomicU8::new(TaskState::Running as u8),
489 ref_count: AtomicUsize::new(1),
490 scheduler: handle.scheduler().clone(),
491 raw_task: AtomicUsize::new(0),
492 output: lock::OptionalCell::new(),
493 waiter: futures::task::AtomicWaker::new(),
494 })),
495 raw_core: Some(task_ref),
496 };
497 }
498
499 let id = gen_task_id();
502 let inner = Arc::new(TaskInner {
503 id,
504 state: AtomicU8::new(TaskState::Running as u8),
505 ref_count: AtomicUsize::new(1),
506 scheduler: SchedulerHandle::new_default(),
507 raw_task: AtomicUsize::new(0),
508 output: lock::OptionalCell::new(),
509 waiter: futures::task::AtomicWaker::new(),
510 });
511
512 let inner_clone = inner.clone();
513
514 std::thread::spawn(move || {
515 let mut future = Box::pin(future);
516 let waker = Waker::noop();
517 let mut context = Context::from_waker(waker);
518
519 let result = loop {
520 match Pin::new(&mut future).poll(&mut context) {
521 Poll::Ready(value) => break value,
522 Poll::Pending => {
523 std::thread::sleep(std::time::Duration::from_millis(1));
524 },
525 }
526 };
527
528 inner_clone.output.set(result);
529 inner_clone
530 .state
531 .store(TaskState::Completed as u8, Ordering::Release);
532 inner_clone.waiter.wake();
533 });
534
535 JoinHandle {
536 inner: Some(inner),
537 raw_core: None,
538 }
539}
540
541pub fn block_on<F, T>(future: F) -> T
560where
561 F: Future<Output = T> + Send + 'static,
562 T: Send + 'static,
563{
564 use std::pin::Pin;
565 use std::sync::mpsc;
566 use std::task::{Context, Poll, RawWaker, Waker};
567 use std::{ptr, thread};
568
569 let (sender, receiver) = mpsc::channel();
572
573 let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_RAW_WAKER_VTABLE)) };
576
577 thread::spawn(move || {
580 let mut future = Box::pin(future);
581 let mut cx = Context::from_waker(&waker);
582
583 loop {
586 match Pin::as_mut(&mut future).poll(&mut cx) {
587 Poll::Ready(result) => {
588 let _ = sender.send(result);
591 break;
592 },
593 Poll::Pending => {
594 thread::sleep(std::time::Duration::from_millis(1));
600 },
601 }
602 }
603 });
604
605 receiver
608 .recv()
609 .unwrap_or_else(|_| panic!("block_on: Failed to receive result from executor"))
610}
611
612const NOOP_RAW_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
615 |_| RawWaker::new(std::ptr::null(), &NOOP_RAW_WAKER_VTABLE), |_| {}, |_| {}, |_| {}, );
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624
625 #[test]
626 fn test_task_id_generation() {
627 let id1 = gen_task_id();
628 let id2 = gen_task_id();
629 assert!(id2 > id1);
630 }
631
632 #[test]
633 fn test_task_state() {
634 assert_eq!(TaskState::Running as u8, 0);
635 assert_eq!(TaskState::Completed as u8, 2);
636 assert!(TaskState::Completed.is_finished());
637 assert!(!TaskState::Running.is_finished());
638 }
639
640 #[test]
641 fn test_join_error_display() {
642 assert_eq!(format!("{}", JoinError::TaskCancelled), "Task was cancelled");
643 assert_eq!(format!("{}", JoinError::TaskPanic), "Task panicked");
644 }
645
646 #[test]
647 fn test_join_error_equality() {
648 assert_eq!(JoinError::TaskCancelled, JoinError::TaskCancelled);
649 assert_eq!(JoinError::TaskPanic, JoinError::TaskPanic);
650 assert_ne!(JoinError::TaskCancelled, JoinError::TaskPanic);
651 }
652
653 #[test]
654 fn test_join_error_is_std_error() {
655 let err: Box<dyn std::error::Error> = Box::new(JoinError::TaskCancelled);
656 assert_eq!(err.to_string(), "Task was cancelled");
657
658 let err: Box<dyn std::error::Error> = Box::new(JoinError::TaskPanic);
659 assert_eq!(err.to_string(), "Task panicked");
660 }
661
662 #[test]
663 fn test_block_on_free_function() {
664 let result = block_on(async { 42i32 });
665 assert_eq!(result, 42);
666 }
667
668 #[test]
669 fn test_block_on_free_function_string() {
670 let result = block_on(async { String::from("hiver") });
671 assert_eq!(result, "hiver");
672 }
673
674 #[test]
675 fn test_block_on_free_function_unit() {
676 block_on(async {});
677 }
678
679 #[test]
680 fn test_block_on_free_function_complex() {
681 let result = block_on(async {
682 let a = 10;
683 let b = 20;
684 a + b
685 });
686 assert_eq!(result, 30);
687 }
688
689 #[test]
690 fn test_task_id_uniqueness() {
691 use std::collections::HashSet;
692 let ids: HashSet<_> = (0..100).map(|_| gen_task_id()).collect();
693 assert_eq!(ids.len(), 100, "all generated task IDs should be unique");
694 }
695
696 #[test]
697 fn test_task_state_is_finished() {
698 assert!(TaskState::Completed.is_finished());
699 assert!(TaskState::Cancelled.is_finished());
700 assert!(TaskState::Panicked.is_finished());
701 assert!(!TaskState::Running.is_finished());
702 assert!(!TaskState::Waiting.is_finished());
703 }
704
705 #[test]
706 fn test_task_state_from_u8_roundtrip() {
707 let states = [
708 TaskState::Running,
709 TaskState::Waiting,
710 TaskState::Completed,
711 TaskState::Cancelled,
712 TaskState::Panicked,
713 ];
714 for state in states {
715 let byte = state as u8;
716 let parsed = TaskState::from_u8(byte);
717 assert!(parsed == Some(state));
718 }
719 assert!(TaskState::from_u8(255).is_none());
720 }
721}