1use crate::reply::{ProblemDetails, ReplyData, ReplySpec, WebError};
3use bytes::Bytes;
4use futures_util::StreamExt;
5use futures_util::future::BoxFuture;
6use http::{HeaderName, HeaderValue, Response, StatusCode, header};
7use http_body_util::{BodyExt, Full, StreamBody};
8use hyper::body::Frame;
9use std::str::FromStr;
10use tracing::warn;
11
12type BoxBody = http_body_util::combinators::BoxBody<Bytes, WebError>;
13
14pub struct Finalizer;
18
19impl Default for Finalizer {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl Finalizer {
26 pub fn new() -> Self {
28 Finalizer
29 }
30
31 pub fn build_response<'a>(&'a self, data: ReplyData) -> BoxFuture<'a, Response<BoxBody>> {
34 Box::pin(async move {
35 match data {
36 ReplyData::Empty => Response::builder()
37 .status(StatusCode::NO_CONTENT)
38 .body(
39 Full::new(Bytes::new())
40 .map_err(|never| match never {})
41 .boxed(),
42 )
43 .unwrap(),
44
45 ReplyData::Bytes { content_type, data } => Response::builder()
46 .status(StatusCode::OK)
47 .header(header::CONTENT_TYPE, content_type.as_ref())
48 .body(
49 Full::new(Bytes::from(data))
50 .map_err(|never| match never {})
51 .boxed(),
52 )
53 .unwrap(),
54
55 ReplyData::Json(val) => {
56 let bytes = serde_json::to_vec(&val).expect("json");
57 Response::builder()
58 .status(StatusCode::OK)
59 .header(header::CONTENT_TYPE, "application/json")
60 .body(
61 Full::new(Bytes::from(bytes))
62 .map_err(|never| match never {})
63 .boxed(),
64 )
65 .unwrap()
66 }
67
68 ReplyData::Stream(body_stream) => {
69 let stream_of_frames = body_stream.map(|chunk| {
70 chunk
71 .map(Frame::data)
72 .map_err(|e| WebError::Internal(e.to_string()))
73 });
74 let body = StreamBody::new(stream_of_frames);
75 Response::builder()
76 .status(StatusCode::OK)
77 .body(BodyExt::boxed(body))
78 .unwrap()
79 }
80 ReplyData::Rich(spec) => self.build_rich_response(*spec).await,
81
82 ReplyData::Upgrade(_) => Response::builder()
87 .status(StatusCode::INTERNAL_SERVER_ERROR)
88 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
89 .body(
90 Full::new(Bytes::from_static(
91 b"upgrade reply was not handled by the server",
92 ))
93 .map_err(|never| match never {})
94 .boxed(),
95 )
96 .unwrap(),
97 }
98 })
99 }
100
101 async fn build_rich_response(&self, spec: ReplySpec) -> Response<BoxBody> {
102 let mut res = self.build_response(spec.payload).await;
103
104 if let Some(status) = spec.status {
105 *res.status_mut() = status;
106 }
107
108 for (k, v) in spec.headers {
113 let key = match HeaderName::from_str(&k) {
114 Ok(name) => name,
115 Err(e) => {
116 warn!(name = %k, error = %e, "dropping invalid response header name");
117 continue;
118 }
119 };
120 let value = match HeaderValue::from_str(&v) {
121 Ok(val) => val,
122 Err(e) => {
123 warn!(name = %k, value = %v, error = %e, "dropping invalid response header value");
124 continue;
125 }
126 };
127 res.headers_mut().insert(key, value);
128 }
129 res
130 }
131
132 pub fn error_to_reply(&self, error: WebError) -> ReplyData {
143 let mut allow_header: Option<String> = None;
144 let mut retry_after_seconds: Option<u64> = None;
145 let problem = match error {
146 WebError::NotFound => ProblemDetails::new(StatusCode::NOT_FOUND, "Not Found"),
147 WebError::MethodNotAllowed(methods) => {
148 allow_header = Some(methods.join(", "));
149 ProblemDetails::new(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed")
150 .extra("allowed_methods", serde_json::Value::from(methods))
151 }
152 WebError::BadRequest(msg) => {
153 ProblemDetails::new(StatusCode::BAD_REQUEST, "Bad Request").detail(msg)
154 }
155 WebError::PayloadTooLarge => {
156 ProblemDetails::new(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large")
157 }
158 WebError::TooManyRequests(retry_after) => {
159 let mut p = ProblemDetails::new(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests");
160 if let Some(d) = retry_after {
161 let secs = d.as_secs().max(if d.subsec_nanos() > 0 { 1 } else { 0 });
166 retry_after_seconds = Some(secs);
167 p = p.extra("retry_after_seconds", serde_json::Value::from(secs));
170 }
171 p
172 }
173 WebError::Timeout => {
174 ProblemDetails::new(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
175 .detail("the request did not complete within the configured timeout")
176 }
177 WebError::Busy(retry_after) => {
178 let mut p =
179 ProblemDetails::new(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
180 .detail("server is overloaded");
181 if let Some(d) = retry_after {
182 let secs = d.as_secs().max(if d.subsec_nanos() > 0 { 1 } else { 0 });
183 retry_after_seconds = Some(secs);
184 p = p.extra("retry_after_seconds", serde_json::Value::from(secs));
185 }
186 p
187 }
188 WebError::Unauthorized => ProblemDetails::new(StatusCode::UNAUTHORIZED, "Unauthorized"),
189 WebError::Forbidden => ProblemDetails::new(StatusCode::FORBIDDEN, "Forbidden"),
190 WebError::Internal(msg) => {
191 ProblemDetails::new(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")
192 .detail(msg)
193 }
194 WebError::Problem(p) => p,
195 };
196
197 let mut body = serde_json::Map::new();
198 body.insert("status".into(), problem.status.as_u16().into());
199 body.insert("title".into(), problem.title.into());
200 if let Some(d) = problem.detail {
201 body.insert("detail".into(), d.into());
202 }
203 for (k, v) in *problem.extra {
205 if !matches!(k.as_str(), "status" | "title" | "detail") {
206 body.insert(k, v);
207 }
208 }
209 let bytes = serde_json::to_vec(&serde_json::Value::Object(body)).expect("json");
210 let status = problem.status;
211
212 let mut headers = std::collections::HashMap::new();
213 if let Some(allow) = allow_header {
214 headers.insert("Allow".to_string(), allow);
215 }
216 if let Some(secs) = retry_after_seconds {
217 headers.insert("Retry-After".to_string(), secs.to_string());
218 }
219
220 ReplyData::Rich(Box::new(ReplySpec {
221 payload: ReplyData::Bytes {
222 content_type: std::borrow::Cow::Borrowed("application/problem+json"),
223 data: bytes,
224 },
225 status: Some(status),
226 headers,
227 }))
228 }
229
230 pub async fn build_error(&self, error: WebError) -> Response<BoxBody> {
237 self.build_response(self.error_to_reply(error)).await
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 async fn body_json(res: Response<BoxBody>) -> serde_json::Value {
246 let bytes = res.into_body().collect().await.unwrap().to_bytes();
247 serde_json::from_slice(&bytes).unwrap()
248 }
249
250 #[tokio::test]
251 async fn method_not_allowed_emits_allow_header_and_lists_methods() {
252 let res = Finalizer::new()
253 .build_error(WebError::MethodNotAllowed(vec!["GET", "POST"]))
254 .await;
255 assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
256 assert_eq!(res.headers().get(header::ALLOW).unwrap(), "GET, POST");
257 let body = body_json(res).await;
258 assert_eq!(body["status"], 405);
259 assert_eq!(body["title"], "Method Not Allowed");
260 assert_eq!(body["allowed_methods"], serde_json::json!(["GET", "POST"]));
261 }
262
263 #[tokio::test]
264 async fn other_errors_have_no_allow_header() {
265 let res = Finalizer::new().build_error(WebError::NotFound).await;
266 assert_eq!(res.status(), StatusCode::NOT_FOUND);
267 assert!(res.headers().get(header::ALLOW).is_none());
268 }
269
270 #[tokio::test]
271 async fn too_many_requests_emits_retry_after_header_and_extra_member() {
272 use std::time::Duration;
273
274 let res = Finalizer::new()
278 .build_error(WebError::TooManyRequests(Some(Duration::from_secs(42))))
279 .await;
280 assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
281 assert_eq!(res.headers().get("Retry-After").unwrap(), "42");
282 let body = body_json(res).await;
283 assert_eq!(body["status"], 429);
284 assert_eq!(body["title"], "Too Many Requests");
285 assert_eq!(body["retry_after_seconds"], 42);
286
287 let res = Finalizer::new()
290 .build_error(WebError::TooManyRequests(Some(Duration::from_millis(500))))
291 .await;
292 assert_eq!(res.headers().get("Retry-After").unwrap(), "1");
293
294 let res = Finalizer::new()
296 .build_error(WebError::TooManyRequests(None))
297 .await;
298 assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
299 assert!(res.headers().get("Retry-After").is_none());
300 }
301
302 #[tokio::test]
303 async fn busy_emits_503_with_retry_after() {
304 use std::time::Duration;
305
306 let res = Finalizer::new()
308 .build_error(WebError::Busy(Some(Duration::from_secs(2))))
309 .await;
310 assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
311 assert_eq!(res.headers().get("Retry-After").unwrap(), "2");
312 let body = body_json(res).await;
313 assert_eq!(body["status"], 503);
314 assert_eq!(body["title"], "Service Unavailable");
315 assert_eq!(body["retry_after_seconds"], 2);
316
317 let res = Finalizer::new().build_error(WebError::Busy(None)).await;
319 assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
320 assert!(res.headers().get("Retry-After").is_none());
321 }
322
323 #[tokio::test]
324 async fn build_rich_response_drops_invalid_headers_without_panicking() {
325 use crate::reply::ReplySpec;
326 use std::collections::HashMap;
327
328 let mut headers = HashMap::new();
329 headers.insert("X-Bad\nName".to_string(), "value".to_string());
332 headers.insert("X-Bad-Value".to_string(), "with\nnewline".to_string());
334 headers.insert("X-OK".to_string(), "fine".to_string());
336
337 let spec = ReplySpec {
338 payload: ReplyData::Empty,
339 status: Some(StatusCode::CREATED),
340 headers,
341 };
342 let res = Finalizer::new()
343 .build_response(ReplyData::Rich(Box::new(spec)))
344 .await;
345 assert_eq!(res.status(), StatusCode::CREATED);
346 assert_eq!(res.headers().get("X-OK").unwrap(), "fine");
347 assert!(res.headers().get("X-Bad-Value").is_none());
349 }
350}