rocket_community/fairing/
ad_hoc.rs

1use futures::future::{BoxFuture, Future, FutureExt};
2use parking_lot::Mutex;
3
4use crate::fairing::{Fairing, Info, Kind, Result};
5use crate::route::RouteUri;
6use crate::trace::Trace;
7use crate::{Build, Data, Orbit, Request, Response, Rocket};
8
9/// A ad-hoc fairing that can be created from a function or closure.
10///
11/// This enum can be used to create a fairing from a simple function or closure
12/// without creating a new structure or implementing `Fairing` directly.
13///
14/// # Usage
15///
16/// Use [`AdHoc::on_ignite`], [`AdHoc::on_liftoff`], [`AdHoc::on_request()`], or
17/// [`AdHoc::on_response()`] to create an `AdHoc` structure from a function or
18/// closure. Then, simply attach the structure to the `Rocket` instance.
19///
20/// # Example
21///
22/// The following snippet creates a `Rocket` instance with two ad-hoc fairings.
23/// The first, a liftoff fairing named "Liftoff Printer", simply prints a message
24/// indicating that Rocket has launched. The second named "Put Rewriter", a
25/// request fairing, rewrites the method of all requests to be `PUT`.
26///
27/// ```rust
28/// # extern crate rocket_community as rocket;
29/// use rocket::fairing::AdHoc;
30/// use rocket::http::Method;
31///
32/// rocket::build()
33///     .attach(AdHoc::on_liftoff("Liftoff Printer", |_| Box::pin(async move {
34///         println!("...annnddd we have liftoff!");
35///     })))
36///     .attach(AdHoc::on_request("Put Rewriter", |req, _| Box::pin(async move {
37///         req.set_method(Method::Put);
38///     })));
39/// ```
40pub struct AdHoc {
41    name: &'static str,
42    kind: AdHocKind,
43}
44
45struct Once<F: ?Sized>(Mutex<Option<Box<F>>>);
46
47impl<F: ?Sized> Once<F> {
48    fn new(f: Box<F>) -> Self {
49        Once(Mutex::new(Some(f)))
50    }
51
52    #[track_caller]
53    fn take(&self) -> Box<F> {
54        self.0.lock().take().expect("Once::take() called once")
55    }
56}
57
58enum AdHocKind {
59    /// An ad-hoc **ignite** fairing. Called during ignition.
60    Ignite(Once<dyn FnOnce(Rocket<Build>) -> BoxFuture<'static, Result> + Send + 'static>),
61
62    /// An ad-hoc **liftoff** fairing. Called just after Rocket launches.
63    Liftoff(Once<dyn for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()> + Send + 'static>),
64
65    /// An ad-hoc **request** fairing. Called when a request is received.
66    Request(
67        Box<
68            dyn for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()>
69                + Send
70                + Sync
71                + 'static,
72        >,
73    ),
74
75    /// An ad-hoc **response** fairing. Called when a response is ready to be
76    /// sent to a client.
77    Response(
78        Box<
79            dyn for<'r, 'b> Fn(&'r Request<'_>, &'b mut Response<'r>) -> BoxFuture<'b, ()>
80                + Send
81                + Sync
82                + 'static,
83        >,
84    ),
85
86    /// An ad-hoc **shutdown** fairing. Called on shutdown.
87    Shutdown(Once<dyn for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()> + Send + 'static>),
88}
89
90impl AdHoc {
91    /// Constructs an `AdHoc` ignite fairing named `name`. The function `f` will
92    /// be called by Rocket during the [`Rocket::ignite()`] phase.
93    ///
94    /// This version of an `AdHoc` ignite fairing cannot abort ignite. For a
95    /// fallible version that can, see [`AdHoc::try_on_ignite()`].
96    ///
97    /// # Example
98    ///
99    /// ```rust
100    /// # extern crate rocket_community as rocket;
101    /// use rocket::fairing::AdHoc;
102    ///
103    /// // The no-op ignite fairing.
104    /// let fairing = AdHoc::on_ignite("Boom!", |rocket| async move {
105    ///     rocket
106    /// });
107    /// ```
108    pub fn on_ignite<F, Fut>(name: &'static str, f: F) -> AdHoc
109    where
110        F: FnOnce(Rocket<Build>) -> Fut + Send + 'static,
111        Fut: Future<Output = Rocket<Build>> + Send + 'static,
112    {
113        AdHoc::try_on_ignite(name, |rocket| f(rocket).map(Ok))
114    }
115
116    /// Constructs an `AdHoc` ignite fairing named `name`. The function `f` will
117    /// be called by Rocket during the [`Rocket::ignite()`] phase. Returning an
118    /// `Err` aborts ignition and thus launch.
119    ///
120    /// For an infallible version, see [`AdHoc::on_ignite()`].
121    ///
122    /// # Example
123    ///
124    /// ```rust
125    /// # extern crate rocket_community as rocket;
126    /// use rocket::fairing::AdHoc;
127    ///
128    /// // The no-op try ignite fairing.
129    /// let fairing = AdHoc::try_on_ignite("No-Op", |rocket| async { Ok(rocket) });
130    /// ```
131    pub fn try_on_ignite<F, Fut>(name: &'static str, f: F) -> AdHoc
132    where
133        F: FnOnce(Rocket<Build>) -> Fut + Send + 'static,
134        Fut: Future<Output = Result> + Send + 'static,
135    {
136        AdHoc {
137            name,
138            kind: AdHocKind::Ignite(Once::new(Box::new(|r| f(r).boxed()))),
139        }
140    }
141
142    /// Constructs an `AdHoc` liftoff fairing named `name`. The function `f`
143    /// will be called by Rocket just after [`Rocket::launch()`].
144    ///
145    /// # Example
146    ///
147    /// ```rust
148    /// # extern crate rocket_community as rocket;
149    /// use rocket::fairing::AdHoc;
150    ///
151    /// // A fairing that prints a message just before launching.
152    /// let fairing = AdHoc::on_liftoff("Boom!", |_| Box::pin(async move {
153    ///     println!("Rocket has lifted off!");
154    /// }));
155    /// ```
156    pub fn on_liftoff<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
157    where
158        F: for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()>,
159    {
160        AdHoc {
161            name,
162            kind: AdHocKind::Liftoff(Once::new(Box::new(f))),
163        }
164    }
165
166    /// Constructs an `AdHoc` request fairing named `name`. The function `f`
167    /// will be called and the returned `Future` will be `await`ed by Rocket
168    /// when a new request is received.
169    ///
170    /// # Example
171    ///
172    /// ```rust
173    /// # extern crate rocket_community as rocket;
174    /// use rocket::fairing::AdHoc;
175    ///
176    /// // The no-op request fairing.
177    /// let fairing = AdHoc::on_request("Dummy", |req, data| {
178    ///     Box::pin(async move {
179    ///         // do something with the request and data...
180    /// #       let (_, _) = (req, data);
181    ///     })
182    /// });
183    /// ```
184    pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
185    where
186        F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()>,
187    {
188        AdHoc {
189            name,
190            kind: AdHocKind::Request(Box::new(f)),
191        }
192    }
193
194    // FIXME(rustc): We'd like to allow passing `async fn` to these methods...
195    // https://github.com/rust-lang/rust/issues/64552#issuecomment-666084589
196
197    /// Constructs an `AdHoc` response fairing named `name`. The function `f`
198    /// will be called and the returned `Future` will be `await`ed by Rocket
199    /// when a response is ready to be sent.
200    ///
201    /// # Example
202    ///
203    /// ```rust
204    /// # extern crate rocket_community as rocket;
205    /// use rocket::fairing::AdHoc;
206    ///
207    /// // The no-op response fairing.
208    /// let fairing = AdHoc::on_response("Dummy", |req, resp| {
209    ///     Box::pin(async move {
210    ///         // do something with the request and pending response...
211    /// #       let (_, _) = (req, resp);
212    ///     })
213    /// });
214    /// ```
215    pub fn on_response<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
216    where
217        F: for<'b, 'r> Fn(&'r Request<'_>, &'b mut Response<'r>) -> BoxFuture<'b, ()>,
218    {
219        AdHoc {
220            name,
221            kind: AdHocKind::Response(Box::new(f)),
222        }
223    }
224
225    /// Constructs an `AdHoc` shutdown fairing named `name`. The function `f`
226    /// will be called by Rocket when [shutdown is triggered].
227    ///
228    /// [shutdown is triggered]: crate::config::ShutdownConfig#triggers
229    ///
230    /// # Example
231    ///
232    /// ```rust
233    /// # extern crate rocket_community as rocket;
234    /// use rocket::fairing::AdHoc;
235    ///
236    /// // A fairing that prints a message just before launching.
237    /// let fairing = AdHoc::on_shutdown("Bye!", |_| Box::pin(async move {
238    ///     println!("Rocket is on its way back!");
239    /// }));
240    /// ```
241    pub fn on_shutdown<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
242    where
243        F: for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()>,
244    {
245        AdHoc {
246            name,
247            kind: AdHocKind::Shutdown(Once::new(Box::new(f))),
248        }
249    }
250
251    /// Constructs an `AdHoc` launch fairing that extracts a configuration of
252    /// type `T` from the configured provider and stores it in managed state. If
253    /// extractions fails, pretty-prints the error message and aborts launch.
254    ///
255    /// # Example
256    ///
257    /// ```rust
258    /// # extern crate rocket_community as rocket;
259    /// # use rocket::launch;
260    /// use serde::Deserialize;
261    /// use rocket::fairing::AdHoc;
262    ///
263    /// #[derive(Deserialize)]
264    /// struct Config {
265    ///     field: String,
266    ///     other: usize,
267    ///     /* and so on.. */
268    /// }
269    ///
270    /// #[launch]
271    /// fn rocket() -> _ {
272    ///     rocket::build().attach(AdHoc::config::<Config>())
273    /// }
274    /// ```
275    pub fn config<'de, T>() -> AdHoc
276    where
277        T: serde::Deserialize<'de> + Send + Sync + 'static,
278    {
279        AdHoc::try_on_ignite(std::any::type_name::<T>(), |rocket| async {
280            let app_config = match rocket.figment().extract::<T>() {
281                Ok(config) => config,
282                Err(e) => {
283                    e.trace_error();
284                    return Err(rocket);
285                }
286            };
287
288            Ok(rocket.manage(app_config))
289        })
290    }
291
292    /// Constructs an `AdHoc` request fairing that strips trailing slashes from
293    /// all URIs in all incoming requests.
294    ///
295    /// The fairing returned by this method is intended largely for applications
296    /// that migrated from Rocket v0.4 to Rocket v0.5. In Rocket v0.4, requests
297    /// with a trailing slash in the URI were treated as if the trailing slash
298    /// were not present. For example, the request URI `/foo/` would match the
299    /// route `/<a>` with `a = foo`. If the application depended on this
300    /// behavior, say by using URIs with previously innocuous trailing slashes
301    /// in an external application, requests will not be routed as expected.
302    ///
303    /// This fairing resolves this issue by stripping a trailing slash, if any,
304    /// in all incoming URIs. When it does so, it logs a warning. It is
305    /// recommended to use this fairing as a stop-gap measure instead of a
306    /// permanent resolution, if possible.
307    //
308    /// # Example
309    ///
310    /// With the fairing attached, request URIs have a trailing slash stripped:
311    ///
312    /// ```rust
313    /// # #[macro_use] extern crate rocket_community as rocket;
314    /// use rocket::local::blocking::Client;
315    /// use rocket::fairing::AdHoc;
316    ///
317    /// #[get("/<param>")]
318    /// fn foo(param: &str) -> &str {
319    ///     param
320    /// }
321    ///
322    /// #[launch]
323    /// fn rocket() -> _ {
324    ///     rocket::build()
325    ///         .mount("/", routes![foo])
326    ///         .attach(AdHoc::uri_normalizer())
327    /// }
328    ///
329    /// # let client = Client::debug(rocket()).unwrap();
330    /// let response = client.get("/bar/").dispatch();
331    /// assert_eq!(response.into_string().unwrap(), "bar");
332    /// ```
333    ///
334    /// Without it, request URIs are unchanged and routed normally:
335    ///
336    /// ```rust
337    /// # #[macro_use] extern crate rocket_community as rocket;
338    /// use rocket::local::blocking::Client;
339    /// use rocket::fairing::AdHoc;
340    ///
341    /// #[get("/<param>")]
342    /// fn foo(param: &str) -> &str {
343    ///     param
344    /// }
345    ///
346    /// #[launch]
347    /// fn rocket() -> _ {
348    ///     rocket::build().mount("/", routes![foo])
349    /// }
350    ///
351    /// # let client = Client::debug(rocket()).unwrap();
352    /// let response = client.get("/bar/").dispatch();
353    /// assert!(response.status().class().is_client_error());
354    ///
355    /// let response = client.get("/bar").dispatch();
356    /// assert_eq!(response.into_string().unwrap(), "bar");
357    /// ```
358    // #[deprecated(since = "0.7", note = "routing from Rocket 0.6 is now standard")]
359    pub fn uri_normalizer() -> impl Fairing {
360        #[derive(Default)]
361        struct Normalizer {
362            routes: state::InitCell<Vec<crate::Route>>,
363        }
364
365        impl Normalizer {
366            fn routes(&self, rocket: &Rocket<Orbit>) -> &[crate::Route] {
367                self.routes.get_or_init(|| {
368                    rocket
369                        .routes()
370                        .filter(|r| r.uri.has_trailing_slash())
371                        .cloned()
372                        .collect()
373                })
374            }
375        }
376
377        #[crate::async_trait]
378        impl Fairing for Normalizer {
379            fn info(&self) -> Info {
380                Info {
381                    name: "URI Normalizer",
382                    kind: Kind::Ignite | Kind::Liftoff | Kind::Request,
383                }
384            }
385
386            async fn on_ignite(&self, rocket: Rocket<Build>) -> Result {
387                // We want a route like `/foo/<bar..>` to match a request for
388                // `/foo` as it would have before. While we could check if a
389                // route is mounted that would cause this match and then rewrite
390                // the request URI as `/foo/`, doing so is expensive and
391                // potentially incorrect due to request guards and ranking.
392                //
393                // Instead, we generate a new route with URI `/foo` with the
394                // same rank and handler as the `/foo/<bar..>` route and mount
395                // it to this instance of `rocket`. This preserves the previous
396                // matching while still checking request guards.
397                let normalized_trailing = rocket
398                    .routes()
399                    .filter(|r| r.uri.metadata.dynamic_trail)
400                    .filter(|r| r.uri.path().segments().num() > 1)
401                    .filter_map(|route| {
402                        let path = route.uri.unmounted().path();
403                        let new_path = path
404                            .as_str()
405                            .rsplit_once('/')
406                            .map(|(prefix, _)| prefix)
407                            .filter(|path| !path.is_empty())
408                            .unwrap_or("/");
409
410                        let base = route.uri.base().as_str();
411                        let uri = match route.uri.unmounted().query() {
412                            Some(q) => format!("{}?{}", new_path, q),
413                            None => new_path.to_string(),
414                        };
415
416                        let mut route = route.clone();
417                        route.uri = RouteUri::try_new(base, &uri).ok()?;
418                        route.name = route.name.map(|r| format!("{} [normalized]", r).into());
419                        Some(route)
420                    })
421                    .collect::<Vec<_>>();
422
423                Ok(rocket.mount("/", normalized_trailing))
424            }
425
426            async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
427                let _ = self.routes(rocket);
428            }
429
430            async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
431                // If the URI has no trailing slash, it routes as before.
432                if req.uri().is_normalized_nontrailing() {
433                    return;
434                }
435
436                // Otherwise, check if there's a route that matches the request
437                // with a trailing slash. If there is, leave the request alone.
438                // This allows incremental compatibility updates. Otherwise,
439                // rewrite the request URI to remove the `/`.
440                if !self.routes(req.rocket()).iter().any(|r| r.matches(req)) {
441                    let normalized = req.uri().clone().into_normalized_nontrailing();
442                    warn!(original = %req.uri(), %normalized,
443                        "incoming request URI normalized for compatibility");
444                    req.set_uri(normalized);
445                }
446            }
447        }
448
449        Normalizer::default()
450    }
451}
452
453#[crate::async_trait]
454impl Fairing for AdHoc {
455    fn info(&self) -> Info {
456        let kind = match self.kind {
457            AdHocKind::Ignite(_) => Kind::Ignite,
458            AdHocKind::Liftoff(_) => Kind::Liftoff,
459            AdHocKind::Request(_) => Kind::Request,
460            AdHocKind::Response(_) => Kind::Response,
461            AdHocKind::Shutdown(_) => Kind::Shutdown,
462        };
463
464        Info {
465            name: self.name,
466            kind,
467        }
468    }
469
470    async fn on_ignite(&self, rocket: Rocket<Build>) -> Result {
471        match self.kind {
472            AdHocKind::Ignite(ref f) => (f.take())(rocket).await,
473            _ => Ok(rocket),
474        }
475    }
476
477    async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
478        if let AdHocKind::Liftoff(ref f) = self.kind {
479            (f.take())(rocket).await
480        }
481    }
482
483    async fn on_request(&self, req: &mut Request<'_>, data: &mut Data<'_>) {
484        if let AdHocKind::Request(ref f) = self.kind {
485            f(req, data).await
486        }
487    }
488
489    async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
490        if let AdHocKind::Response(ref f) = self.kind {
491            f(req, res).await
492        }
493    }
494
495    async fn on_shutdown(&self, rocket: &Rocket<Orbit>) {
496        if let AdHocKind::Shutdown(ref f) = self.kind {
497            (f.take())(rocket).await
498        }
499    }
500}