saphir 3.1.0

Fully async-await http server framework
Documentation
//! Router is responsible for redirecting requests to handlers.
//!
//! *SAFETY NOTICE*
//!
//! Inside the router we need a little bit of unsafe code. This code allow us to
//! consider the futures generated by the handlers as 'static. This is
//! considered safe since all data lives within the server stack which has a
//! static lifetime over your application. We plan to remove this unsafe code as
//! soon as we find another solution to it.
#![allow(clippy::type_complexity)]

use crate::{
    body::Body,
    controller::{Controller, DynControllerHandler},
    error::SaphirError,
    guard::{Builder as GuardBuilder, GuardChain, GuardChainEnd},
    handler::DynHandler,
    http_context::{HandlerMetadata, HttpContext, RouteId, State},
    request::Request,
    responder::{DynResponder, Responder},
    utils::{EndpointResolver, EndpointResolverResult},
};
use futures::{future::BoxFuture, FutureExt};
use http::Method;
use std::{collections::HashMap, sync::Arc};

/// Builder type for the router
pub struct Builder<Chain: RouterChain + Send + Unpin + 'static + Sync> {
    resolver: HashMap<String, EndpointResolver>,
    chain: Chain,
}

impl Default for Builder<RouterChainEnd> {
    fn default() -> Self {
        Self {
            resolver: Default::default(),
            chain: RouterChainEnd { handlers: Default::default() },
        }
    }
}

impl<Controllers: 'static + RouterChain + Unpin + Send + Sync> Builder<Controllers> {
    /// Add a simple request handle to a given path
    ///
    /// ```rust
    /// # use saphir::router::Builder as RBuilder;
    /// # use saphir::prelude::*;
    /// #
    /// # let builder = RBuilder::default();
    /// // Simply declare a handler fn
    /// async fn simple_handler(req: Request<Body>) -> impl Responder {200}
    ///
    /// // Then while building your server
    /// // ...
    /// builder.route("/simple", Method::GET, simple_handler);
    /// // ...
    /// ```
    pub fn route<H>(mut self, route: &str, method: Method, handler: H) -> Self
    where
        H: 'static + DynHandler<Body> + Send + Sync,
    {
        let endpoint_id = if let Some(er) = self.resolver.get_mut(route) {
            er.add_method(method.clone());
            er.id()
        } else {
            let er = EndpointResolver::new(route, method.clone()).expect("Unable to construct endpoint resolver");
            let er_id = er.id();
            self.resolver.insert(route.to_string(), er);
            er_id
        };

        self.chain
            .add_handler(endpoint_id, method, Box::new(handler), crate::guard::Builder::default().build());

        self
    }

    /// Add a request handler to a given path behind guards
    ///
    /// ```rust
    /// # use saphir::router::Builder as RBuilder;
    /// # use saphir::prelude::*;
    /// #
    /// # let builder = RBuilder::default();
    ///
    /// async fn handler(req: Request<Body>) -> impl Responder { 200 }
    ///
    /// async fn value_guard(req: Request<Body>) -> Result<Request<Body>, u16> {
    ///    match req.captures().get("value") {
    ///        Some(v) if v.eq("allowed_value") => Ok(req),
    ///        Some(_) => Err(403),
    ///        None => Err(400),
    ///    }
    /// }
    ///
    /// builder.route_with_guards("/handler/{value}", Method::GET, handler, |g| {
    ///     g.apply(value_guard)
    /// });
    /// // ...
    /// ```
    pub fn route_with_guards<H, F, Chain>(mut self, route: &str, method: Method, handler: H, guards: F) -> Self
    where
        H: 'static + DynHandler<Body> + Send + Sync,
        F: FnOnce(GuardBuilder<GuardChainEnd>) -> GuardBuilder<Chain>,
        Chain: GuardChain + 'static,
    {
        let endpoint_id = if let Some(er) = self.resolver.get_mut(route) {
            er.add_method(method.clone());
            er.id()
        } else {
            let er = EndpointResolver::new(route, method.clone()).expect("Unable to construct endpoint resolver");
            let er_id = er.id();
            self.resolver.insert(route.to_string(), er);
            er_id
        };

        self.chain
            .add_handler(endpoint_id, method, Box::new(handler), guards(GuardBuilder::default()).build());

        self
    }

    /// Add a simple request handle to a given path
    ///
    /// ```rust
    /// # use saphir::router::Builder as RBuilder;
    /// # use saphir::prelude::*;
    /// #
    /// # let builder = RBuilder::default();
    /// // Implement controller for your struct
    /// struct SimpleController;
    /// # impl Controller for SimpleController {
    /// #    const BASE_PATH: &'static str = "/basic";
    /// #    fn handlers(&self) -> Vec<ControllerEndpoint<Self>> where Self: Sized {EndpointsBuilder::new().build()}
    /// # }
    /// // Then while building your server
    /// // ...
    /// builder.controller(SimpleController);
    /// // ...
    /// ```
    pub fn controller<C: Controller + Send + Unpin + Sync>(mut self, controller: C) -> Builder<RouterChainLink<C, Controllers>> {
        let mut handlers = HashMap::new();
        for (name, method, subroute, handler, guard_chain) in controller.handlers() {
            let route = format!("{}{}", C::BASE_PATH, subroute);
            let meta = name.map(|name| HandlerMetadata {
                route_id: Default::default(),
                name: Some(name),
            });
            let endpoint_id = if let Some(er) = self.resolver.get_mut(&route) {
                er.add_method_with_metadata(method.clone(), meta);
                er.id()
            } else {
                let er = EndpointResolver::new_with_metadata(&route, method.clone(), meta).expect("Unable to construct endpoint resolver");
                let er_id = er.id();
                self.resolver.insert(route, er);
                er_id
            };

            handlers.insert((endpoint_id, method), (handler, guard_chain));
        }

        Builder {
            resolver: self.resolver,
            chain: RouterChainLink {
                controller,
                handlers,
                rest: self.chain,
            },
        }
    }

    pub(crate) fn build(self) -> Router {
        let Builder { resolver, chain: controllers } = self;

        let mut resolvers: Vec<_> = resolver.into_values().collect();
        resolvers.sort_unstable();

        Router {
            inner: Arc::new(RouterInner {
                resolvers,
                chain: Box::new(controllers),
            }),
        }
    }
}

