atomr_testkit/
test_scheduler.rs1use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23
24type Callback = Box<dyn FnOnce() + Send + 'static>;
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
27pub struct ScheduledToken(u64);
28
29struct Entry {
30 fire_at: Instant,
31 cb: Option<Callback>,
32 fired: bool,
33 cancelled: bool,
34}
35
36struct Inner {
37 now: Instant,
38 next_token: u64,
39 entries: Vec<(ScheduledToken, Entry)>,
40}
41
42#[derive(Clone)]
48pub struct TestScheduler {
49 inner: Arc<Mutex<Inner>>,
50}
51
52impl Default for TestScheduler {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl TestScheduler {
59 pub fn new() -> Self {
60 Self {
61 inner: Arc::new(Mutex::new(Inner { now: Instant::now(), next_token: 0, entries: Vec::new() })),
62 }
63 }
64
65 pub fn now(&self) -> Instant {
67 self.inner.lock().unwrap().now
68 }
69
70 pub fn schedule_after<F>(&self, delay: Duration, cb: F) -> ScheduledToken
72 where
73 F: FnOnce() + Send + 'static,
74 {
75 let mut g = self.inner.lock().unwrap();
76 let token = ScheduledToken(g.next_token);
77 g.next_token += 1;
78 let fire_at = g.now + delay;
79 g.entries.push((token, Entry { fire_at, cb: Some(Box::new(cb)), fired: false, cancelled: false }));
80 token
81 }
82
83 pub fn cancel(&self, token: ScheduledToken) -> bool {
88 let mut g = self.inner.lock().unwrap();
89 for (tok, entry) in g.entries.iter_mut() {
90 if *tok == token && !entry.fired && !entry.cancelled {
91 entry.cancelled = true;
92 return true;
93 }
94 }
95 false
96 }
97
98 pub async fn advance_by(&self, delta: Duration) {
101 let target = {
102 let g = self.inner.lock().unwrap();
103 g.now + delta
104 };
105 self.advance_to(target).await;
106 }
107
108 pub async fn advance_to(&self, target: Instant) {
110 loop {
111 let next = {
113 let g = self.inner.lock().unwrap();
114 let mut due: Vec<(usize, Instant)> = g
115 .entries
116 .iter()
117 .enumerate()
118 .filter(|(_, (_, e))| !e.fired && !e.cancelled && e.fire_at <= target)
119 .map(|(i, (_, e))| (i, e.fire_at))
120 .collect();
121 due.sort_by_key(|(_, t)| *t);
122 due.first().copied()
123 };
124 match next {
125 Some((idx, t)) => {
126 let cb = {
127 let mut g = self.inner.lock().unwrap();
128 g.now = t;
129 let entry = &mut g.entries[idx].1;
130 entry.fired = true;
131 entry.cb.take()
132 };
133 if let Some(cb) = cb {
134 cb();
135 }
136 tokio::task::yield_now().await;
138 }
139 None => {
140 let mut g = self.inner.lock().unwrap();
141 if g.now < target {
142 g.now = target;
143 }
144 return;
145 }
146 }
147 }
148 }
149
150 pub fn fired(&self, token: ScheduledToken) -> bool {
152 self.inner.lock().unwrap().entries.iter().any(|(t, e)| *t == token && e.fired)
153 }
154
155 pub fn pending(&self) -> usize {
158 self.inner.lock().unwrap().entries.iter().filter(|(_, e)| !e.fired && !e.cancelled).count()
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use std::sync::atomic::{AtomicU32, Ordering};
166
167 #[tokio::test]
168 async fn fires_after_advance() {
169 let s = TestScheduler::new();
170 let counter = Arc::new(AtomicU32::new(0));
171 let c2 = counter.clone();
172 let token = s.schedule_after(Duration::from_secs(5), move || {
173 c2.fetch_add(1, Ordering::SeqCst);
174 });
175 assert_eq!(counter.load(Ordering::SeqCst), 0);
176 s.advance_by(Duration::from_secs(5)).await;
177 assert_eq!(counter.load(Ordering::SeqCst), 1);
178 assert!(s.fired(token));
179 assert_eq!(s.pending(), 0);
180 }
181
182 #[tokio::test]
183 async fn does_not_fire_before_delay() {
184 let s = TestScheduler::new();
185 let counter = Arc::new(AtomicU32::new(0));
186 let c2 = counter.clone();
187 let _ = s.schedule_after(Duration::from_secs(10), move || {
188 c2.fetch_add(1, Ordering::SeqCst);
189 });
190 s.advance_by(Duration::from_secs(9)).await;
191 assert_eq!(counter.load(Ordering::SeqCst), 0);
192 assert_eq!(s.pending(), 1);
193 }
194
195 #[tokio::test]
196 async fn cancel_prevents_fire() {
197 let s = TestScheduler::new();
198 let counter = Arc::new(AtomicU32::new(0));
199 let c2 = counter.clone();
200 let t = s.schedule_after(Duration::from_secs(1), move || {
201 c2.fetch_add(1, Ordering::SeqCst);
202 });
203 assert!(s.cancel(t));
204 s.advance_by(Duration::from_secs(2)).await;
205 assert_eq!(counter.load(Ordering::SeqCst), 0);
206 assert!(!s.fired(t));
207 }
208
209 #[tokio::test]
210 async fn fires_in_order() {
211 let s = TestScheduler::new();
212 let order = Arc::new(Mutex::new(Vec::<u32>::new()));
213 for (delay, id) in [(3u64, 1u32), (1, 2), (2, 3)] {
214 let order = order.clone();
215 s.schedule_after(Duration::from_secs(delay), move || {
216 order.lock().unwrap().push(id);
217 });
218 }
219 s.advance_by(Duration::from_secs(5)).await;
220 assert_eq!(*order.lock().unwrap(), vec![2, 3, 1]);
221 }
222}