use std::time::{Duration, Instant};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::{policy::Policy, error::DoOverError};
pub struct RateLimiter {
capacity: u64,
interval: Duration,
state: Arc<Mutex<(u64, Instant)>>,
}
impl Clone for RateLimiter {
fn clone(&self) -> Self {
Self {
capacity: self.capacity,
interval: self.interval,
state: Arc::clone(&self.state),
}
}
}
impl RateLimiter {
pub fn new(capacity: u64, interval: Duration) -> Self {
Self {
capacity,
interval,
state: Arc::new(Mutex::new((capacity, Instant::now()))),
}
}
}
#[async_trait::async_trait]
impl<E> Policy<DoOverError<E>> for RateLimiter
where
E: Send + Sync,
{
async fn execute<F, Fut, T>(&self, f: F) -> Result<T, DoOverError<E>>
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T, DoOverError<E>>> + Send,
T: Send,
{
let mut state = self.state.lock().await;
if state.1.elapsed() >= self.interval {
state.0 = self.capacity;
state.1 = Instant::now();
}
if state.0 == 0 {
return Err(DoOverError::BulkheadFull);
}
state.0 -= 1;
drop(state);
f().await
}
}