use bytes::Bytes;
use futures::stream::Stream;
use hyper::{HeaderMap, StatusCode};
use serde::Serialize;
use std::pin::Pin;
fn safe_error_message(status: StatusCode) -> &'static str {
match status.as_u16() {
400 => "Bad Request",
401 => "Unauthorized",
403 => "Forbidden",
404 => "Not Found",
405 => "Method Not Allowed",
406 => "Not Acceptable",
408 => "Request Timeout",
409 => "Conflict",
410 => "Gone",
413 => "Payload Too Large",
415 => "Unsupported Media Type",
422 => "Unprocessable Entity",
429 => "Too Many Requests",
500 => "Internal Server Error",
502 => "Bad Gateway",
503 => "Service Unavailable",
504 => "Gateway Timeout",
_ if status.is_client_error() => "Client Error",
_ if status.is_server_error() => "Server Error",
_ => "Error",
}
}
fn safe_client_error_detail(error: &crate::Error) -> Option<String> {
use crate::Error;
match error {
Error::Validation(msg) => Some(msg.clone()),
Error::Http(msg) => Some(msg.clone()),
Error::Serialization(msg) => Some(msg.clone()),
Error::ParseError(_) => Some("Invalid request format".to_string()),
Error::BodyAlreadyConsumed => Some("Request body has already been consumed".to_string()),
Error::MissingContentType => Some("Missing Content-Type header".to_string()),
Error::InvalidPage(msg) => Some(format!("Invalid page: {}", msg)),
Error::InvalidCursor(_) => Some("Invalid cursor value".to_string()),
Error::InvalidLimit(msg) => Some(format!("Invalid limit: {}", msg)),
Error::MissingParameter(name) => Some(format!("Missing parameter: {}", name)),
Error::Conflict(msg) => Some(msg.clone()),
Error::ParamValidation(ctx) => {
Some(format!("{} parameter extraction failed", ctx.param_type))
}
_ => None,
}
}
pub struct SafeErrorResponse {
status: StatusCode,
detail: Option<String>,
debug_info: Option<String>,
debug_mode: bool,
}
impl SafeErrorResponse {
pub fn new(status: StatusCode) -> Self {
Self {
status,
detail: None,
debug_info: None,
debug_mode: false,
}
}
pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
self.detail = Some(detail.into());
self
}
pub fn with_debug_info(mut self, info: impl Into<String>) -> Self {
self.debug_info = Some(info.into());
self
}
pub fn with_debug_mode(mut self, debug: bool) -> Self {
self.debug_mode = debug;
self
}
pub fn build(self) -> Response {
let message = safe_error_message(self.status);
let mut body = serde_json::json!({
"error": message,
});
if self.status.is_client_error()
&& let Some(detail) = &self.detail
{
body["detail"] = serde_json::Value::String(detail.clone());
}
if self.debug_mode {
if let Some(debug_info) = &self.debug_info {
body["debug"] = serde_json::Value::String(debug_info.clone());
}
if self.status.is_server_error()
&& let Some(detail) = &self.detail
{
body["detail"] = serde_json::Value::String(detail.clone());
}
}
Response::new(self.status)
.with_json(&body)
.unwrap_or_else(|_| Response::internal_server_error())
}
}
pub fn truncate_for_log(input: &str, max_length: usize) -> String {
if input.len() <= max_length {
input.to_string()
} else {
let truncate_at = input
.char_indices()
.take_while(|&(i, _)| i <= max_length)
.last()
.map(|(i, _)| i)
.unwrap_or(0);
format!(
"{}...[truncated, {} total bytes]",
&input[..truncate_at],
input.len()
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Response {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: Bytes,
stop_chain: bool,
}
pub struct StreamingResponse<S> {
pub status: StatusCode,
pub headers: HeaderMap,
pub stream: S,
}
pub type StreamBody =
Pin<Box<dyn Stream<Item = Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
impl Response {
pub fn new(status: StatusCode) -> Self {
Self {
status,
headers: HeaderMap::new(),
body: Bytes::new(),
stop_chain: false,
}
}
pub fn ok() -> Self {
Self::new(StatusCode::OK)
}
pub fn created() -> Self {
Self::new(StatusCode::CREATED)
}
pub fn no_content() -> Self {
Self::new(StatusCode::NO_CONTENT)
}
pub fn bad_request() -> Self {
Self::new(StatusCode::BAD_REQUEST)
}
pub fn unauthorized() -> Self {
Self::new(StatusCode::UNAUTHORIZED)
}
pub fn forbidden() -> Self {
Self::new(StatusCode::FORBIDDEN)
}
pub fn not_found() -> Self {
Self::new(StatusCode::NOT_FOUND)
}
pub fn internal_server_error() -> Self {
Self::new(StatusCode::INTERNAL_SERVER_ERROR)
}
pub fn gone() -> Self {
Self::new(StatusCode::GONE)
}
pub fn permanent_redirect(location: impl AsRef<str>) -> Self {
Self::new(StatusCode::MOVED_PERMANENTLY).with_location(location.as_ref())
}
pub fn temporary_redirect(location: impl AsRef<str>) -> Self {
Self::new(StatusCode::FOUND).with_location(location.as_ref())
}
pub fn temporary_redirect_preserve_method(location: impl AsRef<str>) -> Self {
Self::new(StatusCode::TEMPORARY_REDIRECT).with_location(location.as_ref())
}
pub fn with_body(mut self, body: impl Into<Bytes>) -> Self {
self.body = body.into();
self
}
pub fn try_with_header(mut self, name: &str, value: &str) -> crate::Result<Self> {
let header_name = hyper::header::HeaderName::from_bytes(name.as_bytes())
.map_err(|e| crate::Error::Http(format!("Invalid header name '{}': {}", name, e)))?;
let header_value = hyper::header::HeaderValue::from_str(value).map_err(|e| {
crate::Error::Http(format!("Invalid header value for '{}': {}", name, e))
})?;
self.headers.insert(header_name, header_value);
Ok(self)
}
pub fn with_header_if_absent(mut self, name: &str, value: &str) -> Self {
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes())
&& !self.headers.contains_key(&header_name)
&& let Ok(header_value) = hyper::header::HeaderValue::from_str(value)
{
self.headers.insert(header_name, header_value);
}
self
}
pub fn try_with_header_if_absent(mut self, name: &str, value: &str) -> crate::Result<Self> {
let header_name = hyper::header::HeaderName::from_bytes(name.as_bytes())
.map_err(|e| crate::Error::Http(format!("Invalid header name '{}': {}", name, e)))?;
if !self.headers.contains_key(&header_name) {
let header_value = hyper::header::HeaderValue::from_str(value).map_err(|e| {
crate::Error::Http(format!("Invalid header value for '{}': {}", name, e))
})?;
self.headers.insert(header_name, header_value);
}
Ok(self)
}
pub fn with_header(mut self, name: &str, value: &str) -> Self {
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes())
&& let Ok(header_value) = hyper::header::HeaderValue::from_str(value)
{
self.headers.insert(header_name, header_value);
}
self
}
pub fn append_header(mut self, name: &str, value: &str) -> Self {
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes())
&& let Ok(header_value) = hyper::header::HeaderValue::from_str(value)
{
self.headers.append(header_name, header_value);
}
self
}
pub fn with_location(mut self, location: &str) -> Self {
if let Ok(value) = hyper::header::HeaderValue::from_str(location) {
self.headers.insert(hyper::header::LOCATION, value);
}
self
}
pub fn with_json<T: Serialize>(mut self, data: &T) -> crate::Result<Self> {
use crate::Error;
let json = serde_json::to_vec(data).map_err(|e| Error::Serialization(e.to_string()))?;
self.body = Bytes::from(json);
self.headers.insert(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
);
Ok(self)
}
pub fn with_typed_header(
mut self,
key: hyper::header::HeaderName,
value: hyper::header::HeaderValue,
) -> Self {
self.headers.insert(key, value);
self
}
pub fn should_stop_chain(&self) -> bool {
self.stop_chain
}
pub fn with_stop_chain(mut self, stop: bool) -> Self {
self.stop_chain = stop;
self
}
}
impl From<crate::Error> for Response {
fn from(error: crate::Error) -> Self {
let status =
StatusCode::from_u16(error.status_code()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
tracing::error!(
status = status.as_u16(),
error = %error,
"Request error"
);
let mut response = SafeErrorResponse::new(status);
if status.is_client_error()
&& let Some(detail) = safe_client_error_detail(&error)
{
response = response.with_detail(detail);
}
response.build()
}
}
impl<S> StreamingResponse<S>
where
S: Stream<Item = Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
{
pub fn new(stream: S) -> Self {
Self {
status: StatusCode::OK,
headers: HeaderMap::new(),
stream,
}
}
pub fn with_status(stream: S, status: StatusCode) -> Self {
Self {
status,
headers: HeaderMap::new(),
stream,
}
}
pub fn status(mut self, status: StatusCode) -> Self {
self.status = status;
self
}
pub fn header(
mut self,
key: hyper::header::HeaderName,
value: hyper::header::HeaderValue,
) -> Self {
self.headers.insert(key, value);
self
}
pub fn media_type(self, media_type: &str) -> Self {
self.header(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_str(media_type).unwrap_or_else(|_| {
hyper::header::HeaderValue::from_static("application/octet-stream")
}),
)
}
}
impl<S> StreamingResponse<S> {
pub fn into_stream(self) -> S {
self.stream
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")]
#[case(StatusCode::BAD_GATEWAY, "Bad Gateway")]
#[case(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")]
#[case(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")]
fn test_5xx_errors_never_include_internal_details(
#[case] status: StatusCode,
#[case] expected_message: &str,
) {
let sensitive_detail = "Internal path /src/db/connection.rs:42 failed";
let response = SafeErrorResponse::new(status)
.with_detail(sensitive_detail)
.build();
let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
assert_eq!(body["error"], expected_message);
assert!(body.get("detail").is_none());
assert_eq!(response.status, status);
}
#[rstest]
#[case(StatusCode::BAD_REQUEST, "Bad Request")]
#[case(StatusCode::UNAUTHORIZED, "Unauthorized")]
#[case(StatusCode::FORBIDDEN, "Forbidden")]
#[case(StatusCode::NOT_FOUND, "Not Found")]
#[case(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed")]
#[case(StatusCode::CONFLICT, "Conflict")]
fn test_4xx_errors_include_safe_detail(
#[case] status: StatusCode,
#[case] expected_message: &str,
) {
let detail = "Missing required field: name";
let response = SafeErrorResponse::new(status).with_detail(detail).build();
let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
assert_eq!(body["error"], expected_message);
assert_eq!(body["detail"], detail);
assert_eq!(response.status, status);
}
#[rstest]
fn test_debug_mode_includes_full_error_info() {
let debug_info = "Error at src/handlers/user.rs:42: column 'email' not found";
let response = SafeErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
.with_detail("Database query failed")
.with_debug_info(debug_info)
.with_debug_mode(true)
.build();
let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
assert_eq!(body["error"], "Internal Server Error");
assert_eq!(body["detail"], "Database query failed");
assert_eq!(body["debug"], debug_info);
}
#[rstest]
fn test_debug_mode_disabled_excludes_debug_info() {
let debug_info = "Sensitive internal detail";
let response = SafeErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
.with_debug_info(debug_info)
.with_debug_mode(false)
.build();
let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
assert!(body.get("debug").is_none());
}
#[rstest]
#[case(StatusCode::BAD_REQUEST, "Bad Request")]
#[case(StatusCode::UNAUTHORIZED, "Unauthorized")]
#[case(StatusCode::FORBIDDEN, "Forbidden")]
#[case(StatusCode::NOT_FOUND, "Not Found")]
#[case(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed")]
#[case(StatusCode::NOT_ACCEPTABLE, "Not Acceptable")]
#[case(StatusCode::REQUEST_TIMEOUT, "Request Timeout")]
#[case(StatusCode::CONFLICT, "Conflict")]
#[case(StatusCode::GONE, "Gone")]
#[case(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large")]
#[case(StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported Media Type")]
#[case(StatusCode::UNPROCESSABLE_ENTITY, "Unprocessable Entity")]
#[case(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests")]
#[case(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")]
#[case(StatusCode::BAD_GATEWAY, "Bad Gateway")]
#[case(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")]
#[case(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")]
fn test_safe_error_message_returns_correct_messages(
#[case] status: StatusCode,
#[case] expected: &str,
) {
let message = safe_error_message(status);
assert_eq!(message, expected);
}
#[rstest]
fn test_safe_error_message_fallback_client_error() {
let status = StatusCode::IM_A_TEAPOT;
let message = safe_error_message(status);
assert_eq!(message, "Client Error");
}
#[rstest]
fn test_safe_error_message_fallback_server_error() {
let status = StatusCode::HTTP_VERSION_NOT_SUPPORTED;
let message = safe_error_message(status);
assert_eq!(message, "Server Error");
}
#[rstest]
fn test_truncate_for_log_short_string() {
let input = "hello";
let result = truncate_for_log(input, 10);
assert_eq!(result, "hello");
}
#[rstest]
fn test_truncate_for_log_long_string() {
let input = "a".repeat(100);
let result = truncate_for_log(&input, 10);
assert!(result.starts_with("aaaaaaaaaa"));
assert!(result.contains("...[truncated, 100 total bytes]"));
}
#[rstest]
fn test_truncate_for_log_exact_length() {
let input = "abcde";
let result = truncate_for_log(input, 5);
assert_eq!(result, "abcde");
}
#[rstest]
fn test_truncate_for_log_multi_byte_utf8_does_not_panic() {
let input = "日本語テスト文字列";
let result = truncate_for_log(input, 4);
assert!(result.starts_with("日"));
assert!(result.contains("...[truncated"));
}
#[rstest]
fn test_truncate_for_log_emoji_boundary() {
let input = "🦀🐍🐹🐿️";
let result = truncate_for_log(input, 5);
assert!(result.starts_with("🦀"));
assert!(result.contains("...[truncated"));
}
#[rstest]
fn test_truncate_for_log_mixed_ascii_and_multibyte() {
let input = "abc日本語def";
let result = truncate_for_log(input, 5);
assert!(result.starts_with("abc"));
assert!(result.contains("...[truncated"));
}
#[rstest]
fn test_truncate_for_log_zero_max_length() {
let input = "hello";
let result = truncate_for_log(input, 0);
assert!(result.starts_with("...[truncated"));
}
#[rstest]
fn test_from_error_produces_safe_output_for_5xx() {
let error = crate::Error::Database(
"Connection to postgres://user:pass@db:5432/mydb failed".to_string(),
);
let response: Response = error.into();
assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
assert_eq!(body["error"], "Internal Server Error");
let body_str = String::from_utf8_lossy(&response.body);
assert!(!body_str.contains("postgres://"));
assert!(!body_str.contains("user:pass"));
assert!(body.get("detail").is_none());
}
#[rstest]
fn test_from_error_produces_safe_output_for_4xx_validation() {
let error = crate::Error::Validation("Email format is invalid".to_string());
let response: Response = error.into();
assert_eq!(response.status, StatusCode::BAD_REQUEST);
let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
assert_eq!(body["error"], "Bad Request");
assert_eq!(body["detail"], "Email format is invalid");
}
#[rstest]
fn test_from_error_produces_safe_output_for_4xx_parse() {
let error = crate::Error::ParseError(
"invalid digit found in string at src/parser.rs:42".to_string(),
);
let response: Response = error.into();
assert_eq!(response.status, StatusCode::BAD_REQUEST);
let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
assert_eq!(body["error"], "Bad Request");
assert_eq!(body["detail"], "Invalid request format");
let body_str = String::from_utf8_lossy(&response.body);
assert!(!body_str.contains("src/parser.rs"));
}
#[rstest]
fn test_from_error_body_already_consumed() {
let error = crate::Error::BodyAlreadyConsumed;
let response: Response = error.into();
assert_eq!(response.status, StatusCode::BAD_REQUEST);
let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
assert_eq!(body["detail"], "Request body has already been consumed");
}
#[rstest]
fn test_from_error_internal_error_hides_details() {
let error =
crate::Error::Internal("panic at /Users/dev/projects/app/src/main.rs:10".to_string());
let response: Response = error.into();
assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
let body_str = String::from_utf8_lossy(&response.body);
assert!(!body_str.contains("/Users/dev"));
assert!(!body_str.contains("main.rs"));
}
#[rstest]
fn test_safe_error_response_no_detail_set() {
let response = SafeErrorResponse::new(StatusCode::BAD_REQUEST).build();
let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
assert_eq!(body["error"], "Bad Request");
assert!(body.get("detail").is_none());
}
#[rstest]
fn test_safe_error_response_content_type_is_json() {
let response = SafeErrorResponse::new(StatusCode::NOT_FOUND).build();
let content_type = response
.headers
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert_eq!(content_type, "application/json");
}
#[rstest]
fn test_with_header_invalid_name_does_not_panic() {
let response = Response::ok();
let response = response.with_header("Invalid Header", "value");
assert!(response.headers.is_empty());
}
#[rstest]
fn test_with_header_invalid_value_does_not_panic() {
let response = Response::ok();
let response = response.with_header("X-Test", "value\x00with\x01control");
assert!(response.headers.get("X-Test").is_none());
}
#[rstest]
fn test_with_header_valid_header_works() {
let response = Response::ok();
let response = response.with_header("X-Custom", "custom-value");
assert_eq!(
response.headers.get("X-Custom").unwrap().to_str().unwrap(),
"custom-value"
);
}
#[rstest]
fn test_try_with_header_invalid_name_returns_error() {
let response = Response::ok();
let result = response.try_with_header("Invalid Header", "value");
assert!(result.is_err());
}
#[rstest]
fn test_try_with_header_valid_header_returns_ok() {
let response = Response::ok();
let result = response.try_with_header("X-Custom", "valid-value");
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(
response.headers.get("X-Custom").unwrap().to_str().unwrap(),
"valid-value"
);
}
#[rstest]
fn test_append_header_adds_multiple_values() {
let response = Response::ok()
.append_header("Set-Cookie", "a=1; Path=/")
.append_header("Set-Cookie", "b=2; Path=/");
let cookies: Vec<_> = response.headers.get_all("set-cookie").iter().collect();
assert_eq!(cookies.len(), 2);
assert_eq!(cookies[0].to_str().unwrap(), "a=1; Path=/");
assert_eq!(cookies[1].to_str().unwrap(), "b=2; Path=/");
}
#[rstest]
fn test_append_header_coexists_with_with_header() {
let response = Response::ok()
.with_header("Set-Cookie", "a=1; Path=/")
.append_header("Set-Cookie", "b=2; Path=/");
let cookies: Vec<_> = response.headers.get_all("set-cookie").iter().collect();
assert_eq!(cookies.len(), 2);
}
}