atlas_runtime/
dependency_tracker.rs

1//! Utility to track dependent work.
2
3use std::sync::{atomic::AtomicU64, Condvar, Mutex};
4
5#[derive(Debug, Default)]
6pub struct DependencyTracker {
7    /// The current work sequence number
8    work_sequence: AtomicU64,
9    /// The processed work sequence number, if it is None, no work has been processed
10    processed_work_sequence: Mutex<Option<u64>>,
11    condvar: Condvar,
12}
13
14fn less_than(a: &Option<u64>, b: u64) -> bool {
15    a.is_none_or(|a| a < b)
16}
17
18impl DependencyTracker {
19    /// Acquire the next work sequence number.
20    /// The sequence starts from 0 and increments by 1 each time it is called.
21    pub fn declare_work(&self) -> u64 {
22        self.work_sequence
23            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
24            + 1
25    }
26
27    /// Notify all waiting threads that a work has occurred with the given sequence number.
28    /// This function will update the work sequence and notify all waiting threads only if the work
29    /// sequence is greater than the work sequence. Notify a work of sequence number 's' will
30    /// implicitly imply that all work with sequence number less than 's' have been processed.
31    pub fn mark_this_and_all_previous_work_processed(&self, sequence: u64) {
32        let mut work_sequence = self.processed_work_sequence.lock().unwrap();
33        if less_than(&work_sequence, sequence) {
34            *work_sequence = Some(sequence);
35            self.condvar.notify_all();
36        }
37    }
38
39    /// To wait for the dependency work with 'sequence' to be processed.
40    pub fn wait_for_dependency(&self, sequence: u64) {
41        if sequence == 0 {
42            return; // No need to wait for sequence 0 as real work starts from 1.
43        }
44        let mut processed_sequence = self.processed_work_sequence.lock().unwrap();
45        while less_than(&processed_sequence, sequence) {
46            processed_sequence = self.condvar.wait(processed_sequence).unwrap();
47        }
48    }
49
50    /// Get the current work sequence number.
51    pub fn get_current_declared_work(&self) -> u64 {
52        self.work_sequence.load(std::sync::atomic::Ordering::SeqCst)
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use {
59        super::*,
60        std::{sync::Arc, thread},
61    };
62
63    #[test]
64    fn test_less_than() {
65        assert!(less_than(&None, 0));
66        assert!(less_than(&Some(0), 1));
67        assert!(!less_than(&Some(1), 1));
68        assert!(!less_than(&Some(2), 1));
69    }
70
71    #[test]
72    fn test_get_new_work_sequence() {
73        let dependency_tracker = DependencyTracker::default();
74        assert_eq!(dependency_tracker.declare_work(), 1);
75        assert_eq!(dependency_tracker.declare_work(), 2);
76        assert_eq!(dependency_tracker.get_current_declared_work(), 2);
77    }
78
79    #[test]
80    fn test_notify_work_processed() {
81        let dependency_tracker = DependencyTracker::default();
82        dependency_tracker.mark_this_and_all_previous_work_processed(1);
83
84        let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
85        assert_eq!(processed_sequence, Some(1));
86
87        // notify a smaller sequence number, should not change the processed sequence
88        dependency_tracker.mark_this_and_all_previous_work_processed(0);
89        let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
90        assert_eq!(processed_sequence, Some(1));
91        // notify a larger sequence number, should change the processed sequence
92        dependency_tracker.mark_this_and_all_previous_work_processed(2);
93        let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
94        assert_eq!(processed_sequence, Some(2));
95        // notify the same sequence number, should not change the processed sequence
96        dependency_tracker.mark_this_and_all_previous_work_processed(2);
97        let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
98        assert_eq!(processed_sequence, Some(2));
99    }
100
101    #[test]
102    fn test_wait_and_notify_work_processed() {
103        let dependency_tracker = Arc::new(DependencyTracker::default());
104        let tracker_clone = Arc::clone(&dependency_tracker);
105
106        let work = dependency_tracker.declare_work();
107        assert_eq!(work, 1);
108        let work = dependency_tracker.declare_work();
109        assert_eq!(work, 2);
110        let work_to_wait = dependency_tracker.get_current_declared_work();
111        let handle = thread::spawn(move || {
112            tracker_clone.wait_for_dependency(work_to_wait);
113        });
114
115        thread::sleep(std::time::Duration::from_millis(100));
116        dependency_tracker.mark_this_and_all_previous_work_processed(work);
117        handle.join().unwrap();
118
119        let processed_sequence = *dependency_tracker.processed_work_sequence.lock().unwrap();
120        assert_eq!(processed_sequence, Some(2));
121    }
122}