rex/
timeout.rs

1#![allow(dead_code)]
2
3use std::{
4    collections::{BTreeMap, HashMap, HashSet},
5    fmt,
6    iter::IntoIterator,
7    sync::Arc,
8    time::Duration,
9};
10
11use bigerror::{attachment::DisplayDuration, ConversionError, Report};
12use parking_lot::Mutex;
13use tokio::{
14    sync::{mpsc, mpsc::UnboundedSender},
15    task::JoinSet,
16    time::Instant,
17};
18use tracing::{debug, error, instrument, warn, Instrument};
19
20use crate::{
21    manager::{HashKind, Signal, SignalQueue},
22    notification::{Notification, NotificationProcessor, RexMessage, UnaryRequest},
23    Kind, Rex, StateId,
24};
25
26pub const DEFAULT_TICK_RATE: Duration = Duration::from_millis(5);
27const SHORT_TIMEOUT: Duration = Duration::from_secs(10);
28
29/// convert a [`Duration`] into a "0H00m00s" string
30fn hms_string(duration: Duration) -> String {
31    if duration.is_zero() {
32        return "ZERO".to_string();
33    }
34    let s = duration.as_secs();
35    let ms = duration.subsec_millis();
36    // if only milliseconds available
37    if s == 0 {
38        return format!("{ms}ms");
39    }
40    // Grab total hours from seconds
41    let (h, s) = (s / 3600, s % 3600);
42    let (m, s) = (s / 60, s % 60);
43
44    let mut hms = String::new();
45    if h != 0 {
46        hms += &format!("{h:02}H");
47    }
48    if m != 0 {
49        hms += &format!("{m:02}m");
50    }
51    hms += &format!("{s:02}s");
52
53    hms
54}
55
56/// `TimeoutLedger` contains a [`BTreeMap`] that uses [`Instant`]s to time out
57/// specific [`StateId`]s and a [`HashMap`] that indexes `Instant`s by [`StateId`].
58///
59/// This double indexing allows [`Operation::Cancel`]s to go
60/// through without having to provide an `Instant`.
61#[derive(Debug)]
62struct TimeoutLedger<K>
63where
64    K: Kind + Rex,
65    K::Message: TimeoutMessage<K>,
66{
67    timers: BTreeMap<Instant, HashSet<StateId<K>>>,
68    ids: HashMap<StateId<K>, Instant>,
69    retainer: BTreeMap<Instant, Vec<RetainPair<K>>>,
70}
71type RetainPair<K> = (StateId<K>, RetainItem<K>);
72
73impl<K> TimeoutLedger<K>
74where
75    K: Rex + HashKind + Copy,
76    K::Message: TimeoutMessage<K>,
77{
78    fn new() -> Self {
79        Self {
80            timers: BTreeMap::new(),
81            ids: HashMap::new(),
82            retainer: BTreeMap::new(),
83        }
84    }
85
86    fn lint_instant(instant: Instant) {
87        let now = Instant::now();
88        if instant < now {
89            error!("requested timeout is in the past");
90        }
91        let duration = instant - now;
92        if duration <= SHORT_TIMEOUT {
93            warn!(duration = %DisplayDuration(instant - now), "setting short timeout");
94        } else {
95            debug!(duration = %DisplayDuration(instant - now), "setting timeout");
96        }
97    }
98
99    #[instrument(skip_all, fields(%id))]
100    fn retain(&mut self, id: StateId<K>, instant: Instant, item: RetainItem<K>) {
101        Self::lint_instant(instant);
102        self.retainer.entry(instant).or_default().push((id, item));
103    }
104
105    // set timeout for a given instant and associate it with a given id
106    // remove old instants associated with the same id if they exist
107    #[instrument(skip_all, fields(%id))]
108    fn set_timeout(&mut self, id: StateId<K>, instant: Instant) {
109        Self::lint_instant(instant);
110
111        if let Some(old_instant) = self.ids.insert(id, instant) {
112            // remove older reference to id
113            // if instants differ
114            if old_instant != instant {
115                debug!(%id, "renewing timeout");
116                self.timers.get_mut(&old_instant).map(|set| set.remove(&id));
117            }
118        }
119
120        self.timers
121            .entry(instant)
122            .and_modify(|set| {
123                set.insert(id);
124            })
125            .or_default()
126            .insert(id);
127    }
128
129    // remove existing timeout by id, this should remove
130    // one entry in `self.ids` and one entry in `self.timers[id_instant]`
131    fn cancel_timeout(&mut self, id: StateId<K>) {
132        if let Some(instant) = self.ids.remove(&id) {
133            // remove reference to id
134            // from associated instant
135            let removed_id = self.timers.get_mut(&instant).map(|set| set.remove(&id));
136            // if
137            //   `instant` is missing from `self.timers`
138            // or
139            //   `id` is missing from `self.timers[instant]`:
140            //   warn
141            if matches!(removed_id, None | Some(false)) {
142                warn!("timers[{instant:?}][{id}] not found, cancellation ignored");
143            } else {
144                debug!(%id, "cancelled timeout");
145            }
146        }
147    }
148}
149
150pub trait TimeoutMessage<K: Rex>:
151    std::fmt::Debug
152    + RexMessage
153    + From<UnaryRequest<K, Operation<Self::Item>>>
154    + TryInto<UnaryRequest<K, Operation<Self::Item>>, Error = Report<ConversionError>>
155{
156    type Item: Copy + Send + std::fmt::Debug;
157}
158
159pub trait Timeout: Rex
160where
161    Self::Message: TimeoutMessage<Self>,
162{
163    fn return_item(&self, _item: RetainItem<Self>) -> Option<Self::Input> {
164        None
165    }
166}
167
168#[derive(Copy, Clone, Debug, derive_more::Display)]
169pub struct NoRetain;
170
171#[derive(Copy, Clone, Debug)]
172pub enum Operation<T> {
173    Cancel,
174    Set(Instant),
175    Retain(T, Instant),
176}
177
178impl<T> std::fmt::Display for Operation<T> {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        let op = match self {
181            Self::Cancel => "timeout::Cancel",
182            Self::Set(_) => "timeout::Set",
183            Self::Retain(_, _) => "timeout::Retain",
184        };
185        write!(f, "{op}")
186    }
187}
188
189impl<T> Operation<T> {
190    #[must_use]
191    pub fn from_duration(duration: Duration) -> Self {
192        Self::Set(Instant::now() + duration)
193    }
194
195    #[must_use]
196    pub fn from_millis(millis: u64) -> Self {
197        Self::Set(Instant::now() + Duration::from_millis(millis))
198    }
199}
200
201pub type TimeoutInput<K> = UnaryRequest<K, TimeoutOp<K>>;
202pub type TimeoutOp<K> = Operation<<<K as Rex>::Message as TimeoutMessage<K>>::Item>;
203pub type RetainItem<K> = <<K as Rex>::Message as TimeoutMessage<K>>::Item;
204
205impl<K> UnaryRequest<K, TimeoutOp<K>>
206where
207    K: Rex,
208    K::Message: TimeoutMessage<K>,
209{
210    #[cfg(test)]
211    pub(crate) fn set_timeout_millis(id: StateId<K>, millis: u64) -> Self {
212        Self {
213            id,
214            op: Operation::from_millis(millis),
215        }
216    }
217
218    pub fn set_timeout(id: StateId<K>, duration: Duration) -> Self {
219        Self {
220            id,
221            op: Operation::from_duration(duration),
222        }
223    }
224
225    pub const fn cancel_timeout(id: StateId<K>) -> Self {
226        Self {
227            id,
228            op: Operation::Cancel,
229        }
230    }
231
232    pub fn retain(id: StateId<K>, item: RetainItem<K>, duration: Duration) -> Self {
233        Self {
234            id,
235            op: Operation::Retain(item, Instant::now() + duration),
236        }
237    }
238
239    #[cfg(test)]
240    const fn with_id(&self, id: StateId<K>) -> Self {
241        Self { id, ..*self }
242    }
243    #[cfg(test)]
244    const fn with_op(&self, op: TimeoutOp<K>) -> Self {
245        Self { op, ..*self }
246    }
247}
248
249/// Processes incoming [`Operation`]s and modifies the [`TimeoutLedger`]
250/// through a polling loop.
251pub struct TimeoutManager<K>
252where
253    K: Rex,
254    K::Message: TimeoutMessage<K>,
255{
256    // the interval at which  the TimeoutLedger checks for timeouts
257    tick_rate: Duration,
258    ledger: Arc<Mutex<TimeoutLedger<K>>>,
259    topic: <K::Message as RexMessage>::Topic,
260
261    pub(crate) signal_queue: SignalQueue<K>,
262}
263
264impl<K> TimeoutManager<K>
265where
266    K: Rex + Timeout,
267    K::Message: TimeoutMessage<K>,
268{
269    #[must_use]
270    pub fn new(
271        signal_queue: SignalQueue<K>,
272        topic: impl Into<<K::Message as RexMessage>::Topic>,
273    ) -> Self {
274        Self {
275            tick_rate: DEFAULT_TICK_RATE,
276            signal_queue,
277            ledger: Arc::new(Mutex::new(TimeoutLedger::new())),
278            topic: topic.into(),
279        }
280    }
281
282    #[must_use]
283    pub fn with_tick_rate(self, tick_rate: Duration) -> Self {
284        Self { tick_rate, ..self }
285    }
286
287    pub fn init_inner(&self) -> UnboundedSender<Notification<K::Message>> {
288        let mut join_set = JoinSet::new();
289        let tx = self.init_inner_with_handle(&mut join_set);
290        join_set.detach_all();
291        tx
292    }
293
294    pub fn init_inner_with_handle(
295        &self,
296        join_set: &mut JoinSet<()>,
297    ) -> UnboundedSender<Notification<K::Message>> {
298        let (input_tx, mut input_rx) = mpsc::unbounded_channel::<Notification<K::Message>>();
299        let in_ledger = self.ledger.clone();
300
301        join_set.spawn(
302            async move {
303                debug!(target: "state_machine", spawning = "TimeoutManager.notification_tx");
304                while let Some(Notification(msg)) = input_rx.recv().await {
305                    match msg.try_into() {
306                        Ok(UnaryRequest { id, op }) => {
307                            let mut ledger = in_ledger.lock();
308                            match op {
309                                Operation::Cancel => {
310                                    ledger.cancel_timeout(id);
311                                }
312                                Operation::Set(instant) => {
313                                    ledger.set_timeout(id, instant);
314                                }
315                                Operation::Retain(item, instant) => {
316                                    ledger.retain(id, instant, item);
317                                }
318                            }
319                        }
320                        Err(_e) => {
321                            warn!("Invalid input");
322                            continue;
323                        }
324                    }
325                }
326            }
327            .in_current_span(),
328        );
329
330        let timer_ledger = self.ledger.clone();
331        let mut interval = tokio::time::interval(self.tick_rate);
332        let signal_queue = self.signal_queue.clone();
333        join_set.spawn(
334            async move {
335                loop {
336                    interval.tick().await;
337
338                    let now = Instant::now();
339                    let mut ledger = timer_ledger.lock();
340                    // Get all instants where `instant <= now`
341                    let mut release = ledger.timers.split_off(&now);
342                    std::mem::swap(&mut release, &mut ledger.timers);
343
344                    for id in release.into_values().flat_map(IntoIterator::into_iter) {
345                        warn!(%id, "timed out");
346                        ledger.ids.remove(&id);
347                        if let Some(input) = id.timeout_input(now) {
348                            // caveat with this push_front setup is
349                            // that later timeouts will be on top of the stack
350                            signal_queue.push_front(Signal { id, input });
351                        } else {
352                            warn!(%id, "timeout not supported!");
353                        }
354                    }
355
356                    let mut release = ledger.retainer.split_off(&now);
357                    std::mem::swap(&mut release, &mut ledger.retainer);
358                    drop(ledger);
359                    for (id, item) in release.into_values().flat_map(IntoIterator::into_iter) {
360                        if let Some(input) = id.return_item(item) {
361                            // caveat with this push_front setup is
362                            // that later timeouts will be on top of the stack
363                            signal_queue.push_front(Signal { id, input });
364                        } else {
365                            warn!(%id, "timeout not supported!");
366                        }
367                    }
368                }
369            }
370            .in_current_span(),
371        );
372
373        input_tx
374    }
375}
376
377impl<K> NotificationProcessor<K::Message> for TimeoutManager<K>
378where
379    K: Rex + Timeout,
380    K::Message: TimeoutMessage<K>,
381{
382    fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
383        self.init_inner_with_handle(join_set)
384    }
385
386    fn get_topics(&self) -> &[<K::Message as RexMessage>::Topic] {
387        std::slice::from_ref(&self.topic)
388    }
389}
390
391#[cfg(test)]
392#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
393pub struct TimeoutTopic;
394
395#[cfg(test)]
396pub(crate) const TEST_TICK_RATE: Duration = Duration::from_millis(3);
397
398#[cfg(test)]
399pub(crate) const TEST_TIMEOUT: Duration = Duration::from_millis(11);
400
401#[cfg(test)]
402mod tests {
403
404    use super::*;
405    use crate::test_support::*;
406
407    impl TestDefault for TimeoutManager<TestKind> {
408        fn test_default() -> Self {
409            let signal_queue = SignalQueue::default();
410            Self::new(signal_queue, TestTopic::Timeout).with_tick_rate(TEST_TICK_RATE)
411        }
412    }
413
414    #[tokio::test]
415    async fn timeout_to_signal() {
416        let mut timeout_manager = TimeoutManager::test_default();
417
418        let mut join_set = JoinSet::new();
419        let timeout_tx: UnboundedSender<Notification<TestMsg>> =
420            timeout_manager.init(&mut join_set);
421
422        let test_id = StateId::new_rand(TestKind);
423        let timeout_duration = Duration::from_millis(5);
424
425        let timeout = Instant::now() + timeout_duration;
426        let set_timeout = UnaryRequest::set_timeout(test_id, timeout_duration);
427
428        timeout_tx
429            .send(Notification(TestMsg::TimeoutInput(set_timeout)))
430            .unwrap();
431
432        // ensure two ticks have passed
433        tokio::time::sleep(timeout_duration * 3).await;
434
435        let Signal { id, input } = timeout_manager.signal_queue.pop_front().unwrap();
436        assert_eq!(test_id, id);
437
438        let TestInput::Timeout(signal_timeout) = input else {
439            panic!("{input:?}");
440        };
441        assert!(
442            signal_timeout >= timeout,
443            "out[{signal_timeout:?}] >= in[{timeout:?}]"
444        );
445    }
446
447    #[tokio::test]
448    async fn timeout_cancellation() {
449        let mut timeout_manager = TimeoutManager::test_default();
450
451        let mut join_set = JoinSet::new();
452        let timeout_tx: UnboundedSender<Notification<TestMsg>> =
453            timeout_manager.init(&mut join_set);
454
455        let test_id = StateId::new_rand(TestKind);
456        let set_timeout = UnaryRequest::set_timeout_millis(test_id, 10);
457
458        timeout_tx
459            .send(Notification(TestMsg::TimeoutInput(set_timeout)))
460            .unwrap();
461
462        tokio::time::sleep(Duration::from_millis(2)).await;
463        let cancel_timeout = UnaryRequest {
464            id: test_id,
465            op: Operation::Cancel,
466        };
467        timeout_tx
468            .send(Notification(TestMsg::TimeoutInput(cancel_timeout)))
469            .unwrap();
470
471        // wait out the rest of the duration and 4 ticks
472        tokio::time::sleep(Duration::from_millis(3) + TEST_TICK_RATE * 3).await;
473
474        // we should not be getting any signal since the timeout was cancelled
475        assert!(timeout_manager.signal_queue.pop_front().is_none());
476    }
477
478    // this test ensures that 2/3 timers are cancelled
479    #[tokio::test]
480    #[tracing_test::traced_test]
481    async fn partial_timeout_cancellation() {
482        let mut timeout_manager = TimeoutManager::test_default();
483
484        let mut join_set = JoinSet::new();
485        let timeout_tx: UnboundedSender<Notification<TestMsg>> =
486            timeout_manager.init(&mut join_set);
487
488        let id1 = StateId::new_with_u128(TestKind, 1);
489        let id2 = StateId::new_with_u128(TestKind, 2); // gets cancelled
490        let id3 = StateId::new_with_u128(TestKind, 3); // gets overridden with earlier timeout
491
492        let timeout_duration = Duration::from_millis(5);
493        let now = Instant::now();
494        let timeout = now + timeout_duration;
495        let early_timeout = timeout - Duration::from_millis(2);
496        let set_timeout = UnaryRequest {
497            id: id1,
498            op: Operation::Set(timeout),
499        };
500
501        timeout_tx
502            .send(Notification(TestMsg::TimeoutInput(set_timeout)))
503            .unwrap();
504        timeout_tx
505            .send(Notification(TestMsg::TimeoutInput(
506                set_timeout.with_id(id2),
507            )))
508            .unwrap();
509        timeout_tx
510            .send(Notification(TestMsg::TimeoutInput(
511                set_timeout.with_id(id3),
512            )))
513            .unwrap();
514
515        //id1 should timeout after 5 milliseconds
516        // ...
517        // id2 cancellation
518        timeout_tx
519            .send(Notification(TestMsg::TimeoutInput(
520                set_timeout.with_id(id2).with_op(Operation::Cancel),
521            )))
522            .unwrap();
523        // id3 should timeout 2 milliseconds earlier than id1
524        timeout_tx
525            .send(Notification(TestMsg::TimeoutInput(
526                set_timeout
527                    .with_id(id3)
528                    .with_op(Operation::Set(early_timeout)),
529            )))
530            .unwrap();
531
532        tokio::time::sleep(timeout_duration * 3).await;
533
534        let first_timeout = timeout_manager.signal_queue.pop_front().unwrap();
535        assert_eq!(id3, first_timeout.id);
536
537        let second_timeout = timeout_manager.signal_queue.pop_front().unwrap();
538        assert_eq!(id1, second_timeout.id);
539
540        // ... and id2 should be cancelled
541        assert!(timeout_manager.signal_queue.pop_front().is_none());
542    }
543}