dhttp 0.2.0

The True Internet
Documentation
use std::{
    collections::HashMap,
    sync::{Arc, RwLock},
    task::{Context, Poll},
};

use futures::future::BoxFuture;
use http::{Method, StatusCode};
use snafu::{Report, ResultExt, Snafu};

use crate::endpoint::server::{
    BoxService, BoxServiceFuture, IntoBoxService, Request, ResolveError, Response, Serve,
    UnresolvedRequest, box_service,
};

#[tracing::instrument(skip_all)]
pub async fn default_fallback(_request: &mut Request, response: &mut Response) {
    tracing::debug!("call default fallback service (404 Not Found)");
    _ = response.set_status(StatusCode::NOT_FOUND)
}

#[derive(Debug, Clone)]
struct Fallback(Arc<RwLock<BoxService>>);

impl Fallback {
    pub fn new(service: BoxService) -> Self {
        Self(Arc::new(RwLock::new(service)))
    }

    pub fn set(&mut self, service: BoxService) {
        *self.0.write().expect("lock is not poisoned") = service;
    }
}

impl Serve for Fallback {
    type Future<'s> = BoxServiceFuture<'s>;

    fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> Self::Future<'s> {
        tracing::debug!("call fallback service");
        self.0
            .read()
            .expect("lock is not poisoned")
            .serve(request, response)
    }
}

#[derive(Debug, Clone)]
struct RouterInner {
    router: matchit::Router<BoxService>,
    fallback: Fallback,
}

impl Default for RouterInner {
    fn default() -> Self {
        Self {
            router: Default::default(),
            fallback: Fallback::new(box_service(default_fallback)),
        }
    }
}

impl RouterInner {
    fn route(&mut self, path: &str, service: impl IntoBoxService) {
        self.router
            .insert(path, service.into_box_service())
            .expect("failed to register route");
    }

    pub fn on(&mut self, method: Method, path: &str, service: impl IntoBoxService) {
        match self.router.at_mut(path) {
            Ok(exist_service) => {
                if let Some(router) = exist_service
                    .value
                    .downcast_mut::<MethodRouter<BoxService>>()
                {
                    router.set(method, service.into_box_service());
                } else {
                    let fallback = exist_service.value.clone();
                    let mut router = MethodRouter::new(fallback);
                    router.set(method, service.into_box_service());
                    *exist_service.value = router.into_box_service();
                }
            }
            Err(..) => {
                let mut router = MethodRouter::new(self.fallback.clone().into_box_service());
                router.set(method, service.into_box_service());
                self.route(path, router)
            }
        }
    }
}

impl Serve for RouterInner {
    type Future<'s> = BoxServiceFuture<'s>;

    fn serve<'s>(
        &self,
        request: &'s mut Request,
        response: &'s mut Response,
    ) -> BoxServiceFuture<'s> {
        let uri = request.uri();
        let Some(path_and_query) = uri.path_and_query() else {
            tracing::debug!("missing path in request URI, call fallback service");
            return self.fallback.serve(request, response);
        };
        let path = path_and_query.path();
        let Ok(endpoint) = self.router.at(path) else {
            tracing::debug!(path, "path route: not found, call fallback service");
            return self.fallback.serve(request, response);
        };

        tracing::debug!(path, "path route found, call matched service");
        endpoint.value.serve(request, response)
    }
}

#[derive(Debug, Default, Clone)]
pub struct Service {
    inner: Arc<RouterInner>,
}

impl Service {
    pub fn new() -> Self {
        Self::default()
    }

    fn inner_ref(&self) -> &RouterInner {
        &self.inner
    }

    fn inner_mut(&mut self) -> &mut RouterInner {
        Arc::make_mut(&mut self.inner)
    }

    pub fn route(mut self, path: &str, service: impl IntoBoxService) -> Self {
        self.inner_mut().route(path, service.into_box_service());
        self
    }

    pub fn on(mut self, method: Method, path: &str, service: impl IntoBoxService) -> Self {
        self.inner_mut()
            .on(method, path, service.into_box_service());
        self
    }

    pub fn fallback(mut self, service: impl IntoBoxService) -> Self {
        self.inner_mut().fallback.set(service.into_box_service());
        self
    }

    pub fn options(self, path: &str, service: impl IntoBoxService) -> Self {
        self.on(Method::OPTIONS, path, service)
    }
    pub fn get(self, path: &str, service: impl IntoBoxService) -> Self {
        self.on(Method::GET, path, service)
    }
    pub fn post(self, path: &str, service: impl IntoBoxService) -> Self {
        self.on(Method::POST, path, service)
    }
    pub fn put(self, path: &str, service: impl IntoBoxService) -> Self {
        self.on(Method::PUT, path, service)
    }
    pub fn delete(self, path: &str, service: impl IntoBoxService) -> Self {
        self.on(Method::DELETE, path, service)
    }
    pub fn head(self, path: &str, service: impl IntoBoxService) -> Self {
        self.on(Method::HEAD, path, service)
    }
    pub fn trace(self, path: &str, service: impl IntoBoxService) -> Self {
        self.on(Method::TRACE, path, service)
    }
    pub fn connect(self, path: &str, service: impl IntoBoxService) -> Self {
        self.on(Method::CONNECT, path, service)
    }
    pub fn patch(self, path: &str, service: impl IntoBoxService) -> Self {
        self.on(Method::PATCH, path, service)
    }

