1#![allow(dead_code)]
2
3use std::{
4 collections::{BTreeMap, HashMap, HashSet},
5 fmt,
6 iter::IntoIterator,
7 sync::Arc,
8 time::Duration,
9};
10
11use bigerror::{attachment::DisplayDuration, ConversionError, Report};
12use parking_lot::Mutex;
13use tokio::{
14 sync::{mpsc, mpsc::UnboundedSender},
15 task::JoinSet,
16 time::Instant,
17};
18use tracing::{debug, error, instrument, warn, Instrument};
19
20use crate::{
21 manager::{HashKind, Signal, SignalQueue},
22 notification::{Notification, NotificationProcessor, RexMessage, UnaryRequest},
23 Kind, Rex, StateId,
24};
25
26pub const DEFAULT_TICK_RATE: Duration = Duration::from_millis(5);
27const SHORT_TIMEOUT: Duration = Duration::from_secs(10);
28
29fn hms_string(duration: Duration) -> String {
31 if duration.is_zero() {
32 return "ZERO".to_string();
33 }
34 let s = duration.as_secs();
35 let ms = duration.subsec_millis();
36 if s == 0 {
38 return format!("{ms}ms");
39 }
40 let (h, s) = (s / 3600, s % 3600);
42 let (m, s) = (s / 60, s % 60);
43
44 let mut hms = String::new();
45 if h != 0 {
46 hms += &format!("{h:02}H");
47 }
48 if m != 0 {
49 hms += &format!("{m:02}m");
50 }
51 hms += &format!("{s:02}s");
52
53 hms
54}
55
56#[derive(Debug)]
62struct TimeoutLedger<K>
63where
64 K: Kind + Rex,
65 K::Message: TimeoutMessage<K>,
66{
67 timers: BTreeMap<Instant, HashSet<StateId<K>>>,
68 ids: HashMap<StateId<K>, Instant>,
69 retainer: BTreeMap<Instant, Vec<RetainPair<K>>>,
70}
71type RetainPair<K> = (StateId<K>, RetainItem<K>);
72
73impl<K> TimeoutLedger<K>
74where
75 K: Rex + HashKind + Copy,
76 K::Message: TimeoutMessage<K>,
77{
78 fn new() -> Self {
79 Self {
80 timers: BTreeMap::new(),
81 ids: HashMap::new(),
82 retainer: BTreeMap::new(),
83 }
84 }
85
86 fn lint_instant(instant: Instant) {
87 let now = Instant::now();
88 if instant < now {
89 error!("requested timeout is in the past");
90 }
91 let duration = instant - now;
92 if duration <= SHORT_TIMEOUT {
93 warn!(duration = %DisplayDuration(instant - now), "setting short timeout");
94 } else {
95 debug!(duration = %DisplayDuration(instant - now), "setting timeout");
96 }
97 }
98
99 #[instrument(skip_all, fields(%id))]
100 fn retain(&mut self, id: StateId<K>, instant: Instant, item: RetainItem<K>) {
101 Self::lint_instant(instant);
102 self.retainer.entry(instant).or_default().push((id, item));
103 }
104
105 #[instrument(skip_all, fields(%id))]
108 fn set_timeout(&mut self, id: StateId<K>, instant: Instant) {
109 Self::lint_instant(instant);
110
111 if let Some(old_instant) = self.ids.insert(id, instant) {
112 if old_instant != instant {
115 debug!(%id, "renewing timeout");
116 self.timers.get_mut(&old_instant).map(|set| set.remove(&id));
117 }
118 }
119
120 self.timers
121 .entry(instant)
122 .and_modify(|set| {
123 set.insert(id);
124 })
125 .or_default()
126 .insert(id);
127 }
128
129 fn cancel_timeout(&mut self, id: StateId<K>) {
132 if let Some(instant) = self.ids.remove(&id) {
133 let removed_id = self.timers.get_mut(&instant).map(|set| set.remove(&id));
136 if matches!(removed_id, None | Some(false)) {
142 warn!("timers[{instant:?}][{id}] not found, cancellation ignored");
143 } else {
144 debug!(%id, "cancelled timeout");
145 }
146 }
147 }
148}
149
150pub trait TimeoutMessage<K: Rex>:
151 std::fmt::Debug
152 + RexMessage
153 + From<UnaryRequest<K, Operation<Self::Item>>>
154 + TryInto<UnaryRequest<K, Operation<Self::Item>>, Error = Report<ConversionError>>
155{
156 type Item: Copy + Send + std::fmt::Debug;
157}
158
159pub trait Timeout: Rex
160where
161 Self::Message: TimeoutMessage<Self>,
162{
163 fn return_item(&self, _item: RetainItem<Self>) -> Option<Self::Input> {
164 None
165 }
166}
167
168#[derive(Copy, Clone, Debug, derive_more::Display)]
169pub struct NoRetain;
170
171#[derive(Copy, Clone, Debug)]
172pub enum Operation<T> {
173 Cancel,
174 Set(Instant),
175 Retain(T, Instant),
176}
177
178impl<T> std::fmt::Display for Operation<T> {
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 let op = match self {
181 Self::Cancel => "timeout::Cancel",
182 Self::Set(_) => "timeout::Set",
183 Self::Retain(_, _) => "timeout::Retain",
184 };
185 write!(f, "{op}")
186 }
187}
188
189impl<T> Operation<T> {
190 #[must_use]
191 pub fn from_duration(duration: Duration) -> Self {
192 Self::Set(Instant::now() + duration)
193 }
194
195 #[must_use]
196 pub fn from_millis(millis: u64) -> Self {
197 Self::Set(Instant::now() + Duration::from_millis(millis))
198 }
199}
200
201pub type TimeoutInput<K> = UnaryRequest<K, TimeoutOp<K>>;
202pub type TimeoutOp<K> = Operation<<<K as Rex>::Message as TimeoutMessage<K>>::Item>;
203pub type RetainItem<K> = <<K as Rex>::Message as TimeoutMessage<K>>::Item;
204
205impl<K> UnaryRequest<K, TimeoutOp<K>>
206where
207 K: Rex,
208 K::Message: TimeoutMessage<K>,
209{
210 #[cfg(test)]
211 pub(crate) fn set_timeout_millis(id: StateId<K>, millis: u64) -> Self {
212 Self {
213 id,
214 op: Operation::from_millis(millis),
215 }
216 }
217
218 pub fn set_timeout(id: StateId<K>, duration: Duration) -> Self {
219 Self {
220 id,
221 op: Operation::from_duration(duration),
222 }
223 }
224
225 pub const fn cancel_timeout(id: StateId<K>) -> Self {
226 Self {
227 id,
228 op: Operation::Cancel,
229 }
230 }
231
232 pub fn retain(id: StateId<K>, item: RetainItem<K>, duration: Duration) -> Self {
233 Self {
234 id,
235 op: Operation::Retain(item, Instant::now() + duration),
236 }
237 }
238
239 #[cfg(test)]
240 const fn with_id(&self, id: StateId<K>) -> Self {
241 Self { id, ..*self }
242 }
243 #[cfg(test)]
244 const fn with_op(&self, op: TimeoutOp<K>) -> Self {
245 Self { op, ..*self }
246 }
247}
248
249pub struct TimeoutManager<K>
252where
253 K: Rex,
254 K::Message: TimeoutMessage<K>,
255{
256 tick_rate: Duration,
258 ledger: Arc<Mutex<TimeoutLedger<K>>>,
259 topic: <K::Message as RexMessage>::Topic,
260
261 pub(crate) signal_queue: SignalQueue<K>,
262}
263
264impl<K> TimeoutManager<K>
265where
266 K: Rex + Timeout,
267 K::Message: TimeoutMessage<K>,
268{
269 #[must_use]
270 pub fn new(
271 signal_queue: SignalQueue<K>,
272 topic: impl Into<<K::Message as RexMessage>::Topic>,
273 ) -> Self {
274 Self {
275 tick_rate: DEFAULT_TICK_RATE,
276 signal_queue,
277 ledger: Arc::new(Mutex::new(TimeoutLedger::new())),
278 topic: topic.into(),
279 }
280 }
281
282 #[must_use]
283 pub fn with_tick_rate(self, tick_rate: Duration) -> Self {
284 Self { tick_rate, ..self }
285 }
286
287 pub fn init_inner(&self) -> UnboundedSender<Notification<K::Message>> {
288 let mut join_set = JoinSet::new();
289 let tx = self.init_inner_with_handle(&mut join_set);
290 join_set.detach_all();
291 tx
292 }
293
294 pub fn init_inner_with_handle(
295 &self,
296 join_set: &mut JoinSet<()>,
297 ) -> UnboundedSender<Notification<K::Message>> {
298 let (input_tx, mut input_rx) = mpsc::unbounded_channel::<Notification<K::Message>>();
299 let in_ledger = self.ledger.clone();
300
301 join_set.spawn(
302 async move {
303 debug!(target: "state_machine", spawning = "TimeoutManager.notification_tx");
304 while let Some(Notification(msg)) = input_rx.recv().await {
305 match msg.try_into() {
306 Ok(UnaryRequest { id, op }) => {
307 let mut ledger = in_ledger.lock();
308 match op {
309 Operation::Cancel => {
310 ledger.cancel_timeout(id);
311 }
312 Operation::Set(instant) => {
313 ledger.set_timeout(id, instant);
314 }
315 Operation::Retain(item, instant) => {
316 ledger.retain(id, instant, item);
317 }
318 }
319 }
320 Err(_e) => {
321 warn!("Invalid input");
322 continue;
323 }
324 }
325 }
326 }
327 .in_current_span(),
328 );
329
330 let timer_ledger = self.ledger.clone();
331 let mut interval = tokio::time::interval(self.tick_rate);
332 let signal_queue = self.signal_queue.clone();
333 join_set.spawn(
334 async move {
335 loop {
336 interval.tick().await;
337
338 let now = Instant::now();
339 let mut ledger = timer_ledger.lock();
340 let mut release = ledger.timers.split_off(&now);
342 std::mem::swap(&mut release, &mut ledger.timers);
343
344 for id in release.into_values().flat_map(IntoIterator::into_iter) {
345 warn!(%id, "timed out");
346 ledger.ids.remove(&id);
347 if let Some(input) = id.timeout_input(now) {
348 signal_queue.push_front(Signal { id, input });
351 } else {
352 warn!(%id, "timeout not supported!");
353 }
354 }
355
356 let mut release = ledger.retainer.split_off(&now);
357 std::mem::swap(&mut release, &mut ledger.retainer);
358 drop(ledger);
359 for (id, item) in release.into_values().flat_map(IntoIterator::into_iter) {
360 if let Some(input) = id.return_item(item) {
361 signal_queue.push_front(Signal { id, input });
364 } else {
365 warn!(%id, "timeout not supported!");
366 }
367 }
368 }
369 }
370 .in_current_span(),
371 );
372
373 input_tx
374 }
375}
376
377impl<K> NotificationProcessor<K::Message> for TimeoutManager<K>
378where
379 K: Rex + Timeout,
380 K::Message: TimeoutMessage<K>,
381{
382 fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
383 self.init_inner_with_handle(join_set)
384 }
385
386 fn get_topics(&self) -> &[<K::Message as RexMessage>::Topic] {
387 std::slice::from_ref(&self.topic)
388 }
389}
390
391#[cfg(test)]
392#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
393pub struct TimeoutTopic;
394
395#[cfg(test)]
396pub(crate) const TEST_TICK_RATE: Duration = Duration::from_millis(3);
397
398#[cfg(test)]
399pub(crate) const TEST_TIMEOUT: Duration = Duration::from_millis(11);
400
401#[cfg(test)]
402mod tests {
403
404 use super::*;
405 use crate::test_support::*;
406
407 impl TestDefault for TimeoutManager<TestKind> {
408 fn test_default() -> Self {
409 let signal_queue = SignalQueue::default();
410 Self::new(signal_queue, TestTopic::Timeout).with_tick_rate(TEST_TICK_RATE)
411 }
412 }
413
414 #[tokio::test]
415 async fn timeout_to_signal() {
416 let mut timeout_manager = TimeoutManager::test_default();
417
418 let mut join_set = JoinSet::new();
419 let timeout_tx: UnboundedSender<Notification<TestMsg>> =
420 timeout_manager.init(&mut join_set);
421
422 let test_id = StateId::new_rand(TestKind);
423 let timeout_duration = Duration::from_millis(5);
424
425 let timeout = Instant::now() + timeout_duration;
426 let set_timeout = UnaryRequest::set_timeout(test_id, timeout_duration);
427
428 timeout_tx
429 .send(Notification(TestMsg::TimeoutInput(set_timeout)))
430 .unwrap();
431
432 tokio::time::sleep(timeout_duration * 3).await;
434
435 let Signal { id, input } = timeout_manager.signal_queue.pop_front().unwrap();
436 assert_eq!(test_id, id);
437
438 let TestInput::Timeout(signal_timeout) = input else {
439 panic!("{input:?}");
440 };
441 assert!(
442 signal_timeout >= timeout,
443 "out[{signal_timeout:?}] >= in[{timeout:?}]"
444 );
445 }
446
447 #[tokio::test]
448 async fn timeout_cancellation() {
449 let mut timeout_manager = TimeoutManager::test_default();
450
451 let mut join_set = JoinSet::new();
452 let timeout_tx: UnboundedSender<Notification<TestMsg>> =
453 timeout_manager.init(&mut join_set);
454
455 let test_id = StateId::new_rand(TestKind);
456 let set_timeout = UnaryRequest::set_timeout_millis(test_id, 10);
457
458 timeout_tx
459 .send(Notification(TestMsg::TimeoutInput(set_timeout)))
460 .unwrap();
461
462 tokio::time::sleep(Duration::from_millis(2)).await;
463 let cancel_timeout = UnaryRequest {
464 id: test_id,
465 op: Operation::Cancel,
466 };
467 timeout_tx
468 .send(Notification(TestMsg::TimeoutInput(cancel_timeout)))
469 .unwrap();
470
471 tokio::time::sleep(Duration::from_millis(3) + TEST_TICK_RATE * 3).await;
473
474 assert!(timeout_manager.signal_queue.pop_front().is_none());
476 }
477
478 #[tokio::test]
480 #[tracing_test::traced_test]
481 async fn partial_timeout_cancellation() {
482 let mut timeout_manager = TimeoutManager::test_default();
483
484 let mut join_set = JoinSet::new();
485 let timeout_tx: UnboundedSender<Notification<TestMsg>> =
486 timeout_manager.init(&mut join_set);
487
488 let id1 = StateId::new_with_u128(TestKind, 1);
489 let id2 = StateId::new_with_u128(TestKind, 2); let id3 = StateId::new_with_u128(TestKind, 3); let timeout_duration = Duration::from_millis(5);
493 let now = Instant::now();
494 let timeout = now + timeout_duration;
495 let early_timeout = timeout - Duration::from_millis(2);
496 let set_timeout = UnaryRequest {
497 id: id1,
498 op: Operation::Set(timeout),
499 };
500
501 timeout_tx
502 .send(Notification(TestMsg::TimeoutInput(set_timeout)))
503 .unwrap();
504 timeout_tx
505 .send(Notification(TestMsg::TimeoutInput(
506 set_timeout.with_id(id2),
507 )))
508 .unwrap();
509 timeout_tx
510 .send(Notification(TestMsg::TimeoutInput(
511 set_timeout.with_id(id3),
512 )))
513 .unwrap();
514
515 timeout_tx
519 .send(Notification(TestMsg::TimeoutInput(
520 set_timeout.with_id(id2).with_op(Operation::Cancel),
521 )))
522 .unwrap();
523 timeout_tx
525 .send(Notification(TestMsg::TimeoutInput(
526 set_timeout
527 .with_id(id3)
528 .with_op(Operation::Set(early_timeout)),
529 )))
530 .unwrap();
531
532 tokio::time::sleep(timeout_duration * 3).await;
533
534 let first_timeout = timeout_manager.signal_queue.pop_front().unwrap();
535 assert_eq!(id3, first_timeout.id);
536
537 let second_timeout = timeout_manager.signal_queue.pop_front().unwrap();
538 assert_eq!(id1, second_timeout.id);
539
540 assert!(timeout_manager.signal_queue.pop_front().is_none());
542 }
543}