use std::num::NonZeroU32;
use std::sync::Arc;
use async_trait::async_trait;
use governor::clock::DefaultClock;
use governor::middleware::NoOpMiddleware;
use governor::state::{InMemoryState, NotKeyed};
use governor::{Quota, RateLimiter};
use super::Middleware;
use crate::error::KojinError;
use crate::message::TaskMessage;
type DirectLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
type KeyedLimiter =
RateLimiter<String, governor::state::keyed::DefaultKeyedStateStore<String>, DefaultClock>;
pub struct RateLimitMiddleware {
inner: Arc<RateLimitInner>,
}
enum RateLimitInner {
Global(DirectLimiter),
PerTask { limiter: KeyedLimiter },
}
impl RateLimitMiddleware {
pub fn global(quota: Quota) -> Self {
Self {
inner: Arc::new(RateLimitInner::Global(RateLimiter::direct(quota))),
}
}
pub fn per_task(quota: Quota) -> Self {
Self {
inner: Arc::new(RateLimitInner::PerTask {
limiter: RateLimiter::keyed(quota),
}),
}
}
pub fn per_second(n: NonZeroU32) -> Self {
Self::global(Quota::per_second(n))
}
pub fn per_second_per_task(n: NonZeroU32) -> Self {
Self::per_task(Quota::per_second(n))
}
}
#[async_trait]
impl Middleware for RateLimitMiddleware {
async fn before(&self, message: &TaskMessage) -> Result<(), KojinError> {
match self.inner.as_ref() {
RateLimitInner::Global(limiter) => {
limiter.until_ready().await;
}
RateLimitInner::PerTask { limiter } => {
limiter.until_key_ready(&message.task_name).await;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[tokio::test]
async fn global_rate_limit_applies_backpressure() {
let mw = RateLimitMiddleware::per_second(NonZeroU32::new(10).unwrap());
let msg = TaskMessage::new("test", "default", serde_json::json!({}));
let start = Instant::now();
for _ in 0..10 {
mw.before(&msg).await.unwrap();
}
let burst_elapsed = start.elapsed();
let wait_start = Instant::now();
mw.before(&msg).await.unwrap();
let wait_elapsed = wait_start.elapsed();
assert!(
burst_elapsed < std::time::Duration::from_millis(50),
"burst should be fast, took {:?}",
burst_elapsed
);
assert!(
wait_elapsed >= std::time::Duration::from_millis(50),
"should have waited, only took {:?}",
wait_elapsed
);
}
#[tokio::test]
async fn per_task_limits_are_independent() {
let mw = RateLimitMiddleware::per_second_per_task(NonZeroU32::new(5).unwrap());
let msg_a = TaskMessage::new("task_a", "default", serde_json::json!({}));
let msg_b = TaskMessage::new("task_b", "default", serde_json::json!({}));
for _ in 0..5 {
mw.before(&msg_a).await.unwrap();
}
let start = Instant::now();
mw.before(&msg_b).await.unwrap();
assert!(
start.elapsed() < std::time::Duration::from_millis(50),
"task_b should not be blocked by task_a"
);
}
#[tokio::test]
async fn rate_limit_with_quota_constructors() {
use governor::Quota;
let _global = RateLimitMiddleware::global(Quota::per_second(NonZeroU32::new(100).unwrap()));
let _per_task =
RateLimitMiddleware::per_task(Quota::per_minute(NonZeroU32::new(60).unwrap()));
let msg = TaskMessage::new("test", "default", serde_json::json!({}));
_global.before(&msg).await.unwrap();
_per_task.before(&msg).await.unwrap();
}
}