Skip to main content

cloudillo_core/rate_limit/
middleware.rs

1// SPDX-FileCopyrightText: Szilárd Hajba
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4//! Rate Limiting Middleware
5//!
6//! Tower middleware layer for applying rate limits to Axum routes.
7
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use axum::body::Body;
12use axum::response::IntoResponse;
13use futures::future::BoxFuture;
14use hyper::Request;
15use tower::{Layer, Service};
16
17use super::extractors::extract_client_ip;
18use super::limiter::RateLimitManager;
19use crate::app::ServerMode;
20
21/// Rate limit middleware layer
22#[derive(Clone)]
23pub struct RateLimitLayer {
24	manager: Arc<RateLimitManager>,
25	category: &'static str,
26	mode: ServerMode,
27}
28
29impl RateLimitLayer {
30	/// Create a new rate limit layer
31	pub fn new(manager: Arc<RateLimitManager>, category: &'static str, mode: ServerMode) -> Self {
32		Self { manager, category, mode }
33	}
34}
35
36impl<S> Layer<S> for RateLimitLayer {
37	type Service = RateLimitService<S>;
38
39	fn layer(&self, inner: S) -> Self::Service {
40		RateLimitService {
41			inner,
42			manager: self.manager.clone(),
43			category: self.category,
44			mode: self.mode,
45		}
46	}
47}
48
49/// Rate limit middleware service
50#[derive(Clone)]
51pub struct RateLimitService<S> {
52	inner: S,
53	manager: Arc<RateLimitManager>,
54	category: &'static str,
55	mode: ServerMode,
56}
57
58impl<S> Service<Request<Body>> for RateLimitService<S>
59where
60	S: Service<Request<Body>, Response = axum::response::Response> + Clone + Send + 'static,
61	S::Future: Send + 'static,
62{
63	type Response = S::Response;
64	type Error = S::Error;
65	type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
66
67	fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
68		self.inner.poll_ready(cx)
69	}
70
71	fn call(&mut self, req: Request<Body>) -> Self::Future {
72		let manager = self.manager.clone();
73		let category = self.category;
74		let mode = self.mode;
75		let mut inner = self.inner.clone();
76
77		Box::pin(async move {
78			// Extract client IP
79			let client_ip = extract_client_ip(&req, &mode);
80
81			if let Some(ip) = client_ip {
82				// Check rate limit
83				if let Err(error) = manager.check(&ip, category) {
84					// Rate limited - return error response
85					return Ok(error.into_response());
86				}
87			}
88
89			// Not rate limited - proceed with request
90			inner.call(req).await
91		})
92	}
93}
94
95// vim: ts=4