actix_rl/
controller.rs

1use actix_web::{HttpRequest, HttpResponse, HttpResponseBuilder};
2use actix_web::body::{BoxBody, MessageBody};
3use actix_web::http::StatusCode;
4use crate::error::Error;
5use crate::store::Store;
6
7pub(crate) type FromRequestFunc<I> = fn(&HttpRequest) -> I;
8pub(crate) type FromRequestWithRef<S, V> = fn(&HttpRequest, &S, Option<&V>);
9pub(crate) type FromRequestOnError<E, R> = fn(&HttpRequest, E) -> R;
10
11#[derive(Clone)]
12pub struct Controller<T: Store, B: MessageBody = BoxBody> {
13    pub(crate) fn_do_rate_limit: Option<FromRequestFunc<bool>>,
14    pub(crate) fn_find_identifier: Option<FromRequestFunc<T::Key>>,
15    pub(crate) fn_on_rate_limit_error: Option<FromRequestOnError<Error, HttpResponse<B>>>,
16    pub(crate) fn_on_store_error: Option<FromRequestOnError<<T as Store>::Error, HttpResponse<B>>>,
17    pub(crate) fn_on_success: Option<FromRequestWithRef<T, T::Value>>,
18}
19
20impl<T: Store, B: MessageBody> Controller<T, B> {
21    /// Create a default Controller, with all functions as [None]
22    pub fn new() -> Self {
23        Self {
24            fn_do_rate_limit: None,
25            fn_find_identifier: None,
26            fn_on_rate_limit_error: None,
27            fn_on_store_error: None,
28            fn_on_success: None,
29        }
30    }
31
32    /// Determine if a request needs to be checked for rate limiting.
33    /// If not set, all requests will be checked.
34    pub fn with_do_rate_limit(mut self, f: FromRequestFunc<bool>) -> Self {
35        self.fn_do_rate_limit = Some(f);
36        self
37    }
38
39    /// Extract the identifier from the request, such as the IP address or other information.
40    pub fn with_find_identifier(mut self, f: FromRequestFunc<T::Key>) -> Self {
41        self.fn_find_identifier = Some(f);
42        self
43    }
44
45    /// Set the [`HttpResponse<B>`] to be returned when a rate-limit error occurs.
46    pub fn on_rate_limit_error(mut self, f: FromRequestOnError<Error, HttpResponse<B>>) -> Self {
47        self.fn_on_rate_limit_error = Some(f);
48        self
49    }
50
51    /// Set the [`HttpResponse<B>`] to be returned when an error occurs in the [Store]
52    /// (such as Redis or other storage structures).
53    pub fn on_store_error(mut self, f: FromRequestOnError<<T as Store>::Error, HttpResponse<B>>) -> Self {
54        self.fn_on_store_error = Some(f);
55        self
56    }
57
58    /// Execute this function whenever a request successfully passes
59    /// (including those skipped by [Self::fn_do_rate_limit]).
60    pub fn on_success(mut self, f: FromRequestWithRef<T, T::Value>) -> Self {
61        self.fn_on_success = Some(f);
62        self
63    }
64}
65
66impl<T> Default for Controller<T, BoxBody>
67    where T: Store<Key = String> + 'static,
68{
69    /// alias of [Self::new], but use default functions.
70    fn default() -> Self {
71        Self::new()
72            .with_do_rate_limit(default_do_rate_limit)
73            .with_find_identifier(default_find_identifier)
74            .on_rate_limit_error(default_on_rate_limit_error)
75            .on_store_error(default_on_store_error::<T>)
76    }
77}
78
79pub(crate) fn default_do_rate_limit(_: &HttpRequest) -> bool {
80    true
81}
82
83pub(crate) fn default_find_identifier(req: &HttpRequest) -> String {
84    req.peer_addr()
85        .map(|addr| addr.ip().to_string())
86        .unwrap_or("<Unknown Source IP>".to_string())
87}
88
89pub const DEFAULT_RATE_LIMITED_UNTIL_HEADER: &str = "X-Rate-Limited-Until";
90
91pub(crate) fn default_on_rate_limit_error(_: &HttpRequest, error: Error) -> HttpResponse {
92    match error {
93        Error::RateLimited(until) => {
94            let mut builder = HttpResponseBuilder::new(StatusCode::TOO_MANY_REQUESTS);
95
96            if let Some(until) = until {
97                builder.insert_header((DEFAULT_RATE_LIMITED_UNTIL_HEADER, until.timestamp().to_string()));
98            }
99
100            builder.finish()
101        }
102    }
103}
104
105pub(crate) fn default_on_store_error<T: Store>(_: &HttpRequest, _: T::Error) -> HttpResponse {
106    HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
107}