Skip to main content

adk_managed/
sequence.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2
3/// Per-session monotonic sequence counter.
4///
5/// Each session maintains its own counter. Every emitted `SessionEvent`
6/// gets a unique, strictly increasing `seq` value from this counter.
7/// Thread-safe via `AtomicU64`.
8///
9/// # Example
10///
11/// ```rust
12/// use adk_managed::sequence::SequenceCounter;
13///
14/// let counter = SequenceCounter::default();
15/// assert_eq!(counter.next(), 0);
16/// assert_eq!(counter.next(), 1);
17/// assert_eq!(counter.next(), 2);
18/// ```
19pub struct SequenceCounter {
20    value: AtomicU64,
21}
22
23impl SequenceCounter {
24    /// Create a new counter starting at the given value.
25    pub fn new(start: u64) -> Self {
26        Self { value: AtomicU64::new(start) }
27    }
28
29    /// Get the next sequence number (strictly increasing).
30    /// First call returns `start`, second returns `start + 1`, etc.
31    pub fn next(&self) -> u64 {
32        self.value.fetch_add(1, Ordering::SeqCst)
33    }
34
35    /// Get the current value without incrementing.
36    pub fn current(&self) -> u64 {
37        self.value.load(Ordering::SeqCst)
38    }
39}
40
41impl Default for SequenceCounter {
42    fn default() -> Self {
43        Self::new(0)
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use std::sync::Arc;
51    use std::thread;
52
53    #[test]
54    fn test_starts_at_zero() {
55        let counter = SequenceCounter::default();
56        assert_eq!(counter.next(), 0);
57    }
58
59    #[test]
60    fn test_starts_at_custom_value() {
61        let counter = SequenceCounter::new(42);
62        assert_eq!(counter.next(), 42);
63        assert_eq!(counter.next(), 43);
64    }
65
66    #[test]
67    fn test_increments_monotonically() {
68        let counter = SequenceCounter::default();
69        let mut prev = counter.next();
70        for _ in 0..100 {
71            let curr = counter.next();
72            assert!(curr > prev, "expected {curr} > {prev}");
73            prev = curr;
74        }
75    }
76
77    #[test]
78    fn test_current_does_not_increment() {
79        let counter = SequenceCounter::default();
80        assert_eq!(counter.current(), 0);
81        assert_eq!(counter.current(), 0);
82        counter.next();
83        assert_eq!(counter.current(), 1);
84        assert_eq!(counter.current(), 1);
85    }
86
87    #[test]
88    fn test_thread_safe_concurrent_access() {
89        let counter = Arc::new(SequenceCounter::default());
90        let num_threads = 8;
91        let increments_per_thread = 1000;
92
93        let handles: Vec<_> = (0..num_threads)
94            .map(|_| {
95                let counter = Arc::clone(&counter);
96                thread::spawn(move || {
97                    let mut values = Vec::with_capacity(increments_per_thread);
98                    for _ in 0..increments_per_thread {
99                        values.push(counter.next());
100                    }
101                    values
102                })
103            })
104            .collect();
105
106        let mut all_values: Vec<u64> =
107            handles.into_iter().flat_map(|h| h.join().unwrap()).collect();
108
109        // All values should be unique (no duplicates)
110        all_values.sort();
111        all_values.dedup();
112        let expected_total = num_threads * increments_per_thread;
113        assert_eq!(
114            all_values.len(),
115            expected_total,
116            "expected {expected_total} unique values, got {}",
117            all_values.len()
118        );
119
120        // Final counter value should equal total increments
121        assert_eq!(counter.current(), expected_total as u64);
122    }
123
124    #[test]
125    fn test_thread_safe_values_are_monotonic_per_thread() {
126        let counter = Arc::new(SequenceCounter::default());
127        let num_threads = 4;
128        let increments_per_thread = 500;
129
130        let handles: Vec<_> = (0..num_threads)
131            .map(|_| {
132                let counter = Arc::clone(&counter);
133                thread::spawn(move || {
134                    let mut values = Vec::with_capacity(increments_per_thread);
135                    for _ in 0..increments_per_thread {
136                        values.push(counter.next());
137                    }
138                    values
139                })
140            })
141            .collect();
142
143        for handle in handles {
144            let values = handle.join().unwrap();
145            // Each thread's own sequence should be strictly increasing
146            for window in values.windows(2) {
147                assert!(
148                    window[1] > window[0],
149                    "expected monotonically increasing within thread, got {} followed by {}",
150                    window[0],
151                    window[1]
152                );
153            }
154        }
155    }
156}