actix_governor/
key_extractor.rs

1use actix_http::StatusCode;
2use actix_web::{dev::ServiceRequest, http::header::ContentType};
3use actix_web::{HttpResponse, HttpResponseBuilder, ResponseError};
4use governor::clock::{Clock, DefaultClock, QuantaInstant};
5use governor::NotUntil;
6
7use std::error::Error;
8use std::fmt::{Debug, Display};
9use std::{hash::Hash, net::IpAddr};
10
11/// Generic structure of what is needed to extract a rate-limiting key from an incoming request.
12///
13/// ## Example
14/// ```rust
15/// use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
16/// use actix_web::ResponseError;
17/// use actix_web::dev::ServiceRequest;
18///
19/// #[derive(Clone)]
20/// struct Foo;
21///
22/// // will return 500 error and 'Extract error' as content
23/// impl KeyExtractor for Foo {
24///     type Key = ();
25///     type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
26///
27///     fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
28///         Err(SimpleKeyExtractionError::new("Extract error"))
29///     }
30/// }
31/// ```
32///
33/// For more see [`custom_key_bearer`](https://github.com/AaronErhardt/actix-governor/blob/main/examples/custom_key_bearer.rs) example
34pub trait KeyExtractor: Clone {
35    /// The type of the key.
36    type Key: Clone + Hash + Eq;
37
38    /// The type of the error that can occur if key extraction from the request fails.
39    type KeyExtractionError: ResponseError + 'static;
40
41    #[cfg(feature = "log")]
42    /// Name of this extractor (only used in logs).
43    fn name(&self) -> &'static str;
44
45    /// Extraction method, will return [`KeyExtractionError`] response when the extract failed
46    ///
47    /// [`KeyExtractionError`]: KeyExtractor::KeyExtractionError
48    fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError>;
49
50    /// The content you want to show it when the rate limit is exceeded.
51    /// You can calculate the time at which a caller can expect the next positive rate-limiting result by using [`NotUntil`].
52    /// The [`HttpResponseBuilder`] allows you to build a fully customized [`HttpResponse`] in case of an error.
53    /// # Example
54    /// ```rust
55    /// use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
56    /// use actix_web::ResponseError;
57    /// use actix_web::dev::ServiceRequest;
58    /// use governor::{NotUntil, clock::{Clock, QuantaInstant, DefaultClock}};
59    /// use actix_web::{HttpResponse, HttpResponseBuilder};
60    /// use actix_web::http::header::ContentType;
61    ///
62    ///
63    /// #[derive(Clone)]
64    /// struct Foo;
65    ///
66    /// // will return 500 error and 'Extract error' as content
67    /// impl KeyExtractor for Foo {
68    ///     type Key = ();
69    ///     type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
70    ///
71    ///     fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
72    ///         Err(SimpleKeyExtractionError::new("Extract error"))
73    ///     }
74    ///
75    ///     fn exceed_rate_limit_response(
76    ///             &self,
77    ///             negative: &NotUntil<QuantaInstant>,
78    ///             mut response: HttpResponseBuilder,
79    ///         ) -> HttpResponse {
80    ///             let wait_time = negative
81    ///                 .wait_time_from(DefaultClock::default().now())
82    ///                 .as_secs();
83    ///             response
84    ///                 .content_type(ContentType::plaintext())
85    ///                 .body(format!("Too many requests, retry in {}s", wait_time))
86    ///     }
87    /// }
88    /// ```
89    fn exceed_rate_limit_response(
90        &self,
91        negative: &NotUntil<QuantaInstant>,
92        mut response: HttpResponseBuilder,
93    ) -> HttpResponse {
94        let wait_time = negative
95            .wait_time_from(DefaultClock::default().now())
96            .as_secs();
97        response
98            .content_type(ContentType::plaintext())
99            .body(format!("Too many requests, retry in {}s", wait_time))
100    }
101
102    /// Returns a list of whitelisted keys. If a key is in this list, it will never be rate-limited.
103    fn whitelisted_keys(&self) -> Vec<Self::Key> {
104        Vec::new()
105    }
106
107    #[cfg(feature = "log")]
108    /// Value of the extracted key (only used in logs).
109    fn key_name(&self, _key: &Self::Key) -> Option<String> {
110        None
111    }
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115/// A [KeyExtractor] that allow to do rate limiting for all incoming requests. This is useful if you want to hard-limit the HTTP load your app can handle.
116pub struct GlobalKeyExtractor;
117
118#[derive(Debug)]
119/// A simple struct to create  error, by default the status is  500 server error and content-type is plintext
120pub struct SimpleKeyExtractionError<T: Display + Debug> {
121    /// The response body of the error.
122    pub body: T,
123    /// The status code of the error.
124    pub status_code: StatusCode,
125    /// The content type of the error.
126    pub content_type: ContentType,
127}
128
129impl<T: Display + Debug> SimpleKeyExtractionError<T> {
130    /// Create new instance by body
131    ///
132    /// # Example
133    /// ```rust
134    /// use actix_governor::SimpleKeyExtractionError;
135    /// use actix_http::StatusCode;
136    /// use actix_web::http::header::ContentType;
137    ///
138    /// let my_error = SimpleKeyExtractionError::new("Some error content");
139    ///
140    /// assert_eq!(my_error.body, "Some error content");
141    /// assert_eq!(my_error.content_type, ContentType::plaintext());
142    /// assert_eq!(my_error.status_code, StatusCode::INTERNAL_SERVER_ERROR);
143    /// ```
144    pub fn new(body: T) -> Self {
145        Self {
146            body,
147            status_code: StatusCode::INTERNAL_SERVER_ERROR,
148            content_type: ContentType::plaintext(),
149        }
150    }
151
152    /// Set a new status code, the default is [`StatusCode::INTERNAL_SERVER_ERROR`]
153    ///
154    /// # Example
155    /// ```rust
156    /// use actix_governor::SimpleKeyExtractionError;
157    /// use actix_http::StatusCode;
158    /// use actix_web::http::header::ContentType;
159    ///
160    /// let my_error = SimpleKeyExtractionError::new("Some error content")
161    ///         .set_status_code(StatusCode::FORBIDDEN);
162    ///
163    /// assert_eq!(my_error.body, "Some error content");
164    /// assert_eq!(my_error.content_type, ContentType::plaintext());
165    /// assert_eq!(my_error.status_code, StatusCode::FORBIDDEN);
166    /// ```
167    pub fn set_status_code(mut self, status_code: StatusCode) -> Self {
168        self.status_code = status_code;
169        Self { ..self }
170    }
171
172    /// Set a new content type, the default is `text/plain`
173    ///
174    /// # Example
175    /// ```rust
176    /// use actix_governor::SimpleKeyExtractionError;
177    /// use actix_http::StatusCode;
178    /// use actix_web::http::header::ContentType;
179    ///
180    /// let my_error = SimpleKeyExtractionError::new(r#"{"msg":"Some error content"}"#)
181    ///         .set_content_type(ContentType::json());
182    ///
183    /// assert_eq!(my_error.body, r#"{"msg":"Some error content"}"#);
184    /// assert_eq!(my_error.content_type, ContentType::json());
185    /// assert_eq!(my_error.status_code, StatusCode::INTERNAL_SERVER_ERROR);
186    /// ```
187    pub fn set_content_type(mut self, content_type: ContentType) -> Self {
188        self.content_type = content_type;
189        Self { ..self }
190    }
191}
192
193impl<T: Display + Debug> Display for SimpleKeyExtractionError<T> {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        write!(f, "SimpleKeyExtractionError")
196    }
197}
198
199impl<T: Display + Debug> Error for SimpleKeyExtractionError<T> {}
200
201impl<T: Display + Debug> ResponseError for SimpleKeyExtractionError<T> {
202    fn status_code(&self) -> StatusCode {
203        self.status_code
204    }
205
206    fn error_response(&self) -> HttpResponse<actix_http::body::BoxBody> {
207        HttpResponseBuilder::new(self.status_code())
208            .content_type(self.content_type.clone())
209            .body(self.body.to_string())
210    }
211}
212
213impl KeyExtractor for GlobalKeyExtractor {
214    type Key = ();
215    type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
216
217    #[cfg(feature = "log")]
218    fn name(&self) -> &'static str {
219        "global"
220    }
221
222    fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
223        Ok(())
224    }
225
226    #[cfg(feature = "log")]
227    fn key_name(&self, _key: &Self::Key) -> Option<String> {
228        None
229    }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233/// A [KeyExtractor] that uses peer IP as key. **This is the default key extractor and [it may no do want you want](PeerIpKeyExtractor).**
234///
235/// **Warning:** this key extractor enforces rate limiting based on the **IPv4 _peer_ IP address**
236/// or the **IPv6 /56 prefix of the _peer_ IP address**.
237///
238/// This means that if your app is deployed behind a reverse proxy, the peer IP address will _always_ be the proxy's IP address.
239/// In this case, rate limiting will be applied to _all_ incoming requests as if they were from the same user.
240///
241/// If this is not the behavior you want, you may:
242/// - implement your own [KeyExtractor] that tries to get IP from the `Forwarded` or `X-Forwarded-For` headers that most reverse proxies set
243/// - make absolutely sure that you only trust these headers when the peer IP is the IP of your reverse proxy (otherwise any user could set them to fake its IP)
244pub struct PeerIpKeyExtractor;
245
246impl KeyExtractor for PeerIpKeyExtractor {
247    type Key = IpAddr;
248    type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
249
250    #[cfg(feature = "log")]
251    fn name(&self) -> &'static str {
252        "peer IP"
253    }
254
255    fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
256        let mut ip = req.peer_addr().map(|socket| socket.ip()).ok_or_else(|| {
257            SimpleKeyExtractionError::new("Could not extract peer IP address from request")
258        })?;
259        // customers often get their own /56 prefix, apply rate-limiting per prefix instead of per
260        // address for IPv6
261        if let IpAddr::V6(ipv6) = ip {
262            let mut octets = ipv6.octets();
263            octets[7..16].fill(0);
264            ip = IpAddr::V6(octets.into());
265        }
266        Ok(ip)
267    }
268
269    #[cfg(feature = "log")]
270    fn key_name(&self, key: &Self::Key) -> Option<String> {
271        Some(key.to_string())
272    }
273}