Skip to main content

atomr_testkit/
test_scheduler.rs

1//! `TestScheduler` — virtual-time scheduler for deterministic tests.
2//!
3//! Leans on Tokio's `time::pause` for suspension and provides a lightweight
4//! `advance_by`/`advance_to` helper that drives both Tokio's clock and a
5//! list of registered callbacks.
6//!
7//! Typical pattern:
8//!
9//! ```no_run
10//! # use std::time::Duration;
11//! # use atomr_testkit::TestScheduler;
12//! # async fn ex() {
13//! let mut sched = TestScheduler::new();
14//! let token = sched.schedule_after(Duration::from_secs(60), || println!("fired"));
15//! // No real time elapses; callback runs once we advance.
16//! sched.advance_by(Duration::from_secs(60)).await;
17//! assert!(sched.fired(token));
18//! # }
19//! ```
20
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23
24type Callback = Box<dyn FnOnce() + Send + 'static>;
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
27pub struct ScheduledToken(u64);
28
29struct Entry {
30    fire_at: Instant,
31    cb: Option<Callback>,
32    fired: bool,
33    cancelled: bool,
34}
35
36struct Inner {
37    now: Instant,
38    next_token: u64,
39    entries: Vec<(ScheduledToken, Entry)>,
40}
41
42/// Virtual-time scheduler. Time only advances when [`advance_by`] /
43/// [`advance_to`] is called.
44///
45/// [`advance_by`]: TestScheduler::advance_by
46/// [`advance_to`]: TestScheduler::advance_to
47#[derive(Clone)]
48pub struct TestScheduler {
49    inner: Arc<Mutex<Inner>>,
50}
51
52impl Default for TestScheduler {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl TestScheduler {
59    pub fn new() -> Self {
60        Self {
61            inner: Arc::new(Mutex::new(Inner { now: Instant::now(), next_token: 0, entries: Vec::new() })),
62        }
63    }
64
65    /// Current virtual time.
66    pub fn now(&self) -> Instant {
67        self.inner.lock().unwrap().now
68    }
69
70    /// Schedule `cb` to fire `delay` after the current virtual time.
71    pub fn schedule_after<F>(&self, delay: Duration, cb: F) -> ScheduledToken
72    where
73        F: FnOnce() + Send + 'static,
74    {
75        let mut g = self.inner.lock().unwrap();
76        let token = ScheduledToken(g.next_token);
77        g.next_token += 1;
78        let fire_at = g.now + delay;
79        g.entries.push((token, Entry { fire_at, cb: Some(Box::new(cb)), fired: false, cancelled: false }));
80        token
81    }
82
83    /// Cancel a scheduled callback if it hasn't fired or been
84    /// cancelled yet. Returns `true` iff this call performed the
85    /// cancellation; subsequent cancels of the same token are no-ops
86    /// returning `false`.
87    pub fn cancel(&self, token: ScheduledToken) -> bool {
88        let mut g = self.inner.lock().unwrap();
89        for (tok, entry) in g.entries.iter_mut() {
90            if *tok == token && !entry.fired && !entry.cancelled {
91                entry.cancelled = true;
92                return true;
93            }
94        }
95        false
96    }
97
98    /// Advance virtual time by `delta`, firing all callbacks whose
99    /// fire-at falls in the new range. Callbacks fire in fire-at order.
100    pub async fn advance_by(&self, delta: Duration) {
101        let target = {
102            let g = self.inner.lock().unwrap();
103            g.now + delta
104        };
105        self.advance_to(target).await;
106    }
107
108    /// Advance virtual time to `target` (must be ≥ current time).
109    pub async fn advance_to(&self, target: Instant) {
110        loop {
111            // Find the next due entry.
112            let next = {
113                let g = self.inner.lock().unwrap();
114                let mut due: Vec<(usize, Instant)> = g
115                    .entries
116                    .iter()
117                    .enumerate()
118                    .filter(|(_, (_, e))| !e.fired && !e.cancelled && e.fire_at <= target)
119                    .map(|(i, (_, e))| (i, e.fire_at))
120                    .collect();
121                due.sort_by_key(|(_, t)| *t);
122                due.first().copied()
123            };
124            match next {
125                Some((idx, t)) => {
126                    let cb = {
127                        let mut g = self.inner.lock().unwrap();
128                        g.now = t;
129                        let entry = &mut g.entries[idx].1;
130                        entry.fired = true;
131                        entry.cb.take()
132                    };
133                    if let Some(cb) = cb {
134                        cb();
135                    }
136                    // Yield so any spawned tasks can observe the call.
137                    tokio::task::yield_now().await;
138                }
139                None => {
140                    let mut g = self.inner.lock().unwrap();
141                    if g.now < target {
142                        g.now = target;
143                    }
144                    return;
145                }
146            }
147        }
148    }
149
150    /// Has the scheduled callback fired?
151    pub fn fired(&self, token: ScheduledToken) -> bool {
152        self.inner.lock().unwrap().entries.iter().any(|(t, e)| *t == token && e.fired)
153    }
154
155    /// How many scheduled entries are still pending (not fired,
156    /// not cancelled)?
157    pub fn pending(&self) -> usize {
158        self.inner.lock().unwrap().entries.iter().filter(|(_, e)| !e.fired && !e.cancelled).count()
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use std::sync::atomic::{AtomicU32, Ordering};
166
167    #[tokio::test]
168    async fn fires_after_advance() {
169        let s = TestScheduler::new();
170        let counter = Arc::new(AtomicU32::new(0));
171        let c2 = counter.clone();
172        let token = s.schedule_after(Duration::from_secs(5), move || {
173            c2.fetch_add(1, Ordering::SeqCst);
174        });
175        assert_eq!(counter.load(Ordering::SeqCst), 0);
176        s.advance_by(Duration::from_secs(5)).await;
177        assert_eq!(counter.load(Ordering::SeqCst), 1);
178        assert!(s.fired(token));
179        assert_eq!(s.pending(), 0);
180    }
181
182    #[tokio::test]
183    async fn does_not_fire_before_delay() {
184        let s = TestScheduler::new();
185        let counter = Arc::new(AtomicU32::new(0));
186        let c2 = counter.clone();
187        let _ = s.schedule_after(Duration::from_secs(10), move || {
188            c2.fetch_add(1, Ordering::SeqCst);
189        });
190        s.advance_by(Duration::from_secs(9)).await;
191        assert_eq!(counter.load(Ordering::SeqCst), 0);
192        assert_eq!(s.pending(), 1);
193    }
194
195    #[tokio::test]
196    async fn cancel_prevents_fire() {
197        let s = TestScheduler::new();
198        let counter = Arc::new(AtomicU32::new(0));
199        let c2 = counter.clone();
200        let t = s.schedule_after(Duration::from_secs(1), move || {
201            c2.fetch_add(1, Ordering::SeqCst);
202        });
203        assert!(s.cancel(t));
204        s.advance_by(Duration::from_secs(2)).await;
205        assert_eq!(counter.load(Ordering::SeqCst), 0);
206        assert!(!s.fired(t));
207    }
208
209    #[tokio::test]
210    async fn fires_in_order() {
211        let s = TestScheduler::new();
212        let order = Arc::new(Mutex::new(Vec::<u32>::new()));
213        for (delay, id) in [(3u64, 1u32), (1, 2), (2, 3)] {
214            let order = order.clone();
215            s.schedule_after(Duration::from_secs(delay), move || {
216                order.lock().unwrap().push(id);
217            });
218        }
219        s.advance_by(Duration::from_secs(5)).await;
220        assert_eq!(*order.lock().unwrap(), vec![2, 3, 1]);
221    }
222}