Skip to main content

nexus_async_rt/
task.rs

1//! Task storage: header + future/output union in a contiguous allocation.
2//!
3//! Each task is a `Task<F>` struct. The raw pointer to the allocation
4//! IS the task handle — no index layer, no separate metadata store.
5//!
6//! The waker holds the raw pointer directly. `wake()` sets `is_queued`
7//! and pushes the pointer to the ready queue. Zero allocations.
8//!
9//! Tasks can be allocated via Box (default) or slab (power user).
10//! The `free_fn` in the header knows how to deallocate regardless
11//! of which allocator was used.
12//!
13//! ## Union storage
14//!
15//! The slot at `storage_offset` holds either `F` (the future) or `T` (the output),
16//! never both. While running, `F` is live. When the future completes,
17//! `poll_join` drops `F` in place and writes `T` to the same bytes.
18//! `drop_fn` is overwritten from `drop_fn::<F>` to `drop_output::<T>`
19//! so subsequent cleanup targets the correct type.
20
21use std::cell::UnsafeCell;
22use std::future::Future;
23use std::marker::PhantomData;
24use std::pin::Pin;
25use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU16, Ordering};
26use std::task::{Context, Poll, Waker};
27
28// =============================================================================
29// Task flags
30// =============================================================================
31
32/// JoinHandle exists for this task.
33const HAS_JOIN: u8 = 0b001;
34/// JoinHandle consumed the output via poll.
35const OUTPUT_TAKEN: u8 = 0b010;
36/// abort() was called.
37const ABORTED: u8 = 0b100;
38
39// =============================================================================
40// Task layout
41// =============================================================================
42
43/// Header size in bytes. Must match the layout of `Task<F>` before the
44/// `future` field.
45pub const TASK_HEADER_SIZE: usize = 64;
46
47/// Task header + storage in a contiguous allocation. `repr(C)` for
48/// deterministic layout.
49///
50/// `S` is the storage type — either just `F` (fire-and-forget) or a union
51/// of `F` and `T` (joinable). The header is always 64 bytes regardless of `S`.
52///
53/// Layout (64-bit):
54/// ```text
55/// offset  0: poll_fn       (8B, fn pointer — polls the future)
56/// offset  8: drop_fn       (8B, fn pointer — drops F or T in place)
57/// offset 16: free_fn       (8B, fn pointer — deallocates the task storage)
58/// offset 24: is_queued      (1B, AtomicBool — cross-thread wakers CAS this)
59/// offset 25: is_completed   (1B, AtomicBool — cross-thread reads with Acquire)
60/// offset 26: ref_count      (2B, AtomicU16 — number of live references)
61/// offset 28: tracker_key    (4B, u32 — index in Executor::all_tasks slab)
62/// offset 32: cross_next     (8B, AtomicPtr — intrusive cross-thread wake queue)
63/// offset 40: join_waker     (16B, UnsafeCell<Option<Waker>>)
64/// offset 56: storage_offset (2B, u16 — byte offset to storage field)
65/// offset 58: flags          (1B, Cell<u8> — HAS_JOIN | OUTPUT_TAKEN | ABORTED)
66/// offset 59: _pad           (5B)
67/// offset 64: storage        (S bytes — future F or union { F, T })
68/// ```
69#[repr(C)]
70pub(crate) struct Task<S> {
71    /// Polls the future. Receives the task base pointer.
72    poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>,
73    /// Drops the value at `storage_offset` (future F or output T). Receives base pointer.
74    drop_fn: unsafe fn(*mut u8),
75    /// Deallocates the task storage.
76    free_fn: unsafe fn(*mut u8),
77    is_queued: AtomicBool,
78    /// Set when the future completes or is aborted.
79    is_completed: AtomicBool,
80    /// Number of live references (executor + waker clones + JoinHandle).
81    ref_count: AtomicU16,
82    /// Index into the Executor's `all_tasks` slab.
83    tracker_key: u32,
84    /// Intrusive next pointer for the cross-thread wake queue.
85    cross_next: AtomicPtr<u8>,
86    /// Waker for the task awaiting this JoinHandle.
87    join_waker: UnsafeCell<Option<Waker>>,
88    /// Byte offset from task base to the storage field.
89    /// Set at construction from `offset_of!(Task<S>, storage)`.
90    storage_offset: u16,
91    /// Packed flags: HAS_JOIN | OUTPUT_TAKEN | ABORTED.
92    /// Single-threaded — no atomics needed.
93    flags: std::cell::Cell<u8>,
94    /// Padding to reach 64 bytes.
95    _pad: [u8; 5],
96    storage: S,
97}
98
99/// Union storage for joinable tasks. Sized to fit both the future F
100/// and the output T in the same allocation.
101#[repr(C)]
102pub(crate) union FutureOrOutput<F, T> {
103    pub(crate) future: std::mem::ManuallyDrop<F>,
104    pub(crate) output: std::mem::ManuallyDrop<T>,
105}
106
107// Static assertion: header layout matches TASK_HEADER_SIZE.
108const _: () = {
109    assert!(std::mem::size_of::<Task<()>>() == TASK_HEADER_SIZE);
110};
111
112impl<F: Future<Output = ()> + 'static> Task<F> {
113    /// Construct a fire-and-forget task (no JoinHandle) with Box-based free.
114    ///
115    /// Used internally for tests and low-level task construction.
116    /// `ref_count = 1` (executor only), `HAS_JOIN` not set.
117    ///
118    /// # Why `Output = ()` is required
119    ///
120    /// This uses `poll_join::<F>` which writes T at offset 64 after dropping F.
121    /// The storage is `F` (not `FutureOrOutput<F, T>`), so it's only sized for F.
122    /// With `T = ()` (ZST), the write is zero-size and the `drop_fn` overwrite
123    /// to `drop_output::<()>` is a no-op. Relaxing this bound to non-ZST T
124    /// would write T into storage not sized for it — UB.
125    #[cfg(test)]
126    #[inline]
127    pub(crate) fn new_boxed(future: F, tracker_key: u32) -> Self {
128        Self {
129            poll_fn: poll_join::<F>,
130            drop_fn: drop_future::<F>,
131            free_fn: box_free::<F>,
132            is_queued: AtomicBool::new(false),
133            is_completed: AtomicBool::new(false),
134            ref_count: AtomicU16::new(1),
135            tracker_key,
136            cross_next: AtomicPtr::new(std::ptr::null_mut()),
137            join_waker: UnsafeCell::new(None),
138            flags: std::cell::Cell::new(0),
139            storage_offset: std::mem::offset_of!(Task<F>, storage) as u16,
140            _pad: [0; 5],
141            storage: future,
142        }
143    }
144}
145
146/// Allocate a joinable Box task and return the raw pointer.
147///
148/// The task has `ref_count = 2` (executor + JoinHandle) and `HAS_JOIN` set.
149/// The allocation is sized for `max(size_of::<F>(), size_of::<T>())` via
150/// the `FutureOrOutput<F, T>` union.
151pub(crate) fn box_spawn_joinable<F>(future: F, tracker_key: u32) -> *mut u8
152where
153    F: Future + 'static,
154    F::Output: 'static,
155{
156    type Storage<F> = FutureOrOutput<F, <F as Future>::Output>;
157
158    let task: Task<Storage<F>> = Task {
159        poll_fn: poll_join::<F>,
160        drop_fn: drop_future_in_union::<F>,
161        free_fn: box_free::<Storage<F>>,
162        is_queued: AtomicBool::new(false),
163        is_completed: AtomicBool::new(false),
164        ref_count: AtomicU16::new(2), // executor + JoinHandle
165        tracker_key,
166        cross_next: AtomicPtr::new(std::ptr::null_mut()),
167        join_waker: UnsafeCell::new(None),
168        flags: std::cell::Cell::new(HAS_JOIN),
169        storage_offset: std::mem::offset_of!(Task<Storage<F>>, storage) as u16,
170        _pad: [0; 5],
171        storage: FutureOrOutput {
172            future: std::mem::ManuallyDrop::new(future),
173        },
174    };
175    Box::into_raw(Box::new(task)) as *mut u8
176}
177
178/// Construct a joinable task for slab allocation.
179///
180/// Returns the task struct to be copied into a slab slot. Uses the
181/// `FutureOrOutput<F, T>` union so the allocation fits both.
182pub(crate) fn new_joinable_slab<F>(
183    future: F,
184    tracker_key: u32,
185    free_fn: unsafe fn(*mut u8),
186) -> Task<FutureOrOutput<F, F::Output>>
187where
188    F: Future + 'static,
189    F::Output: 'static,
190{
191    type Storage<F> = FutureOrOutput<F, <F as Future>::Output>;
192
193    Task {
194        poll_fn: poll_join::<F>,
195        drop_fn: drop_future_in_union::<F>,
196        free_fn,
197        is_queued: AtomicBool::new(false),
198        is_completed: AtomicBool::new(false),
199        ref_count: AtomicU16::new(2), // executor + JoinHandle
200        tracker_key,
201        cross_next: AtomicPtr::new(std::ptr::null_mut()),
202        join_waker: UnsafeCell::new(None),
203        flags: std::cell::Cell::new(HAS_JOIN),
204        storage_offset: std::mem::offset_of!(Task<Storage<F>>, storage) as u16,
205        _pad: [0; 5],
206        storage: FutureOrOutput {
207            future: std::mem::ManuallyDrop::new(future),
208        },
209    }
210}
211
212// =============================================================================
213// Task handle — raw pointer operations
214// =============================================================================
215
216/// Opaque task identifier. Wraps the raw pointer to the task.
217/// The pointer is stable for the task's lifetime.
218#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
219pub(crate) struct TaskId(pub(crate) *mut u8);
220
221impl TaskId {
222    /// Returns the raw pointer to the task.
223    #[allow(dead_code)]
224    pub(crate) fn as_ptr(&self) -> *mut u8 {
225        self.0
226    }
227}
228
229// =============================================================================
230// JoinHandle
231// =============================================================================
232
233/// Handle to a spawned task. Await to get the result.
234///
235/// Dropping the handle detaches the task — it continues running but the
236/// output is dropped when the task completes. Use [`abort()`](Self::abort)
237/// to cancel the task.
238///
239/// `JoinHandle` is `!Send` and `!Sync` — it must stay on the executor thread.
240#[must_use = "dropping a JoinHandle detaches the task — await it or call .abort()"]
241pub struct JoinHandle<T> {
242    ptr: *mut u8,
243    _marker: PhantomData<T>,
244    _not_send: PhantomData<*const ()>, // !Send + !Sync
245}
246
247impl<T: 'static> Future for JoinHandle<T> {
248    type Output = T;
249
250    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
251        let ptr = self.ptr;
252
253        // SAFETY: ptr is valid — JoinHandle holds a ref (refcount >= 1).
254        if unsafe { is_completed(ptr) } {
255            let flags = unsafe { task_flags(ptr) };
256            assert!(
257                flags & ABORTED == 0,
258                "polled JoinHandle after task was aborted"
259            );
260            // SAFETY: Task completed, so poll_join already transitioned the union
261            // from F to T. The output is live at storage_offset. ptr::read moves
262            // it out (bitwise copy). OUTPUT_TAKEN prevents double-read.
263            let output_ptr = unsafe { ptr.add(storage_offset(ptr)) };
264            let value = unsafe { std::ptr::read(output_ptr.cast::<T>()) };
265            unsafe { set_flag(ptr, OUTPUT_TAKEN) };
266            Poll::Ready(value)
267        } else {
268            // SAFETY: Task still running, single-threaded — safe to write waker.
269            unsafe { set_join_waker(ptr, cx.waker().clone()) };
270            Poll::Pending
271        }
272    }
273}
274
275impl<T> JoinHandle<T> {
276    pub(crate) fn new(ptr: *mut u8) -> Self {
277        Self {
278            ptr,
279            _marker: PhantomData,
280            _not_send: PhantomData,
281        }
282    }
283
284    /// Returns `true` if the task has completed (output is ready).
285    pub fn is_finished(&self) -> bool {
286        unsafe { is_completed(self.ptr) }
287    }
288
289    /// Abort the task and consume the handle.
290    ///
291    /// The future is dropped on the next poll cycle. Consumes the handle
292    /// so it cannot be awaited after abort — this is enforced at the type
293    /// level rather than via a runtime panic.
294    ///
295    /// Returns `true` if the task was still running, `false` if it had
296    /// already completed (output is dropped by `JoinHandle::drop`).
297    #[must_use = "returns whether the task was still running"]
298    pub fn abort(self) -> bool {
299        let ptr = self.ptr;
300        let was_running = !unsafe { is_completed(ptr) };
301        if was_running {
302            unsafe { set_flag(ptr, ABORTED) };
303        }
304        // self is consumed — Drop runs, which clears HAS_JOIN,
305        // takes the join waker, and decrements refcount.
306        was_running
307    }
308}
309
310impl<T> Drop for JoinHandle<T> {
311    fn drop(&mut self) {
312        let ptr = self.ptr;
313        // SAFETY: ptr is valid — JoinHandle holds a ref (refcount >= 1).
314        // All accessor calls below are single-threaded and target valid
315        // header fields at known offsets.
316        let flags = unsafe { task_flags(ptr) };
317
318        if unsafe { is_completed(ptr) } && (flags & OUTPUT_TAKEN == 0) && (flags & ABORTED == 0) {
319            // Task completed but output was never read — drop it.
320            // SAFETY: poll_join overwrote drop_fn to drop_output::<T>,
321            // so this drops the output T (not the future F).
322            unsafe { drop_task_future(ptr) };
323        }
324
325        // Clear HAS_JOIN so complete_task knows nobody is waiting.
326        unsafe { clear_flag(ptr, HAS_JOIN) };
327
328        // If we previously polled to Pending, a cloned waker is stored in the
329        // task. Clear it so the parent task's refcount isn't kept alive until
330        // the child completes. take() returns None if no waker was stored.
331        let _ = unsafe { take_join_waker(ptr) };
332
333        // Release our reference. If refcount hits 0, the task is complete and
334        // all other refs (executor, wakers) are gone — defer the free.
335        let should_free = unsafe { ref_dec(ptr) };
336        if should_free {
337            // SAFETY: refcount is 0, task is completed. Can't free directly
338            // because we may be outside the poll cycle. defer_free pushes to
339            // TLS deferred list (or leaks if TLS unavailable — Executor::drop
340            // catches those).
341            unsafe { defer_free_slot(ptr) };
342        }
343    }
344}
345
346/// Push a task to the deferred free list, or free immediately if outside poll.
347///
348/// # Safety
349///
350/// `ptr` must point to a completed task with ref_count 0.
351unsafe fn defer_free_slot(ptr: *mut u8) {
352    unsafe { crate::waker::defer_free(ptr) };
353}
354
355// =============================================================================
356// Task header accessor functions
357// =============================================================================
358
359/// Read the `tracker_key` from a task pointer.
360///
361/// # Safety
362///
363/// `ptr` must point to a live `Task<F>`.
364#[inline]
365pub(crate) unsafe fn tracker_key(ptr: *mut u8) -> u32 {
366    // SAFETY: tracker_key is at offset 28 in repr(C) Task.
367    unsafe { *(ptr.add(28).cast::<u32>()) }
368}
369
370/// Increment the waker refcount. Called on waker clone.
371///
372/// # Safety
373///
374/// `ptr` must point to a live `Task<F>`.
375#[inline]
376pub(crate) unsafe fn ref_inc(ptr: *mut u8) {
377    // SAFETY: ref_count is AtomicU16 at offset 26 in repr(C) Task.
378    let rc = unsafe { &*ptr.add(26).cast::<AtomicU16>() };
379    let prev = rc.fetch_add(1, Ordering::Relaxed);
380    assert!(prev < u16::MAX, "waker refcount overflow");
381}
382
383/// Decrement the refcount. Returns true if refcount hit 0 (slot can be freed).
384///
385/// # Safety
386///
387/// `ptr` must point to a live (or completed) `Task<F>`.
388#[inline]
389pub(crate) unsafe fn ref_dec(ptr: *mut u8) -> bool {
390    // SAFETY: ref_count is AtomicU16 at offset 26.
391    let rc = unsafe { &*ptr.add(26).cast::<AtomicU16>() };
392    let prev = rc.fetch_sub(1, Ordering::AcqRel);
393    debug_assert!(prev > 0, "waker refcount underflow");
394    prev == 1
395}
396
397/// Read the refcount.
398///
399/// # Safety
400///
401/// `ptr` must point to a live `Task<F>`.
402#[allow(dead_code)]
403#[inline]
404pub(crate) unsafe fn ref_count(ptr: *mut u8) -> u16 {
405    // SAFETY: ref_count is AtomicU16 at offset 26.
406    unsafe { &*ptr.add(26).cast::<AtomicU16>() }.load(Ordering::Relaxed)
407}
408
409/// Set the is_completed flag.
410///
411/// # Safety
412///
413/// `ptr` must point to a live `Task<F>`.
414#[inline]
415pub(crate) unsafe fn set_completed(ptr: *mut u8) {
416    // SAFETY: is_completed is AtomicBool at offset 25 in repr(C) Task.
417    unsafe { &*ptr.add(25).cast::<AtomicBool>() }.store(true, Ordering::Release);
418}
419
420/// Read the is_completed flag.
421///
422/// # Safety
423///
424/// `ptr` must point to a (possibly completed) `Task<F>`.
425#[inline]
426pub(crate) unsafe fn is_completed(ptr: *mut u8) -> bool {
427    // SAFETY: is_completed is AtomicBool at offset 25.
428    unsafe { &*ptr.add(25).cast::<AtomicBool>() }.load(Ordering::Acquire)
429}
430
431/// Get a raw pointer to the `cross_next` atomic pointer.
432///
433/// # Safety
434///
435/// `ptr` must point to a live `Task<F>`.
436#[inline]
437pub(crate) unsafe fn cross_next(ptr: *mut u8) -> *const AtomicPtr<u8> {
438    // SAFETY: cross_next is at offset 32 in repr(C) Task.
439    unsafe { ptr.add(32).cast::<AtomicPtr<u8>>() }
440}
441
442/// Read the `is_queued` flag from a task pointer.
443///
444/// # Safety
445///
446/// `ptr` must point to a live `Task<F>`.
447#[inline]
448pub(crate) unsafe fn is_queued(ptr: *mut u8) -> bool {
449    // SAFETY: is_queued is AtomicBool at offset 24 in repr(C) Task.
450    unsafe { &*ptr.add(24).cast::<AtomicBool>() }.load(Ordering::Relaxed)
451}
452
453/// Set the `is_queued` flag on a task.
454///
455/// # Safety
456///
457/// `ptr` must point to a live `Task<F>`.
458#[inline]
459pub(crate) unsafe fn set_queued(ptr: *mut u8, queued: bool) {
460    // SAFETY: is_queued is AtomicBool at offset 24 in repr(C) Task.
461    unsafe { &*ptr.add(24).cast::<AtomicBool>() }.store(queued, Ordering::Relaxed);
462}
463
464/// Atomically try to set `is_queued` from false to true. Returns true if
465/// successful (was not queued). Used by cross-thread wakers.
466///
467/// # Safety
468///
469/// `ptr` must point to a live `Task<F>`.
470#[inline]
471pub(crate) unsafe fn try_set_queued(ptr: *mut u8) -> bool {
472    // SAFETY: is_queued is AtomicBool at offset 24.
473    let queued = unsafe { &*ptr.add(24).cast::<AtomicBool>() };
474    queued
475        .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
476        .is_ok()
477}
478
479/// Read the storage offset from the task header.
480///
481/// # Safety
482///
483/// `ptr` must point to a live `Task<S>`.
484#[inline]
485pub(crate) unsafe fn storage_offset(ptr: *mut u8) -> usize {
486    // SAFETY: storage_offset is u16 at offset 56 in repr(C) Task.
487    unsafe { *(ptr.add(56).cast::<u16>()) as usize }
488}
489
490/// Read task_flags.
491///
492/// # Safety
493///
494/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
495#[inline]
496unsafe fn task_flags(ptr: *mut u8) -> u8 {
497    // SAFETY: flags is Cell<u8> at offset 58.
498    unsafe { &*ptr.add(58).cast::<std::cell::Cell<u8>>() }.get()
499}
500
501/// Set a flag bit in task_flags.
502///
503/// # Safety
504///
505/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
506#[inline]
507unsafe fn set_flag(ptr: *mut u8, flag: u8) {
508    let cell = unsafe { &*ptr.add(58).cast::<std::cell::Cell<u8>>() };
509    cell.set(cell.get() | flag);
510}
511
512/// Clear a flag bit in task_flags.
513///
514/// # Safety
515///
516/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
517#[inline]
518unsafe fn clear_flag(ptr: *mut u8, flag: u8) {
519    let cell = unsafe { &*ptr.add(58).cast::<std::cell::Cell<u8>>() };
520    cell.set(cell.get() & !flag);
521}
522
523/// Check if HAS_JOIN flag is set.
524///
525/// # Safety
526///
527/// `ptr` must point to a live `Task<F>`.
528#[inline]
529pub(crate) unsafe fn has_join(ptr: *mut u8) -> bool {
530    (unsafe { task_flags(ptr) }) & HAS_JOIN != 0
531}
532
533/// Check if ABORTED flag is set.
534///
535/// # Safety
536///
537/// `ptr` must point to a live `Task<F>`.
538#[inline]
539pub(crate) unsafe fn is_aborted(ptr: *mut u8) -> bool {
540    (unsafe { task_flags(ptr) }) & ABORTED != 0
541}
542
543/// Store a waker for the JoinHandle awaiter.
544///
545/// # Safety
546///
547/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
548#[inline]
549unsafe fn set_join_waker(ptr: *mut u8, waker: Waker) {
550    // SAFETY: join_waker is UnsafeCell<Option<Waker>> at offset 40.
551    let cell = unsafe { &*ptr.add(40).cast::<UnsafeCell<Option<Waker>>>() };
552    unsafe { *cell.get() = Some(waker) };
553}
554
555/// Take the join waker (if any).
556///
557/// # Safety
558///
559/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
560#[inline]
561pub(crate) unsafe fn take_join_waker(ptr: *mut u8) -> Option<Waker> {
562    let cell = unsafe { &*ptr.add(40).cast::<UnsafeCell<Option<Waker>>>() };
563    unsafe { (*cell.get()).take() }
564}
565
566/// Poll the task's future.
567///
568/// # Safety
569///
570/// `ptr` must point to a live `Task<F>`.
571/// The future must not have been dropped.
572#[inline]
573pub(crate) unsafe fn poll_task(ptr: *mut u8, cx: &mut Context<'_>) -> Poll<()> {
574    // SAFETY: poll_fn is at offset 0 in repr(C) Task.
575    let poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()> =
576        unsafe { *(ptr as *const unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>) };
577    // Pass the task base pointer — the trampoline reads storage_offset.
578    unsafe { poll_fn(ptr, cx) }
579}
580
581/// Drop the task's future (or output) in place.
582///
583/// # Safety
584///
585/// `ptr` must point to a live `Task<F>`. Must only be called once.
586#[inline]
587pub(crate) unsafe fn drop_task_future(ptr: *mut u8) {
588    // SAFETY: drop_fn is at offset 8 in repr(C) Task.
589    let drop_fn: unsafe fn(*mut u8) = unsafe { *(ptr.add(8) as *const unsafe fn(*mut u8)) };
590    // Pass base pointer — the trampoline reads storage_offset.
591    unsafe { drop_fn(ptr) }
592}
593
594/// Call the task's free function to deallocate its storage.
595///
596/// # Safety
597///
598/// `ptr` must point to a `Task<F>` whose future has already been dropped.
599/// Must only be called once (after refcount reaches 0).
600#[inline]
601pub(crate) unsafe fn free_task(ptr: *mut u8) {
602    // SAFETY: free_fn is at offset 16 in repr(C) Task.
603    let free_fn: unsafe fn(*mut u8) = unsafe { *(ptr.add(16) as *const unsafe fn(*mut u8)) };
604    unsafe { free_fn(ptr) }
605}
606
607// =============================================================================
608// Type-erased vtable functions
609// =============================================================================
610
611/// Poll trampoline for joinable tasks (Output = T).
612///
613/// On completion: drops F, writes T into the same location, overwrites
614/// drop_fn to target T instead of F.
615///
616/// # Safety
617///
618/// `ptr` must point to a live `Task<F>`. The future must not have been dropped.
619unsafe fn poll_join<F: Future>(ptr: *mut u8, cx: &mut Context<'_>) -> Poll<()>
620where
621    F::Output: 'static,
622{
623    // Check if aborted
624    if unsafe { is_aborted(ptr) } {
625        return Poll::Ready(());
626    }
627
628    let future_ptr = unsafe { ptr.add(storage_offset(ptr)) };
629    let future = unsafe { Pin::new_unchecked(&mut *future_ptr.cast::<F>()) };
630    match future.poll(cx) {
631        Poll::Pending => Poll::Pending,
632        Poll::Ready(value) => {
633            let drop_fn_slot = unsafe { ptr.add(8).cast::<unsafe fn(*mut u8)>() };
634            // 1. Overwrite drop_fn to no-op BEFORE dropping F.
635            //    If F::drop() panics, this prevents double-drop —
636            //    subsequent cleanup calls the no-op instead of
637            //    drop_future_in_union on a partially-dropped F.
638            //    The output (value) is dropped during unwind (stack-owned).
639            unsafe { *drop_fn_slot = drop_noop };
640            // 2. Drop the future in place (panic-safe now)
641            unsafe { std::ptr::drop_in_place(future_ptr.cast::<F>()) };
642            // 3. Write output T into the same location
643            unsafe { std::ptr::write(future_ptr.cast::<F::Output>(), value) };
644            // 4. Overwrite drop_fn: now drops T instead of F
645            unsafe { *drop_fn_slot = drop_output::<F::Output> };
646            Poll::Ready(())
647        }
648    }
649}
650
651/// Drop trampoline for futures stored directly (fire-and-forget tasks).
652///
653/// # Safety
654///
655/// `ptr` must point to a live `Task<F>` with a live future at `storage_offset`.
656#[cfg(test)]
657unsafe fn drop_future<F>(ptr: *mut u8) {
658    let future_ptr = unsafe { ptr.add(storage_offset(ptr)) };
659    unsafe { std::ptr::drop_in_place(future_ptr.cast::<F>()) }
660}
661
662/// Drop trampoline for futures stored in FutureOrOutput union.
663///
664/// # Safety
665///
666/// `ptr` must point to a `Task<FutureOrOutput<F, T>>` with a live future.
667unsafe fn drop_future_in_union<F: Future>(ptr: *mut u8) {
668    let storage_ptr = unsafe { ptr.add(storage_offset(ptr)) };
669    // The future is at the start of the union (same offset as the union itself).
670    unsafe { std::ptr::drop_in_place(storage_ptr.cast::<F>()) }
671}
672
673/// No-op drop trampoline. Installed temporarily during the F→T transition
674/// in `poll_join` to prevent double-drop if `F::drop()` panics.
675///
676/// # Safety
677///
678/// Always safe — does nothing.
679unsafe fn drop_noop(_ptr: *mut u8) {}
680
681/// Drop trampoline for output values. Receives the task base pointer.
682///
683/// Installed by `poll_join` after the future completes, replacing `drop_future`.
684///
685/// # Safety
686///
687/// `ptr` must point to a `Task` with a live `T` at `storage_offset`.
688unsafe fn drop_output<T>(ptr: *mut u8) {
689    let output_ptr = unsafe { ptr.add(storage_offset(ptr)) };
690    unsafe { std::ptr::drop_in_place(output_ptr.cast::<T>()) }
691}
692
693/// Free function for Box-allocated tasks.
694///
695/// Deallocates the memory without running destructors — the future/output
696/// was already dropped via `drop_task_future`, and the header fields
697/// are all trivial. Only the heap allocation needs to be freed.
698///
699/// # Safety
700///
701/// `ptr` must have been produced by `Box::into_raw(Box::new(Task<F>))`.
702/// The value at offset 64 must already be dropped.
703unsafe fn box_free<F>(ptr: *mut u8) {
704    // SAFETY: Layout matches what Box::new(Task<F>) allocated.
705    let layout = std::alloc::Layout::new::<Task<F>>();
706    unsafe { std::alloc::dealloc(ptr, layout) }
707}
708
709// Remove the dead new_joinable_boxed function that had a bad API.
710// box_spawn_joinable and new_joinable_slab are the correct APIs.
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715
716    #[test]
717    fn task_header_size() {
718        assert_eq!(TASK_HEADER_SIZE, 64);
719        assert_eq!(std::mem::size_of::<Task<()>>(), 64);
720    }
721
722    #[test]
723    fn task_layout_offsets() {
724        assert_eq!(std::mem::offset_of!(Task<()>, poll_fn), 0);
725        assert_eq!(std::mem::offset_of!(Task<()>, drop_fn), 8);
726        assert_eq!(std::mem::offset_of!(Task<()>, free_fn), 16);
727        assert_eq!(std::mem::offset_of!(Task<()>, is_queued), 24);
728        assert_eq!(std::mem::offset_of!(Task<()>, is_completed), 25);
729        assert_eq!(std::mem::offset_of!(Task<()>, ref_count), 26);
730        assert_eq!(std::mem::offset_of!(Task<()>, tracker_key), 28);
731        assert_eq!(std::mem::offset_of!(Task<()>, cross_next), 32);
732        assert_eq!(std::mem::offset_of!(Task<()>, join_waker), 40);
733        assert_eq!(std::mem::offset_of!(Task<()>, storage_offset), 56);
734        assert_eq!(std::mem::offset_of!(Task<()>, flags), 58);
735        assert_eq!(std::mem::offset_of!(Task<()>, _pad), 59);
736        assert_eq!(std::mem::offset_of!(Task<()>, storage), 64);
737    }
738
739    #[test]
740    fn task_size_with_future() {
741        #[allow(dead_code)]
742        struct SmallFuture([u8; 24]);
743        impl Future for SmallFuture {
744            type Output = ();
745            fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
746                Poll::Ready(())
747            }
748        }
749
750        // 64 byte header + 24 byte future = 88 bytes
751        assert_eq!(
752            std::mem::size_of::<Task<SmallFuture>>(),
753            TASK_HEADER_SIZE + 24
754        );
755    }
756
757    #[test]
758    fn queued_flag_via_pointer() {
759        struct Noop;
760        impl Future for Noop {
761            type Output = ();
762            fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
763                Poll::Ready(())
764            }
765        }
766
767        let task = Box::new(Task::new_boxed(Noop, 0));
768        let ptr = Box::into_raw(task) as *mut u8;
769
770        unsafe {
771            assert!(!is_queued(ptr));
772            set_queued(ptr, true);
773            assert!(is_queued(ptr));
774            set_queued(ptr, false);
775            assert!(!is_queued(ptr));
776
777            // Drop future, then free storage (matches executor lifecycle).
778            drop_task_future(ptr);
779            free_task(ptr);
780        }
781    }
782
783    #[test]
784    fn box_free_works() {
785        struct Noop;
786        impl Future for Noop {
787            type Output = ();
788            fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
789                Poll::Ready(())
790            }
791        }
792
793        let task = Box::new(Task::new_boxed(Noop, 42));
794        let ptr = Box::into_raw(task) as *mut u8;
795
796        unsafe {
797            assert_eq!(tracker_key(ptr), 42);
798            assert_eq!(ref_count(ptr), 1);
799            // Drop future, then free storage.
800            drop_task_future(ptr);
801            free_task(ptr);
802        }
803    }
804
805    #[test]
806    fn joinable_task_flags() {
807        struct Noop;
808        impl Future for Noop {
809            type Output = u64;
810            fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u64> {
811                Poll::Ready(42)
812            }
813        }
814
815        let ptr = box_spawn_joinable(Noop, 0);
816        unsafe {
817            assert!(has_join(ptr));
818            assert!(!is_aborted(ptr));
819            assert_eq!(ref_count(ptr), 2); // executor + JoinHandle
820
821            // Clean up
822            drop_task_future(ptr);
823            ref_dec(ptr); // JoinHandle ref
824            ref_dec(ptr); // executor ref
825            free_task(ptr);
826        }
827    }
828
829    // =========================================================================
830    // Panic safety — drop_fn transitions
831    // =========================================================================
832
833    /// Future whose Drop impl panics. Used to verify the drop_noop guard
834    /// in poll_join prevents double-drop.
835    struct PanickingDrop {
836        drop_count: *mut u32,
837    }
838
839    impl Future for PanickingDrop {
840        type Output = u64;
841        fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u64> {
842            Poll::Ready(42)
843        }
844    }
845
846    impl Drop for PanickingDrop {
847        fn drop(&mut self) {
848            unsafe { *self.drop_count += 1 };
849            panic!("intentional drop panic");
850        }
851    }
852
853    #[test]
854    fn poll_join_panic_in_drop_prevents_double_drop() {
855        use std::task::{RawWaker, RawWakerVTable, Waker};
856
857        let noop_vtable =
858            RawWakerVTable::new(|p| RawWaker::new(p, &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
859        // Need a named static for the clone fn to reference.
860        static NOOP_VTABLE: RawWakerVTable =
861            RawWakerVTable::new(|p| RawWaker::new(p, &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
862        let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE)) };
863        let mut cx = Context::from_waker(&waker);
864
865        let mut drop_count: u32 = 0;
866        let ptr = box_spawn_joinable(
867            PanickingDrop {
868                drop_count: &raw mut drop_count,
869            },
870            0,
871        );
872
873        // poll_join completes the future, then drops F — which panics.
874        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
875            poll_task(ptr, &mut cx)
876        }));
877
878        // The panic should have been caught.
879        assert!(result.is_err(), "expected panic from PanickingDrop");
880        // F was dropped exactly once (by poll_join, before the panic propagated).
881        assert_eq!(drop_count, 1, "future should be dropped exactly once");
882
883        // drop_fn should now be drop_noop — calling it must NOT double-drop F.
884        unsafe { drop_task_future(ptr) };
885        assert_eq!(
886            drop_count, 1,
887            "drop_task_future after panic must be a no-op (drop_noop)"
888        );
889
890        // Clean up: dec both refs (executor + JoinHandle), then free.
891        unsafe {
892            ref_dec(ptr);
893            ref_dec(ptr);
894            free_task(ptr);
895        }
896    }
897
898    #[test]
899    fn drop_fn_transitions_correctly_on_normal_completion() {
900        use std::task::{RawWaker, RawWakerVTable, Waker};
901
902        static NOOP_VTABLE: RawWakerVTable =
903            RawWakerVTable::new(|p| RawWaker::new(p, &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
904        let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE)) };
905        let mut cx = Context::from_waker(&waker);
906
907        static mut OUTPUT_DROP_COUNT: u32 = 0;
908        struct TrackedOutput;
909        impl Drop for TrackedOutput {
910            fn drop(&mut self) {
911                unsafe { OUTPUT_DROP_COUNT += 1 };
912            }
913        }
914
915        struct ProduceTracked;
916        impl Future for ProduceTracked {
917            type Output = TrackedOutput;
918            fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<TrackedOutput> {
919                Poll::Ready(TrackedOutput)
920            }
921        }
922
923        let ptr = box_spawn_joinable(ProduceTracked, 0);
924
925        // Poll to completion — F dropped, T written, drop_fn → drop_output.
926        let result = unsafe { poll_task(ptr, &mut cx) };
927        assert!(result.is_ready());
928
929        // drop_fn should now target T (TrackedOutput).
930        unsafe { OUTPUT_DROP_COUNT = 0 };
931        unsafe { drop_task_future(ptr) };
932        assert_eq!(
933            unsafe { OUTPUT_DROP_COUNT },
934            1,
935            "drop_fn should drop the output exactly once"
936        );
937
938        // Clean up.
939        unsafe {
940            ref_dec(ptr);
941            ref_dec(ptr);
942            free_task(ptr);
943        }
944    }
945}