use std::collections::HashMap;
use std::time::Duration;
use crate::storage::error::Result;
use crate::storage::{RateLimitOutcome, Storage};
const RETRY_AFTER_FLOOR_SECS: f64 = 1.0;
const RETRY_AFTER_CEILING_SECS: f64 = 60.0;
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum AcquireOutcome {
Granted,
Throttled { retry_after: Duration },
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
storage: Storage,
refill_per_sec: HashMap<&'static str, f64>,
}
impl RateLimiter {
#[must_use]
pub fn new(storage: Storage, defaults: &[(&'static str, i64, f64)]) -> Self {
let refill_per_sec = defaults.iter().map(|(s, _, r)| (*s, *r)).collect();
Self {
storage,
refill_per_sec,
}
}
pub async fn acquire(&self, scope: &str) -> Result<AcquireOutcome> {
match self.storage.rate_limit.acquire(scope).await? {
RateLimitOutcome::Granted => Ok(AcquireOutcome::Granted),
RateLimitOutcome::Throttled => Ok(AcquireOutcome::Throttled {
retry_after: self.retry_after(scope),
}),
}
}
pub async fn drain(&self, scope: &str) -> Result<()> {
self.storage.rate_limit.drain(scope).await
}
fn retry_after(&self, scope: &str) -> Duration {
let refill = self
.refill_per_sec
.get(scope)
.copied()
.filter(|r| r.is_finite() && *r > 0.0);
let secs = refill
.map_or(RETRY_AFTER_CEILING_SECS, |r| 2.0 / r)
.clamp(RETRY_AFTER_FLOOR_SECS, RETRY_AFTER_CEILING_SECS);
Duration::from_secs_f64(secs)
}
}
pub async fn ensure_default_rate_limits(
storage: &Storage,
defaults: &[(&'static str, i64, f64)],
) -> Result<()> {
for (scope, capacity, refill) in defaults {
storage
.rate_limit
.ensure_default(scope, *capacity, *refill)
.await?;
}
Ok(())
}
pub const DEFAULT_RATE_LIMIT_SCOPES: &[(&str, i64, f64)] =
&[("slack", 50, 50.0 / 60.0), ("gh", 5000, 5000.0 / 3600.0)];
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::unnecessary_semicolon,
clippy::unreadable_literal,
reason = "unit tests crash loudly on setup/assert failures; that's the point"
)]
use std::sync::Arc;
use super::*;
use crate::storage::Storage as JobStorage;
use crate::storage::sqlite::SqliteStorage;
async fn fresh() -> JobStorage {
let s = Arc::new(
SqliteStorage::open_in_memory()
.await
.expect("open_in_memory"),
);
JobStorage::from_one(s)
}
#[tokio::test]
async fn acquire_returns_granted_then_throttled_with_sane_retry_after() {
let storage = fresh().await;
ensure_default_rate_limits(&storage, &[("test", 1, 1.0)])
.await
.unwrap();
let limiter = RateLimiter::new(storage, &[("test", 1, 1.0)]);
let first = limiter.acquire("test").await.unwrap();
assert!(
matches!(first, AcquireOutcome::Granted),
"first acquire grants"
);
let second = limiter.acquire("test").await.unwrap();
match second {
AcquireOutcome::Throttled { retry_after } => {
assert_eq!(retry_after, Duration::from_secs(2));
}
AcquireOutcome::Granted => panic!("second acquire must throttle on capacity=1"),
}
}
#[tokio::test]
async fn unknown_scope_throttles_with_ceiling_retry_after() {
let storage = fresh().await;
let limiter = RateLimiter::new(storage, &[]);
match limiter.acquire("ghost").await.unwrap() {
AcquireOutcome::Throttled { retry_after } => {
assert_eq!(retry_after, Duration::from_mins(1));
}
AcquireOutcome::Granted => panic!("unknown scope must throttle"),
}
}
#[tokio::test]
async fn retry_after_clamps_below_floor_and_above_ceiling() {
let storage = fresh().await;
ensure_default_rate_limits(&storage, &[("fast", 1, 1000.0), ("slow", 1, 0.01)])
.await
.unwrap();
let limiter = RateLimiter::new(storage.clone(), &[("fast", 1, 1000.0), ("slow", 1, 0.01)]);
let _ = limiter.acquire("fast").await.unwrap();
match limiter.acquire("fast").await.unwrap() {
AcquireOutcome::Throttled { retry_after } => {
assert_eq!(retry_after, Duration::from_secs(1), "must clamp to floor");
}
AcquireOutcome::Granted => {}
}
let _ = limiter.acquire("slow").await.unwrap();
match limiter.acquire("slow").await.unwrap() {
AcquireOutcome::Throttled { retry_after } => {
assert_eq!(retry_after, Duration::from_mins(1), "must clamp to ceiling");
}
AcquireOutcome::Granted => {}
}
}
}