Skip to main content

fastapi_http/
response.rs

1//! HTTP response writer.
2
3use asupersync::stream::Stream;
4use fastapi_core::{BodyStream, Response, ResponseBody, StatusCode};
5use std::borrow::Cow;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9/// Serialized response output.
10pub enum ResponseWrite {
11    /// Fully-buffered response bytes.
12    Full(Vec<u8>),
13    /// Chunked stream (head + body chunks).
14    Stream(ChunkedEncoder),
15}
16
17/// HTTP trailers sent after a chunked response body.
18///
19/// Per RFC 7230, trailers are headers sent after the final chunk in a
20/// chunked transfer encoding. Common uses include content digests,
21/// signatures, and final status after streaming.
22///
23/// # Example
24///
25/// ```
26/// use fastapi_http::Trailers;
27///
28/// let trailers = Trailers::new()
29///     .add("Content-MD5", "Q2hlY2tzdW0=")
30///     .add("Server-Timing", "total;dur=123");
31/// ```
32#[derive(Debug, Clone, Default)]
33pub struct Trailers {
34    headers: Vec<(String, String)>,
35}
36
37impl Trailers {
38    /// Create an empty trailers set.
39    #[must_use]
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Add a trailer header.
45    #[must_use]
46    pub fn add(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
47        self.headers.push((name.into(), value.into()));
48        self
49    }
50
51    /// Returns true if no trailers are set.
52    #[must_use]
53    pub fn is_empty(&self) -> bool {
54        self.headers.is_empty()
55    }
56
57    /// Returns the trailer header names as a comma-separated string
58    /// for the `Trailer` response header.
59    #[must_use]
60    pub fn trailer_header_value(&self) -> String {
61        self.headers
62            .iter()
63            .map(|(n, _)| n.as_str())
64            .collect::<Vec<_>>()
65            .join(", ")
66    }
67
68    /// Encode the trailers as bytes for the chunked encoding terminator.
69    ///
70    /// Format: `name: value\r\n` for each trailer.
71    fn encode(&self) -> Vec<u8> {
72        let mut out = Vec::new();
73        for (name, value) in &self.headers {
74            write_header_line(&mut out, name, value.as_bytes());
75        }
76        out
77    }
78}
79
80/// Streaming chunked response encoder.
81pub struct ChunkedEncoder {
82    head: Option<Vec<u8>>,
83    body: BodyStream,
84    finished: bool,
85    trailers: Option<Trailers>,
86}
87
88impl ChunkedEncoder {
89    fn new(head: Vec<u8>, body: BodyStream) -> Self {
90        Self {
91            head: Some(head),
92            body,
93            finished: false,
94            trailers: None,
95        }
96    }
97
98    /// Set trailers to be sent after the final chunk.
99    #[must_use]
100    pub fn with_trailers(mut self, trailers: Trailers) -> Self {
101        self.trailers = Some(trailers);
102        self
103    }
104
105    fn encode_chunk(chunk: &[u8]) -> Vec<u8> {
106        // Use std::io::Write to format hex directly into buffer without allocation.
107        // Max hex digits for usize is 16 (64-bit), so we pre-allocate conservatively.
108        use std::io::Write as _;
109        let mut out = Vec::with_capacity(20 + chunk.len() + 4);
110        write!(out, "{:x}\r\n", chunk.len()).expect("write to Vec cannot fail");
111        out.extend_from_slice(chunk);
112        out.extend_from_slice(b"\r\n");
113        out
114    }
115
116    /// Encode the final chunk with optional trailers.
117    ///
118    /// Per RFC 7230 Section 4.1:
119    /// - Without trailers: `0\r\n\r\n`
120    /// - With trailers: `0\r\n<trailer-headers>\r\n`
121    fn encode_final_chunk(&self) -> Vec<u8> {
122        let mut out = Vec::new();
123        out.extend_from_slice(b"0\r\n");
124        if let Some(ref trailers) = self.trailers {
125            out.extend_from_slice(&trailers.encode());
126        }
127        out.extend_from_slice(b"\r\n");
128        out
129    }
130}
131
132impl Stream for ChunkedEncoder {
133    type Item = Vec<u8>;
134
135    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136        if let Some(head) = self.head.take() {
137            return Poll::Ready(Some(head));
138        }
139
140        if self.finished {
141            return Poll::Ready(None);
142        }
143
144        loop {
145            match self.body.as_mut().poll_next(cx) {
146                Poll::Pending => return Poll::Pending,
147                Poll::Ready(Some(chunk)) => {
148                    if chunk.is_empty() {
149                        continue;
150                    }
151                    return Poll::Ready(Some(Self::encode_chunk(&chunk)));
152                }
153                Poll::Ready(None) => {
154                    self.finished = true;
155                    return Poll::Ready(Some(self.encode_final_chunk()));
156                }
157            }
158        }
159    }
160}
161
162/// Writes HTTP responses to a buffer.
163pub struct ResponseWriter {
164    buffer: Vec<u8>,
165}
166
167impl ResponseWriter {
168    /// Create a new response writer.
169    #[must_use]
170    pub fn new() -> Self {
171        Self {
172            buffer: Vec::with_capacity(4096),
173        }
174    }
175
176    /// Write a response into either a full buffer or a stream.
177    #[must_use]
178    pub fn write(&mut self, response: Response) -> ResponseWrite {
179        let (status, headers, body) = response.into_parts();
180        match body {
181            ResponseBody::Empty => {
182                let bytes = self.write_full(status, &headers, &[]);
183                ResponseWrite::Full(bytes)
184            }
185            ResponseBody::Bytes(body) => {
186                let bytes = self.write_full(status, &headers, &body);
187                ResponseWrite::Full(bytes)
188            }
189            ResponseBody::Stream(body) => {
190                let head = self.write_stream_head(status, &headers);
191                ResponseWrite::Stream(ChunkedEncoder::new(head, body))
192            }
193        }
194    }
195
196    fn write_full(
197        &mut self,
198        status: StatusCode,
199        headers: &[(String, Vec<u8>)],
200        body: &[u8],
201    ) -> Vec<u8> {
202        self.buffer.clear();
203
204        // Status line
205        self.buffer.extend_from_slice(b"HTTP/1.1 ");
206        self.write_status(status);
207        self.buffer.extend_from_slice(b"\r\n");
208
209        // Headers (filter hop-by-hop content-length/transfer-encoding)
210        for (name, value) in headers {
211            if is_content_length(name) || is_transfer_encoding(name) {
212                continue;
213            }
214            write_header_line(&mut self.buffer, name, value);
215        }
216
217        // Content-Length
218        self.buffer.extend_from_slice(b"content-length: ");
219        self.buffer
220            .extend_from_slice(body.len().to_string().as_bytes());
221        self.buffer.extend_from_slice(b"\r\n");
222
223        // End of headers
224        self.buffer.extend_from_slice(b"\r\n");
225
226        // Body
227        self.buffer.extend_from_slice(body);
228
229        self.take_buffer()
230    }
231
232    fn write_stream_head(&mut self, status: StatusCode, headers: &[(String, Vec<u8>)]) -> Vec<u8> {
233        self.buffer.clear();
234
235        // Status line
236        self.buffer.extend_from_slice(b"HTTP/1.1 ");
237        self.write_status(status);
238        self.buffer.extend_from_slice(b"\r\n");
239
240        // Headers (filter hop-by-hop content-length/transfer-encoding)
241        for (name, value) in headers {
242            if is_content_length(name) || is_transfer_encoding(name) {
243                continue;
244            }
245            write_header_line(&mut self.buffer, name, value);
246        }
247
248        // Transfer-Encoding: chunked
249        self.buffer
250            .extend_from_slice(b"transfer-encoding: chunked\r\n");
251
252        // End of headers
253        self.buffer.extend_from_slice(b"\r\n");
254
255        self.take_buffer()
256    }
257
258    fn write_status(&mut self, status: StatusCode) {
259        let code = status.as_u16();
260        self.buffer.extend_from_slice(code.to_string().as_bytes());
261        self.buffer.extend_from_slice(b" ");
262        self.buffer
263            .extend_from_slice(status.canonical_reason().as_bytes());
264    }
265
266    fn take_buffer(&mut self) -> Vec<u8> {
267        let mut out = Vec::new();
268        std::mem::swap(&mut out, &mut self.buffer);
269        self.buffer = Vec::with_capacity(out.capacity());
270        out
271    }
272}
273
274fn is_content_length(name: &str) -> bool {
275    name.eq_ignore_ascii_case("content-length")
276}
277
278fn is_transfer_encoding(name: &str) -> bool {
279    name.eq_ignore_ascii_case("transfer-encoding")
280}
281
282fn write_header_line(buffer: &mut Vec<u8>, name: &str, value: &[u8]) {
283    if !is_valid_header_name(name) {
284        return;
285    }
286    buffer.extend_from_slice(name.as_bytes());
287    buffer.extend_from_slice(b": ");
288    buffer.extend_from_slice(sanitize_header_value(value).as_ref());
289    buffer.extend_from_slice(b"\r\n");
290}
291
292fn sanitize_header_value(value: &[u8]) -> Cow<'_, [u8]> {
293    if value
294        .iter()
295        .all(|&byte| byte != b'\r' && byte != b'\n' && byte != 0)
296    {
297        return Cow::Borrowed(value);
298    }
299    Cow::Owned(
300        value
301            .iter()
302            .copied()
303            .filter(|&byte| byte != b'\r' && byte != b'\n' && byte != 0)
304            .collect(),
305    )
306}
307
308fn is_valid_header_name(name: &str) -> bool {
309    !name.is_empty()
310        && name.bytes().all(|byte| {
311            matches!(
312                byte,
313                b'!' | b'#'
314                    | b'$'
315                    | b'%'
316                    | b'&'
317                    | b'\''
318                    | b'*'
319                    | b'+'
320                    | b'-'
321                    | b'.'
322                    | b'0'..=b'9'
323                    | b'A'..=b'Z'
324                    | b'^'
325                    | b'_'
326                    | b'`'
327                    | b'a'..=b'z'
328                    | b'|'
329                    | b'~'
330            )
331        })
332}
333
334impl Default for ResponseWriter {
335    fn default() -> Self {
336        Self::new()
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use asupersync::stream::iter;
344    use std::sync::Arc;
345    use std::task::{Wake, Waker};
346
347    struct NoopWaker;
348
349    impl Wake for NoopWaker {
350        fn wake(self: Arc<Self>) {}
351    }
352
353    fn noop_waker() -> Waker {
354        Waker::from(Arc::new(NoopWaker))
355    }
356
357    fn collect_stream<S: Stream<Item = Vec<u8>> + Unpin>(mut stream: S) -> Vec<u8> {
358        let waker = noop_waker();
359        let mut cx = Context::from_waker(&waker);
360        let mut out = Vec::new();
361
362        loop {
363            match Pin::new(&mut stream).poll_next(&mut cx) {
364                Poll::Ready(Some(chunk)) => out.extend_from_slice(&chunk),
365                Poll::Ready(None) => break,
366                Poll::Pending => panic!("unexpected pending stream"),
367            }
368        }
369
370        out
371    }
372
373    #[test]
374    fn write_full_sets_content_length() {
375        let response = Response::ok()
376            .header("content-type", b"text/plain".to_vec())
377            .body(ResponseBody::Bytes(b"hello".to_vec()));
378        let mut writer = ResponseWriter::new();
379        let bytes = match writer.write(response) {
380            ResponseWrite::Full(bytes) => bytes,
381            ResponseWrite::Stream(_) => panic!("expected full response"),
382        };
383        let text = String::from_utf8_lossy(&bytes);
384        assert!(text.starts_with("HTTP/1.1 200 OK\r\n"));
385        assert!(text.contains("content-length: 5\r\n"));
386        assert!(text.contains("\r\n\r\nhello"));
387    }
388
389    #[test]
390    fn write_stream_uses_chunked_encoding() {
391        let stream = iter(vec![b"hello".to_vec(), b"world".to_vec()]);
392        let response = Response::ok()
393            .header("content-type", b"text/plain".to_vec())
394            .body(ResponseBody::stream(stream));
395        let mut writer = ResponseWriter::new();
396        let bytes = match writer.write(response) {
397            ResponseWrite::Stream(stream) => collect_stream(stream),
398            ResponseWrite::Full(_) => panic!("expected stream response"),
399        };
400
401        let expected = b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ntransfer-encoding: chunked\r\n\r\n5\r\nhello\r\n5\r\nworld\r\n0\r\n\r\n";
402        assert_eq!(bytes, expected);
403    }
404
405    // ====================================================================
406    // Trailer Tests
407    // ====================================================================
408
409    #[test]
410    fn trailers_empty() {
411        let t = Trailers::new();
412        assert!(t.is_empty());
413        assert_eq!(t.trailer_header_value(), "");
414    }
415
416    #[test]
417    fn trailers_encode() {
418        let t = Trailers::new()
419            .add("Content-MD5", "abc123")
420            .add("Server-Timing", "total;dur=50");
421        assert!(!t.is_empty());
422        assert_eq!(t.trailer_header_value(), "Content-MD5, Server-Timing");
423        let encoded = t.encode();
424        let s = std::str::from_utf8(&encoded).unwrap();
425        assert!(s.contains("Content-MD5: abc123\r\n"));
426        assert!(s.contains("Server-Timing: total;dur=50\r\n"));
427    }
428
429    #[test]
430    fn chunked_encoder_with_trailers() {
431        let stream = iter(vec![b"data".to_vec()]);
432        let body = Box::pin(stream) as BodyStream;
433        let head = b"HTTP/1.1 200 OK\r\n\r\n".to_vec();
434        let trailers = Trailers::new().add("Checksum", "deadbeef");
435        let encoder = ChunkedEncoder::new(head, body).with_trailers(trailers);
436        let bytes = collect_stream(encoder);
437        let s = std::str::from_utf8(&bytes).unwrap();
438        // Should contain the trailer after the final chunk
439        assert!(s.contains("0\r\nChecksum: deadbeef\r\n\r\n"));
440    }
441
442    #[test]
443    fn chunked_encoder_without_trailers_unchanged() {
444        let stream = iter(vec![b"hi".to_vec()]);
445        let body = Box::pin(stream) as BodyStream;
446        let head = b"HTTP/1.1 200 OK\r\n\r\n".to_vec();
447        let encoder = ChunkedEncoder::new(head, body);
448        let bytes = collect_stream(encoder);
449        assert!(bytes.ends_with(b"0\r\n\r\n"));
450    }
451
452    #[test]
453    fn final_chunk_format_with_multiple_trailers() {
454        let t = Trailers::new()
455            .add("Digest", "sha-256=abc")
456            .add("Signature", "sig123");
457        let encoder = ChunkedEncoder {
458            head: None,
459            body: Box::pin(iter(Vec::<Vec<u8>>::new())),
460            finished: false,
461            trailers: Some(t),
462        };
463        let final_chunk = encoder.encode_final_chunk();
464        let s = std::str::from_utf8(&final_chunk).unwrap();
465        assert_eq!(s, "0\r\nDigest: sha-256=abc\r\nSignature: sig123\r\n\r\n");
466    }
467
468    #[test]
469    fn write_full_drops_invalid_header_names_and_sanitizes_values() {
470        let mut writer = ResponseWriter::new();
471        let headers = vec![
472            ("x-ok".to_string(), b"safe".to_vec()),
473            ("bad\r\nname".to_string(), b"ignored".to_vec()),
474            ("x-test".to_string(), b"hello\r\nx-injected: yes".to_vec()),
475        ];
476
477        let bytes = writer.write_full(StatusCode::OK, &headers, b"body");
478        let text = String::from_utf8_lossy(&bytes);
479
480        assert!(text.contains("x-ok: safe\r\n"));
481        assert!(!text.contains("bad\r\nname:"));
482        assert!(text.contains("x-test: hellox-injected: yes\r\n"));
483        assert!(!text.contains("\r\nx-injected: yes\r\n"));
484    }
485
486    #[test]
487    fn write_stream_head_drops_invalid_header_names_and_sanitizes_values() {
488        let mut writer = ResponseWriter::new();
489        let headers = vec![
490            ("content-type".to_string(), b"text/plain".to_vec()),
491            ("bad\nname".to_string(), b"ignored".to_vec()),
492            ("x-test".to_string(), b"hello\r\nx-injected: yes".to_vec()),
493        ];
494
495        let bytes = writer.write_stream_head(StatusCode::OK, &headers);
496        let text = String::from_utf8_lossy(&bytes);
497
498        assert!(text.contains("content-type: text/plain\r\n"));
499        assert!(!text.contains("bad\nname:"));
500        assert!(text.contains("x-test: hellox-injected: yes\r\n"));
501        assert!(!text.contains("\r\nx-injected: yes\r\n"));
502    }
503
504    #[test]
505    fn trailers_encode_drops_invalid_names_and_sanitizes_values() {
506        let encoded = Trailers::new()
507            .add("Checksum", "abc123")
508            .add("Bad\r\nName", "ignored")
509            .add("Signature", "sig\r\nInjected: yes")
510            .encode();
511        let text = std::str::from_utf8(&encoded).unwrap();
512
513        assert!(text.contains("Checksum: abc123\r\n"));
514        assert!(!text.contains("Bad\r\nName"));
515        assert!(text.contains("Signature: sigInjected: yes\r\n"));
516        assert!(!text.contains("\r\nInjected: yes\r\n"));
517    }
518}