actix_rl/
middleware.rs

1use std::rc::Rc;
2use std::sync::Arc;
3use actix_web::body::{BoxBody, EitherBody, MessageBody};
4use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
5use futures_util::future::{LocalBoxFuture, Ready, ready};
6use crate::controller::{Controller, default_do_rate_limit, default_on_rate_limit_error, default_on_store_error};
7use crate::error::Error;
8use crate::store::{Store, Value};
9use crate::utils::RateLimitByPass;
10
11/// alias of [RateLimit]
12pub type RateLimitMiddleware<T, CB> = RateLimit<T, CB>;
13
14/// [RateLimit] is the rate-limit middleware.
15///
16/// Params [T]: the [Store];
17///
18/// Params [CB]: the response body for [Controller]. (Controller.Body)
19#[derive(Clone)]
20pub struct RateLimit<T: Store, CB: MessageBody = BoxBody> {
21    inner: Arc<RateLimitInner<T, CB>>,
22}
23
24#[derive(Clone)]
25struct RateLimitInner<T: Store, CB: MessageBody = BoxBody> {
26    pub store: T,
27    pub max: <<T as Store>::Value as Value>::Count,
28    pub controller: Controller<T, CB>,
29}
30
31impl<T, CB, S, B> Transform<S, ServiceRequest> for RateLimit<T, CB>
32    where
33        T: Store + 'static,
34        CB: MessageBody + 'static,
35        S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
36        S::Future: 'static,
37        B: 'static,
38        <T as Store>::Key: 'static,
39{
40    type Response = ServiceResponse<EitherBody<B, EitherBody<BoxBody, CB>>>;
41    type Error = S::Error;
42    type Transform = RateLimitService<T, CB, S>;
43    type InitError = ();
44    type Future = Ready<Result<Self::Transform, Self::InitError>>;
45
46    fn new_transform(&self, service: S) -> Self::Future {
47        ready(Ok(RateLimitService {
48            inner: self.inner.clone(),
49            service: Rc::new(service),
50        }))
51    }
52}
53
54#[derive(Clone)]
55pub struct RateLimitService<T, CB, S>
56    where
57        T: Store,
58        CB: MessageBody,
59{
60    inner: Arc<RateLimitInner<T, CB>>,
61    service: Rc<S>,
62}
63
64impl<T, CB, S, B> Service<ServiceRequest> for RateLimitService<T, CB, S>
65    where
66        T: Store + 'static,
67        CB: MessageBody + 'static,
68        S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
69        S::Future: 'static,
70        B: 'static,
71        <T as Store>::Key: 'static,
72{
73    type Response = ServiceResponse<EitherBody<B, EitherBody<BoxBody, CB>>>;
74    type Error = S::Error;
75    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
76
77    forward_ready!(service);
78
79    fn call(&self, svc: ServiceRequest) -> Self::Future {
80        let service = self.service.clone();
81        let inner = self.inner.clone();
82
83        Box::pin(async move {
84            let checked = RateLimitByPass::<T>::checked(svc.request());
85            let do_rate_limit = !checked && if let Some(f) = &inner.controller.fn_do_rate_limit {
86                f(svc.request())
87            } else {
88                // use default function
89                default_do_rate_limit(svc.request())
90            };
91
92            let mut rate_limit_value = None;
93
94            if do_rate_limit {
95                // get identifier of this request
96                let identifier = inner.controller.fn_find_identifier.as_ref()
97                    .map(|f| f(svc.request()));
98
99                if let Some(identifier) = identifier { // continue only when identifier is found.
100                    let req = svc.request();
101                    match inner.store.incr(identifier).await {
102                        Err(e) => {
103                            // store error occur
104                            return if let Some(f) = &inner.controller.fn_on_store_error {
105                                let body = f(req, e);
106                                Ok(ServiceResponse::new(
107                                    req.clone(),
108                                    body.map_into_right_body().map_into_right_body(),
109                                ))
110                            } else {
111                                let body = default_on_store_error::<T>(req, e);
112                                Ok(ServiceResponse::new(
113                                    req.clone(),
114                                    body.map_into_left_body().map_into_right_body(),
115                                ))
116                            }
117
118                        },
119                        Ok(value) => {
120                            if value.count() > inner.max {
121                                // rate limit error occur
122                                let err = Error::RateLimited(value.expire_date());
123
124                                return if let Some(f) = &inner.controller.fn_on_rate_limit_error {
125                                    let body = f(req, err);
126                                    Ok(ServiceResponse::new(
127                                        req.clone(),
128                                        body.map_into_right_body().map_into_right_body(),
129                                    ))
130                                } else {
131                                    let body = default_on_rate_limit_error(req, err);
132                                    Ok(ServiceResponse::new(
133                                        req.clone(),
134                                        body.map_into_left_body().map_into_right_body(),
135                                    ))
136                                }
137                            }
138
139                            rate_limit_value = Some(value);
140                        },
141                    }
142                }
143            }
144
145            // rate-limit bypass
146            // Add a marker to the request to ensure that no further checks are performed on it.
147            RateLimitByPass::<T>::check(svc.request(), rate_limit_value.clone());
148
149            // call on-success
150            if let Some(f) = inner.controller.fn_on_success {
151                f(svc.request(), &inner.store, rate_limit_value.as_ref());
152            }
153
154            // rate-limit bypass
155            let res = service.call(svc).await?.map_into_left_body();
156            Ok(res)
157        })
158    }
159}
160
161impl<T: Store, CB: MessageBody> RateLimit<T, CB> {
162    /// create a new [RateLimit] middleware, with all custom functions.
163    pub fn new(
164        store: T,
165        max: <<T as Store>::Value as Value>::Count,
166        controller: Controller<T, CB>
167    ) -> Self {
168        Self {
169            inner: Arc::new(RateLimitInner {
170                store,
171                max,
172                controller,
173            })
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use actix_web::{App, HttpRequest, HttpResponse, test, web};
181    use actix_web::http::StatusCode;
182    use chrono::{Utc};
183    use tokio::time::Instant;
184    use crate::controller::{default_find_identifier, DEFAULT_RATE_LIMITED_UNTIL_HEADER};
185    use crate::store::MemStore;
186    use super::*;
187
188    async fn empty() -> HttpResponse {
189        HttpResponse::new(StatusCode::NO_CONTENT)
190    }
191
192    #[tokio::test]
193    async fn test_middleware() -> anyhow::Result<()> {
194        let store = MemStore::new(1024, chrono::Duration::seconds(10));
195        let controller = Controller::default();
196
197        let app = test::init_service(
198            App::new()
199                .wrap(RateLimit::new(
200                    store,
201                    10,
202                    controller,
203                ))
204                .route("/", web::get().to(empty))
205        ).await;
206
207        for _ in 0..10 {
208            let req = test::TestRequest::get().to_request();
209            let resp = test::call_service(&app, req).await;
210            assert_eq!(resp.status(), StatusCode::NO_CONTENT);
211        }
212
213        // then, rate limited...
214        let mut wait_until = 0i64;
215        for _ in 0..10 {
216            let req = test::TestRequest::get().to_request();
217            let resp = test::call_service(&app, req).await;
218            assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
219            let ts = resp.headers().get(DEFAULT_RATE_LIMITED_UNTIL_HEADER).unwrap().to_str().unwrap_or_default();
220            wait_until = ts.parse().unwrap();
221        }
222
223        println!("rate limited until: {}", wait_until);
224        tokio::time::sleep_until(
225            Instant::now() + chrono::Duration::seconds(wait_until - Utc::now().timestamp() + 1).to_std().unwrap()
226        ).await;
227
228        // ok...
229        for _ in 0..5 {
230            let req = test::TestRequest::get().to_request();
231            let resp = test::call_service(&app, req).await;
232            assert_eq!(resp.status(), StatusCode::NO_CONTENT);
233        }
234
235        Ok(())
236    }
237
238    fn test_do_rate_limit_default_rate_limit_func(req: &HttpRequest) -> bool {
239        req.path() != "/bypass"
240    }
241
242    #[tokio::test]
243    async fn test_do_rate_limit() -> anyhow::Result<()> {
244        let store = MemStore::new(1024, chrono::Duration::seconds(10));
245
246        let controller = Controller::<_, BoxBody>::new()
247            .with_do_rate_limit(test_do_rate_limit_default_rate_limit_func)
248            .with_find_identifier(default_find_identifier)
249            .on_rate_limit_error(default_on_rate_limit_error)
250            .on_store_error(default_on_store_error::<MemStore>);
251
252        let app = test::init_service(
253            App::new()
254                .wrap(RateLimit::new(
255                    store,
256                    10,
257                    controller,
258                ))
259                .route("/", web::get().to(empty))
260                .route("/bypass", web::get().to(empty))
261        ).await;
262
263        for _ in 0..10 {
264            let req = test::TestRequest::get().to_request();
265            let resp = test::call_service(&app, req).await;
266            assert_eq!(resp.status(), StatusCode::NO_CONTENT);
267        }
268
269        // then, rate limited...
270        let mut wait_until = 0i64;
271        for _ in 0..10 {
272            let req = test::TestRequest::get().to_request();
273            let resp = test::call_service(&app, req).await;
274            assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
275            let ts = resp.headers().get(DEFAULT_RATE_LIMITED_UNTIL_HEADER).unwrap().to_str().unwrap_or_default();
276            wait_until = ts.parse().unwrap();
277        }
278
279        // but /bypass will not be rate_limited
280        for _ in 0..10 {
281            let req = test::TestRequest::get().uri("/bypass").to_request();
282            let resp = test::call_service(&app, req).await;
283            assert_eq!(resp.status(), StatusCode::NO_CONTENT);
284        }
285
286        println!("rate limited until: {}", wait_until);
287        tokio::time::sleep_until(
288            Instant::now() + chrono::Duration::seconds(wait_until - Utc::now().timestamp() + 1).to_std().unwrap()
289        ).await;
290
291        // ok...
292        for _ in 0..5 {
293            let req = test::TestRequest::get().to_request();
294            let resp = test::call_service(&app, req).await;
295            assert_eq!(resp.status(), StatusCode::NO_CONTENT);
296        }
297
298        Ok(())
299    }
300}