mlx_native/encoder_worker.rs
1//! Persistent encoder worker thread (ADR-028 iter-380).
2//!
3//! Provides a long-lived worker thread for parallel command-buffer encoding,
4//! mirroring llama.cpp's `n_cb=2` GCD `dispatch_apply` pattern (see
5//! `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m:438+550`).
6//!
7//! Per the existing `forward_decode` comment at line 4592-4595:
8//! > Threaded wait DURING encode: -43 tok/s (thread spawn + Metal
9//! > cross-thread synchronization overhead on command queue)
10//!
11//! That falsified attempt used per-token `std::thread::spawn`, paying the
12//! ~50 µs spawn cost on every decode token. This module amortizes that cost
13//! by spawning the worker ONCE at process start, then submitting work via a
14//! crossbeam-style mpsc channel.
15//!
16//! # Usage
17//! ```ignore
18//! use mlx_native::encoder_worker::EncoderWorker;
19//!
20//! // At process start (e.g., model load):
21//! let worker = EncoderWorker::spawn();
22//!
23//! // Per-token (or per-encoding-task):
24//! let (done_tx, done_rx) = std::sync::mpsc::channel();
25//! worker.submit(move || {
26//! // ... encode work into a fresh CommandEncoder ...
27//! done_tx.send(()).ok();
28//! });
29//!
30//! // Main thread can do its own work in parallel.
31//!
32//! // Eventually wait for worker to finish:
33//! done_rx.recv().expect("worker died");
34//! ```
35//!
36//! # Safety / lifetime
37//!
38//! - The worker thread is detached on `EncoderWorker::shutdown()` only. The
39//! thread holds a `Receiver<Closure>`; when all `Sender` clones drop, the
40//! `iter()` loop exits naturally and the thread joins.
41//! - Closures must be `'static` (they cross thread boundaries). Use `Arc`
42//! for shared state.
43//! - Closures must be `Send` (Rust's `mpsc::channel` enforces this).
44
45use std::sync::mpsc;
46use std::thread;
47
48/// A submitted encoding task. Boxed FnOnce because each task may capture
49/// different types. `Send + 'static` is required so the closure can be moved
50/// to the worker thread.
51type Task = Box<dyn FnOnce() + Send + 'static>;
52
53/// A persistent worker thread that executes submitted closures sequentially
54/// (in submission order). Designed for command-buffer encoding workloads
55/// where the cost of `std::thread::spawn` per task would dwarf the work.
56///
57/// The worker is single-threaded; submissions execute one-at-a-time. For
58/// parallelism with the main thread, the typical pattern is:
59///
60/// 1. Spawn one `EncoderWorker` at process start.
61/// 2. Per token: submit half the encoding work to the worker, encode the
62/// other half on the main thread, wait for both to complete.
63///
64/// `EncoderWorker` is NOT a thread pool — for that, spawn multiple workers.
65pub struct EncoderWorker {
66 tx: Option<mpsc::Sender<Task>>,
67 handle: Option<thread::JoinHandle<()>>,
68}
69
70impl EncoderWorker {
71 /// Spawn a new persistent worker thread. The thread runs until either
72 /// [`Self::shutdown`] is called or the `EncoderWorker` is dropped.
73 ///
74 /// The worker's run-loop blocks on the channel; CPU usage is zero when
75 /// idle.
76 pub fn spawn() -> Self {
77 let (tx, rx) = mpsc::channel::<Task>();
78 let handle = thread::Builder::new()
79 .name("mlx-native-encoder-worker".into())
80 .spawn(move || {
81 // Run loop: pull tasks until the channel is closed.
82 while let Ok(task) = rx.recv() {
83 task();
84 }
85 })
86 .expect("failed to spawn encoder worker thread");
87 Self { tx: Some(tx), handle: Some(handle) }
88 }
89
90 /// Submit a closure for execution on the worker thread. Returns
91 /// immediately; the closure runs asynchronously.
92 ///
93 /// To wait for the closure to complete, the caller must arrange its own
94 /// signaling (e.g., a `(tx, rx)` channel pair captured by the closure).
95 ///
96 /// # Errors
97 /// Returns `Err` if the worker thread has been shut down or has panicked.
98 pub fn submit<F>(&self, f: F) -> Result<(), &'static str>
99 where
100 F: FnOnce() + Send + 'static,
101 {
102 match self.tx.as_ref() {
103 Some(tx) => tx.send(Box::new(f)).map_err(|_| "worker thread is dead"),
104 None => Err("worker has been shut down"),
105 }
106 }
107
108 /// Cleanly shut down the worker. Drops the sender (closing the channel),
109 /// then joins the worker thread. Returns once the worker has processed
110 /// all in-flight tasks.
111 pub fn shutdown(&mut self) {
112 // Drop sender → channel closes → worker's recv() returns Err → loop exits.
113 self.tx = None;
114 if let Some(h) = self.handle.take() {
115 // Ignore worker-panic errors during shutdown (already shutting down).
116 let _ = h.join();
117 }
118 }
119}
120
121impl Drop for EncoderWorker {
122 fn drop(&mut self) {
123 self.shutdown();
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use std::sync::atomic::{AtomicU32, Ordering};
131 use std::sync::Arc;
132
133 #[test]
134 fn submit_runs_closure() {
135 let worker = EncoderWorker::spawn();
136 let counter = Arc::new(AtomicU32::new(0));
137 let counter_clone = Arc::clone(&counter);
138
139 let (done_tx, done_rx) = std::sync::mpsc::channel();
140 worker.submit(move || {
141 counter_clone.fetch_add(1, Ordering::SeqCst);
142 done_tx.send(()).ok();
143 }).expect("submit");
144
145 done_rx.recv().expect("worker did not signal completion");
146 assert_eq!(counter.load(Ordering::SeqCst), 1);
147 }
148
149 #[test]
150 fn submissions_run_in_order() {
151 let worker = EncoderWorker::spawn();
152 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
153 let mut signals = Vec::new();
154
155 for i in 0..5 {
156 let order_clone = Arc::clone(&order);
157 let (tx, rx) = std::sync::mpsc::channel();
158 signals.push(rx);
159 worker.submit(move || {
160 order_clone.lock().expect("lock").push(i);
161 tx.send(()).ok();
162 }).expect("submit");
163 }
164
165 for rx in signals {
166 rx.recv().expect("worker panicked");
167 }
168
169 let final_order = order.lock().expect("lock").clone();
170 assert_eq!(final_order, vec![0, 1, 2, 3, 4],
171 "tasks ran out of order: {:?}", final_order);
172 }
173
174 #[test]
175 fn shutdown_waits_for_in_flight_work() {
176 let mut worker = EncoderWorker::spawn();
177 let counter = Arc::new(AtomicU32::new(0));
178
179 for _ in 0..3 {
180 let counter_clone = Arc::clone(&counter);
181 worker.submit(move || {
182 std::thread::sleep(std::time::Duration::from_millis(10));
183 counter_clone.fetch_add(1, Ordering::SeqCst);
184 }).expect("submit");
185 }
186
187 worker.shutdown();
188 // After shutdown, all 3 tasks should have completed.
189 assert_eq!(counter.load(Ordering::SeqCst), 3);
190 }
191
192 #[test]
193 fn submit_after_shutdown_errors() {
194 let mut worker = EncoderWorker::spawn();
195 worker.shutdown();
196 assert!(worker.submit(|| {}).is_err());
197 }
198
199 // ---------------------------------------------------------------------
200 // Metal-dispatch integration tests (ADR-028 iter-381)
201 // ---------------------------------------------------------------------
202
203 #[cfg(target_vendor = "apple")]
204 #[test]
205 fn worker_can_create_metal_encoder_and_commit() {
206 // Validates: MlxDevice can be cloned + Arc'd + sent to worker thread,
207 // CommandEncoder can be created from worker thread, commit_and_wait
208 // works from worker thread.
209 let device = crate::MlxDevice::new().expect("MlxDevice");
210 let device_arc = Arc::new(device);
211 let worker = EncoderWorker::spawn();
212
213 let device_clone = Arc::clone(&device_arc);
214 let (done_tx, done_rx) = std::sync::mpsc::channel::<Result<(), String>>();
215 worker.submit(move || {
216 let result = (|| -> Result<(), String> {
217 let mut enc = device_clone.command_encoder()
218 .map_err(|e| format!("enc create: {e}"))?;
219 enc.commit_and_wait()
220 .map_err(|e| format!("commit_and_wait: {e}"))?;
221 Ok(())
222 })();
223 done_tx.send(result).ok();
224 }).expect("submit");
225
226 let result = done_rx.recv().expect("worker died");
227 assert!(result.is_ok(), "worker Metal encoder failed: {:?}", result);
228 }
229
230 #[cfg(target_vendor = "apple")]
231 #[test]
232 fn worker_can_dispatch_real_kernel_zero_buffer() {
233 // Validates: worker thread can register a kernel, allocate a buffer,
234 // dispatch a real Metal compute kernel, commit + wait, and the host
235 // sees the GPU-modified buffer contents after worker completion.
236 use crate::DType;
237 use crate::ops::moe_dispatch::moe_zero_buffer_encode;
238
239 let device = crate::MlxDevice::new().expect("MlxDevice");
240 let device_arc = Arc::new(device);
241 let worker = EncoderWorker::spawn();
242
243 // Allocate a buffer initialized to 1.0; worker should zero it via GPU.
244 const N: usize = 1024;
245 let mut buf = device_arc
246 .alloc_buffer(N * 4, DType::F32, vec![N])
247 .expect("alloc");
248 for v in buf.as_mut_slice::<f32>().expect("init slice").iter_mut() {
249 *v = 1.0;
250 }
251
252 // Wrap buffer in Arc<Mutex> so worker can mutate via &mut.
253 let buf_arc = Arc::new(std::sync::Mutex::new(buf));
254 let device_clone = Arc::clone(&device_arc);
255 let buf_clone = Arc::clone(&buf_arc);
256
257 let (done_tx, done_rx) = std::sync::mpsc::channel::<Result<(), String>>();
258 worker.submit(move || {
259 let result = (|| -> Result<(), String> {
260 let mut registry = crate::KernelRegistry::new();
261 let mut enc = device_clone.command_encoder()
262 .map_err(|e| format!("enc: {e}"))?;
263 let buf_guard = buf_clone.lock().expect("lock");
264 moe_zero_buffer_encode(
265 &mut enc, &mut registry, device_clone.metal_device(),
266 &buf_guard, N,
267 ).map_err(|e| format!("zero_buffer: {e}"))?;
268 drop(buf_guard); // release before commit so commit doesn't deadlock on a re-lock
269 enc.commit_and_wait().map_err(|e| format!("commit: {e}"))?;
270 Ok(())
271 })();
272 done_tx.send(result).ok();
273 }).expect("submit");
274
275 done_rx.recv().expect("worker died").expect("worker error");
276
277 let buf_guard = buf_arc.lock().expect("lock");
278 let slice = buf_guard.as_slice::<f32>().expect("read");
279 for (i, &v) in slice.iter().enumerate() {
280 assert_eq!(v, 0.0, "element {i} not zeroed: {v}");
281 }
282 }
283}