#![cfg(feature = "tower")]
mod common;
use apollo_errors::tower_http::ErrorLayer;
use common::ErrorWithFields;
use http::{Request, Response, StatusCode, header};
use http_body_util::BodyExt;
use std::{
convert::Infallible,
future::Ready,
task::{Context, Poll},
};
use tower::{BoxError, Layer, Service, ServiceBuilder, ServiceExt};
use tower_test::mock;
#[derive(Default)]
struct TestBody {
data: Option<bytes::Bytes>,
}
impl http_body::Body for TestBody {
type Data = bytes::Bytes;
type Error = Infallible;
fn poll_frame(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
if let Some(data) = self.data.take() {
Poll::Ready(Some(Ok(http_body::Frame::data(data))))
} else {
Poll::Ready(None)
}
}
fn is_end_stream(&self) -> bool {
self.data.is_none()
}
}
#[derive(Clone)]
struct BoxErrorService {
error: ErrorWithFields,
}
impl Service<Request<TestBody>> for BoxErrorService {
type Response = Response<TestBody>;
type Error = BoxError;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<TestBody>) -> Self::Future {
std::future::ready(Err(Box::new(self.error.clone())))
}
}
#[derive(Clone)]
struct PollReadyErrorService {
error: ErrorWithFields,
failed: bool,
}
impl Service<Request<TestBody>> for PollReadyErrorService {
type Response = Response<TestBody>;
type Error = BoxError;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.failed {
Poll::Ready(Ok(()))
} else {
self.failed = true;
Poll::Ready(Err(Box::new(self.error.clone())))
}
}
fn call(&mut self, _req: Request<TestBody>) -> Self::Future {
std::future::ready(Ok(Response::new(TestBody::default())))
}
}
#[tokio::test]
async fn test_box_error_with_service_builder() {
let service = BoxErrorService {
error: ErrorWithFields::InvalidPort {
port: 8080,
config_file: "/etc/config.toml".to_string(),
},
};
let mut service = ServiceBuilder::new()
.layer(ErrorLayer::new())
.service(service);
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(TestBody::default())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json"
);
let body = response.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"config_file": "/etc/config.toml",
"error": "config::invalid_port",
"message": "Invalid port",
"port": 8080
}
"#);
}
#[tokio::test]
async fn test_box_error_from_poll_ready() {
let service = PollReadyErrorService {
error: ErrorWithFields::MissingConfig {
expected_path: "/etc/config.yaml".to_string(),
},
failed: false,
};
let mut service = ServiceBuilder::new()
.layer(ErrorLayer::new())
.service(service);
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(TestBody::default())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json"
);
let body = response.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"error": "config::missing",
"expected_path": "/etc/config.yaml",
"message": "Missing configuration"
}
"#);
}
#[tokio::test]
async fn test_box_error_with_tower_test_mock() {
let (mock_service, mut handle) = mock::pair::<Request<TestBody>, Response<TestBody>>();
let mut service = ErrorLayer::new().layer(mock_service);
let handle_task = tokio::spawn(async move {
let (request, send_response) = handle.next_request().await.expect("service not called");
assert_eq!(
request.headers().get(header::ACCEPT).unwrap(),
"application/json"
);
let error: BoxError = Box::new(ErrorWithFields::InvalidPort {
port: 9000,
config_file: "/app/config.toml".to_string(),
});
send_response.send_error(error);
});
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(TestBody::default())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json"
);
let body = response.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"config_file": "/app/config.toml",
"error": "config::invalid_port",
"message": "Invalid port",
"port": 9000
}
"#);
handle_task.await.unwrap();
}
#[tokio::test]
async fn test_service_is_clone() {
let service = BoxErrorService {
error: ErrorWithFields::InvalidPort {
port: 8080,
config_file: "/etc/config.toml".to_string(),
},
};
let layer = ErrorLayer::new();
let service = layer.layer(service);
let mut service1 = service.clone();
let mut service2 = service;
let req1 = Request::builder()
.header(header::ACCEPT, "application/json")
.body(TestBody::default())
.unwrap();
let req2 = Request::builder()
.header(header::ACCEPT, "text/html")
.body(TestBody::default())
.unwrap();
let response1 = service1.ready().await.unwrap().call(req1).await.unwrap();
let response2 = service2.ready().await.unwrap().call(req2).await.unwrap();
assert_eq!(response1.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response1.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json"
);
assert_eq!(response2.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response2.headers().get(header::CONTENT_TYPE).unwrap(),
"text/html"
);
}