rocket_governor/lib.rs
1//! # rocket-governor - rate-limiting implementation for Rocket web framework
2//!
3//! Provides the [rocket] guard implementing rate-limiting (based on [governor]).
4//!
5//! Declare a struct and use it with the generic [RocketGovernor] guard.
6//! This requires to implement trait [RocketGovernable] for your struct.
7//!
8//! ## Example
9//!
10//! ```rust
11//! use rocket::{catchers, get, http::Status, launch, routes};
12//! use rocket_governor::{rocket_governor_catcher, Method, Quota, RocketGovernable, RocketGovernor};
13//!
14//! pub struct RateLimitGuard;
15//!
16//! impl<'r> RocketGovernable<'r> for RateLimitGuard {
17//! fn quota(_method: Method, _route_name: &str) -> Quota {
18//! Quota::per_second(Self::nonzero(1u32))
19//! }
20//! }
21//!
22//! #[get("/")]
23//! fn route_example(_limitguard: RocketGovernor<RateLimitGuard>) -> Status {
24//! Status::Ok
25//! }
26//!
27//! #[launch]
28//! fn launch_rocket() -> _ {
29//! rocket::build()
30//! .mount("/", routes![route_example])
31//! .register("/", catchers![rocket_governor_catcher])
32//! }
33//! ```
34//!
35//! See [rocket-governor] Github project for more information.
36//!
37//! ## Features
38//!
39//! ### Optional feature __limit_info__
40//!
41//! There is the optional feature __limit_info__ which enables reporting about
42//! rate limits in HTTP headers of requests.
43//!
44//! The implementation is based on headers of
45//! [https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers]().
46//! The feature provides a default implementation of a Rocket fairing
47//! which need to be used to get the HTTP headers set.
48//!
49//! See API documentation for [LimitHeaderGen].
50//!
51//! For usage depend on it in Cargo.toml
52//! ```toml
53//! [dependencies]
54//! rocket-governor = { version = "...", features = ["limit_info"] }
55//! ```
56//!
57//! ### Optional feature __logger__
58//!
59//! There is the optional feature __logger__ which enables some logging output.
60//!
61//! For usage depend on it in Cargo.toml
62//! ```toml
63//! [dependencies]
64//! rocket-governor = { version = "...", features = ["logger"] }
65//! ```
66//!
67//! [governor]: https://docs.rs/governor/
68//! [rocket]: https://docs.rs/rocket/
69//! [rocket-governor]: https://github.com/kolbma/rocket-governor/
70
71#![deny(clippy::all)]
72#![deny(keyword_idents)]
73#![deny(missing_docs)]
74#![deny(non_ascii_idents)]
75#![deny(unreachable_pub)]
76#![deny(unsafe_code)]
77#![deny(unused_crate_dependencies)]
78#![deny(unused_qualifications)]
79//#![deny(unused_results)]
80#![deny(warnings)]
81
82use governor::clock::{Clock, DefaultClock};
83pub use governor::Quota;
84use lazy_static::lazy_static;
85pub use limit_error::LimitError;
86#[cfg(feature = "limit_info")]
87pub use limit_header_gen::LimitHeaderGen;
88use logger::{error, info, trace};
89use registry::Registry;
90#[cfg(feature = "limit_info")]
91pub use req_state::ReqState;
92pub use rocket::http::Method;
93use rocket::{
94 async_trait, catch,
95 http::Status,
96 request::{FromRequest, Outcome},
97 Request,
98};
99pub use rocket_governable::RocketGovernable;
100use std::marker::PhantomData;
101pub use std::num::NonZeroU32;
102
103pub mod header;
104mod limit_error;
105#[cfg(feature = "limit_info")]
106mod limit_header_gen;
107mod logger;
108mod registry;
109#[cfg(feature = "limit_info")]
110mod req_state;
111mod rocket_governable;
112
113/// Generic [RocketGovernor] implementation.
114///
115/// [rocket_governor](crate) is a [rocket] guard implementation of the
116/// [governor] rate limiter.
117///
118/// Declare a struct and use it with the generic [RocketGovernor] guard.
119/// This requires to implement [RocketGovernable] for your struct.
120///
121/// See the top level [crate] documentation.
122///
123/// [governor]: https://docs.rs/governor/
124/// [rocket]: https://docs.rs/rocket/
125///
126pub struct RocketGovernor<'r, T>
127where
128 T: RocketGovernable<'r>,
129{
130 _phantom: PhantomData<&'r T>,
131}
132
133lazy_static! {
134 static ref CLOCK: DefaultClock = DefaultClock::default();
135}
136
137#[doc(hidden)]
138impl<'r, T> RocketGovernor<'r, T>
139where
140 T: RocketGovernable<'r>,
141{
142 /// Handler used in `FromRequest::from_request(request: &'r Request)`.
143 #[inline(always)]
144 pub fn handle_from_request(request: &'r Request) -> Outcome<Self, LimitError> {
145 let res = request.local_cache(|| {
146 if let Some(route) = request.route() {
147 if let Some(route_name) = &route.name {
148 let limiter = Registry::get_or_insert::<T>(
149 route.method,
150 route_name,
151 T::quota(route.method, route_name),
152 );
153 if let Some(client_ip) = request.client_ip() {
154 let limit_check_res = limiter.check_key(&client_ip);
155 match limit_check_res {
156 Ok(state) => {
157 #[allow(unused_variables)] // only used in trace or when feature limit_info
158 let request_capacity = state.remaining_burst_capacity();
159 trace!(
160 "not governed ip {} method {} route {}: remaining request capacity {}",
161 &client_ip,
162 &route.method,
163 route_name,
164 request_capacity
165 );
166
167 #[cfg(feature = "limit_info")] {
168 // `local_cache` lookup works by type and so it doesn't work to catch
169 // `LimitError` and handle different Ok objects:
170 // See https://rocket.rs/v0.5/guide/state/#request-local-state
171 // State wrapper is so cached separate...
172 let req_state = ReqState::new(state.quota(), request_capacity);
173 let is_req_state_allowed = T::limit_info_allow(Some(route.method), Some(route_name), &req_state);
174 if is_req_state_allowed {
175 // For safety and speed this is used by default in a limited way, see:
176 // * Information disclosure:
177 // https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers#section-6.2
178 //
179 let _ = request.local_cache(|| req_state);
180 }
181 }
182
183 Ok(()) // needs to be something not changing during request
184 }
185 Err(notuntil) => {
186 let wait_time = notuntil.wait_time_from(CLOCK.now()).as_secs();
187 info!(
188 "ip {} method {} route {} limited {} sec",
189 &client_ip, &route.method, route_name, &wait_time
190 );
191 Err(LimitError::GovernedRequest(wait_time, notuntil.quota()))
192 }
193 }
194 } else {
195 error!(
196 "missing ip - method {} route {}: request: {:?}",
197 &route.method, route_name, request
198 );
199 Err(LimitError::MissingClientIpAddr)
200 }
201 } else {
202 error!("route without name: request: {:?}", request);
203 Err(LimitError::MissingRouteName)
204 }
205 } else {
206 error!("routing failure: request: {:?}", request);
207 Err(LimitError::MissingRoute)
208 }
209 });
210
211 match res {
212 Ok(_) => {
213 #[cfg(feature = "limit_info")]
214 {
215 // available if `T::limit_info_allow()` is true
216 let state_opt = ReqState::get_or_default(request);
217 #[allow(unused_variables)] // state only used in trace
218 if let Some(state) = state_opt {
219 trace!(
220 "request_capacity: {} rate-limit: {}",
221 state.request_capacity,
222 state.quota.burst_size().get()
223 );
224 }
225 }
226
227 // Forward request
228 Outcome::Success(Self::default())
229 }
230 Err(e) => {
231 let e = e.clone();
232 match e {
233 LimitError::GovernedRequest(_, _) => {
234 Outcome::Error((Status::TooManyRequests, e))
235 }
236 _ => Outcome::Error((Status::BadRequest, e)),
237 }
238 }
239 }
240 }
241}
242
243#[doc(hidden)]
244impl<'r, T> Default for RocketGovernor<'r, T>
245where
246 T: RocketGovernable<'r>,
247{
248 fn default() -> Self {
249 Self {
250 _phantom: PhantomData,
251 }
252 }
253}
254
255#[doc(hidden)]
256#[async_trait]
257impl<'r, T> FromRequest<'r> for RocketGovernor<'r, T>
258where
259 T: RocketGovernable<'r>,
260{
261 type Error = LimitError;
262
263 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, LimitError> {
264 Self::handle_from_request(request)
265 }
266}
267
268/// A default implementation for Rocket [Catcher] handling HTTP TooManyRequests responses.
269///
270/// ## Example
271///
272/// ```rust
273/// use rocket::{catchers, launch};
274/// use rocket_governor::rocket_governor_catcher;
275///
276/// #[launch]
277/// fn launch_rocket() -> _ {
278/// rocket::build()
279/// .register("/", catchers![rocket_governor_catcher])
280/// }
281/// ```
282///
283/// [Catcher]: https://api.rocket.rs/v0.5/rocket/struct.Catcher.html
284#[catch(429)]
285pub fn rocket_governor_catcher<'r>(request: &'r Request) -> &'r LimitError {
286 let cached_res: &Result<(), LimitError> = request.local_cache(|| Err(LimitError::Error));
287 if let Err(limit_err) = cached_res {
288 limit_err
289 } else {
290 &LimitError::Error
291 }
292}