use std::{
cmp::Ordering,
collections::{BinaryHeap, HashMap},
time::{Duration, Instant},
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct TimerId(u64);
impl TimerId {
fn new(raw: u64) -> Self {
TimerId(raw)
}
}
#[derive(Debug, Clone)]
struct TimerEntry {
id: TimerId,
deadline: Instant,
}
#[derive(Debug)]
struct HeapItem(TimerEntry);
impl PartialEq for HeapItem {
fn eq(&self, other: &Self) -> bool {
self.0.deadline.eq(&other.0.deadline) && self.0.id.eq(&other.0.id)
}
}
impl Eq for HeapItem {}
impl PartialOrd for HeapItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(other.0.deadline.cmp(&self.0.deadline))
}
}
impl Ord for HeapItem {
fn cmp(&self, other: &Self) -> Ordering {
other.0.deadline.cmp(&self.0.deadline)
}
}
#[derive(Debug, Default)]
pub struct TimerQueue {
next_id: u64,
heap: BinaryHeap<HeapItem>,
active: HashMap<TimerId, Instant>,
}
impl TimerQueue {
pub fn new() -> Self {
Self::default()
}
pub fn is_empty(&self) -> bool {
self.active.is_empty()
}
pub fn len(&self) -> usize {
self.active.len()
}
pub fn contains(&self, id: TimerId) -> bool {
self.active.contains_key(&id)
}
pub fn schedule_after(&mut self, delay: Duration, now: Instant) -> TimerId {
self.schedule_deadline(now + delay)
}
pub fn schedule_deadline(&mut self, deadline: Instant) -> TimerId {
let id = TimerId::new(self.next_id);
self.next_id = self.next_id.wrapping_add(1);
let entry = TimerEntry { id, deadline };
self.active.insert(id, deadline);
self.heap.push(HeapItem(entry));
id
}
pub fn cancel(&mut self, id: TimerId) -> bool {
self.active.remove(&id).is_some()
}
pub fn next_deadline(&mut self) -> Option<Instant> {
self.peek_active().map(|entry| entry.deadline)
}
pub fn poll_expired(&mut self, now: Instant) -> Vec<TimerId> {
let mut expired = Vec::new();
while let Some(entry) = self.peek_active() {
if entry.deadline > now {
break;
}
let id = entry.id;
self.active.remove(&id);
let _ = self.heap.pop();
expired.push(id);
}
expired
}
fn peek_active(&mut self) -> Option<TimerEntry> {
loop {
let head = match self.heap.peek() {
Some(item) => item.0.clone(),
None => return None,
};
if self.active.contains_key(&head.id) {
return Some(head);
}
let _ = self.heap.pop();
}
}
}
#[cfg(test)]
mod tests {
use super::TimerQueue;
use std::time::{Duration, Instant};
#[test]
fn schedule_and_expire_in_order() {
let start = Instant::now();
let mut timers = TimerQueue::new();
let t1 = timers.schedule_after(Duration::from_millis(10), start);
let t2 = timers.schedule_after(Duration::from_millis(20), start);
assert_eq!(timers.len(), 2);
let expired = timers.poll_expired(start + Duration::from_millis(15));
assert_eq!(expired, vec![t1]);
let expired = timers.poll_expired(start + Duration::from_millis(25));
assert_eq!(expired, vec![t2]);
assert!(timers.is_empty());
}
#[test]
fn cancellation_prevents_expiry() {
let start = Instant::now();
let mut timers = TimerQueue::new();
let t1 = timers.schedule_after(Duration::from_millis(10), start);
let _t2 = timers.schedule_after(Duration::from_millis(20), start);
assert!(timers.cancel(t1));
let expired = timers.poll_expired(start + Duration::from_millis(25));
assert_eq!(expired.len(), 1);
assert_ne!(expired[0], t1);
}
#[test]
fn next_deadline_reports_earliest() {
let start = Instant::now();
let mut timers = TimerQueue::new();
let d1 = start + Duration::from_millis(30);
let d2 = start + Duration::from_millis(10);
timers.schedule_deadline(d1);
timers.schedule_deadline(d2);
assert_eq!(timers.next_deadline(), Some(d2));
}
}