cloudillo_core/rate_limit/
middleware.rs1use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use axum::body::Body;
12use axum::response::IntoResponse;
13use futures::future::BoxFuture;
14use hyper::Request;
15use tower::{Layer, Service};
16
17use super::extractors::extract_client_ip;
18use super::limiter::RateLimitManager;
19use crate::app::ServerMode;
20
21#[derive(Clone)]
23pub struct RateLimitLayer {
24 manager: Arc<RateLimitManager>,
25 category: &'static str,
26 mode: ServerMode,
27}
28
29impl RateLimitLayer {
30 pub fn new(manager: Arc<RateLimitManager>, category: &'static str, mode: ServerMode) -> Self {
32 Self { manager, category, mode }
33 }
34}
35
36impl<S> Layer<S> for RateLimitLayer {
37 type Service = RateLimitService<S>;
38
39 fn layer(&self, inner: S) -> Self::Service {
40 RateLimitService {
41 inner,
42 manager: self.manager.clone(),
43 category: self.category,
44 mode: self.mode,
45 }
46 }
47}
48
49#[derive(Clone)]
51pub struct RateLimitService<S> {
52 inner: S,
53 manager: Arc<RateLimitManager>,
54 category: &'static str,
55 mode: ServerMode,
56}
57
58impl<S> Service<Request<Body>> for RateLimitService<S>
59where
60 S: Service<Request<Body>, Response = axum::response::Response> + Clone + Send + 'static,
61 S::Future: Send + 'static,
62{
63 type Response = S::Response;
64 type Error = S::Error;
65 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
66
67 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
68 self.inner.poll_ready(cx)
69 }
70
71 fn call(&mut self, req: Request<Body>) -> Self::Future {
72 let manager = self.manager.clone();
73 let category = self.category;
74 let mode = self.mode;
75 let mut inner = self.inner.clone();
76
77 Box::pin(async move {
78 let client_ip = extract_client_ip(&req, &mode);
80
81 if let Some(ip) = client_ip {
82 if let Err(error) = manager.check(&ip, category) {
84 return Ok(error.into_response());
86 }
87 }
88
89 inner.call(req).await
91 })
92 }
93}
94
95