#![cfg(feature = "tower")]
mod common;
use std::convert::Infallible;
use apollo_errors::tower_http::{ErrorLayer, NegotiationConfig, Renderer};
use apollo_errors::{CodeCase, FieldCase, FormatConfig};
use bytes::Bytes;
use common::{
ErrorWithFields, ErrorWithHeaders, ErrorWithRenamedField, ErrorWithStatus, JsonRpcError,
};
use http::{Request, Response, StatusCode, header};
use http_body::Body;
use http_body_util::{Full, combinators::UnsyncBoxBody};
use tower::{BoxError, Layer, Service, ServiceExt, service_fn};
use tower_test::mock;
type TestBody = UnsyncBoxBody<Bytes, Infallible>;
fn error_service<E: Clone + Send + 'static>(
error: E,
) -> impl Service<Request<()>, Response = Response<TestBody>, Error = E> {
service_fn(move |_req: Request<()>| {
let error = error.clone();
async move { Err(error) }
})
}
async fn collect_body<B>(body: B) -> Vec<u8>
where
B: http_body::Body,
B::Error: std::fmt::Debug,
{
use http_body_util::BodyExt;
let collected = body.collect().await.unwrap();
collected.to_bytes().to_vec()
}
#[tokio::test]
async fn test_real_error_with_fields_json() {
let mut service = ErrorLayer::new().layer(error_service(ErrorWithFields::InvalidPort {
port: 8080,
config_file: "/etc/config.toml".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json"
);
let body = collect_body(response.into_body()).await;
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r###"
{
"config_file": "/etc/config.toml",
"error": "config::invalid_port",
"message": "Invalid port",
"port": 8080
}
"###);
}
#[tokio::test]
async fn test_real_error_with_fields_graphql() {
let mut service = ErrorLayer::new().layer(error_service(ErrorWithFields::MissingConfig {
expected_path: "/app/config.yaml".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "application/graphql-response+json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = collect_body(response.into_body()).await;
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"errors": [
{
"extensions": {
"code": "CONFIG_MISSING",
"expectedPath": "/app/config.yaml"
},
"message": "Missing configuration"
}
]
}
"#);
}
#[tokio::test]
async fn test_error_with_custom_http_status() {
let (mock_service, mut handle) = mock::pair::<Request<()>, Response<TestBody>>();
let mut service = ErrorLayer::new().layer(mock_service);
let handle_task = tokio::spawn(async move {
let (request, send_response) = handle.next_request().await.expect("service not called");
assert_eq!(
request.headers().get(header::ACCEPT).unwrap(),
"application/json"
);
let error: BoxError = Box::new(ErrorWithStatus::NotFound);
send_response.send_error(error);
});
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
handle_task.await.unwrap();
}
#[tokio::test]
async fn test_error_with_503_status() {
let (mock_service, mut handle) = mock::pair::<Request<()>, Response<TestBody>>();
let mut service = ErrorLayer::new().layer(mock_service);
let handle_task = tokio::spawn(async move {
let (request, send_response) = handle.next_request().await.expect("service not called");
assert_eq!(
request.headers().get(header::ACCEPT).unwrap(),
"application/json"
);
let error: BoxError = Box::new(ErrorWithStatus::ServiceUnavailable);
send_response.send_error(error);
});
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
handle_task.await.unwrap();
}
#[tokio::test]
async fn test_html_rendering() {
let mut service = ErrorLayer::new().layer(error_service(ErrorWithFields::InvalidPort {
port: 9000,
config_file: "/etc/app.conf".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "text/html")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"text/html"
);
let body = collect_body(response.into_body()).await;
let html = String::from_utf8(body).unwrap();
insta::assert_snapshot!(html, @r#"
<div class="error">
<h3 class="error-code">config::invalid_port</h3>
<p class="error-message">Invalid port</p>
<div class="error-extensions">
<div class="error-field"><span class="field-name">port:</span> <span class="field-value">9000</span></div>
<div class="error-field"><span class="field-name">config_file:</span> <span class="field-value">/etc/app.conf</span></div>
</div>
</div>
"#);
}
#[tokio::test]
async fn test_text_rendering() {
let mut service = ErrorLayer::new().layer(error_service(ErrorWithFields::InvalidPort {
port: 9000,
config_file: "/etc/app.conf".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "text/plain")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"text/plain"
);
let body = collect_body(response.into_body()).await;
let text = String::from_utf8(body).unwrap();
insta::assert_snapshot!(text, @"[config::invalid_port] Invalid port");
}
#[tokio::test]
async fn test_graphql_legacy_content_type() {
let config = NegotiationConfig::new().with_mapping("application/json", Renderer::GraphQL);
let mut service =
ErrorLayer::with_config(config).layer(error_service(ErrorWithFields::InvalidPort {
port: 8080,
config_file: "/etc/config.toml".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json"
);
let body = collect_body(response.into_body()).await;
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"errors": [
{
"extensions": {
"code": "CONFIG_INVALID_PORT",
"configFile": "/etc/config.toml",
"port": 8080
},
"message": "Invalid port"
}
]
}
"#);
}
#[tokio::test]
async fn test_content_negotiation_quality() {
let (mock_service, mut handle) = mock::pair::<Request<()>, Response<TestBody>>();
let mut service = ErrorLayer::new().layer(mock_service);
let handle_task = tokio::spawn(async move {
let (request, send_response) = handle.next_request().await.expect("service not called");
assert_eq!(
request.headers().get(header::ACCEPT).unwrap(),
"application/json;q=0.5, text/html;q=0.9"
);
let error: BoxError = Box::new(ErrorWithStatus::BadRequest);
send_response.send_error(error);
});
let req = Request::builder()
.header(header::ACCEPT, "application/json;q=0.5, text/html;q=0.9")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"text/html"
);
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
handle_task.await.unwrap();
}
#[tokio::test]
async fn test_size_hint_preserved_for_success_response() {
let body_content = Bytes::from_static(b"Hello, World!");
let expected_size = body_content.len() as u64;
let mut service = ErrorLayer::new().layer(service_fn(move |_req: Request<()>| {
let body = body_content.clone();
async move { Ok::<_, Infallible>(Response::new(Full::new(body))) }
}));
let req = Request::builder().body(()).unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
let size_hint = response.body().size_hint();
assert_eq!(
size_hint.exact(),
Some(expected_size),
"size_hint should report exact size for Full<Bytes> body"
);
}
#[tokio::test]
async fn test_size_hint_preserved_for_error_response() {
let mut service = ErrorLayer::new().layer(error_service(ErrorWithFields::InvalidPort {
port: 8080,
config_file: "/etc/config.toml".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
let size_hint = response.body().size_hint();
assert!(
size_hint.exact().is_some(),
"size_hint should report exact size for error response body"
);
}
#[tokio::test]
async fn test_jsonrpc_rendering() {
let mut service = ErrorLayer::new().layer(error_service(ErrorWithFields::InvalidPort {
port: 9000,
config_file: "/etc/app.conf".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "application/json-rpc")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json-rpc"
);
let body = collect_body(response.into_body()).await;
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"error": {
"code": -32000,
"data": {
"config_file": "/etc/app.conf",
"diagnostic_code": "config::invalid_port",
"port": 9000
},
"message": "Invalid port"
},
"id": null,
"jsonrpc": "2.0"
}
"#);
}
#[tokio::test]
async fn test_jsonrpc_with_custom_code() {
let mut service = ErrorLayer::new().layer(error_service(JsonRpcError::MethodNotFound {
method: "eth_call".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "application/json-rpc")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json-rpc"
);
let body = collect_body(response.into_body()).await;
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"error": {
"code": -32601,
"data": {
"diagnostic_code": "jsonrpc::method_not_found",
"method": "eth_call"
},
"message": "Method not found: eth_call"
},
"id": null,
"jsonrpc": "2.0"
}
"#);
}
#[tokio::test]
async fn test_http_headers_set_on_response() {
let mut service = ErrorLayer::new().layer(error_service(ErrorWithHeaders::RateLimitExceeded {
retry_after: 60,
remaining: 0,
}));
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(response.headers().get("retry-after").unwrap(), "60");
assert_eq!(
response.headers().get("x-ratelimit-remaining").unwrap(),
"0"
);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json"
);
}
#[tokio::test]
async fn test_http_headers_option_none_omitted() {
let mut service = ErrorLayer::new().layer(error_service(ErrorWithHeaders::QuotaWarning {
remaining: None,
}));
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert!(response.headers().get("x-ratelimit-remaining").is_none());
}
#[tokio::test]
async fn test_http_headers_option_some_included() {
let mut service = ErrorLayer::new().layer(error_service(ErrorWithHeaders::QuotaWarning {
remaining: Some(5),
}));
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(
response.headers().get("x-ratelimit-remaining").unwrap(),
"5"
);
}
#[tokio::test]
async fn test_with_format_override_json() {
let config = NegotiationConfig::new().with_format_config(
Renderer::Json,
FormatConfig {
field_case: FieldCase::CamelCase,
code_case: CodeCase::ScreamingSnakeCase,
},
);
let mut service =
ErrorLayer::with_config(config).layer(error_service(ErrorWithFields::InvalidPort {
port: 8080,
config_file: "/etc/config.toml".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
let body = collect_body(response.into_body()).await;
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"configFile": "/etc/config.toml",
"error": "CONFIG_INVALID_PORT",
"message": "Invalid port",
"port": 8080
}
"#);
}
#[tokio::test]
async fn test_with_format_override_graphql_wins_over_default() {
let config = NegotiationConfig::new().with_format_config(
Renderer::GraphQL,
FormatConfig {
field_case: FieldCase::SnakeCase,
code_case: CodeCase::Default,
},
);
let mut service =
ErrorLayer::with_config(config).layer(error_service(ErrorWithRenamedField::WithRename {
my_field: "value".to_string(),
}));
let req = Request::builder()
.header(header::ACCEPT, "application/graphql-response+json")
.body(())
.unwrap();
let response = service.ready().await.unwrap().call(req).await.unwrap();
let body = collect_body(response.into_body()).await;
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"errors": [
{
"extensions": {
"code": "test::renamed_field",
"my_renamed_field": "value"
},
"message": "Field was renamed"
}
]
}
"#);
}
#[tokio::test]
async fn test_format_config_isolated_across_services() {
let error = ErrorWithFields::InvalidPort {
port: 8080,
config_file: "/etc/config.toml".to_string(),
};
let camel_config = NegotiationConfig::new().with_format_config(
Renderer::Json,
FormatConfig {
field_case: FieldCase::CamelCase,
code_case: CodeCase::ScreamingSnakeCase,
},
);
let mut camel_svc = ErrorLayer::with_config(camel_config).layer(error_service(error.clone()));
let mut default_svc = ErrorLayer::new().layer(error_service(error));
let make_req = || {
Request::builder()
.header(header::ACCEPT, "application/json")
.body(())
.unwrap()
};
let camel_body = collect_body(
camel_svc
.ready()
.await
.unwrap()
.call(make_req())
.await
.unwrap()
.into_body(),
)
.await;
let snake_body = collect_body(
default_svc
.ready()
.await
.unwrap()
.call(make_req())
.await
.unwrap()
.into_body(),
)
.await;
let camel_json: serde_json::Value = serde_json::from_slice(&camel_body).unwrap();
let snake_json: serde_json::Value = serde_json::from_slice(&snake_body).unwrap();
assert_eq!(camel_json["configFile"], "/etc/config.toml");
assert_eq!(camel_json["error"], "CONFIG_INVALID_PORT");
assert!(
camel_json.get("config_file").is_none(),
"camelCase service produced snake_case field"
);
assert_eq!(snake_json["config_file"], "/etc/config.toml");
assert_eq!(snake_json["error"], "config::invalid_port");
assert!(
snake_json.get("configFile").is_none(),
"snake_case service produced camelCase field"
);
}