1use std::{
6 cmp::Ordering,
7 collections::{BinaryHeap, HashMap},
8 time::{Duration, Instant},
9};
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
13pub struct TimerId(u64);
14
15impl TimerId {
16 fn new(raw: u64) -> Self {
17 TimerId(raw)
18 }
19}
20
21#[derive(Debug, Clone)]
23struct TimerEntry {
24 id: TimerId,
25 deadline: Instant,
26}
27
28#[derive(Debug)]
30struct HeapItem(TimerEntry);
31
32impl PartialEq for HeapItem {
33 fn eq(&self, other: &Self) -> bool {
34 self.0.deadline.eq(&other.0.deadline) && self.0.id.eq(&other.0.id)
35 }
36}
37
38impl Eq for HeapItem {}
39
40impl PartialOrd for HeapItem {
41 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
42 Some(other.0.deadline.cmp(&self.0.deadline))
44 }
45}
46
47impl Ord for HeapItem {
48 fn cmp(&self, other: &Self) -> Ordering {
49 other.0.deadline.cmp(&self.0.deadline)
51 }
52}
53
54#[derive(Debug, Default)]
60pub struct TimerQueue {
61 next_id: u64,
62 heap: BinaryHeap<HeapItem>,
63 active: HashMap<TimerId, Instant>,
66}
67
68impl TimerQueue {
69 pub fn new() -> Self {
71 Self::default()
72 }
73
74 pub fn is_empty(&self) -> bool {
76 self.active.is_empty()
77 }
78
79 pub fn len(&self) -> usize {
81 self.active.len()
82 }
83
84 pub fn contains(&self, id: TimerId) -> bool {
86 self.active.contains_key(&id)
87 }
88
89 pub fn schedule_after(&mut self, delay: Duration, now: Instant) -> TimerId {
91 self.schedule_deadline(now + delay)
92 }
93
94 pub fn schedule_deadline(&mut self, deadline: Instant) -> TimerId {
96 let id = TimerId::new(self.next_id);
97 self.next_id = self.next_id.wrapping_add(1);
98
99 let entry = TimerEntry { id, deadline };
100 self.active.insert(id, deadline);
101 self.heap.push(HeapItem(entry));
102
103 id
104 }
105
106 pub fn cancel(&mut self, id: TimerId) -> bool {
109 self.active.remove(&id).is_some()
110 }
111
112 pub fn next_deadline(&mut self) -> Option<Instant> {
114 self.peek_active().map(|entry| entry.deadline)
115 }
116
117 pub fn poll_expired(&mut self, now: Instant) -> Vec<TimerId> {
120 let mut expired = Vec::new();
121
122 while let Some(entry) = self.peek_active() {
123 if entry.deadline > now {
124 break;
125 }
126
127 let id = entry.id;
128 self.active.remove(&id);
131 let _ = self.heap.pop();
136 expired.push(id);
137 }
138
139 expired
140 }
141
142 fn peek_active(&mut self) -> Option<TimerEntry> {
143 loop {
144 let head = match self.heap.peek() {
145 Some(item) => item.0.clone(),
146 None => return None,
147 };
148
149 if self.active.contains_key(&head.id) {
150 return Some(head);
151 }
152
153 let _ = self.heap.pop();
155 }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::TimerQueue;
162 use std::time::{Duration, Instant};
163
164 #[test]
165 fn schedule_and_expire_in_order() {
166 let start = Instant::now();
167 let mut timers = TimerQueue::new();
168
169 let t1 = timers.schedule_after(Duration::from_millis(10), start);
170 let t2 = timers.schedule_after(Duration::from_millis(20), start);
171
172 assert_eq!(timers.len(), 2);
173
174 let expired = timers.poll_expired(start + Duration::from_millis(15));
175 assert_eq!(expired, vec![t1]);
176
177 let expired = timers.poll_expired(start + Duration::from_millis(25));
178 assert_eq!(expired, vec![t2]);
179 assert!(timers.is_empty());
180 }
181
182 #[test]
183 fn cancellation_prevents_expiry() {
184 let start = Instant::now();
185 let mut timers = TimerQueue::new();
186
187 let t1 = timers.schedule_after(Duration::from_millis(10), start);
188 let _t2 = timers.schedule_after(Duration::from_millis(20), start);
189
190 assert!(timers.cancel(t1));
191
192 let expired = timers.poll_expired(start + Duration::from_millis(25));
193 assert_eq!(expired.len(), 1);
194 assert_ne!(expired[0], t1);
195 }
196
197 #[test]
198 fn next_deadline_reports_earliest() {
199 let start = Instant::now();
200 let mut timers = TimerQueue::new();
201
202 let d1 = start + Duration::from_millis(30);
203 let d2 = start + Duration::from_millis(10);
204
205 timers.schedule_deadline(d1);
206 timers.schedule_deadline(d2);
207
208 assert_eq!(timers.next_deadline(), Some(d2));
209 }
210}