#[derive(Debug, Clone)]
pub struct PendingTimer {
pub id: u32,
pub delay_ms: u32,
pub is_interval: bool,
pub callback_source: String,
pub registered_at_ms: u64,
}
const MAX_PENDING_TIMERS: usize = 10_000;
#[derive(Debug, Default)]
pub struct TimerState {
timers: Vec<PendingTimer>,
next_id: u32,
simulated_time_ms: u64,
cancelled: std::collections::HashSet<u32>,
drained_callbacks: std::collections::HashSet<String>,
}
impl TimerState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(
&mut self,
delay_ms: u32,
is_interval: bool,
callback_source: String,
) -> Option<u32> {
if self.pending_count() >= MAX_PENDING_TIMERS {
return None;
}
let id = self.next_id;
self.next_id = self.next_id.checked_add(1)?;
self.timers.push(PendingTimer {
id,
delay_ms,
is_interval,
callback_source,
registered_at_ms: self.simulated_time_ms,
});
Some(id)
}
pub fn cancel(&mut self, id: u32) {
self.cancelled.insert(id);
}
pub fn drain_next(&mut self) -> Option<String> {
let pos = self
.timers
.iter()
.position(|t| !self.cancelled.contains(&t.id))?;
let timer = self.timers.remove(pos);
self.simulated_time_ms = self
.simulated_time_ms
.max(timer.registered_at_ms + u64::from(timer.delay_ms));
let _is_new_callback = self.drained_callbacks.insert(timer.callback_source.clone());
if timer.is_interval {
self.timers.push(PendingTimer {
id: timer.id,
delay_ms: timer.delay_ms,
is_interval: true,
callback_source: timer.callback_source.clone(),
registered_at_ms: self.simulated_time_ms,
});
}
Some(timer.callback_source)
}
#[must_use]
pub fn is_callback_drained(&self, callback_source: &str) -> bool {
self.drained_callbacks.contains(callback_source)
}
#[must_use]
pub fn unique_drained_count(&self) -> usize {
self.drained_callbacks.len()
}
pub fn reset_drained_tracking(&mut self) {
self.drained_callbacks.clear();
}
pub fn fast_forward(&mut self, advance_ms: u64) -> Vec<String> {
let target_time = self.simulated_time_ms + advance_ms;
let mut callbacks = Vec::new();
loop {
let next = self.timers.iter().position(|t| {
!self.cancelled.contains(&t.id)
&& t.registered_at_ms + u64::from(t.delay_ms) <= target_time
});
let Some(pos) = next else { break };
let timer = self.timers.remove(pos);
self.simulated_time_ms = timer.registered_at_ms + u64::from(timer.delay_ms);
if timer.is_interval {
self.timers.push(PendingTimer {
id: timer.id,
delay_ms: timer.delay_ms,
is_interval: true,
callback_source: timer.callback_source.clone(),
registered_at_ms: self.simulated_time_ms,
});
}
callbacks.push(timer.callback_source);
}
self.simulated_time_ms = target_time;
callbacks
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.timers
.iter()
.filter(|t| !self.cancelled.contains(&t.id))
.count()
}
#[must_use]
pub fn simulated_time_ms(&self) -> u64 {
self.simulated_time_ms
}
}