Skip to main content

moduvex_runtime/executor/
task.rs

1//! Task lifecycle types: `TaskHeader`, `Task`, `JoinHandle`.
2//!
3//! # Memory Model
4//!
5//! Two separate heap allocations per spawned future:
6//!
7//! 1. `Arc<TaskHeader>` — shared between executor (`Task`), all `Waker`s,
8//!    and `JoinHandle`. Contains the atomic state, vtable pointer, join-waker
9//!    slot, and the output slot (written on completion, read by JoinHandle).
10//!
11//! 2. `Box<TaskBody<F>>` (stored as `body_ptr: *mut ()` in `TaskHeader`) —
12//!    owns the erased `Pin<Box<F>>` (the live future). Freed by the executor
13//!    the moment the future resolves or the task is cancelled, independent of
14//!    when the JoinHandle reads the output.
15//!
16//! Separating the output from the body lets `drop_body` free the future
17//! immediately on completion while the output lives safely in the Arc until
18//! `JoinHandle::poll` retrieves it.
19
20use std::any::Any;
21use std::cell::UnsafeCell;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::atomic::{AtomicU32, Ordering};
25use std::sync::Arc;
26use std::task::{Context, Poll, Waker};
27
28// ── State constants ───────────────────────────────────────────────────────────
29
30pub(crate) const STATE_IDLE: u32 = 0;
31pub(crate) const STATE_SCHEDULED: u32 = 1;
32pub(crate) const STATE_RUNNING: u32 = 2;
33pub(crate) const STATE_COMPLETED: u32 = 3;
34pub(crate) const STATE_CANCELLED: u32 = 4;
35
36// ── JoinError ─────────────────────────────────────────────────────────────────
37
38/// Error returned by a `JoinHandle` when the task does not complete normally.
39#[derive(Debug)]
40pub enum JoinError {
41    /// Task was aborted via `JoinHandle::abort()`.
42    Cancelled,
43    /// Task's future panicked. Panic payload preserved.
44    Panic(Box<dyn Any + Send + 'static>),
45}
46
47impl std::fmt::Display for JoinError {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            JoinError::Cancelled => write!(f, "task was cancelled"),
51            JoinError::Panic(_) => write!(f, "task panicked"),
52        }
53    }
54}
55impl std::error::Error for JoinError {}
56
57// ── TaskVtable ────────────────────────────────────────────────────────────────
58
59/// Type-erased function pointers for a concrete `TaskBody<F>`.
60pub(crate) struct TaskVtable {
61    /// Poll the future once. Returns `true` when the future completed (Ready).
62    /// On Ready the output has been written to `TaskHeader.output`.
63    pub poll: unsafe fn(body: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool,
64
65    /// Free the `Box<TaskBody<F>>` allocation (future only; output lives in header).
66    pub drop_body: unsafe fn(body: *mut ()),
67}
68
69// ── TaskBody ──────────────────────────────────────────────────────────────────
70
71/// Heap allocation that owns the erased future.
72struct TaskBody<F> {
73    future: Pin<Box<F>>,
74}
75
76// ── Vtable implementations ────────────────────────────────────────────────────
77
78unsafe fn body_poll<F, T>(body_ptr: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool
79where
80    F: Future<Output = T>,
81    T: Send + 'static,
82{
83    // SAFETY: `body_ptr` is `Box::into_raw(Box<TaskBody<F>>)` cast to `*mut ()`.
84    let body = &mut *(body_ptr as *mut TaskBody<F>);
85    match body.future.as_mut().poll(cx) {
86        Poll::Ready(val) => {
87            // Store the boxed output into the header's output slot.
88            // SAFETY: state=RUNNING — only this call site writes `output`.
89            *header.output.get() = Some(Box::new(val) as Box<dyn Any + Send>);
90            true
91        }
92        Poll::Pending => false,
93    }
94}
95
96unsafe fn body_drop<F>(ptr: *mut ()) {
97    // SAFETY: `ptr` is `Box::into_raw(Box<TaskBody<F>>)`.
98    drop(Box::from_raw(ptr as *mut TaskBody<F>));
99}
100
101fn make_vtable<F, T>() -> &'static TaskVtable
102where
103    F: Future<Output = T>,
104    T: Send + 'static,
105{
106    &TaskVtable {
107        poll: body_poll::<F, T>,
108        drop_body: body_drop::<F>,
109    }
110}
111
112// ── TaskHeader ────────────────────────────────────────────────────────────────
113
114/// Shared, reference-counted task descriptor.
115///
116/// Lives inside an `Arc<TaskHeader>`. Every `Waker`, the executor's `Task`,
117/// and the user's `JoinHandle` all hold a clone of this Arc.
118pub(crate) struct TaskHeader {
119    /// Lifecycle state — see `STATE_*` constants.
120    pub state: AtomicU32,
121
122    /// Type-erased vtable for the concrete `F` / `T` types.
123    pub vtable: &'static TaskVtable,
124
125    /// Waker registered by `JoinHandle::poll`. Called when the task finishes.
126    ///
127    /// # Safety invariant
128    /// Written only when `state < STATE_COMPLETED` (by `JoinHandle::poll` on
129    /// the executor thread). Read+cleared only when transitioning to
130    /// COMPLETED/CANCELLED (by `Task::poll_task` / `Task::cancel`, also on the
131    /// executor thread). Single-threaded executor guarantees no data race.
132    pub join_waker: UnsafeCell<Option<Waker>>,
133
134    /// Raw pointer to the `Box<TaskBody<F>>` allocation.
135    ///
136    /// # Safety invariant
137    /// Non-null from `Task::new` until `drop_body` is called by either
138    /// `poll_task` (on completion) or `cancel`. Nulled immediately after.
139    /// Only read/written while `state == STATE_RUNNING` or during cancellation.
140    pub body_ptr: UnsafeCell<*mut ()>,
141
142    /// Output value written by the vtable's `poll` on completion.
143    /// Read (and taken) exactly once by `JoinHandle::poll`.
144    ///
145    /// # Safety invariant
146    /// Written when `state` transitions to COMPLETED. Read when `state` is
147    /// observed as COMPLETED by `JoinHandle::poll`. Single-threaded executor
148    /// prevents concurrent writes+reads.
149    pub output: UnsafeCell<Option<Box<dyn Any + Send>>>,
150}
151
152// SAFETY: All `UnsafeCell` fields in `TaskHeader` are protected by the
153// atomic `state` field and the single-threaded executor invariant.
154// No two threads access mutable fields concurrently.
155unsafe impl Send for TaskHeader {}
156unsafe impl Sync for TaskHeader {}
157
158// ── Task ──────────────────────────────────────────────────────────────────────
159
160/// Executor-owned handle to a spawned task.
161pub(crate) struct Task {
162    pub(crate) header: Arc<TaskHeader>,
163}
164
165impl Task {
166    /// Allocate a new task returning the executor `Task` + user `JoinHandle<T>`.
167    pub(crate) fn new<F, T>(future: F) -> (Task, JoinHandle<T>)
168    where
169        F: Future<Output = T> + 'static,
170        T: Send + 'static,
171    {
172        // Allocate and leak the future body (freed via vtable.drop_body).
173        let body: Box<TaskBody<F>> = Box::new(TaskBody {
174            future: Box::pin(future),
175        });
176        let body_ptr = Box::into_raw(body) as *mut ();
177
178        let header = Arc::new(TaskHeader {
179            state: AtomicU32::new(STATE_SCHEDULED),
180            vtable: make_vtable::<F, T>(),
181            join_waker: UnsafeCell::new(None),
182            body_ptr: UnsafeCell::new(body_ptr),
183            output: UnsafeCell::new(None),
184        });
185
186        let join_arc = Arc::clone(&header);
187        let task = Task { header };
188        let jh = JoinHandle {
189            header: join_arc,
190            _marker: std::marker::PhantomData,
191        };
192        (task, jh)
193    }
194
195    /// Poll the task's future once. Returns `true` when the future completed.
196    ///
197    /// State transitions: SCHEDULED → RUNNING → IDLE (Pending) | COMPLETED (Ready)
198    pub(crate) fn poll_task(&self, cx: &mut Context<'_>) -> bool {
199        let h = &self.header;
200        h.state.store(STATE_RUNNING, Ordering::Release);
201
202        // SAFETY: state=RUNNING — exclusive access to body_ptr.
203        let body_ptr = unsafe { *h.body_ptr.get() };
204        debug_assert!(!body_ptr.is_null(), "poll_task called on freed body");
205
206        // SAFETY: vtable matches the concrete types used in `new`.
207        let completed = unsafe { (h.vtable.poll)(body_ptr, h, cx) };
208
209        if completed {
210            // Free the future body — output is now in h.output.
211            // SAFETY: body_ptr valid; state=RUNNING prevents concurrent access.
212            unsafe {
213                (h.vtable.drop_body)(body_ptr);
214                *h.body_ptr.get() = std::ptr::null_mut();
215            }
216            h.state.store(STATE_COMPLETED, Ordering::Release);
217            // Wake the JoinHandle waiter.
218            // SAFETY: state=COMPLETED — no concurrent join_waker writes.
219            let waker = unsafe { (*h.join_waker.get()).take() };
220            if let Some(w) = waker {
221                w.wake();
222            }
223        } else {
224            h.state.store(STATE_IDLE, Ordering::Release);
225        }
226        completed
227    }
228
229    /// Cancel the task: drop the future body and wake the JoinHandle.
230    ///
231    /// Must be called at most once by the executor.
232    pub(crate) fn cancel(self) {
233        let h = &self.header;
234        // SAFETY: executor guarantees cancel is called while holding the Task,
235        // which means state is SCHEDULED or CANCELLED (set by abort()).
236        // Either way we own exclusive access to body_ptr.
237        let body_ptr = unsafe { *h.body_ptr.get() };
238        if !body_ptr.is_null() {
239            unsafe {
240                (h.vtable.drop_body)(body_ptr);
241                *h.body_ptr.get() = std::ptr::null_mut();
242            }
243        }
244        h.state.store(STATE_CANCELLED, Ordering::Release);
245        // Wake JoinHandle so it returns JoinError::Cancelled.
246        // SAFETY: state=CANCELLED — exclusive join_waker access.
247        let waker = unsafe { (*h.join_waker.get()).take() };
248        if let Some(w) = waker {
249            w.wake();
250        }
251        // Arc refcount decremented when `self` drops.
252    }
253}
254
255// ── JoinHandle ────────────────────────────────────────────────────────────────
256
257/// Future returned from `spawn()`. Resolves when the spawned task completes.
258pub struct JoinHandle<T> {
259    pub(crate) header: Arc<TaskHeader>,
260    _marker: std::marker::PhantomData<T>,
261}
262
263impl<T: Send + 'static> JoinHandle<T> {
264    /// Request cancellation. If the task hasn't started or is idle, it will be
265    /// dropped by the executor on its next scheduling pass.
266    pub fn abort(&self) {
267        // Try to flip IDLE → CANCELLED.
268        let _ = self.header.state.compare_exchange(
269            STATE_IDLE,
270            STATE_CANCELLED,
271            Ordering::AcqRel,
272            Ordering::Relaxed,
273        );
274        // Try to flip SCHEDULED → CANCELLED.
275        let _ = self.header.state.compare_exchange(
276            STATE_SCHEDULED,
277            STATE_CANCELLED,
278            Ordering::AcqRel,
279            Ordering::Relaxed,
280        );
281    }
282}
283
284impl<T: Send + 'static> Future for JoinHandle<T> {
285    type Output = Result<T, JoinError>;
286
287    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
288        let state = self.header.state.load(Ordering::Acquire);
289
290        match state {
291            STATE_COMPLETED => {
292                // Take the output the task wrote into the header.
293                // SAFETY: state=COMPLETED — the executor will not write output again.
294                // Single-threaded: no concurrent reads from another JoinHandle.
295                let boxed = unsafe { (*self.header.output.get()).take() };
296                match boxed {
297                    Some(any_val) => match any_val.downcast::<T>() {
298                        Ok(val) => Poll::Ready(Ok(*val)),
299                        Err(_) => Poll::Ready(Err(JoinError::Cancelled)), // type mismatch (bug)
300                    },
301                    None => Poll::Ready(Err(JoinError::Cancelled)), // already taken
302                }
303            }
304            STATE_CANCELLED => Poll::Ready(Err(JoinError::Cancelled)),
305            _ => {
306                // Task still in flight — register our waker.
307                // SAFETY: state is IDLE/SCHEDULED/RUNNING (not COMPLETED/CANCELLED).
308                // The executor will write join_waker only after observing COMPLETED/CANCELLED,
309                // which has not happened yet. Single-threaded: no concurrent poll.
310                unsafe {
311                    *self.header.join_waker.get() = Some(cx.waker().clone());
312                }
313                Poll::Pending
314            }
315        }
316    }
317}
318
319// ── Tests ─────────────────────────────────────────────────────────────────────
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use std::sync::atomic::AtomicBool;
325
326    #[test]
327    fn task_new_initial_state() {
328        let (task, _jh) = Task::new(async { 42u32 });
329        assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
330    }
331
332    #[test]
333    fn join_error_display() {
334        assert_eq!(JoinError::Cancelled.to_string(), "task was cancelled");
335        assert!(JoinError::Panic(Box::new("x"))
336            .to_string()
337            .contains("panicked"));
338    }
339
340    #[test]
341    fn abort_from_idle_sets_cancelled() {
342        let (task, jh) = Task::new(async { 1u32 });
343        task.header.state.store(STATE_IDLE, Ordering::Release);
344        jh.abort();
345        assert_eq!(task.header.state.load(Ordering::Acquire), STATE_CANCELLED);
346    }
347
348    #[test]
349    fn cancel_drops_future() {
350        let dropped = Arc::new(AtomicBool::new(false));
351        let d = dropped.clone();
352
353        struct Bomb(Arc<AtomicBool>);
354        impl Drop for Bomb {
355            fn drop(&mut self) {
356                self.0.store(true, Ordering::SeqCst);
357            }
358        }
359        impl Future for Bomb {
360            type Output = ();
361            fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<()> {
362                Poll::Pending
363            }
364        }
365
366        let (task, _jh) = Task::new(Bomb(d));
367        task.cancel();
368        assert!(
369            dropped.load(Ordering::SeqCst),
370            "future must be dropped on cancel"
371        );
372    }
373}