    pub fn serve<'s>(
        &self,
        request: &'s mut Request,
        response: &'s mut Response,
    ) -> BoxServiceFuture<'s> {
        self.inner_ref().serve(request, response)
    }

    #[tracing::instrument(skip(self, req), fields(method = tracing::field::Empty, uri = tracing::field::Empty))]
    pub async fn handle(&self, req: UnresolvedRequest) -> Result<(), HandleError> {
        let (mut request, mut response) = crate::endpoint::server::resolve(req)
            .await
            .context(handle_error::ResolveSnafu)?;

        tracing::Span::current()
            .record("method", request.method().as_str())
            .record("uri", request.uri().to_string());

        self.serve(&mut request, &mut response).await;

        // Drop response in place to avoid spawning another tokio task
        // FIXME: remove this when async drop is stabilized (https://github.com/rust-lang/rust/issues/126482)
        if let Some(drop_future) = response.drop()
            && let Err(error) = drop_future.await
        {
            let report = Report::from_error(&error);
            tracing::debug!(error = %report, "failed to finish response after service handler");
        }

        Ok(())
    }
}

impl Serve for Service {
    type Future<'s> = BoxServiceFuture<'s>;

    fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> Self::Future<'s> {
        Service::serve(self, request, response)
    }
}

impl tower_service::Service<UnresolvedRequest> for Service {
    type Response = ();

    type Error = HandleError;

    type Future = BoxFuture<'static, Result<(), HandleError>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        _ = cx;
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: UnresolvedRequest) -> Self::Future {
        let service = self.clone();
        Box::pin(async move { service.handle(req).await })
    }
}

#[derive(Debug, Snafu)]
#[snafu(module)]
pub enum HandleError {
    #[snafu(display("failed to resolve server request"))]
    Resolve { source: ResolveError },
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MethodRouter<S> {
    // most used methods are stored separately for faster access
    options: Option<S>,
    get: Option<S>,
    post: Option<S>,
    put: Option<S>,
    delete: Option<S>,
    head: Option<S>,
    trace: Option<S>,
    connect: Option<S>,
    patch: Option<S>,
    // other
    extensions: HashMap<Method, S>,
    // fallback service when no method match
    fallback: S,
}

impl<S> MethodRouter<S> {
    pub fn new(fallback: S) -> Self {
        Self {
            options: None,
            get: None,
            post: None,
            put: None,
            delete: None,
            head: None,
            trace: None,
            connect: None,
            patch: None,
            extensions: HashMap::new(),
            fallback,
        }
    }

    pub fn service(&self, method: Method) -> Option<&S> {
        match method {
            Method::OPTIONS => self.options.as_ref(),
            Method::GET => self.get.as_ref(),
            Method::POST => self.post.as_ref(),
            Method::PUT => self.put.as_ref(),
            Method::DELETE => self.delete.as_ref(),
            Method::HEAD => self.head.as_ref(),
            Method::TRACE => self.trace.as_ref(),
            Method::CONNECT => self.connect.as_ref(),
            Method::PATCH => self.patch.as_ref(),
            _ => self.extensions.get(&method),
        }
    }

    pub fn service_mut(&mut self, method: Method) -> Option<&mut S> {
        match method {
            Method::OPTIONS => self.options.as_mut(),
            Method::GET => self.get.as_mut(),
            Method::POST => self.post.as_mut(),
            Method::PUT => self.put.as_mut(),
            Method::DELETE => self.delete.as_mut(),
            Method::HEAD => self.head.as_mut(),
            Method::TRACE => self.trace.as_mut(),
            Method::CONNECT => self.connect.as_mut(),
            Method::PATCH => self.patch.as_mut(),
            _ => self.extensions.get_mut(&method),
        }
    }

    pub fn set(&mut self, method: Method, service: S) {
        match method {
            Method::OPTIONS => self.options = Some(service),
            Method::GET => self.get = Some(service),
            Method::POST => self.post = Some(service),
            Method::PUT => self.put = Some(service),
            Method::DELETE => self.delete = Some(service),
            Method::HEAD => self.head = Some(service),
            Method::TRACE => self.trace = Some(service),
            Method::CONNECT => self.connect = Some(service),
            Method::PATCH => self.patch = Some(service),
            _ => _ = self.extensions.insert(method, service),
        }
    }

    pub fn set_fallback(&mut self, service: S) {
        self.fallback = service;
    }
}

impl<S> Serve for MethodRouter<S>
where
    S: Clone + for<'s> Serve<Future<'s>: Send> + Send + 'static,
{
    type Future<'s> = BoxServiceFuture<'s>;

    fn serve<'s>(
        &self,
        request: &'s mut super::Request,
        response: &'s mut super::Response,
    ) -> Self::Future<'s> {
        let method = request.method();
        let service = match method {
            Method::OPTIONS => self.options.as_ref().unwrap_or(&self.fallback),
            Method::GET => self.get.as_ref().unwrap_or(&self.fallback),
            Method::POST => self.post.as_ref().unwrap_or(&self.fallback),
            Method::PUT => self.put.as_ref().unwrap_or(&self.fallback),
            Method::DELETE => self.delete.as_ref().unwrap_or(&self.fallback),
            Method::HEAD => self.head.as_ref().unwrap_or(&self.fallback),
            Method::TRACE => self.trace.as_ref().unwrap_or(&self.fallback),
            Method::CONNECT => self.connect.as_ref().unwrap_or(&self.fallback),
            Method::PATCH => self.patch.as_ref().unwrap_or(&self.fallback),
            _ => self.extensions.get(&method).unwrap_or(&self.fallback),
        }
        .clone();
        Box::pin(async move { service.serve(request, response).await })
    }
}