use async_trait::async_trait;
use std::sync::Arc;
use governor::{Quota, RateLimiter};
use governor::state::{NotKeyed, InMemoryState};
use governor::clock::DefaultClock;
use crate::error::{WorkerError, WorkerResult};
use crate::message::ReceivedMessage;
use crate::middleware::{MessageHandler, Middleware};
pub struct RateLimitMiddleware {
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
name: String,
}
impl std::fmt::Debug for RateLimitMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimitMiddleware")
.field("name", &self.name)
.finish()
}
}
impl RateLimitMiddleware {
pub fn new(rate: u64, burst: u32) -> Self {
let quota = Quota::per_second(std::num::NonZeroU32::new(rate as u32).unwrap_or(std::num::NonZeroU32::new(1).unwrap()))
.allow_burst(std::num::NonZeroU32::new(burst).unwrap_or(std::num::NonZeroU32::new(1).unwrap()));
Self {
limiter: Arc::new(RateLimiter::direct(quota)),
name: format!("rate-limit-{}rps", rate),
}
}
pub fn with_rate(rate: u64) -> Self {
Self::new(rate, rate as u32)
}
}
#[async_trait]
impl Middleware for RateLimitMiddleware {
fn name(&self) -> &str {
&self.name
}
async fn handle(
&self,
message: ReceivedMessage<serde_json::Value>,
next: Box<dyn MessageHandler>,
) -> WorkerResult<()> {
match self.limiter.check() {
Ok(_) => {
next.handle(message).await
}
Err(_) => {
Err(WorkerError::ProcessingFailed(format!(
"Rate limit exceeded for middleware '{}'",
self.name
)))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{Message, MessageMetadata, AckHandle};
use std::sync::Arc;
use tokio::time::{self, Duration};
struct SuccessHandler;
#[async_trait]
impl MessageHandler for SuccessHandler {
async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
Ok(())
}
}
fn create_test_message() -> ReceivedMessage<serde_json::Value> {
#[derive(Debug)]
struct MockAckHandle;
#[async_trait]
impl AckHandle for MockAckHandle {
async fn ack(&self) -> WorkerResult<()> {
Ok(())
}
async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
Ok(())
}
}
let message = Message {
id: "test-1".to_string(),
payload: serde_json::json!({"test": "data"}),
metadata: MessageMetadata::new("test-queue"),
};
ReceivedMessage::new(message, Arc::new(MockAckHandle))
}
#[tokio::test]
async fn test_rate_limit_allows_within_limit() {
let rate = 10;
let burst = 10;
let middleware = RateLimitMiddleware::new(rate, burst);
for _ in 0..burst {
let message = create_test_message();
assert!(middleware.handle(message, Box::new(SuccessHandler)).await.is_ok());
}
}
#[tokio::test]
async fn test_rate_limit_rejects_over_limit() {
let rate = 10;
let burst = 5;
let middleware = RateLimitMiddleware::new(rate, burst);
for _ in 0..burst {
let message = create_test_message();
middleware.handle(message, Box::new(SuccessHandler)).await.unwrap();
}
let message = create_test_message();
let result = middleware.handle(message, Box::new(SuccessHandler)).await;
assert!(result.is_err());
if let Err(WorkerError::ProcessingFailed(_)) = result {
} else {
panic!("Expected ProcessingFailed (rate limited)");
}
}
#[tokio::test]
async fn test_rate_limit_with_custom_burst() {
let rate = 100;
let burst = 20;
let middleware = RateLimitMiddleware::new(rate, burst);
for _ in 0..burst {
let message = create_test_message();
assert!(middleware.handle(message, Box::new(SuccessHandler)).await.is_ok());
}
let message = create_test_message();
assert!(middleware.handle(message, Box::new(SuccessHandler)).await.is_err());
let message = create_test_message();
if let Err(WorkerError::ProcessingFailed(_)) = middleware.handle(message, Box::new(SuccessHandler)).await {
} else {
panic!("Expected ProcessingFailed (rate limited)");
}
}
#[tokio::test]
async fn test_rate_limit_refills_over_time() {
let rate = 1; let burst = 1;
let middleware = RateLimitMiddleware::new(rate, burst);
let message = create_test_message();
assert!(middleware.handle(message, Box::new(SuccessHandler)).await.is_ok());
let message = create_test_message();
assert!(middleware.handle(message, Box::new(SuccessHandler)).await.is_err());
time::sleep(Duration::from_millis(1100)).await;
let message = create_test_message();
assert!(middleware.handle(message, Box::new(SuccessHandler)).await.is_ok());
}
}