1use super::{timer::Meta, tw_proto::TimeWheelProto};
2use futures::{ready, Future, FutureExt};
3use parking_lot::RwLock;
4use pin_project::pin_project;
5use std::{
6 collections::{HashMap, VecDeque},
7 fmt::Debug,
8 ops::{Add, AddAssign},
9 pin::Pin,
10 sync::Arc,
11 task::Poll,
12};
13use tokio::sync::mpsc;
14
15pub struct WheelProxy<T>
35where
36 T: Debug + Send,
37{
38 slot: u32,
39 slot_duration: std::time::Duration,
40 start: Arc<RwLock<std::time::Instant>>,
41 tick_rx: mpsc::Receiver<VecDeque<Meta<T>>>,
42 inner_tx: mpsc::Sender<TimeWheelProto<T>>,
43 timer_map: HashMap<u64, super::Snapshot>,
44
45 ticker_join: tokio::task::JoinHandle<()>,
46 inner_join: tokio::task::JoinHandle<()>,
47}
48
49impl<T> WheelProxy<T>
50where
51 T: Send + Debug + 'static,
52{
53 pub fn new(slot: u32, slot_duration: std::time::Duration, start: std::time::Instant) -> Self {
54 Inner::new(slot, slot_duration, start)
55 }
56
57 pub fn slot(&self) -> u32 {
58 self.slot
59 }
60
61 pub fn slot_duration(&self) -> std::time::Duration {
62 self.slot_duration.clone()
63 }
64
65 pub fn round_duration(&self) -> std::time::Duration {
66 self.slot_duration * self.slot
67 }
68
69 pub fn round_end(&self) -> std::time::Instant {
70 self.start.read().add(self.slot * self.slot_duration)
71 }
72
73 pub async fn dispatch(
74 &mut self,
75 duration: std::time::Duration,
76 data: T,
77 ) -> Result<super::Snapshot, super::Error<T>> {
78 self.dispatch_until(std::time::Instant::now() + duration, data)
79 .await
80 }
81
82 pub async fn dispatch_until(
83 &mut self,
84 end: std::time::Instant,
85 data: T,
86 ) -> Result<super::Snapshot, super::Error<T>> {
87 let now = std::time::Instant::now();
88 if end < now {
89 return Err(super::Error::TimeElapse(Some(data)));
90 }
91 let round_end = self.round_end();
92 if end > round_end {
93 return Err(super::Error::Overflow(Some(data)));
94 }
95 let meta = Meta::new(now, end, data);
96 let snapshot = super::Snapshot {
97 id: meta.id,
98 start: meta.start,
99 end,
100 };
101 let timer_id = meta.id;
102 if let Err(err) = self.inner_tx.send(TimeWheelProto::Add(meta)).await {
103 return Err(match err.0 {
104 TimeWheelProto::Add(meta) => super::Error::Channel(meta.data),
105 _ => panic!("unexpected error"),
107 });
108 }
109 self.timer_map.insert(
110 timer_id,
111 super::Snapshot {
112 id: snapshot.id,
113 start: snapshot.start,
114 end,
115 },
116 );
117 return Ok(snapshot);
118 }
119
120 pub async fn cancel(&mut self, id: u64) -> Result<(), super::Error<T>> {
121 let snapshot = self
122 .timer_map
123 .remove(&id)
124 .ok_or(super::Error::NoRecord(id))?;
125 let slot = find_slot(
126 self.start.read().clone(),
127 self.slot_duration.as_nanos(),
128 snapshot.end,
129 );
130 self.inner_tx
131 .send(TimeWheelProto::Cancel {
132 id,
133 slot_hint: slot as usize,
134 })
135 .await
136 .map_err(|_| super::Error::Channel(None))
137 }
138
139 pub async fn accelerate(
140 &mut self,
141 id: u64,
142 acc_duration: std::time::Duration,
143 ) -> Result<(), super::Error<T>> {
144 let now = std::time::Instant::now();
145 let snapshot = self
147 .timer_map
148 .get_mut(&id)
149 .ok_or(super::Error::NoRecord(id))?;
150 let slot = find_slot(
151 self.start.read().clone(),
152 self.slot_duration.as_nanos(),
153 snapshot.end,
154 );
155 if snapshot.end - acc_duration < now {
157 self.inner_tx
158 .send(TimeWheelProto::Trigger {
159 id,
160 slot_hint: slot as usize,
161 })
162 .await
163 .map_err(|_| super::Error::Channel(None))?;
164 snapshot.end -= acc_duration;
165 return Ok(());
166 }
167 self.inner_tx
169 .send(TimeWheelProto::Accelerate {
170 id,
171 slot_hint: slot as usize,
172 dur: acc_duration,
173 })
174 .await
175 .map_err(|_| super::Error::Channel(None))
176 }
177
178 pub async fn delay(
179 &mut self,
180 id: u64,
181 delay_duration: std::time::Duration,
182 ) -> Result<(), super::Error<T>> {
183 let snapshot = self
184 .timer_map
185 .get_mut(&id)
186 .ok_or(super::Error::NoRecord(id))?;
187 let round_start = self.start.read().clone();
188 if snapshot.end + delay_duration > round_start + self.slot_duration * self.slot {
190 return Err(super::Error::Overflow(None));
191 }
192 let slot = find_slot(round_start, self.slot_duration.as_nanos(), snapshot.end);
193 self.inner_tx
194 .send(TimeWheelProto::Delay {
195 id,
196 slot_hint: slot as usize,
197 dur: delay_duration,
198 })
199 .await
200 .map_err(|_| super::Error::Channel(None))
201 }
202
203 pub async fn trigger(&mut self, id: u64) -> Result<(), super::Error<T>> {
204 let snapshot = self
205 .timer_map
206 .get_mut(&id)
207 .ok_or(super::Error::NoRecord(id))?;
208 let slot = find_slot(
209 self.start.read().clone(),
210 self.slot_duration.as_nanos(),
211 snapshot.end,
212 );
213 self.inner_tx
214 .send(TimeWheelProto::Trigger {
215 id,
216 slot_hint: slot as usize,
217 })
218 .await
219 .map_err(|_| super::Error::Channel(None))
220 }
221
222 pub async fn batch_add(
226 &mut self,
227 metas: Vec<Meta<T>>,
228 ) -> Result<Vec<super::Snapshot>, super::Error<T>> {
229 if metas.len() == 0 {
230 return Ok(Vec::new());
231 }
232
233 let start = self.start.read().clone();
234 let round_end = start.add(self.round_duration());
235 let slot_dur = self.slot_duration.as_nanos();
236 let mut proto = Vec::with_capacity(metas.len());
237 for mut meta in metas {
238 if self.timer_map.contains_key(&meta.id) {
240 return Err(super::Error::DupTimer(meta.id));
241 }
242 if meta.start < start {
244 return Err(super::Error::TimeElapse(meta.data.take()));
245 } else if meta.end > round_end {
246 return Err(super::Error::Overflow(meta.data.take()));
247 } else {
248 let meta_end = meta.end;
250 proto.push((
251 meta,
252 find_slot(start, slot_dur, meta_end) as usize,
254 ));
255 }
256 }
257 proto.iter().for_each(|(meta, _)| {
259 self.timer_map.insert(
260 meta.id,
261 super::Snapshot {
262 id: meta.id,
263 start: meta.start,
264 end: meta.end,
265 },
266 );
267 });
268 let ret = proto
269 .iter()
270 .map(|(meta, _)| super::Snapshot {
271 id: meta.id,
272 start: meta.start,
273 end: meta.end,
274 })
275 .collect();
276 if let Err(err) = self.inner_tx.send(TimeWheelProto::BatchAdd(proto)).await {
277 return Err(match err.0 {
278 TimeWheelProto::BatchAdd(batch) => {
279 super::Error::BatchChannel(batch.into_iter().map(|(meta, _)| meta).collect())
280 }
281 _ => panic!("unexpected error"),
283 });
284 }
285 return Ok(ret);
286 }
287
288 pub async fn tick(&mut self) -> Vec<Meta<T>> {
289 if let Some(metas) = self.tick_rx.recv().await {
290 return metas
291 .into_iter()
292 .filter(|meta| {
293 if self.timer_map.get(&meta.id).is_some() {
294 self.timer_map.remove(&meta.id);
295 return true;
296 }
297 return false;
298 })
299 .collect();
300 }
301 panic!()
302 }
303}
304
305impl<T> Drop for WheelProxy<T>
306where
307 T: Debug + Send,
308{
309 fn drop(&mut self) {
310 self.inner_join.abort();
311 self.ticker_join.abort();
312 }
313}
314
315enum InnerState {
316 PollRecv,
317 SendTick(Pin<Box<dyn Future<Output = ()> + Send>>),
318}
319
320#[pin_project]
321struct Inner<T: Debug + Send> {
322 pub(crate) slot: u32,
323 pub(crate) slot_duration: std::time::Duration,
324 pub(crate) wq: VecDeque<VecDeque<Meta<T>>>,
325 start: Arc<RwLock<std::time::Instant>>,
326 tx: mpsc::Sender<TimeWheelProto<T>>,
327 rx: mpsc::Receiver<TimeWheelProto<T>>,
328 tick_tx: mpsc::Sender<VecDeque<Meta<T>>>,
329 state: InnerState,
330}
331
332impl<T> Inner<T>
333where
334 T: Send + Debug + 'static,
335{
336 fn new(
337 slot: u32,
338 slot_duration: std::time::Duration,
339 start: std::time::Instant,
340 ) -> WheelProxy<T> {
341 let (tx, rx) = mpsc::channel(4);
342 let (tick_tx, tick_rx) = mpsc::channel(64);
343 let mut wq = VecDeque::with_capacity(slot as usize);
344 wq.resize_with(slot as usize, Default::default);
345 let arc_start = Arc::new(RwLock::new(start));
346 let inner = Self {
347 slot,
348 slot_duration: slot_duration.clone(),
349 start: arc_start.clone(),
350 wq,
351 tx: tx.clone(),
352 rx,
353 tick_tx,
354 state: InnerState::PollRecv,
355 };
356 let inner_tx = tx.clone();
357 let ticker_join = tokio::spawn(async move {
358 let mut interval =
359 tokio::time::interval_at(start.clone().into(), inner.slot_duration.clone());
360 let tx = inner_tx.clone();
361 loop {
362 interval.tick().await;
363 tx.send(TimeWheelProto::Tick).await.unwrap();
364 }
365 });
366
367 WheelProxy {
368 tick_rx,
369 inner_tx: tx,
370 ticker_join,
371 inner_join: tokio::spawn(inner),
372 slot,
373 slot_duration,
374 start: arc_start,
375 timer_map: Default::default(),
376 }
377 }
378
379}
380
381impl<T: Debug + Send> Future for Inner<T>
382where
383 T: Send + 'static,
384{
385 type Output = ();
386
387 fn poll(
388 self: std::pin::Pin<&mut Self>,
389 cx: &mut std::task::Context<'_>,
390 ) -> std::task::Poll<Self::Output> {
391 let this = self.project();
392 loop {
393 match this.state {
394 InnerState::PollRecv => {
395 if let Some(proto) = ready!(this.rx.poll_recv(cx)) {
396 match proto {
397 TimeWheelProto::Tick => {
398 this.start.write().add_assign(this.slot_duration.clone());
399 if let Some(metas) = this.wq.pop_front() {
400 this.wq.push_back(Default::default());
401 if metas.len() == 0 {
402 continue;
403 }
404 let tx = this.tick_tx.clone();
405 let mut fut = Box::pin(async move {
406 if let Err(err) = tx.send(metas).await {
407 tracing::error!("TickTx send error: {:?}", err);
408 }
409 });
410 if let Poll::Pending = fut.poll_unpin(cx) {
412 *this.state = InnerState::SendTick(fut);
413 return Poll::Pending;
415 }
416 }
417 }
418 TimeWheelProto::Add(meta) => {
419 let start = this.start.read().clone();
420 let slot_dur = this.slot_duration.as_nanos();
421 let slot = find_slot(start, slot_dur, meta.end);
422 if slot as u32 >= *this.slot {
423 tracing::error!("[Add] overflow timer. {:?}", meta);
424 continue;
425 }
426 tracing::info!("[Add] add timer. {:?}", meta);
427 this.wq.get_mut(slot as usize).unwrap().push_back(meta);
428 }
429 TimeWheelProto::BatchAdd(batch) => {
430 for (meta, slot) in batch {
431 let vec = this.wq.get_mut(slot).unwrap();
432 tracing::info!("[BatchAdd] add timer. {:?}", meta);
433 vec.push_back(meta);
434 }
435 }
436 TimeWheelProto::Cancel { id, slot_hint } => {
437 let vec = this.wq.get_mut(slot_hint).unwrap();
438 tracing::info!("[Cancel] cancel timer {}", id);
439 if let Some((idx, _)) =
440 vec.iter().enumerate().find(|(_, meta)| meta.id == id)
441 {
442 vec.swap_remove_back(idx);
443 } else {
444 tracing::warn!("[Cancel] timer {} not found", id);
445 }
446 }
447 TimeWheelProto::Accelerate { id, slot_hint, dur } => {
448 let vec = this.wq.get_mut(slot_hint).unwrap();
449 if let Some((idx, meta)) =
450 vec.iter_mut().enumerate().find(|(_, meta)| meta.id == id)
451 {
452 let new_slot = find_slot(
453 this.start.read().clone(),
454 this.slot_duration.as_nanos(),
455 meta.end - dur,
456 ) as usize;
457 meta.end -= dur;
458 if new_slot < slot_hint {
460 let meta = vec.remove(idx).unwrap();
461 this.wq.get_mut(new_slot).unwrap().push_back(meta);
462 }
463 } else {
464 tracing::warn!("[Accelerate] timer {} not found", id);
465 }
466 }
467 TimeWheelProto::Delay { id, slot_hint, dur } => {
468 let vec = this.wq.get_mut(slot_hint).unwrap();
469 if let Some((idx, meta)) =
470 vec.iter_mut().enumerate().find(|(_, meta)| meta.id == id)
471 {
472 let new_slot = find_slot(
473 this.start.read().clone(),
474 this.slot_duration.as_nanos(),
475 meta.end + dur,
476 ) as usize;
477 meta.end += dur;
478 if new_slot > slot_hint {
480 let meta = vec.remove(idx).unwrap();
481 this.wq.get_mut(new_slot).unwrap().push_back(meta);
482 }
483 } else {
484 tracing::warn!("[Delay] timer {} not found", id);
485 }
486 }
487 TimeWheelProto::Trigger { id, slot_hint } => {
488 let now = std::time::Instant::now();
489 let vec = this.wq.get_mut(slot_hint).unwrap();
490 tracing::info!("[Trigger] trigger timer {} now", id);
491 if let Some((idx, _)) =
492 vec.iter().enumerate().find(|(_, meta)| meta.id == id)
493 {
494 if let Some(mut meta) = vec.swap_remove_back(idx) {
495 meta.end = now;
496 let mut metas = VecDeque::with_capacity(1);
497 metas.push_back(meta);
498 let tx = this.tick_tx.clone();
499 let mut fut = Box::pin(async move {
500 if let Err(err) = tx.send(metas).await {
501 tracing::error!("TickTx send error: {:?}", err);
502 }
503 });
504 if let Poll::Pending = fut.poll_unpin(cx) {
506 *this.state = InnerState::SendTick(fut);
507 return Poll::Pending;
509 }
510 } else {
511 tracing::warn!("[Trigger] trigger timer {} not found", id);
512 }
513 } else {
514 tracing::warn!("[Cancel] timer {} not found", id);
515 }
516 }
517 }
518 }
519 }
520 InnerState::SendTick(fut) => match fut.poll_unpin(cx) {
521 std::task::Poll::Ready(_) => *this.state = InnerState::PollRecv,
522 std::task::Poll::Pending => return Poll::Pending,
523 },
524 }
525 }
526 }
527}
528
529#[inline(always)]
530pub(crate) fn find_slot(
531 wheel_start: std::time::Instant,
532 slot_dur_ns: u128,
533 timer_end: std::time::Instant,
534) -> u128 {
535 let diff = (timer_end - wheel_start).as_nanos();
536 diff / slot_dur_ns - if diff % slot_dur_ns != 0 { 0 } else { 1 }
537}