1use std::rc::Rc;
2use std::sync::Arc;
3use actix_web::body::{BoxBody, EitherBody, MessageBody};
4use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
5use futures_util::future::{LocalBoxFuture, Ready, ready};
6use crate::controller::{Controller, default_do_rate_limit, default_on_rate_limit_error, default_on_store_error};
7use crate::error::Error;
8use crate::store::{Store, Value};
9use crate::utils::RateLimitByPass;
10
11pub type RateLimitMiddleware<T, CB> = RateLimit<T, CB>;
13
14#[derive(Clone)]
20pub struct RateLimit<T: Store, CB: MessageBody = BoxBody> {
21 inner: Arc<RateLimitInner<T, CB>>,
22}
23
24#[derive(Clone)]
25struct RateLimitInner<T: Store, CB: MessageBody = BoxBody> {
26 pub store: T,
27 pub max: <<T as Store>::Value as Value>::Count,
28 pub controller: Controller<T, CB>,
29}
30
31impl<T, CB, S, B> Transform<S, ServiceRequest> for RateLimit<T, CB>
32 where
33 T: Store + 'static,
34 CB: MessageBody + 'static,
35 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
36 S::Future: 'static,
37 B: 'static,
38 <T as Store>::Key: 'static,
39{
40 type Response = ServiceResponse<EitherBody<B, EitherBody<BoxBody, CB>>>;
41 type Error = S::Error;
42 type Transform = RateLimitService<T, CB, S>;
43 type InitError = ();
44 type Future = Ready<Result<Self::Transform, Self::InitError>>;
45
46 fn new_transform(&self, service: S) -> Self::Future {
47 ready(Ok(RateLimitService {
48 inner: self.inner.clone(),
49 service: Rc::new(service),
50 }))
51 }
52}
53
54#[derive(Clone)]
55pub struct RateLimitService<T, CB, S>
56 where
57 T: Store,
58 CB: MessageBody,
59{
60 inner: Arc<RateLimitInner<T, CB>>,
61 service: Rc<S>,
62}
63
64impl<T, CB, S, B> Service<ServiceRequest> for RateLimitService<T, CB, S>
65 where
66 T: Store + 'static,
67 CB: MessageBody + 'static,
68 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
69 S::Future: 'static,
70 B: 'static,
71 <T as Store>::Key: 'static,
72{
73 type Response = ServiceResponse<EitherBody<B, EitherBody<BoxBody, CB>>>;
74 type Error = S::Error;
75 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
76
77 forward_ready!(service);
78
79 fn call(&self, svc: ServiceRequest) -> Self::Future {
80 let service = self.service.clone();
81 let inner = self.inner.clone();
82
83 Box::pin(async move {
84 let checked = RateLimitByPass::<T>::checked(svc.request());
85 let do_rate_limit = !checked && if let Some(f) = &inner.controller.fn_do_rate_limit {
86 f(svc.request())
87 } else {
88 default_do_rate_limit(svc.request())
90 };
91
92 let mut rate_limit_value = None;
93
94 if do_rate_limit {
95 let identifier = inner.controller.fn_find_identifier.as_ref()
97 .map(|f| f(svc.request()));
98
99 if let Some(identifier) = identifier { let req = svc.request();
101 match inner.store.incr(identifier).await {
102 Err(e) => {
103 return if let Some(f) = &inner.controller.fn_on_store_error {
105 let body = f(req, e);
106 Ok(ServiceResponse::new(
107 req.clone(),
108 body.map_into_right_body().map_into_right_body(),
109 ))
110 } else {
111 let body = default_on_store_error::<T>(req, e);
112 Ok(ServiceResponse::new(
113 req.clone(),
114 body.map_into_left_body().map_into_right_body(),
115 ))
116 }
117
118 },
119 Ok(value) => {
120 if value.count() > inner.max {
121 let err = Error::RateLimited(value.expire_date());
123
124 return if let Some(f) = &inner.controller.fn_on_rate_limit_error {
125 let body = f(req, err);
126 Ok(ServiceResponse::new(
127 req.clone(),
128 body.map_into_right_body().map_into_right_body(),
129 ))
130 } else {
131 let body = default_on_rate_limit_error(req, err);
132 Ok(ServiceResponse::new(
133 req.clone(),
134 body.map_into_left_body().map_into_right_body(),
135 ))
136 }
137 }
138
139 rate_limit_value = Some(value);
140 },
141 }
142 }
143 }
144
145 RateLimitByPass::<T>::check(svc.request(), rate_limit_value.clone());
148
149 if let Some(f) = inner.controller.fn_on_success {
151 f(svc.request(), &inner.store, rate_limit_value.as_ref());
152 }
153
154 let res = service.call(svc).await?.map_into_left_body();
156 Ok(res)
157 })
158 }
159}
160
161impl<T: Store, CB: MessageBody> RateLimit<T, CB> {
162 pub fn new(
164 store: T,
165 max: <<T as Store>::Value as Value>::Count,
166 controller: Controller<T, CB>
167 ) -> Self {
168 Self {
169 inner: Arc::new(RateLimitInner {
170 store,
171 max,
172 controller,
173 })
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use actix_web::{App, HttpRequest, HttpResponse, test, web};
181 use actix_web::http::StatusCode;
182 use chrono::{Utc};
183 use tokio::time::Instant;
184 use crate::controller::{default_find_identifier, DEFAULT_RATE_LIMITED_UNTIL_HEADER};
185 use crate::store::MemStore;
186 use super::*;
187
188 async fn empty() -> HttpResponse {
189 HttpResponse::new(StatusCode::NO_CONTENT)
190 }
191
192 #[tokio::test]
193 async fn test_middleware() -> anyhow::Result<()> {
194 let store = MemStore::new(1024, chrono::Duration::seconds(10));
195 let controller = Controller::default();
196
197 let app = test::init_service(
198 App::new()
199 .wrap(RateLimit::new(
200 store,
201 10,
202 controller,
203 ))
204 .route("/", web::get().to(empty))
205 ).await;
206
207 for _ in 0..10 {
208 let req = test::TestRequest::get().to_request();
209 let resp = test::call_service(&app, req).await;
210 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
211 }
212
213 let mut wait_until = 0i64;
215 for _ in 0..10 {
216 let req = test::TestRequest::get().to_request();
217 let resp = test::call_service(&app, req).await;
218 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
219 let ts = resp.headers().get(DEFAULT_RATE_LIMITED_UNTIL_HEADER).unwrap().to_str().unwrap_or_default();
220 wait_until = ts.parse().unwrap();
221 }
222
223 println!("rate limited until: {}", wait_until);
224 tokio::time::sleep_until(
225 Instant::now() + chrono::Duration::seconds(wait_until - Utc::now().timestamp() + 1).to_std().unwrap()
226 ).await;
227
228 for _ in 0..5 {
230 let req = test::TestRequest::get().to_request();
231 let resp = test::call_service(&app, req).await;
232 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
233 }
234
235 Ok(())
236 }
237
238 fn test_do_rate_limit_default_rate_limit_func(req: &HttpRequest) -> bool {
239 req.path() != "/bypass"
240 }
241
242 #[tokio::test]
243 async fn test_do_rate_limit() -> anyhow::Result<()> {
244 let store = MemStore::new(1024, chrono::Duration::seconds(10));
245
246 let controller = Controller::<_, BoxBody>::new()
247 .with_do_rate_limit(test_do_rate_limit_default_rate_limit_func)
248 .with_find_identifier(default_find_identifier)
249 .on_rate_limit_error(default_on_rate_limit_error)
250 .on_store_error(default_on_store_error::<MemStore>);
251
252 let app = test::init_service(
253 App::new()
254 .wrap(RateLimit::new(
255 store,
256 10,
257 controller,
258 ))
259 .route("/", web::get().to(empty))
260 .route("/bypass", web::get().to(empty))
261 ).await;
262
263 for _ in 0..10 {
264 let req = test::TestRequest::get().to_request();
265 let resp = test::call_service(&app, req).await;
266 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
267 }
268
269 let mut wait_until = 0i64;
271 for _ in 0..10 {
272 let req = test::TestRequest::get().to_request();
273 let resp = test::call_service(&app, req).await;
274 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
275 let ts = resp.headers().get(DEFAULT_RATE_LIMITED_UNTIL_HEADER).unwrap().to_str().unwrap_or_default();
276 wait_until = ts.parse().unwrap();
277 }
278
279 for _ in 0..10 {
281 let req = test::TestRequest::get().uri("/bypass").to_request();
282 let resp = test::call_service(&app, req).await;
283 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
284 }
285
286 println!("rate limited until: {}", wait_until);
287 tokio::time::sleep_until(
288 Instant::now() + chrono::Duration::seconds(wait_until - Utc::now().timestamp() + 1).to_std().unwrap()
289 ).await;
290
291 for _ in 0..5 {
293 let req = test::TestRequest::get().to_request();
294 let resp = test::call_service(&app, req).await;
295 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
296 }
297
298 Ok(())
299 }
300}