1use actix_web::{HttpRequest, HttpResponse, HttpResponseBuilder};
2use actix_web::body::{BoxBody, MessageBody};
3use actix_web::http::StatusCode;
4use crate::error::Error;
5use crate::store::Store;
6
7pub(crate) type FromRequestFunc<I> = fn(&HttpRequest) -> I;
8pub(crate) type FromRequestWithRef<S, V> = fn(&HttpRequest, &S, Option<&V>);
9pub(crate) type FromRequestOnError<E, R> = fn(&HttpRequest, E) -> R;
10
11#[derive(Clone)]
12pub struct Controller<T: Store, B: MessageBody = BoxBody> {
13    pub(crate) fn_do_rate_limit: Option<FromRequestFunc<bool>>,
14    pub(crate) fn_find_identifier: Option<FromRequestFunc<T::Key>>,
15    pub(crate) fn_on_rate_limit_error: Option<FromRequestOnError<Error, HttpResponse<B>>>,
16    pub(crate) fn_on_store_error: Option<FromRequestOnError<<T as Store>::Error, HttpResponse<B>>>,
17    pub(crate) fn_on_success: Option<FromRequestWithRef<T, T::Value>>,
18}
19
20impl<T: Store, B: MessageBody> Controller<T, B> {
21    pub fn new() -> Self {
23        Self {
24            fn_do_rate_limit: None,
25            fn_find_identifier: None,
26            fn_on_rate_limit_error: None,
27            fn_on_store_error: None,
28            fn_on_success: None,
29        }
30    }
31
32    pub fn with_do_rate_limit(mut self, f: FromRequestFunc<bool>) -> Self {
35        self.fn_do_rate_limit = Some(f);
36        self
37    }
38
39    pub fn with_find_identifier(mut self, f: FromRequestFunc<T::Key>) -> Self {
41        self.fn_find_identifier = Some(f);
42        self
43    }
44
45    pub fn on_rate_limit_error(mut self, f: FromRequestOnError<Error, HttpResponse<B>>) -> Self {
47        self.fn_on_rate_limit_error = Some(f);
48        self
49    }
50
51    pub fn on_store_error(mut self, f: FromRequestOnError<<T as Store>::Error, HttpResponse<B>>) -> Self {
54        self.fn_on_store_error = Some(f);
55        self
56    }
57
58    pub fn on_success(mut self, f: FromRequestWithRef<T, T::Value>) -> Self {
61        self.fn_on_success = Some(f);
62        self
63    }
64}
65
66impl<T> Default for Controller<T, BoxBody>
67    where T: Store<Key = String> + 'static,
68{
69    fn default() -> Self {
71        Self::new()
72            .with_do_rate_limit(default_do_rate_limit)
73            .with_find_identifier(default_find_identifier)
74            .on_rate_limit_error(default_on_rate_limit_error)
75            .on_store_error(default_on_store_error::<T>)
76    }
77}
78
79pub(crate) fn default_do_rate_limit(_: &HttpRequest) -> bool {
80    true
81}
82
83pub(crate) fn default_find_identifier(req: &HttpRequest) -> String {
84    req.peer_addr()
85        .map(|addr| addr.ip().to_string())
86        .unwrap_or("<Unknown Source IP>".to_string())
87}
88
89pub const DEFAULT_RATE_LIMITED_UNTIL_HEADER: &str = "X-Rate-Limited-Until";
90
91pub(crate) fn default_on_rate_limit_error(_: &HttpRequest, error: Error) -> HttpResponse {
92    match error {
93        Error::RateLimited(until) => {
94            let mut builder = HttpResponseBuilder::new(StatusCode::TOO_MANY_REQUESTS);
95
96            if let Some(until) = until {
97                builder.insert_header((DEFAULT_RATE_LIMITED_UNTIL_HEADER, until.timestamp().to_string()));
98            }
99
100            builder.finish()
101        }
102    }
103}
104
105pub(crate) fn default_on_store_error<T: Store>(_: &HttpRequest, _: T::Error) -> HttpResponse {
106    HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
107}