use futures::future::{BoxFuture, Future, FutureExt};
use parking_lot::Mutex;
use crate::fairing::{Fairing, Info, Kind, Result};
use crate::route::RouteUri;
use crate::trace::Trace;
use crate::{Build, Data, Orbit, Request, Response, Rocket};
pub struct AdHoc {
name: &'static str,
kind: AdHocKind,
}
struct Once<F: ?Sized>(Mutex<Option<Box<F>>>);
impl<F: ?Sized> Once<F> {
fn new(f: Box<F>) -> Self {
Once(Mutex::new(Some(f)))
}
#[track_caller]
fn take(&self) -> Box<F> {
self.0.lock().take().expect("Once::take() called once")
}
}
enum AdHocKind {
Ignite(Once<dyn FnOnce(Rocket<Build>) -> BoxFuture<'static, Result> + Send + 'static>),
Liftoff(Once<dyn for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()> + Send + 'static>),
Request(
Box<
dyn for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()>
+ Send
+ Sync
+ 'static,
>,
),
Response(
Box<
dyn for<'r, 'b> Fn(&'r Request<'_>, &'b mut Response<'r>) -> BoxFuture<'b, ()>
+ Send
+ Sync
+ 'static,
>,
),
Shutdown(Once<dyn for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()> + Send + 'static>),
}
impl AdHoc {
pub fn on_ignite<F, Fut>(name: &'static str, f: F) -> AdHoc
where
F: FnOnce(Rocket<Build>) -> Fut + Send + 'static,
Fut: Future<Output = Rocket<Build>> + Send + 'static,
{
AdHoc::try_on_ignite(name, |rocket| f(rocket).map(Ok))
}
pub fn try_on_ignite<F, Fut>(name: &'static str, f: F) -> AdHoc
where
F: FnOnce(Rocket<Build>) -> Fut + Send + 'static,
Fut: Future<Output = Result> + Send + 'static,
{
AdHoc {
name,
kind: AdHocKind::Ignite(Once::new(Box::new(|r| f(r).boxed()))),
}
}
pub fn on_liftoff<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where
F: for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()>,
{
AdHoc {
name,
kind: AdHocKind::Liftoff(Once::new(Box::new(f))),
}
}
pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where
F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()>,
{
AdHoc {
name,
kind: AdHocKind::Request(Box::new(f)),
}
}
pub fn on_response<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where
F: for<'b, 'r> Fn(&'r Request<'_>, &'b mut Response<'r>) -> BoxFuture<'b, ()>,
{
AdHoc {
name,
kind: AdHocKind::Response(Box::new(f)),
}
}
pub fn on_shutdown<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where
F: for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()>,
{
AdHoc {
name,
kind: AdHocKind::Shutdown(Once::new(Box::new(f))),
}
}
pub fn config<'de, T>() -> AdHoc
where
T: serde::Deserialize<'de> + Send + Sync + 'static,
{
AdHoc::try_on_ignite(std::any::type_name::<T>(), |rocket| async {
let app_config = match rocket.figment().extract::<T>() {
Ok(config) => config,
Err(e) => {
e.trace_error();
return Err(rocket);
}
};
Ok(rocket.manage(app_config))
})
}
pub fn uri_normalizer() -> impl Fairing {
#[derive(Default)]
struct Normalizer {
routes: state::InitCell<Vec<crate::Route>>,
}
impl Normalizer {
fn routes(&self, rocket: &Rocket<Orbit>) -> &[crate::Route] {
self.routes.get_or_init(|| {
rocket
.routes()
.filter(|r| r.uri.has_trailing_slash())
.cloned()
.collect()
})
}
}
#[crate::async_trait]
impl Fairing for Normalizer {
fn info(&self) -> Info {
Info {
name: "URI Normalizer",
kind: Kind::Ignite | Kind::Liftoff | Kind::Request,
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> Result {
let normalized_trailing = rocket
.routes()
.filter(|r| r.uri.metadata.dynamic_trail)
.filter(|r| r.uri.path().segments().num() > 1)
.filter_map(|route| {
let path = route.uri.unmounted().path();
let new_path = path
.as_str()
.rsplit_once('/')
.map(|(prefix, _)| prefix)
.filter(|path| !path.is_empty())
.unwrap_or("/");
let base = route.uri.base().as_str();
let uri = match route.uri.unmounted().query() {
Some(q) => format!("{}?{}", new_path, q),
None => new_path.to_string(),
};
let mut route = route.clone();
route.uri = RouteUri::try_new(base, &uri).ok()?;
route.name = route.name.map(|r| format!("{} [normalized]", r).into());
Some(route)
})
.collect::<Vec<_>>();
Ok(rocket.mount("/", normalized_trailing))
}
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
let _ = self.routes(rocket);
}
async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
if req.uri().is_normalized_nontrailing() {
return;
}
if !self.routes(req.rocket()).iter().any(|r| r.matches(req)) {
let normalized = req.uri().clone().into_normalized_nontrailing();
warn!(original = %req.uri(), %normalized,
"incoming request URI normalized for compatibility");
req.set_uri(normalized);
}
}
}
Normalizer::default()
}
}
#[crate::async_trait]
impl Fairing for AdHoc {
fn info(&self) -> Info {
let kind = match self.kind {
AdHocKind::Ignite(_) => Kind::Ignite,
AdHocKind::Liftoff(_) => Kind::Liftoff,
AdHocKind::Request(_) => Kind::Request,
AdHocKind::Response(_) => Kind::Response,
AdHocKind::Shutdown(_) => Kind::Shutdown,
};
Info {
name: self.name,
kind,
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> Result {
match self.kind {
AdHocKind::Ignite(ref f) => (f.take())(rocket).await,
_ => Ok(rocket),
}
}
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
if let AdHocKind::Liftoff(ref f) = self.kind {
(f.take())(rocket).await
}
}
async fn on_request(&self, req: &mut Request<'_>, data: &mut Data<'_>) {
if let AdHocKind::Request(ref f) = self.kind {
f(req, data).await
}
}
async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
if let AdHocKind::Response(ref f) = self.kind {
f(req, res).await
}
}
async fn on_shutdown(&self, rocket: &Rocket<Orbit>) {
if let AdHocKind::Shutdown(ref f) = self.kind {
(f.take())(rocket).await
}
}
}