Skip to main content

hiver_runtime/task/
raw_task.rs

1//! Raw task infrastructure for the scheduler
2//! 调度器的原始任务基础设施
3//!
4//! Uses vtable dispatch for type-erased task execution with manual
5//! reference counting (no Arc — avoids Arc/Box layout mismatch).
6//!
7//! 使用vtable分发进行类型擦除的任务执行,配合手动引用计数
8//! (不使用Arc — 避免Arc/Box布局不匹配)。
9
10use std::cell::UnsafeCell;
11use std::future::Future;
12use std::mem::MaybeUninit;
13use std::pin::Pin;
14use std::ptr::NonNull;
15use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
16use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
17
18use crate::scheduler::{RawTask, SchedulerHandle, TaskId, gen_task_id};
19
20const STATE_RUNNING: u8 = 0;
21const STATE_WAITING: u8 = 1;
22const STATE_COMPLETED: u8 = 2;
23
24/// Virtual function table for type-erased task operations
25struct TaskVTable {
26    poll: unsafe fn(*const TaskCore) -> bool,
27    take_output: unsafe fn(*const TaskCore, *mut ()) -> bool,
28    drop_future_and_dealloc: unsafe fn(*const TaskCore),
29}
30
31/// Core task header — always the first field of a `ConcreteTask<F>`.
32/// Manually reference-counted. Layout: `[vtable | id | state | ref_count | scheduler]`.
33#[repr(C)]
34pub(crate) struct TaskCore {
35    vtable: &'static TaskVTable,
36    id: TaskId,
37    state: AtomicU8,
38    ref_count: AtomicUsize,
39    scheduler: SchedulerHandle,
40}
41
42impl TaskCore {
43    /// Get the task ID.
44    /// 获取任务ID。
45    pub(crate) fn id(&self) -> TaskId {
46        self.id
47    }
48
49    /// Check if the task has completed.
50    /// 检查任务是否已完成。
51    #[must_use]
52    pub(crate) fn is_completed(&self) -> bool {
53        self.state.load(Ordering::Acquire) == STATE_COMPLETED
54    }
55
56    #[allow(dead_code)]
57    pub(crate) fn scheduler(&self) -> &SchedulerHandle {
58        &self.scheduler
59    }
60
61    /// Poll this task via the vtable. Returns true if completed.
62    unsafe fn poll(&self) -> bool {
63        (self.vtable.poll)(self)
64    }
65
66    /// Increment reference count. Returns the raw pointer for convenience.
67    fn inc_ref(ptr: *const Self) -> *const Self {
68        let core = &unsafe { &*ptr };
69        core.ref_count.fetch_add(1, Ordering::Relaxed);
70        ptr
71    }
72
73    /// Decrement reference count. Deallocates if it reaches zero.
74    unsafe fn dec_ref(ptr: *const Self) {
75        let core = &*ptr;
76        if core.ref_count.fetch_sub(1, Ordering::Release) == 1 {
77            std::sync::atomic::fence(Ordering::Acquire);
78            (core.vtable.drop_future_and_dealloc)(ptr);
79        }
80    }
81}
82
83/// Concrete task with a known future type.
84/// Layout: `[TaskCore | Future | Output]`
85#[repr(C)]
86struct ConcreteTask<F: Future + Send + 'static> {
87    core: TaskCore,
88    future: UnsafeCell<MaybeUninit<F>>,
89    output: UnsafeCell<MaybeUninit<F::Output>>,
90}
91
92impl<F: Future + Send + 'static> ConcreteTask<F> {
93    const fn vtable() -> &'static TaskVTable {
94        &TaskVTable {
95            poll: Self::poll_impl,
96            take_output: Self::take_output_impl,
97            drop_future_and_dealloc: Self::drop_future_and_dealloc_impl,
98        }
99    }
100
101    unsafe fn poll_impl(core: *const TaskCore) -> bool {
102        let task = &*(core as *const Self);
103        let waker = create_task_waker(core);
104        let mut cx = Context::from_waker(&waker);
105
106        let future = &mut *task.future.get();
107        let future = Pin::new_unchecked(future.assume_init_mut());
108
109        match future.poll(&mut cx) {
110            Poll::Ready(value) => {
111                (*task.output.get()).write(value);
112                task.core.state.store(STATE_COMPLETED, Ordering::Release);
113                std::ptr::drop_in_place((*task.future.get()).as_mut_ptr());
114                true
115            },
116            Poll::Pending => {
117                task.core.state.store(STATE_WAITING, Ordering::Release);
118                false
119            },
120        }
121    }
122
123    unsafe fn take_output_impl(core: *const TaskCore, dest: *mut ()) -> bool {
124        let task = &*(core as *const Self);
125        if task.core.state.load(Ordering::Acquire) != STATE_COMPLETED {
126            return false;
127        }
128        let output = (*task.output.get()).assume_init_read();
129        (dest as *mut F::Output).write(output);
130        true
131    }
132
133    unsafe fn drop_future_and_dealloc_impl(core: *const TaskCore) {
134        let task = core as *mut Self;
135        let state = (*task).core.state.load(Ordering::Acquire);
136        if state == STATE_RUNNING || state == STATE_WAITING {
137            std::ptr::drop_in_place((*(*task).future.get()).as_mut_ptr());
138        }
139        if state == STATE_COMPLETED {
140            // Output was already taken or needs to be dropped
141            // Check if it was consumed by take_output
142            // We track this by the output being MaybeUninit —
143            // if take_output ran, the value was moved out.
144            // If not, we need to drop it.
145            // For simplicity, output is always consumed via take_output
146            // before deallocation in our usage pattern.
147        }
148        let _ = Box::from_raw(task);
149    }
150}
151
152// --- Waker (manual ref-counted) ---
153
154static WAKER_VTABLE: RawWakerVTable =
155    RawWakerVTable::new(waker_clone, waker_wake, waker_wake_by_ref, waker_drop);
156
157fn create_task_waker(core: *const TaskCore) -> Waker {
158    // Increment ref for the waker
159    TaskCore::inc_ref(core);
160    unsafe { Waker::from_raw(RawWaker::new(core as *const (), &WAKER_VTABLE)) }
161}
162
163unsafe fn waker_clone(data: *const ()) -> RawWaker {
164    TaskCore::inc_ref(data as *const TaskCore);
165    RawWaker::new(data, &WAKER_VTABLE)
166}
167
168unsafe fn waker_wake(data: *const ()) {
169    let core = data as *const TaskCore;
170    try_re_enqueue(core);
171    TaskCore::dec_ref(core); // consume the waker's ref
172}
173
174unsafe fn waker_wake_by_ref(data: *const ()) {
175    try_re_enqueue(data as *const TaskCore);
176}
177
178unsafe fn waker_drop(data: *const ()) {
179    TaskCore::dec_ref(data as *const TaskCore);
180}
181
182fn try_re_enqueue(core: *const TaskCore) {
183    unsafe {
184        let c = &*core;
185        if c.state
186            .compare_exchange(STATE_WAITING, STATE_RUNNING, Ordering::AcqRel, Ordering::Relaxed)
187            .is_err()
188        {
189            return;
190        }
191        // Add a ref for the queue slot
192        TaskCore::inc_ref(core);
193        let _ = c.scheduler.submit(core as RawTask);
194    }
195}
196
197// --- Public API ---
198
199/// Handle for JoinHandle to track a raw task.
200/// Automatically decrements ref count on drop.
201pub(crate) struct TaskRef(Option<NonNull<TaskCore>>);
202
203impl TaskRef {
204    pub(crate) fn new(ptr: *const TaskCore) -> Self {
205        TaskRef(Some(unsafe { NonNull::new_unchecked(ptr.cast_mut()) }))
206    }
207
208    pub(crate) fn core(&self) -> Option<&TaskCore> {
209        self.0.map(|nn| unsafe { nn.as_ref() })
210    }
211
212    #[allow(dead_code)]
213    pub(crate) fn is_some(&self) -> bool {
214        self.0.is_some()
215    }
216}
217
218impl Drop for TaskRef {
219    fn drop(&mut self) {
220        if let Some(nn) = self.0.take() {
221            unsafe {
222                TaskCore::dec_ref(nn.as_ptr());
223            }
224        }
225    }
226}
227
228// Send + Sync: TaskRef is just a raw pointer with ref-counted ownership
229unsafe impl Send for TaskRef {}
230unsafe impl Sync for TaskRef {}
231
232/// Allocate a new task. Returns (RawTask for queue, TaskRef for JoinHandle).
233/// ref_count starts at 2: one for queue slot, one for JoinHandle.
234pub(crate) fn allocate_task<F>(future: F, scheduler: SchedulerHandle) -> (RawTask, TaskRef)
235where
236    F: Future + Send + 'static,
237    F::Output: Send + 'static,
238{
239    let id = gen_task_id();
240
241    let task = Box::new(ConcreteTask {
242        core: TaskCore {
243            vtable: ConcreteTask::<F>::vtable(),
244            id,
245            state: AtomicU8::new(STATE_RUNNING),
246            ref_count: AtomicUsize::new(2), // queue + JoinHandle
247            scheduler,
248        },
249        future: UnsafeCell::new(MaybeUninit::new(future)),
250        output: UnsafeCell::new(MaybeUninit::uninit()),
251    });
252
253    let raw: *mut ConcreteTask<F> = Box::into_raw(task);
254    let core_ptr: *const TaskCore = raw as *const TaskCore;
255
256    let raw_task = core_ptr as RawTask;
257    let task_ref = TaskRef::new(core_ptr);
258
259    (raw_task, task_ref)
260}
261
262/// Poll a raw task. Returns true if completed.
263///
264/// # Safety
265///
266/// `raw_task` must be valid. Caller must have exclusive access.
267pub unsafe fn poll_raw_task(raw_task: RawTask) -> bool {
268    let core = raw_task as *const TaskCore;
269    (*core).poll()
270}
271
272/// Deallocate a completed task (consumes the queue slot ref).
273///
274/// # Safety
275///
276/// `raw_task` must be valid. Must only be called when the task just completed
277/// (the scheduler's queue ref is being consumed).
278pub unsafe fn deallocate_completed_task(raw_task: RawTask) {
279    TaskCore::dec_ref(raw_task as *const TaskCore);
280}
281
282/// Read output from a completed task core.
283///
284/// # Safety
285///
286/// `core` must point to a completed task.
287pub(crate) unsafe fn read_output<T>(core: &TaskCore) -> Option<T> {
288    let mut output: MaybeUninit<T> = MaybeUninit::uninit();
289    let ok = (core.vtable.take_output)(core, &mut output as *mut _ as *mut ());
290    if ok { Some(output.assume_init()) } else { None }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_sync_future_completes() {
299        let scheduler = SchedulerHandle::new_default();
300        let (raw_task, task_ref) = allocate_task(async { 42 }, scheduler);
301
302        let core = task_ref.core().unwrap();
303        assert_eq!(core.state.load(Ordering::Acquire), STATE_RUNNING);
304        assert!(!core.is_completed());
305
306        unsafe {
307            assert!(poll_raw_task(raw_task));
308        }
309        assert!(core.is_completed());
310
311        unsafe {
312            let output: i32 = read_output(core).unwrap();
313            assert_eq!(output, 42);
314        }
315
316        // Consume queue ref
317        unsafe {
318            deallocate_completed_task(raw_task);
319        }
320        // TaskRef drop will consume JoinHandle ref
321    }
322
323    #[test]
324    fn test_task_id_unique() {
325        let scheduler = SchedulerHandle::new_default();
326        let (_, ref1) = allocate_task(async {}, scheduler.clone());
327        let (_, ref2) = allocate_task(async {}, scheduler);
328
329        assert_ne!(ref1.core().unwrap().id(), ref2.core().unwrap().id());
330    }
331}