moduvex_runtime/executor/task.rs
1//! Task lifecycle types: `TaskHeader`, `Task`, `JoinHandle`.
2//!
3//! # Memory Model
4//!
5//! Two separate heap allocations per spawned future:
6//!
7//! 1. `Arc<TaskHeader>` — shared between executor (`Task`), all `Waker`s,
8//! and `JoinHandle`. Contains the atomic state, vtable pointer, join-waker
9//! slot, and the output slot (written on completion, read by JoinHandle).
10//!
11//! 2. `Box<TaskBody<F>>` (stored as `body_ptr: *mut ()` in `TaskHeader`) —
12//! owns the erased `Pin<Box<F>>` (the live future). Freed by the executor
13//! the moment the future resolves or the task is cancelled, independent of
14//! when the JoinHandle reads the output.
15//!
16//! Separating the output from the body lets `drop_body` free the future
17//! immediately on completion while the output lives safely in the Arc until
18//! `JoinHandle::poll` retrieves it.
19
20use std::any::Any;
21use std::cell::UnsafeCell;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::atomic::{AtomicU32, Ordering};
25use std::sync::Arc;
26use std::task::{Context, Poll, Waker};
27
28// ── State constants ───────────────────────────────────────────────────────────
29
30pub(crate) const STATE_IDLE: u32 = 0;
31pub(crate) const STATE_SCHEDULED: u32 = 1;
32pub(crate) const STATE_RUNNING: u32 = 2;
33pub(crate) const STATE_COMPLETED: u32 = 3;
34pub(crate) const STATE_CANCELLED: u32 = 4;
35
36// ── JoinError ─────────────────────────────────────────────────────────────────
37
38/// Error returned by a `JoinHandle` when the task does not complete normally.
39#[derive(Debug)]
40pub enum JoinError {
41 /// Task was aborted via `JoinHandle::abort()`.
42 Cancelled,
43 /// Task's future panicked. Panic payload preserved.
44 Panic(Box<dyn Any + Send + 'static>),
45}
46
47impl std::fmt::Display for JoinError {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 JoinError::Cancelled => write!(f, "task was cancelled"),
51 JoinError::Panic(_) => write!(f, "task panicked"),
52 }
53 }
54}
55impl std::error::Error for JoinError {}
56
57// ── TaskVtable ────────────────────────────────────────────────────────────────
58
59/// Type-erased function pointers for a concrete `TaskBody<F>`.
60pub(crate) struct TaskVtable {
61 /// Poll the future once. Returns `true` when the future completed (Ready).
62 /// On Ready the output has been written to `TaskHeader.output`.
63 pub poll: unsafe fn(body: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool,
64
65 /// Free the `Box<TaskBody<F>>` allocation (future only; output lives in header).
66 pub drop_body: unsafe fn(body: *mut ()),
67}
68
69// ── TaskBody ──────────────────────────────────────────────────────────────────
70
71/// Heap allocation that owns the erased future.
72struct TaskBody<F> {
73 future: Pin<Box<F>>,
74}
75
76// ── Vtable implementations ────────────────────────────────────────────────────
77
78unsafe fn body_poll<F, T>(body_ptr: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool
79where
80 F: Future<Output = T>,
81 T: Send + 'static,
82{
83 // SAFETY: `body_ptr` is `Box::into_raw(Box<TaskBody<F>>)` cast to `*mut ()`.
84 let body = &mut *(body_ptr as *mut TaskBody<F>);
85 match body.future.as_mut().poll(cx) {
86 Poll::Ready(val) => {
87 // Store the boxed output into the header's output slot.
88 // SAFETY: state=RUNNING — only this call site writes `output`.
89 *header.output.get() = Some(Box::new(val) as Box<dyn Any + Send>);
90 true
91 }
92 Poll::Pending => false,
93 }
94}
95
96unsafe fn body_drop<F>(ptr: *mut ()) {
97 // SAFETY: `ptr` is `Box::into_raw(Box<TaskBody<F>>)`.
98 drop(Box::from_raw(ptr as *mut TaskBody<F>));
99}
100
101fn make_vtable<F, T>() -> &'static TaskVtable
102where
103 F: Future<Output = T>,
104 T: Send + 'static,
105{
106 &TaskVtable {
107 poll: body_poll::<F, T>,
108 drop_body: body_drop::<F>,
109 }
110}
111
112// ── TaskHeader ────────────────────────────────────────────────────────────────
113
114/// Shared, reference-counted task descriptor.
115///
116/// Lives inside an `Arc<TaskHeader>`. Every `Waker`, the executor's `Task`,
117/// and the user's `JoinHandle` all hold a clone of this Arc.
118pub(crate) struct TaskHeader {
119 /// Lifecycle state — see `STATE_*` constants.
120 pub state: AtomicU32,
121
122 /// Type-erased vtable for the concrete `F` / `T` types.
123 pub vtable: &'static TaskVtable,
124
125 /// Waker registered by `JoinHandle::poll`. Called when the task finishes.
126 ///
127 /// # Safety invariant
128 /// Written only when `state < STATE_COMPLETED` (by `JoinHandle::poll` on
129 /// the executor thread). Read+cleared only when transitioning to
130 /// COMPLETED/CANCELLED (by `Task::poll_task` / `Task::cancel`, also on the
131 /// executor thread). Single-threaded executor guarantees no data race.
132 pub join_waker: UnsafeCell<Option<Waker>>,
133
134 /// Raw pointer to the `Box<TaskBody<F>>` allocation.
135 ///
136 /// # Safety invariant
137 /// Non-null from `Task::new` until `drop_body` is called by either
138 /// `poll_task` (on completion) or `cancel`. Nulled immediately after.
139 /// Only read/written while `state == STATE_RUNNING` or during cancellation.
140 pub body_ptr: UnsafeCell<*mut ()>,
141
142 /// Output value written by the vtable's `poll` on completion.
143 /// Read (and taken) exactly once by `JoinHandle::poll`.
144 ///
145 /// # Safety invariant
146 /// Written when `state` transitions to COMPLETED. Read when `state` is
147 /// observed as COMPLETED by `JoinHandle::poll`. Single-threaded executor
148 /// prevents concurrent writes+reads.
149 pub output: UnsafeCell<Option<Box<dyn Any + Send>>>,
150}
151
152// SAFETY: All `UnsafeCell` fields in `TaskHeader` are protected by the
153// atomic `state` field and the single-threaded executor invariant.
154// No two threads access mutable fields concurrently.
155unsafe impl Send for TaskHeader {}
156unsafe impl Sync for TaskHeader {}
157
158// ── Task ──────────────────────────────────────────────────────────────────────
159
160/// Executor-owned handle to a spawned task.
161pub(crate) struct Task {
162 pub(crate) header: Arc<TaskHeader>,
163}
164
165impl Task {
166 /// Allocate a new task returning the executor `Task` + user `JoinHandle<T>`.
167 pub(crate) fn new<F, T>(future: F) -> (Task, JoinHandle<T>)
168 where
169 F: Future<Output = T> + 'static,
170 T: Send + 'static,
171 {
172 // Allocate and leak the future body (freed via vtable.drop_body).
173 let body: Box<TaskBody<F>> = Box::new(TaskBody {
174 future: Box::pin(future),
175 });
176 let body_ptr = Box::into_raw(body) as *mut ();
177
178 let header = Arc::new(TaskHeader {
179 state: AtomicU32::new(STATE_SCHEDULED),
180 vtable: make_vtable::<F, T>(),
181 join_waker: UnsafeCell::new(None),
182 body_ptr: UnsafeCell::new(body_ptr),
183 output: UnsafeCell::new(None),
184 });
185
186 let join_arc = Arc::clone(&header);
187 let task = Task { header };
188 let jh = JoinHandle {
189 header: join_arc,
190 _marker: std::marker::PhantomData,
191 };
192 (task, jh)
193 }
194
195 /// Poll the task's future once. Returns `true` when the future completed.
196 ///
197 /// State transitions: SCHEDULED → RUNNING → IDLE (Pending) | COMPLETED (Ready)
198 pub(crate) fn poll_task(&self, cx: &mut Context<'_>) -> bool {
199 let h = &self.header;
200 h.state.store(STATE_RUNNING, Ordering::Release);
201
202 // SAFETY: state=RUNNING — exclusive access to body_ptr.
203 let body_ptr = unsafe { *h.body_ptr.get() };
204 debug_assert!(!body_ptr.is_null(), "poll_task called on freed body");
205
206 // SAFETY: vtable matches the concrete types used in `new`.
207 let completed = unsafe { (h.vtable.poll)(body_ptr, h, cx) };
208
209 if completed {
210 // Free the future body — output is now in h.output.
211 // SAFETY: body_ptr valid; state=RUNNING prevents concurrent access.
212 unsafe {
213 (h.vtable.drop_body)(body_ptr);
214 *h.body_ptr.get() = std::ptr::null_mut();
215 }
216 h.state.store(STATE_COMPLETED, Ordering::Release);
217 // Wake the JoinHandle waiter.
218 // SAFETY: state=COMPLETED — no concurrent join_waker writes.
219 let waker = unsafe { (*h.join_waker.get()).take() };
220 if let Some(w) = waker {
221 w.wake();
222 }
223 } else {
224 h.state.store(STATE_IDLE, Ordering::Release);
225 }
226 completed
227 }
228
229 /// Cancel the task: drop the future body and wake the JoinHandle.
230 ///
231 /// Must be called at most once by the executor.
232 pub(crate) fn cancel(self) {
233 let h = &self.header;
234 // SAFETY: executor guarantees cancel is called while holding the Task,
235 // which means state is SCHEDULED or CANCELLED (set by abort()).
236 // Either way we own exclusive access to body_ptr.
237 let body_ptr = unsafe { *h.body_ptr.get() };
238 if !body_ptr.is_null() {
239 unsafe {
240 (h.vtable.drop_body)(body_ptr);
241 *h.body_ptr.get() = std::ptr::null_mut();
242 }
243 }
244 h.state.store(STATE_CANCELLED, Ordering::Release);
245 // Wake JoinHandle so it returns JoinError::Cancelled.
246 // SAFETY: state=CANCELLED — exclusive join_waker access.
247 let waker = unsafe { (*h.join_waker.get()).take() };
248 if let Some(w) = waker {
249 w.wake();
250 }
251 // Arc refcount decremented when `self` drops.
252 }
253}
254
255// ── JoinHandle ────────────────────────────────────────────────────────────────
256
257/// Future returned from `spawn()`. Resolves when the spawned task completes.
258pub struct JoinHandle<T> {
259 pub(crate) header: Arc<TaskHeader>,
260 _marker: std::marker::PhantomData<T>,
261}
262
263impl<T: Send + 'static> JoinHandle<T> {
264 /// Request cancellation. If the task hasn't started or is idle, it will be
265 /// dropped by the executor on its next scheduling pass.
266 pub fn abort(&self) {
267 // Try to flip IDLE → CANCELLED.
268 let _ = self.header.state.compare_exchange(
269 STATE_IDLE,
270 STATE_CANCELLED,
271 Ordering::AcqRel,
272 Ordering::Relaxed,
273 );
274 // Try to flip SCHEDULED → CANCELLED.
275 let _ = self.header.state.compare_exchange(
276 STATE_SCHEDULED,
277 STATE_CANCELLED,
278 Ordering::AcqRel,
279 Ordering::Relaxed,
280 );
281 }
282}
283
284impl<T: Send + 'static> Future for JoinHandle<T> {
285 type Output = Result<T, JoinError>;
286
287 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
288 let state = self.header.state.load(Ordering::Acquire);
289
290 match state {
291 STATE_COMPLETED => {
292 // Take the output the task wrote into the header.
293 // SAFETY: state=COMPLETED — the executor will not write output again.
294 // Single-threaded: no concurrent reads from another JoinHandle.
295 let boxed = unsafe { (*self.header.output.get()).take() };
296 match boxed {
297 Some(any_val) => match any_val.downcast::<T>() {
298 Ok(val) => Poll::Ready(Ok(*val)),
299 Err(_) => Poll::Ready(Err(JoinError::Cancelled)), // type mismatch (bug)
300 },
301 None => Poll::Ready(Err(JoinError::Cancelled)), // already taken
302 }
303 }
304 STATE_CANCELLED => Poll::Ready(Err(JoinError::Cancelled)),
305 _ => {
306 // Task still in flight — register our waker.
307 // SAFETY: state is IDLE/SCHEDULED/RUNNING (not COMPLETED/CANCELLED).
308 // The executor will write join_waker only after observing COMPLETED/CANCELLED,
309 // which has not happened yet. Single-threaded: no concurrent poll.
310 unsafe {
311 *self.header.join_waker.get() = Some(cx.waker().clone());
312 }
313 Poll::Pending
314 }
315 }
316 }
317}
318
319// ── Tests ─────────────────────────────────────────────────────────────────────
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use std::sync::atomic::AtomicBool;
325
326 #[test]
327 fn task_new_initial_state() {
328 let (task, _jh) = Task::new(async { 42u32 });
329 assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
330 }
331
332 #[test]
333 fn join_error_display() {
334 assert_eq!(JoinError::Cancelled.to_string(), "task was cancelled");
335 assert!(JoinError::Panic(Box::new("x"))
336 .to_string()
337 .contains("panicked"));
338 }
339
340 #[test]
341 fn abort_from_idle_sets_cancelled() {
342 let (task, jh) = Task::new(async { 1u32 });
343 task.header.state.store(STATE_IDLE, Ordering::Release);
344 jh.abort();
345 assert_eq!(task.header.state.load(Ordering::Acquire), STATE_CANCELLED);
346 }
347
348 #[test]
349 fn cancel_drops_future() {
350 let dropped = Arc::new(AtomicBool::new(false));
351 let d = dropped.clone();
352
353 struct Bomb(Arc<AtomicBool>);
354 impl Drop for Bomb {
355 fn drop(&mut self) {
356 self.0.store(true, Ordering::SeqCst);
357 }
358 }
359 impl Future for Bomb {
360 type Output = ();
361 fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<()> {
362 Poll::Pending
363 }
364 }
365
366 let (task, _jh) = Task::new(Bomb(d));
367 task.cancel();
368 assert!(
369 dropped.load(Ordering::SeqCst),
370 "future must be dropped on cancel"
371 );
372 }
373}