use crate::config::CONFIG;
use crate::prelude::ChatId;
use async_trait::async_trait;
use crossbeam_utils::CachePadded;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::time::{Duration, sleep};
use tracing::{debug, info, warn};
#[derive(Debug)]
pub struct LockFreeTokenBucket {
state: CachePadded<Arc<AtomicU64>>,
last_refill: CachePadded<Arc<AtomicU64>>,
capacity: u32,
refill_rate: u32,
stats: CachePadded<Arc<AtomicU64>>,
}
impl LockFreeTokenBucket {
pub fn new(capacity: u32, refill_rate: u32) -> Self {
let now_micros = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros() as u64;
Self {
state: CachePadded::new(Arc::new(AtomicU64::new(Self::pack_state(0, capacity)))),
last_refill: CachePadded::new(Arc::new(AtomicU64::new(now_micros))),
capacity,
refill_rate,
stats: CachePadded::new(Arc::new(AtomicU64::new(0))),
}
}
#[inline]
fn pack_state(total_requests: u32, available_tokens: u32) -> u64 {
((total_requests as u64) << 32) | (available_tokens as u64)
}
#[inline]
fn unpack_state(state: u64) -> (u32, u32) {
((state >> 32) as u32, state as u32)
}
#[inline]
fn pack_stats(rate_limited: u32, allowed: u32) -> u64 {
((rate_limited as u64) << 32) | (allowed as u64)
}
#[inline]
fn current_time_micros() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros() as u64
}
#[inline]
pub fn try_consume(&self) -> bool {
let now = Self::current_time_micros();
loop {
let current_state = self.state.load(Ordering::Acquire);
let (total_requests, available_tokens) = Self::unpack_state(current_state);
let last_refill = self.last_refill.load(Ordering::Acquire);
let time_passed = now.saturating_sub(last_refill);
let tokens_to_add = ((time_passed * self.refill_rate as u64) / 1_000_000)
.min((self.capacity - available_tokens) as u64)
as u32;
let new_available = (available_tokens + tokens_to_add).min(self.capacity);
if new_available == 0 {
let current_stats = self.stats.load(Ordering::Relaxed);
let (rate_limited, allowed) = Self::unpack_state(current_stats);
let new_stats = Self::pack_stats(rate_limited + 1, allowed);
self.stats.store(new_stats, Ordering::Relaxed);
return false; }
let new_state = Self::pack_state(total_requests + 1, new_available - 1);
match self.state.compare_exchange_weak(
current_state,
new_state,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => {
if tokens_to_add > 0 {
self.last_refill.store(now, Ordering::Relaxed);
}
let current_stats = self.stats.load(Ordering::Relaxed);
let (rate_limited, allowed) = Self::unpack_state(current_stats);
let new_stats = Self::pack_stats(rate_limited, allowed + 1);
self.stats.store(new_stats, Ordering::Relaxed);
return true;
}
Err(_) => continue, }
}
}
#[inline]
pub fn available_tokens(&self) -> u32 {
let state = self.state.load(Ordering::Relaxed);
let (_, available_tokens) = Self::unpack_state(state);
available_tokens
}
pub fn get_stats(&self) -> (u32, u32, u32) {
let state = self.state.load(Ordering::Relaxed);
let stats = self.stats.load(Ordering::Relaxed);
let (total_requests, _) = Self::unpack_state(state);
let (rate_limited, allowed) = Self::unpack_state(stats);
(total_requests, allowed, rate_limited)
}
}
#[derive(Debug)]
pub struct ProactiveCleanup {
memory_threshold: usize,
cleanup_interval: Duration,
last_cleanup: Arc<AtomicU64>,
}
impl ProactiveCleanup {
pub fn new(memory_threshold: usize, cleanup_interval: Duration) -> Self {
Self {
memory_threshold,
cleanup_interval,
last_cleanup: Arc::new(AtomicU64::new(0)),
}
}
#[inline]
pub fn should_cleanup(&self, current_bucket_count: usize) -> bool {
let current_memory = current_bucket_count * std::mem::size_of::<LockFreeTokenBucket>();
let now = LockFreeTokenBucket::current_time_micros();
let last_cleanup = self.last_cleanup.load(Ordering::Relaxed);
current_memory > self.memory_threshold
|| (now - last_cleanup) > self.cleanup_interval.as_micros() as u64
}
#[inline]
pub fn mark_cleanup_completed(&self) {
let now = LockFreeTokenBucket::current_time_micros();
self.last_cleanup.store(now, Ordering::Relaxed);
}
}
#[derive(Debug)]
pub struct LockFreeGlobalStats {
total_requests: Arc<AtomicU64>,
rate_limited_requests: Arc<AtomicU64>,
allowed_requests: Arc<AtomicU64>,
active_buckets: Arc<AtomicU32>,
}
impl Default for LockFreeGlobalStats {
fn default() -> Self {
Self::new()
}
}
impl LockFreeGlobalStats {
pub fn new() -> Self {
Self {
total_requests: Arc::new(AtomicU64::new(0)),
rate_limited_requests: Arc::new(AtomicU64::new(0)),
allowed_requests: Arc::new(AtomicU64::new(0)),
active_buckets: Arc::new(AtomicU32::new(0)),
}
}
#[inline]
pub fn record_request(&self, was_allowed: bool) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
if was_allowed {
self.allowed_requests.fetch_add(1, Ordering::Relaxed);
} else {
self.rate_limited_requests.fetch_add(1, Ordering::Relaxed);
}
}
#[inline]
pub fn set_active_buckets(&self, count: u32) {
self.active_buckets.store(count, Ordering::Relaxed);
}
pub fn get_snapshot(&self) -> BucketStats {
BucketStats {
total_requests: self.total_requests.load(Ordering::Relaxed),
rate_limited_requests: self.rate_limited_requests.load(Ordering::Relaxed),
allowed_requests: self.allowed_requests.load(Ordering::Relaxed),
last_access: SystemTime::now(),
max_tokens_used: 0, }
}
}
#[derive(Debug, Clone)]
pub struct BucketStats {
pub total_requests: u64,
pub rate_limited_requests: u64,
pub allowed_requests: u64,
pub last_access: SystemTime,
pub max_tokens_used: u32,
}
impl Default for BucketStats {
fn default() -> Self {
Self {
total_requests: 0,
rate_limited_requests: 0,
allowed_requests: 0,
last_access: SystemTime::now(),
max_tokens_used: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
chat_buckets: Arc<DashMap<ChatId, Arc<LockFreeTokenBucket>>>,
global_stats: Arc<LockFreeGlobalStats>,
cleanup_manager: Arc<ProactiveCleanup>,
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
impl RateLimiter {
#[tracing::instrument]
pub fn new() -> Self {
debug!("Creating lock-free high-performance RateLimiter");
let cfg = &CONFIG.rate_limit;
let capacity = u32::try_from(cfg.init_bucket)
.unwrap_or_else(|_| panic!("Rate limit capacity too large: {}", cfg.init_bucket));
Self {
chat_buckets: Arc::new(DashMap::with_capacity(capacity as usize)),
global_stats: Arc::new(LockFreeGlobalStats::new()),
cleanup_manager: Arc::new(ProactiveCleanup::new(
cfg.init_bucket * std::mem::size_of::<LockFreeTokenBucket>(),
Duration::from_secs(cfg.cleanup_interval),
)),
}
}
pub async fn get_global_stats(&self) -> BucketStats {
self.global_stats.get_snapshot()
}
pub async fn get_chat_stats(&self, chat_id: &ChatId) -> Option<BucketStats> {
if let Some(bucket) = self.chat_buckets.get(chat_id) {
let (total, allowed, rate_limited) = bucket.get_stats();
Some(BucketStats {
total_requests: total as u64,
rate_limited_requests: rate_limited as u64,
allowed_requests: allowed as u64,
last_access: SystemTime::now(),
max_tokens_used: 0, })
} else {
None
}
}
fn maybe_cleanup_inactive_buckets(&self) {
if self.cleanup_manager.should_cleanup(self.chat_buckets.len()) {
let mut removed_count = 0;
let to_remove: Vec<ChatId> = self
.chat_buckets
.iter()
.filter_map(|entry| {
if entry.value().available_tokens() == 0 {
Some(entry.key().clone())
} else {
None
}
})
.take(10) .collect();
for chat_id in to_remove {
if self.chat_buckets.remove(&chat_id).is_some() {
removed_count += 1;
}
}
if removed_count > 0 {
debug!("Proactively cleaned up {} inactive buckets", removed_count);
}
self.cleanup_manager.mark_cleanup_completed();
self.global_stats
.set_active_buckets(self.chat_buckets.len() as u32);
}
}
#[tracing::instrument(skip(self))]
pub async fn check_rate_limit(&self, chat_id: &ChatId) -> bool {
self.maybe_cleanup_inactive_buckets();
let bucket = self
.chat_buckets
.entry(chat_id.clone())
.or_insert_with(|| {
debug!(
"Creating new lock-free token bucket for chat_id: {}",
chat_id.0
);
let cfg = &CONFIG.rate_limit;
let capacity = u32::try_from(cfg.limit)
.unwrap_or_else(|_| panic!("Rate limit capacity too large: {}", cfg.limit));
let refill_rate = u32::try_from(capacity as u64 / cfg.duration.max(1)).unwrap_or(1);
Arc::new(LockFreeTokenBucket::new(capacity, refill_rate))
})
.clone();
let result = bucket.try_consume();
self.global_stats.record_request(result);
debug!(
"Rate limit check for chat_id {}: {}",
chat_id.0,
if result { "allowed" } else { "limited" }
);
result
}
#[tracing::instrument(skip(self))]
pub async fn wait_if_needed(&mut self, chat_id: &ChatId) -> bool {
let cfg = &CONFIG.rate_limit;
let mut attempts = 0;
let base_retry_delay = Duration::from_millis(cfg.retry_delay);
while attempts < cfg.retry_attempts {
if self.check_rate_limit(chat_id).await {
if attempts > 0 {
debug!(
"Rate limit passed after {} attempts for chat_id: {}",
attempts + 1,
chat_id.0
);
}
return true;
}
attempts += 1;
let jitter_ms = LockFreeTokenBucket::current_time_micros() % 50;
let factor = 1.0 + (attempts as f64 * 0.5);
let retry_delay = Duration::from_millis(
(base_retry_delay.as_millis() as f64 * factor) as u64 + jitter_ms,
);
debug!(
"Rate limit exceeded, attempt {}/{}, backing off for {:?}",
attempts, cfg.retry_attempts, retry_delay
);
sleep(retry_delay).await;
}
warn!(
"Rate limit exceeded after {} attempts for chat_id: {}",
attempts, chat_id.0
);
false
}
pub async fn clear_rate_limit(&self, chat_id: &ChatId) {
if self.chat_buckets.remove(chat_id).is_some() {
debug!("Cleared rate limit for chat_id: {}", chat_id.0);
self.global_stats
.set_active_buckets(self.chat_buckets.len() as u32);
}
}
pub async fn get_available_tokens(&self, chat_id: &ChatId) -> Option<u32> {
self.chat_buckets
.get(chat_id)
.map(|bucket| bucket.available_tokens())
}
pub async fn get_bucket_capacity(&self, chat_id: &ChatId) -> Option<u32> {
if self.chat_buckets.contains_key(chat_id) {
let cfg = &CONFIG.rate_limit;
Some(u32::try_from(cfg.limit).unwrap_or(100))
} else {
None
}
}
pub fn active_bucket_count(&self) -> usize {
self.chat_buckets.len()
}
pub async fn shutdown(&self) {
let bucket_count = self.chat_buckets.len();
info!(
"Shutting down lock-free RateLimiter with {} active buckets",
bucket_count
);
self.chat_buckets.clear();
self.global_stats.set_active_buckets(0);
info!("Lock-free RateLimiter shutdown complete");
}
}
#[async_trait]
pub trait RateLimiterExt {
async fn check_with_priority(&self, chat_id: &ChatId, priority: u8) -> bool;
}
#[async_trait]
impl RateLimiterExt for RateLimiter {
async fn check_with_priority(&self, chat_id: &ChatId, priority: u8) -> bool {
let attempts = match priority {
0 => 1, 1 => 2, 2..=u8::MAX => 3, };
for _ in 0..attempts {
if self.check_rate_limit(chat_id).await {
return true;
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{Duration, sleep};
fn create_test_bucket(capacity: u32, refill_rate: u32) -> LockFreeTokenBucket {
LockFreeTokenBucket::new(capacity, refill_rate)
}
#[tokio::test]
async fn test_bucket_created_with_full_capacity() {
let bucket = create_test_bucket(10, 60);
assert_eq!(bucket.available_tokens(), 10);
assert!(
bucket.try_consume(),
"First token should be available immediately"
);
assert_eq!(bucket.available_tokens(), 9);
for i in 0..9 {
assert!(bucket.try_consume(), "Token {} should be available", i + 2);
}
assert_eq!(bucket.available_tokens(), 0);
assert!(!bucket.try_consume(), "Should be rate limited when empty");
}
#[tokio::test]
async fn test_background_refill() {
let bucket = create_test_bucket(5, 5);
for _ in 0..5 {
assert!(bucket.try_consume());
}
assert_eq!(bucket.available_tokens(), 0);
sleep(Duration::from_millis(1200)).await;
let tokens_after_wait = bucket.available_tokens();
let can_consume = bucket.try_consume();
assert!(
can_consume || tokens_after_wait > 0,
"Should have tokens after waiting or be able to consume"
);
}
#[tokio::test]
async fn test_concurrent_token_consumption() {
let bucket = Arc::new(create_test_bucket(100, 60));
let mut handles = vec![];
for i in 0..20 {
let bucket_clone = Arc::clone(&bucket);
let handle = tokio::spawn(async move {
let mut successes = 0;
for _ in 0..10 {
if bucket_clone.try_consume() {
successes += 1;
}
}
(i, successes)
});
handles.push(handle);
}
let mut total_successes = 0;
for handle in handles {
let (_, successes) = handle.await.unwrap();
total_successes += successes;
}
assert!(total_successes <= 100, "Should not exceed bucket capacity");
assert!(
total_successes >= 90,
"Should consume most tokens with high concurrency, got: {total_successes}"
);
}
#[tokio::test]
async fn test_graceful_shutdown() {
let bucket = create_test_bucket(10, 60);
assert!(bucket.try_consume());
let initial_tokens = bucket.available_tokens();
assert!(initial_tokens > 0, "Should have remaining tokens");
sleep(Duration::from_millis(100)).await;
let tokens_after_wait = bucket.available_tokens();
assert!(
tokens_after_wait <= initial_tokens,
"Tokens should not increase without refill trigger"
);
}
#[tokio::test]
async fn test_bucket_capacity_limits() {
let bucket = create_test_bucket(3, 1);
sleep(Duration::from_millis(5000)).await;
let _ = bucket.try_consume();
assert!(
bucket.available_tokens() <= 3,
"Should never exceed capacity even after long wait, got: {}",
bucket.available_tokens()
);
}
#[tokio::test]
async fn test_atomic_operations_performance() {
let bucket = create_test_bucket(1000, 60);
let start = std::time::Instant::now();
for _ in 0..1000 {
bucket.try_consume();
}
let duration = start.elapsed();
assert!(
duration.as_millis() < 100, "Atomic operations should be very fast: {duration:?}"
);
}
#[tokio::test]
async fn test_cache_padded_alignment() {
let bucket = create_test_bucket(10, 60);
assert_eq!(bucket.available_tokens(), 10);
assert!(bucket.try_consume());
assert_eq!(bucket.available_tokens(), 9);
let (total, allowed, rate_limited) = bucket.get_stats();
assert_eq!(total, 1);
assert_eq!(allowed, 1);
assert_eq!(rate_limited, 0);
}
}
#[cfg(test)]
mod ratelimiter_tests {
use super::*;
use crate::prelude::ChatId;
use tokio::time::{Duration, sleep};
fn chat_id(n: u64) -> ChatId {
ChatId::from(format!("chat_{n}"))
}
#[tokio::test]
async fn test_new_and_active_bucket_count() {
let limiter = RateLimiter::new();
assert_eq!(limiter.active_bucket_count(), 0);
let _ = limiter.check_rate_limit(&chat_id(1)).await;
assert_eq!(limiter.active_bucket_count(), 1);
}
#[tokio::test]
async fn test_check_rate_limit_and_tokens() {
let limiter = RateLimiter::new();
let cid = chat_id(2);
for _ in 0..5 {
assert!(limiter.check_rate_limit(&cid).await);
}
let mut limited = false;
for _ in 0..100 {
if !limiter.check_rate_limit(&cid).await {
limited = true;
break;
}
}
assert!(
limited,
"Rate limiter должен ограничивать после превышения лимита"
);
}
#[tokio::test]
async fn test_clear_rate_limit_and_capacity() {
let limiter = RateLimiter::new();
let cid = chat_id(3);
let _ = limiter.check_rate_limit(&cid).await;
let cap = limiter.get_bucket_capacity(&cid).await;
assert!(cap.is_some());
limiter.clear_rate_limit(&cid).await;
let _ = limiter.check_rate_limit(&cid).await;
let cap2 = limiter.get_bucket_capacity(&cid).await;
assert_eq!(cap, cap2);
}
#[tokio::test]
async fn test_get_available_tokens() {
let limiter = RateLimiter::new();
let cid = chat_id(4);
let _ = limiter.check_rate_limit(&cid).await;
let tokens = limiter.get_available_tokens(&cid).await;
assert!(tokens.is_some());
}
#[tokio::test]
async fn test_get_global_and_chat_stats() {
let limiter = RateLimiter::new();
let cid = chat_id(5);
for _ in 0..3 {
let _ = limiter.check_rate_limit(&cid).await;
}
let global = limiter.get_global_stats().await;
let chat = limiter.get_chat_stats(&cid).await;
assert!(global.total_requests > 0);
assert!(chat.is_some());
assert!(chat.unwrap().total_requests > 0);
}
#[tokio::test]
async fn test_shutdown_and_post_shutdown_behavior() {
let limiter = RateLimiter::new();
let cid = chat_id(6);
let _ = limiter.check_rate_limit(&cid).await;
limiter.shutdown().await;
let tokens = limiter.get_available_tokens(&cid).await;
sleep(Duration::from_millis(1200)).await;
let tokens_after = limiter.get_available_tokens(&cid).await;
assert_eq!(tokens, tokens_after);
}
#[tokio::test]
async fn test_check_with_priority() {
let limiter = RateLimiter::new();
let cid = chat_id(7);
let mut allowed_high = 0;
let mut allowed_low = 0;
for _ in 0..20 {
if limiter.check_with_priority(&cid, 10).await {
allowed_high += 1;
}
}
for _ in 0..20 {
if limiter.check_with_priority(&cid, 1).await {
allowed_low += 1;
}
}
assert!(allowed_high >= allowed_low);
}
#[tokio::test]
async fn test_concurrent_access() {
let limiter = Arc::new(RateLimiter::new());
let mut handles = vec![];
for i in 0..10 {
let limiter = Arc::clone(&limiter);
let cid = chat_id(100 + i);
handles.push(tokio::spawn(async move {
for _ in 0..10 {
let _ = limiter.check_rate_limit(&cid).await;
}
}));
}
for h in handles {
h.await.unwrap();
}
assert!(limiter.active_bucket_count() >= 10);
}
}