use std::sync::mpsc;
use std::thread;
type Task = Box<dyn FnOnce() + Send + 'static>;
pub struct EncoderWorker {
tx: Option<mpsc::Sender<Task>>,
handle: Option<thread::JoinHandle<()>>,
}
impl EncoderWorker {
pub fn spawn() -> Self {
let (tx, rx) = mpsc::channel::<Task>();
let handle = thread::Builder::new()
.name("mlx-native-encoder-worker".into())
.spawn(move || {
while let Ok(task) = rx.recv() {
task();
}
})
.expect("failed to spawn encoder worker thread");
Self { tx: Some(tx), handle: Some(handle) }
}
pub fn submit<F>(&self, f: F) -> Result<(), &'static str>
where
F: FnOnce() + Send + 'static,
{
match self.tx.as_ref() {
Some(tx) => tx.send(Box::new(f)).map_err(|_| "worker thread is dead"),
None => Err("worker has been shut down"),
}
}
pub fn shutdown(&mut self) {
self.tx = None;
if let Some(h) = self.handle.take() {
let _ = h.join();
}
}
}
impl Drop for EncoderWorker {
fn drop(&mut self) {
self.shutdown();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn submit_runs_closure() {
let worker = EncoderWorker::spawn();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let (done_tx, done_rx) = std::sync::mpsc::channel();
worker.submit(move || {
counter_clone.fetch_add(1, Ordering::SeqCst);
done_tx.send(()).ok();
}).expect("submit");
done_rx.recv().expect("worker did not signal completion");
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[test]
fn submissions_run_in_order() {
let worker = EncoderWorker::spawn();
let order = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut signals = Vec::new();
for i in 0..5 {
let order_clone = Arc::clone(&order);
let (tx, rx) = std::sync::mpsc::channel();
signals.push(rx);
worker.submit(move || {
order_clone.lock().expect("lock").push(i);
tx.send(()).ok();
}).expect("submit");
}
for rx in signals {
rx.recv().expect("worker panicked");
}
let final_order = order.lock().expect("lock").clone();
assert_eq!(final_order, vec![0, 1, 2, 3, 4],
"tasks ran out of order: {:?}", final_order);
}
#[test]
fn shutdown_waits_for_in_flight_work() {
let mut worker = EncoderWorker::spawn();
let counter = Arc::new(AtomicU32::new(0));
for _ in 0..3 {
let counter_clone = Arc::clone(&counter);
worker.submit(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
counter_clone.fetch_add(1, Ordering::SeqCst);
}).expect("submit");
}
worker.shutdown();
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[test]
fn submit_after_shutdown_errors() {
let mut worker = EncoderWorker::spawn();
worker.shutdown();
assert!(worker.submit(|| {}).is_err());
}
#[cfg(target_vendor = "apple")]
#[test]
fn worker_can_create_metal_encoder_and_commit() {
let device = crate::MlxDevice::new().expect("MlxDevice");
let device_arc = Arc::new(device);
let worker = EncoderWorker::spawn();
let device_clone = Arc::clone(&device_arc);
let (done_tx, done_rx) = std::sync::mpsc::channel::<Result<(), String>>();
worker.submit(move || {
let result = (|| -> Result<(), String> {
let mut enc = device_clone.command_encoder()
.map_err(|e| format!("enc create: {e}"))?;
enc.commit_and_wait()
.map_err(|e| format!("commit_and_wait: {e}"))?;
Ok(())
})();
done_tx.send(result).ok();
}).expect("submit");
let result = done_rx.recv().expect("worker died");
assert!(result.is_ok(), "worker Metal encoder failed: {:?}", result);
}
#[cfg(target_vendor = "apple")]
#[test]
fn worker_can_dispatch_real_kernel_zero_buffer() {
use crate::DType;
use crate::ops::moe_dispatch::moe_zero_buffer_encode;
let device = crate::MlxDevice::new().expect("MlxDevice");
let device_arc = Arc::new(device);
let worker = EncoderWorker::spawn();
const N: usize = 1024;
let mut buf = device_arc
.alloc_buffer(N * 4, DType::F32, vec![N])
.expect("alloc");
for v in buf.as_mut_slice::<f32>().expect("init slice").iter_mut() {
*v = 1.0;
}
let buf_arc = Arc::new(std::sync::Mutex::new(buf));
let device_clone = Arc::clone(&device_arc);
let buf_clone = Arc::clone(&buf_arc);
let (done_tx, done_rx) = std::sync::mpsc::channel::<Result<(), String>>();
worker.submit(move || {
let result = (|| -> Result<(), String> {
let mut registry = crate::KernelRegistry::new();
let mut enc = device_clone.command_encoder()
.map_err(|e| format!("enc: {e}"))?;
let buf_guard = buf_clone.lock().expect("lock");
moe_zero_buffer_encode(
&mut enc, &mut registry, device_clone.metal_device(),
&buf_guard, N,
).map_err(|e| format!("zero_buffer: {e}"))?;
drop(buf_guard); enc.commit_and_wait().map_err(|e| format!("commit: {e}"))?;
Ok(())
})();
done_tx.send(result).ok();
}).expect("submit");
done_rx.recv().expect("worker died").expect("worker error");
let buf_guard = buf_arc.lock().expect("lock");
let slice = buf_guard.as_slice::<f32>().expect("read");
for (i, &v) in slice.iter().enumerate() {
assert_eq!(v, 0.0, "element {i} not zeroed: {v}");
}
}
}