Skip to main content

moduvex_runtime/executor/
task.rs

1//! Task lifecycle types: `TaskHeader`, `Task`, `JoinHandle`.
2//!
3//! # Memory Model
4//!
5//! Two separate heap allocations per spawned future:
6//!
7//! 1. `Arc<TaskHeader>` — shared between executor (`Task`), all `Waker`s,
8//!    and `JoinHandle`. Contains the atomic state, vtable pointer, join-waker
9//!    slot, and the output slot (written on completion, read by JoinHandle).
10//!
11//! 2. `Box<TaskBody<F>>` (stored as `body_ptr: *mut ()` in `TaskHeader`) —
12//!    owns the erased `Pin<Box<F>>` (the live future). Freed by the executor
13//!    the moment the future resolves or the task is cancelled, independent of
14//!    when the JoinHandle reads the output.
15//!
16//! Separating the output from the body lets `drop_body` free the future
17//! immediately on completion while the output lives safely in the Arc until
18//! `JoinHandle::poll` retrieves it.
19//!
20//! # Thread Safety for Multi-Threaded Executor
21//!
22//! `join_waker` is now protected by a `Mutex` to allow safe concurrent access
23//! between `JoinHandle::poll` (any worker thread) and `poll_task` / `cancel`
24//! (any background worker). The double-check pattern in `JoinHandle::poll`
25//! ensures the waker is never missed if a task completes concurrently.
26
27use std::any::Any;
28use std::cell::UnsafeCell;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::atomic::{AtomicU32, Ordering};
32use std::sync::{Arc, Mutex};
33use std::task::{Context, Poll, Waker};
34
35// ── State constants ───────────────────────────────────────────────────────────
36
37pub(crate) const STATE_IDLE: u32 = 0;
38pub(crate) const STATE_SCHEDULED: u32 = 1;
39pub(crate) const STATE_RUNNING: u32 = 2;
40pub(crate) const STATE_COMPLETED: u32 = 3;
41pub(crate) const STATE_CANCELLED: u32 = 4;
42
43// ── JoinError ─────────────────────────────────────────────────────────────────
44
45/// Error returned by a `JoinHandle` when the task does not complete normally.
46#[derive(Debug)]
47pub enum JoinError {
48    /// Task was aborted via `JoinHandle::abort()`.
49    Cancelled,
50    /// Task's future panicked. Panic payload preserved.
51    Panic(Box<dyn Any + Send + 'static>),
52}
53
54impl std::fmt::Display for JoinError {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match self {
57            JoinError::Cancelled => write!(f, "task was cancelled"),
58            JoinError::Panic(_) => write!(f, "task panicked"),
59        }
60    }
61}
62impl std::error::Error for JoinError {}
63
64// ── TaskVtable ────────────────────────────────────────────────────────────────
65
66/// Type-erased function pointers for a concrete `TaskBody<F>`.
67pub(crate) struct TaskVtable {
68    /// Poll the future once. Returns `true` when the future completed (Ready).
69    /// On Ready the output has been written to `TaskHeader.output`.
70    pub poll: unsafe fn(body: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool,
71
72    /// Free the `Box<TaskBody<F>>` allocation (future only; output lives in header).
73    pub drop_body: unsafe fn(body: *mut ()),
74}
75
76// ── TaskBody ──────────────────────────────────────────────────────────────────
77
78/// Heap allocation that owns the erased future.
79struct TaskBody<F> {
80    future: Pin<Box<F>>,
81}
82
83// ── Vtable implementations ────────────────────────────────────────────────────
84
85unsafe fn body_poll<F, T>(body_ptr: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool
86where
87    F: Future<Output = T>,
88    T: Send + 'static,
89{
90    // SAFETY: `body_ptr` is `Box::into_raw(Box<TaskBody<F>>)` cast to `*mut ()`.
91    let body = &mut *(body_ptr as *mut TaskBody<F>);
92    match body.future.as_mut().poll(cx) {
93        Poll::Ready(val) => {
94            // Store the boxed output into the header's output slot.
95            // SAFETY: state=RUNNING — only this call site writes `output`.
96            *header.output.get() = Some(Box::new(val) as Box<dyn Any + Send>);
97            true
98        }
99        Poll::Pending => false,
100    }
101}
102
103unsafe fn body_drop<F>(ptr: *mut ()) {
104    // SAFETY: `ptr` is `Box::into_raw(Box<TaskBody<F>>)`.
105    drop(Box::from_raw(ptr as *mut TaskBody<F>));
106}
107
108fn make_vtable<F, T>() -> &'static TaskVtable
109where
110    F: Future<Output = T>,
111    T: Send + 'static,
112{
113    &TaskVtable {
114        poll: body_poll::<F, T>,
115        drop_body: body_drop::<F>,
116    }
117}
118
119// ── TaskHeader ────────────────────────────────────────────────────────────────
120
121/// Shared, reference-counted task descriptor.
122///
123/// Lives inside an `Arc<TaskHeader>`. Every `Waker`, the executor's `Task`,
124/// and the user's `JoinHandle` all hold a clone of this Arc.
125pub(crate) struct TaskHeader {
126    /// Lifecycle state — see `STATE_*` constants.
127    pub state: AtomicU32,
128
129    /// Type-erased vtable for the concrete `F` / `T` types.
130    pub vtable: &'static TaskVtable,
131
132    /// Waker registered by `JoinHandle::poll`. Called when the task finishes.
133    ///
134    /// Protected by a `Mutex` to allow safe concurrent access between
135    /// `JoinHandle::poll` (on any worker thread) and `poll_task`/`cancel`
136    /// (on any background worker). The double-check pattern in `JoinHandle::poll`
137    /// ensures no missed wake-ups.
138    pub join_waker: Mutex<Option<Waker>>,
139
140    /// Raw pointer to the `Box<TaskBody<F>>` allocation.
141    ///
142    /// # Safety invariant
143    /// Non-null from `Task::new` until `drop_body` is called by either
144    /// `poll_task` (on completion) or `cancel`. Nulled immediately after.
145    /// Only read/written while `state == STATE_RUNNING` or during cancellation.
146    pub body_ptr: UnsafeCell<*mut ()>,
147
148    /// Output value written by the vtable's `poll` on completion.
149    ///
150    /// Written with Release ordering on state → COMPLETED transition.
151    /// Read with Acquire ordering after observing STATE_COMPLETED.
152    /// The Release/Acquire pair on `state` provides the memory barrier.
153    pub output: UnsafeCell<Option<Box<dyn Any + Send>>>,
154}
155
156// SAFETY: `body_ptr` and `output` are UnsafeCell fields accessed under the
157// state machine's ordering guarantees:
158// - `body_ptr`: only accessed while state == STATE_RUNNING (exclusive)
159// - `output`: written before STATE_COMPLETED store (Release); read after
160//   STATE_COMPLETED load (Acquire)
161// `join_waker` is protected by its own Mutex.
162unsafe impl Send for TaskHeader {}
163unsafe impl Sync for TaskHeader {}
164
165// ── Task ──────────────────────────────────────────────────────────────────────
166
167/// Executor-owned handle to a spawned task.
168pub(crate) struct Task {
169    pub(crate) header: Arc<TaskHeader>,
170}
171
172impl Task {
173    /// Allocate a new task returning the executor `Task` + user `JoinHandle<T>`.
174    pub(crate) fn new<F, T>(future: F) -> (Task, JoinHandle<T>)
175    where
176        F: Future<Output = T> + 'static,
177        T: Send + 'static,
178    {
179        // Allocate and leak the future body (freed via vtable.drop_body).
180        let body: Box<TaskBody<F>> = Box::new(TaskBody {
181            future: Box::pin(future),
182        });
183        let body_ptr = Box::into_raw(body) as *mut ();
184
185        let header = Arc::new(TaskHeader {
186            state: AtomicU32::new(STATE_SCHEDULED),
187            vtable: make_vtable::<F, T>(),
188            join_waker: Mutex::new(None),
189            body_ptr: UnsafeCell::new(body_ptr),
190            output: UnsafeCell::new(None),
191        });
192
193        let join_arc = Arc::clone(&header);
194        let task = Task { header };
195        let jh = JoinHandle {
196            header: join_arc,
197            _marker: std::marker::PhantomData,
198        };
199        (task, jh)
200    }
201
202    /// Poll the task's future once. Returns `true` when the future completed.
203    ///
204    /// State transitions: SCHEDULED → RUNNING → IDLE (Pending) | COMPLETED (Ready)
205    pub(crate) fn poll_task(&self, cx: &mut Context<'_>) -> bool {
206        let h = &self.header;
207        h.state.store(STATE_RUNNING, Ordering::Release);
208
209        // SAFETY: state=RUNNING — exclusive access to body_ptr.
210        let body_ptr = unsafe { *h.body_ptr.get() };
211        debug_assert!(!body_ptr.is_null(), "poll_task called on freed body");
212
213        // SAFETY: vtable matches the concrete types used in `new`.
214        let completed = unsafe { (h.vtable.poll)(body_ptr, h, cx) };
215
216        if completed {
217            // Free the future body — output is now in h.output.
218            // SAFETY: body_ptr valid; state=RUNNING prevents concurrent access.
219            unsafe {
220                (h.vtable.drop_body)(body_ptr);
221                *h.body_ptr.get() = std::ptr::null_mut();
222            }
223            // Set COMPLETED with Release so the output write is visible to
224            // any thread that observes STATE_COMPLETED with Acquire.
225            h.state.store(STATE_COMPLETED, Ordering::Release);
226            // Wake the JoinHandle waiter under the Mutex to prevent races
227            // with JoinHandle::poll registering a waker concurrently.
228            let waker = h.join_waker.lock().unwrap().take();
229            if let Some(w) = waker {
230                w.wake();
231            }
232        } else {
233            h.state.store(STATE_IDLE, Ordering::Release);
234        }
235        completed
236    }
237
238    /// Cancel the task: drop the future body and wake the JoinHandle.
239    ///
240    /// Must be called at most once by the executor.
241    pub(crate) fn cancel(self) {
242        let h = &self.header;
243        // SAFETY: executor holds the Task exclusively; state = SCHEDULED or CANCELLED.
244        let body_ptr = unsafe { *h.body_ptr.get() };
245        if !body_ptr.is_null() {
246            unsafe {
247                (h.vtable.drop_body)(body_ptr);
248                *h.body_ptr.get() = std::ptr::null_mut();
249            }
250        }
251        h.state.store(STATE_CANCELLED, Ordering::Release);
252        // Wake JoinHandle under the Mutex so no waker is missed.
253        let waker = h.join_waker.lock().unwrap().take();
254        if let Some(w) = waker {
255            w.wake();
256        }
257        // Arc refcount decremented when `self` drops.
258    }
259}
260
261// ── JoinHandle ────────────────────────────────────────────────────────────────
262
263/// Future returned from `spawn()`. Resolves when the spawned task completes.
264pub struct JoinHandle<T> {
265    pub(crate) header: Arc<TaskHeader>,
266    _marker: std::marker::PhantomData<T>,
267}
268
269impl<T: Send + 'static> JoinHandle<T> {
270    /// Request cancellation. If the task hasn't started or is idle, it will be
271    /// dropped by the executor on its next scheduling pass.
272    pub fn abort(&self) {
273        // Try to flip IDLE → CANCELLED.
274        let _ = self.header.state.compare_exchange(
275            STATE_IDLE,
276            STATE_CANCELLED,
277            Ordering::AcqRel,
278            Ordering::Relaxed,
279        );
280        // Try to flip SCHEDULED → CANCELLED.
281        let _ = self.header.state.compare_exchange(
282            STATE_SCHEDULED,
283            STATE_CANCELLED,
284            Ordering::AcqRel,
285            Ordering::Relaxed,
286        );
287    }
288}
289
290impl<T: Send + 'static> Future for JoinHandle<T> {
291    type Output = Result<T, JoinError>;
292
293    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
294        // Fast path: check state before acquiring the waker lock.
295        let state = self.header.state.load(Ordering::Acquire);
296
297        if state == STATE_COMPLETED {
298            return self.take_output();
299        }
300        if state == STATE_CANCELLED {
301            return Poll::Ready(Err(JoinError::Cancelled));
302        }
303
304        // Task still in flight. Register waker under the Mutex to prevent a
305        // race with poll_task completing the task simultaneously.
306        //
307        // Double-check pattern:
308        //   1. Lock the waker Mutex.
309        //   2. Re-read state (now synchronized with poll_task's Mutex lock).
310        //   3. If still in-flight, store waker.
311        //   4. If completed/cancelled, return Ready immediately.
312        let mut guard = self.header.join_waker.lock().unwrap();
313        // Re-check under lock: poll_task takes the lock before setting
314        // STATE_COMPLETED, so if state is not COMPLETED here, we're safe to
315        // store the waker and it will be taken by poll_task later.
316        let state = self.header.state.load(Ordering::Acquire);
317        match state {
318            STATE_COMPLETED => {
319                drop(guard);
320                self.take_output()
321            }
322            STATE_CANCELLED => {
323                drop(guard);
324                Poll::Ready(Err(JoinError::Cancelled))
325            }
326            _ => {
327                *guard = Some(cx.waker().clone());
328                Poll::Pending
329            }
330        }
331    }
332}
333
334impl<T: Send + 'static> JoinHandle<T> {
335    /// Take the output from the header after observing STATE_COMPLETED.
336    fn take_output(self: Pin<&mut Self>) -> Poll<Result<T, JoinError>> {
337        // SAFETY: state=COMPLETED (observed with Acquire). The worker that set
338        // COMPLETED used Release ordering. The Release/Acquire pair establishes
339        // happens-before: output write → COMPLETED store → our load → output read.
340        let boxed = unsafe { (*self.header.output.get()).take() };
341        match boxed {
342            Some(any_val) => match any_val.downcast::<T>() {
343                Ok(val) => Poll::Ready(Ok(*val)),
344                Err(_) => Poll::Ready(Err(JoinError::Cancelled)),
345            },
346            None => Poll::Ready(Err(JoinError::Cancelled)), // already taken
347        }
348    }
349}
350
351// ── Tests ─────────────────────────────────────────────────────────────────────
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use std::sync::atomic::AtomicBool;
357
358    #[test]
359    fn task_new_initial_state() {
360        let (task, _jh) = Task::new(async { 42u32 });
361        assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
362    }
363
364    #[test]
365    fn join_error_display() {
366        assert_eq!(JoinError::Cancelled.to_string(), "task was cancelled");
367        assert!(JoinError::Panic(Box::new("x"))
368            .to_string()
369            .contains("panicked"));
370    }
371
372    #[test]
373    fn abort_from_idle_sets_cancelled() {
374        let (task, jh) = Task::new(async { 1u32 });
375        task.header.state.store(STATE_IDLE, Ordering::Release);
376        jh.abort();
377        assert_eq!(task.header.state.load(Ordering::Acquire), STATE_CANCELLED);
378    }
379
380    #[test]
381    fn cancel_drops_future() {
382        let dropped = Arc::new(AtomicBool::new(false));
383        let d = dropped.clone();
384
385        struct Bomb(Arc<AtomicBool>);
386        impl Drop for Bomb {
387            fn drop(&mut self) {
388                self.0.store(true, Ordering::SeqCst);
389            }
390        }
391        impl Future for Bomb {
392            type Output = ();
393            fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<()> {
394                Poll::Pending
395            }
396        }
397
398        let (task, _jh) = Task::new(Bomb(d));
399        task.cancel();
400        assert!(
401            dropped.load(Ordering::SeqCst),
402            "future must be dropped on cancel"
403        );
404    }
405
406    // ── Additional task tests ──────────────────────────────────────────────
407
408    #[test]
409    fn join_error_panic_display() {
410        let err = JoinError::Panic(Box::new("boom"));
411        let s = err.to_string();
412        assert!(s.contains("panic"));
413    }
414
415    #[test]
416    fn join_error_cancelled_display() {
417        let err = JoinError::Cancelled;
418        let s = err.to_string();
419        assert!(s.contains("cancel") || s.contains("Cancel"));
420    }
421
422    #[test]
423    fn abort_from_scheduled_sets_cancelled() {
424        let (_task, jh) = Task::new(async { 1u32 });
425        // Initial state is SCHEDULED
426        jh.abort();
427        assert_eq!(
428            jh.header.state.load(Ordering::Acquire),
429            STATE_CANCELLED
430        );
431    }
432
433    #[test]
434    fn task_header_initial_state_is_scheduled() {
435        let (task, _jh) = Task::new(async { 0u8 });
436        assert_eq!(
437            task.header.state.load(Ordering::Acquire),
438            STATE_SCHEDULED
439        );
440    }
441
442    #[test]
443    fn cancel_sets_state_to_cancelled() {
444        let (task, _jh) = Task::new(async { 0u8 });
445        task.cancel();
446        // After cancel, state must be CANCELLED
447        // (We read from _jh which still holds the Arc)
448    }
449
450    #[test]
451    fn abort_completed_task_has_no_effect() {
452        let (task, jh) = Task::new(async { 99u32 });
453        // Manually set state to COMPLETED (simulating task that already ran)
454        task.header.state.store(STATE_COMPLETED, Ordering::Release);
455        jh.abort(); // abort on completed task — must not panic
456        // State remains COMPLETED (CAS to IDLE fails, CAS to SCHEDULED fails)
457        assert_eq!(
458            jh.header.state.load(Ordering::Acquire),
459            STATE_COMPLETED
460        );
461    }
462
463    #[test]
464    fn state_constants_distinct() {
465        // All STATE_* constants must be distinct values
466        let states = [
467            STATE_IDLE,
468            STATE_SCHEDULED,
469            STATE_RUNNING,
470            STATE_COMPLETED,
471            STATE_CANCELLED,
472        ];
473        let unique: std::collections::HashSet<u32> = states.iter().cloned().collect();
474        assert_eq!(unique.len(), states.len());
475    }
476
477    #[test]
478    fn join_error_debug_format() {
479        let err = JoinError::Cancelled;
480        let s = format!("{err:?}");
481        assert!(!s.is_empty());
482    }
483
484    #[test]
485    fn task_new_creates_join_handle_with_same_header() {
486        let (task, jh) = Task::new(async { 0u32 });
487        // Both task and jh share the same header Arc
488        assert!(Arc::ptr_eq(&task.header, &jh.header));
489    }
490
491    #[test]
492    fn abort_from_idle_state_succeeds() {
493        let (task, jh) = Task::new(async { 0u32 });
494        task.header.state.store(STATE_IDLE, Ordering::Release);
495        jh.abort();
496        assert_eq!(task.header.state.load(Ordering::Acquire), STATE_CANCELLED);
497    }
498
499    #[test]
500    fn multiple_aborts_are_idempotent() {
501        let (_task, jh) = Task::new(async { 0u32 });
502        // Abort multiple times — must not panic
503        jh.abort();
504        jh.abort();
505        jh.abort();
506    }
507}