1use std::sync::atomic::{AtomicU64, Ordering};
2
3pub struct SequenceCounter {
20 value: AtomicU64,
21}
22
23impl SequenceCounter {
24 pub fn new(start: u64) -> Self {
26 Self { value: AtomicU64::new(start) }
27 }
28
29 pub fn next(&self) -> u64 {
32 self.value.fetch_add(1, Ordering::SeqCst)
33 }
34
35 pub fn current(&self) -> u64 {
37 self.value.load(Ordering::SeqCst)
38 }
39}
40
41impl Default for SequenceCounter {
42 fn default() -> Self {
43 Self::new(0)
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50 use std::sync::Arc;
51 use std::thread;
52
53 #[test]
54 fn test_starts_at_zero() {
55 let counter = SequenceCounter::default();
56 assert_eq!(counter.next(), 0);
57 }
58
59 #[test]
60 fn test_starts_at_custom_value() {
61 let counter = SequenceCounter::new(42);
62 assert_eq!(counter.next(), 42);
63 assert_eq!(counter.next(), 43);
64 }
65
66 #[test]
67 fn test_increments_monotonically() {
68 let counter = SequenceCounter::default();
69 let mut prev = counter.next();
70 for _ in 0..100 {
71 let curr = counter.next();
72 assert!(curr > prev, "expected {curr} > {prev}");
73 prev = curr;
74 }
75 }
76
77 #[test]
78 fn test_current_does_not_increment() {
79 let counter = SequenceCounter::default();
80 assert_eq!(counter.current(), 0);
81 assert_eq!(counter.current(), 0);
82 counter.next();
83 assert_eq!(counter.current(), 1);
84 assert_eq!(counter.current(), 1);
85 }
86
87 #[test]
88 fn test_thread_safe_concurrent_access() {
89 let counter = Arc::new(SequenceCounter::default());
90 let num_threads = 8;
91 let increments_per_thread = 1000;
92
93 let handles: Vec<_> = (0..num_threads)
94 .map(|_| {
95 let counter = Arc::clone(&counter);
96 thread::spawn(move || {
97 let mut values = Vec::with_capacity(increments_per_thread);
98 for _ in 0..increments_per_thread {
99 values.push(counter.next());
100 }
101 values
102 })
103 })
104 .collect();
105
106 let mut all_values: Vec<u64> =
107 handles.into_iter().flat_map(|h| h.join().unwrap()).collect();
108
109 all_values.sort();
111 all_values.dedup();
112 let expected_total = num_threads * increments_per_thread;
113 assert_eq!(
114 all_values.len(),
115 expected_total,
116 "expected {expected_total} unique values, got {}",
117 all_values.len()
118 );
119
120 assert_eq!(counter.current(), expected_total as u64);
122 }
123
124 #[test]
125 fn test_thread_safe_values_are_monotonic_per_thread() {
126 let counter = Arc::new(SequenceCounter::default());
127 let num_threads = 4;
128 let increments_per_thread = 500;
129
130 let handles: Vec<_> = (0..num_threads)
131 .map(|_| {
132 let counter = Arc::clone(&counter);
133 thread::spawn(move || {
134 let mut values = Vec::with_capacity(increments_per_thread);
135 for _ in 0..increments_per_thread {
136 values.push(counter.next());
137 }
138 values
139 })
140 })
141 .collect();
142
143 for handle in handles {
144 let values = handle.join().unwrap();
145 for window in values.windows(2) {
147 assert!(
148 window[1] > window[0],
149 "expected monotonically increasing within thread, got {} followed by {}",
150 window[0],
151 window[1]
152 );
153 }
154 }
155 }
156}