use std::{
sync::{Arc, Weak},
time::Duration,
};
use dashmap::DashMap;
use moka::future::{Cache, CacheBuilder};
use tokio::{sync::Semaphore, time};
use super::types::RateLimiter;
#[derive(Clone)]
pub struct RateLimiterCache {
limiters: Cache<String, Arc<RateLimiter>>,
active_limiters: Arc<DashMap<String, Weak<RateLimiter>>>,
rate_kbs: u64,
burst_kbs: u64,
}
impl RateLimiterCache {
pub fn new(
max_capacity: u64,
time_to_live: u64,
rate_kbs: u64,
burst_kbs: u64,
) -> Self {
let active_limiters = Arc::new(DashMap::new());
let active_limiters_clone = active_limiters.clone();
let limiters = CacheBuilder::new(max_capacity)
.time_to_live(Duration::from_secs(time_to_live))
.eviction_listener(move |key: Arc<String>, _value, _cause| {
active_limiters_clone.remove(key.as_ref());
})
.build();
Self {
limiters,
active_limiters,
rate_kbs,
burst_kbs,
}
}
pub async fn fetch_limiter(&self, device_id: &str) -> Arc<RateLimiter> {
if self.rate_kbs == 0 {
return Arc::new(RateLimiter {
semaphore: Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)),
});
}
self.limiters
.get_with(device_id.to_string(), async {
let bytes_per_sec = self
.rate_kbs
.checked_mul(1024)
.map(|v| v as usize)
.unwrap_or(Semaphore::MAX_PERMITS)
.min(Semaphore::MAX_PERMITS);
let limiter = Arc::new(RateLimiter {
semaphore: Arc::new(Semaphore::new(bytes_per_sec)),
});
self.active_limiters
.insert(device_id.to_string(), Arc::downgrade(&limiter));
limiter
})
.await
}
pub fn start_refill_task(&self) {
if self.rate_kbs == 0 {
return;
}
let active_limiters = self.active_limiters.clone();
let bytes_to_add_per_second = self
.rate_kbs
.checked_mul(1024)
.map(|v| v as usize)
.unwrap_or(Semaphore::MAX_PERMITS)
.min(Semaphore::MAX_PERMITS);
let max_permits =
if self.burst_kbs >= self.rate_kbs && self.rate_kbs > 0 {
self.burst_kbs
.checked_mul(1024)
.map(|v| v as usize)
.unwrap_or(Semaphore::MAX_PERMITS)
.min(Semaphore::MAX_PERMITS)
} else {
((bytes_to_add_per_second as f64 * 1.2) as usize)
.min(Semaphore::MAX_PERMITS)
};
tokio::spawn(async move {
let mut interval = time::interval(Duration::from_secs(1));
loop {
interval.tick().await;
active_limiters.retain(|_key, weak_limiter| {
if let Some(limiter) = weak_limiter.upgrade() {
let current_permits =
limiter.semaphore.available_permits();
if current_permits < max_permits {
let to_add = (max_permits - current_permits)
.min(bytes_to_add_per_second);
limiter.semaphore.add_permits(to_add);
}
true
} else {
false
}
});
}
});
}
pub fn get_limiters_count(&self) -> u64 {
self.limiters.entry_count()
}
}