use crate::metrics::Metrics;
use crate::types::{ThrottleRequest, ThrottleResponse};
use anyhow::Result;
use std::sync::Arc;
use throttlecrab::{AdaptiveStore, CellError, PeriodicStore, ProbabilisticStore, RateLimiter};
use tokio::sync::{mpsc, oneshot};
pub enum RateLimiterMessage {
Throttle {
request: ThrottleRequest,
response_tx: oneshot::Sender<Result<ThrottleResponse>>,
},
}
#[derive(Clone)]
pub struct RateLimiterHandle {
tx: mpsc::Sender<RateLimiterMessage>,
#[allow(dead_code)] pub metrics: Arc<Metrics>,
}
impl RateLimiterHandle {
pub async fn throttle(&self, request: ThrottleRequest) -> Result<ThrottleResponse> {
let (response_tx, response_rx) = oneshot::channel();
self.tx
.send(RateLimiterMessage::Throttle {
request,
response_tx,
})
.await
.map_err(|_| anyhow::anyhow!("Rate limiter actor has shut down"))?;
response_rx
.await
.map_err(|_| anyhow::anyhow!("Rate limiter actor dropped response channel"))?
}
}
pub struct RateLimiterActor;
impl RateLimiterActor {
pub fn spawn_periodic(
buffer_size: usize,
store: PeriodicStore,
metrics: Arc<Metrics>,
) -> RateLimiterHandle {
let (tx, rx) = mpsc::channel(buffer_size);
let metrics_clone = Arc::clone(&metrics);
tokio::spawn(async move {
let store_type = StoreType::Periodic(RateLimiter::new(store));
run_actor(rx, store_type, metrics_clone).await;
});
RateLimiterHandle { tx, metrics }
}
pub fn spawn_probabilistic(
buffer_size: usize,
store: ProbabilisticStore,
metrics: Arc<Metrics>,
) -> RateLimiterHandle {
let (tx, rx) = mpsc::channel(buffer_size);
let metrics_clone = Arc::clone(&metrics);
tokio::spawn(async move {
let store_type = StoreType::Probabilistic(RateLimiter::new(store));
run_actor(rx, store_type, metrics_clone).await;
});
RateLimiterHandle { tx, metrics }
}
pub fn spawn_adaptive(
buffer_size: usize,
store: AdaptiveStore,
metrics: Arc<Metrics>,
) -> RateLimiterHandle {
let (tx, rx) = mpsc::channel(buffer_size);
let metrics_clone = Arc::clone(&metrics);
tokio::spawn(async move {
let store_type = StoreType::Adaptive(RateLimiter::new(store));
run_actor(rx, store_type, metrics_clone).await;
});
RateLimiterHandle { tx, metrics }
}
}
enum StoreType {
Periodic(RateLimiter<PeriodicStore>),
Probabilistic(RateLimiter<ProbabilisticStore>),
Adaptive(RateLimiter<AdaptiveStore>),
}
impl StoreType {
fn rate_limit(
&mut self,
key: &str,
max_burst: i64,
count_per_period: i64,
period: i64,
quantity: i64,
timestamp: std::time::SystemTime,
) -> Result<(bool, throttlecrab::RateLimitResult), CellError> {
match self {
StoreType::Periodic(limiter) => limiter.rate_limit(
key,
max_burst,
count_per_period,
period,
quantity,
timestamp,
),
StoreType::Probabilistic(limiter) => limiter.rate_limit(
key,
max_burst,
count_per_period,
period,
quantity,
timestamp,
),
StoreType::Adaptive(limiter) => limiter.rate_limit(
key,
max_burst,
count_per_period,
period,
quantity,
timestamp,
),
}
}
}
async fn run_actor(
mut rx: mpsc::Receiver<RateLimiterMessage>,
mut store_type: StoreType,
_metrics: Arc<Metrics>,
) {
while let Some(msg) = rx.recv().await {
match msg {
RateLimiterMessage::Throttle {
request,
response_tx,
} => {
let response = handle_throttle(&mut store_type, request);
let _ = response_tx.send(response);
}
}
}
tracing::info!("Rate limiter actor shutting down");
}
fn handle_throttle(
store_type: &mut StoreType,
request: ThrottleRequest,
) -> Result<ThrottleResponse> {
let (allowed, result) = store_type
.rate_limit(
&request.key,
request.max_burst,
request.count_per_period,
request.period,
request.quantity,
request.timestamp,
)
.map_err(|e| anyhow::anyhow!("Rate limit check failed: {}", e))?;
Ok(ThrottleResponse::from((allowed, result)))
}