cloudillo_core/rate_limit/
middleware.rs1use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use axum::body::Body;
9use axum::response::IntoResponse;
10use futures::future::BoxFuture;
11use hyper::Request;
12use tower::{Layer, Service};
13
14use super::extractors::extract_client_ip;
15use super::limiter::RateLimitManager;
16use crate::app::ServerMode;
17
18#[derive(Clone)]
20pub struct RateLimitLayer {
21 manager: Arc<RateLimitManager>,
22 category: &'static str,
23 mode: ServerMode,
24}
25
26impl RateLimitLayer {
27 pub fn new(manager: Arc<RateLimitManager>, category: &'static str, mode: ServerMode) -> Self {
29 Self { manager, category, mode }
30 }
31}
32
33impl<S> Layer<S> for RateLimitLayer {
34 type Service = RateLimitService<S>;
35
36 fn layer(&self, inner: S) -> Self::Service {
37 RateLimitService {
38 inner,
39 manager: self.manager.clone(),
40 category: self.category,
41 mode: self.mode,
42 }
43 }
44}
45
46#[derive(Clone)]
48pub struct RateLimitService<S> {
49 inner: S,
50 manager: Arc<RateLimitManager>,
51 category: &'static str,
52 mode: ServerMode,
53}
54
55impl<S> Service<Request<Body>> for RateLimitService<S>
56where
57 S: Service<Request<Body>, Response = axum::response::Response> + Clone + Send + 'static,
58 S::Future: Send + 'static,
59{
60 type Response = S::Response;
61 type Error = S::Error;
62 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
63
64 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65 self.inner.poll_ready(cx)
66 }
67
68 fn call(&mut self, req: Request<Body>) -> Self::Future {
69 let manager = self.manager.clone();
70 let category = self.category;
71 let mode = self.mode;
72 let mut inner = self.inner.clone();
73
74 Box::pin(async move {
75 let client_ip = extract_client_ip(&req, &mode);
77
78 if let Some(ip) = client_ip {
79 if let Err(error) = manager.check(&ip, category) {
81 return Ok(error.into_response());
83 }
84 }
85
86 inner.call(req).await
88 })
89 }
90}
91
92