nexus_async_rt/task.rs
1//! Task storage: header + future/output union in a contiguous allocation.
2//!
3//! Each task is a `Task<F>` struct. The raw pointer to the allocation
4//! IS the task handle — no index layer, no separate metadata store.
5//!
6//! The waker holds the raw pointer directly. `wake()` sets `is_queued`
7//! and pushes the pointer to the ready queue. Zero allocations.
8//!
9//! Tasks can be allocated via Box (default) or slab (power user).
10//! The `free_fn` in the header knows how to deallocate regardless
11//! of which allocator was used.
12//!
13//! ## Union storage
14//!
15//! The slot at `storage_offset` holds either `F` (the future) or `T` (the output),
16//! never both. While running, `F` is live. When the future completes,
17//! `poll_join` drops `F` in place and writes `T` to the same bytes.
18//! `drop_fn` is overwritten from `drop_fn::<F>` to `drop_output::<T>`
19//! so subsequent cleanup targets the correct type.
20
21use std::cell::UnsafeCell;
22use std::future::Future;
23use std::marker::PhantomData;
24use std::pin::Pin;
25use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU16, Ordering};
26use std::task::{Context, Poll, Waker};
27
28// =============================================================================
29// Task flags
30// =============================================================================
31
32/// JoinHandle exists for this task.
33const HAS_JOIN: u8 = 0b001;
34/// JoinHandle consumed the output via poll.
35const OUTPUT_TAKEN: u8 = 0b010;
36/// abort() was called.
37const ABORTED: u8 = 0b100;
38
39// =============================================================================
40// Task layout
41// =============================================================================
42
43/// Header size in bytes. Must match the layout of `Task<F>` before the
44/// `future` field.
45pub const TASK_HEADER_SIZE: usize = 64;
46
47/// Task header + storage in a contiguous allocation. `repr(C)` for
48/// deterministic layout.
49///
50/// `S` is the storage type — either just `F` (fire-and-forget) or a union
51/// of `F` and `T` (joinable). The header is always 64 bytes regardless of `S`.
52///
53/// Layout (64-bit):
54/// ```text
55/// offset 0: poll_fn (8B, fn pointer — polls the future)
56/// offset 8: drop_fn (8B, fn pointer — drops F or T in place)
57/// offset 16: free_fn (8B, fn pointer — deallocates the task storage)
58/// offset 24: is_queued (1B, AtomicBool — cross-thread wakers CAS this)
59/// offset 25: is_completed (1B, AtomicBool — cross-thread reads with Acquire)
60/// offset 26: ref_count (2B, AtomicU16 — number of live references)
61/// offset 28: tracker_key (4B, u32 — index in Executor::all_tasks slab)
62/// offset 32: cross_next (8B, AtomicPtr — intrusive cross-thread wake queue)
63/// offset 40: join_waker (16B, UnsafeCell<Option<Waker>>)
64/// offset 56: storage_offset (2B, u16 — byte offset to storage field)
65/// offset 58: flags (1B, Cell<u8> — HAS_JOIN | OUTPUT_TAKEN | ABORTED)
66/// offset 59: _pad (5B)
67/// offset 64: storage (S bytes — future F or union { F, T })
68/// ```
69#[repr(C)]
70pub(crate) struct Task<S> {
71 /// Polls the future. Receives the task base pointer.
72 poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>,
73 /// Drops the value at `storage_offset` (future F or output T). Receives base pointer.
74 drop_fn: unsafe fn(*mut u8),
75 /// Deallocates the task storage.
76 free_fn: unsafe fn(*mut u8),
77 is_queued: AtomicBool,
78 /// Set when the future completes or is aborted.
79 is_completed: AtomicBool,
80 /// Number of live references (executor + waker clones + JoinHandle).
81 ref_count: AtomicU16,
82 /// Index into the Executor's `all_tasks` slab.
83 tracker_key: u32,
84 /// Intrusive next pointer for the cross-thread wake queue.
85 cross_next: AtomicPtr<u8>,
86 /// Waker for the task awaiting this JoinHandle.
87 join_waker: UnsafeCell<Option<Waker>>,
88 /// Byte offset from task base to the storage field.
89 /// Set at construction from `offset_of!(Task<S>, storage)`.
90 storage_offset: u16,
91 /// Packed flags: HAS_JOIN | OUTPUT_TAKEN | ABORTED.
92 /// Single-threaded — no atomics needed.
93 flags: std::cell::Cell<u8>,
94 /// Padding to reach 64 bytes.
95 _pad: [u8; 5],
96 storage: S,
97}
98
99/// Union storage for joinable tasks. Sized to fit both the future F
100/// and the output T in the same allocation.
101#[repr(C)]
102pub(crate) union FutureOrOutput<F, T> {
103 pub(crate) future: std::mem::ManuallyDrop<F>,
104 pub(crate) output: std::mem::ManuallyDrop<T>,
105}
106
107// Static assertion: header layout matches TASK_HEADER_SIZE.
108const _: () = {
109 assert!(std::mem::size_of::<Task<()>>() == TASK_HEADER_SIZE);
110};
111
112impl<F: Future<Output = ()> + 'static> Task<F> {
113 /// Construct a fire-and-forget task (no JoinHandle) with Box-based free.
114 ///
115 /// Used internally for tests and low-level task construction.
116 /// `ref_count = 1` (executor only), `HAS_JOIN` not set.
117 ///
118 /// # Why `Output = ()` is required
119 ///
120 /// This uses `poll_join::<F>` which writes T at offset 64 after dropping F.
121 /// The storage is `F` (not `FutureOrOutput<F, T>`), so it's only sized for F.
122 /// With `T = ()` (ZST), the write is zero-size and the `drop_fn` overwrite
123 /// to `drop_output::<()>` is a no-op. Relaxing this bound to non-ZST T
124 /// would write T into storage not sized for it — UB.
125 #[cfg(test)]
126 #[inline]
127 pub(crate) fn new_boxed(future: F, tracker_key: u32) -> Self {
128 Self {
129 poll_fn: poll_join::<F>,
130 drop_fn: drop_future::<F>,
131 free_fn: box_free::<F>,
132 is_queued: AtomicBool::new(false),
133 is_completed: AtomicBool::new(false),
134 ref_count: AtomicU16::new(1),
135 tracker_key,
136 cross_next: AtomicPtr::new(std::ptr::null_mut()),
137 join_waker: UnsafeCell::new(None),
138 flags: std::cell::Cell::new(0),
139 storage_offset: std::mem::offset_of!(Task<F>, storage) as u16,
140 _pad: [0; 5],
141 storage: future,
142 }
143 }
144}
145
146/// Allocate a joinable Box task and return the raw pointer.
147///
148/// The task has `ref_count = 2` (executor + JoinHandle) and `HAS_JOIN` set.
149/// The allocation is sized for `max(size_of::<F>(), size_of::<T>())` via
150/// the `FutureOrOutput<F, T>` union.
151pub(crate) fn box_spawn_joinable<F>(future: F, tracker_key: u32) -> *mut u8
152where
153 F: Future + 'static,
154 F::Output: 'static,
155{
156 type Storage<F> = FutureOrOutput<F, <F as Future>::Output>;
157
158 let task: Task<Storage<F>> = Task {
159 poll_fn: poll_join::<F>,
160 drop_fn: drop_future_in_union::<F>,
161 free_fn: box_free::<Storage<F>>,
162 is_queued: AtomicBool::new(false),
163 is_completed: AtomicBool::new(false),
164 ref_count: AtomicU16::new(2), // executor + JoinHandle
165 tracker_key,
166 cross_next: AtomicPtr::new(std::ptr::null_mut()),
167 join_waker: UnsafeCell::new(None),
168 flags: std::cell::Cell::new(HAS_JOIN),
169 storage_offset: std::mem::offset_of!(Task<Storage<F>>, storage) as u16,
170 _pad: [0; 5],
171 storage: FutureOrOutput {
172 future: std::mem::ManuallyDrop::new(future),
173 },
174 };
175 Box::into_raw(Box::new(task)) as *mut u8
176}
177
178/// Construct a joinable task for slab allocation.
179///
180/// Returns the task struct to be copied into a slab slot. Uses the
181/// `FutureOrOutput<F, T>` union so the allocation fits both.
182pub(crate) fn new_joinable_slab<F>(
183 future: F,
184 tracker_key: u32,
185 free_fn: unsafe fn(*mut u8),
186) -> Task<FutureOrOutput<F, F::Output>>
187where
188 F: Future + 'static,
189 F::Output: 'static,
190{
191 type Storage<F> = FutureOrOutput<F, <F as Future>::Output>;
192
193 Task {
194 poll_fn: poll_join::<F>,
195 drop_fn: drop_future_in_union::<F>,
196 free_fn,
197 is_queued: AtomicBool::new(false),
198 is_completed: AtomicBool::new(false),
199 ref_count: AtomicU16::new(2), // executor + JoinHandle
200 tracker_key,
201 cross_next: AtomicPtr::new(std::ptr::null_mut()),
202 join_waker: UnsafeCell::new(None),
203 flags: std::cell::Cell::new(HAS_JOIN),
204 storage_offset: std::mem::offset_of!(Task<Storage<F>>, storage) as u16,
205 _pad: [0; 5],
206 storage: FutureOrOutput {
207 future: std::mem::ManuallyDrop::new(future),
208 },
209 }
210}
211
212// =============================================================================
213// Task handle — raw pointer operations
214// =============================================================================
215
216/// Opaque task identifier. Wraps the raw pointer to the task.
217/// The pointer is stable for the task's lifetime.
218#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
219pub(crate) struct TaskId(pub(crate) *mut u8);
220
221impl TaskId {
222 /// Returns the raw pointer to the task.
223 #[allow(dead_code)]
224 pub(crate) fn as_ptr(&self) -> *mut u8 {
225 self.0
226 }
227}
228
229// =============================================================================
230// JoinHandle
231// =============================================================================
232
233/// Handle to a spawned task. Await to get the result.
234///
235/// Dropping the handle detaches the task — it continues running but the
236/// output is dropped when the task completes. Use [`abort()`](Self::abort)
237/// to cancel the task.
238///
239/// `JoinHandle` is `!Send` and `!Sync` — it must stay on the executor thread.
240#[must_use = "dropping a JoinHandle detaches the task — await it or call .abort()"]
241pub struct JoinHandle<T> {
242 ptr: *mut u8,
243 _marker: PhantomData<T>,
244 _not_send: PhantomData<*const ()>, // !Send + !Sync
245}
246
247impl<T: 'static> Future for JoinHandle<T> {
248 type Output = T;
249
250 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
251 let ptr = self.ptr;
252
253 // SAFETY: ptr is valid — JoinHandle holds a ref (refcount >= 1).
254 if unsafe { is_completed(ptr) } {
255 let flags = unsafe { task_flags(ptr) };
256 assert!(
257 flags & ABORTED == 0,
258 "polled JoinHandle after task was aborted"
259 );
260 // SAFETY: Task completed, so poll_join already transitioned the union
261 // from F to T. The output is live at storage_offset. ptr::read moves
262 // it out (bitwise copy). OUTPUT_TAKEN prevents double-read.
263 let output_ptr = unsafe { ptr.add(storage_offset(ptr)) };
264 let value = unsafe { std::ptr::read(output_ptr.cast::<T>()) };
265 unsafe { set_flag(ptr, OUTPUT_TAKEN) };
266 Poll::Ready(value)
267 } else {
268 // SAFETY: Task still running, single-threaded — safe to write waker.
269 unsafe { set_join_waker(ptr, cx.waker().clone()) };
270 Poll::Pending
271 }
272 }
273}
274
275impl<T> JoinHandle<T> {
276 pub(crate) fn new(ptr: *mut u8) -> Self {
277 Self {
278 ptr,
279 _marker: PhantomData,
280 _not_send: PhantomData,
281 }
282 }
283
284 /// Returns `true` if the task has completed (output is ready).
285 pub fn is_finished(&self) -> bool {
286 unsafe { is_completed(self.ptr) }
287 }
288
289 /// Abort the task and consume the handle.
290 ///
291 /// The future is dropped on the next poll cycle. Consumes the handle
292 /// so it cannot be awaited after abort — this is enforced at the type
293 /// level rather than via a runtime panic.
294 ///
295 /// Returns `true` if the task was still running, `false` if it had
296 /// already completed (output is dropped by `JoinHandle::drop`).
297 #[must_use = "returns whether the task was still running"]
298 pub fn abort(self) -> bool {
299 let ptr = self.ptr;
300 let was_running = !unsafe { is_completed(ptr) };
301 if was_running {
302 unsafe { set_flag(ptr, ABORTED) };
303 }
304 // self is consumed — Drop runs, which clears HAS_JOIN,
305 // takes the join waker, and decrements refcount.
306 was_running
307 }
308}
309
310impl<T> Drop for JoinHandle<T> {
311 fn drop(&mut self) {
312 let ptr = self.ptr;
313 // SAFETY: ptr is valid — JoinHandle holds a ref (refcount >= 1).
314 // All accessor calls below are single-threaded and target valid
315 // header fields at known offsets.
316 let flags = unsafe { task_flags(ptr) };
317
318 if unsafe { is_completed(ptr) } && (flags & OUTPUT_TAKEN == 0) && (flags & ABORTED == 0) {
319 // Task completed but output was never read — drop it.
320 // SAFETY: poll_join overwrote drop_fn to drop_output::<T>,
321 // so this drops the output T (not the future F).
322 unsafe { drop_task_future(ptr) };
323 }
324
325 // Clear HAS_JOIN so complete_task knows nobody is waiting.
326 unsafe { clear_flag(ptr, HAS_JOIN) };
327
328 // If we previously polled to Pending, a cloned waker is stored in the
329 // task. Clear it so the parent task's refcount isn't kept alive until
330 // the child completes. take() returns None if no waker was stored.
331 let _ = unsafe { take_join_waker(ptr) };
332
333 // Release our reference. If refcount hits 0, the task is complete and
334 // all other refs (executor, wakers) are gone — defer the free.
335 let should_free = unsafe { ref_dec(ptr) };
336 if should_free {
337 // SAFETY: refcount is 0, task is completed. Can't free directly
338 // because we may be outside the poll cycle. defer_free pushes to
339 // TLS deferred list (or leaks if TLS unavailable — Executor::drop
340 // catches those).
341 unsafe { defer_free_slot(ptr) };
342 }
343 }
344}
345
346/// Push a task to the deferred free list, or free immediately if outside poll.
347///
348/// # Safety
349///
350/// `ptr` must point to a completed task with ref_count 0.
351unsafe fn defer_free_slot(ptr: *mut u8) {
352 unsafe { crate::waker::defer_free(ptr) };
353}
354
355// =============================================================================
356// Task header accessor functions
357// =============================================================================
358
359/// Read the `tracker_key` from a task pointer.
360///
361/// # Safety
362///
363/// `ptr` must point to a live `Task<F>`.
364#[inline]
365pub(crate) unsafe fn tracker_key(ptr: *mut u8) -> u32 {
366 // SAFETY: tracker_key is at offset 28 in repr(C) Task.
367 unsafe { *(ptr.add(28).cast::<u32>()) }
368}
369
370/// Increment the waker refcount. Called on waker clone.
371///
372/// # Safety
373///
374/// `ptr` must point to a live `Task<F>`.
375#[inline]
376pub(crate) unsafe fn ref_inc(ptr: *mut u8) {
377 // SAFETY: ref_count is AtomicU16 at offset 26 in repr(C) Task.
378 let rc = unsafe { &*ptr.add(26).cast::<AtomicU16>() };
379 let prev = rc.fetch_add(1, Ordering::Relaxed);
380 assert!(prev < u16::MAX, "waker refcount overflow");
381}
382
383/// Decrement the refcount. Returns true if refcount hit 0 (slot can be freed).
384///
385/// # Safety
386///
387/// `ptr` must point to a live (or completed) `Task<F>`.
388#[inline]
389pub(crate) unsafe fn ref_dec(ptr: *mut u8) -> bool {
390 // SAFETY: ref_count is AtomicU16 at offset 26.
391 let rc = unsafe { &*ptr.add(26).cast::<AtomicU16>() };
392 let prev = rc.fetch_sub(1, Ordering::AcqRel);
393 debug_assert!(prev > 0, "waker refcount underflow");
394 prev == 1
395}
396
397/// Read the refcount.
398///
399/// # Safety
400///
401/// `ptr` must point to a live `Task<F>`.
402#[allow(dead_code)]
403#[inline]
404pub(crate) unsafe fn ref_count(ptr: *mut u8) -> u16 {
405 // SAFETY: ref_count is AtomicU16 at offset 26.
406 unsafe { &*ptr.add(26).cast::<AtomicU16>() }.load(Ordering::Relaxed)
407}
408
409/// Set the is_completed flag.
410///
411/// # Safety
412///
413/// `ptr` must point to a live `Task<F>`.
414#[inline]
415pub(crate) unsafe fn set_completed(ptr: *mut u8) {
416 // SAFETY: is_completed is AtomicBool at offset 25 in repr(C) Task.
417 unsafe { &*ptr.add(25).cast::<AtomicBool>() }.store(true, Ordering::Release);
418}
419
420/// Read the is_completed flag.
421///
422/// # Safety
423///
424/// `ptr` must point to a (possibly completed) `Task<F>`.
425#[inline]
426pub(crate) unsafe fn is_completed(ptr: *mut u8) -> bool {
427 // SAFETY: is_completed is AtomicBool at offset 25.
428 unsafe { &*ptr.add(25).cast::<AtomicBool>() }.load(Ordering::Acquire)
429}
430
431/// Get a raw pointer to the `cross_next` atomic pointer.
432///
433/// # Safety
434///
435/// `ptr` must point to a live `Task<F>`.
436#[inline]
437pub(crate) unsafe fn cross_next(ptr: *mut u8) -> *const AtomicPtr<u8> {
438 // SAFETY: cross_next is at offset 32 in repr(C) Task.
439 unsafe { ptr.add(32).cast::<AtomicPtr<u8>>() }
440}
441
442/// Read the `is_queued` flag from a task pointer.
443///
444/// # Safety
445///
446/// `ptr` must point to a live `Task<F>`.
447#[inline]
448pub(crate) unsafe fn is_queued(ptr: *mut u8) -> bool {
449 // SAFETY: is_queued is AtomicBool at offset 24 in repr(C) Task.
450 unsafe { &*ptr.add(24).cast::<AtomicBool>() }.load(Ordering::Relaxed)
451}
452
453/// Set the `is_queued` flag on a task.
454///
455/// # Safety
456///
457/// `ptr` must point to a live `Task<F>`.
458#[inline]
459pub(crate) unsafe fn set_queued(ptr: *mut u8, queued: bool) {
460 // SAFETY: is_queued is AtomicBool at offset 24 in repr(C) Task.
461 unsafe { &*ptr.add(24).cast::<AtomicBool>() }.store(queued, Ordering::Relaxed);
462}
463
464/// Atomically try to set `is_queued` from false to true. Returns true if
465/// successful (was not queued). Used by cross-thread wakers.
466///
467/// # Safety
468///
469/// `ptr` must point to a live `Task<F>`.
470#[inline]
471pub(crate) unsafe fn try_set_queued(ptr: *mut u8) -> bool {
472 // SAFETY: is_queued is AtomicBool at offset 24.
473 let queued = unsafe { &*ptr.add(24).cast::<AtomicBool>() };
474 queued
475 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
476 .is_ok()
477}
478
479/// Read the storage offset from the task header.
480///
481/// # Safety
482///
483/// `ptr` must point to a live `Task<S>`.
484#[inline]
485pub(crate) unsafe fn storage_offset(ptr: *mut u8) -> usize {
486 // SAFETY: storage_offset is u16 at offset 56 in repr(C) Task.
487 unsafe { *(ptr.add(56).cast::<u16>()) as usize }
488}
489
490/// Read task_flags.
491///
492/// # Safety
493///
494/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
495#[inline]
496unsafe fn task_flags(ptr: *mut u8) -> u8 {
497 // SAFETY: flags is Cell<u8> at offset 58.
498 unsafe { &*ptr.add(58).cast::<std::cell::Cell<u8>>() }.get()
499}
500
501/// Set a flag bit in task_flags.
502///
503/// # Safety
504///
505/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
506#[inline]
507unsafe fn set_flag(ptr: *mut u8, flag: u8) {
508 let cell = unsafe { &*ptr.add(58).cast::<std::cell::Cell<u8>>() };
509 cell.set(cell.get() | flag);
510}
511
512/// Clear a flag bit in task_flags.
513///
514/// # Safety
515///
516/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
517#[inline]
518unsafe fn clear_flag(ptr: *mut u8, flag: u8) {
519 let cell = unsafe { &*ptr.add(58).cast::<std::cell::Cell<u8>>() };
520 cell.set(cell.get() & !flag);
521}
522
523/// Check if HAS_JOIN flag is set.
524///
525/// # Safety
526///
527/// `ptr` must point to a live `Task<F>`.
528#[inline]
529pub(crate) unsafe fn has_join(ptr: *mut u8) -> bool {
530 (unsafe { task_flags(ptr) }) & HAS_JOIN != 0
531}
532
533/// Check if ABORTED flag is set.
534///
535/// # Safety
536///
537/// `ptr` must point to a live `Task<F>`.
538#[inline]
539pub(crate) unsafe fn is_aborted(ptr: *mut u8) -> bool {
540 (unsafe { task_flags(ptr) }) & ABORTED != 0
541}
542
543/// Store a waker for the JoinHandle awaiter.
544///
545/// # Safety
546///
547/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
548#[inline]
549unsafe fn set_join_waker(ptr: *mut u8, waker: Waker) {
550 // SAFETY: join_waker is UnsafeCell<Option<Waker>> at offset 40.
551 let cell = unsafe { &*ptr.add(40).cast::<UnsafeCell<Option<Waker>>>() };
552 unsafe { *cell.get() = Some(waker) };
553}
554
555/// Take the join waker (if any).
556///
557/// # Safety
558///
559/// `ptr` must point to a live `Task<F>`. Single-threaded access only.
560#[inline]
561pub(crate) unsafe fn take_join_waker(ptr: *mut u8) -> Option<Waker> {
562 let cell = unsafe { &*ptr.add(40).cast::<UnsafeCell<Option<Waker>>>() };
563 unsafe { (*cell.get()).take() }
564}
565
566/// Poll the task's future.
567///
568/// # Safety
569///
570/// `ptr` must point to a live `Task<F>`.
571/// The future must not have been dropped.
572#[inline]
573pub(crate) unsafe fn poll_task(ptr: *mut u8, cx: &mut Context<'_>) -> Poll<()> {
574 // SAFETY: poll_fn is at offset 0 in repr(C) Task.
575 let poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()> =
576 unsafe { *(ptr as *const unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>) };
577 // Pass the task base pointer — the trampoline reads storage_offset.
578 unsafe { poll_fn(ptr, cx) }
579}
580
581/// Drop the task's future (or output) in place.
582///
583/// # Safety
584///
585/// `ptr` must point to a live `Task<F>`. Must only be called once.
586#[inline]
587pub(crate) unsafe fn drop_task_future(ptr: *mut u8) {
588 // SAFETY: drop_fn is at offset 8 in repr(C) Task.
589 let drop_fn: unsafe fn(*mut u8) = unsafe { *(ptr.add(8) as *const unsafe fn(*mut u8)) };
590 // Pass base pointer — the trampoline reads storage_offset.
591 unsafe { drop_fn(ptr) }
592}
593
594/// Call the task's free function to deallocate its storage.
595///
596/// # Safety
597///
598/// `ptr` must point to a `Task<F>` whose future has already been dropped.
599/// Must only be called once (after refcount reaches 0).
600#[inline]
601pub(crate) unsafe fn free_task(ptr: *mut u8) {
602 // SAFETY: free_fn is at offset 16 in repr(C) Task.
603 let free_fn: unsafe fn(*mut u8) = unsafe { *(ptr.add(16) as *const unsafe fn(*mut u8)) };
604 unsafe { free_fn(ptr) }
605}
606
607// =============================================================================
608// Type-erased vtable functions
609// =============================================================================
610
611/// Poll trampoline for joinable tasks (Output = T).
612///
613/// On completion: drops F, writes T into the same location, overwrites
614/// drop_fn to target T instead of F.
615///
616/// # Safety
617///
618/// `ptr` must point to a live `Task<F>`. The future must not have been dropped.
619unsafe fn poll_join<F: Future>(ptr: *mut u8, cx: &mut Context<'_>) -> Poll<()>
620where
621 F::Output: 'static,
622{
623 // Check if aborted
624 if unsafe { is_aborted(ptr) } {
625 return Poll::Ready(());
626 }
627
628 let future_ptr = unsafe { ptr.add(storage_offset(ptr)) };
629 let future = unsafe { Pin::new_unchecked(&mut *future_ptr.cast::<F>()) };
630 match future.poll(cx) {
631 Poll::Pending => Poll::Pending,
632 Poll::Ready(value) => {
633 let drop_fn_slot = unsafe { ptr.add(8).cast::<unsafe fn(*mut u8)>() };
634 // 1. Overwrite drop_fn to no-op BEFORE dropping F.
635 // If F::drop() panics, this prevents double-drop —
636 // subsequent cleanup calls the no-op instead of
637 // drop_future_in_union on a partially-dropped F.
638 // The output (value) is dropped during unwind (stack-owned).
639 unsafe { *drop_fn_slot = drop_noop };
640 // 2. Drop the future in place (panic-safe now)
641 unsafe { std::ptr::drop_in_place(future_ptr.cast::<F>()) };
642 // 3. Write output T into the same location
643 unsafe { std::ptr::write(future_ptr.cast::<F::Output>(), value) };
644 // 4. Overwrite drop_fn: now drops T instead of F
645 unsafe { *drop_fn_slot = drop_output::<F::Output> };
646 Poll::Ready(())
647 }
648 }
649}
650
651/// Drop trampoline for futures stored directly (fire-and-forget tasks).
652///
653/// # Safety
654///
655/// `ptr` must point to a live `Task<F>` with a live future at `storage_offset`.
656#[cfg(test)]
657unsafe fn drop_future<F>(ptr: *mut u8) {
658 let future_ptr = unsafe { ptr.add(storage_offset(ptr)) };
659 unsafe { std::ptr::drop_in_place(future_ptr.cast::<F>()) }
660}
661
662/// Drop trampoline for futures stored in FutureOrOutput union.
663///
664/// # Safety
665///
666/// `ptr` must point to a `Task<FutureOrOutput<F, T>>` with a live future.
667unsafe fn drop_future_in_union<F: Future>(ptr: *mut u8) {
668 let storage_ptr = unsafe { ptr.add(storage_offset(ptr)) };
669 // The future is at the start of the union (same offset as the union itself).
670 unsafe { std::ptr::drop_in_place(storage_ptr.cast::<F>()) }
671}
672
673/// No-op drop trampoline. Installed temporarily during the F→T transition
674/// in `poll_join` to prevent double-drop if `F::drop()` panics.
675///
676/// # Safety
677///
678/// Always safe — does nothing.
679unsafe fn drop_noop(_ptr: *mut u8) {}
680
681/// Drop trampoline for output values. Receives the task base pointer.
682///
683/// Installed by `poll_join` after the future completes, replacing `drop_future`.
684///
685/// # Safety
686///
687/// `ptr` must point to a `Task` with a live `T` at `storage_offset`.
688unsafe fn drop_output<T>(ptr: *mut u8) {
689 let output_ptr = unsafe { ptr.add(storage_offset(ptr)) };
690 unsafe { std::ptr::drop_in_place(output_ptr.cast::<T>()) }
691}
692
693/// Free function for Box-allocated tasks.
694///
695/// Deallocates the memory without running destructors — the future/output
696/// was already dropped via `drop_task_future`, and the header fields
697/// are all trivial. Only the heap allocation needs to be freed.
698///
699/// # Safety
700///
701/// `ptr` must have been produced by `Box::into_raw(Box::new(Task<F>))`.
702/// The value at offset 64 must already be dropped.
703unsafe fn box_free<F>(ptr: *mut u8) {
704 // SAFETY: Layout matches what Box::new(Task<F>) allocated.
705 let layout = std::alloc::Layout::new::<Task<F>>();
706 unsafe { std::alloc::dealloc(ptr, layout) }
707}
708
709// Remove the dead new_joinable_boxed function that had a bad API.
710// box_spawn_joinable and new_joinable_slab are the correct APIs.
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715
716 #[test]
717 fn task_header_size() {
718 assert_eq!(TASK_HEADER_SIZE, 64);
719 assert_eq!(std::mem::size_of::<Task<()>>(), 64);
720 }
721
722 #[test]
723 fn task_layout_offsets() {
724 assert_eq!(std::mem::offset_of!(Task<()>, poll_fn), 0);
725 assert_eq!(std::mem::offset_of!(Task<()>, drop_fn), 8);
726 assert_eq!(std::mem::offset_of!(Task<()>, free_fn), 16);
727 assert_eq!(std::mem::offset_of!(Task<()>, is_queued), 24);
728 assert_eq!(std::mem::offset_of!(Task<()>, is_completed), 25);
729 assert_eq!(std::mem::offset_of!(Task<()>, ref_count), 26);
730 assert_eq!(std::mem::offset_of!(Task<()>, tracker_key), 28);
731 assert_eq!(std::mem::offset_of!(Task<()>, cross_next), 32);
732 assert_eq!(std::mem::offset_of!(Task<()>, join_waker), 40);
733 assert_eq!(std::mem::offset_of!(Task<()>, storage_offset), 56);
734 assert_eq!(std::mem::offset_of!(Task<()>, flags), 58);
735 assert_eq!(std::mem::offset_of!(Task<()>, _pad), 59);
736 assert_eq!(std::mem::offset_of!(Task<()>, storage), 64);
737 }
738
739 #[test]
740 fn task_size_with_future() {
741 #[allow(dead_code)]
742 struct SmallFuture([u8; 24]);
743 impl Future for SmallFuture {
744 type Output = ();
745 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
746 Poll::Ready(())
747 }
748 }
749
750 // 64 byte header + 24 byte future = 88 bytes
751 assert_eq!(
752 std::mem::size_of::<Task<SmallFuture>>(),
753 TASK_HEADER_SIZE + 24
754 );
755 }
756
757 #[test]
758 fn queued_flag_via_pointer() {
759 struct Noop;
760 impl Future for Noop {
761 type Output = ();
762 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
763 Poll::Ready(())
764 }
765 }
766
767 let task = Box::new(Task::new_boxed(Noop, 0));
768 let ptr = Box::into_raw(task) as *mut u8;
769
770 unsafe {
771 assert!(!is_queued(ptr));
772 set_queued(ptr, true);
773 assert!(is_queued(ptr));
774 set_queued(ptr, false);
775 assert!(!is_queued(ptr));
776
777 // Drop future, then free storage (matches executor lifecycle).
778 drop_task_future(ptr);
779 free_task(ptr);
780 }
781 }
782
783 #[test]
784 fn box_free_works() {
785 struct Noop;
786 impl Future for Noop {
787 type Output = ();
788 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
789 Poll::Ready(())
790 }
791 }
792
793 let task = Box::new(Task::new_boxed(Noop, 42));
794 let ptr = Box::into_raw(task) as *mut u8;
795
796 unsafe {
797 assert_eq!(tracker_key(ptr), 42);
798 assert_eq!(ref_count(ptr), 1);
799 // Drop future, then free storage.
800 drop_task_future(ptr);
801 free_task(ptr);
802 }
803 }
804
805 #[test]
806 fn joinable_task_flags() {
807 struct Noop;
808 impl Future for Noop {
809 type Output = u64;
810 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u64> {
811 Poll::Ready(42)
812 }
813 }
814
815 let ptr = box_spawn_joinable(Noop, 0);
816 unsafe {
817 assert!(has_join(ptr));
818 assert!(!is_aborted(ptr));
819 assert_eq!(ref_count(ptr), 2); // executor + JoinHandle
820
821 // Clean up
822 drop_task_future(ptr);
823 ref_dec(ptr); // JoinHandle ref
824 ref_dec(ptr); // executor ref
825 free_task(ptr);
826 }
827 }
828
829 // =========================================================================
830 // Panic safety — drop_fn transitions
831 // =========================================================================
832
833 /// Future whose Drop impl panics. Used to verify the drop_noop guard
834 /// in poll_join prevents double-drop.
835 struct PanickingDrop {
836 drop_count: *mut u32,
837 }
838
839 impl Future for PanickingDrop {
840 type Output = u64;
841 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u64> {
842 Poll::Ready(42)
843 }
844 }
845
846 impl Drop for PanickingDrop {
847 fn drop(&mut self) {
848 unsafe { *self.drop_count += 1 };
849 panic!("intentional drop panic");
850 }
851 }
852
853 #[test]
854 fn poll_join_panic_in_drop_prevents_double_drop() {
855 use std::task::{RawWaker, RawWakerVTable, Waker};
856
857 let noop_vtable =
858 RawWakerVTable::new(|p| RawWaker::new(p, &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
859 // Need a named static for the clone fn to reference.
860 static NOOP_VTABLE: RawWakerVTable =
861 RawWakerVTable::new(|p| RawWaker::new(p, &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
862 let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE)) };
863 let mut cx = Context::from_waker(&waker);
864
865 let mut drop_count: u32 = 0;
866 let ptr = box_spawn_joinable(
867 PanickingDrop {
868 drop_count: &raw mut drop_count,
869 },
870 0,
871 );
872
873 // poll_join completes the future, then drops F — which panics.
874 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
875 poll_task(ptr, &mut cx)
876 }));
877
878 // The panic should have been caught.
879 assert!(result.is_err(), "expected panic from PanickingDrop");
880 // F was dropped exactly once (by poll_join, before the panic propagated).
881 assert_eq!(drop_count, 1, "future should be dropped exactly once");
882
883 // drop_fn should now be drop_noop — calling it must NOT double-drop F.
884 unsafe { drop_task_future(ptr) };
885 assert_eq!(
886 drop_count, 1,
887 "drop_task_future after panic must be a no-op (drop_noop)"
888 );
889
890 // Clean up: dec both refs (executor + JoinHandle), then free.
891 unsafe {
892 ref_dec(ptr);
893 ref_dec(ptr);
894 free_task(ptr);
895 }
896 }
897
898 #[test]
899 fn drop_fn_transitions_correctly_on_normal_completion() {
900 use std::task::{RawWaker, RawWakerVTable, Waker};
901
902 static NOOP_VTABLE: RawWakerVTable =
903 RawWakerVTable::new(|p| RawWaker::new(p, &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
904 let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE)) };
905 let mut cx = Context::from_waker(&waker);
906
907 static mut OUTPUT_DROP_COUNT: u32 = 0;
908 struct TrackedOutput;
909 impl Drop for TrackedOutput {
910 fn drop(&mut self) {
911 unsafe { OUTPUT_DROP_COUNT += 1 };
912 }
913 }
914
915 struct ProduceTracked;
916 impl Future for ProduceTracked {
917 type Output = TrackedOutput;
918 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<TrackedOutput> {
919 Poll::Ready(TrackedOutput)
920 }
921 }
922
923 let ptr = box_spawn_joinable(ProduceTracked, 0);
924
925 // Poll to completion — F dropped, T written, drop_fn → drop_output.
926 let result = unsafe { poll_task(ptr, &mut cx) };
927 assert!(result.is_ready());
928
929 // drop_fn should now target T (TrackedOutput).
930 unsafe { OUTPUT_DROP_COUNT = 0 };
931 unsafe { drop_task_future(ptr) };
932 assert_eq!(
933 unsafe { OUTPUT_DROP_COUNT },
934 1,
935 "drop_fn should drop the output exactly once"
936 );
937
938 // Clean up.
939 unsafe {
940 ref_dec(ptr);
941 ref_dec(ptr);
942 free_task(ptr);
943 }
944 }
945}