philiprehberger_task_queue/
lib.rs1use std::cmp::Ordering;
24use std::collections::BinaryHeap;
25use std::panic::{self, AssertUnwindSafe};
26use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
27use std::sync::{Arc, Condvar, Mutex};
28use std::thread;
29use std::time::{Duration, Instant};
30
31#[derive(Debug, Clone, Copy, Eq, PartialEq)]
35pub enum Priority {
36 Low,
38 Normal,
40 High,
42}
43
44impl Priority {
45 fn as_u8(self) -> u8 {
46 match self {
47 Priority::Low => 0,
48 Priority::Normal => 1,
49 Priority::High => 2,
50 }
51 }
52}
53
54impl Ord for Priority {
55 fn cmp(&self, other: &Self) -> Ordering {
56 self.as_u8().cmp(&other.as_u8())
57 }
58}
59
60impl PartialOrd for Priority {
61 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
62 Some(self.cmp(other))
63 }
64}
65
66#[derive(Debug)]
68pub enum TaskError {
69 Panicked,
71 Cancelled,
73}
74
75impl std::fmt::Display for TaskError {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 TaskError::Panicked => write!(f, "task panicked"),
79 TaskError::Cancelled => write!(f, "task cancelled"),
80 }
81 }
82}
83
84impl std::error::Error for TaskError {}
85
86#[derive(Debug, Clone)]
90pub struct TaskQueueStats {
91 pub total_submitted: u64,
93 pub completed: u64,
95 pub failed: u64,
97 pub in_flight: u64,
99}
100
101struct StatsCounters {
103 total_submitted: AtomicU64,
104 completed: AtomicU64,
105 failed: AtomicU64,
106 in_flight: AtomicU64,
107}
108
109impl StatsCounters {
110 fn new() -> Self {
111 Self {
112 total_submitted: AtomicU64::new(0),
113 completed: AtomicU64::new(0),
114 failed: AtomicU64::new(0),
115 in_flight: AtomicU64::new(0),
116 }
117 }
118}
119
120type CompletionCallback = dyn Fn(bool, Duration) + Send + Sync;
121
122pub struct TaskHandle<T> {
135 inner: Arc<TaskResultSlot<T>>,
136}
137
138struct TaskResultSlot<T> {
139 mutex: Mutex<Option<Result<T, TaskError>>>,
140 condvar: Condvar,
141}
142
143impl<T> TaskResultSlot<T> {
144 fn set(&self, value: Result<T, TaskError>) {
145 let mut guard = self.mutex.lock().unwrap();
146 *guard = Some(value);
147 self.condvar.notify_one();
148 }
149
150
151}
152
153impl<T> TaskHandle<T> {
154 pub fn join(self) -> Result<T, TaskError> {
159 let mut guard = self.inner.mutex.lock().unwrap();
160 while guard.is_none() {
161 guard = self.inner.condvar.wait(guard).unwrap();
162 }
163 guard.take().unwrap()
164 }
165
166 pub fn is_done(&self) -> bool {
168 self.inner.mutex.lock().unwrap().is_some()
169 }
170}
171
172struct CancelGuard<T> {
176 slot: Arc<TaskResultSlot<T>>,
177}
178
179impl<T> Drop for CancelGuard<T> {
180 fn drop(&mut self) {
181 let mut guard = self.slot.mutex.lock().unwrap();
182 if guard.is_none() {
183 *guard = Some(Err(TaskError::Cancelled));
184 self.slot.condvar.notify_one();
185 }
186 }
187}
188
189type TaskCompletion = Box<dyn FnOnce() + Send>;
192type BoxedTask = Box<dyn FnOnce() -> TaskCompletion + Send>;
193
194struct QueueEntry {
195 priority: Priority,
196 sequence: u64,
197 task: BoxedTask,
198}
199
200impl Eq for QueueEntry {}
201
202impl PartialEq for QueueEntry {
203 fn eq(&self, other: &Self) -> bool {
204 self.priority == other.priority && self.sequence == other.sequence
205 }
206}
207
208impl Ord for QueueEntry {
209 fn cmp(&self, other: &Self) -> Ordering {
210 self.priority
211 .cmp(&other.priority)
212 .then_with(|| other.sequence.cmp(&self.sequence))
213 }
214}
215
216impl PartialOrd for QueueEntry {
217 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
218 Some(self.cmp(other))
219 }
220}
221
222struct SharedState {
223 queue: BinaryHeap<QueueEntry>,
224 shutdown: bool,
225 draining: bool,
226 next_sequence: u64,
227}
228
229pub struct TaskQueue {
251 shared: Arc<(Mutex<SharedState>, Condvar)>,
252 workers: Option<Vec<thread::JoinHandle<()>>>,
253 stats: Arc<StatsCounters>,
254 callback: Arc<Mutex<Option<Arc<CompletionCallback>>>>,
255}
256
257impl TaskQueue {
258 pub fn new(concurrency: usize) -> Self {
264 assert!(concurrency > 0, "concurrency must be at least 1");
265
266 let shared = Arc::new((
267 Mutex::new(SharedState {
268 queue: BinaryHeap::new(),
269 shutdown: false,
270 draining: false,
271 next_sequence: 0,
272 }),
273 Condvar::new(),
274 ));
275
276 let stats = Arc::new(StatsCounters::new());
277 let callback: Arc<Mutex<Option<Arc<CompletionCallback>>>> = Arc::new(Mutex::new(None));
278
279 let mut workers = Vec::with_capacity(concurrency);
280 for _ in 0..concurrency {
281 let shared = Arc::clone(&shared);
282 let stats = Arc::clone(&stats);
283 let callback = Arc::clone(&callback);
284 let handle = thread::spawn(move || {
285 worker_loop(&shared, &stats, &callback);
286 });
287 workers.push(handle);
288 }
289
290 TaskQueue {
291 shared,
292 workers: Some(workers),
293 stats,
294 callback,
295 }
296 }
297
298 pub fn submit<F, T>(&self, task: F) -> TaskHandle<T>
302 where
303 F: FnOnce() -> T + Send + 'static,
304 T: Send + 'static,
305 {
306 self.submit_with_priority(Priority::Normal, task)
307 }
308
309 pub fn submit_with_priority<F, T>(&self, priority: Priority, task: F) -> TaskHandle<T>
319 where
320 F: FnOnce() -> T + Send + 'static,
321 T: Send + 'static,
322 {
323 let slot = Arc::new(TaskResultSlot {
324 mutex: Mutex::new(None),
325 condvar: Condvar::new(),
326 });
327
328 {
330 let (ref mutex, _) = *self.shared;
331 let state = mutex.lock().unwrap();
332 if state.draining || state.shutdown {
333 slot.set(Err(TaskError::Cancelled));
334 return TaskHandle { inner: slot };
335 }
336 }
337
338 let cancel_guard = CancelGuard {
339 slot: Arc::clone(&slot),
340 };
341
342 let boxed: BoxedTask = Box::new(move || {
343 let outcome = panic::catch_unwind(AssertUnwindSafe(task));
348 let success = outcome.is_ok();
349 TASK_SUCCESS.with(|s| s.set(success));
350 let value = match outcome {
351 Ok(v) => Ok(v),
352 Err(_) => Err(TaskError::Panicked),
353 };
354 let slot = Arc::clone(&cancel_guard.slot);
355 std::mem::forget(cancel_guard);
357 Box::new(move || slot.set(value))
360 });
361
362 self.stats.total_submitted.fetch_add(1, AtomicOrdering::Relaxed);
363
364 let (ref mutex, ref condvar) = *self.shared;
365 let mut state = mutex.lock().unwrap();
366 let sequence = state.next_sequence;
367 state.next_sequence += 1;
368 state.queue.push(QueueEntry {
369 priority,
370 sequence,
371 task: boxed,
372 });
373 condvar.notify_one();
374
375 TaskHandle { inner: slot }
376 }
377
378 pub fn stats(&self) -> TaskQueueStats {
398 TaskQueueStats {
399 total_submitted: self.stats.total_submitted.load(AtomicOrdering::Relaxed),
400 completed: self.stats.completed.load(AtomicOrdering::Relaxed),
401 failed: self.stats.failed.load(AtomicOrdering::Relaxed),
402 in_flight: self.stats.in_flight.load(AtomicOrdering::Relaxed),
403 }
404 }
405
406 pub fn drain(mut self) {
436 self.do_drain();
437 }
438
439 fn do_drain(&mut self) {
440 let (ref mutex, ref condvar) = *self.shared;
441 {
442 let mut state = mutex.lock().unwrap();
443 state.draining = true;
444 }
446
447 {
449 let mut state = mutex.lock().unwrap();
450 while !state.queue.is_empty()
451 || self.stats.in_flight.load(AtomicOrdering::SeqCst) > 0
452 {
453 state = condvar.wait(state).unwrap();
454 }
455 }
456
457 self.do_shutdown();
460 }
461
462 pub fn on_complete<F>(&self, callback: F)
490 where
491 F: Fn(bool, Duration) + Send + Sync + 'static,
492 {
493 let mut guard = self.callback.lock().unwrap();
494 *guard = Some(Arc::new(callback));
495 }
496
497 pub fn shutdown(mut self) {
503 self.do_shutdown();
504 }
505
506 fn do_shutdown(&mut self) {
507 let (ref mutex, ref condvar) = *self.shared;
508
509 {
510 let mut state = mutex.lock().unwrap();
511 state.shutdown = true;
512 condvar.notify_all();
513 state.queue.clear();
516 }
517
518 if let Some(workers) = self.workers.take() {
519 for w in workers {
520 let _ = w.join();
521 }
522 }
523 }
524}
525
526impl Drop for TaskQueue {
527 fn drop(&mut self) {
528 let (ref mutex, ref condvar) = *self.shared;
529 {
530 let mut state = mutex.lock().unwrap();
531 if !state.shutdown {
532 state.shutdown = true;
533 if !state.draining {
534 state.queue.clear();
535 }
536 condvar.notify_all();
537 }
538 }
539 if let Some(workers) = self.workers.take() {
540 for w in workers {
541 let _ = w.join();
542 }
543 }
544 }
545}
546
547thread_local! {
548 static TASK_SUCCESS: std::cell::Cell<bool> = const { std::cell::Cell::new(true) };
550}
551
552fn worker_loop(
553 shared: &(Mutex<SharedState>, Condvar),
554 stats: &StatsCounters,
555 callback: &Mutex<Option<Arc<CompletionCallback>>>,
556) {
557 let (ref mutex, ref condvar) = *shared;
558 loop {
559 let task = {
560 let mut state = mutex.lock().unwrap();
561 loop {
562 if let Some(entry) = state.queue.pop() {
563 break Some(entry.task);
564 }
565 if state.shutdown || (state.draining && state.queue.is_empty()) {
566 break None;
567 }
568 state = condvar.wait(state).unwrap();
569 }
570 };
571 match task {
572 Some(task) => {
573 stats.in_flight.fetch_add(1, AtomicOrdering::SeqCst);
574 let start = Instant::now();
575 let completion = task();
576 let elapsed = start.elapsed();
577 stats.in_flight.fetch_sub(1, AtomicOrdering::SeqCst);
578
579 let success = TASK_SUCCESS.with(|s| s.get());
583 if success {
584 stats.completed.fetch_add(1, AtomicOrdering::Relaxed);
585 } else {
586 stats.failed.fetch_add(1, AtomicOrdering::Relaxed);
587 }
588
589 if let Ok(guard) = callback.lock() {
591 if let Some(ref cb) = *guard {
592 cb(success, elapsed);
593 }
594 }
595
596 completion();
599
600 condvar.notify_all();
602 }
603 None => return,
604 }
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use std::sync::atomic::{AtomicUsize, Ordering};
612 use std::sync::mpsc;
613 use std::sync::Barrier;
614 use std::time::Duration;
615
616 #[test]
617 fn submit_and_join() {
618 let queue = TaskQueue::new(1);
619 let handle = queue.submit(|| 42);
620 assert_eq!(handle.join().unwrap(), 42);
621 queue.shutdown();
622 }
623
624 #[test]
625 fn submit_multiple_tasks_all_complete() {
626 let queue = TaskQueue::new(2);
627 let handles: Vec<_> = (0..10).map(|i| queue.submit(move || i * 2)).collect();
628 let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
629 for (i, r) in results.iter().enumerate() {
630 assert_eq!(*r, i * 2);
631 }
632 queue.shutdown();
633 }
634
635 #[test]
636 fn priority_ordering() {
637 let queue = TaskQueue::new(1);
638 let barrier = Arc::new(Barrier::new(2));
639 let order = Arc::new(Mutex::new(Vec::new()));
640
641 let b = barrier.clone();
643 queue.submit(move || {
644 b.wait();
645 });
646
647 thread::sleep(Duration::from_millis(50));
649
650 let o = order.clone();
652 let h_low = queue.submit_with_priority(Priority::Low, move || {
653 o.lock().unwrap().push("low");
654 });
655
656 let o = order.clone();
657 let h_high = queue.submit_with_priority(Priority::High, move || {
658 o.lock().unwrap().push("high");
659 });
660
661 let o = order.clone();
662 let h_normal = queue.submit_with_priority(Priority::Normal, move || {
663 o.lock().unwrap().push("normal");
664 });
665
666 barrier.wait();
668
669 h_low.join().unwrap();
671 h_high.join().unwrap();
672 h_normal.join().unwrap();
673
674 let final_order = order.lock().unwrap();
675 assert_eq!(*final_order, vec!["high", "normal", "low"]);
676
677 queue.shutdown();
678 }
679
680 #[test]
681 fn is_done_returns_false_then_true() {
682 let queue = TaskQueue::new(1);
683 let barrier = Arc::new(Barrier::new(2));
684
685 let b = barrier.clone();
686 let handle = queue.submit(move || {
687 b.wait();
688 99
689 });
690
691 assert!(!handle.is_done());
693
694 barrier.wait();
696
697 let result = handle.join().unwrap();
699 assert_eq!(result, 99);
700
701 queue.shutdown();
702 }
703
704 #[test]
705 fn shutdown_completes_running_tasks() {
706 let queue = TaskQueue::new(1);
707 let (tx, rx) = mpsc::channel();
708
709 queue.submit(move || {
710 thread::sleep(Duration::from_millis(50));
711 tx.send(true).unwrap();
712 });
713
714 thread::sleep(Duration::from_millis(10));
716
717 queue.shutdown();
719
720 assert!(rx.recv_timeout(Duration::from_millis(100)).unwrap());
722 }
723
724 #[test]
725 fn panicking_task_returns_panicked_error() {
726 let queue = TaskQueue::new(1);
727 let handle = queue.submit(|| {
728 panic!("intentional panic");
729 });
730 match handle.join() {
731 Err(TaskError::Panicked) => {}
732 other => panic!("expected TaskError::Panicked, got {:?}", other.err()),
733 }
734
735 let handle = queue.submit(|| 123);
737 assert_eq!(handle.join().unwrap(), 123);
738
739 queue.shutdown();
740 }
741
742 #[test]
743 fn concurrency_limit_is_respected() {
744 let concurrency = 3;
745 let queue = TaskQueue::new(concurrency);
746 let running = Arc::new(AtomicUsize::new(0));
747 let max_running = Arc::new(AtomicUsize::new(0));
748
749 let mut handles = Vec::new();
750 for _ in 0..concurrency * 2 {
751 let r = running.clone();
752 let m = max_running.clone();
753 handles.push(queue.submit(move || {
754 let current = r.fetch_add(1, Ordering::SeqCst) + 1;
755 loop {
757 let prev_max = m.load(Ordering::SeqCst);
758 if current <= prev_max {
759 break;
760 }
761 if m.compare_exchange(prev_max, current, Ordering::SeqCst, Ordering::SeqCst)
762 .is_ok()
763 {
764 break;
765 }
766 }
767 thread::sleep(Duration::from_millis(50));
768 r.fetch_sub(1, Ordering::SeqCst);
769 }));
770 }
771
772 for h in handles {
773 h.join().unwrap();
774 }
775
776 let observed_max = max_running.load(Ordering::SeqCst);
777 assert!(
778 observed_max <= concurrency,
779 "max concurrent tasks ({observed_max}) exceeded concurrency limit ({concurrency})"
780 );
781
782 queue.shutdown();
783 }
784
785 #[test]
786 fn stats_tracks_submitted_and_completed() {
787 let queue = TaskQueue::new(2);
788
789 let handles: Vec<_> = (0..5).map(|i| queue.submit(move || i)).collect();
790 for h in handles {
791 h.join().unwrap();
792 }
793
794 let s = queue.stats();
795 assert_eq!(s.total_submitted, 5);
796 assert_eq!(s.completed, 5);
797 assert_eq!(s.failed, 0);
798 assert_eq!(s.in_flight, 0);
799
800 queue.shutdown();
801 }
802
803 #[test]
804 fn stats_tracks_failures() {
805 let queue = TaskQueue::new(1);
806
807 let h1 = queue.submit(|| panic!("boom"));
808 let _ = h1.join(); let h2 = queue.submit(|| 42);
811 h2.join().unwrap();
812
813 let s = queue.stats();
814 assert_eq!(s.total_submitted, 2);
815 assert_eq!(s.completed, 1);
816 assert_eq!(s.failed, 1);
817
818 queue.shutdown();
819 }
820
821 #[test]
822 fn drain_completes_all_pending_tasks() {
823 let queue = TaskQueue::new(1);
824 let counter = Arc::new(AtomicUsize::new(0));
825
826 for _ in 0..10 {
827 let c = counter.clone();
828 queue.submit(move || {
829 c.fetch_add(1, Ordering::SeqCst);
830 });
831 }
832
833 queue.drain();
834 assert_eq!(counter.load(Ordering::SeqCst), 10);
835 }
836
837 #[test]
838 fn drain_rejects_new_submissions() {
839 let queue = TaskQueue::new(1);
840 let barrier = Arc::new(Barrier::new(2));
841
842 let b = barrier.clone();
844 queue.submit(move || {
845 b.wait();
846 });
847
848 thread::sleep(Duration::from_millis(50));
850
851 let counter = Arc::new(AtomicUsize::new(0));
853 let c = counter.clone();
854 queue.submit(move || {
855 c.fetch_add(1, Ordering::SeqCst);
856 });
857
858 barrier.wait();
863 queue.drain();
864 assert_eq!(counter.load(Ordering::SeqCst), 1);
865 }
866
867 #[test]
868 fn on_complete_callback_fires_on_success() {
869 let queue = TaskQueue::new(1);
870 let call_count = Arc::new(AtomicUsize::new(0));
871 let success_count = Arc::new(AtomicUsize::new(0));
872
873 let cc = call_count.clone();
874 let sc = success_count.clone();
875 queue.on_complete(move |success, dur| {
876 cc.fetch_add(1, Ordering::SeqCst);
877 if success {
878 sc.fetch_add(1, Ordering::SeqCst);
879 }
880 assert!(dur.as_nanos() > 0);
881 });
882
883 let h = queue.submit(|| 42);
884 h.join().unwrap();
885
886 assert_eq!(call_count.load(Ordering::SeqCst), 1);
887 assert_eq!(success_count.load(Ordering::SeqCst), 1);
888
889 queue.shutdown();
890 }
891
892 #[test]
893 fn on_complete_callback_fires_on_failure() {
894 let queue = TaskQueue::new(1);
895 let failure_count = Arc::new(AtomicUsize::new(0));
896
897 let fc = failure_count.clone();
898 queue.on_complete(move |success, _dur| {
899 if !success {
900 fc.fetch_add(1, Ordering::SeqCst);
901 }
902 });
903
904 let h = queue.submit(|| panic!("intentional"));
905 let _ = h.join();
906
907 assert_eq!(failure_count.load(Ordering::SeqCst), 1);
908
909 queue.shutdown();
910 }
911
912 #[test]
913 fn on_complete_callback_reports_duration() {
914 let queue = TaskQueue::new(1);
915 let observed_duration = Arc::new(Mutex::new(Duration::ZERO));
916
917 let od = observed_duration.clone();
918 queue.on_complete(move |_success, dur| {
919 *od.lock().unwrap() = dur;
920 });
921
922 let h = queue.submit(|| {
923 thread::sleep(Duration::from_millis(50));
924 });
925 h.join().unwrap();
926
927 let dur = *observed_duration.lock().unwrap();
928 assert!(dur >= Duration::from_millis(40), "duration was {dur:?}");
929
930 queue.shutdown();
931 }
932
933 #[test]
934 fn replacing_callback() {
935 let queue = TaskQueue::new(1);
936 let first_count = Arc::new(AtomicUsize::new(0));
937 let second_count = Arc::new(AtomicUsize::new(0));
938
939 let fc = first_count.clone();
940 queue.on_complete(move |_, _| {
941 fc.fetch_add(1, Ordering::SeqCst);
942 });
943
944 queue.submit(|| {}).join().unwrap();
945
946 let sc = second_count.clone();
947 queue.on_complete(move |_, _| {
948 sc.fetch_add(1, Ordering::SeqCst);
949 });
950
951 queue.submit(|| {}).join().unwrap();
952
953 assert_eq!(first_count.load(Ordering::SeqCst), 1);
954 assert_eq!(second_count.load(Ordering::SeqCst), 1);
955
956 queue.shutdown();
957 }
958}