1use std::collections::{BTreeSet, HashMap};
2use std::ops::Sub;
3
4use mm1_address::address::Address;
5use mm1_common::log;
6use mm1_common::types::Never;
7use mm1_core::context::{Messaging, Now, Quit, Tell};
8use mm1_core::envelope::{Envelope, dispatch};
9use mm1_proto_timer as t;
10use mm1_proto_timer::Timer;
11use tokio::task;
12
13pub async fn timer_actor<T, Ctx>(ctx: &mut Ctx, receiver: Address) -> Never
14where
15 T: Timer,
16 Ctx: Messaging + Now<Instant = T::Instant> + Quit,
17{
18 match inner::<T, _>(ctx, receiver).await {
19 Ok(never) => never,
20 Err(reason) => ctx.quit_err(Error(reason)).await,
21 }
22}
23#[derive(Debug, thiserror::Error)]
24#[error("{}", _0)]
25struct Error(#[source] Box<dyn std::error::Error + Send + Sync + 'static>);
26
27async fn inner<T, Ctx>(
28 ctx: &mut Ctx,
29 receiver: Address,
30) -> Result<Never, Box<dyn std::error::Error + Send + Sync + 'static>>
31where
32 T: Timer,
33 Ctx: Messaging + Now<Instant = T::Instant> + Quit,
34{
35 let mut state: State<T> = Default::default();
36 loop {
37 let now = ctx.now();
38 let next_at = state.next_at();
39 let wake_fut = async move {
40 if let Some(at) = next_at {
41 let dt = checked_sub(at, now).unwrap_or_default();
42 T::sleep(dt).await
43 } else {
44 std::future::pending().await
45 }
46 };
47 let recv_fut = ctx.recv();
48
49 enum Selected {
50 WakeUp,
51 Envelope(Envelope),
52 }
53 let selected = tokio::select! {
54 () = wake_fut => {
55 log::trace!("wake");
56 Selected::WakeUp
57 },
58 recv_result = recv_fut => {
59 log::trace!("recv");
60 let envelope = recv_result?;
61 Selected::Envelope(envelope)
62 }
63 };
64
65 match selected {
66 Selected::Envelope(envelope) => {
67 dispatch!(match envelope {
68 t::ScheduleOnce::<T> { key, at, msg } => {
69 log::trace!("recv:schedule");
70 state.schedule_once(key, at, msg)
71 },
72 t::Cancel::<T> { key } => {
73 log::trace!("recv:cancel");
74 state.cancel(key)
75 },
76 })
77 },
78 Selected::WakeUp => {
79 let now = ctx.now();
80 while let Some(msg) = state.take_elapsed(now) {
81 log::trace!("wake:elapsed");
82
83 ctx.tell(receiver, msg).await?;
85
86 task::yield_now().await;
87 }
88 log::trace!("wake:done");
89 },
90 }
91 }
92}
93
94struct State<T: Timer> {
95 upcoming: BTreeSet<(T::Instant, T::Key)>,
96 entries: HashMap<T::Key, Entry<T::Instant, T::Message>>,
97}
98struct Entry<I, M> {
99 at: I,
100 msg_gen: MsgGen<M>,
101}
102
103enum MsgGen<M> {
104 Once(M),
105}
106
107impl<T: Timer> Default for State<T> {
108 fn default() -> Self {
109 Self {
110 upcoming: Default::default(),
111 entries: Default::default(),
112 }
113 }
114}
115
116impl<T: Timer> State<T> {
117 fn schedule_once(&mut self, key: T::Key, at: T::Instant, message: T::Message) {
118 if let Some(existing_entry) = self.entries.insert(
119 key.clone(),
120 Entry {
121 at,
122 msg_gen: MsgGen::Once(message),
123 },
124 ) {
125 let former_at = existing_entry.at;
126 let existed_before = self.upcoming.remove(&(former_at, key.clone()));
127
128 assert!(existed_before);
129 }
130
131 let newly_inserted = self.upcoming.insert((at, key));
132
133 assert!(newly_inserted);
134 }
135
136 fn cancel(&mut self, key: T::Key) {
137 let Some(existing_entry) = self.entries.remove(&key) else {
138 return
139 };
140 let Entry { at, msg_gen: _ } = existing_entry;
141 let existed = self.upcoming.remove(&(at, key));
142 assert!(existed);
143 }
144
145 fn next_at(&self) -> Option<T::Instant> {
146 self.upcoming.first().map(|(t, _k)| *t)
147 }
148
149 fn take_elapsed(&mut self, now: T::Instant) -> Option<T::Message> {
150 let _ = self.upcoming.first().filter(|(t, _)| *t <= now)?;
151
152 let (_at, key) = self
153 .upcoming
154 .pop_first()
155 .expect(".first() returned Some. We expect .pop_first() to return Some as well");
156
157 let Entry { at: _, msg_gen } = self
158 .entries
159 .remove(&key)
160 .expect("`upcoming` contains a key that does not exist in `entries`");
161
162 match msg_gen {
163 MsgGen::Once(message) => Some(message),
164 }
165 }
166}
167
168fn checked_sub<I, D>(l: I, r: I) -> Option<D>
169where
170 I: Ord + Sub<I, Output = D>,
171{
172 r.cmp(&l).is_ge().then(|| r.sub(l))
173}