1#[doc(hidden)]
7extern crate alloc;
8
9#[cfg(feature = "std")]
10extern crate std;
11
12use alloc::{boxed::Box, string::String, sync::Arc};
14use core::{fmt::Debug, time::Duration};
15use dimas_core::{
16 Result,
17 enums::{OperationState, TaskSignal},
18 traits::{Capability, Context},
19};
20#[cfg(feature = "std")]
21use std::sync::Mutex;
22#[cfg(feature = "std")]
23use tokio::{task::JoinHandle, time};
24use tracing::{Level, error, info, instrument, warn};
25pub type ArcTimerCallback<P> =
30 Arc<Mutex<dyn FnMut(Context<P>) -> Result<()> + Send + Sync + 'static>>;
31pub enum Timer<P>
36where
37 P: Send + Sync + 'static,
38{
39 Interval {
41 selector: String,
43 context: Context<P>,
45 activation_state: OperationState,
47 callback: ArcTimerCallback<P>,
49 interval: Duration,
51 handle: Mutex<Option<JoinHandle<()>>>,
53 },
54 DelayedInterval {
56 selector: String,
58 context: Context<P>,
60 activation_state: OperationState,
62 callback: ArcTimerCallback<P>,
64 interval: Duration,
66 delay: Duration,
68 handle: Mutex<Option<JoinHandle<()>>>,
70 },
71}
72
73impl<P> Debug for Timer<P>
74where
75 P: Send + Sync + 'static,
76{
77 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78 match self {
79 Self::Interval { interval, .. } => f
80 .debug_struct("IntervalTimer")
81 .field("interval", interval)
82 .finish_non_exhaustive(),
83 Self::DelayedInterval {
84 delay, interval, ..
85 } => f
86 .debug_struct("DelayedIntervalTimer")
87 .field("delay", delay)
88 .field("interval", interval)
89 .finish_non_exhaustive(),
90 }
91 }
92}
93
94impl<P> Capability for Timer<P>
95where
96 P: Send + Sync + 'static,
97{
98 fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
99 match self {
100 Self::Interval {
101 selector: _,
102 context: _,
103 activation_state,
104 interval: _,
105 callback: _,
106 handle: _,
107 }
108 | Self::DelayedInterval {
109 selector: _,
110 context: _,
111 activation_state,
112 delay: _,
113 interval: _,
114 callback: _,
115 handle: _,
116 } => {
117 if state >= activation_state {
118 self.start()
119 } else if state < activation_state {
120 self.stop()
121 } else {
122 Ok(())
123 }
124 }
125 }
126 }
127}
128
129impl<P> Timer<P>
130where
131 P: Send + Sync + 'static,
132{
133 #[must_use]
135 pub fn new(
136 name: String,
137 context: Context<P>,
138 activation_state: OperationState,
139 callback: ArcTimerCallback<P>,
140 interval: Duration,
141 delay: Option<Duration>,
142 ) -> Self {
143 match delay {
144 Some(delay) => Self::DelayedInterval {
145 selector: name,
146 context,
147 activation_state,
148 delay,
149 interval,
150 callback,
151 handle: Mutex::new(None),
152 },
153 None => Self::Interval {
154 selector: name,
155 context,
156 activation_state,
157 interval,
158 callback,
159 handle: Mutex::new(None),
160 },
161 }
162 }
163
164 #[allow(clippy::cognitive_complexity)]
167 #[instrument(level = Level::TRACE, skip_all)]
168 fn start(&self) -> Result<()> {
169 self.stop()?;
170
171 match self {
172 Self::Interval {
173 selector,
174 context,
175 activation_state: _,
176 interval,
177 callback,
178 handle,
179 } => {
180 {
182 if callback.lock().is_err() {
183 warn!("found poisoned Mutex");
184 callback.clear_poison();
185 }
186 }
187
188 let key = selector.clone();
189 let interval = *interval;
190 let cb = callback.clone();
191 let ctx1 = context.clone();
192 let ctx2 = context.clone();
193
194 handle.lock().map_or_else(
195 |_| todo!(),
196 |mut handle| {
197 handle.replace(tokio::task::spawn(async move {
198 std::panic::set_hook(Box::new(move |reason| {
199 error!("delayed timer panic: {}", reason);
200 if let Err(reason) = ctx1
201 .sender()
202 .blocking_send(TaskSignal::RestartTimer(key.clone()))
203 {
204 error!("could not restart timer: {}", reason);
205 } else {
206 info!("restarting timer!");
207 }
208 }));
209 run_timer(interval, cb, ctx2).await;
210 }));
211 Ok(())
212 },
213 )
214 }
215 Self::DelayedInterval {
216 selector,
217 context,
218 activation_state: _,
219 delay,
220 interval,
221 callback,
222 handle,
223 } => {
224 {
226 if callback.lock().is_err() {
227 warn!("found poisoned Mutex");
228 callback.clear_poison();
229 }
230 }
231
232 let key = selector.clone();
233 let delay = *delay;
234 let interval = *interval;
235 let cb = callback.clone();
236 let ctx1 = context.clone();
237 let ctx2 = context.clone();
238
239 handle.lock().map_or_else(
240 |_| todo!(),
241 |mut handle| {
242 handle.replace(tokio::task::spawn(async move {
243 std::panic::set_hook(Box::new(move |reason| {
244 error!("delayed timer panic: {}", reason);
245 if let Err(reason) = ctx1
246 .sender()
247 .blocking_send(TaskSignal::RestartTimer(key.clone()))
248 {
249 error!("could not restart timer: {}", reason);
250 } else {
251 info!("restarting timer!");
252 }
253 }));
254 tokio::time::sleep(delay).await;
255 run_timer(interval, cb, ctx2).await;
256 }));
257 Ok(())
258 },
259 )
260 }
261 }
262 }
263
264 #[instrument(level = Level::TRACE, skip_all)]
266 fn stop(&self) -> Result<()> {
267 match self {
268 Self::Interval {
269 selector: _,
270 context: _,
271 activation_state: _,
272 interval: _,
273 callback: _,
274 handle,
275 }
276 | Self::DelayedInterval {
277 selector: _,
278 context: _,
279 activation_state: _,
280 delay: _,
281 interval: _,
282 callback: _,
283 handle,
284 } => handle.lock().map_or_else(
285 |_| todo!(),
286 |mut handle| {
287 if let Some(handle) = handle.take() {
288 handle.abort();
289 }
290 Ok(())
291 },
292 ),
293 }
294 }
295}
296
297#[instrument(name="timer", level = Level::ERROR, skip_all)]
298async fn run_timer<P>(interval: Duration, cb: ArcTimerCallback<P>, ctx: Context<P>)
299where
300 P: Send + Sync + 'static,
301{
302 let mut interval = time::interval(interval);
303 loop {
304 let ctx = ctx.clone();
305 interval.tick().await;
306
307 match cb.lock() {
308 Ok(mut cb) => {
309 if let Err(error) = cb(ctx) {
310 error!("callback failed with {error}");
311 }
312 }
313 Err(err) => {
314 error!("callback lock failed with {err}");
315 }
316 }
317 }
318}
319#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[derive(Debug)]
326 struct Props {}
327
328 const fn is_normal<T: Sized + Send + Sync>() {}
330
331 #[test]
332 const fn normal_types() {
333 is_normal::<Timer<Props>>();
334 }
335}