actix_governor/key_extractor.rs
1use actix_http::StatusCode;
2use actix_web::{dev::ServiceRequest, http::header::ContentType};
3use actix_web::{HttpResponse, HttpResponseBuilder, ResponseError};
4use governor::clock::{Clock, DefaultClock, QuantaInstant};
5use governor::NotUntil;
6
7use std::fmt::{Debug, Display};
8use std::{hash::Hash, net::IpAddr};
9
10/// Generic structure of what is needed to extract a rate-limiting key from an incoming request.
11///
12/// ## Example
13/// ```rust
14/// use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
15/// use actix_web::ResponseError;
16/// use actix_web::dev::ServiceRequest;
17///
18/// #[derive(Clone)]
19/// struct Foo;
20///
21/// // will return 500 error and 'Extract error' as content
22/// impl KeyExtractor for Foo {
23/// type Key = ();
24/// type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
25///
26/// fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
27/// Err(SimpleKeyExtractionError::new("Extract error"))
28/// }
29/// }
30/// ```
31///
32/// For more see [`custom_key_bearer`](https://github.com/AaronErhardt/actix-governor/blob/main/examples/custom_key_bearer.rs) example
33pub trait KeyExtractor: Clone {
34 /// The type of the key.
35 type Key: Clone + Hash + Eq;
36
37 /// The type of the error that can occur if key extraction from the request fails.
38 type KeyExtractionError: ResponseError + 'static;
39
40 #[cfg(feature = "log")]
41 /// Name of this extractor (only used in logs).
42 fn name(&self) -> &'static str;
43
44 /// Extraction method, will return [`KeyExtractionError`] response when the extract failed
45 ///
46 /// [`KeyExtractionError`]: KeyExtractor::KeyExtractionError
47 fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError>;
48
49 /// The content you want to show it when the rate limit is exceeded.
50 /// You can calculate the time at which a caller can expect the next positive rate-limiting result by using [`NotUntil`].
51 /// The [`HttpResponseBuilder`] allows you to build a fully customized [`HttpResponse`] in case of an error.
52 /// # Example
53 /// ```rust
54 /// use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
55 /// use actix_web::ResponseError;
56 /// use actix_web::dev::ServiceRequest;
57 /// use governor::{NotUntil, clock::{Clock, QuantaInstant, DefaultClock}};
58 /// use actix_web::{HttpResponse, HttpResponseBuilder};
59 /// use actix_web::http::header::ContentType;
60 ///
61 ///
62 /// #[derive(Clone)]
63 /// struct Foo;
64 ///
65 /// // will return 500 error and 'Extract error' as content
66 /// impl KeyExtractor for Foo {
67 /// type Key = ();
68 /// type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
69 ///
70 /// fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
71 /// Err(SimpleKeyExtractionError::new("Extract error"))
72 /// }
73 ///
74 /// fn exceed_rate_limit_response(
75 /// &self,
76 /// negative: &NotUntil<QuantaInstant>,
77 /// mut response: HttpResponseBuilder,
78 /// ) -> HttpResponse {
79 /// let wait_time = negative
80 /// .wait_time_from(DefaultClock::default().now())
81 /// .as_secs();
82 /// response
83 /// .content_type(ContentType::plaintext())
84 /// .body(format!("Too many requests, retry in {}s", wait_time))
85 /// }
86 /// }
87 /// ```
88 fn exceed_rate_limit_response(
89 &self,
90 negative: &NotUntil<QuantaInstant>,
91 mut response: HttpResponseBuilder,
92 ) -> HttpResponse {
93 let wait_time = negative
94 .wait_time_from(DefaultClock::default().now())
95 .as_secs();
96 response
97 .content_type(ContentType::plaintext())
98 .body(format!("Too many requests, retry in {}s", wait_time))
99 }
100
101 /// Returns a list of whitelisted keys. If a key is in this list, it will never be rate-limited.
102 fn whitelisted_keys(&self) -> Vec<Self::Key> {
103 Vec::new()
104 }
105
106 #[cfg(feature = "log")]
107 /// Value of the extracted key (only used in logs).
108 fn key_name(&self, _key: &Self::Key) -> Option<String> {
109 None
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114/// A [KeyExtractor] that allow to do rate limiting for all incoming requests. This is useful if you want to hard-limit the HTTP load your app can handle.
115pub struct GlobalKeyExtractor;
116
117#[derive(Debug)]
118/// A simple struct to create error, by default the status is 500 server error and content-type is plintext
119pub struct SimpleKeyExtractionError<T: Display + Debug> {
120 /// The response body of the error.
121 pub body: T,
122 /// The status code of the error.
123 pub status_code: StatusCode,
124 /// The content type of the error.
125 pub content_type: ContentType,
126}
127
128impl<T: Display + Debug> SimpleKeyExtractionError<T> {
129 /// Create new instance by body
130 ///
131 /// # Example
132 /// ```rust
133 /// use actix_governor::SimpleKeyExtractionError;
134 /// use actix_http::StatusCode;
135 /// use actix_web::http::header::ContentType;
136 ///
137 /// let my_error = SimpleKeyExtractionError::new("Some error content");
138 ///
139 /// assert_eq!(my_error.body, "Some error content");
140 /// assert_eq!(my_error.content_type, ContentType::plaintext());
141 /// assert_eq!(my_error.status_code, StatusCode::INTERNAL_SERVER_ERROR);
142 /// ```
143 pub fn new(body: T) -> Self {
144 Self {
145 body,
146 status_code: StatusCode::INTERNAL_SERVER_ERROR,
147 content_type: ContentType::plaintext(),
148 }
149 }
150
151 /// Set a new status code, the default is [`StatusCode::INTERNAL_SERVER_ERROR`]
152 ///
153 /// # Example
154 /// ```rust
155 /// use actix_governor::SimpleKeyExtractionError;
156 /// use actix_http::StatusCode;
157 /// use actix_web::http::header::ContentType;
158 ///
159 /// let my_error = SimpleKeyExtractionError::new("Some error content")
160 /// .set_status_code(StatusCode::FORBIDDEN);
161 ///
162 /// assert_eq!(my_error.body, "Some error content");
163 /// assert_eq!(my_error.content_type, ContentType::plaintext());
164 /// assert_eq!(my_error.status_code, StatusCode::FORBIDDEN);
165 /// ```
166 pub fn set_status_code(mut self, status_code: StatusCode) -> Self {
167 self.status_code = status_code;
168 Self { ..self }
169 }
170
171 /// Set a new content type, the default is `text/plain`
172 ///
173 /// # Example
174 /// ```rust
175 /// use actix_governor::SimpleKeyExtractionError;
176 /// use actix_http::StatusCode;
177 /// use actix_web::http::header::ContentType;
178 ///
179 /// let my_error = SimpleKeyExtractionError::new(r#"{"msg":"Some error content"}"#)
180 /// .set_content_type(ContentType::json());
181 ///
182 /// assert_eq!(my_error.body, r#"{"msg":"Some error content"}"#);
183 /// assert_eq!(my_error.content_type, ContentType::json());
184 /// assert_eq!(my_error.status_code, StatusCode::INTERNAL_SERVER_ERROR);
185 /// ```
186 pub fn set_content_type(mut self, content_type: ContentType) -> Self {
187 self.content_type = content_type;
188 Self { ..self }
189 }
190}
191
192impl<T: Display + Debug> Display for SimpleKeyExtractionError<T> {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 write!(f, "SimpleKeyExtractionError")
195 }
196}
197
198impl<T: Display + Debug> ResponseError for SimpleKeyExtractionError<T> {
199 fn status_code(&self) -> StatusCode {
200 self.status_code
201 }
202
203 fn error_response(&self) -> HttpResponse<actix_http::body::BoxBody> {
204 HttpResponseBuilder::new(self.status_code())
205 .content_type(self.content_type.clone())
206 .body(self.body.to_string())
207 }
208}
209
210impl KeyExtractor for GlobalKeyExtractor {
211 type Key = ();
212 type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
213
214 #[cfg(feature = "log")]
215 fn name(&self) -> &'static str {
216 "global"
217 }
218
219 fn extract(&self, _req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
220 Ok(())
221 }
222
223 #[cfg(feature = "log")]
224 fn key_name(&self, _key: &Self::Key) -> Option<String> {
225 None
226 }
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230/// A [KeyExtractor] that uses peer IP as key. **This is the default key extractor and [it may no do want you want](PeerIpKeyExtractor).**
231///
232/// **Warning:** this key extractor enforces rate limiting based on the **IPv4 _peer_ IP address**
233/// or the **IPv6 /56 prefix of the _peer_ IP address**.
234///
235/// This means that if your app is deployed behind a reverse proxy, the peer IP address will _always_ be the proxy's IP address.
236/// In this case, rate limiting will be applied to _all_ incoming requests as if they were from the same user.
237///
238/// If this is not the behavior you want, you may:
239/// - implement your own [KeyExtractor] that tries to get IP from the `Forwarded` or `X-Forwarded-For` headers that most reverse proxies set
240/// - make absolutely sure that you only trust these headers when the peer IP is the IP of your reverse proxy (otherwise any user could set them to fake its IP)
241pub struct PeerIpKeyExtractor;
242
243impl KeyExtractor for PeerIpKeyExtractor {
244 type Key = IpAddr;
245 type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
246
247 #[cfg(feature = "log")]
248 fn name(&self) -> &'static str {
249 "peer IP"
250 }
251
252 fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
253 let mut ip = req.peer_addr().map(|socket| socket.ip()).ok_or_else(|| {
254 SimpleKeyExtractionError::new("Could not extract peer IP address from request")
255 })?;
256 // customers often get their own /56 prefix, apply rate-limiting per prefix instead of per
257 // address for IPv6
258 if let IpAddr::V6(ipv6) = ip {
259 let mut octets = ipv6.octets();
260 octets[7..16].fill(0);
261 ip = IpAddr::V6(octets.into());
262 }
263 Ok(ip)
264 }
265
266 #[cfg(feature = "log")]
267 fn key_name(&self, key: &Self::Key) -> Option<String> {
268 Some(key.to_string())
269 }
270}