use std::{
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use futures::future::join_all;
use tokio::{sync::RwLock, task::JoinHandle};
use crate::utils::RateLimiter;
#[derive(Debug, Clone, Copy)]
pub struct TokenBucketConfig {
pub initial_tokens: u64,
pub tokens_per_interval: u64,
pub replenish_interval: Duration,
pub max_tokens: u64,
}
impl Default for TokenBucketConfig {
fn default() -> Self {
Self {
initial_tokens: 100,
tokens_per_interval: 10,
replenish_interval: Duration::from_secs(1),
max_tokens: 100,
}
}
}
pub struct TokenBucket {
tokens: Arc<AtomicU64>,
config: Arc<RwLock<TokenBucketConfig>>,
task_handle: Option<JoinHandle<()>>,
shutdown_flag: Arc<AtomicBool>,
}
impl std::fmt::Debug for TokenBucket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenBucket")
.field("tokens", &self.tokens.load(Ordering::Acquire))
.field("config", &self.config.blocking_read())
.field("shutdown", &self.shutdown_flag.load(Ordering::Acquire))
.finish()
}
}
impl TokenBucket {
pub fn new(config: TokenBucketConfig) -> Self {
let tokens = Arc::new(AtomicU64::new(config.initial_tokens));
Self {
tokens,
config: Arc::new(RwLock::new(config)),
task_handle: None,
shutdown_flag: Arc::new(AtomicBool::new(false)),
}
}
pub fn initialize(config: TokenBucketConfig) -> Self {
let mut limiter = Self::new(config);
let handle = limiter.start();
limiter.task_handle = Some(handle);
limiter
}
pub async fn get_config(&self) -> TokenBucketConfig {
let config_guard = self.config.read().await;
*config_guard
}
pub async fn update_config(&self, new_config: TokenBucketConfig) {
let mut config_guard = self.config.write().await;
*config_guard = new_config;
}
pub async fn set_tokens_per_interval(&self, tokens_per_interval: u64) {
let mut config_guard = self.config.write().await;
config_guard.tokens_per_interval = tokens_per_interval;
}
pub async fn set_replenish_interval(&self, replenish_interval: Duration) {
let mut config_guard = self.config.write().await;
config_guard.replenish_interval = replenish_interval;
}
pub async fn set_max_tokens(&self, max_tokens: u64) {
let mut config_guard = self.config.write().await;
config_guard.max_tokens = max_tokens;
}
pub async fn get_config_all<'a, I: IntoIterator<Item = &'a Self>>(
rate_limiters: I,
) -> Vec<TokenBucketConfig> {
join_all(
rate_limiters
.into_iter()
.map(|limiter| async move { limiter.get_config().await }),
)
.await
}
pub async fn update_config_all<'a, I: IntoIterator<Item = &'a Self>>(
rate_limiters: I,
new_config: TokenBucketConfig,
) {
join_all(
rate_limiters
.into_iter()
.map(|limiter| limiter.update_config(new_config)),
)
.await;
}
pub async fn set_tokens_per_interval_all<'a, I: IntoIterator<Item = &'a Self>>(
rate_limiters: I,
tokens_per_interval: u64,
) {
join_all(
rate_limiters
.into_iter()
.map(|limiter| limiter.set_tokens_per_interval(tokens_per_interval)),
)
.await;
}
pub async fn set_replenish_interval_all<'a, I: IntoIterator<Item = &'a Self>>(
rate_limiters: I,
replenish_interval: Duration,
) {
join_all(
rate_limiters
.into_iter()
.map(|limiter| limiter.set_replenish_interval(replenish_interval)),
)
.await;
}
pub async fn set_max_tokens_all<'a, I: IntoIterator<Item = &'a Self>>(
rate_limiters: I,
max_tokens: u64,
) {
join_all(
rate_limiters
.into_iter()
.map(|limiter| limiter.set_max_tokens(max_tokens)),
)
.await;
}
}
impl TokenBucket {
pub fn start(&mut self) -> JoinHandle<()> {
let tokens = self.tokens.clone();
let config = self.config.clone();
let shutdown_flag = self.shutdown_flag.clone();
tokio::spawn(async move {
loop {
if shutdown_flag.load(Ordering::Acquire) {
break;
}
let current_config = *config.read().await;
tokio::time::sleep(current_config.replenish_interval).await;
let _ = tokens.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
let new_value = std::cmp::min(
current.saturating_add(current_config.tokens_per_interval),
current_config.max_tokens,
);
Some(new_value)
});
}
})
}
pub async fn stop(&mut self) {
self.shutdown_flag.store(true, Ordering::Release);
if let Some(handle) = self.task_handle.take() {
let _ = handle.await;
}
}
pub fn get_tokens(&self) -> &Arc<AtomicU64> {
&self.tokens
}
}
impl RateLimiter for TokenBucket {
type TokenType = u64;
fn try_consume(&self, tokens: u64) -> bool {
self.tokens
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
if current >= tokens {
Some(current - tokens)
} else {
None
}
})
.is_ok()
}
fn available_tokens(&self) -> u64 {
self.tokens.load(Ordering::Acquire)
}
}
impl Drop for TokenBucket {
fn drop(&mut self) {
self.shutdown_flag.store(true, Ordering::Release);
if let Some(handle) = self.task_handle.take() {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use tokio::time::sleep;
use super::*;
#[tokio::test]
async fn test_initial_tokens() {
let config = TokenBucketConfig {
initial_tokens: 50,
tokens_per_interval: 10,
replenish_interval: Duration::from_millis(100),
max_tokens: 100,
};
let limiter = TokenBucket::initialize(config);
assert_eq!(limiter.available_tokens(), 50);
}
#[tokio::test]
async fn test_try_consume_success() {
let config = TokenBucketConfig {
initial_tokens: 50,
tokens_per_interval: 10,
replenish_interval: Duration::from_secs(1),
max_tokens: 100,
};
let limiter = TokenBucket::initialize(config);
assert!(limiter.try_consume(20));
assert_eq!(limiter.available_tokens(), 30);
assert!(limiter.try_consume(30));
assert_eq!(limiter.available_tokens(), 0);
}
#[tokio::test]
async fn test_try_consume_failure() {
let config = TokenBucketConfig {
initial_tokens: 10,
tokens_per_interval: 5,
replenish_interval: Duration::from_secs(1),
max_tokens: 100,
};
let limiter = TokenBucket::initialize(config);
assert!(limiter.try_consume(5));
assert_eq!(limiter.available_tokens(), 5);
assert!(!limiter.try_consume(10));
assert_eq!(limiter.available_tokens(), 5);
}
#[tokio::test]
async fn test_token_replenishment() {
let config = TokenBucketConfig {
initial_tokens: 10,
tokens_per_interval: 20,
replenish_interval: Duration::from_millis(100),
max_tokens: 100,
};
let limiter = TokenBucket::initialize(config);
assert!(limiter.try_consume(10));
assert_eq!(limiter.available_tokens(), 0);
sleep(Duration::from_millis(150)).await;
let tokens = limiter.available_tokens();
assert!(tokens >= 20, "Expected at least 20 tokens, got {tokens}");
}
#[tokio::test]
async fn test_max_tokens_cap() {
let config = TokenBucketConfig {
initial_tokens: 90,
tokens_per_interval: 20,
replenish_interval: Duration::from_millis(100),
max_tokens: 100,
};
let limiter = TokenBucket::initialize(config);
sleep(Duration::from_millis(150)).await;
let tokens = limiter.available_tokens();
assert!(tokens <= 100, "Tokens exceeded max: {tokens}");
assert_eq!(tokens, 100, "Expected tokens to be capped at 100");
}
#[tokio::test]
async fn test_dynamic_config_update() {
let config = TokenBucketConfig {
initial_tokens: 10,
tokens_per_interval: 5,
replenish_interval: Duration::from_millis(100),
max_tokens: 50,
};
let limiter = TokenBucket::initialize(config);
assert!(limiter.try_consume(10));
assert_eq!(limiter.available_tokens(), 0);
let new_config = TokenBucketConfig {
initial_tokens: 10,
tokens_per_interval: 30,
replenish_interval: Duration::from_millis(100),
max_tokens: 50,
};
limiter.update_config(new_config).await;
sleep(Duration::from_millis(150)).await;
let tokens = limiter.available_tokens();
assert!(tokens >= 30, "Expected at least 30 tokens, got {tokens}");
}
#[tokio::test]
async fn test_concurrent_consumption() {
let config = TokenBucketConfig {
initial_tokens: 1000,
tokens_per_interval: 100,
replenish_interval: Duration::from_millis(100),
max_tokens: 1000,
};
let limiter = TokenBucket::initialize(config);
let tokens = limiter.get_tokens();
let mut handles = vec![];
for _ in 0..10 {
let tokens = tokens.clone();
let handle = tokio::spawn(async move {
for _ in 0..10 {
tokens.try_consume(10);
sleep(Duration::from_millis(5)).await;
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let tokens = limiter.available_tokens();
assert!(tokens <= 1000, "Tokens exceeded max");
}
#[tokio::test]
async fn test_rate_limiting_behavior() {
let config = TokenBucketConfig {
initial_tokens: 5,
tokens_per_interval: 5,
replenish_interval: Duration::from_millis(100),
max_tokens: 10,
};
let limiter = TokenBucket::initialize(config);
let start = Instant::now();
assert!(limiter.try_consume(5));
assert!(!limiter.try_consume(5));
while limiter.available_tokens() < 5 {
sleep(Duration::from_millis(10)).await;
}
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(100));
assert!(limiter.try_consume(5));
}
#[tokio::test]
async fn test_get_config() {
let config = TokenBucketConfig {
initial_tokens: 42,
tokens_per_interval: 13,
replenish_interval: Duration::from_millis(250),
max_tokens: 200,
};
let limiter = TokenBucket::initialize(config);
let retrieved_config = limiter.get_config().await;
assert_eq!(retrieved_config.initial_tokens, 42);
assert_eq!(retrieved_config.tokens_per_interval, 13);
assert_eq!(
retrieved_config.replenish_interval,
Duration::from_millis(250)
);
assert_eq!(retrieved_config.max_tokens, 200);
}
}