Skip to main content

philiprehberger_task_queue/
lib.rs

1//! In-process thread-based task queue with priority and concurrency control.
2//!
3//! This crate provides a simple task queue that runs closures on a pool of worker
4//! threads. Tasks can be submitted with different priorities, and higher-priority
5//! tasks are executed first.
6//!
7//! # Example
8//!
9//! ```
10//! use philiprehberger_task_queue::{TaskQueue, Priority};
11//!
12//! let queue = TaskQueue::new(2);
13//!
14//! let handle = queue.submit(|| 1 + 1);
15//! assert_eq!(handle.join().unwrap(), 2);
16//!
17//! let handle = queue.submit_with_priority(Priority::High, || "done");
18//! assert_eq!(handle.join().unwrap(), "done");
19//!
20//! queue.shutdown();
21//! ```
22
23use std::cmp::Ordering;
24use std::collections::BinaryHeap;
25use std::panic::{self, AssertUnwindSafe};
26use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
27use std::sync::{Arc, Condvar, Mutex};
28use std::thread;
29use std::time::{Duration, Instant};
30
31/// Task execution priority.
32///
33/// Higher-priority tasks are dequeued before lower-priority ones.
34#[derive(Debug, Clone, Copy, Eq, PartialEq)]
35pub enum Priority {
36    /// Lowest execution priority.
37    Low,
38    /// Default execution priority.
39    Normal,
40    /// Highest execution priority.
41    High,
42}
43
44impl Priority {
45    fn as_u8(self) -> u8 {
46        match self {
47            Priority::Low => 0,
48            Priority::Normal => 1,
49            Priority::High => 2,
50        }
51    }
52}
53
54impl Ord for Priority {
55    fn cmp(&self, other: &Self) -> Ordering {
56        self.as_u8().cmp(&other.as_u8())
57    }
58}
59
60impl PartialOrd for Priority {
61    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
62        Some(self.cmp(other))
63    }
64}
65
66/// Error returned when a task fails to produce a result.
67#[derive(Debug)]
68pub enum TaskError {
69    /// The task panicked during execution.
70    Panicked,
71    /// The task was cancelled because the queue shut down before it could run.
72    Cancelled,
73}
74
75impl std::fmt::Display for TaskError {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            TaskError::Panicked => write!(f, "task panicked"),
79            TaskError::Cancelled => write!(f, "task cancelled"),
80        }
81    }
82}
83
84impl std::error::Error for TaskError {}
85
86/// Snapshot of task queue statistics for observability.
87///
88/// Obtained via [`TaskQueue::stats`].
89#[derive(Debug, Clone)]
90pub struct TaskQueueStats {
91    /// Total number of tasks submitted to the queue.
92    pub total_submitted: u64,
93    /// Number of tasks that completed successfully.
94    pub completed: u64,
95    /// Number of tasks that failed (panicked).
96    pub failed: u64,
97    /// Number of tasks currently being executed by workers.
98    pub in_flight: u64,
99}
100
101/// Shared atomic counters used by the task queue for stats tracking.
102struct StatsCounters {
103    total_submitted: AtomicU64,
104    completed: AtomicU64,
105    failed: AtomicU64,
106    in_flight: AtomicU64,
107}
108
109impl StatsCounters {
110    fn new() -> Self {
111        Self {
112            total_submitted: AtomicU64::new(0),
113            completed: AtomicU64::new(0),
114            failed: AtomicU64::new(0),
115            in_flight: AtomicU64::new(0),
116        }
117    }
118}
119
120type CompletionCallback = dyn Fn(bool, Duration) + Send + Sync;
121
122/// A handle to a submitted task, used to retrieve the result.
123///
124/// # Example
125///
126/// ```
127/// use philiprehberger_task_queue::TaskQueue;
128///
129/// let queue = TaskQueue::new(1);
130/// let handle = queue.submit(|| 42);
131/// assert_eq!(handle.join().unwrap(), 42);
132/// queue.shutdown();
133/// ```
134pub struct TaskHandle<T> {
135    inner: Arc<TaskResultSlot<T>>,
136}
137
138struct TaskResultSlot<T> {
139    mutex: Mutex<Option<Result<T, TaskError>>>,
140    condvar: Condvar,
141}
142
143impl<T> TaskResultSlot<T> {
144    fn set(&self, value: Result<T, TaskError>) {
145        let mut guard = self.mutex.lock().unwrap();
146        *guard = Some(value);
147        self.condvar.notify_one();
148    }
149
150
151}
152
153impl<T> TaskHandle<T> {
154    /// Block until the task completes and return its result.
155    ///
156    /// Returns `Ok(value)` if the task completed successfully, or a [`TaskError`]
157    /// if the task panicked or was cancelled.
158    pub fn join(self) -> Result<T, TaskError> {
159        let mut guard = self.inner.mutex.lock().unwrap();
160        while guard.is_none() {
161            guard = self.inner.condvar.wait(guard).unwrap();
162        }
163        guard.take().unwrap()
164    }
165
166    /// Check whether the task has completed without blocking.
167    pub fn is_done(&self) -> bool {
168        self.inner.mutex.lock().unwrap().is_some()
169    }
170}
171
172/// Guard that sets `TaskError::Cancelled` on the result slot when dropped,
173/// unless the task has already completed. This ensures that `TaskHandle::join`
174/// never blocks forever if the task is dropped without running.
175struct CancelGuard<T> {
176    slot: Arc<TaskResultSlot<T>>,
177}
178
179impl<T> Drop for CancelGuard<T> {
180    fn drop(&mut self) {
181        let mut guard = self.slot.mutex.lock().unwrap();
182        if guard.is_none() {
183            *guard = Some(Err(TaskError::Cancelled));
184            self.slot.condvar.notify_one();
185        }
186    }
187}
188
189/// Returned by a task closure: signals the result slot after the worker
190/// has finished post-task bookkeeping (stats, callback).
191type TaskCompletion = Box<dyn FnOnce() + Send>;
192type BoxedTask = Box<dyn FnOnce() -> TaskCompletion + Send>;
193
194struct QueueEntry {
195    priority: Priority,
196    sequence: u64,
197    task: BoxedTask,
198}
199
200impl Eq for QueueEntry {}
201
202impl PartialEq for QueueEntry {
203    fn eq(&self, other: &Self) -> bool {
204        self.priority == other.priority && self.sequence == other.sequence
205    }
206}
207
208impl Ord for QueueEntry {
209    fn cmp(&self, other: &Self) -> Ordering {
210        self.priority
211            .cmp(&other.priority)
212            .then_with(|| other.sequence.cmp(&self.sequence))
213    }
214}
215
216impl PartialOrd for QueueEntry {
217    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
218        Some(self.cmp(other))
219    }
220}
221
222struct SharedState {
223    queue: BinaryHeap<QueueEntry>,
224    shutdown: bool,
225    draining: bool,
226    next_sequence: u64,
227}
228
229/// A thread-based task queue with configurable concurrency and priority scheduling.
230///
231/// Workers continuously pull the highest-priority task from the queue and execute it.
232/// When the queue is shut down, running tasks are allowed to complete but pending
233/// tasks are dropped (their handles will receive `TaskError::Cancelled`).
234///
235/// # Example
236///
237/// ```
238/// use philiprehberger_task_queue::{TaskQueue, Priority};
239///
240/// let queue = TaskQueue::new(2);
241///
242/// let h1 = queue.submit(|| 10);
243/// let h2 = queue.submit_with_priority(Priority::High, || 20);
244///
245/// assert_eq!(h1.join().unwrap(), 10);
246/// assert_eq!(h2.join().unwrap(), 20);
247///
248/// queue.shutdown();
249/// ```
250pub struct TaskQueue {
251    shared: Arc<(Mutex<SharedState>, Condvar)>,
252    workers: Option<Vec<thread::JoinHandle<()>>>,
253    stats: Arc<StatsCounters>,
254    callback: Arc<Mutex<Option<Arc<CompletionCallback>>>>,
255}
256
257impl TaskQueue {
258    /// Create a new task queue with the given number of worker threads.
259    ///
260    /// # Panics
261    ///
262    /// Panics if `concurrency` is zero.
263    pub fn new(concurrency: usize) -> Self {
264        assert!(concurrency > 0, "concurrency must be at least 1");
265
266        let shared = Arc::new((
267            Mutex::new(SharedState {
268                queue: BinaryHeap::new(),
269                shutdown: false,
270                draining: false,
271                next_sequence: 0,
272            }),
273            Condvar::new(),
274        ));
275
276        let stats = Arc::new(StatsCounters::new());
277        let callback: Arc<Mutex<Option<Arc<CompletionCallback>>>> = Arc::new(Mutex::new(None));
278
279        let mut workers = Vec::with_capacity(concurrency);
280        for _ in 0..concurrency {
281            let shared = Arc::clone(&shared);
282            let stats = Arc::clone(&stats);
283            let callback = Arc::clone(&callback);
284            let handle = thread::spawn(move || {
285                worker_loop(&shared, &stats, &callback);
286            });
287            workers.push(handle);
288        }
289
290        TaskQueue {
291            shared,
292            workers: Some(workers),
293            stats,
294            callback,
295        }
296    }
297
298    /// Submit a task with `Normal` priority.
299    ///
300    /// Returns a [`TaskHandle`] that can be used to retrieve the result.
301    pub fn submit<F, T>(&self, task: F) -> TaskHandle<T>
302    where
303        F: FnOnce() -> T + Send + 'static,
304        T: Send + 'static,
305    {
306        self.submit_with_priority(Priority::Normal, task)
307    }
308
309    /// Submit a task with the given priority.
310    ///
311    /// Higher-priority tasks are executed before lower-priority ones when
312    /// multiple tasks are waiting in the queue.
313    ///
314    /// Returns a [`TaskHandle`] that can be used to retrieve the result.
315    ///
316    /// If the queue is draining or shut down, the returned handle will
317    /// immediately yield `TaskError::Cancelled`.
318    pub fn submit_with_priority<F, T>(&self, priority: Priority, task: F) -> TaskHandle<T>
319    where
320        F: FnOnce() -> T + Send + 'static,
321        T: Send + 'static,
322    {
323        let slot = Arc::new(TaskResultSlot {
324            mutex: Mutex::new(None),
325            condvar: Condvar::new(),
326        });
327
328        // Reject submissions if draining or shut down.
329        {
330            let (ref mutex, _) = *self.shared;
331            let state = mutex.lock().unwrap();
332            if state.draining || state.shutdown {
333                slot.set(Err(TaskError::Cancelled));
334                return TaskHandle { inner: slot };
335            }
336        }
337
338        let cancel_guard = CancelGuard {
339            slot: Arc::clone(&slot),
340        };
341
342        let boxed: BoxedTask = Box::new(move || {
343            // The cancel guard is moved into the closure. If the closure runs,
344            // we explicitly set the result and then forget the guard so it
345            // doesn't overwrite with Cancelled. If the closure is dropped without
346            // running, the guard's Drop fires and sets Cancelled.
347            let outcome = panic::catch_unwind(AssertUnwindSafe(task));
348            let success = outcome.is_ok();
349            TASK_SUCCESS.with(|s| s.set(success));
350            let value = match outcome {
351                Ok(v) => Ok(v),
352                Err(_) => Err(TaskError::Panicked),
353            };
354            let slot = Arc::clone(&cancel_guard.slot);
355            // Prevent the Drop impl from overwriting the result with Cancelled
356            std::mem::forget(cancel_guard);
357            // Return a completion callback that the worker calls AFTER stats
358            // and on_complete callback, so join() doesn't return prematurely.
359            Box::new(move || slot.set(value))
360        });
361
362        self.stats.total_submitted.fetch_add(1, AtomicOrdering::Relaxed);
363
364        let (ref mutex, ref condvar) = *self.shared;
365        let mut state = mutex.lock().unwrap();
366        let sequence = state.next_sequence;
367        state.next_sequence += 1;
368        state.queue.push(QueueEntry {
369            priority,
370            sequence,
371            task: boxed,
372        });
373        condvar.notify_one();
374
375        TaskHandle { inner: slot }
376    }
377
378    /// Return a snapshot of task queue statistics.
379    ///
380    /// The counters are updated atomically as tasks are submitted, completed,
381    /// and failed, so successive calls may return different values.
382    ///
383    /// # Example
384    ///
385    /// ```
386    /// use philiprehberger_task_queue::TaskQueue;
387    ///
388    /// let queue = TaskQueue::new(1);
389    /// let handle = queue.submit(|| 1 + 1);
390    /// handle.join().unwrap();
391    ///
392    /// let stats = queue.stats();
393    /// assert_eq!(stats.total_submitted, 1);
394    /// assert_eq!(stats.completed, 1);
395    /// queue.shutdown();
396    /// ```
397    pub fn stats(&self) -> TaskQueueStats {
398        TaskQueueStats {
399            total_submitted: self.stats.total_submitted.load(AtomicOrdering::Relaxed),
400            completed: self.stats.completed.load(AtomicOrdering::Relaxed),
401            failed: self.stats.failed.load(AtomicOrdering::Relaxed),
402            in_flight: self.stats.in_flight.load(AtomicOrdering::Relaxed),
403        }
404    }
405
406    /// Stop accepting new tasks and wait for all queued and in-flight tasks to
407    /// complete.
408    ///
409    /// Unlike [`shutdown`](TaskQueue::shutdown), `drain` does **not** drop
410    /// pending tasks — every task that was already submitted will run to
411    /// completion. New submissions made after `drain` is called will be
412    /// immediately cancelled.
413    ///
414    /// This method blocks until the queue is empty and all workers are idle,
415    /// then shuts down the worker threads.
416    ///
417    /// # Example
418    ///
419    /// ```
420    /// use philiprehberger_task_queue::TaskQueue;
421    /// use std::sync::Arc;
422    /// use std::sync::atomic::{AtomicUsize, Ordering};
423    ///
424    /// let queue = TaskQueue::new(2);
425    /// let counter = Arc::new(AtomicUsize::new(0));
426    ///
427    /// for _ in 0..5 {
428    ///     let c = counter.clone();
429    ///     queue.submit(move || { c.fetch_add(1, Ordering::SeqCst); });
430    /// }
431    ///
432    /// queue.drain();
433    /// assert_eq!(counter.load(Ordering::SeqCst), 5);
434    /// ```
435    pub fn drain(mut self) {
436        self.do_drain();
437    }
438
439    fn do_drain(&mut self) {
440        let (ref mutex, ref condvar) = *self.shared;
441        {
442            let mut state = mutex.lock().unwrap();
443            state.draining = true;
444            // Do NOT clear the queue — let workers process everything.
445        }
446
447        // Wait until the queue is empty and no tasks are in-flight.
448        {
449            let mut state = mutex.lock().unwrap();
450            while !state.queue.is_empty()
451                || self.stats.in_flight.load(AtomicOrdering::SeqCst) > 0
452            {
453                state = condvar.wait(state).unwrap();
454            }
455        }
456
457        // Now perform a normal shutdown (workers will exit because queue is
458        // empty and shutdown flag is set).
459        self.do_shutdown();
460    }
461
462    /// Register a callback that fires after each task completes.
463    ///
464    /// The callback receives two arguments:
465    /// - `success` — `true` if the task completed without panicking, `false` otherwise.
466    /// - `duration` — wall-clock time the task took to execute.
467    ///
468    /// Only one callback may be active at a time; calling this again replaces
469    /// the previous callback.
470    ///
471    /// # Example
472    ///
473    /// ```
474    /// use philiprehberger_task_queue::TaskQueue;
475    /// use std::sync::Arc;
476    /// use std::sync::atomic::{AtomicUsize, Ordering};
477    ///
478    /// let queue = TaskQueue::new(1);
479    /// let count = Arc::new(AtomicUsize::new(0));
480    /// let c = count.clone();
481    /// queue.on_complete(move |_success, _dur| {
482    ///     c.fetch_add(1, Ordering::SeqCst);
483    /// });
484    ///
485    /// queue.submit(|| 42).join().unwrap();
486    /// assert_eq!(count.load(Ordering::SeqCst), 1);
487    /// queue.shutdown();
488    /// ```
489    pub fn on_complete<F>(&self, callback: F)
490    where
491        F: Fn(bool, Duration) + Send + Sync + 'static,
492    {
493        let mut guard = self.callback.lock().unwrap();
494        *guard = Some(Arc::new(callback));
495    }
496
497    /// Shut down the task queue.
498    ///
499    /// Signals all workers to stop, waits for currently running tasks to finish,
500    /// and drops any pending tasks. Pending task handles will receive
501    /// `TaskError::Cancelled` when joined.
502    pub fn shutdown(mut self) {
503        self.do_shutdown();
504    }
505
506    fn do_shutdown(&mut self) {
507        let (ref mutex, ref condvar) = *self.shared;
508
509        {
510            let mut state = mutex.lock().unwrap();
511            state.shutdown = true;
512            condvar.notify_all();
513            // Drain the queue — dropping each entry drops its closure, which
514            // drops the CancelGuard, which sets TaskError::Cancelled on the slot.
515            state.queue.clear();
516        }
517
518        if let Some(workers) = self.workers.take() {
519            for w in workers {
520                let _ = w.join();
521            }
522        }
523    }
524}
525
526impl Drop for TaskQueue {
527    fn drop(&mut self) {
528        let (ref mutex, ref condvar) = *self.shared;
529        {
530            let mut state = mutex.lock().unwrap();
531            if !state.shutdown {
532                state.shutdown = true;
533                if !state.draining {
534                    state.queue.clear();
535                }
536                condvar.notify_all();
537            }
538        }
539        if let Some(workers) = self.workers.take() {
540            for w in workers {
541                let _ = w.join();
542            }
543        }
544    }
545}
546
547thread_local! {
548    /// Used by the task closure to communicate success/failure to the worker loop.
549    static TASK_SUCCESS: std::cell::Cell<bool> = const { std::cell::Cell::new(true) };
550}
551
552fn worker_loop(
553    shared: &(Mutex<SharedState>, Condvar),
554    stats: &StatsCounters,
555    callback: &Mutex<Option<Arc<CompletionCallback>>>,
556) {
557    let (ref mutex, ref condvar) = *shared;
558    loop {
559        let task = {
560            let mut state = mutex.lock().unwrap();
561            loop {
562                if let Some(entry) = state.queue.pop() {
563                    break Some(entry.task);
564                }
565                if state.shutdown || (state.draining && state.queue.is_empty()) {
566                    break None;
567                }
568                state = condvar.wait(state).unwrap();
569            }
570        };
571        match task {
572            Some(task) => {
573                stats.in_flight.fetch_add(1, AtomicOrdering::SeqCst);
574                let start = Instant::now();
575                let completion = task();
576                let elapsed = start.elapsed();
577                stats.in_flight.fetch_sub(1, AtomicOrdering::SeqCst);
578
579                // The task closure uses catch_unwind internally and communicates
580                // success/failure via a thread-local, since the boxed closure
581                // always returns () without panicking.
582                let success = TASK_SUCCESS.with(|s| s.get());
583                if success {
584                    stats.completed.fetch_add(1, AtomicOrdering::Relaxed);
585                } else {
586                    stats.failed.fetch_add(1, AtomicOrdering::Relaxed);
587                }
588
589                // Fire the on_complete callback if registered.
590                if let Ok(guard) = callback.lock() {
591                    if let Some(ref cb) = *guard {
592                        cb(success, elapsed);
593                    }
594                }
595
596                // Now set the result and notify the TaskHandle — this ensures
597                // stats and callback have both completed before join() returns.
598                completion();
599
600                // Notify condvar so drain() can check progress.
601                condvar.notify_all();
602            }
603            None => return,
604        }
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611    use std::sync::atomic::{AtomicUsize, Ordering};
612    use std::sync::mpsc;
613    use std::sync::Barrier;
614    use std::time::Duration;
615
616    #[test]
617    fn submit_and_join() {
618        let queue = TaskQueue::new(1);
619        let handle = queue.submit(|| 42);
620        assert_eq!(handle.join().unwrap(), 42);
621        queue.shutdown();
622    }
623
624    #[test]
625    fn submit_multiple_tasks_all_complete() {
626        let queue = TaskQueue::new(2);
627        let handles: Vec<_> = (0..10).map(|i| queue.submit(move || i * 2)).collect();
628        let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
629        for (i, r) in results.iter().enumerate() {
630            assert_eq!(*r, i * 2);
631        }
632        queue.shutdown();
633    }
634
635    #[test]
636    fn priority_ordering() {
637        let queue = TaskQueue::new(1);
638        let barrier = Arc::new(Barrier::new(2));
639        let order = Arc::new(Mutex::new(Vec::new()));
640
641        // Block the single worker
642        let b = barrier.clone();
643        queue.submit(move || {
644            b.wait();
645        });
646
647        // Give the worker time to pick up the blocking task
648        thread::sleep(Duration::from_millis(50));
649
650        // Now submit tasks with different priorities — they'll queue up
651        let o = order.clone();
652        let h_low = queue.submit_with_priority(Priority::Low, move || {
653            o.lock().unwrap().push("low");
654        });
655
656        let o = order.clone();
657        let h_high = queue.submit_with_priority(Priority::High, move || {
658            o.lock().unwrap().push("high");
659        });
660
661        let o = order.clone();
662        let h_normal = queue.submit_with_priority(Priority::Normal, move || {
663            o.lock().unwrap().push("normal");
664        });
665
666        // Unblock the worker
667        barrier.wait();
668
669        // Wait for all tasks
670        h_low.join().unwrap();
671        h_high.join().unwrap();
672        h_normal.join().unwrap();
673
674        let final_order = order.lock().unwrap();
675        assert_eq!(*final_order, vec!["high", "normal", "low"]);
676
677        queue.shutdown();
678    }
679
680    #[test]
681    fn is_done_returns_false_then_true() {
682        let queue = TaskQueue::new(1);
683        let barrier = Arc::new(Barrier::new(2));
684
685        let b = barrier.clone();
686        let handle = queue.submit(move || {
687            b.wait();
688            99
689        });
690
691        // Task is blocked, so not done yet
692        assert!(!handle.is_done());
693
694        // Unblock the task
695        barrier.wait();
696
697        // Wait for completion
698        let result = handle.join().unwrap();
699        assert_eq!(result, 99);
700
701        queue.shutdown();
702    }
703
704    #[test]
705    fn shutdown_completes_running_tasks() {
706        let queue = TaskQueue::new(1);
707        let (tx, rx) = mpsc::channel();
708
709        queue.submit(move || {
710            thread::sleep(Duration::from_millis(50));
711            tx.send(true).unwrap();
712        });
713
714        // Give the worker time to start the task
715        thread::sleep(Duration::from_millis(10));
716
717        // Shutdown should wait for the running task
718        queue.shutdown();
719
720        // The task should have completed
721        assert!(rx.recv_timeout(Duration::from_millis(100)).unwrap());
722    }
723
724    #[test]
725    fn panicking_task_returns_panicked_error() {
726        let queue = TaskQueue::new(1);
727        let handle = queue.submit(|| {
728            panic!("intentional panic");
729        });
730        match handle.join() {
731            Err(TaskError::Panicked) => {}
732            other => panic!("expected TaskError::Panicked, got {:?}", other.err()),
733        }
734
735        // Queue should still work after a panic
736        let handle = queue.submit(|| 123);
737        assert_eq!(handle.join().unwrap(), 123);
738
739        queue.shutdown();
740    }
741
742    #[test]
743    fn concurrency_limit_is_respected() {
744        let concurrency = 3;
745        let queue = TaskQueue::new(concurrency);
746        let running = Arc::new(AtomicUsize::new(0));
747        let max_running = Arc::new(AtomicUsize::new(0));
748
749        let mut handles = Vec::new();
750        for _ in 0..concurrency * 2 {
751            let r = running.clone();
752            let m = max_running.clone();
753            handles.push(queue.submit(move || {
754                let current = r.fetch_add(1, Ordering::SeqCst) + 1;
755                // Update max using compare-and-swap loop
756                loop {
757                    let prev_max = m.load(Ordering::SeqCst);
758                    if current <= prev_max {
759                        break;
760                    }
761                    if m.compare_exchange(prev_max, current, Ordering::SeqCst, Ordering::SeqCst)
762                        .is_ok()
763                    {
764                        break;
765                    }
766                }
767                thread::sleep(Duration::from_millis(50));
768                r.fetch_sub(1, Ordering::SeqCst);
769            }));
770        }
771
772        for h in handles {
773            h.join().unwrap();
774        }
775
776        let observed_max = max_running.load(Ordering::SeqCst);
777        assert!(
778            observed_max <= concurrency,
779            "max concurrent tasks ({observed_max}) exceeded concurrency limit ({concurrency})"
780        );
781
782        queue.shutdown();
783    }
784
785    #[test]
786    fn stats_tracks_submitted_and_completed() {
787        let queue = TaskQueue::new(2);
788
789        let handles: Vec<_> = (0..5).map(|i| queue.submit(move || i)).collect();
790        for h in handles {
791            h.join().unwrap();
792        }
793
794        let s = queue.stats();
795        assert_eq!(s.total_submitted, 5);
796        assert_eq!(s.completed, 5);
797        assert_eq!(s.failed, 0);
798        assert_eq!(s.in_flight, 0);
799
800        queue.shutdown();
801    }
802
803    #[test]
804    fn stats_tracks_failures() {
805        let queue = TaskQueue::new(1);
806
807        let h1 = queue.submit(|| panic!("boom"));
808        let _ = h1.join(); // Err(Panicked)
809
810        let h2 = queue.submit(|| 42);
811        h2.join().unwrap();
812
813        let s = queue.stats();
814        assert_eq!(s.total_submitted, 2);
815        assert_eq!(s.completed, 1);
816        assert_eq!(s.failed, 1);
817
818        queue.shutdown();
819    }
820
821    #[test]
822    fn drain_completes_all_pending_tasks() {
823        let queue = TaskQueue::new(1);
824        let counter = Arc::new(AtomicUsize::new(0));
825
826        for _ in 0..10 {
827            let c = counter.clone();
828            queue.submit(move || {
829                c.fetch_add(1, Ordering::SeqCst);
830            });
831        }
832
833        queue.drain();
834        assert_eq!(counter.load(Ordering::SeqCst), 10);
835    }
836
837    #[test]
838    fn drain_rejects_new_submissions() {
839        let queue = TaskQueue::new(1);
840        let barrier = Arc::new(Barrier::new(2));
841
842        // Block the worker so we can call drain from another context
843        let b = barrier.clone();
844        queue.submit(move || {
845            b.wait();
846        });
847
848        // Give the worker time to pick up the task
849        thread::sleep(Duration::from_millis(50));
850
851        // Submit a task that should be queued
852        let counter = Arc::new(AtomicUsize::new(0));
853        let c = counter.clone();
854        queue.submit(move || {
855            c.fetch_add(1, Ordering::SeqCst);
856        });
857
858        // We need to set draining and then unblock. Since drain() consumes self,
859        // we test the rejection behavior differently: submit after drain finishes
860        // is not possible (self consumed). Instead, verify that drain processes
861        // all queued tasks.
862        barrier.wait();
863        queue.drain();
864        assert_eq!(counter.load(Ordering::SeqCst), 1);
865    }
866
867    #[test]
868    fn on_complete_callback_fires_on_success() {
869        let queue = TaskQueue::new(1);
870        let call_count = Arc::new(AtomicUsize::new(0));
871        let success_count = Arc::new(AtomicUsize::new(0));
872
873        let cc = call_count.clone();
874        let sc = success_count.clone();
875        queue.on_complete(move |success, dur| {
876            cc.fetch_add(1, Ordering::SeqCst);
877            if success {
878                sc.fetch_add(1, Ordering::SeqCst);
879            }
880            assert!(dur.as_nanos() > 0);
881        });
882
883        let h = queue.submit(|| 42);
884        h.join().unwrap();
885
886        assert_eq!(call_count.load(Ordering::SeqCst), 1);
887        assert_eq!(success_count.load(Ordering::SeqCst), 1);
888
889        queue.shutdown();
890    }
891
892    #[test]
893    fn on_complete_callback_fires_on_failure() {
894        let queue = TaskQueue::new(1);
895        let failure_count = Arc::new(AtomicUsize::new(0));
896
897        let fc = failure_count.clone();
898        queue.on_complete(move |success, _dur| {
899            if !success {
900                fc.fetch_add(1, Ordering::SeqCst);
901            }
902        });
903
904        let h = queue.submit(|| panic!("intentional"));
905        let _ = h.join();
906
907        assert_eq!(failure_count.load(Ordering::SeqCst), 1);
908
909        queue.shutdown();
910    }
911
912    #[test]
913    fn on_complete_callback_reports_duration() {
914        let queue = TaskQueue::new(1);
915        let observed_duration = Arc::new(Mutex::new(Duration::ZERO));
916
917        let od = observed_duration.clone();
918        queue.on_complete(move |_success, dur| {
919            *od.lock().unwrap() = dur;
920        });
921
922        let h = queue.submit(|| {
923            thread::sleep(Duration::from_millis(50));
924        });
925        h.join().unwrap();
926
927        let dur = *observed_duration.lock().unwrap();
928        assert!(dur >= Duration::from_millis(40), "duration was {dur:?}");
929
930        queue.shutdown();
931    }
932
933    #[test]
934    fn replacing_callback() {
935        let queue = TaskQueue::new(1);
936        let first_count = Arc::new(AtomicUsize::new(0));
937        let second_count = Arc::new(AtomicUsize::new(0));
938
939        let fc = first_count.clone();
940        queue.on_complete(move |_, _| {
941            fc.fetch_add(1, Ordering::SeqCst);
942        });
943
944        queue.submit(|| {}).join().unwrap();
945
946        let sc = second_count.clone();
947        queue.on_complete(move |_, _| {
948            sc.fetch_add(1, Ordering::SeqCst);
949        });
950
951        queue.submit(|| {}).join().unwrap();
952
953        assert_eq!(first_count.load(Ordering::SeqCst), 1);
954        assert_eq!(second_count.load(Ordering::SeqCst), 1);
955
956        queue.shutdown();
957    }
958}