Skip to main content

oximedia_core/
work_queue_ws.rs

1//! Chase-Lev work-stealing work queue for multi-threaded media pipelines.
2//!
3//! This module exposes [`WorkQueue`] — a thread-safe, work-stealing task
4//! distributor backed by [`crossbeam_deque`].  Each call to [`WorkQueue::new`]
5//! creates one injector (global push point) and `workers` local steal handles.
6//!
7//! # Design
8//!
9//! ```text
10//!   Producer         Injector          Worker 0 deque
11//!   ────────►  push  ──────► steal ───►  pop / steal
12//!                                         │
13//!                           Worker 1 ─────┘  (steals from Worker 0)
14//! ```
15//!
16//! Tasks pushed via [`WorkQueue::push`] land in the global injector queue.
17//! Worker threads call [`WorkQueue::steal`] which first drains the injector,
18//! then falls back to stealing from sibling workers.  [`WorkQueue::len`]
19//! returns an approximate total count.
20//!
21//! # Examples
22//!
23//! ```
24//! use oximedia_core::work_queue_ws::WorkQueue;
25//! use std::sync::Arc;
26//! use std::sync::atomic::{AtomicUsize, Ordering};
27//!
28//! let wq: WorkQueue<u32> = WorkQueue::new(2);
29//! for i in 0..10_u32 {
30//!     wq.push(i);
31//! }
32//! // Any thread can steal.
33//! let _item = wq.steal();
34//! assert!(wq.len() <= 10);
35//! ```
36
37use crossbeam_deque::{Injector, Steal, Stealer, Worker};
38use std::sync::{Arc, Mutex};
39
40// ─────────────────────────────────────────────────────────────────────────────
41// WorkQueue
42// ─────────────────────────────────────────────────────────────────────────────
43
44/// Inner shared state of a [`WorkQueue`].
45struct Inner<T> {
46    /// The global injection point; any thread may push here.
47    injector: Injector<T>,
48    /// One stealer handle per logical worker (cloned from the worker deques).
49    stealers: Vec<Stealer<T>>,
50    /// One worker deque per logical worker (protected behind a mutex so that
51    /// steal() can borrow a deque without requiring the caller to own a slot).
52    workers: Vec<Mutex<Worker<T>>>,
53    /// Approximate item count (incremented on push, decremented on steal).
54    len: std::sync::atomic::AtomicIsize,
55}
56
57/// A work-stealing work queue for distributing tasks across multiple workers.
58///
59/// `WorkQueue<T>` is `Clone` — all clones share the same underlying state,
60/// so tasks pushed from one clone are visible to all others.
61///
62/// # Thread safety
63///
64/// `WorkQueue<T>` is `Send + Sync` when `T: Send`.  Multiple threads may
65/// call [`push`](WorkQueue::push) and [`steal`](WorkQueue::steal)
66/// concurrently without external synchronisation.
67///
68/// # Examples
69///
70/// ```
71/// use oximedia_core::work_queue_ws::WorkQueue;
72/// use std::thread;
73/// use std::sync::Arc;
74/// use std::sync::atomic::{AtomicUsize, Ordering};
75///
76/// let wq = WorkQueue::<u32>::new(4);
77/// for i in 0..100_u32 {
78///     wq.push(i);
79/// }
80///
81/// let total = Arc::new(AtomicUsize::new(0));
82/// let mut handles = Vec::new();
83///
84/// for _ in 0..4 {
85///     let wq2 = wq.clone();
86///     let count = Arc::clone(&total);
87///     handles.push(thread::spawn(move || {
88///         while let Some(_task) = wq2.steal() {
89///             count.fetch_add(1, Ordering::Relaxed);
90///         }
91///     }));
92/// }
93/// for h in handles { h.join().expect("thread panicked"); }
94/// assert_eq!(total.load(Ordering::Relaxed), 100);
95/// ```
96#[derive(Clone)]
97pub struct WorkQueue<T: Send + 'static> {
98    inner: Arc<Inner<T>>,
99}
100
101impl<T: Send + 'static> WorkQueue<T> {
102    /// Creates a new `WorkQueue` with `workers` local deques.
103    ///
104    /// `workers` controls the number of distinct steal handles.  A value of
105    /// `0` is clamped to `1`.
106    ///
107    /// # Examples
108    ///
109    /// ```
110    /// use oximedia_core::work_queue_ws::WorkQueue;
111    ///
112    /// let wq = WorkQueue::<i32>::new(4);
113    /// assert_eq!(wq.len(), 0);
114    /// ```
115    #[must_use]
116    pub fn new(workers: usize) -> Self {
117        let num = workers.max(1);
118        let injector = Injector::new();
119        let mut worker_deques = Vec::with_capacity(num);
120        let mut stealers = Vec::with_capacity(num);
121
122        for _ in 0..num {
123            let w: Worker<T> = Worker::new_fifo();
124            stealers.push(w.stealer());
125            worker_deques.push(Mutex::new(w));
126        }
127
128        Self {
129            inner: Arc::new(Inner {
130                injector,
131                stealers,
132                workers: worker_deques,
133                len: std::sync::atomic::AtomicIsize::new(0),
134            }),
135        }
136    }
137
138    /// Pushes a task into the global injection queue.
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// use oximedia_core::work_queue_ws::WorkQueue;
144    ///
145    /// let wq = WorkQueue::<u32>::new(2);
146    /// wq.push(42_u32);
147    /// assert_eq!(wq.len(), 1);
148    /// ```
149    pub fn push(&self, task: T) {
150        self.inner.injector.push(task);
151        self.inner
152            .len
153            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
154    }
155
156    /// Attempts to steal a task from any available source.
157    ///
158    /// The implementation first drains the global injector into a local worker
159    /// deque (slot 0), then tries to pop from each worker in round-robin order,
160    /// retrying on contention.
161    ///
162    /// Returns `None` when all queues appear empty.
163    ///
164    /// # Examples
165    ///
166    /// ```
167    /// use oximedia_core::work_queue_ws::WorkQueue;
168    ///
169    /// let wq = WorkQueue::<u32>::new(2);
170    /// wq.push(1_u32);
171    /// wq.push(2_u32);
172    /// let t1 = wq.steal();
173    /// let t2 = wq.steal();
174    /// assert!(t1.is_some());
175    /// assert!(t2.is_some());
176    /// ```
177    pub fn steal(&self) -> Option<T> {
178        // Try draining the injector into worker 0 first.
179        if let Ok(guard) = self.inner.workers[0].lock() {
180            loop {
181                match self.inner.injector.steal_batch_and_pop(&guard) {
182                    Steal::Success(v) => {
183                        self.inner
184                            .len
185                            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
186                        return Some(v);
187                    }
188                    Steal::Retry => continue,
189                    Steal::Empty => break,
190                }
191            }
192        }
193
194        // Try popping from each worker deque in turn.
195        for w_mutex in &self.inner.workers {
196            if let Ok(guard) = w_mutex.lock() {
197                if let Some(item) = guard.pop() {
198                    self.inner
199                        .len
200                        .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
201                    return Some(item);
202                }
203            }
204        }
205
206        // Fall back to stealing via stealer handles (cross-thread steal).
207        for stealer in &self.inner.stealers {
208            loop {
209                match stealer.steal() {
210                    Steal::Success(v) => {
211                        self.inner
212                            .len
213                            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
214                        return Some(v);
215                    }
216                    Steal::Retry => continue,
217                    Steal::Empty => break,
218                }
219            }
220        }
221
222        None
223    }
224
225    /// Returns the approximate number of tasks currently in the queue.
226    ///
227    /// This value may be slightly stale due to concurrent operations.  It
228    /// saturates at zero rather than going negative.
229    ///
230    /// # Examples
231    ///
232    /// ```
233    /// use oximedia_core::work_queue_ws::WorkQueue;
234    ///
235    /// let wq = WorkQueue::<u32>::new(2);
236    /// wq.push(1_u32);
237    /// wq.push(2_u32);
238    /// assert_eq!(wq.len(), 2);
239    /// ```
240    #[must_use]
241    pub fn len(&self) -> usize {
242        let v = self.inner.len.load(std::sync::atomic::Ordering::Relaxed);
243        v.max(0) as usize
244    }
245
246    /// Returns `true` if the queue appears empty.
247    #[must_use]
248    pub fn is_empty(&self) -> bool {
249        self.len() == 0
250    }
251}
252
253// ─────────────────────────────────────────────────────────────────────────────
254// Tests
255// ─────────────────────────────────────────────────────────────────────────────
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use std::sync::atomic::{AtomicUsize, Ordering};
261    use std::thread;
262
263    // 1. Basic push and steal.
264    #[test]
265    fn push_and_steal_basic() {
266        let wq = WorkQueue::<u32>::new(1);
267        wq.push(10_u32);
268        wq.push(20_u32);
269        let a = wq.steal();
270        let b = wq.steal();
271        assert!(a.is_some());
272        assert!(b.is_some());
273        assert_eq!(wq.len(), 0);
274    }
275
276    // 2. Steal from empty returns None.
277    #[test]
278    fn steal_empty_returns_none() {
279        let wq = WorkQueue::<u32>::new(2);
280        assert!(wq.steal().is_none());
281    }
282
283    // 3. len tracks count.
284    #[test]
285    fn len_tracks_count() {
286        let wq = WorkQueue::<u32>::new(2);
287        assert_eq!(wq.len(), 0);
288        wq.push(1_u32);
289        assert_eq!(wq.len(), 1);
290        wq.push(2_u32);
291        assert_eq!(wq.len(), 2);
292        wq.steal();
293        assert_eq!(wq.len(), 1);
294    }
295
296    // 4. is_empty.
297    #[test]
298    fn is_empty_basic() {
299        let wq = WorkQueue::<u32>::new(2);
300        assert!(wq.is_empty());
301        wq.push(1_u32);
302        assert!(!wq.is_empty());
303    }
304
305    // 5. Clone shares state.
306    #[test]
307    fn clone_shares_state() {
308        let wq = WorkQueue::<u32>::new(2);
309        let wq2 = wq.clone();
310        wq.push(99_u32);
311        let stolen = wq2.steal();
312        assert_eq!(stolen, Some(99_u32));
313    }
314
315    // 6. Multi-threaded stress test: 4 workers, 10 000 tasks.
316    #[test]
317    fn threaded_stress_10000_tasks() {
318        const TASKS: u32 = 10_000;
319        const WORKERS: usize = 4;
320
321        let wq = WorkQueue::<u32>::new(WORKERS);
322        for i in 0..TASKS {
323            wq.push(i);
324        }
325
326        let stolen_count = Arc::new(AtomicUsize::new(0));
327        let mut handles = Vec::with_capacity(WORKERS);
328
329        for _ in 0..WORKERS {
330            let wq_clone = wq.clone();
331            let count = Arc::clone(&stolen_count);
332            handles.push(thread::spawn(move || {
333                let mut local = 0usize;
334                // Keep trying until the queue is empty.
335                let mut empty_streak = 0usize;
336                loop {
337                    match wq_clone.steal() {
338                        Some(_) => {
339                            local += 1;
340                            empty_streak = 0;
341                        }
342                        None => {
343                            empty_streak += 1;
344                            // After many consecutive misses, assume queue is drained.
345                            if empty_streak > 200 {
346                                break;
347                            }
348                            std::hint::spin_loop();
349                        }
350                    }
351                }
352                count.fetch_add(local, Ordering::Relaxed);
353            }));
354        }
355
356        for h in handles {
357            h.join().expect("worker thread panicked");
358        }
359
360        let total = stolen_count.load(Ordering::Relaxed);
361        assert_eq!(
362            total, TASKS as usize,
363            "expected all {TASKS} tasks to be consumed, got {total}"
364        );
365    }
366
367    // 7. Push from multiple producers, steal from multiple consumers.
368    #[test]
369    fn multi_producer_multi_consumer() {
370        const PER_PRODUCER: usize = 1_000;
371        const PRODUCERS: usize = 4;
372        const CONSUMERS: usize = 4;
373        const TOTAL: usize = PER_PRODUCER * PRODUCERS;
374
375        let wq = WorkQueue::<usize>::new(CONSUMERS);
376        let consumed = Arc::new(AtomicUsize::new(0));
377
378        // Spawn producers.
379        let mut handles = Vec::new();
380        for p in 0..PRODUCERS {
381            let wq_p = wq.clone();
382            handles.push(thread::spawn(move || {
383                for i in 0..PER_PRODUCER {
384                    wq_p.push(p * PER_PRODUCER + i);
385                }
386            }));
387        }
388        for h in handles {
389            h.join().expect("producer panicked");
390        }
391
392        // Spawn consumers.
393        let mut handles = Vec::new();
394        for _ in 0..CONSUMERS {
395            let wq_c = wq.clone();
396            let cnt = Arc::clone(&consumed);
397            handles.push(thread::spawn(move || {
398                let mut miss = 0;
399                loop {
400                    match wq_c.steal() {
401                        Some(_) => {
402                            cnt.fetch_add(1, Ordering::Relaxed);
403                            miss = 0;
404                        }
405                        None => {
406                            miss += 1;
407                            if miss > 500 {
408                                break;
409                            }
410                        }
411                    }
412                }
413            }));
414        }
415        for h in handles {
416            h.join().expect("consumer panicked");
417        }
418
419        assert_eq!(consumed.load(Ordering::Relaxed), TOTAL);
420    }
421}