Skip to main content

cloudillo_core/rate_limit/
middleware.rs

1//! Rate Limiting Middleware
2//!
3//! Tower middleware layer for applying rate limits to Axum routes.
4
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use axum::body::Body;
9use axum::response::IntoResponse;
10use futures::future::BoxFuture;
11use hyper::Request;
12use tower::{Layer, Service};
13
14use super::extractors::extract_client_ip;
15use super::limiter::RateLimitManager;
16use crate::app::ServerMode;
17
18/// Rate limit middleware layer
19#[derive(Clone)]
20pub struct RateLimitLayer {
21	manager: Arc<RateLimitManager>,
22	category: &'static str,
23	mode: ServerMode,
24}
25
26impl RateLimitLayer {
27	/// Create a new rate limit layer
28	pub fn new(manager: Arc<RateLimitManager>, category: &'static str, mode: ServerMode) -> Self {
29		Self { manager, category, mode }
30	}
31}
32
33impl<S> Layer<S> for RateLimitLayer {
34	type Service = RateLimitService<S>;
35
36	fn layer(&self, inner: S) -> Self::Service {
37		RateLimitService {
38			inner,
39			manager: self.manager.clone(),
40			category: self.category,
41			mode: self.mode,
42		}
43	}
44}
45
46/// Rate limit middleware service
47#[derive(Clone)]
48pub struct RateLimitService<S> {
49	inner: S,
50	manager: Arc<RateLimitManager>,
51	category: &'static str,
52	mode: ServerMode,
53}
54
55impl<S> Service<Request<Body>> for RateLimitService<S>
56where
57	S: Service<Request<Body>, Response = axum::response::Response> + Clone + Send + 'static,
58	S::Future: Send + 'static,
59{
60	type Response = S::Response;
61	type Error = S::Error;
62	type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
63
64	fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65		self.inner.poll_ready(cx)
66	}
67
68	fn call(&mut self, req: Request<Body>) -> Self::Future {
69		let manager = self.manager.clone();
70		let category = self.category;
71		let mode = self.mode;
72		let mut inner = self.inner.clone();
73
74		Box::pin(async move {
75			// Extract client IP
76			let client_ip = extract_client_ip(&req, &mode);
77
78			if let Some(ip) = client_ip {
79				// Check rate limit
80				if let Err(error) = manager.check(&ip, category) {
81					// Rate limited - return error response
82					return Ok(error.into_response());
83				}
84			}
85
86			// Not rate limited - proceed with request
87			inner.call(req).await
88		})
89	}
90}
91
92// vim: ts=4