Skip to main content

actus_reply/
finalizer.rs

1//! Contains the Finalizer, which converts a ReplyData into a real HTTP response.
2use crate::reply::{ProblemDetails, ReplyData, ReplySpec, WebError};
3use bytes::Bytes;
4use futures_util::StreamExt;
5use futures_util::future::BoxFuture;
6use http::{HeaderName, HeaderValue, Response, StatusCode, header};
7use http_body_util::{BodyExt, Full, StreamBody};
8use hyper::body::Frame;
9use std::str::FromStr;
10use tracing::warn;
11
12type BoxBody = http_body_util::combinators::BoxBody<Bytes, WebError>;
13
14/// Converts a [`ReplyData`] into a concrete `hyper` HTTP response — setting
15/// status, headers, and body, and driving buffered, streaming, SSE, and
16/// connection-upgrade replies.
17pub struct Finalizer;
18
19impl Default for Finalizer {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl Finalizer {
26    /// Create a finalizer.
27    pub fn new() -> Self {
28        Finalizer
29    }
30
31    /// Build the `hyper` response for `data` — its status line, headers, and
32    /// body (buffered, streaming, SSE, or a connection upgrade).
33    pub fn build_response<'a>(&'a self, data: ReplyData) -> BoxFuture<'a, Response<BoxBody>> {
34        Box::pin(async move {
35            match data {
36                ReplyData::Empty => Response::builder()
37                    .status(StatusCode::NO_CONTENT)
38                    .body(
39                        Full::new(Bytes::new())
40                            .map_err(|never| match never {})
41                            .boxed(),
42                    )
43                    .unwrap(),
44
45                ReplyData::Bytes { content_type, data } => Response::builder()
46                    .status(StatusCode::OK)
47                    .header(header::CONTENT_TYPE, content_type.as_ref())
48                    .body(
49                        Full::new(Bytes::from(data))
50                            .map_err(|never| match never {})
51                            .boxed(),
52                    )
53                    .unwrap(),
54
55                ReplyData::Json(val) => {
56                    let bytes = serde_json::to_vec(&val).expect("json");
57                    Response::builder()
58                        .status(StatusCode::OK)
59                        .header(header::CONTENT_TYPE, "application/json")
60                        .body(
61                            Full::new(Bytes::from(bytes))
62                                .map_err(|never| match never {})
63                                .boxed(),
64                        )
65                        .unwrap()
66                }
67
68                ReplyData::Stream(body_stream) => {
69                    let stream_of_frames = body_stream.map(|chunk| {
70                        chunk
71                            .map(Frame::data)
72                            .map_err(|e| WebError::Internal(e.to_string()))
73                    });
74                    let body = StreamBody::new(stream_of_frames);
75                    Response::builder()
76                        .status(StatusCode::OK)
77                        .body(BodyExt::boxed(body))
78                        .unwrap()
79                }
80                ReplyData::Rich(spec) => self.build_rich_response(*spec).await,
81
82                // The server intercepts `Upgrade` replies (to complete the
83                // handshake) before they reach the finalizer; reaching here
84                // means it wasn't intercepted — surface that as a 500 rather
85                // than panic.
86                ReplyData::Upgrade(_) => Response::builder()
87                    .status(StatusCode::INTERNAL_SERVER_ERROR)
88                    .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
89                    .body(
90                        Full::new(Bytes::from_static(
91                            b"upgrade reply was not handled by the server",
92                        ))
93                        .map_err(|never| match never {})
94                        .boxed(),
95                    )
96                    .unwrap(),
97            }
98        })
99    }
100
101    async fn build_rich_response(&self, spec: ReplySpec) -> Response<BoxBody> {
102        let mut res = self.build_response(spec.payload).await;
103
104        if let Some(status) = spec.status {
105            *res.status_mut() = status;
106        }
107
108        // Insert headers defensively: a hostile or sloppy caller of
109        // `ReplyData::add_header("\n", …)` shouldn't panic the request.
110        // Drop invalid name/value pairs with a `warn!` and carry on —
111        // the rest of the response is still useful.
112        for (k, v) in spec.headers {
113            let key = match HeaderName::from_str(&k) {
114                Ok(name) => name,
115                Err(e) => {
116                    warn!(name = %k, error = %e, "dropping invalid response header name");
117                    continue;
118                }
119            };
120            let value = match HeaderValue::from_str(&v) {
121                Ok(val) => val,
122                Err(e) => {
123                    warn!(name = %k, value = %v, error = %e, "dropping invalid response header value");
124                    continue;
125                }
126            };
127            res.headers_mut().insert(key, value);
128        }
129        res
130    }
131
132    /// Convert a [`WebError`] into a [`ReplyData`] carrying the canonical
133    /// `application/problem+json` body (per RFC 7807), the appropriate
134    /// status, and any error-specific headers (e.g. `Allow` for 405).
135    ///
136    /// Use this when you want an error to flow through the same response
137    /// pipeline as a handler success — after-chain middleware, compression,
138    /// CORS. The simple variants (`NotFound`, `BadRequest`, …) map to obvious
139    /// status/title pairs; `Problem(p)` preserves extension members; the
140    /// returned reply is a [`ReplyData::Rich`] so an `after` hook can stamp
141    /// headers or replace the status without manual juggling.
142    pub fn error_to_reply(&self, error: WebError) -> ReplyData {
143        let mut allow_header: Option<String> = None;
144        let mut retry_after_seconds: Option<u64> = None;
145        let problem = match error {
146            WebError::NotFound => ProblemDetails::new(StatusCode::NOT_FOUND, "Not Found"),
147            WebError::MethodNotAllowed(methods) => {
148                allow_header = Some(methods.join(", "));
149                ProblemDetails::new(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed")
150                    .extra("allowed_methods", serde_json::Value::from(methods))
151            }
152            WebError::BadRequest(msg) => {
153                ProblemDetails::new(StatusCode::BAD_REQUEST, "Bad Request").detail(msg)
154            }
155            WebError::PayloadTooLarge => {
156                ProblemDetails::new(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large")
157            }
158            WebError::TooManyRequests(retry_after) => {
159                let mut p = ProblemDetails::new(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests");
160                if let Some(d) = retry_after {
161                    // RFC 7231 §7.1.3: `Retry-After` is either an HTTP-date
162                    // or delta-seconds (an integer). We use seconds; sub-
163                    // second precision rounds up so we don't tell the
164                    // client "retry in 0s" for a 500ms hint.
165                    let secs = d.as_secs().max(if d.subsec_nanos() > 0 { 1 } else { 0 });
166                    retry_after_seconds = Some(secs);
167                    // Mirror in the problem body too, so a JSON-only client
168                    // (no header inspection) can see the hint.
169                    p = p.extra("retry_after_seconds", serde_json::Value::from(secs));
170                }
171                p
172            }
173            WebError::Timeout => {
174                ProblemDetails::new(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
175                    .detail("the request did not complete within the configured timeout")
176            }
177            WebError::Busy(retry_after) => {
178                let mut p =
179                    ProblemDetails::new(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
180                        .detail("server is overloaded");
181                if let Some(d) = retry_after {
182                    let secs = d.as_secs().max(if d.subsec_nanos() > 0 { 1 } else { 0 });
183                    retry_after_seconds = Some(secs);
184                    p = p.extra("retry_after_seconds", serde_json::Value::from(secs));
185                }
186                p
187            }
188            WebError::Unauthorized => ProblemDetails::new(StatusCode::UNAUTHORIZED, "Unauthorized"),
189            WebError::Forbidden => ProblemDetails::new(StatusCode::FORBIDDEN, "Forbidden"),
190            WebError::Internal(msg) => {
191                ProblemDetails::new(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")
192                    .detail(msg)
193            }
194            WebError::Problem(p) => p,
195        };
196
197        let mut body = serde_json::Map::new();
198        body.insert("status".into(), problem.status.as_u16().into());
199        body.insert("title".into(), problem.title.into());
200        if let Some(d) = problem.detail {
201            body.insert("detail".into(), d.into());
202        }
203        // Extension members: don't let them shadow the standard fields.
204        for (k, v) in *problem.extra {
205            if !matches!(k.as_str(), "status" | "title" | "detail") {
206                body.insert(k, v);
207            }
208        }
209        let bytes = serde_json::to_vec(&serde_json::Value::Object(body)).expect("json");
210        let status = problem.status;
211
212        let mut headers = std::collections::HashMap::new();
213        if let Some(allow) = allow_header {
214            headers.insert("Allow".to_string(), allow);
215        }
216        if let Some(secs) = retry_after_seconds {
217            headers.insert("Retry-After".to_string(), secs.to_string());
218        }
219
220        ReplyData::Rich(Box::new(ReplySpec {
221            payload: ReplyData::Bytes {
222                content_type: std::borrow::Cow::Borrowed("application/problem+json"),
223                data: bytes,
224            },
225            status: Some(status),
226            headers,
227        }))
228    }
229
230    /// Build a complete error `Response` directly — the one-shot
231    /// `error → response` path used for fallback paths (after-chain failures
232    /// while finalizing an error reply, etc.) where running the error
233    /// through the after-chain would risk recursion. For the normal error
234    /// path, prefer [`Finalizer::error_to_reply`] + [`Finalizer::build_response`]
235    /// so middleware / compression / CORS apply uniformly.
236    pub async fn build_error(&self, error: WebError) -> Response<BoxBody> {
237        self.build_response(self.error_to_reply(error)).await
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    async fn body_json(res: Response<BoxBody>) -> serde_json::Value {
246        let bytes = res.into_body().collect().await.unwrap().to_bytes();
247        serde_json::from_slice(&bytes).unwrap()
248    }
249
250    #[tokio::test]
251    async fn method_not_allowed_emits_allow_header_and_lists_methods() {
252        let res = Finalizer::new()
253            .build_error(WebError::MethodNotAllowed(vec!["GET", "POST"]))
254            .await;
255        assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
256        assert_eq!(res.headers().get(header::ALLOW).unwrap(), "GET, POST");
257        let body = body_json(res).await;
258        assert_eq!(body["status"], 405);
259        assert_eq!(body["title"], "Method Not Allowed");
260        assert_eq!(body["allowed_methods"], serde_json::json!(["GET", "POST"]));
261    }
262
263    #[tokio::test]
264    async fn other_errors_have_no_allow_header() {
265        let res = Finalizer::new().build_error(WebError::NotFound).await;
266        assert_eq!(res.status(), StatusCode::NOT_FOUND);
267        assert!(res.headers().get(header::ALLOW).is_none());
268    }
269
270    #[tokio::test]
271    async fn too_many_requests_emits_retry_after_header_and_extra_member() {
272        use std::time::Duration;
273
274        // With a retry hint: 429 + `Retry-After: <seconds>` + an extra
275        // member in the body so a JSON-only client can read it without
276        // header inspection.
277        let res = Finalizer::new()
278            .build_error(WebError::TooManyRequests(Some(Duration::from_secs(42))))
279            .await;
280        assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
281        assert_eq!(res.headers().get("Retry-After").unwrap(), "42");
282        let body = body_json(res).await;
283        assert_eq!(body["status"], 429);
284        assert_eq!(body["title"], "Too Many Requests");
285        assert_eq!(body["retry_after_seconds"], 42);
286
287        // Sub-second hints round up — we never tell a client "retry in 0s"
288        // for a 500ms hint (which they'd typically interpret as "now").
289        let res = Finalizer::new()
290            .build_error(WebError::TooManyRequests(Some(Duration::from_millis(500))))
291            .await;
292        assert_eq!(res.headers().get("Retry-After").unwrap(), "1");
293
294        // Without a retry hint: 429 + no `Retry-After` header.
295        let res = Finalizer::new()
296            .build_error(WebError::TooManyRequests(None))
297            .await;
298        assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
299        assert!(res.headers().get("Retry-After").is_none());
300    }
301
302    #[tokio::test]
303    async fn busy_emits_503_with_retry_after() {
304        use std::time::Duration;
305
306        // With a hint: 503 + `Retry-After: <seconds>` + extra in the body.
307        let res = Finalizer::new()
308            .build_error(WebError::Busy(Some(Duration::from_secs(2))))
309            .await;
310        assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
311        assert_eq!(res.headers().get("Retry-After").unwrap(), "2");
312        let body = body_json(res).await;
313        assert_eq!(body["status"], 503);
314        assert_eq!(body["title"], "Service Unavailable");
315        assert_eq!(body["retry_after_seconds"], 2);
316
317        // No hint: 503 + no Retry-After.
318        let res = Finalizer::new().build_error(WebError::Busy(None)).await;
319        assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
320        assert!(res.headers().get("Retry-After").is_none());
321    }
322
323    #[tokio::test]
324    async fn build_rich_response_drops_invalid_headers_without_panicking() {
325        use crate::reply::ReplySpec;
326        use std::collections::HashMap;
327
328        let mut headers = HashMap::new();
329        // A newline in a header name is invalid per RFC 7230 §3.2 — must
330        // be dropped, not panicked on.
331        headers.insert("X-Bad\nName".to_string(), "value".to_string());
332        // A control byte in the value is also invalid.
333        headers.insert("X-Bad-Value".to_string(), "with\nnewline".to_string());
334        // And a valid one alongside, to prove the rest of the spec survives.
335        headers.insert("X-OK".to_string(), "fine".to_string());
336
337        let spec = ReplySpec {
338            payload: ReplyData::Empty,
339            status: Some(StatusCode::CREATED),
340            headers,
341        };
342        let res = Finalizer::new()
343            .build_response(ReplyData::Rich(Box::new(spec)))
344            .await;
345        assert_eq!(res.status(), StatusCode::CREATED);
346        assert_eq!(res.headers().get("X-OK").unwrap(), "fine");
347        // The two invalid pairs got dropped.
348        assert!(res.headers().get("X-Bad-Value").is_none());
349    }
350}