Skip to main content

zsh/extensions/
worker.rs

1//! Worker pool for zshrs — persistent threads for background work.
2//!
3//! **zshrs-original infrastructure — no C source counterpart.** This
4//! module does NOT port a corresponding `Src/*.c` file. C zsh's
5//! background-work strategy is `fork(2)`: every completion run,
6//! process substitution, or command substitution is a child process
7//! (see `zfork()` in Src/exec.c and the `forklevel` machinery
8//! Src/init.c uses to track depth). zshrs replaces that pattern with
9//! a fixed-size thread pool + crossbeam channel dispatch.
10//!
11//! Replacement rationale (vs the fork() path the C source takes):
12//!   - No fork overhead (50-500μs per fork on macOS)
13//!   - No address space duplication
14//!   - Warm thread stacks ready to go
15//!   - Backpressure via bounded channel
16//!
17//! Pool size = available_parallelism() clamped to [2, 18].
18//! Channel capacity = 4 × pool size (bounded backpressure).
19//!
20//! Audit fixes applied:
21//!   1. crossbeam-channel replaces Arc<Mutex<mpsc::Receiver>> — no mutex contention
22//!   2. Bounded channel (4×N) provides backpressure
23//!   3. catch_unwind wraps every task — panics logged, worker stays alive
24//!   4. tracing spans on submit + worker loop
25//!   5. Queue depth metric on submit
26//!   6. Task cancellation via AtomicBool flag
27
28use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
29use std::sync::Arc;
30use std::thread;
31
32/// A unit of work the pool can execute.
33type Task = Box<dyn FnOnce() + Send + 'static>;
34
35/// Fixed-size thread pool with bounded FIFO task queue.
36///
37/// zshrs-original — replaces C zsh's per-task `fork()` + `wait()`
38/// pattern (Src/exec.c `zfork()` / Src/jobs.c child management) with
39/// a persistent thread pool. Uses crossbeam-channel for lock-free
40/// multi-consumer dispatch — each worker calls `recv()` directly,
41/// no mutex.
42pub struct WorkerPool {
43    workers: Vec<Worker>,
44    sender: Option<crossbeam_channel::Sender<Task>>,
45    size: usize,
46    /// Shared cancellation flag — when set, workers drop pending tasks
47    cancelled: Arc<AtomicBool>,
48    /// Queue depth — incremented on submit, decremented on task start
49    queued: Arc<AtomicUsize>,
50    /// Total tasks completed across all workers
51    completed: Arc<AtomicUsize>,
52}
53
54struct Worker {
55    #[allow(dead_code)]
56    id: usize,
57    handle: Option<thread::JoinHandle<()>>,
58}
59
60impl WorkerPool {
61    /// Create a pool with `size` worker threads and bounded channel.
62    /// Channel capacity = 4 × size (provides backpressure without
63    /// starving).
64    /// zshrs-original — no C counterpart. Replaces the
65    /// "spawn-on-demand" semantics of `zfork()` (Src/exec.c) with
66    /// pre-spawned threads ready to receive work over a bounded
67    /// channel.
68    pub fn new(size: usize) -> Self {
69        let capacity = size * 4;
70        let (sender, receiver) = crossbeam_channel::bounded::<Task>(capacity);
71        let cancelled = Arc::new(AtomicBool::new(false));
72        let queued = Arc::new(AtomicUsize::new(0));
73        let completed = Arc::new(AtomicUsize::new(0));
74
75        let mut workers = Vec::with_capacity(size);
76        for id in 0..size {
77            let rx = receiver.clone();
78            let cancelled = Arc::clone(&cancelled);
79            let queued = Arc::clone(&queued);
80            let completed = Arc::clone(&completed);
81
82            let handle = thread::Builder::new()
83                .name(format!("zshrs-worker-{}", id))
84                .spawn(move || {
85                    loop {
86                        let task = match rx.recv() {
87                            Ok(task) => task,
88                            Err(_) => break, // channel closed → shutdown
89                        };
90
91                        queued.fetch_sub(1, Ordering::Relaxed);
92
93                        // Check cancellation before running
94                        if cancelled.load(Ordering::Relaxed) {
95                            continue; // drain without executing
96                        }
97
98                        // catch_unwind keeps the worker alive if a task panics
99                        if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(task))
100                        {
101                            let msg = if let Some(s) = e.downcast_ref::<&str>() {
102                                (*s).to_string()
103                            } else if let Some(s) = e.downcast_ref::<String>() {
104                                s.clone()
105                            } else {
106                                "unknown panic".to_string()
107                            };
108                            tracing::error!(
109                                worker = id,
110                                panic = %msg,
111                                "worker task panicked"
112                            );
113                        }
114
115                        completed.fetch_add(1, Ordering::Relaxed);
116                    }
117                    tracing::debug!(worker = id, "worker thread exiting");
118                })
119                .expect("failed to spawn worker thread");
120
121            workers.push(Worker {
122                id,
123                handle: Some(handle),
124            });
125        }
126
127        tracing::info!(
128            pool_size = size,
129            channel_capacity = capacity,
130            "worker pool started"
131        );
132
133        WorkerPool {
134            workers,
135            sender: Some(sender),
136            size,
137            cancelled,
138            queued,
139            completed,
140        }
141    }
142
143    /// Create a pool sized to the machine's parallelism, clamped to
144    /// `[2, 18]`.
145    /// zshrs-original — no C counterpart. C zsh has no concept of a
146    /// "pool size" because it forks on demand (one child per
147    /// background task, see Src/jobs.c).
148    pub fn default_size() -> Self {
149        let cpus = thread::available_parallelism()
150            .map(|n| n.get())
151            .unwrap_or(4);
152        Self::new(cpus.clamp(2, 18))
153    }
154
155    /// Submit a task to the pool. Blocks if the queue is full
156    /// (backpressure). Panics if the pool has been shut down.
157    /// zshrs-original — replaces the `fork() + execve()` /
158    /// `fork() + run-shell-fn` dispatch pairs in Src/exec.c.
159    pub fn submit<F>(&self, f: F)
160    where
161        F: FnOnce() + Send + 'static,
162    {
163        let depth = self.queued.fetch_add(1, Ordering::Relaxed) + 1;
164        if depth > self.size * 2 {
165            tracing::debug!(queue_depth = depth, "worker pool queue building up");
166        }
167        self.sender
168            .as_ref()
169            .expect("pool shut down")
170            .send(Box::new(f))
171            .expect("all workers dead");
172    }
173
174    /// Submit a task and get a receiver for its result.
175    /// zshrs-original — closest C analog is the pipe-based
176    /// command-substitution result capture in Src/exec.c
177    /// (`getoutput()` reading the child's stdout pipe), but using a
178    /// typed Rust channel sidesteps the marshalling.
179    pub fn submit_with_result<F, R>(&self, f: F) -> crossbeam_channel::Receiver<R>
180    where
181        F: FnOnce() -> R + Send + 'static,
182        R: Send + 'static,
183    {
184        let (tx, rx) = crossbeam_channel::bounded(1);
185        self.submit(move || {
186            let result = f();
187            let _ = tx.send(result);
188        });
189        rx
190    }
191
192    /// Signal all workers to drop pending tasks.
193    /// Already-running tasks will finish, but queued tasks are
194    /// skipped. Reset with `reset_cancel()`.
195    /// zshrs-original — closest C analog is the SIGINT/SIGQUIT
196    /// signal-storm dispatch C zsh fires at its background children
197    /// in Src/signals.c (`killjb()` / `killpg()`), but here we set a
198    /// flag instead of sending a signal across a fork boundary.
199    pub fn cancel(&self) {
200        self.cancelled.store(true, Ordering::Relaxed);
201        tracing::info!("worker pool: cancel requested");
202    }
203
204    /// Clear the cancellation flag — pool resumes normal execution.
205    /// zshrs-original — no C counterpart.
206    pub fn reset_cancel(&self) {
207        self.cancelled.store(false, Ordering::Relaxed);
208    }
209
210    /// Number of worker threads.
211    /// zshrs-original — no C counterpart.
212    pub fn size(&self) -> usize {
213        self.size
214    }
215
216    /// Approximate number of tasks waiting in the queue.
217    /// zshrs-original — no C counterpart; closest equivalent is the
218    /// `jobtab` length walk Src/jobs.c uses for `jobs -l` output.
219    pub fn queue_depth(&self) -> usize {
220        self.queued.load(Ordering::Relaxed)
221    }
222
223    /// Total tasks completed since pool creation.
224    /// zshrs-original — no C counterpart.
225    pub fn completed(&self) -> usize {
226        self.completed.load(Ordering::Relaxed)
227    }
228}
229
230impl Drop for WorkerPool {
231    fn drop(&mut self) {
232        // Signal workers to skip remaining queued tasks
233        self.cancelled.store(true, Ordering::Relaxed);
234        // Drop the sender → channel closes → recv() returns Err → threads exit
235        drop(self.sender.take());
236        // Give workers a brief window to finish their current task.
237        // Don't block indefinitely — the process is exiting.
238        for w in &mut self.workers {
239            if let Some(handle) = w.handle.take() {
240                // Detach the thread — OS cleans up on process exit.
241                // join() would block if a worker is mid-parse on a 500-line
242                // completion function. Not worth the wait on Ctrl-D/exit.
243                drop(handle);
244            }
245        }
246        tracing::info!(
247            tasks_completed = self.completed.load(Ordering::Relaxed),
248            "worker pool shut down"
249        );
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    /// Spin-wait helper for tests: poll `counter` until it reaches
258    /// `target` or the deadline elapses. Replaces the old "drop(pool)
259    /// implicitly waits" pattern, which broke when production Drop
260    /// switched to setting cancelled=true (so queued tasks would be
261    /// skipped on drop instead of drained).
262    fn wait_for_count(counter: &AtomicUsize, target: usize, max_wait_ms: u64) {
263        let deadline =
264            std::time::Instant::now() + std::time::Duration::from_millis(max_wait_ms);
265        while counter.load(Ordering::Relaxed) < target {
266            if std::time::Instant::now() >= deadline {
267                panic!(
268                    "wait_for_count timed out: counter={} target={} after {}ms",
269                    counter.load(Ordering::Relaxed),
270                    target,
271                    max_wait_ms
272                );
273            }
274            std::thread::sleep(std::time::Duration::from_millis(2));
275        }
276    }
277
278    #[test]
279    fn test_pool_executes_tasks() {
280        let pool = WorkerPool::new(2);
281        let counter = Arc::new(AtomicUsize::new(0));
282
283        for _ in 0..100 {
284            let c = Arc::clone(&counter);
285            pool.submit(move || {
286                c.fetch_add(1, Ordering::Relaxed);
287            });
288        }
289
290        // Drain explicitly — production Drop sets cancelled=true and
291        // skips queued tasks (intentional for shell exit), so the test
292        // can't rely on `drop(pool)` to wait.
293        wait_for_count(&counter, 100, 5_000);
294        drop(pool);
295        assert_eq!(counter.load(Ordering::Relaxed), 100);
296    }
297
298    #[test]
299    fn test_submit_with_result() {
300        let pool = WorkerPool::new(2);
301        let rx = pool.submit_with_result(|| 42);
302        assert_eq!(rx.recv().unwrap(), 42);
303    }
304
305    #[test]
306    fn test_default_size() {
307        let pool = WorkerPool::default_size();
308        assert!(pool.size() >= 2);
309        assert!(pool.size() <= 18);
310    }
311
312    #[test]
313    fn test_panic_does_not_kill_worker() {
314        let pool = WorkerPool::new(2);
315        let counter = Arc::new(AtomicUsize::new(0));
316
317        // Submit a task that panics
318        pool.submit(|| panic!("intentional test panic"));
319
320        // Submit tasks after the panic — they should still run
321        for _ in 0..10 {
322            let c = Arc::clone(&counter);
323            pool.submit(move || {
324                c.fetch_add(1, Ordering::Relaxed);
325            });
326        }
327
328        wait_for_count(&counter, 10, 5_000);
329        drop(pool);
330        assert_eq!(counter.load(Ordering::Relaxed), 10);
331    }
332
333    #[test]
334    fn test_cancel_skips_queued_tasks() {
335        let pool = WorkerPool::new(1); // single worker to control ordering
336        let barrier = Arc::new(std::sync::Barrier::new(2));
337        // Signal the worker fires when it ENTERS the barrier task. Lets
338        // the main thread wait until the worker is provably blocked
339        // inside the barrier BEFORE calling cancel(). Without this, a
340        // pre-empted worker that hasn't yet pulled task #1 would see the
341        // cancel flag, skip task #1, and the main thread's barrier.wait()
342        // below would deadlock waiting for a second party that never
343        // arrives.
344        let started = Arc::new(std::sync::Mutex::new(false));
345        let started_cv = Arc::new(std::sync::Condvar::new());
346        let counter = Arc::new(AtomicUsize::new(0));
347
348        let b = Arc::clone(&barrier);
349        let started_clone = Arc::clone(&started);
350        let cv_clone = Arc::clone(&started_cv);
351        pool.submit(move || {
352            // Mark "task entered" + notify before blocking.
353            *started_clone.lock().unwrap() = true;
354            cv_clone.notify_one();
355            b.wait();
356        });
357
358        // Wait until the worker is provably inside the task (and thus
359        // committed to calling b.wait() — no race with cancel below).
360        // 5s timeout is a safety net; in practice this fires within μs.
361        let mut g = started.lock().unwrap();
362        let timeout = std::time::Duration::from_secs(5);
363        while !*g {
364            let (gg, wait_result) = started_cv.wait_timeout(g, timeout).unwrap();
365            g = gg;
366            if wait_result.timed_out() && !*g {
367                panic!("worker never started task #1 within 5s — test scaffolding broken");
368            }
369        }
370        drop(g);
371
372        // Queue tasks that should be skipped (worker is parked at b.wait()).
373        // Cap at channel capacity (size * 4 = 4 for a 1-worker pool) MINUS 1
374        // for safety. Submitting more than the channel holds while the
375        // worker is blocked deadlocks `submit` itself, since the bounded
376        // crossbeam channel back-pressures `send()`. 3 skipped tasks is
377        // enough to prove "queued tasks get cancelled" — the count isn't
378        // load-bearing.
379        for _ in 0..3 {
380            let c = Arc::clone(&counter);
381            pool.submit(move || {
382                c.fetch_add(1, Ordering::Relaxed);
383            });
384        }
385
386        // Cancel, then unblock the worker — it'll return from b.wait(),
387        // loop, see cancelled=true, drain the 5 queued tasks without
388        // executing them.
389        pool.cancel();
390        barrier.wait();
391
392        // Give workers time to drain
393        std::thread::sleep(std::time::Duration::from_millis(50));
394
395        // Queued tasks should have been skipped
396        assert_eq!(counter.load(Ordering::Relaxed), 0);
397
398        // Reset and verify pool still works
399        pool.reset_cancel();
400        let c = Arc::clone(&counter);
401        pool.submit(move || {
402            c.fetch_add(1, Ordering::Relaxed);
403        });
404        // Wait for the post-reset task to complete BEFORE drop, since
405        // production Drop sets cancelled=true again and would skip
406        // any not-yet-pulled task.
407        wait_for_count(&counter, 1, 5_000);
408        drop(pool);
409        assert_eq!(counter.load(Ordering::Relaxed), 1);
410    }
411
412    #[test]
413    fn test_metrics() {
414        let pool = WorkerPool::new(2);
415        assert_eq!(pool.completed(), 0);
416
417        for _ in 0..10 {
418            pool.submit(|| {});
419        }
420
421        drop(pool);
422        // Can't assert exact completed count due to timing,
423        // but it should be > 0 after drop waits for all
424    }
425
426    #[test]
427    fn test_backpressure_bounded() {
428        // Pool of 1 with capacity 4 — 5th submit blocks (back-pressure)
429        // until the worker drains one. With 20 submits + 1 worker the
430        // pool's submit() call blocks naturally; by the time the loop
431        // exits, ~16 are completed and ~4 are still queued / in-flight.
432        let pool = WorkerPool::new(1);
433        let counter = Arc::new(AtomicUsize::new(0));
434
435        for _ in 0..20 {
436            let c = Arc::clone(&counter);
437            pool.submit(move || {
438                c.fetch_add(1, Ordering::Relaxed);
439            });
440        }
441
442        wait_for_count(&counter, 20, 5_000);
443        drop(pool);
444        assert_eq!(counter.load(Ordering::Relaxed), 20);
445    }
446}