use super::{NegotiationConfig, body::ErrorBody};
use crate::ext::HeapErrorExt;
use bytes::Bytes;
use http::{Response, header};
use http_body::Body;
use pin_project_lite::pin_project;
use std::{
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use tower::BoxError;
pin_project! {
pub struct ErrorServiceFuture<F, ResBody> {
#[pin]
pub(super) inner: F,
pub(super) config: NegotiationConfig,
pub(super) accept: Option<String>,
pub(super) poll_ready_error: Option<BoxError>,
#[pin]
pub(super) _marker: PhantomData<ResBody>,
}
}
impl<F, ResBody, E> Future for ErrorServiceFuture<F, ResBody>
where
F: Future<Output = Result<Response<ResBody>, E>>,
E: Into<BoxError>,
ResBody: Body,
ResBody::Error: Into<BoxError>,
{
type Output = Result<Response<ErrorBody<ResBody>>, std::convert::Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Some(error) = this.poll_ready_error.take() {
return Poll::Ready(Ok(Self::render_error_response(
&error,
this.config,
this.accept.as_deref(),
)));
}
match this.inner.poll(cx) {
Poll::Ready(Ok(response)) => {
let (parts, body) = response.into_parts();
let response = Response::from_parts(parts, ErrorBody::passthrough(body));
Poll::Ready(Ok(response))
}
Poll::Ready(Err(error)) => {
let boxed_error = error.into();
Poll::Ready(Ok(Self::render_error_response(
&boxed_error,
this.config,
this.accept.as_deref(),
)))
}
Poll::Pending => Poll::Pending,
}
}
}
impl<F, ResBody> ErrorServiceFuture<F, ResBody> {
fn render_error_response<B>(
boxed_error: &BoxError,
config: &NegotiationConfig,
accept: Option<&str>,
) -> Response<ErrorBody<B>> {
let (renderer, content_type) = config.negotiate_with_content_type(accept);
let format_config = config.format_for(renderer);
let body_bytes = match renderer {
super::Renderer::Json => {
match boxed_error.to_json(format_config) {
Ok(json) => serde_json::to_vec(&json).unwrap_or_else(|_| {
br#"{"error":"INTERNAL_SERVER_ERROR","message":"Failed to serialize error"}"#.to_vec()
}),
Err(_) => {
br#"{"error":"INTERNAL_SERVER_ERROR","message":"Failed to serialize error response"}"#.to_vec()
}
}
}
super::Renderer::Html => {
let html = boxed_error.to_html(format_config);
html.into_bytes()
}
super::Renderer::GraphQL => {
match boxed_error.to_graphql(format_config) {
Ok(graphql_error) => {
let graphql = serde_json::json!({ "errors": [graphql_error] });
serde_json::to_vec(&graphql).unwrap_or_else(|_| {
br#"{"errors":[{"message":"Internal server error"}]}"#.to_vec()
})
}
Err(_) => {
br#"{"errors":[{"message":"Failed to serialize error response"}]}"#.to_vec()
}
}
}
super::Renderer::Text => {
let text = boxed_error.to_text(format_config);
text.into_bytes()
}
super::Renderer::JsonRpc => {
const INTERNAL_ERROR: &[u8] =
br#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"Internal error"},"id":null}"#;
match boxed_error.to_jsonrpc(format_config) {
Ok(jsonrpc_error) => {
let jsonrpc = serde_json::json!({
"jsonrpc": "2.0",
"error": jsonrpc_error,
"id": null
});
serde_json::to_vec(&jsonrpc).unwrap_or_else(|_| INTERNAL_ERROR.to_vec())
}
Err(_) => INTERNAL_ERROR.to_vec(),
}
}
};
let status_code = boxed_error.http_status();
let error_headers = boxed_error.http_headers();
let mut builder = Response::builder()
.status(status_code)
.header(header::CONTENT_TYPE, content_type);
for (name, value) in error_headers {
builder = builder.header(name, value);
}
builder
.body(ErrorBody::error(Bytes::from(body_bytes)))
.unwrap_or_else(|_e| {
Response::builder()
.status(500)
.header(header::CONTENT_TYPE, "text/plain")
.body(ErrorBody::error(Bytes::from_static(
b"Internal Server Error",
)))
.expect("static error response should always build successfully")
})
}
}