Skip to main content

flodl/
worker.rs

1//! Background CPU work queue.
2//!
3//! A single-threaded worker that accepts closures via an mpsc channel and
4//! executes them in order. Designed for offloading CPU-bound work (checkpoints,
5//! evaluation, file I/O) off the GPU training thread.
6
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::mpsc::{self, Sender};
9use std::sync::Arc;
10use std::thread::{self, JoinHandle};
11
12/// A single-threaded background worker that executes closures in order.
13///
14/// ```ignore
15/// let worker = CpuWorker::new();
16/// worker.submit(|| {
17///     save_checkpoint(&snapshot, "model.pt").unwrap();
18/// });
19/// // GPU training continues immediately
20/// worker.finish(); // blocks until all queued work completes
21/// ```
22pub struct CpuWorker {
23    tx: Option<Sender<Box<dyn FnOnce() + Send>>>,
24    handle: Option<JoinHandle<()>>,
25    busy: Arc<AtomicBool>,
26}
27
28impl CpuWorker {
29    /// Spawn the background worker thread.
30    pub fn new() -> Self {
31        let (tx, rx) = mpsc::channel::<Box<dyn FnOnce() + Send>>();
32        let busy = Arc::new(AtomicBool::new(false));
33        let busy2 = busy.clone();
34
35        let handle = thread::spawn(move || {
36            for task in rx {
37                busy2.store(true, Ordering::Release);
38                task();
39                busy2.store(false, Ordering::Release);
40            }
41        });
42
43        CpuWorker {
44            tx: Some(tx),
45            handle: Some(handle),
46            busy,
47        }
48    }
49
50    /// Submit a closure to run on the background thread.
51    pub fn submit<F: FnOnce() + Send + 'static>(&self, f: F) {
52        if let Some(ref tx) = self.tx {
53            let _ = tx.send(Box::new(f));
54        }
55    }
56
57    /// Check whether the worker is idle (not currently executing a task).
58    ///
59    /// Useful for skip-if-busy semantics: only submit a new checkpoint
60    /// if the previous one has finished.
61    pub fn is_idle(&self) -> bool {
62        !self.busy.load(Ordering::Acquire)
63    }
64
65    /// Drop the sender and join the worker thread, blocking until all
66    /// queued tasks have completed.
67    pub fn finish(&mut self) {
68        self.tx.take(); // drop sender → rx iterator ends
69        if let Some(h) = self.handle.take() {
70            let _ = h.join();
71        }
72    }
73}
74
75impl Default for CpuWorker {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl Drop for CpuWorker {
82    fn drop(&mut self) {
83        self.finish();
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use std::sync::atomic::AtomicUsize;
91
92    #[test]
93    fn submit_and_finish() {
94        let flag = Arc::new(AtomicBool::new(false));
95        let flag2 = flag.clone();
96
97        let mut worker = CpuWorker::new();
98        worker.submit(move || {
99            flag2.store(true, Ordering::Release);
100        });
101        worker.finish();
102
103        assert!(flag.load(Ordering::Acquire), "closure should have run");
104    }
105
106    #[test]
107    fn tasks_execute_in_order() {
108        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
109
110        let mut worker = CpuWorker::new();
111        for i in 0..5 {
112            let log2 = log.clone();
113            worker.submit(move || {
114                log2.lock().unwrap().push(i);
115            });
116        }
117        worker.finish();
118
119        assert_eq!(*log.lock().unwrap(), vec![0, 1, 2, 3, 4]);
120    }
121
122    #[test]
123    fn drop_joins_thread() {
124        let counter = Arc::new(AtomicUsize::new(0));
125        let counter2 = counter.clone();
126
127        {
128            let worker = CpuWorker::new();
129            worker.submit(move || {
130                counter2.fetch_add(1, Ordering::Release);
131            });
132            // drop here
133        }
134
135        assert_eq!(counter.load(Ordering::Acquire), 1);
136    }
137}