reinhardt_dispatch/
exception.rs1use async_trait::async_trait;
7use bytes::Bytes;
8use hyper::StatusCode;
9use reinhardt_http::{Request, Response};
10use std::fmt;
11use std::future::Future;
12use tracing::{error, warn};
13
14use crate::DispatchError;
15use crate::build_error_response;
16
17pub type ExceptionResult = Result<Response, DispatchError>;
19
20#[async_trait]
22pub trait ExceptionHandler: Send + Sync {
23 async fn handle_exception(&self, request: &Request, error: DispatchError) -> Response;
25}
26
27pub struct DefaultExceptionHandler;
31
32#[async_trait]
33impl ExceptionHandler for DefaultExceptionHandler {
34 async fn handle_exception(&self, _request: &Request, error: DispatchError) -> Response {
35 let (status, client_message) = match &error {
38 DispatchError::View(msg) => {
39 warn!("View error: {}", msg);
40 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")
41 }
42 DispatchError::UrlResolution(msg) => {
43 warn!("URL resolution error: {}", msg);
44 (StatusCode::NOT_FOUND, "Not Found")
45 }
46 DispatchError::Middleware(msg) => {
47 error!("Middleware error: {}", msg);
48 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")
49 }
50 DispatchError::Http(msg) => {
51 warn!("HTTP error: {}", msg);
52 (StatusCode::BAD_REQUEST, "Bad Request")
53 }
54 DispatchError::Internal(msg) => {
55 error!("Internal error: {}", msg);
56 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")
57 }
58 };
59
60 build_error_response(status, client_message)
61 }
62}
63
64pub async fn convert_exception_to_response<F, Fut>(handler: F, request: Request) -> Response
74where
75 F: FnOnce(Request) -> Fut,
76 Fut: Future<Output = Result<Response, DispatchError>>,
77{
78 let method = request.method.clone();
81 let uri = request.uri.clone();
82 let version = request.version;
83 let headers = request.headers.clone();
84
85 match handler(request).await {
86 Ok(response) => response,
87 Err(error) => {
88 let exception_handler = DefaultExceptionHandler;
89 match Request::builder()
91 .method(method)
92 .uri(uri.to_string())
93 .version(version)
94 .headers(headers)
95 .body(Bytes::new())
96 .build()
97 {
98 Ok(context_request) => {
99 exception_handler
100 .handle_exception(&context_request, error)
101 .await
102 }
103 Err(_) => {
104 let mut response = Response::new(hyper::StatusCode::INTERNAL_SERVER_ERROR);
105 response.body = Bytes::from("Internal Server Error");
106 response
107 }
108 }
109 }
110 }
111}
112
113pub trait IntoResponse {
115 fn into_response(self) -> Response;
117}
118
119impl IntoResponse for Response {
120 fn into_response(self) -> Response {
121 self
122 }
123}
124
125impl IntoResponse for String {
126 fn into_response(self) -> Response {
127 let mut response = Response::new(StatusCode::OK);
128 response.body = Bytes::from(self.into_bytes());
129 response
130 }
131}
132
133impl IntoResponse for &str {
134 fn into_response(self) -> Response {
135 let mut response = Response::new(StatusCode::OK);
136 response.body = Bytes::from(self.as_bytes().to_vec());
137 response
138 }
139}
140
141impl IntoResponse for Vec<u8> {
142 fn into_response(self) -> Response {
143 let mut response = Response::new(StatusCode::OK);
144 response.body = Bytes::from(self);
145 response
146 }
147}
148
149impl IntoResponse for StatusCode {
150 fn into_response(self) -> Response {
151 Response::new(self)
152 }
153}
154
155impl<T: IntoResponse, E: fmt::Display> IntoResponse for Result<T, E> {
156 fn into_response(self) -> Response {
157 match self {
158 Ok(value) => value.into_response(),
159 Err(error) => {
160 error!("Error converting to response: {}", error);
162 build_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")
163 }
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 fn build_request() -> Request {
173 Request::builder()
174 .method(hyper::Method::GET)
175 .uri("/")
176 .version(hyper::Version::HTTP_11)
177 .headers(hyper::HeaderMap::new())
178 .body(Bytes::new())
179 .build()
180 .unwrap()
181 }
182
183 #[tokio::test]
188 async fn test_internal_error_does_not_expose_details() {
189 let handler = DefaultExceptionHandler;
191 let request = build_request();
192 let error =
193 DispatchError::Internal("database pool exhausted at /src/db/pool.rs:99".to_string());
194
195 let response = handler.handle_exception(&request, error).await;
197
198 let body = String::from_utf8(response.body.to_vec()).unwrap();
200 assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
201 assert_eq!(body, "Internal Server Error");
202 assert!(!body.contains("database"));
203 assert!(!body.contains(".rs:"));
204 }
205
206 #[tokio::test]
207 async fn test_middleware_error_does_not_expose_details() {
208 let handler = DefaultExceptionHandler;
210 let request = build_request();
211 let error = DispatchError::Middleware(
212 "JWT decode failed: invalid signature for key abc123".to_string(),
213 );
214
215 let response = handler.handle_exception(&request, error).await;
217
218 let body = String::from_utf8(response.body.to_vec()).unwrap();
220 assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
221 assert_eq!(body, "Internal Server Error");
222 assert!(!body.contains("JWT"));
223 assert!(!body.contains("abc123"));
224 }
225
226 #[tokio::test]
227 async fn test_view_error_does_not_expose_details() {
228 let handler = DefaultExceptionHandler;
230 let request = build_request();
231 let error = DispatchError::View(
232 "template rendering panicked at /src/views/admin.rs:42".to_string(),
233 );
234
235 let response = handler.handle_exception(&request, error).await;
237
238 let body = String::from_utf8(response.body.to_vec()).unwrap();
240 assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
241 assert_eq!(body, "Internal Server Error");
242 assert!(!body.contains("panicked"));
243 assert!(!body.contains(".rs:"));
244 }
245
246 #[tokio::test]
247 async fn test_url_resolution_returns_not_found() {
248 let handler = DefaultExceptionHandler;
250 let request = build_request();
251 let error = DispatchError::UrlResolution("no route matched".to_string());
252
253 let response = handler.handle_exception(&request, error).await;
255
256 let body = String::from_utf8(response.body.to_vec()).unwrap();
258 assert_eq!(response.status, StatusCode::NOT_FOUND);
259 assert_eq!(body, "Not Found");
260 }
261
262 #[tokio::test]
263 async fn test_http_error_returns_bad_request() {
264 let handler = DefaultExceptionHandler;
266 let request = build_request();
267 let error = DispatchError::Http("malformed header".to_string());
268
269 let response = handler.handle_exception(&request, error).await;
271
272 let body = String::from_utf8(response.body.to_vec()).unwrap();
274 assert_eq!(response.status, StatusCode::BAD_REQUEST);
275 assert_eq!(body, "Bad Request");
276 }
277
278 #[test]
279 fn test_into_response_for_result_err_does_not_expose_error() {
280 let result: Result<String, String> =
282 Err("connection string: postgres://admin:pass@host/db".to_string());
283
284 let response = result.into_response();
286
287 let body = String::from_utf8(response.body.to_vec()).unwrap();
289 assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
290 assert!(!body.contains("postgres"));
291 assert!(!body.contains("admin"));
292 assert_eq!(body, "Internal Server Error");
293 }
294}