Skip to main content

atomr_testkit/
test_scheduler.rs

1//! `TestScheduler` — virtual-time scheduler for deterministic tests.
2//!
3//! Phase 4 of `docs/full-port-plan.md`. Akka.NET parity:
4//! `Akka.TestKit.TestScheduler`. Differs in API shape because we
5//! lean on Tokio's `time::pause` for suspension and provide a
6//! lightweight `advance_by`/`advance_to` helper that drives both
7//! Tokio's clock and a list of registered callbacks.
8//!
9//! Typical pattern:
10//!
11//! ```no_run
12//! # use std::time::Duration;
13//! # use atomr_testkit::TestScheduler;
14//! # async fn ex() {
15//! let mut sched = TestScheduler::new();
16//! let token = sched.schedule_after(Duration::from_secs(60), || println!("fired"));
17//! // No real time elapses; callback runs once we advance.
18//! sched.advance_by(Duration::from_secs(60)).await;
19//! assert!(sched.fired(token));
20//! # }
21//! ```
22
23use std::sync::{Arc, Mutex};
24use std::time::{Duration, Instant};
25
26type Callback = Box<dyn FnOnce() + Send + 'static>;
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
29pub struct ScheduledToken(u64);
30
31struct Entry {
32    fire_at: Instant,
33    cb: Option<Callback>,
34    fired: bool,
35    cancelled: bool,
36}
37
38struct Inner {
39    now: Instant,
40    next_token: u64,
41    entries: Vec<(ScheduledToken, Entry)>,
42}
43
44/// Virtual-time scheduler. Time only advances when [`advance_by`] /
45/// [`advance_to`] is called.
46///
47/// [`advance_by`]: TestScheduler::advance_by
48/// [`advance_to`]: TestScheduler::advance_to
49#[derive(Clone)]
50pub struct TestScheduler {
51    inner: Arc<Mutex<Inner>>,
52}
53
54impl Default for TestScheduler {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl TestScheduler {
61    pub fn new() -> Self {
62        Self {
63            inner: Arc::new(Mutex::new(Inner { now: Instant::now(), next_token: 0, entries: Vec::new() })),
64        }
65    }
66
67    /// Current virtual time.
68    pub fn now(&self) -> Instant {
69        self.inner.lock().unwrap().now
70    }
71
72    /// Schedule `cb` to fire `delay` after the current virtual time.
73    pub fn schedule_after<F>(&self, delay: Duration, cb: F) -> ScheduledToken
74    where
75        F: FnOnce() + Send + 'static,
76    {
77        let mut g = self.inner.lock().unwrap();
78        let token = ScheduledToken(g.next_token);
79        g.next_token += 1;
80        let fire_at = g.now + delay;
81        g.entries.push((token, Entry { fire_at, cb: Some(Box::new(cb)), fired: false, cancelled: false }));
82        token
83    }
84
85    /// Cancel a scheduled callback if it hasn't fired yet.
86    pub fn cancel(&self, token: ScheduledToken) -> bool {
87        let mut g = self.inner.lock().unwrap();
88        for (tok, entry) in g.entries.iter_mut() {
89            if *tok == token && !entry.fired {
90                entry.cancelled = true;
91                return true;
92            }
93        }
94        false
95    }
96
97    /// Advance virtual time by `delta`, firing all callbacks whose
98    /// fire-at falls in the new range. Callbacks fire in fire-at order.
99    pub async fn advance_by(&self, delta: Duration) {
100        let target = {
101            let g = self.inner.lock().unwrap();
102            g.now + delta
103        };
104        self.advance_to(target).await;
105    }
106
107    /// Advance virtual time to `target` (must be ≥ current time).
108    pub async fn advance_to(&self, target: Instant) {
109        loop {
110            // Find the next due entry.
111            let next = {
112                let g = self.inner.lock().unwrap();
113                let mut due: Vec<(usize, Instant)> = g
114                    .entries
115                    .iter()
116                    .enumerate()
117                    .filter(|(_, (_, e))| !e.fired && !e.cancelled && e.fire_at <= target)
118                    .map(|(i, (_, e))| (i, e.fire_at))
119                    .collect();
120                due.sort_by_key(|(_, t)| *t);
121                due.first().copied()
122            };
123            match next {
124                Some((idx, t)) => {
125                    let cb = {
126                        let mut g = self.inner.lock().unwrap();
127                        g.now = t;
128                        let entry = &mut g.entries[idx].1;
129                        entry.fired = true;
130                        entry.cb.take()
131                    };
132                    if let Some(cb) = cb {
133                        cb();
134                    }
135                    // Yield so any spawned tasks can observe the call.
136                    tokio::task::yield_now().await;
137                }
138                None => {
139                    let mut g = self.inner.lock().unwrap();
140                    if g.now < target {
141                        g.now = target;
142                    }
143                    return;
144                }
145            }
146        }
147    }
148
149    /// Has the scheduled callback fired?
150    pub fn fired(&self, token: ScheduledToken) -> bool {
151        self.inner.lock().unwrap().entries.iter().any(|(t, e)| *t == token && e.fired)
152    }
153
154    /// How many scheduled entries are still pending (not fired,
155    /// not cancelled)?
156    pub fn pending(&self) -> usize {
157        self.inner.lock().unwrap().entries.iter().filter(|(_, e)| !e.fired && !e.cancelled).count()
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use std::sync::atomic::{AtomicU32, Ordering};
165
166    #[tokio::test]
167    async fn fires_after_advance() {
168        let s = TestScheduler::new();
169        let counter = Arc::new(AtomicU32::new(0));
170        let c2 = counter.clone();
171        let token = s.schedule_after(Duration::from_secs(5), move || {
172            c2.fetch_add(1, Ordering::SeqCst);
173        });
174        assert_eq!(counter.load(Ordering::SeqCst), 0);
175        s.advance_by(Duration::from_secs(5)).await;
176        assert_eq!(counter.load(Ordering::SeqCst), 1);
177        assert!(s.fired(token));
178        assert_eq!(s.pending(), 0);
179    }
180
181    #[tokio::test]
182    async fn does_not_fire_before_delay() {
183        let s = TestScheduler::new();
184        let counter = Arc::new(AtomicU32::new(0));
185        let c2 = counter.clone();
186        let _ = s.schedule_after(Duration::from_secs(10), move || {
187            c2.fetch_add(1, Ordering::SeqCst);
188        });
189        s.advance_by(Duration::from_secs(9)).await;
190        assert_eq!(counter.load(Ordering::SeqCst), 0);
191        assert_eq!(s.pending(), 1);
192    }
193
194    #[tokio::test]
195    async fn cancel_prevents_fire() {
196        let s = TestScheduler::new();
197        let counter = Arc::new(AtomicU32::new(0));
198        let c2 = counter.clone();
199        let t = s.schedule_after(Duration::from_secs(1), move || {
200            c2.fetch_add(1, Ordering::SeqCst);
201        });
202        assert!(s.cancel(t));
203        s.advance_by(Duration::from_secs(2)).await;
204        assert_eq!(counter.load(Ordering::SeqCst), 0);
205        assert!(!s.fired(t));
206    }
207
208    #[tokio::test]
209    async fn fires_in_order() {
210        let s = TestScheduler::new();
211        let order = Arc::new(Mutex::new(Vec::<u32>::new()));
212        for (delay, id) in [(3u64, 1u32), (1, 2), (2, 3)] {
213            let order = order.clone();
214            s.schedule_after(Duration::from_secs(delay), move || {
215                order.lock().unwrap().push(id);
216            });
217        }
218        s.advance_by(Duration::from_secs(5)).await;
219        assert_eq!(*order.lock().unwrap(), vec![2, 3, 1]);
220    }
221}