use super::backend::ThrottleBackend;
use super::{Throttle, ThrottleResult};
use super::time_provider::{SystemTimeProvider, TimeProvider};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::time::Instant;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct LeakyBucketConfig {
pub capacity: usize,
pub leak_rate: f64,
}
impl LeakyBucketConfig {
pub fn new(capacity: usize, leak_rate: f64) -> Self {
Self {
capacity,
leak_rate,
}
}
pub fn per_second(rate: f64, capacity: usize) -> Self {
Self {
capacity,
leak_rate: rate,
}
}
pub fn per_minute(rate: f64, capacity: usize) -> Self {
Self {
capacity,
leak_rate: rate / 60.0,
}
}
}
#[derive(Debug, Clone)]
struct BucketState {
level: f64,
last_leak: Instant,
}
pub struct LeakyBucketThrottle<B: ThrottleBackend, T: TimeProvider = SystemTimeProvider> {
#[allow(dead_code)]
key: String,
#[allow(dead_code)]
backend: Arc<B>,
config: LeakyBucketConfig,
time_provider: Arc<T>,
state: Arc<RwLock<BucketState>>,
}
impl<B: ThrottleBackend> LeakyBucketThrottle<B, SystemTimeProvider> {
pub fn new(key: String, backend: Arc<B>, config: LeakyBucketConfig) -> Self {
let initial_state = BucketState {
level: 0.0,
last_leak: SystemTimeProvider::new().now(),
};
Self {
key,
backend,
config,
time_provider: Arc::new(SystemTimeProvider::new()),
state: Arc::new(RwLock::new(initial_state)),
}
}
}
impl<B: ThrottleBackend, T: TimeProvider> LeakyBucketThrottle<B, T> {
pub fn with_time_provider(
key: String,
backend: Arc<B>,
config: LeakyBucketConfig,
time_provider: Arc<T>,
) -> Self {
let initial_state = BucketState {
level: 0.0,
last_leak: time_provider.now(),
};
Self {
key,
backend,
config,
time_provider,
state: Arc::new(RwLock::new(initial_state)),
}
}
fn leak_bucket(&self, state: &mut BucketState) {
let now = self.time_provider.now();
let elapsed = now.duration_since(state.last_leak);
let elapsed_secs = elapsed.as_secs_f64();
let leaked = elapsed_secs * self.config.leak_rate;
state.level = (state.level - leaked).max(0.0);
state.last_leak = now;
}
pub async fn level(&self) -> f64 {
let mut state = self.state.write().await;
self.leak_bucket(&mut state);
state.level
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
state.level = 0.0;
state.last_leak = self.time_provider.now();
}
}
#[async_trait]
impl<B: ThrottleBackend, T: TimeProvider> Throttle for LeakyBucketThrottle<B, T> {
async fn allow_request(&self, _key: &str) -> ThrottleResult<bool> {
let mut state = self.state.write().await;
self.leak_bucket(&mut state);
if state.level < self.config.capacity as f64 {
state.level += 1.0;
Ok(true)
} else {
Ok(false)
}
}
async fn wait_time(&self, _key: &str) -> ThrottleResult<Option<u64>> {
let state = self.state.read().await;
if state.level < self.config.capacity as f64 {
return Ok(None);
}
let excess = state.level - (self.config.capacity as f64 - 1.0);
let wait_secs = (excess / self.config.leak_rate).ceil();
Ok(Some(wait_secs as u64))
}
fn get_rate(&self) -> (usize, u64) {
(self.config.leak_rate as usize, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::throttling::backend::MemoryBackend;
use crate::throttling::time_provider::MockTimeProvider;
#[tokio::test]
async fn test_leaky_bucket_basic() {
use tokio::time::Instant;
let time_provider = Arc::new(MockTimeProvider::new(Instant::now()));
let backend = Arc::new(MemoryBackend::with_time_provider(time_provider.clone()));
let config = LeakyBucketConfig::new(5, 1.0);
let throttle = LeakyBucketThrottle::with_time_provider(
"test".to_string(),
backend,
config,
time_provider,
);
for _ in 0..5 {
assert!(throttle.allow_request("user").await.unwrap());
}
assert!(!throttle.allow_request("user").await.unwrap());
}
#[tokio::test]
async fn test_leaky_bucket_leak() {
use tokio::time::Instant;
let time_provider = Arc::new(MockTimeProvider::new(Instant::now()));
let backend = Arc::new(MemoryBackend::with_time_provider(time_provider.clone()));
let config = LeakyBucketConfig::new(10, 2.0); let throttle = LeakyBucketThrottle::with_time_provider(
"test".to_string(),
backend,
config,
time_provider.clone(),
);
for _ in 0..10 {
assert!(throttle.allow_request("user").await.unwrap());
}
assert!(!throttle.allow_request("user").await.unwrap());
time_provider.advance(std::time::Duration::from_secs(1));
assert!(throttle.allow_request("user").await.unwrap());
assert!(throttle.allow_request("user").await.unwrap());
assert!(!throttle.allow_request("user").await.unwrap());
}
#[tokio::test]
async fn test_leaky_bucket_smoothing() {
use tokio::time::Instant;
let time_provider = Arc::new(MockTimeProvider::new(Instant::now()));
let backend = Arc::new(MemoryBackend::with_time_provider(time_provider.clone()));
let config = LeakyBucketConfig::per_second(5.0, 10);
let throttle = LeakyBucketThrottle::with_time_provider(
"test".to_string(),
backend,
config,
time_provider.clone(),
);
for _ in 0..10 {
assert!(throttle.allow_request("user").await.unwrap());
}
assert!(!throttle.allow_request("user").await.unwrap());
time_provider.advance(std::time::Duration::from_secs(2));
for _ in 0..10 {
assert!(throttle.allow_request("user").await.unwrap());
}
}
#[tokio::test]
async fn test_leaky_bucket_level() {
use tokio::time::Instant;
let time_provider = Arc::new(MockTimeProvider::new(Instant::now()));
let backend = Arc::new(MemoryBackend::with_time_provider(time_provider.clone()));
let config = LeakyBucketConfig::new(10, 2.0);
let throttle = LeakyBucketThrottle::with_time_provider(
"test".to_string(),
backend,
config,
time_provider.clone(),
);
assert_eq!(throttle.level().await, 0.0);
for _ in 0..5 {
throttle.allow_request("user").await.unwrap();
}
assert_eq!(throttle.level().await, 5.0);
time_provider.advance(std::time::Duration::from_secs(1));
assert_eq!(throttle.level().await, 3.0);
}
#[tokio::test]
async fn test_leaky_bucket_reset() {
use tokio::time::Instant;
let time_provider = Arc::new(MockTimeProvider::new(Instant::now()));
let backend = Arc::new(MemoryBackend::with_time_provider(time_provider.clone()));
let config = LeakyBucketConfig::new(10, 1.0);
let throttle = LeakyBucketThrottle::with_time_provider(
"test".to_string(),
backend,
config,
time_provider,
);
for _ in 0..10 {
throttle.allow_request("user").await.unwrap();
}
assert!(throttle.level().await > 0.0);
throttle.reset().await;
assert_eq!(throttle.level().await, 0.0);
}
#[tokio::test]
async fn test_leaky_bucket_wait_time() {
use tokio::time::Instant;
let time_provider = Arc::new(MockTimeProvider::new(Instant::now()));
let backend = Arc::new(MemoryBackend::with_time_provider(time_provider.clone()));
let config = LeakyBucketConfig::new(5, 1.0);
let throttle = LeakyBucketThrottle::with_time_provider(
"test".to_string(),
backend,
config,
time_provider,
);
for _ in 0..5 {
throttle.allow_request("user").await.unwrap();
}
let wait = throttle.wait_time("user").await.unwrap();
assert!(wait.is_some());
assert!(wait.unwrap() > 0);
}
#[test]
fn test_leaky_bucket_config_per_second() {
let config = LeakyBucketConfig::per_second(10.0, 20);
assert_eq!(config.leak_rate, 10.0);
assert_eq!(config.capacity, 20);
}
#[test]
fn test_leaky_bucket_config_per_minute() {
let config = LeakyBucketConfig::per_minute(60.0, 100);
assert_eq!(config.leak_rate, 1.0);
assert_eq!(config.capacity, 100);
}
}