oximedia_core/work_queue_ws.rs
1//! Chase-Lev work-stealing work queue for multi-threaded media pipelines.
2//!
3//! This module exposes [`WorkQueue`] — a thread-safe, work-stealing task
4//! distributor backed by [`crossbeam_deque`]. Each call to [`WorkQueue::new`]
5//! creates one injector (global push point) and `workers` local steal handles.
6//!
7//! # Design
8//!
9//! ```text
10//! Producer Injector Worker 0 deque
11//! ────────► push ──────► steal ───► pop / steal
12//! │
13//! Worker 1 ─────┘ (steals from Worker 0)
14//! ```
15//!
16//! Tasks pushed via [`WorkQueue::push`] land in the global injector queue.
17//! Worker threads call [`WorkQueue::steal`] which first drains the injector,
18//! then falls back to stealing from sibling workers. [`WorkQueue::len`]
19//! returns an approximate total count.
20//!
21//! # Examples
22//!
23//! ```
24//! use oximedia_core::work_queue_ws::WorkQueue;
25//! use std::sync::Arc;
26//! use std::sync::atomic::{AtomicUsize, Ordering};
27//!
28//! let wq: WorkQueue<u32> = WorkQueue::new(2);
29//! for i in 0..10_u32 {
30//! wq.push(i);
31//! }
32//! // Any thread can steal.
33//! let _item = wq.steal();
34//! assert!(wq.len() <= 10);
35//! ```
36
37use crossbeam_deque::{Injector, Steal, Stealer, Worker};
38use std::sync::{Arc, Mutex};
39
40// ─────────────────────────────────────────────────────────────────────────────
41// WorkQueue
42// ─────────────────────────────────────────────────────────────────────────────
43
44/// Inner shared state of a [`WorkQueue`].
45struct Inner<T> {
46 /// The global injection point; any thread may push here.
47 injector: Injector<T>,
48 /// One stealer handle per logical worker (cloned from the worker deques).
49 stealers: Vec<Stealer<T>>,
50 /// One worker deque per logical worker (protected behind a mutex so that
51 /// steal() can borrow a deque without requiring the caller to own a slot).
52 workers: Vec<Mutex<Worker<T>>>,
53 /// Approximate item count (incremented on push, decremented on steal).
54 len: std::sync::atomic::AtomicIsize,
55}
56
57/// A work-stealing work queue for distributing tasks across multiple workers.
58///
59/// `WorkQueue<T>` is `Clone` — all clones share the same underlying state,
60/// so tasks pushed from one clone are visible to all others.
61///
62/// # Thread safety
63///
64/// `WorkQueue<T>` is `Send + Sync` when `T: Send`. Multiple threads may
65/// call [`push`](WorkQueue::push) and [`steal`](WorkQueue::steal)
66/// concurrently without external synchronisation.
67///
68/// # Examples
69///
70/// ```
71/// use oximedia_core::work_queue_ws::WorkQueue;
72/// use std::thread;
73/// use std::sync::Arc;
74/// use std::sync::atomic::{AtomicUsize, Ordering};
75///
76/// let wq = WorkQueue::<u32>::new(4);
77/// for i in 0..100_u32 {
78/// wq.push(i);
79/// }
80///
81/// let total = Arc::new(AtomicUsize::new(0));
82/// let mut handles = Vec::new();
83///
84/// for _ in 0..4 {
85/// let wq2 = wq.clone();
86/// let count = Arc::clone(&total);
87/// handles.push(thread::spawn(move || {
88/// while let Some(_task) = wq2.steal() {
89/// count.fetch_add(1, Ordering::Relaxed);
90/// }
91/// }));
92/// }
93/// for h in handles { h.join().expect("thread panicked"); }
94/// assert_eq!(total.load(Ordering::Relaxed), 100);
95/// ```
96#[derive(Clone)]
97pub struct WorkQueue<T: Send + 'static> {
98 inner: Arc<Inner<T>>,
99}
100
101impl<T: Send + 'static> WorkQueue<T> {
102 /// Creates a new `WorkQueue` with `workers` local deques.
103 ///
104 /// `workers` controls the number of distinct steal handles. A value of
105 /// `0` is clamped to `1`.
106 ///
107 /// # Examples
108 ///
109 /// ```
110 /// use oximedia_core::work_queue_ws::WorkQueue;
111 ///
112 /// let wq = WorkQueue::<i32>::new(4);
113 /// assert_eq!(wq.len(), 0);
114 /// ```
115 #[must_use]
116 pub fn new(workers: usize) -> Self {
117 let num = workers.max(1);
118 let injector = Injector::new();
119 let mut worker_deques = Vec::with_capacity(num);
120 let mut stealers = Vec::with_capacity(num);
121
122 for _ in 0..num {
123 let w: Worker<T> = Worker::new_fifo();
124 stealers.push(w.stealer());
125 worker_deques.push(Mutex::new(w));
126 }
127
128 Self {
129 inner: Arc::new(Inner {
130 injector,
131 stealers,
132 workers: worker_deques,
133 len: std::sync::atomic::AtomicIsize::new(0),
134 }),
135 }
136 }
137
138 /// Pushes a task into the global injection queue.
139 ///
140 /// # Examples
141 ///
142 /// ```
143 /// use oximedia_core::work_queue_ws::WorkQueue;
144 ///
145 /// let wq = WorkQueue::<u32>::new(2);
146 /// wq.push(42_u32);
147 /// assert_eq!(wq.len(), 1);
148 /// ```
149 pub fn push(&self, task: T) {
150 self.inner.injector.push(task);
151 self.inner
152 .len
153 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
154 }
155
156 /// Attempts to steal a task from any available source.
157 ///
158 /// The implementation first drains the global injector into a local worker
159 /// deque (slot 0), then tries to pop from each worker in round-robin order,
160 /// retrying on contention.
161 ///
162 /// Returns `None` when all queues appear empty.
163 ///
164 /// # Examples
165 ///
166 /// ```
167 /// use oximedia_core::work_queue_ws::WorkQueue;
168 ///
169 /// let wq = WorkQueue::<u32>::new(2);
170 /// wq.push(1_u32);
171 /// wq.push(2_u32);
172 /// let t1 = wq.steal();
173 /// let t2 = wq.steal();
174 /// assert!(t1.is_some());
175 /// assert!(t2.is_some());
176 /// ```
177 pub fn steal(&self) -> Option<T> {
178 // Try draining the injector into worker 0 first.
179 if let Ok(guard) = self.inner.workers[0].lock() {
180 loop {
181 match self.inner.injector.steal_batch_and_pop(&guard) {
182 Steal::Success(v) => {
183 self.inner
184 .len
185 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
186 return Some(v);
187 }
188 Steal::Retry => continue,
189 Steal::Empty => break,
190 }
191 }
192 }
193
194 // Try popping from each worker deque in turn.
195 for w_mutex in &self.inner.workers {
196 if let Ok(guard) = w_mutex.lock() {
197 if let Some(item) = guard.pop() {
198 self.inner
199 .len
200 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
201 return Some(item);
202 }
203 }
204 }
205
206 // Fall back to stealing via stealer handles (cross-thread steal).
207 for stealer in &self.inner.stealers {
208 loop {
209 match stealer.steal() {
210 Steal::Success(v) => {
211 self.inner
212 .len
213 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
214 return Some(v);
215 }
216 Steal::Retry => continue,
217 Steal::Empty => break,
218 }
219 }
220 }
221
222 None
223 }
224
225 /// Returns the approximate number of tasks currently in the queue.
226 ///
227 /// This value may be slightly stale due to concurrent operations. It
228 /// saturates at zero rather than going negative.
229 ///
230 /// # Examples
231 ///
232 /// ```
233 /// use oximedia_core::work_queue_ws::WorkQueue;
234 ///
235 /// let wq = WorkQueue::<u32>::new(2);
236 /// wq.push(1_u32);
237 /// wq.push(2_u32);
238 /// assert_eq!(wq.len(), 2);
239 /// ```
240 #[must_use]
241 pub fn len(&self) -> usize {
242 let v = self.inner.len.load(std::sync::atomic::Ordering::Relaxed);
243 v.max(0) as usize
244 }
245
246 /// Returns `true` if the queue appears empty.
247 #[must_use]
248 pub fn is_empty(&self) -> bool {
249 self.len() == 0
250 }
251}
252
253// ─────────────────────────────────────────────────────────────────────────────
254// Tests
255// ─────────────────────────────────────────────────────────────────────────────
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use std::sync::atomic::{AtomicUsize, Ordering};
261 use std::thread;
262
263 // 1. Basic push and steal.
264 #[test]
265 fn push_and_steal_basic() {
266 let wq = WorkQueue::<u32>::new(1);
267 wq.push(10_u32);
268 wq.push(20_u32);
269 let a = wq.steal();
270 let b = wq.steal();
271 assert!(a.is_some());
272 assert!(b.is_some());
273 assert_eq!(wq.len(), 0);
274 }
275
276 // 2. Steal from empty returns None.
277 #[test]
278 fn steal_empty_returns_none() {
279 let wq = WorkQueue::<u32>::new(2);
280 assert!(wq.steal().is_none());
281 }
282
283 // 3. len tracks count.
284 #[test]
285 fn len_tracks_count() {
286 let wq = WorkQueue::<u32>::new(2);
287 assert_eq!(wq.len(), 0);
288 wq.push(1_u32);
289 assert_eq!(wq.len(), 1);
290 wq.push(2_u32);
291 assert_eq!(wq.len(), 2);
292 wq.steal();
293 assert_eq!(wq.len(), 1);
294 }
295
296 // 4. is_empty.
297 #[test]
298 fn is_empty_basic() {
299 let wq = WorkQueue::<u32>::new(2);
300 assert!(wq.is_empty());
301 wq.push(1_u32);
302 assert!(!wq.is_empty());
303 }
304
305 // 5. Clone shares state.
306 #[test]
307 fn clone_shares_state() {
308 let wq = WorkQueue::<u32>::new(2);
309 let wq2 = wq.clone();
310 wq.push(99_u32);
311 let stolen = wq2.steal();
312 assert_eq!(stolen, Some(99_u32));
313 }
314
315 // 6. Multi-threaded stress test: 4 workers, 10 000 tasks.
316 #[test]
317 fn threaded_stress_10000_tasks() {
318 const TASKS: u32 = 10_000;
319 const WORKERS: usize = 4;
320
321 let wq = WorkQueue::<u32>::new(WORKERS);
322 for i in 0..TASKS {
323 wq.push(i);
324 }
325
326 let stolen_count = Arc::new(AtomicUsize::new(0));
327 let mut handles = Vec::with_capacity(WORKERS);
328
329 for _ in 0..WORKERS {
330 let wq_clone = wq.clone();
331 let count = Arc::clone(&stolen_count);
332 handles.push(thread::spawn(move || {
333 let mut local = 0usize;
334 // Keep trying until the queue is empty.
335 let mut empty_streak = 0usize;
336 loop {
337 match wq_clone.steal() {
338 Some(_) => {
339 local += 1;
340 empty_streak = 0;
341 }
342 None => {
343 empty_streak += 1;
344 // After many consecutive misses, assume queue is drained.
345 if empty_streak > 200 {
346 break;
347 }
348 std::hint::spin_loop();
349 }
350 }
351 }
352 count.fetch_add(local, Ordering::Relaxed);
353 }));
354 }
355
356 for h in handles {
357 h.join().expect("worker thread panicked");
358 }
359
360 let total = stolen_count.load(Ordering::Relaxed);
361 assert_eq!(
362 total, TASKS as usize,
363 "expected all {TASKS} tasks to be consumed, got {total}"
364 );
365 }
366
367 // 7. Push from multiple producers, steal from multiple consumers.
368 #[test]
369 fn multi_producer_multi_consumer() {
370 const PER_PRODUCER: usize = 1_000;
371 const PRODUCERS: usize = 4;
372 const CONSUMERS: usize = 4;
373 const TOTAL: usize = PER_PRODUCER * PRODUCERS;
374
375 let wq = WorkQueue::<usize>::new(CONSUMERS);
376 let consumed = Arc::new(AtomicUsize::new(0));
377
378 // Spawn producers.
379 let mut handles = Vec::new();
380 for p in 0..PRODUCERS {
381 let wq_p = wq.clone();
382 handles.push(thread::spawn(move || {
383 for i in 0..PER_PRODUCER {
384 wq_p.push(p * PER_PRODUCER + i);
385 }
386 }));
387 }
388 for h in handles {
389 h.join().expect("producer panicked");
390 }
391
392 // Spawn consumers.
393 let mut handles = Vec::new();
394 for _ in 0..CONSUMERS {
395 let wq_c = wq.clone();
396 let cnt = Arc::clone(&consumed);
397 handles.push(thread::spawn(move || {
398 let mut miss = 0;
399 loop {
400 match wq_c.steal() {
401 Some(_) => {
402 cnt.fetch_add(1, Ordering::Relaxed);
403 miss = 0;
404 }
405 None => {
406 miss += 1;
407 if miss > 500 {
408 break;
409 }
410 }
411 }
412 }
413 }));
414 }
415 for h in handles {
416 h.join().expect("consumer panicked");
417 }
418
419 assert_eq!(consumed.load(Ordering::Relaxed), TOTAL);
420 }
421}