1use bytes::Bytes;
2use futures::stream::Stream;
3use hyper::{HeaderMap, StatusCode};
4use serde::Serialize;
5use std::pin::Pin;
6
7fn safe_error_message(status: StatusCode) -> &'static str {
13 match status.as_u16() {
14 400 => "Bad Request",
15 401 => "Unauthorized",
16 403 => "Forbidden",
17 404 => "Not Found",
18 405 => "Method Not Allowed",
19 406 => "Not Acceptable",
20 408 => "Request Timeout",
21 409 => "Conflict",
22 410 => "Gone",
23 413 => "Payload Too Large",
24 415 => "Unsupported Media Type",
25 422 => "Unprocessable Entity",
26 429 => "Too Many Requests",
27 500 => "Internal Server Error",
29 502 => "Bad Gateway",
30 503 => "Service Unavailable",
31 504 => "Gateway Timeout",
32 _ if status.is_client_error() => "Client Error",
33 _ if status.is_server_error() => "Server Error",
34 _ => "Error",
35 }
36}
37
38fn safe_client_error_detail(error: &crate::Error) -> Option<String> {
44 use crate::Error;
45 match error {
46 Error::Validation(msg) => Some(msg.clone()),
47 Error::ParseError(_) => Some("Invalid request format".to_string()),
48 Error::BodyAlreadyConsumed => Some("Request body has already been consumed".to_string()),
49 Error::MissingContentType => Some("Missing Content-Type header".to_string()),
50 Error::InvalidPage(msg) => Some(format!("Invalid page: {}", msg)),
51 Error::InvalidCursor(_) => Some("Invalid cursor value".to_string()),
52 Error::InvalidLimit(msg) => Some(format!("Invalid limit: {}", msg)),
53 Error::MissingParameter(name) => Some(format!("Missing parameter: {}", name)),
54 Error::ParamValidation(ctx) => {
55 Some(format!("{} parameter extraction failed", ctx.param_type))
56 }
57 _ => None,
59 }
60}
61
62pub struct SafeErrorResponse {
83 status: StatusCode,
84 detail: Option<String>,
85 debug_info: Option<String>,
86 debug_mode: bool,
87}
88
89impl SafeErrorResponse {
90 pub fn new(status: StatusCode) -> Self {
92 Self {
93 status,
94 detail: None,
95 debug_info: None,
96 debug_mode: false,
97 }
98 }
99
100 pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
105 self.detail = Some(detail.into());
106 self
107 }
108
109 pub fn with_debug_info(mut self, info: impl Into<String>) -> Self {
113 self.debug_info = Some(info.into());
114 self
115 }
116
117 pub fn with_debug_mode(mut self, debug: bool) -> Self {
122 self.debug_mode = debug;
123 self
124 }
125
126 pub fn build(self) -> Response {
128 let message = safe_error_message(self.status);
129 let mut body = serde_json::json!({
130 "error": message,
131 });
132
133 if self.status.is_client_error()
135 && let Some(detail) = &self.detail
136 {
137 body["detail"] = serde_json::Value::String(detail.clone());
138 }
139
140 if self.debug_mode {
142 if let Some(debug_info) = &self.debug_info {
143 body["debug"] = serde_json::Value::String(debug_info.clone());
144 }
145 if self.status.is_server_error()
147 && let Some(detail) = &self.detail
148 {
149 body["detail"] = serde_json::Value::String(detail.clone());
150 }
151 }
152
153 Response::new(self.status)
154 .with_json(&body)
155 .unwrap_or_else(|_| Response::internal_server_error())
156 }
157}
158
159pub fn truncate_for_log(input: &str, max_length: usize) -> String {
176 if input.len() <= max_length {
177 input.to_string()
178 } else {
179 format!(
180 "{}...[truncated, {} total bytes]",
181 &input[..max_length],
182 input.len()
183 )
184 }
185}
186
187#[derive(Debug)]
189pub struct Response {
190 pub status: StatusCode,
191 pub headers: HeaderMap,
192 pub body: Bytes,
193 stop_chain: bool,
196}
197
198pub struct StreamingResponse<S> {
200 pub status: StatusCode,
201 pub headers: HeaderMap,
202 pub stream: S,
203}
204
205pub type StreamBody =
207 Pin<Box<dyn Stream<Item = Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
208
209impl Response {
210 pub fn new(status: StatusCode) -> Self {
223 Self {
224 status,
225 headers: HeaderMap::new(),
226 body: Bytes::new(),
227 stop_chain: false,
228 }
229 }
230 pub fn ok() -> Self {
242 Self::new(StatusCode::OK)
243 }
244 pub fn created() -> Self {
256 Self::new(StatusCode::CREATED)
257 }
258 pub fn no_content() -> Self {
270 Self::new(StatusCode::NO_CONTENT)
271 }
272 pub fn bad_request() -> Self {
284 Self::new(StatusCode::BAD_REQUEST)
285 }
286 pub fn unauthorized() -> Self {
298 Self::new(StatusCode::UNAUTHORIZED)
299 }
300 pub fn forbidden() -> Self {
312 Self::new(StatusCode::FORBIDDEN)
313 }
314 pub fn not_found() -> Self {
326 Self::new(StatusCode::NOT_FOUND)
327 }
328 pub fn internal_server_error() -> Self {
340 Self::new(StatusCode::INTERNAL_SERVER_ERROR)
341 }
342 pub fn gone() -> Self {
356 Self::new(StatusCode::GONE)
357 }
358 pub fn permanent_redirect(location: impl AsRef<str>) -> Self {
374 Self::new(StatusCode::MOVED_PERMANENTLY).with_location(location.as_ref())
375 }
376 pub fn temporary_redirect(location: impl AsRef<str>) -> Self {
392 Self::new(StatusCode::FOUND).with_location(location.as_ref())
393 }
394 pub fn temporary_redirect_preserve_method(location: impl AsRef<str>) -> Self {
412 Self::new(StatusCode::TEMPORARY_REDIRECT).with_location(location.as_ref())
413 }
414 pub fn with_body(mut self, body: impl Into<Bytes>) -> Self {
426 self.body = body.into();
427 self
428 }
429 pub fn try_with_header(mut self, name: &str, value: &str) -> crate::Result<Self> {
455 let header_name = hyper::header::HeaderName::from_bytes(name.as_bytes())
456 .map_err(|e| crate::Error::Http(format!("Invalid header name '{}': {}", name, e)))?;
457 let header_value = hyper::header::HeaderValue::from_str(value).map_err(|e| {
458 crate::Error::Http(format!("Invalid header value for '{}': {}", name, e))
459 })?;
460 self.headers.insert(header_name, header_value);
461 Ok(self)
462 }
463
464 pub fn with_header(mut self, name: &str, value: &str) -> Self {
489 if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes())
490 && let Ok(header_value) = hyper::header::HeaderValue::from_str(value)
491 {
492 self.headers.insert(header_name, header_value);
493 }
494 self
495 }
496 pub fn with_location(mut self, location: &str) -> Self {
511 if let Ok(value) = hyper::header::HeaderValue::from_str(location) {
512 self.headers.insert(hyper::header::LOCATION, value);
513 }
514 self
515 }
516 pub fn with_json<T: Serialize>(mut self, data: &T) -> crate::Result<Self> {
533 use crate::Error;
534 let json = serde_json::to_vec(data).map_err(|e| Error::Serialization(e.to_string()))?;
535 self.body = Bytes::from(json);
536 self.headers.insert(
537 hyper::header::CONTENT_TYPE,
538 hyper::header::HeaderValue::from_static("application/json"),
539 );
540 Ok(self)
541 }
542 pub fn with_typed_header(
560 mut self,
561 key: hyper::header::HeaderName,
562 value: hyper::header::HeaderValue,
563 ) -> Self {
564 self.headers.insert(key, value);
565 self
566 }
567
568 pub fn should_stop_chain(&self) -> bool {
584 self.stop_chain
585 }
586
587 pub fn with_stop_chain(mut self, stop: bool) -> Self {
617 self.stop_chain = stop;
618 self
619 }
620}
621
622impl From<crate::Error> for Response {
623 fn from(error: crate::Error) -> Self {
624 let status =
625 StatusCode::from_u16(error.status_code()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
626
627 tracing::error!(
629 status = status.as_u16(),
630 error = %error,
631 "Request error"
632 );
633
634 let mut response = SafeErrorResponse::new(status);
635
636 if status.is_client_error()
639 && let Some(detail) = safe_client_error_detail(&error)
640 {
641 response = response.with_detail(detail);
642 }
643
644 response.build()
645 }
646}
647
648impl<S> StreamingResponse<S>
649where
650 S: Stream<Item = Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
651{
652 pub fn new(stream: S) -> Self {
669 Self {
670 status: StatusCode::OK,
671 headers: HeaderMap::new(),
672 stream,
673 }
674 }
675 pub fn with_status(stream: S, status: StatusCode) -> Self {
692 Self {
693 status,
694 headers: HeaderMap::new(),
695 stream,
696 }
697 }
698 pub fn status(mut self, status: StatusCode) -> Self {
715 self.status = status;
716 self
717 }
718 pub fn header(
739 mut self,
740 key: hyper::header::HeaderName,
741 value: hyper::header::HeaderValue,
742 ) -> Self {
743 self.headers.insert(key, value);
744 self
745 }
746 pub fn media_type(self, media_type: &str) -> Self {
766 self.header(
767 hyper::header::CONTENT_TYPE,
768 hyper::header::HeaderValue::from_str(media_type).unwrap_or_else(|_| {
769 hyper::header::HeaderValue::from_static("application/octet-stream")
770 }),
771 )
772 }
773}
774
775impl<S> StreamingResponse<S> {
776 pub fn into_stream(self) -> S {
796 self.stream
797 }
798}
799
800#[cfg(test)]
801mod tests {
802 use super::*;
803 use rstest::rstest;
804
805 #[rstest]
806 #[case(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")]
807 #[case(StatusCode::BAD_GATEWAY, "Bad Gateway")]
808 #[case(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")]
809 #[case(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")]
810 fn test_5xx_errors_never_include_internal_details(
811 #[case] status: StatusCode,
812 #[case] expected_message: &str,
813 ) {
814 let sensitive_detail = "Internal path /src/db/connection.rs:42 failed";
816
817 let response = SafeErrorResponse::new(status)
819 .with_detail(sensitive_detail)
820 .build();
821
822 let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
824 assert_eq!(body["error"], expected_message);
825 assert!(body.get("detail").is_none());
827 assert_eq!(response.status, status);
828 }
829
830 #[rstest]
831 #[case(StatusCode::BAD_REQUEST, "Bad Request")]
832 #[case(StatusCode::UNAUTHORIZED, "Unauthorized")]
833 #[case(StatusCode::FORBIDDEN, "Forbidden")]
834 #[case(StatusCode::NOT_FOUND, "Not Found")]
835 #[case(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed")]
836 #[case(StatusCode::CONFLICT, "Conflict")]
837 fn test_4xx_errors_include_safe_detail(
838 #[case] status: StatusCode,
839 #[case] expected_message: &str,
840 ) {
841 let detail = "Missing required field: name";
843
844 let response = SafeErrorResponse::new(status).with_detail(detail).build();
846
847 let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
849 assert_eq!(body["error"], expected_message);
850 assert_eq!(body["detail"], detail);
851 assert_eq!(response.status, status);
852 }
853
854 #[rstest]
855 fn test_debug_mode_includes_full_error_info() {
856 let debug_info = "Error at src/handlers/user.rs:42: column 'email' not found";
858
859 let response = SafeErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
861 .with_detail("Database query failed")
862 .with_debug_info(debug_info)
863 .with_debug_mode(true)
864 .build();
865
866 let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
868 assert_eq!(body["error"], "Internal Server Error");
869 assert_eq!(body["detail"], "Database query failed");
871 assert_eq!(body["debug"], debug_info);
872 }
873
874 #[rstest]
875 fn test_debug_mode_disabled_excludes_debug_info() {
876 let debug_info = "Sensitive internal detail";
878
879 let response = SafeErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
881 .with_debug_info(debug_info)
882 .with_debug_mode(false)
883 .build();
884
885 let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
887 assert!(body.get("debug").is_none());
888 }
889
890 #[rstest]
891 #[case(StatusCode::BAD_REQUEST, "Bad Request")]
892 #[case(StatusCode::UNAUTHORIZED, "Unauthorized")]
893 #[case(StatusCode::FORBIDDEN, "Forbidden")]
894 #[case(StatusCode::NOT_FOUND, "Not Found")]
895 #[case(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed")]
896 #[case(StatusCode::NOT_ACCEPTABLE, "Not Acceptable")]
897 #[case(StatusCode::REQUEST_TIMEOUT, "Request Timeout")]
898 #[case(StatusCode::CONFLICT, "Conflict")]
899 #[case(StatusCode::GONE, "Gone")]
900 #[case(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large")]
901 #[case(StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported Media Type")]
902 #[case(StatusCode::UNPROCESSABLE_ENTITY, "Unprocessable Entity")]
903 #[case(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests")]
904 #[case(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")]
905 #[case(StatusCode::BAD_GATEWAY, "Bad Gateway")]
906 #[case(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")]
907 #[case(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")]
908 fn test_safe_error_message_returns_correct_messages(
909 #[case] status: StatusCode,
910 #[case] expected: &str,
911 ) {
912 let message = safe_error_message(status);
914
915 assert_eq!(message, expected);
917 }
918
919 #[rstest]
920 fn test_safe_error_message_fallback_client_error() {
921 let status = StatusCode::IM_A_TEAPOT;
924
925 let message = safe_error_message(status);
927
928 assert_eq!(message, "Client Error");
930 }
931
932 #[rstest]
933 fn test_safe_error_message_fallback_server_error() {
934 let status = StatusCode::HTTP_VERSION_NOT_SUPPORTED;
937
938 let message = safe_error_message(status);
940
941 assert_eq!(message, "Server Error");
943 }
944
945 #[rstest]
946 fn test_truncate_for_log_short_string() {
947 let input = "hello";
949
950 let result = truncate_for_log(input, 10);
952
953 assert_eq!(result, "hello");
955 }
956
957 #[rstest]
958 fn test_truncate_for_log_long_string() {
959 let input = "a".repeat(100);
961
962 let result = truncate_for_log(&input, 10);
964
965 assert!(result.starts_with("aaaaaaaaaa"));
967 assert!(result.contains("...[truncated, 100 total bytes]"));
968 }
969
970 #[rstest]
971 fn test_truncate_for_log_exact_length() {
972 let input = "abcde";
974
975 let result = truncate_for_log(input, 5);
977
978 assert_eq!(result, "abcde");
980 }
981
982 #[rstest]
983 fn test_from_error_produces_safe_output_for_5xx() {
984 let error = crate::Error::Database(
986 "Connection to postgres://user:pass@db:5432/mydb failed".to_string(),
987 );
988
989 let response: Response = error.into();
991
992 assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
994 let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
995 assert_eq!(body["error"], "Internal Server Error");
996 let body_str = String::from_utf8_lossy(&response.body);
998 assert!(!body_str.contains("postgres://"));
999 assert!(!body_str.contains("user:pass"));
1000 assert!(body.get("detail").is_none());
1001 }
1002
1003 #[rstest]
1004 fn test_from_error_produces_safe_output_for_4xx_validation() {
1005 let error = crate::Error::Validation("Email format is invalid".to_string());
1007
1008 let response: Response = error.into();
1010
1011 assert_eq!(response.status, StatusCode::BAD_REQUEST);
1013 let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
1014 assert_eq!(body["error"], "Bad Request");
1015 assert_eq!(body["detail"], "Email format is invalid");
1016 }
1017
1018 #[rstest]
1019 fn test_from_error_produces_safe_output_for_4xx_parse() {
1020 let error = crate::Error::ParseError(
1022 "invalid digit found in string at src/parser.rs:42".to_string(),
1023 );
1024
1025 let response: Response = error.into();
1027
1028 assert_eq!(response.status, StatusCode::BAD_REQUEST);
1030 let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
1031 assert_eq!(body["error"], "Bad Request");
1032 assert_eq!(body["detail"], "Invalid request format");
1034 let body_str = String::from_utf8_lossy(&response.body);
1035 assert!(!body_str.contains("src/parser.rs"));
1036 }
1037
1038 #[rstest]
1039 fn test_from_error_body_already_consumed() {
1040 let error = crate::Error::BodyAlreadyConsumed;
1042
1043 let response: Response = error.into();
1045
1046 assert_eq!(response.status, StatusCode::BAD_REQUEST);
1048 let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
1049 assert_eq!(body["detail"], "Request body has already been consumed");
1050 }
1051
1052 #[rstest]
1053 fn test_from_error_internal_error_hides_details() {
1054 let error =
1056 crate::Error::Internal("panic at /Users/dev/projects/app/src/main.rs:10".to_string());
1057
1058 let response: Response = error.into();
1060
1061 assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
1063 let body_str = String::from_utf8_lossy(&response.body);
1064 assert!(!body_str.contains("/Users/dev"));
1065 assert!(!body_str.contains("main.rs"));
1066 }
1067
1068 #[rstest]
1069 fn test_safe_error_response_no_detail_set() {
1070 let response = SafeErrorResponse::new(StatusCode::BAD_REQUEST).build();
1072
1073 let body: serde_json::Value = serde_json::from_slice(&response.body).unwrap();
1075 assert_eq!(body["error"], "Bad Request");
1076 assert!(body.get("detail").is_none());
1077 }
1078
1079 #[rstest]
1080 fn test_safe_error_response_content_type_is_json() {
1081 let response = SafeErrorResponse::new(StatusCode::NOT_FOUND).build();
1083
1084 let content_type = response
1086 .headers
1087 .get("content-type")
1088 .unwrap()
1089 .to_str()
1090 .unwrap();
1091 assert_eq!(content_type, "application/json");
1092 }
1093
1094 #[rstest]
1099 fn test_with_header_invalid_name_does_not_panic() {
1100 let response = Response::ok();
1102
1103 let response = response.with_header("Invalid Header", "value");
1105
1106 assert!(response.headers.is_empty());
1108 }
1109
1110 #[rstest]
1111 fn test_with_header_invalid_value_does_not_panic() {
1112 let response = Response::ok();
1114
1115 let response = response.with_header("X-Test", "value\x00with\x01control");
1117
1118 assert!(response.headers.get("X-Test").is_none());
1120 }
1121
1122 #[rstest]
1123 fn test_with_header_valid_header_works() {
1124 let response = Response::ok();
1126
1127 let response = response.with_header("X-Custom", "custom-value");
1129
1130 assert_eq!(
1132 response.headers.get("X-Custom").unwrap().to_str().unwrap(),
1133 "custom-value"
1134 );
1135 }
1136
1137 #[rstest]
1138 fn test_try_with_header_invalid_name_returns_error() {
1139 let response = Response::ok();
1141
1142 let result = response.try_with_header("Invalid Header", "value");
1144
1145 assert!(result.is_err());
1147 }
1148
1149 #[rstest]
1150 fn test_try_with_header_valid_header_returns_ok() {
1151 let response = Response::ok();
1153
1154 let result = response.try_with_header("X-Custom", "valid-value");
1156
1157 assert!(result.is_ok());
1159 let response = result.unwrap();
1160 assert_eq!(
1161 response.headers.get("X-Custom").unwrap().to_str().unwrap(),
1162 "valid-value"
1163 );
1164 }
1165}