use super::backend::ThrottleBackend;
use super::{Throttle, ThrottleError, ThrottleResult};
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct BurstRateThrottle<B: ThrottleBackend> {
backend: Arc<Mutex<B>>,
burst_rate: usize,
sustained_rate: usize,
burst_duration: std::time::Duration,
sustained_duration: std::time::Duration,
}
impl<B: ThrottleBackend> BurstRateThrottle<B> {
pub fn new(
backend: Arc<Mutex<B>>,
burst_rate: usize,
sustained_rate: usize,
burst_duration: std::time::Duration,
sustained_duration: std::time::Duration,
) -> Self {
Self {
backend,
burst_rate,
sustained_rate,
burst_duration,
sustained_duration,
}
}
}
#[async_trait::async_trait]
impl<B: ThrottleBackend> Throttle for BurstRateThrottle<B> {
async fn allow_request(&self, key: &str) -> ThrottleResult<bool> {
let backend = self.backend.lock().await;
let burst_key = format!("burst:{}", key);
let sustained_key = format!("sustained:{}", key);
let burst_count = backend
.get_count(&burst_key)
.await
.map_err(ThrottleError::ThrottleError)?;
if burst_count >= self.burst_rate {
return Ok(false);
}
let sustained_count = backend
.get_count(&sustained_key)
.await
.map_err(ThrottleError::ThrottleError)?;
if sustained_count >= self.sustained_rate {
return Ok(false);
}
backend
.increment_duration(&burst_key, self.burst_duration)
.await?;
backend
.increment_duration(&sustained_key, self.sustained_duration)
.await?;
Ok(true)
}
async fn wait_time(&self, key: &str) -> ThrottleResult<Option<u64>> {
let backend = self.backend.lock().await;
backend
.get_wait_time(key)
.await
.map(|opt| opt.map(|d| d.as_secs()))
}
fn get_rate(&self) -> (usize, u64) {
(self.sustained_rate, self.sustained_duration.as_secs())
}
}