use http::StatusCode;
use http::header::ALLOW;
use std::borrow::Cow;
use std::fmt::{self, Display, Formatter};
use super::{Error, ErrorKindRef, Errors};
use crate::middleware::{BoxFuture, Middleware};
use crate::response::{Finalize, Response, ResponseBuilder};
use crate::{Next, Request};
struct Recover<F>(Box<F>);
pub struct Rescue<F> {
recover: Recover<F>,
}
pub struct Sanitizer<'a> {
json: bool,
error: &'a Error,
status: Option<StatusCode>,
message: Option<Cow<'a, str>>,
}
impl<F> Rescue<F>
where
F: Fn(&mut Sanitizer) + Send + Sync,
{
pub fn new(recover: F) -> Self {
Self {
recover: Recover(Box::new(recover)),
}
}
}
impl<App, F> Middleware<App> for Rescue<F>
where
F: Fn(&mut Sanitizer) + Copy + Send + Sync + Sized + 'static,
{
fn call(&self, request: Request<App>, next: Next<App>) -> BoxFuture {
let future = next.call(request);
let recover = *self.recover.0;
Box::pin(async move {
future.await.or_else(|error| {
let mut sanitizer = Sanitizer::new(&error);
recover(&mut sanitizer);
let response = Response::build();
sanitizer.finalize(response).or_else(|residual| {
if cfg!(debug_assertions) {
eprintln!("warn: a residual error occurred in rescue");
eprintln!("{}", residual);
}
Ok(error.into())
})
})
})
}
}
impl<'a> Sanitizer<'a> {
pub fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.error.source()
}
pub fn set_message<T>(&mut self, message: T)
where
Cow<'a, str>: From<T>,
{
self.message = Some(message.into());
}
pub fn set_status(&mut self, status: StatusCode) {
self.status = Some(status);
}
pub fn use_canonical_reason(&mut self) {
self.message = self.status().canonical_reason().map(Cow::Borrowed);
}
pub fn use_json(&mut self) {
self.json = true;
}
}
impl<'a> Sanitizer<'a> {
fn new(error: &'a Error) -> Self {
Self {
json: false,
error,
status: None,
message: None,
}
}
fn status(&self) -> StatusCode {
self.status.unwrap_or(self.error.status)
}
}
impl Display for Sanitizer<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Display::fmt(self.error, f)
}
}
impl Finalize for Sanitizer<'_> {
fn finalize(self, builder: ResponseBuilder) -> Result<Response, Error> {
let status = self.status();
let mut builder = builder.status(status);
if let ErrorKindRef::AllowMethod(error) = self.error.kind()
&& let Some(allow) = error.allows()
{
builder = builder.header(ALLOW, allow);
}
if self.json {
let json = self.message.as_deref().map_or_else(
|| self.error.repr_json(status),
|message| Errors::new(status, message),
);
builder.json(&json)
} else if let Some(message) = self.message {
builder.text(message)
} else {
builder.text(self.error.to_string())
}
}
}