struct RouterInner {
    resolvers: Vec<EndpointResolver>,
    chain: Box<dyn RouterChain + Send + Unpin + Sync>,
}

#[doc(hidden)]
#[derive(Clone)]
pub struct Router {
    inner: Arc<RouterInner>,
}

impl Router {
    pub fn builder() -> Builder<RouterChainEnd> {
        Builder::default()
    }

    pub fn resolve(&self, req: &mut Request<Body>) -> Result<u64, u16> {
        let mut method_not_allowed = false;
        for endpoint_resolver in &self.inner.resolvers {
            match endpoint_resolver.resolve(req) {
                EndpointResolverResult::InvalidPath => continue,
                EndpointResolverResult::MethodNotAllowed => method_not_allowed = true,
                EndpointResolverResult::Match(_) => return Ok(endpoint_resolver.id()),
            }
        }

        if method_not_allowed {
            Err(405)
        } else {
            Err(404)
        }
    }

    pub fn resolve_metadata(&self, req: &mut Request) -> HandlerMetadata {
        let mut method_not_allowed = false;

        for endpoint_resolver in &self.inner.resolvers {
            match endpoint_resolver.resolve(req) {
                EndpointResolverResult::InvalidPath => continue,
                EndpointResolverResult::MethodNotAllowed => method_not_allowed = true,
                EndpointResolverResult::Match(meta) => return meta.clone(),
            }
        }

        if method_not_allowed {
            HandlerMetadata::not_allowed()
        } else {
            HandlerMetadata::not_found()
        }
    }

