Skip to main content

hiver_runtime/
task.rs

1//! Task management module
2//! 任务管理模块
3//!
4//! # Overview / 概述
5//!
6//! This module provides task spawning and management with support for:
7//! - Task lifecycle tracking (Running, Completed, Cancelled)
8//! - Wake-up notifications for async polling
9//! - Join handles for awaiting task completion
10//!
11//! 本模块提供任务生成和管理,支持:
12//! - 任务生命周期跟踪(运行中、已完成、已取消)
13//! - 异步轮询的唤醒通知
14//! - 等待任务完成的join句柄
15
16#![allow(private_interfaces)]
17
18pub mod raw_task;
19
20use std::future::Future;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
23use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
24
25use crate::scheduler::{RawTask, SchedulerHandle};
26
27/// Task ID type
28/// 任务ID类型
29pub use crate::scheduler::TaskId;
30
31/// Generate a new unique task ID
32/// 生成新的唯一任务ID
33pub use crate::scheduler::gen_task_id;
34
35/// Task state for lifecycle tracking
36/// 任务生命周期跟踪状态
37#[derive(Clone, Copy, PartialEq, Eq)]
38enum TaskState {
39    /// Task is currently running / 任务正在运行
40    Running = 0,
41    /// Task is waiting for wake-up / 任务正在等待唤醒
42    Waiting = 1,
43    /// Task has completed successfully / 任务已成功完成
44    Completed = 2,
45    /// Task was cancelled / 任务已被取消
46    Cancelled = 3,
47    /// Task panicked / 任务发生panic
48    Panicked = 4,
49}
50
51impl TaskState {
52    /// Create from u8 value
53    /// 从u8值创建
54    fn from_u8(value: u8) -> Option<Self> {
55        match value {
56            0 => Some(Self::Running),
57            1 => Some(Self::Waiting),
58            2 => Some(Self::Completed),
59            3 => Some(Self::Cancelled),
60            4 => Some(Self::Panicked),
61            _ => None,
62        }
63    }
64
65    /// Check if task is finished
66    /// 检查任务是否已完成
67    fn is_finished(self) -> bool {
68        matches!(self, Self::Completed | Self::Cancelled | Self::Panicked)
69    }
70}
71
72/// Inner task data shared between task, waker, and join handle
73/// 任务、waker和join句柄之间共享的内部任务数据
74#[allow(dead_code)]
75struct TaskInner<T> {
76    /// Task ID / 任务ID
77    id: TaskId,
78    /// Task state / 任务状态
79    state: AtomicU8,
80    /// Reference count / 引用计数
81    ref_count: AtomicUsize,
82    /// Scheduler handle for re-scheduling / 用于重新调度的调度器句柄
83    scheduler: SchedulerHandle,
84    /// Raw task pointer for wake-up / 用于唤醒的原始任务指针
85    raw_task: AtomicUsize,
86    /// Task output (available when completed) / 任务输出(完成时可用)
87    output: lock::OptionalCell<T>,
88    /// Waker for waiters / 等待者的waker
89    waiter: futures::task::AtomicWaker,
90}
91
92/// Lock-free cell for optional task output
93/// 用于可选任务输出的线程安全单元
94mod lock {
95    use std::mem::MaybeUninit;
96    use std::sync::Mutex;
97    use std::sync::atomic::{AtomicU8, Ordering};
98
99    pub(super) struct OptionalCell<T> {
100        inner: Mutex<MaybeUninit<T>>,
101        initialized: AtomicU8,
102    }
103
104    impl<T> OptionalCell<T> {
105        #[allow(dead_code)]
106        pub(super) fn new() -> Self {
107            Self {
108                inner: Mutex::new(MaybeUninit::uninit()),
109                initialized: AtomicU8::new(0),
110            }
111        }
112
113        #[allow(dead_code)]
114        pub(super) fn set(&self, value: T) {
115            let mut inner = self.inner.lock().unwrap();
116            *inner = MaybeUninit::new(value);
117            self.initialized.store(1, Ordering::Release);
118        }
119
120        #[allow(dead_code)]
121        pub(super) unsafe fn get(&self) -> Option<T> {
122            if self.initialized.load(Ordering::Acquire) == 1 {
123                let inner = self.inner.lock().unwrap();
124                // Read the MaybeUninit value and assume it's initialized
125                Some(inner.assume_init_read())
126            } else {
127                None
128            }
129        }
130    }
131
132    // SAFETY: When T is Send, we can safely share this cell across threads
133    // The inner Mutex ensures proper synchronization
134    unsafe impl<T: Send> Send for OptionalCell<T> {}
135    unsafe impl<T: Send> Sync for OptionalCell<T> {}
136
137    impl<T> Drop for OptionalCell<T> {
138        fn drop(&mut self) {
139            if self.initialized.load(Ordering::Acquire) == 1 {
140                let mut inner = self.inner.lock().unwrap();
141                // Drop the initialized value
142                unsafe {
143                    std::ptr::drop_in_place(inner.as_mut_ptr());
144                }
145            }
146        }
147    }
148}
149
150/// A spawned task
151/// 生成的任务
152///
153/// Wraps a future and manages its execution lifecycle.
154/// 包装一个future并管理其执行生命周期。
155#[allow(dead_code)]
156pub struct Task<T> {
157    inner: Arc<TaskInner<T>>,
158}
159
160impl<T> Task<T> {
161    /// Create a new task
162    /// 创建新任务
163    #[allow(dead_code)]
164    fn new<F>(_future: F, id: TaskId, scheduler: SchedulerHandle) -> (Self, RawTask)
165    where
166        F: Future<Output = T> + Send + 'static,
167        T: Send + 'static,
168    {
169        let inner = Arc::new(TaskInner {
170            id,
171            state: AtomicU8::new(TaskState::Running as u8),
172            ref_count: AtomicUsize::new(2), // Task + waker
173            scheduler,
174            raw_task: AtomicUsize::new(0),
175            output: lock::OptionalCell::new(),
176            waiter: futures::task::AtomicWaker::new(),
177        });
178
179        let raw_task = Arc::into_raw(inner.clone()) as RawTask;
180        inner.raw_task.store(raw_task as usize, Ordering::Release);
181
182        let task = Task { inner };
183        (task, raw_task)
184    }
185
186    /// Get the task ID
187    /// 获取任务ID
188    #[must_use]
189    pub fn id(&self) -> TaskId {
190        self.inner.id
191    }
192
193    /// Poll the task future
194    /// 轮询任务future
195    #[allow(dead_code)]
196    fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<T> {
197        // This would be called by the executor
198        // For now, we'll use a simpler approach
199        // 这将由执行器调用
200        // 目前我们使用更简单的方法
201        Poll::Pending
202    }
203}
204
205use std::pin::Pin;
206
207impl<T> Drop for Task<T> {
208    fn drop(&mut self) {
209        // Clear the raw_task pointer to prevent use-after-free
210        // 清除raw_task指针以防止use-after-free
211        self.inner.raw_task.store(0, Ordering::Release);
212    }
213}
214
215/// Custom waker for task wake-up notifications
216/// 用于任务唤醒通知的自定义waker
217///
218/// Uses the vtable pattern for raw waker implementation.
219/// 使用vtable模式实现原始waker。
220#[allow(dead_code)]
221fn task_waker(inner: &Arc<TaskInner<()>>) -> Waker {
222    // Clone and convert to raw pointer
223    // 克隆并转换为原始指针
224    let cloned = inner.clone();
225    let data = Arc::into_raw(cloned) as *const ();
226
227    unsafe { Waker::from_raw(RawWaker::new(data, &RAW_WAKER_VTABLE)) }
228}
229
230/// VTable for the task waker
231/// 任务waker的VTable
232///
233/// Provides functions for cloning, waking, and dropping the waker.
234/// 提供克隆、唤醒和删除waker的函数。
235#[allow(dead_code)]
236static RAW_WAKER_VTABLE: RawWakerVTable =
237    RawWakerVTable::new(raw_waker_clone, raw_waker_wake, raw_waker_wake_by_ref, raw_waker_drop);
238
239#[allow(dead_code)]
240unsafe fn raw_waker_clone(data: *const ()) -> RawWaker {
241    // Increment reference count
242    // 增加引用计数
243    let inner = &*(data as *const TaskInner<()>);
244    inner.ref_count.fetch_add(1, Ordering::Relaxed);
245
246    RawWaker::new(data, &RAW_WAKER_VTABLE)
247}
248
249#[allow(dead_code)]
250unsafe fn raw_waker_wake(data: *const ()) {
251    raw_waker_wake_by_ref(data);
252    raw_waker_drop(data);
253}
254
255#[allow(dead_code)]
256unsafe fn raw_waker_wake_by_ref(data: *const ()) {
257    let inner = &*(data as *const TaskInner<()>);
258
259    // Try to transition from Waiting to Running
260    // 尝试从Waiting转换到Running
261    if inner
262        .state
263        .compare_exchange(
264            TaskState::Waiting as u8,
265            TaskState::Running as u8,
266            Ordering::Release,
267            Ordering::Relaxed,
268        )
269        .is_err()
270    {
271        return; // Not in waiting state
272    }
273
274    // Re-schedule the task
275    // 重新调度任务
276    let raw_task = inner.raw_task.load(Ordering::Acquire) as RawTask;
277    if raw_task as usize != 0 {
278        let _ = inner.scheduler.submit(raw_task);
279    }
280}
281
282#[allow(dead_code)]
283unsafe fn raw_waker_drop(data: *const ()) {
284    let inner = &*(data as *const TaskInner<()>);
285
286    // Decrement reference count
287    // 减少引用计数
288    if inner.ref_count.fetch_sub(1, Ordering::Release) == 1 {
289        // Last reference, deallocate
290        // 最后一个引用,释放内存
291        // Note: This is handled by Arc, we don't need explicit deallocation
292        // 注意:这由Arc处理,我们不需要显式释放
293    }
294}
295
296/// Join handle for spawned tasks
297/// 生成任务的join句柄
298///
299/// Allows awaiting task completion and retrieving the result.
300/// 允许等待任务完成并检索结果。
301pub struct JoinHandle<T> {
302    inner: Option<Arc<TaskInner<T>>>,
303    raw_core: Option<raw_task::TaskRef>,
304}
305
306impl<T> JoinHandle<T> {
307    /// Get the task ID
308    /// 获取任务ID
309    #[must_use]
310    pub fn id(&self) -> TaskId {
311        if let Some(refs) = &self.raw_core
312            && let Some(core) = refs.core()
313        {
314            return core.id();
315        }
316        self.inner.as_ref().map_or(0, |i| i.id)
317    }
318
319    /// Check if the task has finished (completed, cancelled, or panicked).
320    /// 检查任务是否已完成(成功完成、已取消或发生panic)。
321    #[must_use]
322    pub fn is_finished(&self) -> bool {
323        if let Some(refs) = &self.raw_core
324            && let Some(core) = refs.core()
325        {
326            return core.is_completed();
327        }
328        self.inner
329            .as_ref()
330            .and_then(|i| TaskState::from_u8(i.state.load(Ordering::Acquire)))
331            .is_some_and(TaskState::is_finished)
332    }
333
334    /// Wait for the task to complete and retrieve its result.
335    /// 等待任务完成并获取其结果。
336    pub async fn wait(self) -> Result<T, JoinError> {
337        if let Some(refs) = &self.raw_core
338            && let Some(core) = refs.core()
339        {
340            std::future::poll_fn(|cx| {
341                if core.is_completed() {
342                    Poll::Ready(())
343                } else {
344                    cx.waker().wake_by_ref();
345                    Poll::Pending
346                }
347            })
348            .await;
349            return unsafe { raw_task::read_output::<T>(core) }.ok_or(JoinError::TaskCancelled);
350        }
351        if let Some(inner) = self.inner {
352            return WaitForTask::new(inner).await;
353        }
354        Err(JoinError::TaskCancelled)
355    }
356}
357
358/// Future for waiting on task completion
359/// 等待任务完成的future
360struct WaitForTask<T> {
361    inner: Option<Arc<TaskInner<T>>>,
362}
363
364impl<T> WaitForTask<T> {
365    fn new(inner: Arc<TaskInner<T>>) -> Self {
366        Self { inner: Some(inner) }
367    }
368}
369
370impl<T> Future for WaitForTask<T> {
371    type Output = Result<T, JoinError>;
372
373    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
374        let inner = self.inner.as_ref().unwrap();
375
376        // Register waker so the completing task can wake us
377        inner.waiter.register(cx.waker());
378
379        // Check current state
380        // 检查当前状态
381        let state = TaskState::from_u8(inner.state.load(Ordering::Acquire));
382
383        match state {
384            Some(TaskState::Completed) => {
385                // Get the output
386                // 获取输出
387                let output = unsafe { inner.output.get() };
388                if let Some(result) = output {
389                    self.inner = None;
390                    Poll::Ready(Ok(result))
391                } else {
392                    // Should not happen
393                    // 不应该发生
394                    Poll::Ready(Err(JoinError::TaskCancelled))
395                }
396            },
397            Some(TaskState::Cancelled) => {
398                self.inner = None;
399                Poll::Ready(Err(JoinError::TaskCancelled))
400            },
401            Some(TaskState::Panicked) => {
402                self.inner = None;
403                Poll::Ready(Err(JoinError::TaskPanic))
404            },
405            Some(TaskState::Running | TaskState::Waiting) => {
406                // Task still running, park this future
407                // 任务仍在运行,暂停此future
408                Poll::Pending
409            },
410            None => Poll::Ready(Err(JoinError::TaskCancelled)),
411        }
412    }
413}
414
415impl<T> Drop for WaitForTask<T> {
416    fn drop(&mut self) {
417        // Clear inner to prevent holding reference
418        // 清除inner以防止持有引用
419        self.inner = None;
420    }
421}
422
423/// Error from joining a task
424/// 加入任务的错误
425#[derive(Debug, Clone, PartialEq, Eq)]
426pub enum JoinError {
427    /// Task was cancelled
428    TaskCancelled,
429    /// Task panicked
430    TaskPanic,
431}
432
433impl std::fmt::Display for JoinError {
434    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435        match self {
436            Self::TaskCancelled => write!(f, "Task was cancelled"),
437            Self::TaskPanic => write!(f, "Task panicked"),
438        }
439    }
440}
441
442impl std::error::Error for JoinError {}
443
444/// Spawn a new async task
445/// 生成新的异步任务
446///
447/// # Panics / 恐慌
448///
449/// Panics if called outside of a runtime context.
450/// 如果在运行时上下文之外调用则恐慌。
451///
452/// # Example / 示例
453///
454/// ```rust,no_run,ignore
455/// use hiver_runtime::task::spawn;
456///
457/// async fn my_task() -> i32 {
458///     42
459/// }
460///
461/// async fn main() {
462///     let handle = spawn(my_task());
463///     let result = handle.wait().await.unwrap();
464///     assert_eq!(result, 42);
465/// }
466/// ```
467///
468/// Note: This is a simplified implementation for Phase 2.
469/// Full integration with the runtime scheduler will be added in Phase 3.
470/// 注意:这是第2阶段的简化实现。
471/// 与运行时调度器的完全集成将在第3阶段添加。
472pub fn spawn<F, T>(future: F) -> JoinHandle<T>
473where
474    F: Future<Output = T> + Send + 'static,
475    T: Send + 'static,
476{
477    // Try to use the scheduler if a runtime context is available
478    // 如果运行时上下文可用,尝试使用调度器
479    if let Some(handle) = crate::runtime::Handle::try_current() {
480        let (raw_task, task_ref) = raw_task::allocate_task(future, handle.scheduler().clone());
481
482        let id = task_ref.core().map_or(0, raw_task::TaskCore::id);
483        let _ = handle.scheduler().submit(raw_task);
484
485        return JoinHandle {
486            inner: Some(Arc::new(TaskInner {
487                id,
488                state: AtomicU8::new(TaskState::Running as u8),
489                ref_count: AtomicUsize::new(1),
490                scheduler: handle.scheduler().clone(),
491                raw_task: AtomicUsize::new(0),
492                output: lock::OptionalCell::new(),
493                waiter: futures::task::AtomicWaker::new(),
494            })),
495            raw_core: Some(task_ref),
496        };
497    }
498
499    // Fallback: thread-per-task executor (when no runtime context)
500    // 回退:每任务一线程执行器(无运行时上下文时)
501    let id = gen_task_id();
502    let inner = Arc::new(TaskInner {
503        id,
504        state: AtomicU8::new(TaskState::Running as u8),
505        ref_count: AtomicUsize::new(1),
506        scheduler: SchedulerHandle::new_default(),
507        raw_task: AtomicUsize::new(0),
508        output: lock::OptionalCell::new(),
509        waiter: futures::task::AtomicWaker::new(),
510    });
511
512    let inner_clone = inner.clone();
513
514    std::thread::spawn(move || {
515        let mut future = Box::pin(future);
516        let waker = Waker::noop();
517        let mut context = Context::from_waker(waker);
518
519        let result = loop {
520            match Pin::new(&mut future).poll(&mut context) {
521                Poll::Ready(value) => break value,
522                Poll::Pending => {
523                    std::thread::sleep(std::time::Duration::from_millis(1));
524                },
525            }
526        };
527
528        inner_clone.output.set(result);
529        inner_clone
530            .state
531            .store(TaskState::Completed as u8, Ordering::Release);
532        inner_clone.waiter.wake();
533    });
534
535    JoinHandle {
536        inner: Some(inner),
537        raw_core: None,
538    }
539}
540
541/// Block on a future to completion
542/// 阻塞等待future完成
543///
544/// This function will block the current thread until the future completes.
545/// 此函数将阻塞当前线程直到future完成。
546///
547/// # Example / 示例
548///
549/// ```rust,no_run,ignore
550/// use hiver_runtime::task::block_on;
551///
552/// block_on(async {
553///     println!("Hello from async!");
554/// });
555/// ```
556///
557/// Note: This creates a temporary runtime for the execution.
558/// 注意:这会创建一个临时运行时来执行。
559pub fn block_on<F, T>(future: F) -> T
560where
561    F: Future<Output = T> + Send + 'static,
562    T: Send + 'static,
563{
564    use std::pin::Pin;
565    use std::sync::mpsc;
566    use std::task::{Context, Poll, RawWaker, Waker};
567    use std::{ptr, thread};
568
569    // Channel to communicate the result
570    // 通道用于通信结果
571    let (sender, receiver) = mpsc::channel();
572
573    // Create a no-op waker (we poll in a tight loop)
574    // 创建一个无操作的waker(我们在紧密循环中轮询)
575    let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_RAW_WAKER_VTABLE)) };
576
577    // Spawn a thread to run the future
578    // 生成一个线程来运行future
579    thread::spawn(move || {
580        let mut future = Box::pin(future);
581        let mut cx = Context::from_waker(&waker);
582
583        // Poll until complete
584        // 轮询直到完成
585        loop {
586            match Pin::as_mut(&mut future).poll(&mut cx) {
587                Poll::Ready(result) => {
588                    // Send result (ignore send errors - receiver may be dropped)
589                    // 发送结果(忽略发送错误 - 接收器可能已被删除)
590                    let _ = sender.send(result);
591                    break;
592                },
593                Poll::Pending => {
594                    // Yield to avoid busy-wait burning CPU.
595                    // A 1ms sleep is a reasonable trade-off between
596                    // responsiveness and CPU usage for a blocking executor.
597                    // 让出 CPU 避免忙等烧 CPU。
598                    // 1ms 休眠在响应性和 CPU 使用之间是合理的折衷。
599                    thread::sleep(std::time::Duration::from_millis(1));
600                },
601            }
602        }
603    });
604
605    // Block until result is ready
606    // 阻塞直到结果就绪
607    receiver
608        .recv()
609        .unwrap_or_else(|_| panic!("block_on: Failed to receive result from executor"))
610}
611
612// No-op raw waker vtable for simple polling
613// 用于简单轮询的无操作raw waker vtable
614const NOOP_RAW_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
615    |_| RawWaker::new(std::ptr::null(), &NOOP_RAW_WAKER_VTABLE), // clone
616    |_| {},                                                      // drop
617    |_| {},                                                      // wake
618    |_| {},                                                      // wake_by_ref
619);
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624
625    #[test]
626    fn test_task_id_generation() {
627        let id1 = gen_task_id();
628        let id2 = gen_task_id();
629        assert!(id2 > id1);
630    }
631
632    #[test]
633    fn test_task_state() {
634        assert_eq!(TaskState::Running as u8, 0);
635        assert_eq!(TaskState::Completed as u8, 2);
636        assert!(TaskState::Completed.is_finished());
637        assert!(!TaskState::Running.is_finished());
638    }
639
640    #[test]
641    fn test_join_error_display() {
642        assert_eq!(format!("{}", JoinError::TaskCancelled), "Task was cancelled");
643        assert_eq!(format!("{}", JoinError::TaskPanic), "Task panicked");
644    }
645
646    #[test]
647    fn test_join_error_equality() {
648        assert_eq!(JoinError::TaskCancelled, JoinError::TaskCancelled);
649        assert_eq!(JoinError::TaskPanic, JoinError::TaskPanic);
650        assert_ne!(JoinError::TaskCancelled, JoinError::TaskPanic);
651    }
652
653    #[test]
654    fn test_join_error_is_std_error() {
655        let err: Box<dyn std::error::Error> = Box::new(JoinError::TaskCancelled);
656        assert_eq!(err.to_string(), "Task was cancelled");
657
658        let err: Box<dyn std::error::Error> = Box::new(JoinError::TaskPanic);
659        assert_eq!(err.to_string(), "Task panicked");
660    }
661
662    #[test]
663    fn test_block_on_free_function() {
664        let result = block_on(async { 42i32 });
665        assert_eq!(result, 42);
666    }
667
668    #[test]
669    fn test_block_on_free_function_string() {
670        let result = block_on(async { String::from("hiver") });
671        assert_eq!(result, "hiver");
672    }
673
674    #[test]
675    fn test_block_on_free_function_unit() {
676        block_on(async {});
677    }
678
679    #[test]
680    fn test_block_on_free_function_complex() {
681        let result = block_on(async {
682            let a = 10;
683            let b = 20;
684            a + b
685        });
686        assert_eq!(result, 30);
687    }
688
689    #[test]
690    fn test_task_id_uniqueness() {
691        use std::collections::HashSet;
692        let ids: HashSet<_> = (0..100).map(|_| gen_task_id()).collect();
693        assert_eq!(ids.len(), 100, "all generated task IDs should be unique");
694    }
695
696    #[test]
697    fn test_task_state_is_finished() {
698        assert!(TaskState::Completed.is_finished());
699        assert!(TaskState::Cancelled.is_finished());
700        assert!(TaskState::Panicked.is_finished());
701        assert!(!TaskState::Running.is_finished());
702        assert!(!TaskState::Waiting.is_finished());
703    }
704
705    #[test]
706    fn test_task_state_from_u8_roundtrip() {
707        let states = [
708            TaskState::Running,
709            TaskState::Waiting,
710            TaskState::Completed,
711            TaskState::Cancelled,
712            TaskState::Panicked,
713        ];
714        for state in states {
715            let byte = state as u8;
716            let parsed = TaskState::from_u8(byte);
717            assert!(parsed == Some(state));
718        }
719        assert!(TaskState::from_u8(255).is_none());
720    }
721}