use super::backend::ThrottleBackend;
use super::{Throttle, ThrottleResult};
use super::time_provider::{SystemTimeProvider, TimeProvider};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::time::{Duration, Instant};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct TokenBucketConfig {
pub capacity: usize,
pub refill_rate: usize,
pub refill_interval: u64,
pub tokens_per_request: usize,
}
impl TokenBucketConfig {
pub fn new(
capacity: usize,
refill_rate: usize,
refill_interval: u64,
tokens_per_request: usize,
) -> Self {
Self {
capacity,
refill_rate,
refill_interval,
tokens_per_request,
}
}
pub fn builder() -> TokenBucketConfigBuilder {
TokenBucketConfigBuilder::default()
}
pub fn per_second(rate: usize, burst: usize) -> Self {
Self {
capacity: burst,
refill_rate: rate,
refill_interval: 1,
tokens_per_request: 1,
}
}
pub fn per_minute(rate: usize, burst: usize) -> Self {
Self {
capacity: burst,
refill_rate: rate,
refill_interval: 60,
tokens_per_request: 1,
}
}
pub fn per_hour(rate: usize, burst: usize) -> Self {
Self {
capacity: burst,
refill_rate: rate,
refill_interval: 3600,
tokens_per_request: 1,
}
}
}
#[derive(Debug, Default)]
pub struct TokenBucketConfigBuilder {
capacity: Option<usize>,
refill_rate: Option<usize>,
refill_interval: Option<u64>,
tokens_per_request: Option<usize>,
}
impl TokenBucketConfigBuilder {
pub fn capacity(mut self, capacity: usize) -> Self {
self.capacity = Some(capacity);
self
}
pub fn refill_rate(mut self, rate: usize) -> Self {
self.refill_rate = Some(rate);
self
}
pub fn refill_interval(mut self, interval: u64) -> Self {
self.refill_interval = Some(interval);
self
}
pub fn tokens_per_request(mut self, tokens: usize) -> Self {
self.tokens_per_request = Some(tokens);
self
}
pub fn build(self) -> TokenBucketConfig {
TokenBucketConfig {
capacity: self.capacity.expect("capacity must be set"),
refill_rate: self.refill_rate.expect("refill_rate must be set"),
refill_interval: self.refill_interval.expect("refill_interval must be set"),
tokens_per_request: self.tokens_per_request.unwrap_or(1),
}
}
}
#[derive(Debug, Clone)]
struct BucketState {
tokens: usize,
last_refill: Instant,
}
pub struct TokenBucket<B: ThrottleBackend, T: TimeProvider = SystemTimeProvider> {
#[allow(dead_code)]
key: String,
#[allow(dead_code)]
backend: Arc<B>,
config: TokenBucketConfig,
time_provider: Arc<T>,
state: Arc<RwLock<BucketState>>,
}
impl<B: ThrottleBackend> TokenBucket<B, SystemTimeProvider> {
pub fn new(key: String, backend: Arc<B>, config: TokenBucketConfig) -> Self {
let initial_state = BucketState {
tokens: config.capacity,
last_refill: SystemTimeProvider::new().now(),
};
Self {
key,
backend,
config,
time_provider: Arc::new(SystemTimeProvider::new()),
state: Arc::new(RwLock::new(initial_state)),
}
}
}
impl<B: ThrottleBackend, T: TimeProvider> TokenBucket<B, T> {
pub fn with_time_provider(
key: String,
backend: Arc<B>,
config: TokenBucketConfig,
time_provider: Arc<T>,
) -> Self {
let initial_state = BucketState {
tokens: config.capacity,
last_refill: time_provider.now(),
};
Self {
key,
backend,
config,
time_provider,
state: Arc::new(RwLock::new(initial_state)),
}
}
fn refill_tokens(&self, state: &mut BucketState) {
let now = self.time_provider.now();
let elapsed = now.duration_since(state.last_refill);
let refill_duration = Duration::from_secs(self.config.refill_interval);
if elapsed >= refill_duration {
let intervals = elapsed.as_secs() / self.config.refill_interval;
let tokens_to_add = (intervals as usize) * self.config.refill_rate;
state.tokens = (state.tokens + tokens_to_add).min(self.config.capacity);
state.last_refill = now;
}
}
async fn consume_tokens(&self, count: usize) -> ThrottleResult<bool> {
let mut state = self.state.write().await;
self.refill_tokens(&mut state);
if state.tokens >= count {
state.tokens -= count;
Ok(true)
} else {
Ok(false)
}
}
pub async fn tokens(&self) -> usize {
let mut state = self.state.write().await;
self.refill_tokens(&mut state);
state.tokens
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
state.tokens = self.config.capacity;
state.last_refill = self.time_provider.now();
}
}
#[async_trait]
impl<B: ThrottleBackend, T: TimeProvider> Throttle for TokenBucket<B, T> {
async fn allow_request(&self, _key: &str) -> ThrottleResult<bool> {
self.consume_tokens(self.config.tokens_per_request).await
}
async fn wait_time(&self, _key: &str) -> ThrottleResult<Option<u64>> {
let state = self.state.read().await;
if state.tokens >= self.config.tokens_per_request {
return Ok(None);
}
let now = self.time_provider.now();
let elapsed = now.duration_since(state.last_refill);
let refill_duration = Duration::from_secs(self.config.refill_interval);
if elapsed < refill_duration {
let wait = refill_duration - elapsed;
Ok(Some(wait.as_secs()))
} else {
Ok(Some(0))
}
}
fn get_rate(&self) -> (usize, u64) {
(self.config.refill_rate, self.config.refill_interval)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::throttling::backend::MemoryBackend;
use crate::throttling::time_provider::MockTimeProvider;
#[tokio::test]
async fn test_token_bucket_basic() {
let backend = Arc::new(MemoryBackend::new());
let config = TokenBucketConfig::new(5, 5, 10, 1);
let throttle = TokenBucket::new("test".to_string(), backend, config);
for _ in 0..5 {
assert!(throttle.allow_request("user").await.unwrap());
}
assert!(!throttle.allow_request("user").await.unwrap());
}
#[tokio::test]
async fn test_token_bucket_refill() {
use tokio::time::Instant;
let time_provider = Arc::new(MockTimeProvider::new(Instant::now()));
let backend = Arc::new(MemoryBackend::with_time_provider(time_provider.clone()));
let config = TokenBucketConfig::new(10, 5, 1, 1);
let throttle = TokenBucket::with_time_provider(
"test".to_string(),
backend,
config.clone(),
time_provider.clone(),
);
for _ in 0..10 {
assert!(throttle.allow_request("user").await.unwrap());
}
assert!(!throttle.allow_request("user").await.unwrap());
time_provider.advance(std::time::Duration::from_secs(1));
for _ in 0..5 {
assert!(throttle.allow_request("user").await.unwrap());
}
assert!(!throttle.allow_request("user").await.unwrap());
}
#[tokio::test]
async fn test_token_bucket_burst() {
let backend = Arc::new(MemoryBackend::new());
let config = TokenBucketConfig::per_second(5, 20); let throttle = TokenBucket::new("test".to_string(), backend, config);
for _ in 0..20 {
assert!(throttle.allow_request("user").await.unwrap());
}
assert!(!throttle.allow_request("user").await.unwrap());
}
#[tokio::test]
async fn test_token_bucket_tokens_per_request() {
let backend = Arc::new(MemoryBackend::new());
let config = TokenBucketConfig::new(10, 10, 10, 2); let throttle = TokenBucket::new("test".to_string(), backend, config);
for _ in 0..5 {
assert!(throttle.allow_request("user").await.unwrap());
}
assert!(!throttle.allow_request("user").await.unwrap());
}
#[tokio::test]
async fn test_token_bucket_get_tokens() {
let backend = Arc::new(MemoryBackend::new());
let config = TokenBucketConfig::new(10, 5, 10, 1);
let throttle = TokenBucket::new("test".to_string(), backend, config);
assert_eq!(throttle.tokens().await, 10);
for _ in 0..3 {
throttle.allow_request("user").await.unwrap();
}
assert_eq!(throttle.tokens().await, 7);
}
#[tokio::test]
async fn test_token_bucket_reset() {
let backend = Arc::new(MemoryBackend::new());
let config = TokenBucketConfig::new(10, 5, 10, 1);
let throttle = TokenBucket::new("test".to_string(), backend, config);
for _ in 0..10 {
throttle.allow_request("user").await.unwrap();
}
assert_eq!(throttle.tokens().await, 0);
throttle.reset().await;
assert_eq!(throttle.tokens().await, 10);
}
#[tokio::test]
async fn test_token_bucket_wait_time() {
use tokio::time::Instant;
let time_provider = Arc::new(MockTimeProvider::new(Instant::now()));
let backend = Arc::new(MemoryBackend::with_time_provider(time_provider.clone()));
let config = TokenBucketConfig::new(5, 5, 10, 1);
let throttle = TokenBucket::with_time_provider(
"test".to_string(),
backend,
config,
time_provider.clone(),
);
for _ in 0..5 {
throttle.allow_request("user").await.unwrap();
}
let wait = throttle.wait_time("user").await.unwrap();
assert!(wait.is_some());
assert!(wait.unwrap() > 0);
}
#[test]
fn test_token_bucket_config_builder() {
let config = TokenBucketConfig::builder()
.capacity(100)
.refill_rate(50)
.refill_interval(60)
.tokens_per_request(2)
.build();
assert_eq!(config.capacity, 100);
assert_eq!(config.refill_rate, 50);
assert_eq!(config.refill_interval, 60);
assert_eq!(config.tokens_per_request, 2);
}
#[test]
fn test_token_bucket_config_per_second() {
let config = TokenBucketConfig::per_second(10, 20);
assert_eq!(config.refill_rate, 10);
assert_eq!(config.capacity, 20);
assert_eq!(config.refill_interval, 1);
}
#[test]
fn test_token_bucket_config_per_minute() {
let config = TokenBucketConfig::per_minute(100, 150);
assert_eq!(config.refill_rate, 100);
assert_eq!(config.capacity, 150);
assert_eq!(config.refill_interval, 60);
}
#[test]
fn test_token_bucket_config_per_hour() {
let config = TokenBucketConfig::per_hour(1000, 1500);
assert_eq!(config.refill_rate, 1000);
assert_eq!(config.capacity, 1500);
assert_eq!(config.refill_interval, 3600);
}
}