1#[doc(hidden)]
7extern crate alloc;
8
9#[cfg(feature = "std")]
10extern crate std;
11
12use alloc::{boxed::Box, string::String, sync::Arc};
14use core::{fmt::Debug, time::Duration};
15use dimas_core::{
16 Result,
17 enums::{OperationState, TaskSignal},
18 traits::{Capability, Context},
19};
20#[cfg(feature = "std")]
21use std::sync::Mutex;
22#[cfg(feature = "std")]
23use tokio::{task::JoinHandle, time};
24use tracing::{Level, error, info, instrument, warn};
25pub type ArcTimerCallback<P> =
30 Arc<Mutex<dyn FnMut(Context<P>) -> Result<()> + Send + Sync + 'static>>;
31pub enum Timer<P>
36where
37 P: Send + Sync + 'static,
38{
39 Interval {
41 selector: String,
43 context: Context<P>,
45 activation_state: OperationState,
47 callback: ArcTimerCallback<P>,
49 interval: Duration,
51 handle: Mutex<Option<JoinHandle<()>>>,
53 },
54 DelayedInterval {
56 selector: String,
58 context: Context<P>,
60 activation_state: OperationState,
62 callback: ArcTimerCallback<P>,
64 interval: Duration,
66 delay: Duration,
68 handle: Mutex<Option<JoinHandle<()>>>,
70 },
71}
72
73impl<P> Debug for Timer<P>
74where
75 P: Send + Sync + 'static,
76{
77 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78 match self {
79 Self::Interval { interval, .. } => f
80 .debug_struct("IntervalTimer")
81 .field("interval", interval)
82 .finish_non_exhaustive(),
83 Self::DelayedInterval {
84 delay, interval, ..
85 } => f
86 .debug_struct("DelayedIntervalTimer")
87 .field("delay", delay)
88 .field("interval", interval)
89 .finish_non_exhaustive(),
90 }
91 }
92}
93
94impl<P> Capability for Timer<P>
95where
96 P: Send + Sync + 'static,
97{
98 fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
99 match self {
100 Self::Interval {
101 selector: _,
102 context: _,
103 activation_state,
104 interval: _,
105 callback: _,
106 handle: _,
107 }
108 | Self::DelayedInterval {
109 selector: _,
110 context: _,
111 activation_state,
112 delay: _,
113 interval: _,
114 callback: _,
115 handle: _,
116 } => {
117 if state >= activation_state {
118 self.start()
119 } else if state < activation_state {
120 self.stop()
121 } else {
122 Ok(())
123 }
124 }
125 }
126 }
127}
128
129impl<P> Timer<P>
130where
131 P: Send + Sync + 'static,
132{
133 #[must_use]
135 pub fn new(
136 name: String,
137 context: Context<P>,
138 activation_state: OperationState,
139 callback: ArcTimerCallback<P>,
140 interval: Duration,
141 delay: Option<Duration>,
142 ) -> Self {
143 match delay {
144 Some(delay) => Self::DelayedInterval {
145 selector: name,
146 context,
147 activation_state,
148 delay,
149 interval,
150 callback,
151 handle: Mutex::new(None),
152 },
153 None => Self::Interval {
154 selector: name,
155 context,
156 activation_state,
157 interval,
158 callback,
159 handle: Mutex::new(None),
160 },
161 }
162 }
163
164 #[instrument(level = Level::TRACE, skip_all)]
167 fn start(&self) -> Result<()> {
168 self.stop()?;
169
170 match self {
171 Self::Interval {
172 selector,
173 context,
174 activation_state: _,
175 interval,
176 callback,
177 handle,
178 } => {
179 {
181 if callback.lock().is_err() {
182 warn!("found poisoned Mutex");
183 callback.clear_poison();
184 }
185 }
186
187 let key = selector.clone();
188 let interval = *interval;
189 let cb = callback.clone();
190 let ctx1 = context.clone();
191 let ctx2 = context.clone();
192
193 handle.lock().map_or_else(
194 |_| todo!(),
195 |mut handle| {
196 handle.replace(tokio::task::spawn(async move {
197 std::panic::set_hook(Box::new(move |reason| {
198 error!("delayed timer panic: {}", reason);
199 if let Err(reason) = ctx1
200 .sender()
201 .blocking_send(TaskSignal::RestartTimer(key.clone()))
202 {
203 error!("could not restart timer: {}", reason);
204 } else {
205 info!("restarting timer!");
206 }
207 }));
208 run_timer(interval, cb, ctx2).await;
209 }));
210 Ok(())
211 },
212 )
213 }
214 Self::DelayedInterval {
215 selector,
216 context,
217 activation_state: _,
218 delay,
219 interval,
220 callback,
221 handle,
222 } => {
223 {
225 if callback.lock().is_err() {
226 warn!("found poisoned Mutex");
227 callback.clear_poison();
228 }
229 }
230
231 let key = selector.clone();
232 let delay = *delay;
233 let interval = *interval;
234 let cb = callback.clone();
235 let ctx1 = context.clone();
236 let ctx2 = context.clone();
237
238 handle.lock().map_or_else(
239 |_| todo!(),
240 |mut handle| {
241 handle.replace(tokio::task::spawn(async move {
242 std::panic::set_hook(Box::new(move |reason| {
243 error!("delayed timer panic: {}", reason);
244 if let Err(reason) = ctx1
245 .sender()
246 .blocking_send(TaskSignal::RestartTimer(key.clone()))
247 {
248 error!("could not restart timer: {}", reason);
249 } else {
250 info!("restarting timer!");
251 }
252 }));
253 tokio::time::sleep(delay).await;
254 run_timer(interval, cb, ctx2).await;
255 }));
256 Ok(())
257 },
258 )
259 }
260 }
261 }
262
263 #[instrument(level = Level::TRACE, skip_all)]
265 fn stop(&self) -> Result<()> {
266 match self {
267 Self::Interval {
268 selector: _,
269 context: _,
270 activation_state: _,
271 interval: _,
272 callback: _,
273 handle,
274 }
275 | Self::DelayedInterval {
276 selector: _,
277 context: _,
278 activation_state: _,
279 delay: _,
280 interval: _,
281 callback: _,
282 handle,
283 } => handle.lock().map_or_else(
284 |_| todo!(),
285 |mut handle| {
286 if let Some(handle) = handle.take() {
287 handle.abort();
288 }
289 Ok(())
290 },
291 ),
292 }
293 }
294}
295
296#[instrument(name="timer", level = Level::ERROR, skip_all)]
297async fn run_timer<P>(interval: Duration, cb: ArcTimerCallback<P>, ctx: Context<P>)
298where
299 P: Send + Sync + 'static,
300{
301 let mut interval = time::interval(interval);
302 loop {
303 let ctx = ctx.clone();
304 interval.tick().await;
305
306 match cb.lock() {
307 Ok(mut cb) => {
308 if let Err(error) = cb(ctx) {
309 error!("callback failed with {error}");
310 }
311 }
312 Err(err) => {
313 error!("callback lock failed with {err}");
314 }
315 }
316 }
317}
318#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[derive(Debug)]
325 struct Props {}
326
327 const fn is_normal<T: Sized + Send + Sync>() {}
329
330 #[test]
331 const fn normal_types() {
332 is_normal::<Timer<Props>>();
333 }
334}