use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
use crate::hybrid::HybridRateLimiterProvider;
use crate::{
HardLimitFactor, LocalRateLimiterOptions, LocalRateLimiterProvider, RateGroupSizeMs,
SuppressionFactorCacheMs, TrypemaError, WindowSizeSeconds,
};
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
use crate::redis::{RedisKey, RedisRateLimiterOptions, RedisRateLimiterProvider};
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
use crate::hybrid::SyncIntervalMs;
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
use redis::aio::ConnectionManager;
#[derive(Clone, Debug)]
pub struct RateLimiterOptions {
pub local: LocalRateLimiterOptions,
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
pub redis: RedisRateLimiterOptions,
}
#[cfg(not(any(feature = "redis-tokio", feature = "redis-smol")))]
impl Default for RateLimiterOptions {
fn default() -> Self {
Self {
local: LocalRateLimiterOptions::default(),
}
}
}
pub struct RateLimiter {
local: LocalRateLimiterProvider,
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
redis: RedisRateLimiterProvider,
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
hybrid: HybridRateLimiterProvider,
is_loop_running: AtomicBool,
}
impl Drop for RateLimiter {
fn drop(&mut self) {
self.stop_cleanup_loop();
}
}
impl RateLimiter {
pub fn new(options: RateLimiterOptions) -> Self {
Self {
local: LocalRateLimiterProvider::new(options.local),
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
redis: RedisRateLimiterProvider::new(options.redis.clone()),
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
hybrid: HybridRateLimiterProvider::new(options.redis),
is_loop_running: AtomicBool::new(false),
}
}
pub fn run_cleanup_loop(self: &Arc<Self>) {
self.run_cleanup_loop_with_config(10 * 60 * 1000, 30 * 1000);
}
pub fn run_cleanup_loop_with_config(
self: &Arc<Self>,
stale_after_ms: u64,
cleanup_interval_ms: u64,
) {
if self.is_loop_running.swap(true, Ordering::SeqCst) {
return;
}
#[cfg(not(any(feature = "redis-tokio", feature = "redis-smol")))]
{
let rl = Arc::downgrade(self);
std::thread::spawn(move || {
let interval = Duration::from_millis(cleanup_interval_ms);
std::thread::sleep(interval);
loop {
let Some(rl) = rl.upgrade() else {
break;
};
if !rl.is_loop_running.load(Ordering::SeqCst) {
break;
}
rl.local.cleanup(stale_after_ms);
std::thread::sleep(interval);
}
});
}
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
{
let rl = Arc::downgrade(self);
crate::runtime::spawn_task(async move {
let interval = Duration::from_millis(cleanup_interval_ms);
let mut interval = crate::runtime::new_interval(interval);
crate::runtime::tick(&mut interval).await;
loop {
crate::runtime::tick(&mut interval).await;
let Some(rl) = rl.upgrade() else {
break;
};
if !rl.is_loop_running.load(Ordering::SeqCst) {
break;
}
rl.local.cleanup(stale_after_ms);
if let Err(e) = rl.redis.cleanup(stale_after_ms).await {
tracing::warn!(error = ?e, "Redis cleanup failed, will retry");
}
if let Err(e) = rl.hybrid.cleanup(stale_after_ms).await {
tracing::warn!(error = ?e, "Hybrid cleanup failed, will retry");
}
}
});
}
}
pub fn stop_cleanup_loop(&self) {
self.is_loop_running.store(false, Ordering::SeqCst);
}
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
pub fn redis(&self) -> &RedisRateLimiterProvider {
&self.redis
}
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
pub fn hybrid(&self) -> &HybridRateLimiterProvider {
&self.hybrid
}
pub fn local(&self) -> &LocalRateLimiterProvider {
&self.local
}
}
pub struct RateLimiterBuilder {
window_size_seconds: u64,
rate_group_size_ms: u64,
hard_limit_factor: f64,
suppression_factor_cache_ms: u64,
stale_after_ms: u64,
cleanup_interval_ms: u64,
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
connection_manager: ConnectionManager,
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
redis_prefix: Option<RedisKey>,
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
sync_interval_ms: u64,
}
#[cfg(not(any(feature = "redis-tokio", feature = "redis-smol")))]
impl Default for RateLimiterBuilder {
fn default() -> Self {
Self {
window_size_seconds: *WindowSizeSeconds::default(),
rate_group_size_ms: *RateGroupSizeMs::default(),
hard_limit_factor: *HardLimitFactor::default(),
suppression_factor_cache_ms: *SuppressionFactorCacheMs::default(),
stale_after_ms: 10 * 60 * 1000,
cleanup_interval_ms: 30 * 1000,
}
}
}
#[cfg(not(any(feature = "redis-tokio", feature = "redis-smol")))]
impl RateLimiter {
pub fn builder() -> RateLimiterBuilder {
RateLimiterBuilder::default()
}
}
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
impl RateLimiterBuilder {
fn new_with_connection_manager(connection_manager: ConnectionManager) -> Self {
Self {
window_size_seconds: *WindowSizeSeconds::default(),
rate_group_size_ms: *RateGroupSizeMs::default(),
hard_limit_factor: *HardLimitFactor::default(),
suppression_factor_cache_ms: *SuppressionFactorCacheMs::default(),
stale_after_ms: 10 * 60 * 1000,
cleanup_interval_ms: 30 * 1000,
connection_manager,
redis_prefix: None,
sync_interval_ms: *SyncIntervalMs::default(),
}
}
}
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
impl RateLimiter {
pub fn builder(connection_manager: ConnectionManager) -> RateLimiterBuilder {
RateLimiterBuilder::new_with_connection_manager(connection_manager)
}
}
impl RateLimiterBuilder {
pub fn window_size_seconds(mut self, v: u64) -> Self {
self.window_size_seconds = v;
self
}
pub fn rate_group_size_ms(mut self, v: u64) -> Self {
self.rate_group_size_ms = v;
self
}
pub fn hard_limit_factor(mut self, v: f64) -> Self {
self.hard_limit_factor = v;
self
}
pub fn suppression_factor_cache_ms(mut self, v: u64) -> Self {
self.suppression_factor_cache_ms = v;
self
}
pub fn stale_after_ms(mut self, v: u64) -> Self {
self.stale_after_ms = v;
self
}
pub fn cleanup_interval_ms(mut self, v: u64) -> Self {
self.cleanup_interval_ms = v;
self
}
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
pub fn redis_prefix(mut self, v: RedisKey) -> Self {
self.redis_prefix = Some(v);
self
}
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "redis-tokio", feature = "redis-smol"))))]
pub fn sync_interval_ms(mut self, v: u64) -> Self {
self.sync_interval_ms = v;
self
}
pub fn build(self) -> Result<Arc<RateLimiter>, TrypemaError> {
let options = RateLimiterOptions {
local: LocalRateLimiterOptions {
window_size_seconds: WindowSizeSeconds::try_from(self.window_size_seconds)?,
rate_group_size_ms: RateGroupSizeMs::try_from(self.rate_group_size_ms)?,
hard_limit_factor: HardLimitFactor::try_from(self.hard_limit_factor)?,
suppression_factor_cache_ms: SuppressionFactorCacheMs::try_from(
self.suppression_factor_cache_ms,
)?,
},
#[cfg(any(feature = "redis-tokio", feature = "redis-smol"))]
redis: RedisRateLimiterOptions {
connection_manager: self.connection_manager,
prefix: self.redis_prefix,
window_size_seconds: WindowSizeSeconds::try_from(self.window_size_seconds)?,
rate_group_size_ms: RateGroupSizeMs::try_from(self.rate_group_size_ms)?,
hard_limit_factor: HardLimitFactor::try_from(self.hard_limit_factor)?,
suppression_factor_cache_ms: SuppressionFactorCacheMs::try_from(
self.suppression_factor_cache_ms,
)?,
sync_interval_ms: SyncIntervalMs::try_from(self.sync_interval_ms)?,
},
};
let rl = Arc::new(RateLimiter::new(options));
rl.run_cleanup_loop_with_config(self.stale_after_ms, self.cleanup_interval_ms);
Ok(rl)
}
}