#[cfg(feature = "rate-limiting")]
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(feature = "rate-limiting")]
use std::sync::Arc;
#[cfg(feature = "rate-limiting")]
use std::time::Duration;
#[cfg(feature = "rate-limiting")]
use tokio::sync::Mutex;
#[cfg(feature = "rate-limiting")]
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests_per_second: u64,
pub burst_capacity: u64,
pub block_duration_secs: u64,
}
#[cfg(feature = "rate-limiting")]
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests_per_second: 1000,
burst_capacity: 2000,
block_duration_secs: 10,
}
}
}
#[cfg(feature = "rate-limiting")]
#[derive(Debug)]
pub struct TokenBucket {
tokens: AtomicU64,
last_update: AtomicU64,
capacity: u64,
refill_rate: u64, }
#[cfg(feature = "rate-limiting")]
impl TokenBucket {
pub fn new(capacity: u64, refill_rate: u64) -> Self {
let now = Self::now_millis();
Self {
tokens: AtomicU64::new(capacity),
last_update: AtomicU64::new(now),
capacity,
refill_rate,
}
}
#[inline]
fn now_millis() -> u64 {
let since_epoch = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|e| {
tracing::warn!("系统时间异常,可能时钟回退: {}", e);
std::time::Duration::ZERO
});
since_epoch.as_millis() as u64
}
pub fn try_acquire(&self) -> bool {
self.try_acquire_n(1)
}
pub fn try_acquire_n(&self, n: u64) -> bool {
let now = Self::now_millis();
let last_update = self.last_update.load(Ordering::Relaxed);
let elapsed = now.saturating_sub(last_update);
let refill = (elapsed * self.refill_rate) / 1000;
let current_tokens = self.tokens.load(Ordering::Relaxed);
let new_tokens = (current_tokens + refill).min(self.capacity);
if new_tokens >= n {
self.tokens.store(new_tokens - n, Ordering::Relaxed);
self.last_update.store(now, Ordering::Relaxed);
true
} else {
false
}
}
pub fn available_tokens(&self) -> u64 {
let now = Self::now_millis();
let current_tokens = self.tokens.load(Ordering::Relaxed);
let last_update = self.last_update.load(Ordering::Relaxed);
let elapsed = now.saturating_sub(last_update);
let refill = (elapsed * self.refill_rate) / 1000;
(current_tokens + refill).min(self.capacity)
}
}
#[cfg(feature = "rate-limiting")]
#[derive(Debug)]
pub struct ClientRateLimiter {
per_client: Mutex<std::collections::HashMap<String, Arc<TokenBucket>>>,
global_limit: TokenBucket,
config: RateLimitConfig,
}
#[cfg(feature = "rate-limiting")]
impl ClientRateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
per_client: Mutex::new(std::collections::HashMap::new()),
global_limit: TokenBucket::new(config.burst_capacity, config.max_requests_per_second),
config,
}
}
pub async fn check_rate_limit(&self, client_id: &str, cost: u64) -> Result<(), Duration> {
let global_available = self.global_limit.available_tokens();
let mut per_client_map = self.per_client.lock().await;
let bucket = per_client_map
.entry(client_id.to_string())
.or_insert_with(|| {
Arc::new(TokenBucket::new(
self.config.burst_capacity,
self.config.max_requests_per_second,
))
});
let per_client_available = bucket.available_tokens();
if per_client_available < cost {
let wait_time = Duration::from_millis(
((cost - per_client_available) * 1000 / self.config.max_requests_per_second) as u64,
);
return Err(wait_time);
}
if global_available < cost {
let wait_time = Duration::from_millis(
(cost - global_available) * 1000 / self.config.max_requests_per_second,
);
return Err(wait_time);
}
bucket.try_acquire_n(cost);
self.global_limit.try_acquire_n(cost);
Ok(())
}
pub async fn get_client_status(&self, client_id: &str) -> RateLimitStatus {
let per_client_map = self.per_client.lock().await;
let bucket = per_client_map.get(client_id);
if let Some(b) = bucket {
RateLimitStatus {
client_available: b.available_tokens(),
client_capacity: b.capacity,
global_available: self.global_limit.available_tokens(),
global_capacity: self.global_limit.capacity,
}
} else {
RateLimitStatus {
client_available: self.config.burst_capacity,
client_capacity: self.config.burst_capacity,
global_available: self.global_limit.available_tokens(),
global_capacity: self.global_limit.capacity,
}
}
}
}
#[cfg(feature = "rate-limiting")]
#[derive(Debug, Clone)]
pub struct RateLimitStatus {
pub client_available: u64,
pub client_capacity: u64,
pub global_available: u64,
pub global_capacity: u64,
}
#[cfg(feature = "rate-limiting")]
#[derive(Debug, Clone)]
pub struct GlobalRateLimiter(Arc<ClientRateLimiter>);
#[cfg(feature = "rate-limiting")]
impl GlobalRateLimiter {
pub fn new(config: Option<RateLimitConfig>) -> Self {
Self(Arc::new(ClientRateLimiter::new(config.unwrap_or_default())))
}
pub fn inner(&self) -> &Arc<ClientRateLimiter> {
&self.0
}
}
#[cfg(feature = "rate-limiting")]
impl Default for GlobalRateLimiter {
fn default() -> Self {
Self::new(None)
}
}
#[cfg(not(feature = "rate-limiting"))]
#[derive(Debug, Clone, Default)]
pub struct RateLimitConfig;
#[cfg(not(feature = "rate-limiting"))]
impl RateLimitConfig {
pub fn new(
_max_requests_per_second: u64,
_burst_capacity: u64,
_block_duration_secs: u64,
) -> Self {
Self
}
}
#[cfg(not(feature = "rate-limiting"))]
#[derive(Debug, Clone, Default)]
pub struct TokenBucket;
#[cfg(not(feature = "rate-limiting"))]
impl TokenBucket {
pub fn new(_capacity: u64, _refill_rate: u64) -> Self {
Self
}
pub fn try_acquire(&self) -> bool {
true
}
pub fn try_acquire_n(&self, _n: u64) -> bool {
true
}
pub fn available_tokens(&self) -> u64 {
u64::MAX
}
}
#[cfg(not(feature = "rate-limiting"))]
#[derive(Debug, Clone, Default)]
pub struct ClientRateLimiter;
#[cfg(not(feature = "rate-limiting"))]
impl ClientRateLimiter {
pub fn new(_config: RateLimitConfig) -> Self {
Self
}
pub async fn check_rate_limit(
&self,
_client_id: &str,
_cost: u64,
) -> Result<(), std::time::Duration> {
Ok(())
}
pub async fn get_client_status(&self, _client_id: &str) -> RateLimitStatus {
RateLimitStatus::default()
}
}
#[cfg(not(feature = "rate-limiting"))]
#[derive(Debug, Clone, Default)]
pub struct RateLimitStatus {
pub client_available: u64,
pub client_capacity: u64,
pub global_available: u64,
pub global_capacity: u64,
}
#[cfg(not(feature = "rate-limiting"))]
#[derive(Debug, Clone, Default)]
pub struct GlobalRateLimiter;
#[cfg(not(feature = "rate-limiting"))]
impl GlobalRateLimiter {
pub fn new(_config: Option<RateLimitConfig>) -> Self {
Self
}
pub fn inner(&self) -> &Arc<ClientRateLimiter> {
static EMPTY: Arc<ClientRateLimiter> = Arc::new(ClientRateLimiter);
&EMPTY
}
}
#[cfg(test)]
#[cfg(feature = "rate-limiting")]
mod tests {
use super::*;
#[test]
fn test_token_bucket_basic() {
let bucket = TokenBucket::new(10, 10);
assert_eq!(bucket.available_tokens(), 10);
assert!(bucket.try_acquire_n(5));
assert_eq!(bucket.available_tokens(), 5);
assert!(!bucket.try_acquire_n(6));
assert!(bucket.try_acquire_n(5));
assert_eq!(bucket.available_tokens(), 0);
}
#[test]
fn test_token_bucket_refill() {
let bucket = TokenBucket::new(10, 100);
bucket.try_acquire_n(10);
assert_eq!(bucket.available_tokens(), 0);
std::thread::sleep(Duration::from_millis(50));
let tokens = bucket.available_tokens();
assert!(
tokens >= 5,
"Expected at least 5 tokens after refill, but got {}",
tokens
);
}
#[tokio::test]
async fn test_client_rate_limiter() {
let limiter = ClientRateLimiter::new(RateLimitConfig {
max_requests_per_second: 100,
burst_capacity: 100,
block_duration_secs: 10,
});
let status = limiter.get_client_status("test_client").await;
assert_eq!(status.client_available, 100);
assert_eq!(status.global_available, 100);
assert!(limiter.check_rate_limit("test_client", 1).await.is_ok());
for _ in 0..100 {
let _ = limiter.check_rate_limit("test_client", 1).await;
}
assert!(limiter.check_rate_limit("test_client", 1).await.is_err());
}
}