rakka_testkit/
test_scheduler.rs1use std::sync::{Arc, Mutex};
24use std::time::{Duration, Instant};
25
26type Callback = Box<dyn FnOnce() + Send + 'static>;
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
29pub struct ScheduledToken(u64);
30
31struct Entry {
32 fire_at: Instant,
33 cb: Option<Callback>,
34 fired: bool,
35 cancelled: bool,
36}
37
38struct Inner {
39 now: Instant,
40 next_token: u64,
41 entries: Vec<(ScheduledToken, Entry)>,
42}
43
44#[derive(Clone)]
50pub struct TestScheduler {
51 inner: Arc<Mutex<Inner>>,
52}
53
54impl Default for TestScheduler {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl TestScheduler {
61 pub fn new() -> Self {
62 Self {
63 inner: Arc::new(Mutex::new(Inner { now: Instant::now(), next_token: 0, entries: Vec::new() })),
64 }
65 }
66
67 pub fn now(&self) -> Instant {
69 self.inner.lock().unwrap().now
70 }
71
72 pub fn schedule_after<F>(&self, delay: Duration, cb: F) -> ScheduledToken
74 where
75 F: FnOnce() + Send + 'static,
76 {
77 let mut g = self.inner.lock().unwrap();
78 let token = ScheduledToken(g.next_token);
79 g.next_token += 1;
80 let fire_at = g.now + delay;
81 g.entries.push((token, Entry { fire_at, cb: Some(Box::new(cb)), fired: false, cancelled: false }));
82 token
83 }
84
85 pub fn cancel(&self, token: ScheduledToken) -> bool {
87 let mut g = self.inner.lock().unwrap();
88 for (tok, entry) in g.entries.iter_mut() {
89 if *tok == token && !entry.fired {
90 entry.cancelled = true;
91 return true;
92 }
93 }
94 false
95 }
96
97 pub async fn advance_by(&self, delta: Duration) {
100 let target = {
101 let g = self.inner.lock().unwrap();
102 g.now + delta
103 };
104 self.advance_to(target).await;
105 }
106
107 pub async fn advance_to(&self, target: Instant) {
109 loop {
110 let next = {
112 let g = self.inner.lock().unwrap();
113 let mut due: Vec<(usize, Instant)> = g
114 .entries
115 .iter()
116 .enumerate()
117 .filter(|(_, (_, e))| !e.fired && !e.cancelled && e.fire_at <= target)
118 .map(|(i, (_, e))| (i, e.fire_at))
119 .collect();
120 due.sort_by_key(|(_, t)| *t);
121 due.first().copied()
122 };
123 match next {
124 Some((idx, t)) => {
125 let cb = {
126 let mut g = self.inner.lock().unwrap();
127 g.now = t;
128 let entry = &mut g.entries[idx].1;
129 entry.fired = true;
130 entry.cb.take()
131 };
132 if let Some(cb) = cb {
133 cb();
134 }
135 tokio::task::yield_now().await;
137 }
138 None => {
139 let mut g = self.inner.lock().unwrap();
140 if g.now < target {
141 g.now = target;
142 }
143 return;
144 }
145 }
146 }
147 }
148
149 pub fn fired(&self, token: ScheduledToken) -> bool {
151 self.inner.lock().unwrap().entries.iter().any(|(t, e)| *t == token && e.fired)
152 }
153
154 pub fn pending(&self) -> usize {
157 self.inner.lock().unwrap().entries.iter().filter(|(_, e)| !e.fired && !e.cancelled).count()
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use std::sync::atomic::{AtomicU32, Ordering};
165
166 #[tokio::test]
167 async fn fires_after_advance() {
168 let s = TestScheduler::new();
169 let counter = Arc::new(AtomicU32::new(0));
170 let c2 = counter.clone();
171 let token = s.schedule_after(Duration::from_secs(5), move || {
172 c2.fetch_add(1, Ordering::SeqCst);
173 });
174 assert_eq!(counter.load(Ordering::SeqCst), 0);
175 s.advance_by(Duration::from_secs(5)).await;
176 assert_eq!(counter.load(Ordering::SeqCst), 1);
177 assert!(s.fired(token));
178 assert_eq!(s.pending(), 0);
179 }
180
181 #[tokio::test]
182 async fn does_not_fire_before_delay() {
183 let s = TestScheduler::new();
184 let counter = Arc::new(AtomicU32::new(0));
185 let c2 = counter.clone();
186 let _ = s.schedule_after(Duration::from_secs(10), move || {
187 c2.fetch_add(1, Ordering::SeqCst);
188 });
189 s.advance_by(Duration::from_secs(9)).await;
190 assert_eq!(counter.load(Ordering::SeqCst), 0);
191 assert_eq!(s.pending(), 1);
192 }
193
194 #[tokio::test]
195 async fn cancel_prevents_fire() {
196 let s = TestScheduler::new();
197 let counter = Arc::new(AtomicU32::new(0));
198 let c2 = counter.clone();
199 let t = s.schedule_after(Duration::from_secs(1), move || {
200 c2.fetch_add(1, Ordering::SeqCst);
201 });
202 assert!(s.cancel(t));
203 s.advance_by(Duration::from_secs(2)).await;
204 assert_eq!(counter.load(Ordering::SeqCst), 0);
205 assert!(!s.fired(t));
206 }
207
208 #[tokio::test]
209 async fn fires_in_order() {
210 let s = TestScheduler::new();
211 let order = Arc::new(Mutex::new(Vec::<u32>::new()));
212 for (delay, id) in [(3u64, 1u32), (1, 2), (2, 3)] {
213 let order = order.clone();
214 s.schedule_after(Duration::from_secs(delay), move || {
215 order.lock().unwrap().push(id);
216 });
217 }
218 s.advance_by(Duration::from_secs(5)).await;
219 assert_eq!(*order.lock().unwrap(), vec![2, 3, 1]);
220 }
221}