actix_governor/
key_extractor.rs

1use actix_http::StatusCode;
2use actix_web::{dev::ServiceRequest, http::header::ContentType};
3use actix_web::{HttpResponse, HttpResponseBuilder, ResponseError};
4use governor::clock::{Clock, DefaultClock, QuantaInstant};
5use governor::NotUntil;
6
7use std::fmt::{Debug, Display};
8use std::{hash::Hash, net::IpAddr};
9
10/// Generic structure of what is needed to extract a rate-limiting key from an incoming request.
11///
12/// ## Example
13/// ```rust
14/// use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
15/// use actix_web::ResponseError;
16/// use actix_web::dev::ServiceRequest;
17///
18/// #[derive(Clone)]
19/// struct Foo;
20///
21/// // will return 500 error and 'Extract error' as content
22/// impl KeyExtractor for Foo {
23///     type Key = ();
24///     type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
25///
26///     fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
27///         Err(SimpleKeyExtractionError::new("Extract error"))
28///     }
29/// }
30/// ```
31///
32/// For more see [`custom_key_bearer`](https://github.com/AaronErhardt/actix-governor/blob/main/examples/custom_key_bearer.rs) example
33pub trait KeyExtractor: Clone {
34    /// The type of the key.
35    type Key: Clone + Hash + Eq;
36
37    /// The type of the error that can occur if key extraction from the request fails.
38    type KeyExtractionError: ResponseError + 'static;
39
40    #[cfg(feature = "log")]
41    /// Name of this extractor (only used in logs).
42    fn name(&self) -> &'static str;
43
44    /// Extraction method, will return [`KeyExtractionError`] response when the extract failed
45    ///
46    /// [`KeyExtractionError`]: KeyExtractor::KeyExtractionError
47    fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError>;
48
49    /// The content you want to show it when the rate limit is exceeded.
50    /// You can calculate the time at which a caller can expect the next positive rate-limiting result by using [`NotUntil`].
51    /// The [`HttpResponseBuilder`] allows you to build a fully customized [`HttpResponse`] in case of an error.
52    /// # Example
53    /// ```rust
54    /// use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
55    /// use actix_web::ResponseError;
56    /// use actix_web::dev::ServiceRequest;
57    /// use governor::{NotUntil, clock::{Clock, QuantaInstant, DefaultClock}};
58    /// use actix_web::{HttpResponse, HttpResponseBuilder};
59    /// use actix_web::http::header::ContentType;
60    ///
61    ///
62    /// #[derive(Clone)]
63    /// struct Foo;
64    ///
65    /// // will return 500 error and 'Extract error' as content
66    /// impl KeyExtractor for Foo {
67    ///     type Key = ();
68    ///     type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
69    ///
70    ///     fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
71    ///         Err(SimpleKeyExtractionError::new("Extract error"))
72    ///     }
73    ///
74    ///     fn exceed_rate_limit_response(
75    ///             &self,
76    ///             negative: &NotUntil<QuantaInstant>,
77    ///             mut response: HttpResponseBuilder,
78    ///         ) -> HttpResponse {
79    ///             let wait_time = negative
80    ///                 .wait_time_from(DefaultClock::default().now())
81    ///                 .as_secs();
82    ///             response
83    ///                 .content_type(ContentType::plaintext())
84    ///                 .body(format!("Too many requests, retry in {}s", wait_time))
85    ///     }
86    /// }
87    /// ```
88    fn exceed_rate_limit_response(
89        &self,
90        negative: &NotUntil<QuantaInstant>,
91        mut response: HttpResponseBuilder,
92    ) -> HttpResponse {
93        let wait_time = negative
94            .wait_time_from(DefaultClock::default().now())
95            .as_secs();
96        response
97            .content_type(ContentType::plaintext())
98            .body(format!("Too many requests, retry in {}s", wait_time))
99    }
100
101    /// Returns a list of whitelisted keys. If a key is in this list, it will never be rate-limited.
102    fn whitelisted_keys(&self) -> Vec<Self::Key> {
103        Vec::new()
104    }
105
106    #[cfg(feature = "log")]
107    /// Value of the extracted key (only used in logs).
108    fn key_name(&self, _key: &Self::Key) -> Option<String> {
109        None
110    }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114/// A [KeyExtractor] that allow to do rate limiting for all incoming requests. This is useful if you want to hard-limit the HTTP load your app can handle.
115pub struct GlobalKeyExtractor;
116
117#[derive(Debug)]
118/// A simple struct to create  error, by default the status is  500 server error and content-type is plintext
119pub struct SimpleKeyExtractionError<T: Display + Debug> {
120    /// The response body of the error.
121    pub body: T,
122    /// The status code of the error.
123    pub status_code: StatusCode,
124    /// The content type of the error.
125    pub content_type: ContentType,
126}
127
128impl<T: Display + Debug> SimpleKeyExtractionError<T> {
129    /// Create new instance by body
130    ///
131    /// # Example
132    /// ```rust
133    /// use actix_governor::SimpleKeyExtractionError;
134    /// use actix_http::StatusCode;
135    /// use actix_web::http::header::ContentType;
136    ///
137    /// let my_error = SimpleKeyExtractionError::new("Some error content");
138    ///
139    /// assert_eq!(my_error.body, "Some error content");
140    /// assert_eq!(my_error.content_type, ContentType::plaintext());
141    /// assert_eq!(my_error.status_code, StatusCode::INTERNAL_SERVER_ERROR);
142    /// ```
143    pub fn new(body: T) -> Self {
144        Self {
145            body,
146            status_code: StatusCode::INTERNAL_SERVER_ERROR,
147            content_type: ContentType::plaintext(),
148        }
149    }
150
151    /// Set a new status code, the default is [`StatusCode::INTERNAL_SERVER_ERROR`]
152    ///
153    /// # Example
154    /// ```rust
155    /// use actix_governor::SimpleKeyExtractionError;
156    /// use actix_http::StatusCode;
157    /// use actix_web::http::header::ContentType;
158    ///
159    /// let my_error = SimpleKeyExtractionError::new("Some error content")
160    ///         .set_status_code(StatusCode::FORBIDDEN);
161    ///
162    /// assert_eq!(my_error.body, "Some error content");
163    /// assert_eq!(my_error.content_type, ContentType::plaintext());
164    /// assert_eq!(my_error.status_code, StatusCode::FORBIDDEN);
165    /// ```
166    pub fn set_status_code(mut self, status_code: StatusCode) -> Self {
167        self.status_code = status_code;
168        Self { ..self }
169    }
170
171    /// Set a new content type, the default is `text/plain`
172    ///
173    /// # Example
174    /// ```rust
175    /// use actix_governor::SimpleKeyExtractionError;
176    /// use actix_http::StatusCode;
177    /// use actix_web::http::header::ContentType;
178    ///
179    /// let my_error = SimpleKeyExtractionError::new(r#"{"msg":"Some error content"}"#)
180    ///         .set_content_type(ContentType::json());
181    ///
182    /// assert_eq!(my_error.body, r#"{"msg":"Some error content"}"#);
183    /// assert_eq!(my_error.content_type, ContentType::json());
184    /// assert_eq!(my_error.status_code, StatusCode::INTERNAL_SERVER_ERROR);
185    /// ```
186    pub fn set_content_type(mut self, content_type: ContentType) -> Self {
187        self.content_type = content_type;
188        Self { ..self }
189    }
190}
191
192impl<T: Display + Debug> Display for SimpleKeyExtractionError<T> {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        write!(f, "SimpleKeyExtractionError")
195    }
196}
197
198impl<T: Display + Debug> ResponseError for SimpleKeyExtractionError<T> {
199    fn status_code(&self) -> StatusCode {
200        self.status_code
201    }
202
203    fn error_response(&self) -> HttpResponse<actix_http::body::BoxBody> {
204        HttpResponseBuilder::new(self.status_code())
205            .content_type(self.content_type.clone())
206            .body(self.body.to_string())
207    }
208}
209
210impl KeyExtractor for GlobalKeyExtractor {
211    type Key = ();
212    type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
213
214    #[cfg(feature = "log")]
215    fn name(&self) -> &'static str {
216        "global"
217    }
218
219    fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
220        Ok(())
221    }
222
223    #[cfg(feature = "log")]
224    fn key_name(&self, _key: &Self::Key) -> Option<String> {
225        None
226    }
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230/// A [KeyExtractor] that uses peer IP as key. **This is the default key extractor and [it may no do want you want](PeerIpKeyExtractor).**
231///
232/// **Warning:** this key extractor enforces rate limiting based on the **IPv4 _peer_ IP address**
233/// or the **IPv6 /56 prefix of the _peer_ IP address**.
234///
235/// This means that if your app is deployed behind a reverse proxy, the peer IP address will _always_ be the proxy's IP address.
236/// In this case, rate limiting will be applied to _all_ incoming requests as if they were from the same user.
237///
238/// If this is not the behavior you want, you may:
239/// - implement your own [KeyExtractor] that tries to get IP from the `Forwarded` or `X-Forwarded-For` headers that most reverse proxies set
240/// - make absolutely sure that you only trust these headers when the peer IP is the IP of your reverse proxy (otherwise any user could set them to fake its IP)
241pub struct PeerIpKeyExtractor;
242
243impl KeyExtractor for PeerIpKeyExtractor {
244    type Key = IpAddr;
245    type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
246
247    #[cfg(feature = "log")]
248    fn name(&self) -> &'static str {
249        "peer IP"
250    }
251
252    fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
253        let mut ip = req.peer_addr().map(|socket| socket.ip()).ok_or_else(|| {
254            SimpleKeyExtractionError::new("Could not extract peer IP address from request")
255        })?;
256        // customers often get their own /56 prefix, apply rate-limiting per prefix instead of per
257        // address for IPv6
258        if let IpAddr::V6(ipv6) = ip {
259            let mut octets = ipv6.octets();
260            octets[7..16].fill(0);
261            ip = IpAddr::V6(octets.into());
262        }
263        Ok(ip)
264    }
265
266    #[cfg(feature = "log")]
267    fn key_name(&self, key: &Self::Key) -> Option<String> {
268        Some(key.to_string())
269    }
270}