via 2.0.0-gm.24

An async multi-threaded web framework for people who appreciate simplicity.
Documentation
use http::{StatusCode, header};
use std::borrow::Cow;
use std::fmt::{self, Display, Formatter};
use std::sync::Arc;

use super::{Error, ErrorKind, Errors};
use crate::middleware::{BoxFuture, Middleware};
use crate::response::{Finalize, Json, Response, ResponseBuilder};
use crate::{Next, Request};

/// Recover from errors that occur in downstream middleware.
///
pub struct Rescue<F> {
    recover: Arc<F>,
}

/// Customize how an [`Error`] is converted to a response.
///
pub struct Sanitizer<'a> {
    json: bool,
    error: &'a Error,
    status: Option<StatusCode>,
    message: Option<Cow<'a, str>>,
}

impl<F> Rescue<F>
where
    F: Fn(&mut Sanitizer) + Send + Sync,
{
    pub fn with(recover: F) -> Self {
        Self {
            recover: Arc::new(recover),
        }
    }
}

impl<App, F> Middleware<App> for Rescue<F>
where
    F: Fn(&mut Sanitizer) + Send + Sync + 'static,
{
    fn call(&self, request: Request<App>, next: Next<App>) -> BoxFuture {
        let recover = Arc::clone(&self.recover);
        let future = next.call(request);

        Box::pin(async move {
            future.await.or_else(|error| {
                let mut sanitizer = Sanitizer::new(&error);
                recover(&mut sanitizer);

                let response = Response::build();
                sanitizer.finalize(response).or_else(|residual| {
                    if cfg!(debug_assertions) {
                        eprintln!("warn: a residual error occurred in rescue");
                        eprintln!("{}", residual);
                    }

                    Ok(error.into())
                })
            })
        })
    }
}

impl<'a> Sanitizer<'a> {
    /// Returns a reference to the error source.
    ///
    pub fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        self.error.source()
    }

    /// Provide a custom message to use for the response generated from this
    /// error.
    ///
    pub fn set_message<T>(&mut self, message: T)
    where
        Cow<'a, str>: From<T>,
    {
        self.message = Some(message.into());
    }

    /// Overrides the HTTP status code of the error.
    ///
    pub fn set_status(&mut self, status: StatusCode) {
        self.status = Some(status);
    }

    /// Use the canonical reason of the status code as the error message.
    ///
    pub fn use_canonical_reason(&mut self) {
        self.message = self.status().canonical_reason().map(Cow::Borrowed);
    }

    /// Generate a json response for the error.
    ///
    pub fn use_json(&mut self) {
        self.json = true;
    }
}

impl<'a> Sanitizer<'a> {
    fn new(error: &'a Error) -> Self {
        Self {
            json: false,
            error,
            status: None,
            message: None,
        }
    }

    fn status(&self) -> StatusCode {
        self.status.unwrap_or(self.error.status)
    }

    fn repr_json(self, status: StatusCode) -> Errors<'a> {
        if let Some(message) = self.message {
            let mut errors = Errors::new(status);
            errors.push(message);
            errors
        } else {
            self.error.repr_json(status)
        }
    }
}

impl Display for Sanitizer<'_> {
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
        Display::fmt(self.error, f)
    }
}

impl Finalize for Sanitizer<'_> {
    fn finalize(self, builder: ResponseBuilder) -> Result<Response, Error> {
        let status = self.status();
        let mut builder = builder.status(status);

        if let ErrorKind::MethodNotAllowed(error) = &self.error.kind {
            builder = builder.header(header::ALLOW, error.allows());
        }

        if self.json {
            Json(&self.repr_json(status)).finalize(builder)
        } else if let Some(message) = self.message {
            builder.text(message.into_owned())
        } else {
            builder.text(self.error.to_string())
        }
    }
}