use std::fmt::{self, Debug};
use thiserror::Error;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::RecvError;
use tokio::time::{self, Duration, Interval};
use crate::Handle;
#[derive(Error, Debug, PartialEq, Clone)]
pub enum ThrottleError {
#[error("A throttle should be initialized, listen to an update or both")]
UselessThrottle,
#[error("A throttle cannot fire on events if it does not listen to them")]
InvalidFrequency,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum Frequency {
OnEvent,
Interval(Duration),
OnEventWhen(Duration),
}
pub trait Throttled<F> {
fn parse(&self) -> F;
}
impl<T: Clone> Throttled<T> for T {
fn parse(&self) -> T {
self.clone()
}
}
struct Throttle<C, T, F> {
frequency: Frequency,
client: C,
call: fn(&C, F),
val_rx: Option<broadcast::Receiver<T>>,
cache: Option<T>,
}
impl<C, T, F> Throttle<C, T, F>
where
T: Clone + Throttled<F>,
F: Clone,
{
async fn tick(&mut self) {
let mut interval = match self.frequency {
Frequency::OnEvent => None,
Frequency::Interval(duration) => Some(time::interval(duration)),
Frequency::OnEventWhen(duration) => Some(time::interval(duration)),
};
if let Some(iv) = &mut interval {
iv.tick().await; }
self.execute_call(); let mut event_processed = true;
loop {
let received_msg = tokio::select!(
_ = Throttle::<C, T, F>::keep_time(&mut interval) => false,
res = Throttle::<C, T, F>::check_value(&mut self.val_rx) => {
match res {
Ok(val) => {
event_processed = false;
self.cache = Some(val);
true
}
Err(RecvError::Closed) => {
log::warn!("Attached actor of type {} closed - exiting throttle", std::any::type_name::<T>());
break
}
Err(RecvError::Lagged(nr)) => {
log::warn!("Throttle of type {} lagged {nr} messages", std::any::type_name::<T>());
continue
}
}
},
);
match self.frequency {
Frequency::OnEvent if received_msg => self.execute_call(),
Frequency::Interval(_) if !received_msg => self.execute_call(),
Frequency::OnEventWhen(_) if !received_msg && !event_processed => {
event_processed = true;
self.execute_call()
}
_ => continue,
}
}
}
fn execute_call(&self) {
let val = if let Some(inner) = &self.cache {
inner.parse()
} else {
return; };
(self.call)(&self.client, F::clone(&val));
}
async fn keep_time(interval: &mut Option<Interval>) {
if let Some(interval) = interval {
interval.tick().await;
} else {
loop {
time::sleep(Duration::from_secs(10)).await; }
}
}
async fn check_value(val_rx: &mut Option<broadcast::Receiver<T>>) -> Result<T, RecvError> {
if let Some(rx) = val_rx {
rx.recv().await
} else {
loop {
time::sleep(Duration::from_secs(10)).await; }
}
}
}
pub struct ThrottleBuilder<C, T, F> {
frequency: Frequency,
client: C,
call: fn(&C, F),
val_rx: Option<broadcast::Receiver<T>>,
cache: Option<T>,
}
impl<C, T, F> fmt::Debug for ThrottleBuilder<C, T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ThrottleBuilder")
.field("frequency", &self.frequency)
.field("client", &std::any::type_name::<C>().to_string())
.field("call", &std::any::type_name::<fn(&C, F)>().to_string())
.field("val_rx", &self.val_rx)
.field("cache", &std::any::type_name::<Option<T>>().to_string())
.finish()
}
}
impl<C, T, F> ThrottleBuilder<C, T, F>
where
C: Send + Sync + 'static,
T: Clone + Debug + Throttled<F> + Send + Sync + 'static,
F: Clone + Send + Sync + 'static,
{
pub fn new(client: C, call: fn(&C, F), freq: Frequency) -> ThrottleBuilder<C, T, F> {
ThrottleBuilder {
frequency: freq,
client,
call,
val_rx: None,
cache: None,
}
}
pub fn init(mut self, val: T) -> Self {
self.cache = Some(val);
self
}
pub fn attach(mut self, handle: Handle<T>) -> Self {
let receiver = handle.subscribe();
self.val_rx = Some(receiver);
self
}
pub fn attach_rx(mut self, rx: broadcast::Receiver<T>) -> Self {
self.val_rx = Some(rx);
self
}
pub fn spawn(self) -> Result<(), ThrottleError> {
let mut throttle = self.build()?; tokio::spawn(async move { throttle.tick().await });
Ok(())
}
fn build(self) -> Result<Throttle<C, T, F>, ThrottleError> {
if self.cache.is_none() && self.val_rx.is_none() {
return Err(ThrottleError::UselessThrottle);
}
if matches!(self.frequency, Frequency::OnEvent) && self.val_rx.is_none() {
return Err(ThrottleError::InvalidFrequency);
}
Ok(Throttle {
frequency: self.frequency,
client: self.client,
call: self.call,
val_rx: self.val_rx,
cache: self.cache,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
use tokio::time::{sleep, Duration, Instant};
#[tokio::test]
async fn test_exit_on_shutdown() {
let handle = Handle::new(1);
let counter = CounterClient::new();
ThrottleBuilder::new(
counter.clone(),
CounterClient::call,
Frequency::Interval(Duration::from_millis(10)),
)
.attach(handle.clone())
.spawn()
.unwrap();
handle.set(2).await.unwrap(); sleep(Duration::from_millis(200)).await;
handle.shutdown().await.unwrap();
let count_before_drop = *counter.count.lock().unwrap();
drop(handle);
sleep(Duration::from_millis(200)).await;
let count_after_drop = *counter.count.lock().unwrap();
assert_eq!(count_before_drop, count_after_drop);
}
#[tokio::test]
async fn test_on_event() {
let timer = 200.;
let handle = Handle::new(1);
let mut interval = time::interval(Duration::from_millis(timer as u64));
interval.tick().await; let counter = CounterClient::new();
ThrottleBuilder::new(counter.clone(), CounterClient::call, Frequency::OnEvent)
.attach(handle.clone())
.spawn()
.unwrap();
interval.tick().await; handle.set(2).await.unwrap(); sleep(Duration::from_millis(10)).await; let time = *counter.elapsed.lock().unwrap() as f64;
let count = *counter.count.lock().unwrap();
assert!((timer - time).abs() / timer < 0.1 && count == 1);
}
#[tokio::test]
async fn test_hot_on_event_when() {
let timer = 200.;
let handle = Handle::new(1);
let mut interval = time::interval(Duration::from_millis(timer as u64));
interval.tick().await; let counter = CounterClient::new();
ThrottleBuilder::new(
counter.clone(),
CounterClient::call,
Frequency::OnEventWhen(Duration::from_millis(timer as u64)),
)
.attach(handle.clone())
.spawn()
.unwrap();
for i in 0..10 {
handle.set(i).await.unwrap();
sleep(Duration::from_millis((timer / 10.) as u64)).await;
}
sleep(Duration::from_millis(5)).await;
let time = *counter.elapsed.lock().unwrap() as f64;
let count = *counter.count.lock().unwrap();
assert!((timer - time).abs() / timer < 0.1 && count == 1);
}
#[tokio::test]
async fn test_interval() {
let timer = 200.;
let mut interval = time::interval(Duration::from_millis(timer as u64));
interval.tick().await; let counter = CounterClient::new();
ThrottleBuilder::new(
counter.clone(),
CounterClient::call,
Frequency::Interval(Duration::from_millis(timer as u64)),
)
.init(1)
.spawn()
.unwrap();
for _ in 0..5 {
interval.tick().await; }
sleep(Duration::from_millis(20)).await; let time = *counter.elapsed.lock().unwrap() as f64;
let count = *counter.count.lock().unwrap();
assert!((timer * 5. - time).abs() / (5. * timer) < 0.1 && count == 6);
}
#[tokio::test]
async fn test_on_event_when_interval_passed() {
let timer = 200.;
let handle = Handle::new(1);
let mut interval = time::interval(Duration::from_millis(timer as u64));
interval.tick().await; let counter = CounterClient::new();
ThrottleBuilder::new(
counter.clone(),
CounterClient::call,
Frequency::OnEventWhen(Duration::from_millis((timer * 0.55) as u64)),
)
.attach(handle.clone())
.spawn()
.unwrap();
interval.tick().await; handle.set(2).await.unwrap(); interval.tick().await;
let time = *counter.elapsed.lock().unwrap() as f64;
let count = *counter.count.lock().unwrap();
assert!((timer * 1.1 - time).abs() / (timer * 1.1) < 0.1 && count == 1);
}
#[tokio::test]
async fn test_on_event_when_too_soon() {
let timer = 200.;
let handle = Handle::new(1);
let mut interval = time::interval(Duration::from_millis(timer as u64));
interval.tick().await; let counter = CounterClient::new();
ThrottleBuilder::new(
counter.clone(),
CounterClient::call,
Frequency::OnEventWhen(Duration::from_millis((timer * 1.5) as u64)),
)
.attach(handle.clone())
.spawn()
.unwrap();
interval.tick().await; handle.set(2).await.unwrap(); let time = *counter.elapsed.lock().unwrap();
let count = *counter.count.lock().unwrap();
assert!(count == 0);
assert_eq!(time, 0);
}
#[tokio::test]
async fn test_throttle_parsing() {
ThrottleBuilder::new(
DummyClient {},
DummyClient::call_a,
Frequency::Interval(Duration::from_millis(100)),
)
.init(A {})
.build()
.unwrap();
ThrottleBuilder::new(
DummyClient {},
DummyClient::call_b,
Frequency::Interval(Duration::from_millis(100)),
)
.init(A {})
.build()
.unwrap();
ThrottleBuilder::new(
DummyClient {},
DummyClient::call_c,
Frequency::Interval(Duration::from_millis(100)),
)
.init(A {})
.build()
.unwrap();
}
#[derive(Debug, Clone)]
struct A {}
#[derive(Debug, Clone)]
struct B {}
#[derive(Debug, Clone)]
struct C {}
impl Throttled<B> for A {
fn parse(&self) -> B {
B {}
}
}
impl Throttled<C> for A {
fn parse(&self) -> C {
C {}
}
}
#[derive(Debug, Clone)]
struct DummyClient {}
impl DummyClient {
fn call_a(&self, _event: A) {}
fn call_b(&self, _event: B) {}
fn call_c(&self, _event: C) {}
}
#[derive(Debug, Clone)]
struct CounterClient {
start: Instant,
elapsed: Arc<Mutex<u128>>,
count: Arc<Mutex<i32>>,
}
impl CounterClient {
fn new() -> Self {
CounterClient {
start: Instant::now(),
elapsed: Arc::new(Mutex::new(0)),
count: Arc::new(Mutex::new(0)),
}
}
fn call(&self, _event: i32) {
let mut time = self.elapsed.lock().unwrap();
*time = self.start.elapsed().as_millis();
let mut count = self.count.lock().unwrap();
*count += 1;
}
}
}