atlas_runtime/
dependency_tracker.rs1use std::sync::{atomic::AtomicU64, Condvar, Mutex};
4
5#[derive(Debug, Default)]
6pub struct DependencyTracker {
7 work_sequence: AtomicU64,
9 processed_work_sequence: Mutex<Option<u64>>,
11 condvar: Condvar,
12}
13
14fn less_than(a: &Option<u64>, b: u64) -> bool {
15 a.is_none_or(|a| a < b)
16}
17
18impl DependencyTracker {
19 pub fn declare_work(&self) -> u64 {
22 self.work_sequence
23 .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
24 + 1
25 }
26
27 pub fn mark_this_and_all_previous_work_processed(&self, sequence: u64) {
32 let mut work_sequence = self.processed_work_sequence.lock().unwrap();
33 if less_than(&work_sequence, sequence) {
34 *work_sequence = Some(sequence);
35 self.condvar.notify_all();
36 }
37 }
38
39 pub fn wait_for_dependency(&self, sequence: u64) {
41 if sequence == 0 {
42 return; }
44 let mut processed_sequence = self.processed_work_sequence.lock().unwrap();
45 while less_than(&processed_sequence, sequence) {
46 processed_sequence = self.condvar.wait(processed_sequence).unwrap();
47 }
48 }
49
50 pub fn get_current_declared_work(&self) -> u64 {
52 self.work_sequence.load(std::sync::atomic::Ordering::SeqCst)
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use {
59 super::*,
60 std::{sync::Arc, thread},
61 };
62
63 #[test]
64 fn test_less_than() {
65 assert!(less_than(&None, 0));
66 assert!(less_than(&Some(0), 1));
67 assert!(!less_than(&Some(1), 1));
68 assert!(!less_than(&Some(2), 1));
69 }
70
71 #[test]
72 fn test_get_new_work_sequence() {
73 let dependency_tracker = DependencyTracker::default();
74 assert_eq!(dependency_tracker.declare_work(), 1);
75 assert_eq!(dependency_tracker.declare_work(), 2);
76 assert_eq!(dependency_tracker.get_current_declared_work(), 2);
77 }
78
79 #[test]
80 fn test_notify_work_processed() {
81 let dependency_tracker = DependencyTracker::default();
82 dependency_tracker.mark_this_and_all_previous_work_processed(1);
83
84 let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
85 assert_eq!(processed_sequence, Some(1));
86
87 dependency_tracker.mark_this_and_all_previous_work_processed(0);
89 let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
90 assert_eq!(processed_sequence, Some(1));
91 dependency_tracker.mark_this_and_all_previous_work_processed(2);
93 let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
94 assert_eq!(processed_sequence, Some(2));
95 dependency_tracker.mark_this_and_all_previous_work_processed(2);
97 let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
98 assert_eq!(processed_sequence, Some(2));
99 }
100
101 #[test]
102 fn test_wait_and_notify_work_processed() {
103 let dependency_tracker = Arc::new(DependencyTracker::default());
104 let tracker_clone = Arc::clone(&dependency_tracker);
105
106 let work = dependency_tracker.declare_work();
107 assert_eq!(work, 1);
108 let work = dependency_tracker.declare_work();
109 assert_eq!(work, 2);
110 let work_to_wait = dependency_tracker.get_current_declared_work();
111 let handle = thread::spawn(move || {
112 tracker_clone.wait_for_dependency(work_to_wait);
113 });
114
115 thread::sleep(std::time::Duration::from_millis(100));
116 dependency_tracker.mark_this_and_all_previous_work_processed(work);
117 handle.join().unwrap();
118
119 let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
120 assert_eq!(processed_sequence, Some(2));
121 }
122}