gsfw_util/timer/
wheel.rs

1use super::{timer::Meta, tw_proto::TimeWheelProto};
2use futures::{ready, Future, FutureExt};
3use parking_lot::RwLock;
4use pin_project::pin_project;
5use std::{
6    collections::{HashMap, VecDeque},
7    fmt::Debug,
8    ops::{Add, AddAssign},
9    pin::Pin,
10    sync::Arc,
11    task::Poll,
12};
13use tokio::sync::mpsc;
14
15///
16/// # Time Wheel Proxy
17/// this provide user with easy-use functions to interact with TimeWheel
18/// # example
19/// ```rust
20/// use gsfw_util::timer::Wheel;
21/// #[tokio::main]
22/// async fn main() {
23///     let now = std::time::Instant::now();
24///     let mut wheel = Wheel::<i32>::new(60, std::time::Duration::from_secs(1), now);
25///     let snapshot = wheel.dispatch(std::time::Duration::from_secs(2), 1).await.unwrap();
26///     let recv_data = wheel.tick().await;
27///     assert_eq!(recv_data.len(), 1);
28///     assert_eq!(recv_data[0].id, snapshot.id);
29///     assert_eq!(recv_data[0].data, Some(1));
30///     assert_eq!(recv_data[0].start, snapshot.start);
31///     assert_eq!(recv_data[0].end, snapshot.end);
32/// }
33/// ```
34pub struct WheelProxy<T>
35where
36    T: Debug + Send,
37{
38    slot: u32,
39    slot_duration: std::time::Duration,
40    start: Arc<RwLock<std::time::Instant>>,
41    tick_rx: mpsc::Receiver<VecDeque<Meta<T>>>,
42    inner_tx: mpsc::Sender<TimeWheelProto<T>>,
43    timer_map: HashMap<u64, super::Snapshot>,
44
45    ticker_join: tokio::task::JoinHandle<()>,
46    inner_join: tokio::task::JoinHandle<()>,
47}
48
49impl<T> WheelProxy<T>
50where
51    T: Send + Debug + 'static,
52{
53    pub fn new(slot: u32, slot_duration: std::time::Duration, start: std::time::Instant) -> Self {
54        Inner::new(slot, slot_duration, start)
55    }
56
57    pub fn slot(&self) -> u32 {
58        self.slot
59    }
60
61    pub fn slot_duration(&self) -> std::time::Duration {
62        self.slot_duration.clone()
63    }
64
65    pub fn round_duration(&self) -> std::time::Duration {
66        self.slot_duration * self.slot
67    }
68
69    pub fn round_end(&self) -> std::time::Instant {
70        self.start.read().add(self.slot * self.slot_duration)
71    }
72
73    pub async fn dispatch(
74        &mut self,
75        duration: std::time::Duration,
76        data: T,
77    ) -> Result<super::Snapshot, super::Error<T>> {
78        self.dispatch_until(std::time::Instant::now() + duration, data)
79            .await
80    }
81
82    pub async fn dispatch_until(
83        &mut self,
84        end: std::time::Instant,
85        data: T,
86    ) -> Result<super::Snapshot, super::Error<T>> {
87        let now = std::time::Instant::now();
88        if end < now {
89            return Err(super::Error::TimeElapse(Some(data)));
90        }
91        let round_end = self.round_end();
92        if end > round_end {
93            return Err(super::Error::Overflow(Some(data)));
94        }
95        let meta = Meta::new(now, end, data);
96        let snapshot = super::Snapshot {
97            id: meta.id,
98            start: meta.start,
99            end,
100        };
101        let timer_id = meta.id;
102        if let Err(err) = self.inner_tx.send(TimeWheelProto::Add(meta)).await {
103            return Err(match err.0 {
104                TimeWheelProto::Add(meta) => super::Error::Channel(meta.data),
105                // this shall never happen
106                _ => panic!("unexpected error"),
107            });
108        }
109        self.timer_map.insert(
110            timer_id,
111            super::Snapshot {
112                id: snapshot.id,
113                start: snapshot.start,
114                end,
115            },
116        );
117        return Ok(snapshot);
118    }
119
120    pub async fn cancel(&mut self, id: u64) -> Result<(), super::Error<T>> {
121        let snapshot = self
122            .timer_map
123            .remove(&id)
124            .ok_or(super::Error::NoRecord(id))?;
125        let slot = find_slot(
126            self.start.read().clone(),
127            self.slot_duration.as_nanos(),
128            snapshot.end,
129        );
130        self.inner_tx
131            .send(TimeWheelProto::Cancel {
132                id,
133                slot_hint: slot as usize,
134            })
135            .await
136            .map_err(|_| super::Error::Channel(None))
137    }
138
139    pub async fn accelerate(
140        &mut self,
141        id: u64,
142        acc_duration: std::time::Duration,
143    ) -> Result<(), super::Error<T>> {
144        let now = std::time::Instant::now();
145        // check timer exist
146        let snapshot = self
147            .timer_map
148            .get_mut(&id)
149            .ok_or(super::Error::NoRecord(id))?;
150        let slot = find_slot(
151            self.start.read().clone(),
152            self.slot_duration.as_nanos(),
153            snapshot.end,
154        );
155        // trigger now -> Trigger
156        if snapshot.end - acc_duration < now {
157            self.inner_tx
158                .send(TimeWheelProto::Trigger {
159                    id,
160                    slot_hint: slot as usize,
161                })
162                .await
163                .map_err(|_| super::Error::Channel(None))?;
164            snapshot.end -= acc_duration;
165            return Ok(());
166        }
167        // future trigger -> Accelerate
168        self.inner_tx
169            .send(TimeWheelProto::Accelerate {
170                id,
171                slot_hint: slot as usize,
172                dur: acc_duration,
173            })
174            .await
175            .map_err(|_| super::Error::Channel(None))
176    }
177
178    pub async fn delay(
179        &mut self,
180        id: u64,
181        delay_duration: std::time::Duration,
182    ) -> Result<(), super::Error<T>> {
183        let snapshot = self
184            .timer_map
185            .get_mut(&id)
186            .ok_or(super::Error::NoRecord(id))?;
187        let round_start = self.start.read().clone();
188        // check overflow
189        if snapshot.end + delay_duration > round_start + self.slot_duration * self.slot {
190            return Err(super::Error::Overflow(None));
191        }
192        let slot = find_slot(round_start, self.slot_duration.as_nanos(), snapshot.end);
193        self.inner_tx
194            .send(TimeWheelProto::Delay {
195                id,
196                slot_hint: slot as usize,
197                dur: delay_duration,
198            })
199            .await
200            .map_err(|_| super::Error::Channel(None))
201    }
202
203    pub async fn trigger(&mut self, id: u64) -> Result<(), super::Error<T>> {
204        let snapshot = self
205            .timer_map
206            .get_mut(&id)
207            .ok_or(super::Error::NoRecord(id))?;
208        let slot = find_slot(
209            self.start.read().clone(),
210            self.slot_duration.as_nanos(),
211            snapshot.end,
212        );
213        self.inner_tx
214            .send(TimeWheelProto::Trigger {
215                id,
216                slot_hint: slot as usize,
217            })
218            .await
219            .map_err(|_| super::Error::Channel(None))
220    }
221
222    /// batch_add atomic operation. either all metas are added to the wheel, nor none is added.
223    /// batch_add will first check metas's id is unique and
224    /// then check all metas are not overflow or elapse
225    pub async fn batch_add(
226        &mut self,
227        metas: Vec<Meta<T>>,
228    ) -> Result<Vec<super::Snapshot>, super::Error<T>> {
229        if metas.len() == 0 {
230            return Ok(Vec::new());
231        }
232
233        let start = self.start.read().clone();
234        let round_end = start.add(self.round_duration());
235        let slot_dur = self.slot_duration.as_nanos();
236        let mut proto = Vec::with_capacity(metas.len());
237        for mut meta in metas {
238            // check id unique
239            if self.timer_map.contains_key(&meta.id) {
240                return Err(super::Error::DupTimer(meta.id));
241            }
242            // ensure not overflow and elapse
243            if meta.start < start {
244                return Err(super::Error::TimeElapse(meta.data.take()));
245            } else if meta.end > round_end {
246                return Err(super::Error::Overflow(meta.data.take()));
247            } else {
248                // let diff = (meta.end - start).as_nanos();
249                let meta_end = meta.end;
250                proto.push((
251                    meta,
252                    // (diff / slot_dur - if diff % slot_dur != 0 { 0 } else { 1 }) as usize,
253                    find_slot(start, slot_dur, meta_end) as usize,
254                ));
255            }
256        }
257        // insert snapshot of this batch timers
258        proto.iter().for_each(|(meta, _)| {
259            self.timer_map.insert(
260                meta.id,
261                super::Snapshot {
262                    id: meta.id,
263                    start: meta.start,
264                    end: meta.end,
265                },
266            );
267        });
268        let ret = proto
269            .iter()
270            .map(|(meta, _)| super::Snapshot {
271                id: meta.id,
272                start: meta.start,
273                end: meta.end,
274            })
275            .collect();
276        if let Err(err) = self.inner_tx.send(TimeWheelProto::BatchAdd(proto)).await {
277            return Err(match err.0 {
278                TimeWheelProto::BatchAdd(batch) => {
279                    super::Error::BatchChannel(batch.into_iter().map(|(meta, _)| meta).collect())
280                }
281                // this shall never happen
282                _ => panic!("unexpected error"),
283            });
284        }
285        return Ok(ret);
286    }
287
288    pub async fn tick(&mut self) -> Vec<Meta<T>> {
289        if let Some(metas) = self.tick_rx.recv().await {
290            return metas
291                .into_iter()
292                .filter(|meta| {
293                    if self.timer_map.get(&meta.id).is_some() {
294                        self.timer_map.remove(&meta.id);
295                        return true;
296                    }
297                    return false;
298                })
299                .collect();
300        }
301        panic!()
302    }
303}
304
305impl<T> Drop for WheelProxy<T>
306where
307    T: Debug + Send,
308{
309    fn drop(&mut self) {
310        self.inner_join.abort();
311        self.ticker_join.abort();
312    }
313}
314
315enum InnerState {
316    PollRecv,
317    SendTick(Pin<Box<dyn Future<Output = ()> + Send>>),
318}
319
320#[pin_project]
321struct Inner<T: Debug + Send> {
322    pub(crate) slot: u32,
323    pub(crate) slot_duration: std::time::Duration,
324    pub(crate) wq: VecDeque<VecDeque<Meta<T>>>,
325    start: Arc<RwLock<std::time::Instant>>,
326    tx: mpsc::Sender<TimeWheelProto<T>>,
327    rx: mpsc::Receiver<TimeWheelProto<T>>,
328    tick_tx: mpsc::Sender<VecDeque<Meta<T>>>,
329    state: InnerState,
330}
331
332impl<T> Inner<T>
333where
334    T: Send + Debug + 'static,
335{
336    fn new(
337        slot: u32,
338        slot_duration: std::time::Duration,
339        start: std::time::Instant,
340    ) -> WheelProxy<T> {
341        let (tx, rx) = mpsc::channel(4);
342        let (tick_tx, tick_rx) = mpsc::channel(64);
343        let mut wq = VecDeque::with_capacity(slot as usize);
344        wq.resize_with(slot as usize, Default::default);
345        let arc_start = Arc::new(RwLock::new(start));
346        let inner = Self {
347            slot,
348            slot_duration: slot_duration.clone(),
349            start: arc_start.clone(),
350            wq,
351            tx: tx.clone(),
352            rx,
353            tick_tx,
354            state: InnerState::PollRecv,
355        };
356        let inner_tx = tx.clone();
357        let ticker_join = tokio::spawn(async move {
358            let mut interval =
359                tokio::time::interval_at(start.clone().into(), inner.slot_duration.clone());
360            let tx = inner_tx.clone();
361            loop {
362                interval.tick().await;
363                tx.send(TimeWheelProto::Tick).await.unwrap();
364            }
365        });
366
367        WheelProxy {
368            tick_rx,
369            inner_tx: tx,
370            ticker_join,
371            inner_join: tokio::spawn(inner),
372            slot,
373            slot_duration,
374            start: arc_start,
375            timer_map: Default::default(),
376        }
377    }
378
379}
380
381impl<T: Debug + Send> Future for Inner<T>
382where
383    T: Send + 'static,
384{
385    type Output = ();
386
387    fn poll(
388        self: std::pin::Pin<&mut Self>,
389        cx: &mut std::task::Context<'_>,
390    ) -> std::task::Poll<Self::Output> {
391        let this = self.project();
392        loop {
393            match this.state {
394                InnerState::PollRecv => {
395                    if let Some(proto) = ready!(this.rx.poll_recv(cx)) {
396                        match proto {
397                            TimeWheelProto::Tick => {
398                                this.start.write().add_assign(this.slot_duration.clone());
399                                if let Some(metas) = this.wq.pop_front() {
400                                    this.wq.push_back(Default::default());
401                                    if metas.len() == 0 {
402                                        continue;
403                                    }
404                                    let tx = this.tick_tx.clone();
405                                    let mut fut = Box::pin(async move {
406                                        if let Err(err) = tx.send(metas).await {
407                                            tracing::error!("TickTx send error: {:?}", err);
408                                        }
409                                    });
410                                    // try poll immediately
411                                    if let Poll::Pending = fut.poll_unpin(cx) {
412                                        *this.state = InnerState::SendTick(fut);
413                                        // yield
414                                        return Poll::Pending;
415                                    }
416                                }
417                            }
418                            TimeWheelProto::Add(meta) => {
419                                let start = this.start.read().clone();
420                                let slot_dur = this.slot_duration.as_nanos();
421                                let slot = find_slot(start, slot_dur, meta.end);
422                                if slot as u32 >= *this.slot {
423                                    tracing::error!("[Add] overflow timer. {:?}", meta);
424                                    continue;
425                                }
426                                tracing::info!("[Add] add timer. {:?}", meta);
427                                this.wq.get_mut(slot as usize).unwrap().push_back(meta);
428                            }
429                            TimeWheelProto::BatchAdd(batch) => {
430                                for (meta, slot) in batch {
431                                    let vec = this.wq.get_mut(slot).unwrap();
432                                    tracing::info!("[BatchAdd] add timer. {:?}", meta);
433                                    vec.push_back(meta);
434                                }
435                            }
436                            TimeWheelProto::Cancel { id, slot_hint } => {
437                                let vec = this.wq.get_mut(slot_hint).unwrap();
438                                tracing::info!("[Cancel] cancel timer {}", id);
439                                if let Some((idx, _)) =
440                                    vec.iter().enumerate().find(|(_, meta)| meta.id == id)
441                                {
442                                    vec.swap_remove_back(idx);
443                                } else {
444                                    tracing::warn!("[Cancel] timer {} not found", id);
445                                }
446                            }
447                            TimeWheelProto::Accelerate { id, slot_hint, dur } => {
448                                let vec = this.wq.get_mut(slot_hint).unwrap();
449                                if let Some((idx, meta)) =
450                                    vec.iter_mut().enumerate().find(|(_, meta)| meta.id == id)
451                                {
452                                    let new_slot = find_slot(
453                                        this.start.read().clone(),
454                                        this.slot_duration.as_nanos(),
455                                        meta.end - dur,
456                                    ) as usize;
457                                    meta.end -= dur;
458                                    // before current slot
459                                    if new_slot < slot_hint {
460                                        let meta = vec.remove(idx).unwrap();
461                                        this.wq.get_mut(new_slot).unwrap().push_back(meta);
462                                    }
463                                } else {
464                                    tracing::warn!("[Accelerate] timer {} not found", id);
465                                }
466                            }
467                            TimeWheelProto::Delay { id, slot_hint, dur } => {
468                                let vec = this.wq.get_mut(slot_hint).unwrap();
469                                if let Some((idx, meta)) =
470                                    vec.iter_mut().enumerate().find(|(_, meta)| meta.id == id)
471                                {
472                                    let new_slot = find_slot(
473                                        this.start.read().clone(),
474                                        this.slot_duration.as_nanos(),
475                                        meta.end + dur,
476                                    ) as usize;
477                                    meta.end += dur;
478                                    // after current slot
479                                    if new_slot > slot_hint {
480                                        let meta = vec.remove(idx).unwrap();
481                                        this.wq.get_mut(new_slot).unwrap().push_back(meta);
482                                    }
483                                } else {
484                                    tracing::warn!("[Delay] timer {} not found", id);
485                                }
486                            }
487                            TimeWheelProto::Trigger { id, slot_hint } => {
488                                let now = std::time::Instant::now();
489                                let vec = this.wq.get_mut(slot_hint).unwrap();
490                                tracing::info!("[Trigger] trigger timer {} now", id);
491                                if let Some((idx, _)) =
492                                    vec.iter().enumerate().find(|(_, meta)| meta.id == id)
493                                {
494                                    if let Some(mut meta) = vec.swap_remove_back(idx) {
495                                        meta.end = now;
496                                        let mut metas = VecDeque::with_capacity(1);
497                                        metas.push_back(meta);
498                                        let tx = this.tick_tx.clone();
499                                        let mut fut = Box::pin(async move {
500                                            if let Err(err) = tx.send(metas).await {
501                                                tracing::error!("TickTx send error: {:?}", err);
502                                            }
503                                        });
504                                        // try poll immediately
505                                        if let Poll::Pending = fut.poll_unpin(cx) {
506                                            *this.state = InnerState::SendTick(fut);
507                                            // yield
508                                            return Poll::Pending;
509                                        }
510                                    } else {
511                                        tracing::warn!("[Trigger] trigger timer {} not found", id);
512                                    }
513                                } else {
514                                    tracing::warn!("[Cancel] timer {} not found", id);
515                                }
516                            }
517                        }
518                    }
519                }
520                InnerState::SendTick(fut) => match fut.poll_unpin(cx) {
521                    std::task::Poll::Ready(_) => *this.state = InnerState::PollRecv,
522                    std::task::Poll::Pending => return Poll::Pending,
523                },
524            }
525        }
526    }
527}
528
529#[inline(always)]
530pub(crate) fn find_slot(
531    wheel_start: std::time::Instant,
532    slot_dur_ns: u128,
533    timer_end: std::time::Instant,
534) -> u128 {
535    let diff = (timer_end - wheel_start).as_nanos();
536    diff / slot_dur_ns - if diff % slot_dur_ns != 0 { 0 } else { 1 }
537}