actix_governor/key_extractor.rs
1use actix_http::StatusCode;
2use actix_web::{dev::ServiceRequest, http::header::ContentType};
3use actix_web::{HttpResponse, HttpResponseBuilder, ResponseError};
4use governor::clock::{Clock, DefaultClock, QuantaInstant};
5use governor::NotUntil;
6
7use std::error::Error;
8use std::fmt::{Debug, Display};
9use std::{hash::Hash, net::IpAddr};
10
11/// Generic structure of what is needed to extract a rate-limiting key from an incoming request.
12///
13/// ## Example
14/// ```rust
15/// use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
16/// use actix_web::ResponseError;
17/// use actix_web::dev::ServiceRequest;
18///
19/// #[derive(Clone)]
20/// struct Foo;
21///
22/// // will return 500 error and 'Extract error' as content
23/// impl KeyExtractor for Foo {
24/// type Key = ();
25/// type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
26///
27/// fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
28/// Err(SimpleKeyExtractionError::new("Extract error"))
29/// }
30/// }
31/// ```
32///
33/// For more see [`custom_key_bearer`](https://github.com/AaronErhardt/actix-governor/blob/main/examples/custom_key_bearer.rs) example
34pub trait KeyExtractor: Clone {
35 /// The type of the key.
36 type Key: Clone + Hash + Eq;
37
38 /// The type of the error that can occur if key extraction from the request fails.
39 type KeyExtractionError: ResponseError + 'static;
40
41 #[cfg(feature = "log")]
42 /// Name of this extractor (only used in logs).
43 fn name(&self) -> &'static str;
44
45 /// Extraction method, will return [`KeyExtractionError`] response when the extract failed
46 ///
47 /// [`KeyExtractionError`]: KeyExtractor::KeyExtractionError
48 fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError>;
49
50 /// The content you want to show it when the rate limit is exceeded.
51 /// You can calculate the time at which a caller can expect the next positive rate-limiting result by using [`NotUntil`].
52 /// The [`HttpResponseBuilder`] allows you to build a fully customized [`HttpResponse`] in case of an error.
53 /// # Example
54 /// ```rust
55 /// use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
56 /// use actix_web::ResponseError;
57 /// use actix_web::dev::ServiceRequest;
58 /// use governor::{NotUntil, clock::{Clock, QuantaInstant, DefaultClock}};
59 /// use actix_web::{HttpResponse, HttpResponseBuilder};
60 /// use actix_web::http::header::ContentType;
61 ///
62 ///
63 /// #[derive(Clone)]
64 /// struct Foo;
65 ///
66 /// // will return 500 error and 'Extract error' as content
67 /// impl KeyExtractor for Foo {
68 /// type Key = ();
69 /// type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
70 ///
71 /// fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
72 /// Err(SimpleKeyExtractionError::new("Extract error"))
73 /// }
74 ///
75 /// fn exceed_rate_limit_response(
76 /// &self,
77 /// negative: &NotUntil<QuantaInstant>,
78 /// mut response: HttpResponseBuilder,
79 /// ) -> HttpResponse {
80 /// let wait_time = negative
81 /// .wait_time_from(DefaultClock::default().now())
82 /// .as_secs();
83 /// response
84 /// .content_type(ContentType::plaintext())
85 /// .body(format!("Too many requests, retry in {}s", wait_time))
86 /// }
87 /// }
88 /// ```
89 fn exceed_rate_limit_response(
90 &self,
91 negative: &NotUntil<QuantaInstant>,
92 mut response: HttpResponseBuilder,
93 ) -> HttpResponse {
94 let wait_time = negative
95 .wait_time_from(DefaultClock::default().now())
96 .as_secs();
97 response
98 .content_type(ContentType::plaintext())
99 .body(format!("Too many requests, retry in {}s", wait_time))
100 }
101
102 /// Returns a list of whitelisted keys. If a key is in this list, it will never be rate-limited.
103 fn whitelisted_keys(&self) -> Vec<Self::Key> {
104 Vec::new()
105 }
106
107 #[cfg(feature = "log")]
108 /// Value of the extracted key (only used in logs).
109 fn key_name(&self, _key: &Self::Key) -> Option<String> {
110 None
111 }
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115/// A [KeyExtractor] that allow to do rate limiting for all incoming requests. This is useful if you want to hard-limit the HTTP load your app can handle.
116pub struct GlobalKeyExtractor;
117
118#[derive(Debug)]
119/// A simple struct to create error, by default the status is 500 server error and content-type is plintext
120pub struct SimpleKeyExtractionError<T: Display + Debug> {
121 /// The response body of the error.
122 pub body: T,
123 /// The status code of the error.
124 pub status_code: StatusCode,
125 /// The content type of the error.
126 pub content_type: ContentType,
127}
128
129impl<T: Display + Debug> SimpleKeyExtractionError<T> {
130 /// Create new instance by body
131 ///
132 /// # Example
133 /// ```rust
134 /// use actix_governor::SimpleKeyExtractionError;
135 /// use actix_http::StatusCode;
136 /// use actix_web::http::header::ContentType;
137 ///
138 /// let my_error = SimpleKeyExtractionError::new("Some error content");
139 ///
140 /// assert_eq!(my_error.body, "Some error content");
141 /// assert_eq!(my_error.content_type, ContentType::plaintext());
142 /// assert_eq!(my_error.status_code, StatusCode::INTERNAL_SERVER_ERROR);
143 /// ```
144 pub fn new(body: T) -> Self {
145 Self {
146 body,
147 status_code: StatusCode::INTERNAL_SERVER_ERROR,
148 content_type: ContentType::plaintext(),
149 }
150 }
151
152 /// Set a new status code, the default is [`StatusCode::INTERNAL_SERVER_ERROR`]
153 ///
154 /// # Example
155 /// ```rust
156 /// use actix_governor::SimpleKeyExtractionError;
157 /// use actix_http::StatusCode;
158 /// use actix_web::http::header::ContentType;
159 ///
160 /// let my_error = SimpleKeyExtractionError::new("Some error content")
161 /// .set_status_code(StatusCode::FORBIDDEN);
162 ///
163 /// assert_eq!(my_error.body, "Some error content");
164 /// assert_eq!(my_error.content_type, ContentType::plaintext());
165 /// assert_eq!(my_error.status_code, StatusCode::FORBIDDEN);
166 /// ```
167 pub fn set_status_code(mut self, status_code: StatusCode) -> Self {
168 self.status_code = status_code;
169 Self { ..self }
170 }
171
172 /// Set a new content type, the default is `text/plain`
173 ///
174 /// # Example
175 /// ```rust
176 /// use actix_governor::SimpleKeyExtractionError;
177 /// use actix_http::StatusCode;
178 /// use actix_web::http::header::ContentType;
179 ///
180 /// let my_error = SimpleKeyExtractionError::new(r#"{"msg":"Some error content"}"#)
181 /// .set_content_type(ContentType::json());
182 ///
183 /// assert_eq!(my_error.body, r#"{"msg":"Some error content"}"#);
184 /// assert_eq!(my_error.content_type, ContentType::json());
185 /// assert_eq!(my_error.status_code, StatusCode::INTERNAL_SERVER_ERROR);
186 /// ```
187 pub fn set_content_type(mut self, content_type: ContentType) -> Self {
188 self.content_type = content_type;
189 Self { ..self }
190 }
191}
192
193impl<T: Display + Debug> Display for SimpleKeyExtractionError<T> {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 write!(f, "SimpleKeyExtractionError")
196 }
197}
198
199impl<T: Display + Debug> Error for SimpleKeyExtractionError<T> {}
200
201impl<T: Display + Debug> ResponseError for SimpleKeyExtractionError<T> {
202 fn status_code(&self) -> StatusCode {
203 self.status_code
204 }
205
206 fn error_response(&self) -> HttpResponse<actix_http::body::BoxBody> {
207 HttpResponseBuilder::new(self.status_code())
208 .content_type(self.content_type.clone())
209 .body(self.body.to_string())
210 }
211}
212
213impl KeyExtractor for GlobalKeyExtractor {
214 type Key = ();
215 type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
216
217 #[cfg(feature = "log")]
218 fn name(&self) -> &'static str {
219 "global"
220 }
221
222 fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
223 Ok(())
224 }
225
226 #[cfg(feature = "log")]
227 fn key_name(&self, _key: &Self::Key) -> Option<String> {
228 None
229 }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233/// A [KeyExtractor] that uses peer IP as key. **This is the default key extractor and [it may no do want you want](PeerIpKeyExtractor).**
234///
235/// **Warning:** this key extractor enforces rate limiting based on the **IPv4 _peer_ IP address**
236/// or the **IPv6 /56 prefix of the _peer_ IP address**.
237///
238/// This means that if your app is deployed behind a reverse proxy, the peer IP address will _always_ be the proxy's IP address.
239/// In this case, rate limiting will be applied to _all_ incoming requests as if they were from the same user.
240///
241/// If this is not the behavior you want, you may:
242/// - implement your own [KeyExtractor] that tries to get IP from the `Forwarded` or `X-Forwarded-For` headers that most reverse proxies set
243/// - make absolutely sure that you only trust these headers when the peer IP is the IP of your reverse proxy (otherwise any user could set them to fake its IP)
244pub struct PeerIpKeyExtractor;
245
246impl KeyExtractor for PeerIpKeyExtractor {
247 type Key = IpAddr;
248 type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
249
250 #[cfg(feature = "log")]
251 fn name(&self) -> &'static str {
252 "peer IP"
253 }
254
255 fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
256 let mut ip = req.peer_addr().map(|socket| socket.ip()).ok_or_else(|| {
257 SimpleKeyExtractionError::new("Could not extract peer IP address from request")
258 })?;
259 // customers often get their own /56 prefix, apply rate-limiting per prefix instead of per
260 // address for IPv6
261 if let IpAddr::V6(ipv6) = ip {
262 let mut octets = ipv6.octets();
263 octets[7..16].fill(0);
264 ip = IpAddr::V6(octets.into());
265 }
266 Ok(ip)
267 }
268
269 #[cfg(feature = "log")]
270 fn key_name(&self, key: &Self::Key) -> Option<String> {
271 Some(key.to_string())
272 }
273}