moduvex_runtime/executor/waker.rs
1//! Custom `RawWakerVTable` implementation.
2//!
3//! Each `Waker` holds an `Arc<TaskHeader>` cast to a raw `*const ()`.
4//! The four vtable functions implement the `RawWaker` contract:
5//!
6//! | function | action |
7//! |---------------|-----------------------------------------------------|
8//! | `clone_waker` | `Arc::clone` — increments refcount |
9//! | `wake` | schedule task, consume (decrement) Arc |
10//! | `wake_by_ref` | schedule task, keep Arc alive |
11//! | `drop_waker` | `Arc::from_raw` then drop — decrements refcount |
12//!
13//! Safety contract: the data pointer is always a valid `Arc<TaskHeader>` that
14//! was created via `Arc::into_raw`. All four functions restore it to an `Arc`
15//! before performing any operation, maintaining the reference count correctly.
16
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::Arc;
19use std::task::{RawWaker, RawWakerVTable, Waker};
20
21use super::scheduler::GlobalQueue;
22use super::task::{TaskHeader, STATE_IDLE, STATE_SCHEDULED};
23
24// ── Vtable ────────────────────────────────────────────────────────────────────
25
26/// The single static vtable shared by all task wakers.
27static TASK_WAKER_VTABLE: RawWakerVTable =
28 RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker);
29
30// ── Public entry point ────────────────────────────────────────────────────────
31
32/// Construct a `Waker` from an `Arc<TaskHeader>` and a reference to the
33/// global queue into which the waker will push the task when fired.
34///
35/// `notifier` is optional: in multi-threaded mode it writes to a worker's
36/// self-pipe to unpark it after re-scheduling a task.
37pub(crate) fn make_waker(
38 header: Arc<TaskHeader>,
39 queue: Arc<GlobalQueue>,
40) -> Waker {
41 make_waker_with_notifier(header, queue, None)
42}
43
44/// Like `make_waker` but with an explicit `WorkerNotifier` for multi-threaded mode.
45pub(crate) fn make_waker_with_notifier(
46 header: Arc<TaskHeader>,
47 queue: Arc<GlobalQueue>,
48 notifier: Option<Arc<WorkerNotifier>>,
49) -> Waker {
50 let data = Arc::new(WakerData {
51 header,
52 queue,
53 notifier,
54 });
55 let ptr = Arc::into_raw(data) as *const ();
56 let raw = RawWaker::new(ptr, &TASK_WAKER_VTABLE);
57 // SAFETY: The vtable functions correctly implement the RawWaker contract
58 // (see module doc). `ptr` is a valid Arc pointer.
59 unsafe { Waker::from_raw(raw) }
60}
61
62// ── WorkerNotifier ────────────────────────────────────────────────────────────
63
64/// Holds write-end fds of all worker self-pipes. Used to unpark a worker
65/// after pushing a task to GlobalQueue.
66pub(crate) struct WorkerNotifier {
67 wake_fds: std::sync::Mutex<Vec<i32>>,
68 next: AtomicUsize,
69}
70
71
72impl WorkerNotifier {
73 pub(crate) fn new() -> Self {
74 Self {
75 wake_fds: std::sync::Mutex::new(Vec::new()),
76 next: AtomicUsize::new(0),
77 }
78 }
79
80 /// Register a worker's self-pipe write fd.
81 pub(crate) fn add_fd(&self, fd: i32) {
82 self.wake_fds.lock().unwrap().push(fd);
83 }
84
85 /// Write one byte to a worker's self-pipe (round-robin) to unpark it.
86 #[cfg(unix)]
87 pub(crate) fn notify_one(&self) {
88 let fds = self.wake_fds.lock().unwrap();
89 if fds.is_empty() {
90 return;
91 }
92 let idx = self.next.fetch_add(1, Ordering::Relaxed) % fds.len();
93 let fd = fds[idx];
94 drop(fds);
95 unsafe {
96 let b: u8 = 1;
97 libc::write(fd, &b as *const u8 as *const _, 1);
98 }
99 }
100
101 #[cfg(not(unix))]
102 pub(crate) fn notify_one(&self) {}
103}
104
105// ── WakerData ─────────────────────────────────────────────────────────────────
106
107/// Heap allocation backing each `Waker`. Bundles the task header with the
108/// queue reference needed to reschedule the task.
109struct WakerData {
110 header: Arc<TaskHeader>,
111 queue: Arc<GlobalQueue>,
112 notifier: Option<Arc<WorkerNotifier>>,
113}
114
115
116// ── Vtable functions ──────────────────────────────────────────────────────────
117
118/// Reconstruct an `Arc<WakerData>` from a raw pointer WITHOUT consuming it,
119/// then immediately `forget` the Arc so the refcount is unchanged.
120///
121/// # Safety
122/// `ptr` must be a valid `Arc<WakerData>` pointer produced by `Arc::into_raw`.
123#[inline]
124unsafe fn data_ref(ptr: *const ()) -> std::mem::ManuallyDrop<Arc<WakerData>> {
125 // SAFETY: `ptr` is always `Arc::into_raw(Arc<WakerData>)`.
126 std::mem::ManuallyDrop::new(Arc::from_raw(ptr as *const WakerData))
127}
128
129unsafe fn clone_waker(ptr: *const ()) -> RawWaker {
130 // SAFETY: `ptr` is a valid Arc<WakerData> pointer (contract of RawWaker).
131 let data = data_ref(ptr);
132 // Increment refcount by cloning, then leak the clone.
133 let cloned = Arc::clone(&*data);
134 let new_ptr = Arc::into_raw(cloned) as *const ();
135 RawWaker::new(new_ptr, &TASK_WAKER_VTABLE)
136}
137
138unsafe fn wake(ptr: *const ()) {
139 // SAFETY: `ptr` is `Arc::into_raw(Arc<WakerData>)`; consuming it here
140 // correctly decrements the refcount when `data` is dropped at end of fn.
141 let data = Arc::from_raw(ptr as *const WakerData);
142 schedule_task(&data);
143 // `data` drops here → Arc refcount decremented.
144}
145
146unsafe fn wake_by_ref(ptr: *const ()) {
147 // SAFETY: same pointer contract; we borrow without consuming.
148 let data = data_ref(ptr);
149 schedule_task(&data);
150 // ManuallyDrop — refcount unchanged.
151}
152
153unsafe fn drop_waker(ptr: *const ()) {
154 // SAFETY: Reconstruct and immediately drop to decrement Arc refcount.
155 drop(Arc::from_raw(ptr as *const WakerData));
156}
157
158// ── Scheduling helper ─────────────────────────────────────────────────────────
159
160/// Attempt to transition the task from IDLE → SCHEDULED and push it to the
161/// global queue. If the task is already SCHEDULED/RUNNING, skip (it will be
162/// re-polled automatically).
163fn schedule_task(data: &WakerData) {
164 let header = &data.header;
165 // Only transition IDLE → SCHEDULED. Other states:
166 // SCHEDULED: already queued, nothing to do.
167 // RUNNING: executor holds it; it will check for re-schedule after poll.
168 // COMPLETED/CANCELLED: done, ignore wake.
169 let prev = header.state.compare_exchange(
170 STATE_IDLE,
171 STATE_SCHEDULED,
172 Ordering::AcqRel,
173 Ordering::Relaxed,
174 );
175 if prev.is_ok() {
176 data.queue.push_header(Arc::clone(header));
177 // Notify a parked worker to check the global queue.
178 if let Some(ref notifier) = data.notifier {
179 notifier.notify_one();
180 }
181 }
182}
183
184// ── Tests ─────────────────────────────────────────────────────────────────────
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::executor::task::{Task, STATE_IDLE, STATE_SCHEDULED};
190 use std::sync::atomic::Ordering;
191
192 fn make_test_waker(task: &Task) -> (Waker, Arc<GlobalQueue>) {
193 let q = Arc::new(GlobalQueue::new());
194 let w = make_waker(Arc::clone(&task.header), Arc::clone(&q));
195 (w, q)
196 }
197
198 #[test]
199 fn waker_clone_increments_refcount() {
200 let (task, _jh) = Task::new(async { 1u32 });
201 task.header.state.store(STATE_IDLE, Ordering::Release);
202 let q = Arc::new(GlobalQueue::new());
203 let w1 = make_waker(Arc::clone(&task.header), Arc::clone(&q));
204 let w2 = w1.clone();
205 // Both wakers exist — refcount is at least 2 on top of task.header.
206 drop(w1);
207 drop(w2);
208 // No panic = correct refcount management.
209 }
210
211 #[test]
212 fn wake_by_ref_schedules_idle_task() {
213 let (task, _jh) = Task::new(async { 2u32 });
214 task.header.state.store(STATE_IDLE, Ordering::Release);
215 let (waker, queue) = make_test_waker(&task);
216 waker.wake_by_ref();
217 assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
218 assert!(queue.pop().is_some());
219 }
220
221 #[test]
222 fn wake_consumes_and_schedules() {
223 let (task, _jh) = Task::new(async { 3u32 });
224 task.header.state.store(STATE_IDLE, Ordering::Release);
225 let (waker, queue) = make_test_waker(&task);
226 waker.wake(); // consumes the waker
227 assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
228 assert!(queue.pop().is_some());
229 }
230
231 #[test]
232 fn wake_noop_when_already_scheduled() {
233 let (task, _jh) = Task::new(async { 4u32 });
234 task.header.state.store(STATE_SCHEDULED, Ordering::Release);
235 let (waker, queue) = make_test_waker(&task);
236 waker.wake_by_ref();
237 // State stays SCHEDULED, queue stays empty (CAS rejected).
238 assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
239 assert!(queue.pop().is_none());
240 }
241}