    pub async fn dispatch(&self, mut ctx: HttpContext) -> Result<HttpContext, SaphirError> {
        let req = ctx.state.take_request().ok_or(SaphirError::RequestMovedBeforeHandler)?;
        // # SAFETY #
        // The router is initialized in static memory when calling run on Server.
        let static_self = unsafe { std::mem::transmute::<&'_ Self, &'static Self>(self) };
        let b = crate::response::Builder::new();
        let route_id = match ctx.metadata.route_id {
            RouteId::Id(id) => id,
            RouteId::Error(e) => {
                return e.respond_with_builder(b, &ctx).build().map(|r| {
                    ctx.state = State::After(Box::new(r));
                    ctx
                });
            }
        };
        let res = if let Some(responder) = static_self.inner.chain.dispatch(route_id, req) {
            responder.await.dyn_respond(b, &ctx)
        } else {
            404.respond_with_builder(b, &ctx)
        }
        .build();

        res.map(|r| {
            ctx.state = State::After(Box::new(r));
            ctx
        })
    }
}

#[doc(hidden)]
pub trait RouterChain {
    fn dispatch(&'static self, resolver_id: u64, req: Request<Body>) -> Option<BoxFuture<'static, Box<dyn DynResponder + Send>>>;
    fn add_handler(&mut self, endpoint_id: u64, method: Method, handler: Box<dyn DynHandler<Body> + Send + Sync>, guards: Box<dyn GuardChain>);
}

#[doc(hidden)]
pub struct RouterChainEnd {
    handlers: HashMap<(u64, Method), (Box<dyn DynHandler<Body> + Send + Sync>, Box<dyn GuardChain>)>,
}

impl RouterChain for RouterChainEnd {
    #[inline]
    fn dispatch(&'static self, resolver_id: u64, req: Request<Body>) -> Option<BoxFuture<'static, Box<dyn DynResponder + Send>>> {
        if let Some(handler) = self.handlers.get(&(resolver_id, req.method().clone())) {
            if handler.1.is_end() {
                Some(handler.0.dyn_handle(req))
            } else {
                let fut = handler.1.validate(req).then(move |req| async move {
                    match req {
                        Ok(req) => handler.0.dyn_handle(req).await,
                        Err(resp) => resp,
                    }
                });
                Some(fut.boxed())
            }
        } else {
            None
        }
    }

    #[inline]
    fn add_handler(&mut self, endpoint_id: u64, method: Method, handler: Box<dyn DynHandler<Body> + Send + Sync>, guards: Box<dyn GuardChain>) {
        self.handlers.insert((endpoint_id, method), (handler, guards));
    }
}

#[doc(hidden)]
pub struct RouterChainLink<C, Rest: RouterChain> {
    controller: C,
    handlers: HashMap<(u64, Method), (Box<dyn DynControllerHandler<C, Body> + Send + Sync>, Box<dyn GuardChain>)>,
    rest: Rest,
}

impl<C: Sync + Send, Rest: RouterChain + Sync + Send> RouterChain for RouterChainLink<C, Rest> {
    #[inline]
    fn dispatch(&'static self, resolver_id: u64, req: Request<Body>) -> Option<BoxFuture<'static, Box<dyn DynResponder + Send>>> {
        if let Some(handler) = self.handlers.get(&(resolver_id, req.method().clone())) {
            if handler.1.is_end() {
                Some(handler.0.dyn_handle(&self.controller, req))
            } else {
                let fut = handler.1.validate(req).then(move |req| async move {
                    match req {
                        Ok(req) => handler.0.dyn_handle(&self.controller, req).await,
                        Err(resp) => resp,
                    }
                });
                Some(fut.boxed())
            }
        } else {
            self.rest.dispatch(resolver_id, req)
        }
    }

    #[inline]
    fn add_handler(&mut self, endpoint_id: u64, method: Method, handler: Box<dyn DynHandler<Body> + Send + Sync>, guards: Box<dyn GuardChain>) {
        self.rest.add_handler(endpoint_id, method, handler, guards);
    }
}