1use std::fmt::{self, Debug};
2use tokio::sync::broadcast::error::RecvError;
3use tokio::sync::broadcast::{self, Receiver};
4use tokio::time::{self, Duration, Interval};
5
6#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
8pub enum Frequency {
9 OnEvent,
11 Interval(Duration),
13 OnEventWhen(Duration),
15}
16
17pub trait Throttled<F> {
20 fn parse(&self) -> F;
22}
23
24impl<T: Clone> Throttled<T> for T {
27 fn parse(&self) -> T {
28 self.clone()
29 }
30}
31
32pub struct Throttle<C, T, F> {
38 frequency: Frequency,
39 client: C,
40 call: fn(&C, F),
41 val_rx: Option<broadcast::Receiver<T>>,
42 current_val: Option<T>,
43}
44
45impl<C, T, F> fmt::Debug for Throttle<C, T, F> {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 f.debug_struct("Throttle")
48 .field("frequency", &self.frequency)
49 .field("client", &std::any::type_name::<C>().to_string())
50 .field("call", &std::any::type_name::<fn(&C, F)>().to_string())
51 .field("val_rx", &self.val_rx)
52 .field(
53 "current_val",
54 &std::any::type_name::<Option<T>>().to_string(),
55 )
56 .finish()
57 }
58}
59
60impl<C, T, F> Throttle<C, T, F>
61where
62 C: Send + Sync + 'static,
63 T: Clone + Throttled<F> + Send + Sync + 'static,
64 F: Clone + Send + Sync + 'static,
65{
66 pub fn spawn_from_receiver(
67 client: C,
68 call: fn(&C, F),
69 frequency: Frequency,
70 receiver: Receiver<T>,
71 init: Option<T>,
72 ) {
73 let mut throttle = Throttle {
74 frequency,
75 client,
76 call,
77 val_rx: Some(receiver),
78 current_val: init,
79 };
80 tokio::spawn(async move { throttle.tick().await });
81 }
82
83 pub fn spawn_interval(client: C, call: fn(&C, F), interval: Duration, val: T) {
84 let mut throttle = Throttle {
85 frequency: Frequency::Interval(interval),
86 client,
87 call,
88 val_rx: None,
89 current_val: Some(val),
90 };
91 tokio::spawn(async move { throttle.tick().await });
92 }
93
94 async fn tick(&mut self) {
95 let mut interval = match self.frequency {
96 Frequency::OnEvent => None,
97 Frequency::Interval(duration) => Some(time::interval(duration)),
98 Frequency::OnEventWhen(duration) => Some(time::interval(duration)),
99 };
100
101 if let Some(iv) = &mut interval {
102 iv.tick().await; }
104
105 self.execute_call(); let mut event_processed = true;
108 loop {
109 let received_msg = tokio::select!(
111 _ = Throttle::<C, T, F>::keep_time(&mut interval) => false,
112 res = Throttle::<C, T, F>::check_value(&mut self.val_rx) => {
113 match res {
114 Ok(val) => {
115 event_processed = false;
116 self.current_val = Some(val);
117 true
118 }
119 Err(RecvError::Closed) => {
120 log::debug!("Attached actor of type {} closed - exiting throttle", std::any::type_name::<T>());
121 break
122 }
123 Err(RecvError::Lagged(nr)) => {
124 log::debug!("Throttle of type {} lagged {nr} messages", std::any::type_name::<T>());
125 continue
126 }
127 }
128
129 },
130 );
131
132 match self.frequency {
133 Frequency::OnEvent if received_msg => self.execute_call(),
134 Frequency::Interval(_) if !received_msg => self.execute_call(),
135 Frequency::OnEventWhen(_) if !received_msg && !event_processed => {
136 event_processed = true;
137 self.execute_call()
138 }
139 _ => continue,
140 }
141 }
142 }
143
144 fn execute_call(&self) {
145 let val = if let Some(inner) = &self.current_val {
147 inner.parse()
148 } else {
149 return; };
151
152 (self.call)(&self.client, F::clone(&val));
154 }
155
156 async fn keep_time(interval: &mut Option<Interval>) {
157 if let Some(interval) = interval {
158 interval.tick().await;
159 } else {
160 std::future::pending::<()>().await;
161 }
162 }
163
164 async fn check_value(val_rx: &mut Option<broadcast::Receiver<T>>) -> Result<T, RecvError> {
165 if let Some(rx) = val_rx {
166 rx.recv().await
167 } else {
168 std::future::pending::<Result<T, RecvError>>().await
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use crate::Handle;
176
177 use super::*;
178 use std::sync::{Arc, Mutex};
179 use tokio::time::{Duration, Instant, sleep};
180
181 #[tokio::test(start_paused = true)]
182 async fn test_first_shot() {
183 let handle = Handle::new(1);
184 let counter = CounterClient::new();
185
186 handle
188 .spawn_throttle(counter.clone(), CounterClient::call, Frequency::OnEvent)
189 .await;
190 sleep(Duration::from_millis(200)).await;
191
192 let count = *counter.count.lock().unwrap();
193 assert_eq!(count, 1)
194 }
195
196 #[tokio::test(start_paused = true)]
197 async fn test_throttle_from_cache() {
198 let handle = Handle::new(1);
199 let counter = CounterClient::new();
200 let cache = handle.create_cache().await;
201
202 cache.spawn_throttle(counter.clone(), CounterClient::call, Frequency::OnEvent);
204 sleep(Duration::from_millis(200)).await;
205
206 let count = *counter.count.lock().unwrap();
207 assert_eq!(count, 1)
208 }
209
210 #[tokio::test(start_paused = true)]
211 async fn test_exit_on_shutdown() {
212 let handle = Handle::new(1);
213 let receiver = handle.subscribe();
214
215 let counter = CounterClient::new();
216
217 Throttle::spawn_from_receiver(
219 counter.clone(),
220 CounterClient::call,
221 Frequency::Interval(Duration::from_millis(100)),
222 receiver,
223 None,
224 );
225
226 sleep(Duration::from_millis(500)).await;
227
228 let count_before_drop = *counter.count.lock().unwrap();
229
230 drop(handle);
232
233 sleep(Duration::from_millis(500)).await;
234
235 let count_after_drop = *counter.count.lock().unwrap();
236
237 assert_eq!(count_before_drop, count_after_drop);
239 }
240
241 #[tokio::test(start_paused = true)]
242 async fn test_on_event() {
243 let timer = 200.;
245 let handle = Handle::new(1);
246 let mut interval = time::interval(Duration::from_millis(timer as u64));
247 interval.tick().await; let counter = CounterClient::new();
251
252 let receiver = handle.subscribe();
254 Throttle::spawn_from_receiver(
255 counter.clone(),
256 CounterClient::call,
257 Frequency::OnEvent,
258 receiver,
259 None,
260 );
261
262 interval.tick().await; handle.set(2).await; sleep(Duration::from_millis(10)).await; let time = *counter.elapsed.lock().unwrap() as f64;
267 let count = *counter.count.lock().unwrap();
268 assert_eq!(count, 1);
269 assert!((timer - time).abs() / timer < 0.1);
270 }
271
272 #[tokio::test(start_paused = true)]
273 async fn test_hot_on_event_when() {
274 let timer = 200.;
276 let handle = Handle::new(1);
277 let mut interval = time::interval(Duration::from_millis(timer as u64));
278 interval.tick().await; let counter = CounterClient::new();
282
283 let receiver = handle.subscribe();
285 Throttle::spawn_from_receiver(
286 counter.clone(),
287 CounterClient::call,
288 Frequency::OnEventWhen(Duration::from_millis(timer as u64)),
289 receiver,
290 None,
291 );
292
293 for i in 0..10 {
295 handle.set(i).await;
296 sleep(Duration::from_millis((timer / 10.) as u64)).await;
297 }
298
299 sleep(Duration::from_millis(5)).await;
300
301 let time = *counter.elapsed.lock().unwrap() as f64;
302 let count = *counter.count.lock().unwrap();
303
304 assert!((timer - time).abs() / timer < 0.1 && count == 1);
307 }
308
309 #[tokio::test(start_paused = true)]
310 async fn test_interval() {
311 let timer = 200.;
314 let mut interval = time::interval(Duration::from_millis(timer as u64));
315 interval.tick().await; let counter = CounterClient::new();
319
320 Throttle::spawn_interval(
322 counter.clone(),
323 CounterClient::call,
324 Duration::from_millis(timer as u64),
325 1,
326 );
327
328 for _ in 0..5 {
329 interval.tick().await; }
331 sleep(Duration::from_millis(20)).await; let time = *counter.elapsed.lock().unwrap() as f64;
335 let count = *counter.count.lock().unwrap();
336 assert!((timer * 5. - time).abs() / (5. * timer) < 0.1 && count == 6);
337 }
338
339 #[tokio::test(start_paused = true)]
340 async fn test_on_event_when_interval_passed() {
341 let timer = 200.;
348 let handle = Handle::new(1);
349 let mut interval = time::interval(Duration::from_millis(timer as u64));
350 interval.tick().await; let counter = CounterClient::new();
354
355 let receiver = handle.subscribe();
357 Throttle::spawn_from_receiver(
358 counter.clone(),
359 CounterClient::call,
360 Frequency::OnEventWhen(Duration::from_millis((timer * 0.55) as u64)),
361 receiver,
362 None,
363 );
364
365 interval.tick().await; handle.set(2).await; interval.tick().await;
368
369 let time = *counter.elapsed.lock().unwrap() as f64;
371 let count = *counter.count.lock().unwrap();
372 assert!((timer * 1.1 - time).abs() / (timer * 1.1) < 0.1 && count == 1);
373 }
374
375 #[tokio::test(start_paused = true)]
376 async fn test_on_event_when_too_soon() {
377 let timer = 200.;
382 let handle = Handle::new(1);
383 let mut interval = time::interval(Duration::from_millis(timer as u64));
384 interval.tick().await; let counter = CounterClient::new();
388
389 let receiver = handle.subscribe();
391 Throttle::spawn_from_receiver(
392 counter.clone(),
393 CounterClient::call,
394 Frequency::OnEventWhen(Duration::from_millis((timer * 1.5) as u64)),
395 receiver,
396 None,
397 );
398
399 interval.tick().await; handle.set(2).await; let time = *counter.elapsed.lock().unwrap();
404 let count = *counter.count.lock().unwrap();
405 assert!(count == 0);
406 assert_eq!(time, 0);
407 }
408
409 #[tokio::test(start_paused = true)]
410 async fn test_throttle_parsing() {
411 Throttle::spawn_interval(
413 DummyClient {},
414 DummyClient::call_a,
415 Duration::from_millis(100),
416 A {},
417 );
418
419 Throttle::spawn_interval(
421 DummyClient {},
422 DummyClient::call_b,
423 Duration::from_millis(100),
424 A {},
425 );
426
427 Throttle::spawn_interval(
428 DummyClient {},
429 DummyClient::call_c,
430 Duration::from_millis(100),
431 A {},
432 );
433 }
434
435 #[derive(Debug, Clone)]
436 struct A {}
437
438 #[derive(Debug, Clone)]
439 struct B {}
440
441 #[derive(Debug, Clone)]
442 struct C {}
443
444 impl Throttled<B> for A {
445 fn parse(&self) -> B {
446 B {}
447 }
448 }
449
450 impl Throttled<C> for A {
451 fn parse(&self) -> C {
452 C {}
453 }
454 }
455
456 #[derive(Debug, Clone)]
457 struct DummyClient {}
458
459 impl DummyClient {
460 fn call_a(&self, _event: A) {}
461 fn call_b(&self, _event: B) {}
462 fn call_c(&self, _event: C) {}
463 }
464
465 #[derive(Debug, Clone)]
466 struct CounterClient {
467 start: Instant,
468 elapsed: Arc<Mutex<u128>>,
469 count: Arc<Mutex<i32>>,
470 }
471
472 impl CounterClient {
473 fn new() -> Self {
474 CounterClient {
475 start: Instant::now(),
476 elapsed: Arc::new(Mutex::new(0)),
477 count: Arc::new(Mutex::new(0)),
478 }
479 }
480
481 fn call(&self, _event: i32) {
482 let mut time = self.elapsed.lock().unwrap();
483 *time = self.start.elapsed().as_millis();
484
485 let mut count = self.count.lock().unwrap();
486 *count += 1;
487 }
488 }
489}