Skip to main content

mlx_native/
encoder_worker.rs

1//! Persistent encoder worker thread (ADR-028 iter-380).
2//!
3//! Provides a long-lived worker thread for parallel command-buffer encoding,
4//! mirroring llama.cpp's `n_cb=2` GCD `dispatch_apply` pattern (see
5//! `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:438+550`).
6//!
7//! Per the existing `forward_decode` comment at line 4592-4595:
8//! > Threaded wait DURING encode: -43 tok/s (thread spawn + Metal
9//! > cross-thread synchronization overhead on command queue)
10//!
11//! That falsified attempt used per-token `std::thread::spawn`, paying the
12//! ~50 µs spawn cost on every decode token.  This module amortizes that cost
13//! by spawning the worker ONCE at process start, then submitting work via a
14//! crossbeam-style mpsc channel.
15//!
16//! # Usage
17//! ```ignore
18//! use mlx_native::encoder_worker::EncoderWorker;
19//!
20//! // At process start (e.g., model load):
21//! let worker = EncoderWorker::spawn();
22//!
23//! // Per-token (or per-encoding-task):
24//! let (done_tx, done_rx) = std::sync::mpsc::channel();
25//! worker.submit(move || {
26//!     // ... encode work into a fresh CommandEncoder ...
27//!     done_tx.send(()).ok();
28//! });
29//!
30//! // Main thread can do its own work in parallel.
31//!
32//! // Eventually wait for worker to finish:
33//! done_rx.recv().expect("worker died");
34//! ```
35//!
36//! # Safety / lifetime
37//!
38//! - The worker thread is detached on `EncoderWorker::shutdown()` only.  The
39//!   thread holds a `Receiver<Closure>`; when all `Sender` clones drop, the
40//!   `iter()` loop exits naturally and the thread joins.
41//! - Closures must be `'static` (they cross thread boundaries).  Use `Arc`
42//!   for shared state.
43//! - Closures must be `Send` (Rust's `mpsc::channel` enforces this).
44
45use std::sync::mpsc;
46use std::thread;
47
48/// A submitted encoding task.  Boxed FnOnce because each task may capture
49/// different types.  `Send + 'static` is required so the closure can be moved
50/// to the worker thread.
51type Task = Box<dyn FnOnce() + Send + 'static>;
52
53/// A persistent worker thread that executes submitted closures sequentially
54/// (in submission order).  Designed for command-buffer encoding workloads
55/// where the cost of `std::thread::spawn` per task would dwarf the work.
56///
57/// The worker is single-threaded; submissions execute one-at-a-time.  For
58/// parallelism with the main thread, the typical pattern is:
59///
60/// 1. Spawn one `EncoderWorker` at process start.
61/// 2. Per token: submit half the encoding work to the worker, encode the
62///    other half on the main thread, wait for both to complete.
63///
64/// `EncoderWorker` is NOT a thread pool — for that, spawn multiple workers.
65pub struct EncoderWorker {
66    tx: Option<mpsc::Sender<Task>>,
67    handle: Option<thread::JoinHandle<()>>,
68}
69
70impl EncoderWorker {
71    /// Spawn a new persistent worker thread.  The thread runs until either
72    /// [`Self::shutdown`] is called or the `EncoderWorker` is dropped.
73    ///
74    /// The worker's run-loop blocks on the channel; CPU usage is zero when
75    /// idle.
76    pub fn spawn() -> Self {
77        let (tx, rx) = mpsc::channel::<Task>();
78        let handle = thread::Builder::new()
79            .name("mlx-native-encoder-worker".into())
80            .spawn(move || {
81                // Run loop: pull tasks until the channel is closed.
82                while let Ok(task) = rx.recv() {
83                    task();
84                }
85            })
86            .expect("failed to spawn encoder worker thread");
87        Self { tx: Some(tx), handle: Some(handle) }
88    }
89
90    /// Submit a closure for execution on the worker thread.  Returns
91    /// immediately; the closure runs asynchronously.
92    ///
93    /// To wait for the closure to complete, the caller must arrange its own
94    /// signaling (e.g., a `(tx, rx)` channel pair captured by the closure).
95    ///
96    /// # Errors
97    /// Returns `Err` if the worker thread has been shut down or has panicked.
98    pub fn submit<F>(&self, f: F) -> Result<(), &'static str>
99    where
100        F: FnOnce() + Send + 'static,
101    {
102        match self.tx.as_ref() {
103            Some(tx) => tx.send(Box::new(f)).map_err(|_| "worker thread is dead"),
104            None => Err("worker has been shut down"),
105        }
106    }
107
108    /// Cleanly shut down the worker.  Drops the sender (closing the channel),
109    /// then joins the worker thread.  Returns once the worker has processed
110    /// all in-flight tasks.
111    pub fn shutdown(&mut self) {
112        // Drop sender → channel closes → worker's recv() returns Err → loop exits.
113        self.tx = None;
114        if let Some(h) = self.handle.take() {
115            // Ignore worker-panic errors during shutdown (already shutting down).
116            let _ = h.join();
117        }
118    }
119}
120
121impl Drop for EncoderWorker {
122    fn drop(&mut self) {
123        self.shutdown();
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use std::sync::atomic::{AtomicU32, Ordering};
131    use std::sync::Arc;
132
133    #[test]
134    fn submit_runs_closure() {
135        let worker = EncoderWorker::spawn();
136        let counter = Arc::new(AtomicU32::new(0));
137        let counter_clone = Arc::clone(&counter);
138
139        let (done_tx, done_rx) = std::sync::mpsc::channel();
140        worker.submit(move || {
141            counter_clone.fetch_add(1, Ordering::SeqCst);
142            done_tx.send(()).ok();
143        }).expect("submit");
144
145        done_rx.recv().expect("worker did not signal completion");
146        assert_eq!(counter.load(Ordering::SeqCst), 1);
147    }
148
149    #[test]
150    fn submissions_run_in_order() {
151        let worker = EncoderWorker::spawn();
152        let order = Arc::new(std::sync::Mutex::new(Vec::new()));
153        let mut signals = Vec::new();
154
155        for i in 0..5 {
156            let order_clone = Arc::clone(&order);
157            let (tx, rx) = std::sync::mpsc::channel();
158            signals.push(rx);
159            worker.submit(move || {
160                order_clone.lock().expect("lock").push(i);
161                tx.send(()).ok();
162            }).expect("submit");
163        }
164
165        for rx in signals {
166            rx.recv().expect("worker panicked");
167        }
168
169        let final_order = order.lock().expect("lock").clone();
170        assert_eq!(final_order, vec![0, 1, 2, 3, 4],
171            "tasks ran out of order: {:?}", final_order);
172    }
173
174    #[test]
175    fn shutdown_waits_for_in_flight_work() {
176        let mut worker = EncoderWorker::spawn();
177        let counter = Arc::new(AtomicU32::new(0));
178
179        for _ in 0..3 {
180            let counter_clone = Arc::clone(&counter);
181            worker.submit(move || {
182                std::thread::sleep(std::time::Duration::from_millis(10));
183                counter_clone.fetch_add(1, Ordering::SeqCst);
184            }).expect("submit");
185        }
186
187        worker.shutdown();
188        // After shutdown, all 3 tasks should have completed.
189        assert_eq!(counter.load(Ordering::SeqCst), 3);
190    }
191
192    #[test]
193    fn submit_after_shutdown_errors() {
194        let mut worker = EncoderWorker::spawn();
195        worker.shutdown();
196        assert!(worker.submit(|| {}).is_err());
197    }
198
199    // ---------------------------------------------------------------------
200    // Metal-dispatch integration tests (ADR-028 iter-381)
201    // ---------------------------------------------------------------------
202
203    #[cfg(target_vendor = "apple")]
204    #[test]
205    fn worker_can_create_metal_encoder_and_commit() {
206        // Validates: MlxDevice can be cloned + Arc'd + sent to worker thread,
207        // CommandEncoder can be created from worker thread, commit_and_wait
208        // works from worker thread.
209        let device = crate::MlxDevice::new().expect("MlxDevice");
210        let device_arc = Arc::new(device);
211        let worker = EncoderWorker::spawn();
212
213        let device_clone = Arc::clone(&device_arc);
214        let (done_tx, done_rx) = std::sync::mpsc::channel::<Result<(), String>>();
215        worker.submit(move || {
216            let result = (|| -> Result<(), String> {
217                let mut enc = device_clone.command_encoder()
218                    .map_err(|e| format!("enc create: {e}"))?;
219                enc.commit_and_wait()
220                    .map_err(|e| format!("commit_and_wait: {e}"))?;
221                Ok(())
222            })();
223            done_tx.send(result).ok();
224        }).expect("submit");
225
226        let result = done_rx.recv().expect("worker died");
227        assert!(result.is_ok(), "worker Metal encoder failed: {:?}", result);
228    }
229
230    #[cfg(target_vendor = "apple")]
231    #[test]
232    fn worker_can_dispatch_real_kernel_zero_buffer() {
233        // Validates: worker thread can register a kernel, allocate a buffer,
234        // dispatch a real Metal compute kernel, commit + wait, and the host
235        // sees the GPU-modified buffer contents after worker completion.
236        use crate::DType;
237        use crate::ops::moe_dispatch::moe_zero_buffer_encode;
238
239        let device = crate::MlxDevice::new().expect("MlxDevice");
240        let device_arc = Arc::new(device);
241        let worker = EncoderWorker::spawn();
242
243        // Allocate a buffer initialized to 1.0; worker should zero it via GPU.
244        const N: usize = 1024;
245        let mut buf = device_arc
246            .alloc_buffer(N * 4, DType::F32, vec![N])
247            .expect("alloc");
248        for v in buf.as_mut_slice::<f32>().expect("init slice").iter_mut() {
249            *v = 1.0;
250        }
251
252        // Wrap buffer in Arc<Mutex> so worker can mutate via &mut.
253        let buf_arc = Arc::new(std::sync::Mutex::new(buf));
254        let device_clone = Arc::clone(&device_arc);
255        let buf_clone = Arc::clone(&buf_arc);
256
257        let (done_tx, done_rx) = std::sync::mpsc::channel::<Result<(), String>>();
258        worker.submit(move || {
259            let result = (|| -> Result<(), String> {
260                let mut registry = crate::KernelRegistry::new();
261                let mut enc = device_clone.command_encoder()
262                    .map_err(|e| format!("enc: {e}"))?;
263                let buf_guard = buf_clone.lock().expect("lock");
264                moe_zero_buffer_encode(
265                    &mut enc, &mut registry, device_clone.metal_device(),
266                    &buf_guard, N,
267                ).map_err(|e| format!("zero_buffer: {e}"))?;
268                drop(buf_guard); // release before commit so commit doesn't deadlock on a re-lock
269                enc.commit_and_wait().map_err(|e| format!("commit: {e}"))?;
270                Ok(())
271            })();
272            done_tx.send(result).ok();
273        }).expect("submit");
274
275        done_rx.recv().expect("worker died").expect("worker error");
276
277        let buf_guard = buf_arc.lock().expect("lock");
278        let slice = buf_guard.as_slice::<f32>().expect("read");
279        for (i, &v) in slice.iter().enumerate() {
280            assert_eq!(v, 0.0, "element {i} not zeroed: {v}");
281        }
282    }
283}