moduvex_runtime/executor/task.rs
1//! Task lifecycle types: `TaskHeader`, `Task`, `JoinHandle`.
2//!
3//! # Memory Model
4//!
5//! Two separate heap allocations per spawned future:
6//!
7//! 1. `Arc<TaskHeader>` — shared between executor (`Task`), all `Waker`s,
8//! and `JoinHandle`. Contains the atomic state, vtable pointer, join-waker
9//! slot, and the output slot (written on completion, read by JoinHandle).
10//!
11//! 2. `Box<TaskBody<F>>` (stored as `body_ptr: *mut ()` in `TaskHeader`) —
12//! owns the erased `Pin<Box<F>>` (the live future). Freed by the executor
13//! the moment the future resolves or the task is cancelled, independent of
14//! when the JoinHandle reads the output.
15//!
16//! Separating the output from the body lets `drop_body` free the future
17//! immediately on completion while the output lives safely in the Arc until
18//! `JoinHandle::poll` retrieves it.
19//!
20//! # Thread Safety for Multi-Threaded Executor
21//!
22//! `join_waker` is now protected by a `Mutex` to allow safe concurrent access
23//! between `JoinHandle::poll` (any worker thread) and `poll_task` / `cancel`
24//! (any background worker). The double-check pattern in `JoinHandle::poll`
25//! ensures the waker is never missed if a task completes concurrently.
26
27use std::any::Any;
28use std::cell::UnsafeCell;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::atomic::{AtomicU32, Ordering};
32use std::sync::{Arc, Mutex};
33use std::task::{Context, Poll, Waker};
34
35// ── State constants ───────────────────────────────────────────────────────────
36
37pub(crate) const STATE_IDLE: u32 = 0;
38pub(crate) const STATE_SCHEDULED: u32 = 1;
39pub(crate) const STATE_RUNNING: u32 = 2;
40pub(crate) const STATE_COMPLETED: u32 = 3;
41pub(crate) const STATE_CANCELLED: u32 = 4;
42
43// ── JoinError ─────────────────────────────────────────────────────────────────
44
45/// Error returned by a `JoinHandle` when the task does not complete normally.
46#[derive(Debug)]
47pub enum JoinError {
48 /// Task was aborted via `JoinHandle::abort()`.
49 Cancelled,
50 /// Task's future panicked. Panic payload preserved.
51 Panic(Box<dyn Any + Send + 'static>),
52}
53
54impl std::fmt::Display for JoinError {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 JoinError::Cancelled => write!(f, "task was cancelled"),
58 JoinError::Panic(_) => write!(f, "task panicked"),
59 }
60 }
61}
62impl std::error::Error for JoinError {}
63
64// ── TaskVtable ────────────────────────────────────────────────────────────────
65
66/// Type-erased function pointers for a concrete `TaskBody<F>`.
67pub(crate) struct TaskVtable {
68 /// Poll the future once. Returns `true` when the future completed (Ready).
69 /// On Ready the output has been written to `TaskHeader.output`.
70 pub poll: unsafe fn(body: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool,
71
72 /// Free the `Box<TaskBody<F>>` allocation (future only; output lives in header).
73 pub drop_body: unsafe fn(body: *mut ()),
74}
75
76// ── TaskBody ──────────────────────────────────────────────────────────────────
77
78/// Heap allocation that owns the erased future.
79struct TaskBody<F> {
80 future: Pin<Box<F>>,
81}
82
83// ── Vtable implementations ────────────────────────────────────────────────────
84
85unsafe fn body_poll<F, T>(body_ptr: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool
86where
87 F: Future<Output = T>,
88 T: Send + 'static,
89{
90 // SAFETY: `body_ptr` is `Box::into_raw(Box<TaskBody<F>>)` cast to `*mut ()`.
91 let body = &mut *(body_ptr as *mut TaskBody<F>);
92 match body.future.as_mut().poll(cx) {
93 Poll::Ready(val) => {
94 // Store the boxed output into the header's output slot.
95 // SAFETY: state=RUNNING — only this call site writes `output`.
96 *header.output.get() = Some(Box::new(val) as Box<dyn Any + Send>);
97 true
98 }
99 Poll::Pending => false,
100 }
101}
102
103unsafe fn body_drop<F>(ptr: *mut ()) {
104 // SAFETY: `ptr` is `Box::into_raw(Box<TaskBody<F>>)`.
105 drop(Box::from_raw(ptr as *mut TaskBody<F>));
106}
107
108fn make_vtable<F, T>() -> &'static TaskVtable
109where
110 F: Future<Output = T>,
111 T: Send + 'static,
112{
113 &TaskVtable {
114 poll: body_poll::<F, T>,
115 drop_body: body_drop::<F>,
116 }
117}
118
119// ── TaskHeader ────────────────────────────────────────────────────────────────
120
121/// Shared, reference-counted task descriptor.
122///
123/// Lives inside an `Arc<TaskHeader>`. Every `Waker`, the executor's `Task`,
124/// and the user's `JoinHandle` all hold a clone of this Arc.
125pub(crate) struct TaskHeader {
126 /// Lifecycle state — see `STATE_*` constants.
127 pub state: AtomicU32,
128
129 /// Type-erased vtable for the concrete `F` / `T` types.
130 pub vtable: &'static TaskVtable,
131
132 /// Waker registered by `JoinHandle::poll`. Called when the task finishes.
133 ///
134 /// Protected by a `Mutex` to allow safe concurrent access between
135 /// `JoinHandle::poll` (on any worker thread) and `poll_task`/`cancel`
136 /// (on any background worker). The double-check pattern in `JoinHandle::poll`
137 /// ensures no missed wake-ups.
138 pub join_waker: Mutex<Option<Waker>>,
139
140 /// Raw pointer to the `Box<TaskBody<F>>` allocation.
141 ///
142 /// # Safety invariant
143 /// Non-null from `Task::new` until `drop_body` is called by either
144 /// `poll_task` (on completion) or `cancel`. Nulled immediately after.
145 /// Only read/written while `state == STATE_RUNNING` or during cancellation.
146 pub body_ptr: UnsafeCell<*mut ()>,
147
148 /// Output value written by the vtable's `poll` on completion.
149 ///
150 /// Written with Release ordering on state → COMPLETED transition.
151 /// Read with Acquire ordering after observing STATE_COMPLETED.
152 /// The Release/Acquire pair on `state` provides the memory barrier.
153 pub output: UnsafeCell<Option<Box<dyn Any + Send>>>,
154}
155
156// SAFETY: `body_ptr` and `output` are UnsafeCell fields accessed under the
157// state machine's ordering guarantees:
158// - `body_ptr`: only accessed while state == STATE_RUNNING (exclusive)
159// - `output`: written before STATE_COMPLETED store (Release); read after
160// STATE_COMPLETED load (Acquire)
161// `join_waker` is protected by its own Mutex.
162unsafe impl Send for TaskHeader {}
163unsafe impl Sync for TaskHeader {}
164
165// ── Task ──────────────────────────────────────────────────────────────────────
166
167/// Executor-owned handle to a spawned task.
168pub(crate) struct Task {
169 pub(crate) header: Arc<TaskHeader>,
170}
171
172impl Task {
173 /// Allocate a new task returning the executor `Task` + user `JoinHandle<T>`.
174 pub(crate) fn new<F, T>(future: F) -> (Task, JoinHandle<T>)
175 where
176 F: Future<Output = T> + 'static,
177 T: Send + 'static,
178 {
179 // Allocate and leak the future body (freed via vtable.drop_body).
180 let body: Box<TaskBody<F>> = Box::new(TaskBody {
181 future: Box::pin(future),
182 });
183 let body_ptr = Box::into_raw(body) as *mut ();
184
185 let header = Arc::new(TaskHeader {
186 state: AtomicU32::new(STATE_SCHEDULED),
187 vtable: make_vtable::<F, T>(),
188 join_waker: Mutex::new(None),
189 body_ptr: UnsafeCell::new(body_ptr),
190 output: UnsafeCell::new(None),
191 });
192
193 let join_arc = Arc::clone(&header);
194 let task = Task { header };
195 let jh = JoinHandle {
196 header: join_arc,
197 _marker: std::marker::PhantomData,
198 };
199 (task, jh)
200 }
201
202 /// Poll the task's future once. Returns `true` when the future completed.
203 ///
204 /// State transitions: SCHEDULED → RUNNING → IDLE (Pending) | COMPLETED (Ready)
205 pub(crate) fn poll_task(&self, cx: &mut Context<'_>) -> bool {
206 let h = &self.header;
207 h.state.store(STATE_RUNNING, Ordering::Release);
208
209 // SAFETY: state=RUNNING — exclusive access to body_ptr.
210 let body_ptr = unsafe { *h.body_ptr.get() };
211 debug_assert!(!body_ptr.is_null(), "poll_task called on freed body");
212
213 // SAFETY: vtable matches the concrete types used in `new`.
214 let completed = unsafe { (h.vtable.poll)(body_ptr, h, cx) };
215
216 if completed {
217 // Free the future body — output is now in h.output.
218 // SAFETY: body_ptr valid; state=RUNNING prevents concurrent access.
219 unsafe {
220 (h.vtable.drop_body)(body_ptr);
221 *h.body_ptr.get() = std::ptr::null_mut();
222 }
223 // Set COMPLETED with Release so the output write is visible to
224 // any thread that observes STATE_COMPLETED with Acquire.
225 h.state.store(STATE_COMPLETED, Ordering::Release);
226 // Wake the JoinHandle waiter under the Mutex to prevent races
227 // with JoinHandle::poll registering a waker concurrently.
228 let waker = h.join_waker.lock().unwrap().take();
229 if let Some(w) = waker {
230 w.wake();
231 }
232 } else {
233 h.state.store(STATE_IDLE, Ordering::Release);
234 }
235 completed
236 }
237
238 /// Cancel the task: drop the future body and wake the JoinHandle.
239 ///
240 /// Must be called at most once by the executor.
241 pub(crate) fn cancel(self) {
242 let h = &self.header;
243 // SAFETY: executor holds the Task exclusively; state = SCHEDULED or CANCELLED.
244 let body_ptr = unsafe { *h.body_ptr.get() };
245 if !body_ptr.is_null() {
246 unsafe {
247 (h.vtable.drop_body)(body_ptr);
248 *h.body_ptr.get() = std::ptr::null_mut();
249 }
250 }
251 h.state.store(STATE_CANCELLED, Ordering::Release);
252 // Wake JoinHandle under the Mutex so no waker is missed.
253 let waker = h.join_waker.lock().unwrap().take();
254 if let Some(w) = waker {
255 w.wake();
256 }
257 // Arc refcount decremented when `self` drops.
258 }
259}
260
261// ── JoinHandle ────────────────────────────────────────────────────────────────
262
263/// Future returned from `spawn()`. Resolves when the spawned task completes.
264pub struct JoinHandle<T> {
265 pub(crate) header: Arc<TaskHeader>,
266 _marker: std::marker::PhantomData<T>,
267}
268
269impl<T: Send + 'static> JoinHandle<T> {
270 /// Request cancellation. If the task hasn't started or is idle, it will be
271 /// dropped by the executor on its next scheduling pass.
272 pub fn abort(&self) {
273 // Try to flip IDLE → CANCELLED.
274 let _ = self.header.state.compare_exchange(
275 STATE_IDLE,
276 STATE_CANCELLED,
277 Ordering::AcqRel,
278 Ordering::Relaxed,
279 );
280 // Try to flip SCHEDULED → CANCELLED.
281 let _ = self.header.state.compare_exchange(
282 STATE_SCHEDULED,
283 STATE_CANCELLED,
284 Ordering::AcqRel,
285 Ordering::Relaxed,
286 );
287 }
288}
289
290impl<T: Send + 'static> Future for JoinHandle<T> {
291 type Output = Result<T, JoinError>;
292
293 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
294 // Fast path: check state before acquiring the waker lock.
295 let state = self.header.state.load(Ordering::Acquire);
296
297 if state == STATE_COMPLETED {
298 return self.take_output();
299 }
300 if state == STATE_CANCELLED {
301 return Poll::Ready(Err(JoinError::Cancelled));
302 }
303
304 // Task still in flight. Register waker under the Mutex to prevent a
305 // race with poll_task completing the task simultaneously.
306 //
307 // Double-check pattern:
308 // 1. Lock the waker Mutex.
309 // 2. Re-read state (now synchronized with poll_task's Mutex lock).
310 // 3. If still in-flight, store waker.
311 // 4. If completed/cancelled, return Ready immediately.
312 let mut guard = self.header.join_waker.lock().unwrap();
313 // Re-check under lock: poll_task takes the lock before setting
314 // STATE_COMPLETED, so if state is not COMPLETED here, we're safe to
315 // store the waker and it will be taken by poll_task later.
316 let state = self.header.state.load(Ordering::Acquire);
317 match state {
318 STATE_COMPLETED => {
319 drop(guard);
320 self.take_output()
321 }
322 STATE_CANCELLED => {
323 drop(guard);
324 Poll::Ready(Err(JoinError::Cancelled))
325 }
326 _ => {
327 *guard = Some(cx.waker().clone());
328 Poll::Pending
329 }
330 }
331 }
332}
333
334impl<T: Send + 'static> JoinHandle<T> {
335 /// Take the output from the header after observing STATE_COMPLETED.
336 fn take_output(self: Pin<&mut Self>) -> Poll<Result<T, JoinError>> {
337 // SAFETY: state=COMPLETED (observed with Acquire). The worker that set
338 // COMPLETED used Release ordering. The Release/Acquire pair establishes
339 // happens-before: output write → COMPLETED store → our load → output read.
340 let boxed = unsafe { (*self.header.output.get()).take() };
341 match boxed {
342 Some(any_val) => match any_val.downcast::<T>() {
343 Ok(val) => Poll::Ready(Ok(*val)),
344 Err(_) => Poll::Ready(Err(JoinError::Cancelled)),
345 },
346 None => Poll::Ready(Err(JoinError::Cancelled)), // already taken
347 }
348 }
349}
350
351// ── Tests ─────────────────────────────────────────────────────────────────────
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use std::sync::atomic::AtomicBool;
357
358 #[test]
359 fn task_new_initial_state() {
360 let (task, _jh) = Task::new(async { 42u32 });
361 assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
362 }
363
364 #[test]
365 fn join_error_display() {
366 assert_eq!(JoinError::Cancelled.to_string(), "task was cancelled");
367 assert!(JoinError::Panic(Box::new("x"))
368 .to_string()
369 .contains("panicked"));
370 }
371
372 #[test]
373 fn abort_from_idle_sets_cancelled() {
374 let (task, jh) = Task::new(async { 1u32 });
375 task.header.state.store(STATE_IDLE, Ordering::Release);
376 jh.abort();
377 assert_eq!(task.header.state.load(Ordering::Acquire), STATE_CANCELLED);
378 }
379
380 #[test]
381 fn cancel_drops_future() {
382 let dropped = Arc::new(AtomicBool::new(false));
383 let d = dropped.clone();
384
385 struct Bomb(Arc<AtomicBool>);
386 impl Drop for Bomb {
387 fn drop(&mut self) {
388 self.0.store(true, Ordering::SeqCst);
389 }
390 }
391 impl Future for Bomb {
392 type Output = ();
393 fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<()> {
394 Poll::Pending
395 }
396 }
397
398 let (task, _jh) = Task::new(Bomb(d));
399 task.cancel();
400 assert!(
401 dropped.load(Ordering::SeqCst),
402 "future must be dropped on cancel"
403 );
404 }
405
406 // ── Additional task tests ──────────────────────────────────────────────
407
408 #[test]
409 fn join_error_panic_display() {
410 let err = JoinError::Panic(Box::new("boom"));
411 let s = err.to_string();
412 assert!(s.contains("panic"));
413 }
414
415 #[test]
416 fn join_error_cancelled_display() {
417 let err = JoinError::Cancelled;
418 let s = err.to_string();
419 assert!(s.contains("cancel") || s.contains("Cancel"));
420 }
421
422 #[test]
423 fn abort_from_scheduled_sets_cancelled() {
424 let (_task, jh) = Task::new(async { 1u32 });
425 // Initial state is SCHEDULED
426 jh.abort();
427 assert_eq!(
428 jh.header.state.load(Ordering::Acquire),
429 STATE_CANCELLED
430 );
431 }
432
433 #[test]
434 fn task_header_initial_state_is_scheduled() {
435 let (task, _jh) = Task::new(async { 0u8 });
436 assert_eq!(
437 task.header.state.load(Ordering::Acquire),
438 STATE_SCHEDULED
439 );
440 }
441
442 #[test]
443 fn cancel_sets_state_to_cancelled() {
444 let (task, _jh) = Task::new(async { 0u8 });
445 task.cancel();
446 // After cancel, state must be CANCELLED
447 // (We read from _jh which still holds the Arc)
448 }
449
450 #[test]
451 fn abort_completed_task_has_no_effect() {
452 let (task, jh) = Task::new(async { 99u32 });
453 // Manually set state to COMPLETED (simulating task that already ran)
454 task.header.state.store(STATE_COMPLETED, Ordering::Release);
455 jh.abort(); // abort on completed task — must not panic
456 // State remains COMPLETED (CAS to IDLE fails, CAS to SCHEDULED fails)
457 assert_eq!(
458 jh.header.state.load(Ordering::Acquire),
459 STATE_COMPLETED
460 );
461 }
462
463 #[test]
464 fn state_constants_distinct() {
465 // All STATE_* constants must be distinct values
466 let states = [
467 STATE_IDLE,
468 STATE_SCHEDULED,
469 STATE_RUNNING,
470 STATE_COMPLETED,
471 STATE_CANCELLED,
472 ];
473 let unique: std::collections::HashSet<u32> = states.iter().cloned().collect();
474 assert_eq!(unique.len(), states.len());
475 }
476
477 #[test]
478 fn join_error_debug_format() {
479 let err = JoinError::Cancelled;
480 let s = format!("{err:?}");
481 assert!(!s.is_empty());
482 }
483
484 #[test]
485 fn task_new_creates_join_handle_with_same_header() {
486 let (task, jh) = Task::new(async { 0u32 });
487 // Both task and jh share the same header Arc
488 assert!(Arc::ptr_eq(&task.header, &jh.header));
489 }
490
491 #[test]
492 fn abort_from_idle_state_succeeds() {
493 let (task, jh) = Task::new(async { 0u32 });
494 task.header.state.store(STATE_IDLE, Ordering::Release);
495 jh.abort();
496 assert_eq!(task.header.state.load(Ordering::Acquire), STATE_CANCELLED);
497 }
498
499 #[test]
500 fn multiple_aborts_are_idempotent() {
501 let (_task, jh) = Task::new(async { 0u32 });
502 // Abort multiple times — must not panic
503 jh.abort();
504 jh.abort();
505 jh.abort();
506 }
507}