Skip to main content

kozan_scheduler/
waker.rs

1//! Cross-thread waker — wakes the window thread from background threads.
2//!
3//! When a background task (fetch, file I/O, compute) completes on a
4//! thread pool, it needs to wake the window thread so the result can
5//! be processed. This module provides the thread-safe waking mechanism.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Background Thread              Window Thread
11//! ┌────────────────┐            ┌──────────────────────┐
12//! │ HTTP done!     │            │ Scheduler (parked)    │
13//! │ sender.send()──┼───────────→│ receiver.try_recv()   │
14//! │                │            │ → wakes up, runs task │
15//! └────────────────┘            └──────────────────────┘
16//! ```
17//!
18//! # Chrome mapping
19//!
20//! Chrome uses `base::TaskRunner::PostTask()` to cross thread boundaries.
21//! The `TaskRunner` carries the target thread + priority. In Rust, we use
22//! `mpsc::sync_channel` (bounded) which gives us the same semantics plus
23//! backpressure when the window thread is overwhelmed.
24//!
25//! # Performance
26//!
27//! - Bounded channel prevents unbounded memory growth from runaway senders.
28//! - `try_recv()` is non-blocking — fits into the event loop.
29//! - No allocation per send (channel pre-allocates buffer).
30
31use std::sync::Arc;
32use std::sync::mpsc;
33
34use crate::task::TaskPriority;
35
36/// Maximum number of pending cross-thread tasks before senders block.
37/// This prevents a runaway background thread from filling unbounded memory.
38/// 1024 is generous — if the window thread can't keep up with 1024 tasks
39/// per frame, something else is wrong.
40const CHANNEL_CAPACITY: usize = 1024;
41
42/// A task sent from a background thread to the window thread.
43///
44/// Must be `Send` because it crosses thread boundaries.
45/// Carries a priority so the scheduler can route it to the correct queue.
46///
47/// # Example
48///
49/// ```ignore
50/// // On background thread:
51/// let data = reqwest::get(url).await?;
52///
53/// // Send result back to window thread at Normal priority:
54/// sender.send(CrossThreadTask::new(TaskPriority::Normal, move || {
55///     btn.set_text(&data.title);  // safe! runs on window thread
56/// }));
57/// ```
58pub struct CrossThreadTask {
59    callback: Box<dyn FnOnce() + Send>,
60    priority: TaskPriority,
61}
62
63impl CrossThreadTask {
64    /// Create a new cross-thread task with the given priority.
65    #[inline]
66    pub fn new(priority: TaskPriority, callback: impl FnOnce() + Send + 'static) -> Self {
67        Self {
68            callback: Box::new(callback),
69            priority,
70        }
71    }
72
73    /// The priority this task should be routed to.
74    #[inline]
75    #[must_use]
76    pub fn priority(&self) -> TaskPriority {
77        self.priority
78    }
79
80    /// Execute this task on the window thread.
81    #[inline]
82    pub fn run(self) {
83        (self.callback)();
84    }
85}
86
87/// The sending half — cloned and given to background threads.
88///
89/// `Send + Clone` — clone this and move the clone to each background thread.
90/// Each clone can independently send tasks to the window thread.
91///
92/// Like Chrome's `scoped_refptr<base::TaskRunner>` which can be used
93/// from any thread to post tasks to a specific sequence.
94///
95/// Note: `WakeSender` is `Send` but not `Sync`. You cannot share it
96/// by `&` reference across threads — clone it and move the clone.
97pub struct WakeSender {
98    sender: mpsc::SyncSender<CrossThreadTask>,
99    /// Called after a successful send to unpark the view thread.
100    notify: Option<Arc<dyn Fn() + Send + Sync>>,
101}
102
103impl WakeSender {
104    /// Wire a "wake the event loop" callback into this sender.
105    ///
106    /// After every successful `send` / `post`, `notify` is called so the
107    /// view thread stops parking and drains the cross-thread task queue.
108    pub fn set_notify(&mut self, notify: Arc<dyn Fn() + Send + Sync>) {
109        self.notify = Some(notify);
110    }
111
112    /// Send a task to the window thread with the given priority.
113    ///
114    /// Blocks if the channel is full (backpressure). Returns `Err`
115    /// if the receiver has been dropped (window closed).
116    #[inline]
117    pub fn send(&self, task: CrossThreadTask) -> Result<(), SendError> {
118        let result = self.sender.send(task).map_err(|_| SendError::Disconnected);
119        if result.is_ok() {
120            if let Some(notify) = &self.notify {
121                notify();
122            }
123        }
124        result
125    }
126
127    /// Try to send without blocking. Returns `Err(Full)` if channel is full.
128    #[inline]
129    pub fn try_send(&self, task: CrossThreadTask) -> Result<(), SendError> {
130        let result = self.sender.try_send(task).map_err(|e| match e {
131            mpsc::TrySendError::Full(_) => SendError::Full,
132            mpsc::TrySendError::Disconnected(_) => SendError::Disconnected,
133        });
134        if result.is_ok() {
135            if let Some(notify) = &self.notify {
136                notify();
137            }
138        }
139        result
140    }
141
142    /// Convenience: send a closure at Normal priority.
143    #[inline]
144    pub fn post(&self, callback: impl FnOnce() + Send + 'static) -> Result<(), SendError> {
145        self.send(CrossThreadTask::new(TaskPriority::Normal, callback))
146    }
147
148    /// Convenience: send a closure at a specific priority.
149    #[inline]
150    pub fn post_with_priority(
151        &self,
152        priority: TaskPriority,
153        callback: impl FnOnce() + Send + 'static,
154    ) -> Result<(), SendError> {
155        self.send(CrossThreadTask::new(priority, callback))
156    }
157}
158
159impl Clone for WakeSender {
160    fn clone(&self) -> Self {
161        Self {
162            sender: self.sender.clone(),
163            notify: self.notify.clone(),
164        }
165    }
166}
167
168/// The receiving half — stays on the window thread.
169///
170/// `!Send` by design (via `PhantomData`). Only the window thread should
171/// drain incoming cross-thread tasks.
172pub struct WakeReceiver {
173    receiver: mpsc::Receiver<CrossThreadTask>,
174    _not_send: std::marker::PhantomData<*const ()>,
175}
176
177impl WakeReceiver {
178    /// Try to receive a cross-thread task without blocking.
179    ///
180    /// Returns `None` if no tasks are available.
181    #[inline]
182    #[must_use]
183    pub fn try_recv(&self) -> Option<CrossThreadTask> {
184        self.receiver.try_recv().ok()
185    }
186
187    /// Process all pending cross-thread tasks through a callback.
188    ///
189    /// Calls `f` for each task. No allocation — iterates inline.
190    /// Returns the number of tasks processed.
191    ///
192    /// ```ignore
193    /// receiver.drain_into(|task| {
194    ///     task_queue.push(Task::new(task.priority(), move || task.run()));
195    /// });
196    /// ```
197    pub fn drain_into(&self, mut f: impl FnMut(CrossThreadTask)) -> usize {
198        let mut count = 0;
199        while let Some(task) = self.try_recv() {
200            f(task);
201            count += 1;
202        }
203        count
204    }
205}
206
207/// Error returned when sending to a closed or full channel.
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub enum SendError {
210    /// The window thread's receiver was dropped (window closed).
211    Disconnected,
212    /// The channel is full (backpressure — window thread is overwhelmed).
213    Full,
214}
215
216impl std::fmt::Display for SendError {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        match self {
219            SendError::Disconnected => write!(f, "window thread disconnected"),
220            SendError::Full => write!(f, "cross-thread channel full"),
221        }
222    }
223}
224
225impl std::error::Error for SendError {}
226
227/// Create a linked sender/receiver pair.
228///
229/// The `WakeSender` can be cloned and sent to background threads.
230/// The `WakeReceiver` stays on the window thread.
231/// The channel is bounded to `CHANNEL_CAPACITY` to prevent
232/// unbounded memory growth from runaway background senders.
233///
234/// ```ignore
235/// let (sender, receiver) = cross_thread_channel();
236///
237/// // Give sender to background threads
238/// tokio::spawn(async move {
239///     let data = fetch(url).await;
240///     sender.post(move || btn.set_text(&data)).unwrap();
241/// });
242///
243/// // On window thread event loop
244/// receiver.drain_into(|task| {
245///     queue.push(Task::new(task.priority(), move || task.run()));
246/// });
247/// ```
248#[must_use]
249pub fn cross_thread_channel() -> (WakeSender, WakeReceiver) {
250    let (sender, receiver) = mpsc::sync_channel(CHANNEL_CAPACITY);
251    (
252        WakeSender {
253            sender,
254            notify: None,
255        },
256        WakeReceiver {
257            receiver,
258            _not_send: std::marker::PhantomData,
259        },
260    )
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use std::sync::Arc;
267    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
268
269    #[test]
270    fn send_and_receive_task() {
271        let (sender, receiver) = cross_thread_channel();
272
273        let called = Arc::new(AtomicBool::new(false));
274        let c = called.clone();
275
276        sender
277            .post(move || c.store(true, Ordering::SeqCst))
278            .unwrap();
279
280        let task = receiver.try_recv().unwrap();
281        task.run();
282        assert!(called.load(Ordering::SeqCst));
283    }
284
285    #[test]
286    fn try_recv_empty_returns_none() {
287        let (_sender, receiver) = cross_thread_channel();
288        assert!(receiver.try_recv().is_none());
289    }
290
291    #[test]
292    fn drain_into_runs_all() {
293        let (sender, receiver) = cross_thread_channel();
294        let counter = Arc::new(AtomicU32::new(0));
295
296        for _ in 0..5 {
297            let c = counter.clone();
298            sender
299                .post(move || {
300                    c.fetch_add(1, Ordering::SeqCst);
301                })
302                .unwrap();
303        }
304
305        let executed = receiver.drain_into(|task| task.run());
306        assert_eq!(executed, 5);
307        assert_eq!(counter.load(Ordering::SeqCst), 5);
308    }
309
310    #[test]
311    fn sender_clone_works() {
312        let (sender, receiver) = cross_thread_channel();
313        let sender2 = sender.clone();
314
315        sender.post(|| {}).unwrap();
316        sender2.post(|| {}).unwrap();
317
318        assert!(receiver.try_recv().is_some());
319        assert!(receiver.try_recv().is_some());
320        assert!(receiver.try_recv().is_none());
321    }
322
323    #[test]
324    fn send_from_another_thread() {
325        let (sender, receiver) = cross_thread_channel();
326
327        let handle = std::thread::spawn(move || {
328            sender.post(|| {}).unwrap();
329        });
330
331        handle.join().unwrap();
332        assert!(receiver.try_recv().is_some());
333    }
334
335    #[test]
336    fn send_to_dropped_receiver_errors() {
337        let (sender, receiver) = cross_thread_channel();
338        drop(receiver);
339
340        let result = sender.post(|| {});
341        assert_eq!(result, Err(SendError::Disconnected));
342    }
343
344    #[test]
345    fn send_error_display() {
346        assert_eq!(
347            SendError::Disconnected.to_string(),
348            "window thread disconnected"
349        );
350        assert_eq!(SendError::Full.to_string(), "cross-thread channel full");
351    }
352
353    #[test]
354    fn cross_thread_task_carries_priority() {
355        let task = CrossThreadTask::new(TaskPriority::Input, || {});
356        assert_eq!(task.priority(), TaskPriority::Input);
357    }
358
359    #[test]
360    fn post_with_priority() {
361        let (sender, receiver) = cross_thread_channel();
362        sender
363            .post_with_priority(TaskPriority::Input, || {})
364            .unwrap();
365
366        let task = receiver.try_recv().unwrap();
367        assert_eq!(task.priority(), TaskPriority::Input);
368    }
369}