Skip to main content

actify/
throttle.rs

1use std::fmt::{self, Debug};
2use tokio::sync::broadcast::error::RecvError;
3use tokio::sync::broadcast::{self, Receiver};
4use tokio::time::{self, Duration, Interval};
5
6/// The Frequency is used to tune the speed of a [`Throttle`].
7#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
8pub enum Frequency {
9    /// Fires any time an event arrives. Designed for infrequent but important events.
10    OnEvent,
11    /// Fires every interval, regardless of incoming events.
12    Interval(Duration),
13    /// Fires for an event only after the interval has passed. Designed for high-throughput types.
14    OnEventWhen(Duration),
15}
16
17/// The Throttled trait can be implemented to parse the type held by the actor to a custom output type.
18/// This allows a single [`Handle`](crate::Handle) to attach itself to multiple throttles, each with a separate parsing implementation.
19pub trait Throttled<F> {
20    /// Implement this parse function on the type to be sent by the throttle
21    fn parse(&self) -> F;
22}
23
24// TODO add a derive macro for Throttled derivation for self
25/// A blanket implementation is used to ensure any standard type implements it
26impl<T: Clone> Throttled<T> for T {
27    fn parse(&self) -> T {
28        self.clone()
29    }
30}
31
32/// Rate-limits broadcasted updates from a [`Handle`](crate::Handle) or [`Cache`](crate::Cache)
33/// before forwarding them to a callback.
34///
35/// Configure the rate with [`Frequency`]. The actor type must implement [`Throttled<F>`](Throttled)
36/// to convert the actor value into the callback argument type `F`.
37pub struct Throttle<C, T, F> {
38    frequency: Frequency,
39    client: C,
40    call: fn(&C, F),
41    val_rx: Option<broadcast::Receiver<T>>,
42    current_val: Option<T>,
43}
44
45impl<C, T, F> fmt::Debug for Throttle<C, T, F> {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        f.debug_struct("Throttle")
48            .field("frequency", &self.frequency)
49            .field("client", &std::any::type_name::<C>().to_string())
50            .field("call", &std::any::type_name::<fn(&C, F)>().to_string())
51            .field("val_rx", &self.val_rx)
52            .field(
53                "current_val",
54                &std::any::type_name::<Option<T>>().to_string(),
55            )
56            .finish()
57    }
58}
59
60impl<C, T, F> Throttle<C, T, F>
61where
62    C: Send + Sync + 'static,
63    T: Clone + Throttled<F> + Send + Sync + 'static,
64    F: Clone + Send + Sync + 'static,
65{
66    pub fn spawn_from_receiver(
67        client: C,
68        call: fn(&C, F),
69        frequency: Frequency,
70        receiver: Receiver<T>,
71        init: Option<T>,
72    ) {
73        let mut throttle = Throttle {
74            frequency,
75            client,
76            call,
77            val_rx: Some(receiver),
78            current_val: init,
79        };
80        tokio::spawn(async move { throttle.tick().await });
81    }
82
83    pub fn spawn_interval(client: C, call: fn(&C, F), interval: Duration, val: T) {
84        let mut throttle = Throttle {
85            frequency: Frequency::Interval(interval),
86            client,
87            call,
88            val_rx: None,
89            current_val: Some(val),
90        };
91        tokio::spawn(async move { throttle.tick().await });
92    }
93
94    async fn tick(&mut self) {
95        let mut interval = match self.frequency {
96            Frequency::OnEvent => None,
97            Frequency::Interval(duration) => Some(time::interval(duration)),
98            Frequency::OnEventWhen(duration) => Some(time::interval(duration)),
99        };
100
101        if let Some(iv) = &mut interval {
102            iv.tick().await; // First tick completes immediately, so ignore by calling prior
103        }
104
105        self.execute_call(); // Always execute the call once in case it was initialized
106
107        let mut event_processed = true;
108        loop {
109            // Wait or update cache
110            let received_msg = tokio::select!(
111                _ = Throttle::<C, T, F>::keep_time(&mut interval) => false,
112                res = Throttle::<C, T, F>::check_value(&mut self.val_rx) => {
113                    match res {
114                        Ok(val) => {
115                            event_processed = false;
116                            self.current_val = Some(val);
117                            true
118                        }
119                        Err(RecvError::Closed) => {
120                            log::debug!("Attached actor of type {} closed - exiting throttle", std::any::type_name::<T>());
121                            break
122                        }
123                        Err(RecvError::Lagged(nr)) => {
124                            log::debug!("Throttle of type {} lagged {nr} messages", std::any::type_name::<T>());
125                            continue
126                        }
127                    }
128
129                },
130            );
131
132            match self.frequency {
133                Frequency::OnEvent if received_msg => self.execute_call(),
134                Frequency::Interval(_) if !received_msg => self.execute_call(),
135                Frequency::OnEventWhen(_) if !received_msg && !event_processed => {
136                    event_processed = true;
137                    self.execute_call()
138                }
139                _ => continue,
140            }
141        }
142    }
143
144    fn execute_call(&self) {
145        // Either parse the value to a different type F, or to itself when T = F
146        let val = if let Some(inner) = &self.current_val {
147            inner.parse()
148        } else {
149            return; // If cache empty, skip call
150        };
151
152        // Perform the call
153        (self.call)(&self.client, F::clone(&val));
154    }
155
156    async fn keep_time(interval: &mut Option<Interval>) {
157        if let Some(interval) = interval {
158            interval.tick().await;
159        } else {
160            std::future::pending::<()>().await;
161        }
162    }
163
164    async fn check_value(val_rx: &mut Option<broadcast::Receiver<T>>) -> Result<T, RecvError> {
165        if let Some(rx) = val_rx {
166            rx.recv().await
167        } else {
168            std::future::pending::<Result<T, RecvError>>().await
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use crate::Handle;
176
177    use super::*;
178    use std::sync::{Arc, Mutex};
179    use tokio::time::{Duration, Instant, sleep};
180
181    #[tokio::test(start_paused = true)]
182    async fn test_first_shot() {
183        let handle = Handle::new(1);
184        let counter = CounterClient::new();
185
186        // Spawn throttle that should only activate once on creation
187        handle
188            .spawn_throttle(counter.clone(), CounterClient::call, Frequency::OnEvent)
189            .await;
190        sleep(Duration::from_millis(200)).await;
191
192        let count = *counter.count.lock().unwrap();
193        assert_eq!(count, 1)
194    }
195
196    #[tokio::test(start_paused = true)]
197    async fn test_throttle_from_cache() {
198        let handle = Handle::new(1);
199        let counter = CounterClient::new();
200        let cache = handle.create_cache().await;
201
202        // Spawn throttle that should only activate once on creation
203        cache.spawn_throttle(counter.clone(), CounterClient::call, Frequency::OnEvent);
204        sleep(Duration::from_millis(200)).await;
205
206        let count = *counter.count.lock().unwrap();
207        assert_eq!(count, 1)
208    }
209
210    #[tokio::test(start_paused = true)]
211    async fn test_exit_on_shutdown() {
212        let handle = Handle::new(1);
213        let receiver = handle.subscribe();
214
215        let counter = CounterClient::new();
216
217        // Spawn throttle
218        Throttle::spawn_from_receiver(
219            counter.clone(),
220            CounterClient::call,
221            Frequency::Interval(Duration::from_millis(100)),
222            receiver,
223            None,
224        );
225
226        sleep(Duration::from_millis(500)).await;
227
228        let count_before_drop = *counter.count.lock().unwrap();
229
230        // The throttle will stop, as no handles are present anymore
231        drop(handle);
232
233        sleep(Duration::from_millis(500)).await;
234
235        let count_after_drop = *counter.count.lock().unwrap();
236
237        // No updates have arrived even though the frequency is a constant interval, as the throttle has exited
238        assert_eq!(count_before_drop, count_after_drop);
239    }
240
241    #[tokio::test(start_paused = true)]
242    async fn test_on_event() {
243        // The Handle update event should be received directly after the interval has passed
244        let timer = 200.;
245        let handle = Handle::new(1);
246        let mut interval = time::interval(Duration::from_millis(timer as u64));
247        interval.tick().await; // Completed immediately
248
249        // Start counter
250        let counter = CounterClient::new();
251
252        // Spawn throttle
253        let receiver = handle.subscribe();
254        Throttle::spawn_from_receiver(
255            counter.clone(),
256            CounterClient::call,
257            Frequency::OnEvent,
258            receiver,
259            None,
260        );
261
262        interval.tick().await; // Should wait up to exactly 200ms
263        handle.set(2).await; // Update handle, firing event
264        sleep(Duration::from_millis(10)).await; // Allow call to be executed to happen
265
266        let time = *counter.elapsed.lock().unwrap() as f64;
267        let count = *counter.count.lock().unwrap();
268        assert_eq!(count, 1);
269        assert!((timer - time).abs() / timer < 0.1);
270    }
271
272    #[tokio::test(start_paused = true)]
273    async fn test_hot_on_event_when() {
274        // The Handle update event should be received directly after the interval has passed
275        let timer = 200.;
276        let handle = Handle::new(1);
277        let mut interval = time::interval(Duration::from_millis(timer as u64));
278        interval.tick().await; // Completed immediately
279
280        // Start counter
281        let counter = CounterClient::new();
282
283        // Spawn throttle
284        let receiver = handle.subscribe();
285        Throttle::spawn_from_receiver(
286            counter.clone(),
287            CounterClient::call,
288            Frequency::OnEventWhen(Duration::from_millis(timer as u64)),
289            receiver,
290            None,
291        );
292
293        // Many updates are triggered in quick succesion
294        for i in 0..10 {
295            handle.set(i).await;
296            sleep(Duration::from_millis((timer / 10.) as u64)).await;
297        }
298
299        sleep(Duration::from_millis(5)).await;
300
301        let time = *counter.elapsed.lock().unwrap() as f64;
302        let count = *counter.count.lock().unwrap();
303
304        // Still the counter has been invoked 1 time
305        // The interval has not been exceeded between calls, but it did since the last update
306        assert!((timer - time).abs() / timer < 0.1 && count == 1);
307    }
308
309    #[tokio::test(start_paused = true)]
310    async fn test_interval() {
311        // The interval passed to the throttle used to send the value each time
312
313        let timer = 200.;
314        let mut interval = time::interval(Duration::from_millis(timer as u64));
315        interval.tick().await; // Completed immediately
316
317        // Start counter
318        let counter = CounterClient::new();
319
320        // Spawn throttle
321        Throttle::spawn_interval(
322            counter.clone(),
323            CounterClient::call,
324            Duration::from_millis(timer as u64),
325            1,
326        );
327
328        for _ in 0..5 {
329            interval.tick().await; // Should wait up to exactly 200ms
330        }
331        sleep(Duration::from_millis(20)).await; // Allow last call to be processed
332
333        // All updates should be processed
334        let time = *counter.elapsed.lock().unwrap() as f64;
335        let count = *counter.count.lock().unwrap();
336        assert!((timer * 5. - time).abs() / (5. * timer) < 0.1 && count == 6);
337    }
338
339    #[tokio::test(start_paused = true)]
340    async fn test_on_event_when_interval_passed() {
341        // The interval passed to the throttle is shorter than the time to the event, so its value is passed to the client call
342        // Throttle interval passes at 0.55 timer, does nothing
343        // Event fires at 1. timer
344        // Throttle interval passes at 1.1 timer, and processes event
345        // Throttle interval passes at 1.65 timer, does nothing
346
347        let timer = 200.;
348        let handle = Handle::new(1);
349        let mut interval = time::interval(Duration::from_millis(timer as u64));
350        interval.tick().await; // Completed immediately
351
352        // Start counter
353        let counter = CounterClient::new();
354
355        // Spawn throttle
356        let receiver = handle.subscribe();
357        Throttle::spawn_from_receiver(
358            counter.clone(),
359            CounterClient::call,
360            Frequency::OnEventWhen(Duration::from_millis((timer * 0.55) as u64)),
361            receiver,
362            None,
363        );
364
365        interval.tick().await; // Should wait up to exactly 200ms
366        handle.set(2).await; // Update handle, firing event
367        interval.tick().await;
368
369        // Update should be received directly after the interval
370        let time = *counter.elapsed.lock().unwrap() as f64;
371        let count = *counter.count.lock().unwrap();
372        assert!((timer * 1.1 - time).abs() / (timer * 1.1) < 0.1 && count == 1);
373    }
374
375    #[tokio::test(start_paused = true)]
376    async fn test_on_event_when_too_soon() {
377        // The interval passed to the throttle is longer than the time to the event, so its value is disregarded
378        // Event fires at 1. timer
379        // Test terminates before throttle interval passed at 1.5 timer
380
381        let timer = 200.;
382        let handle = Handle::new(1);
383        let mut interval = time::interval(Duration::from_millis(timer as u64));
384        interval.tick().await; // Completed immediately
385
386        // Start counter
387        let counter = CounterClient::new();
388
389        // Spawn throttle
390        let receiver = handle.subscribe();
391        Throttle::spawn_from_receiver(
392            counter.clone(),
393            CounterClient::call,
394            Frequency::OnEventWhen(Duration::from_millis((timer * 1.5) as u64)),
395            receiver,
396            None,
397        );
398
399        interval.tick().await; // Should wait up to exactly 200ms
400        handle.set(2).await; // Update handle, firing event
401
402        // Update should not be processed
403        let time = *counter.elapsed.lock().unwrap();
404        let count = *counter.count.lock().unwrap();
405        assert!(count == 0);
406        assert_eq!(time, 0);
407    }
408
409    #[tokio::test(start_paused = true)]
410    async fn test_throttle_parsing() {
411        // Parsing to self should succeed
412        Throttle::spawn_interval(
413            DummyClient {},
414            DummyClient::call_a,
415            Duration::from_millis(100),
416            A {},
417        );
418
419        // Parsing to either B or C should be infered by the compiler
420        Throttle::spawn_interval(
421            DummyClient {},
422            DummyClient::call_b,
423            Duration::from_millis(100),
424            A {},
425        );
426
427        Throttle::spawn_interval(
428            DummyClient {},
429            DummyClient::call_c,
430            Duration::from_millis(100),
431            A {},
432        );
433    }
434
435    #[derive(Debug, Clone)]
436    struct A {}
437
438    #[derive(Debug, Clone)]
439    struct B {}
440
441    #[derive(Debug, Clone)]
442    struct C {}
443
444    impl Throttled<B> for A {
445        fn parse(&self) -> B {
446            B {}
447        }
448    }
449
450    impl Throttled<C> for A {
451        fn parse(&self) -> C {
452            C {}
453        }
454    }
455
456    #[derive(Debug, Clone)]
457    struct DummyClient {}
458
459    impl DummyClient {
460        fn call_a(&self, _event: A) {}
461        fn call_b(&self, _event: B) {}
462        fn call_c(&self, _event: C) {}
463    }
464
465    #[derive(Debug, Clone)]
466    struct CounterClient {
467        start: Instant,
468        elapsed: Arc<Mutex<u128>>,
469        count: Arc<Mutex<i32>>,
470    }
471
472    impl CounterClient {
473        fn new() -> Self {
474            CounterClient {
475                start: Instant::now(),
476                elapsed: Arc::new(Mutex::new(0)),
477                count: Arc::new(Mutex::new(0)),
478            }
479        }
480
481        fn call(&self, _event: i32) {
482            let mut time = self.elapsed.lock().unwrap();
483            *time = self.start.elapsed().as_millis();
484
485            let mut count = self.count.lock().unwrap();
486            *count += 1;
487        }
488    }
489}