Skip to main content

fastapi_http/
response.rs

1//! HTTP response writer.
2
3use asupersync::stream::Stream;
4use fastapi_core::{BodyStream, Response, ResponseBody, StatusCode};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8/// Serialized response output.
9pub enum ResponseWrite {
10    /// Fully-buffered response bytes.
11    Full(Vec<u8>),
12    /// Chunked stream (head + body chunks).
13    Stream(ChunkedEncoder),
14}
15
16/// HTTP trailers sent after a chunked response body.
17///
18/// Per RFC 7230, trailers are headers sent after the final chunk in a
19/// chunked transfer encoding. Common uses include content digests,
20/// signatures, and final status after streaming.
21///
22/// # Example
23///
24/// ```
25/// use fastapi_http::Trailers;
26///
27/// let trailers = Trailers::new()
28///     .add("Content-MD5", "Q2hlY2tzdW0=")
29///     .add("Server-Timing", "total;dur=123");
30/// ```
31#[derive(Debug, Clone, Default)]
32pub struct Trailers {
33    headers: Vec<(String, String)>,
34}
35
36impl Trailers {
37    /// Create an empty trailers set.
38    #[must_use]
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    /// Add a trailer header.
44    #[must_use]
45    pub fn add(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
46        self.headers.push((name.into(), value.into()));
47        self
48    }
49
50    /// Returns true if no trailers are set.
51    #[must_use]
52    pub fn is_empty(&self) -> bool {
53        self.headers.is_empty()
54    }
55
56    /// Returns the trailer header names as a comma-separated string
57    /// for the `Trailer` response header.
58    #[must_use]
59    pub fn trailer_header_value(&self) -> String {
60        self.headers
61            .iter()
62            .map(|(n, _)| n.as_str())
63            .collect::<Vec<_>>()
64            .join(", ")
65    }
66
67    /// Encode the trailers as bytes for the chunked encoding terminator.
68    ///
69    /// Format: `name: value\r\n` for each trailer.
70    fn encode(&self) -> Vec<u8> {
71        let mut out = Vec::new();
72        for (name, value) in &self.headers {
73            out.extend_from_slice(name.as_bytes());
74            out.extend_from_slice(b": ");
75            out.extend_from_slice(value.as_bytes());
76            out.extend_from_slice(b"\r\n");
77        }
78        out
79    }
80}
81
82/// Streaming chunked response encoder.
83pub struct ChunkedEncoder {
84    head: Option<Vec<u8>>,
85    body: BodyStream,
86    finished: bool,
87    trailers: Option<Trailers>,
88}
89
90impl ChunkedEncoder {
91    fn new(head: Vec<u8>, body: BodyStream) -> Self {
92        Self {
93            head: Some(head),
94            body,
95            finished: false,
96            trailers: None,
97        }
98    }
99
100    /// Set trailers to be sent after the final chunk.
101    #[must_use]
102    pub fn with_trailers(mut self, trailers: Trailers) -> Self {
103        self.trailers = Some(trailers);
104        self
105    }
106
107    fn encode_chunk(chunk: &[u8]) -> Vec<u8> {
108        // Use std::io::Write to format hex directly into buffer without allocation.
109        // Max hex digits for usize is 16 (64-bit), so we pre-allocate conservatively.
110        use std::io::Write as _;
111        let mut out = Vec::with_capacity(20 + chunk.len() + 4);
112        write!(out, "{:x}\r\n", chunk.len()).expect("write to Vec cannot fail");
113        out.extend_from_slice(chunk);
114        out.extend_from_slice(b"\r\n");
115        out
116    }
117
118    /// Encode the final chunk with optional trailers.
119    ///
120    /// Per RFC 7230 Section 4.1:
121    /// - Without trailers: `0\r\n\r\n`
122    /// - With trailers: `0\r\n<trailer-headers>\r\n`
123    fn encode_final_chunk(&self) -> Vec<u8> {
124        let mut out = Vec::new();
125        out.extend_from_slice(b"0\r\n");
126        if let Some(ref trailers) = self.trailers {
127            out.extend_from_slice(&trailers.encode());
128        }
129        out.extend_from_slice(b"\r\n");
130        out
131    }
132}
133
134impl Stream for ChunkedEncoder {
135    type Item = Vec<u8>;
136
137    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138        if let Some(head) = self.head.take() {
139            return Poll::Ready(Some(head));
140        }
141
142        if self.finished {
143            return Poll::Ready(None);
144        }
145
146        loop {
147            match self.body.as_mut().poll_next(cx) {
148                Poll::Pending => return Poll::Pending,
149                Poll::Ready(Some(chunk)) => {
150                    if chunk.is_empty() {
151                        continue;
152                    }
153                    return Poll::Ready(Some(Self::encode_chunk(&chunk)));
154                }
155                Poll::Ready(None) => {
156                    self.finished = true;
157                    return Poll::Ready(Some(self.encode_final_chunk()));
158                }
159            }
160        }
161    }
162}
163
164/// Writes HTTP responses to a buffer.
165pub struct ResponseWriter {
166    buffer: Vec<u8>,
167}
168
169impl ResponseWriter {
170    /// Create a new response writer.
171    #[must_use]
172    pub fn new() -> Self {
173        Self {
174            buffer: Vec::with_capacity(4096),
175        }
176    }
177
178    /// Write a response into either a full buffer or a stream.
179    #[must_use]
180    pub fn write(&mut self, response: Response) -> ResponseWrite {
181        let (status, headers, body) = response.into_parts();
182        match body {
183            ResponseBody::Empty => {
184                let bytes = self.write_full(status, &headers, &[]);
185                ResponseWrite::Full(bytes)
186            }
187            ResponseBody::Bytes(body) => {
188                let bytes = self.write_full(status, &headers, &body);
189                ResponseWrite::Full(bytes)
190            }
191            ResponseBody::Stream(body) => {
192                let head = self.write_stream_head(status, &headers);
193                ResponseWrite::Stream(ChunkedEncoder::new(head, body))
194            }
195        }
196    }
197
198    fn write_full(
199        &mut self,
200        status: StatusCode,
201        headers: &[(String, Vec<u8>)],
202        body: &[u8],
203    ) -> Vec<u8> {
204        self.buffer.clear();
205
206        // Status line
207        self.buffer.extend_from_slice(b"HTTP/1.1 ");
208        self.write_status(status);
209        self.buffer.extend_from_slice(b"\r\n");
210
211        // Headers (filter hop-by-hop content-length/transfer-encoding)
212        for (name, value) in headers {
213            if is_content_length(name) || is_transfer_encoding(name) {
214                continue;
215            }
216            self.buffer.extend_from_slice(name.as_bytes());
217            self.buffer.extend_from_slice(b": ");
218            self.buffer.extend_from_slice(value);
219            self.buffer.extend_from_slice(b"\r\n");
220        }
221
222        // Content-Length
223        self.buffer.extend_from_slice(b"content-length: ");
224        self.buffer
225            .extend_from_slice(body.len().to_string().as_bytes());
226        self.buffer.extend_from_slice(b"\r\n");
227
228        // End of headers
229        self.buffer.extend_from_slice(b"\r\n");
230
231        // Body
232        self.buffer.extend_from_slice(body);
233
234        self.take_buffer()
235    }
236
237    fn write_stream_head(&mut self, status: StatusCode, headers: &[(String, Vec<u8>)]) -> Vec<u8> {
238        self.buffer.clear();
239
240        // Status line
241        self.buffer.extend_from_slice(b"HTTP/1.1 ");
242        self.write_status(status);
243        self.buffer.extend_from_slice(b"\r\n");
244
245        // Headers (filter hop-by-hop content-length/transfer-encoding)
246        for (name, value) in headers {
247            if is_content_length(name) || is_transfer_encoding(name) {
248                continue;
249            }
250            self.buffer.extend_from_slice(name.as_bytes());
251            self.buffer.extend_from_slice(b": ");
252            self.buffer.extend_from_slice(value);
253            self.buffer.extend_from_slice(b"\r\n");
254        }
255
256        // Transfer-Encoding: chunked
257        self.buffer
258            .extend_from_slice(b"transfer-encoding: chunked\r\n");
259
260        // End of headers
261        self.buffer.extend_from_slice(b"\r\n");
262
263        self.take_buffer()
264    }
265
266    fn write_status(&mut self, status: StatusCode) {
267        let code = status.as_u16();
268        self.buffer.extend_from_slice(code.to_string().as_bytes());
269        self.buffer.extend_from_slice(b" ");
270        self.buffer
271            .extend_from_slice(status.canonical_reason().as_bytes());
272    }
273
274    fn take_buffer(&mut self) -> Vec<u8> {
275        let mut out = Vec::new();
276        std::mem::swap(&mut out, &mut self.buffer);
277        self.buffer = Vec::with_capacity(out.capacity());
278        out
279    }
280}
281
282fn is_content_length(name: &str) -> bool {
283    name.eq_ignore_ascii_case("content-length")
284}
285
286fn is_transfer_encoding(name: &str) -> bool {
287    name.eq_ignore_ascii_case("transfer-encoding")
288}
289
290impl Default for ResponseWriter {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use asupersync::stream::iter;
300    use std::sync::Arc;
301    use std::task::{Wake, Waker};
302
303    struct NoopWaker;
304
305    impl Wake for NoopWaker {
306        fn wake(self: Arc<Self>) {}
307    }
308
309    fn noop_waker() -> Waker {
310        Waker::from(Arc::new(NoopWaker))
311    }
312
313    fn collect_stream<S: Stream<Item = Vec<u8>> + Unpin>(mut stream: S) -> Vec<u8> {
314        let waker = noop_waker();
315        let mut cx = Context::from_waker(&waker);
316        let mut out = Vec::new();
317
318        loop {
319            match Pin::new(&mut stream).poll_next(&mut cx) {
320                Poll::Ready(Some(chunk)) => out.extend_from_slice(&chunk),
321                Poll::Ready(None) => break,
322                Poll::Pending => panic!("unexpected pending stream"),
323            }
324        }
325
326        out
327    }
328
329    #[test]
330    fn write_full_sets_content_length() {
331        let response = Response::ok()
332            .header("content-type", b"text/plain".to_vec())
333            .body(ResponseBody::Bytes(b"hello".to_vec()));
334        let mut writer = ResponseWriter::new();
335        let bytes = match writer.write(response) {
336            ResponseWrite::Full(bytes) => bytes,
337            ResponseWrite::Stream(_) => panic!("expected full response"),
338        };
339        let text = String::from_utf8_lossy(&bytes);
340        assert!(text.starts_with("HTTP/1.1 200 OK\r\n"));
341        assert!(text.contains("content-length: 5\r\n"));
342        assert!(text.contains("\r\n\r\nhello"));
343    }
344
345    #[test]
346    fn write_stream_uses_chunked_encoding() {
347        let stream = iter(vec![b"hello".to_vec(), b"world".to_vec()]);
348        let response = Response::ok()
349            .header("content-type", b"text/plain".to_vec())
350            .body(ResponseBody::stream(stream));
351        let mut writer = ResponseWriter::new();
352        let bytes = match writer.write(response) {
353            ResponseWrite::Stream(stream) => collect_stream(stream),
354            ResponseWrite::Full(_) => panic!("expected stream response"),
355        };
356
357        let expected = b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ntransfer-encoding: chunked\r\n\r\n5\r\nhello\r\n5\r\nworld\r\n0\r\n\r\n";
358        assert_eq!(bytes, expected);
359    }
360
361    // ====================================================================
362    // Trailer Tests
363    // ====================================================================
364
365    #[test]
366    fn trailers_empty() {
367        let t = Trailers::new();
368        assert!(t.is_empty());
369        assert_eq!(t.trailer_header_value(), "");
370    }
371
372    #[test]
373    fn trailers_encode() {
374        let t = Trailers::new()
375            .add("Content-MD5", "abc123")
376            .add("Server-Timing", "total;dur=50");
377        assert!(!t.is_empty());
378        assert_eq!(t.trailer_header_value(), "Content-MD5, Server-Timing");
379        let encoded = t.encode();
380        let s = std::str::from_utf8(&encoded).unwrap();
381        assert!(s.contains("Content-MD5: abc123\r\n"));
382        assert!(s.contains("Server-Timing: total;dur=50\r\n"));
383    }
384
385    #[test]
386    fn chunked_encoder_with_trailers() {
387        let stream = iter(vec![b"data".to_vec()]);
388        let body = Box::pin(stream) as BodyStream;
389        let head = b"HTTP/1.1 200 OK\r\n\r\n".to_vec();
390        let trailers = Trailers::new().add("Checksum", "deadbeef");
391        let encoder = ChunkedEncoder::new(head, body).with_trailers(trailers);
392        let bytes = collect_stream(encoder);
393        let s = std::str::from_utf8(&bytes).unwrap();
394        // Should contain the trailer after the final chunk
395        assert!(s.contains("0\r\nChecksum: deadbeef\r\n\r\n"));
396    }
397
398    #[test]
399    fn chunked_encoder_without_trailers_unchanged() {
400        let stream = iter(vec![b"hi".to_vec()]);
401        let body = Box::pin(stream) as BodyStream;
402        let head = b"HTTP/1.1 200 OK\r\n\r\n".to_vec();
403        let encoder = ChunkedEncoder::new(head, body);
404        let bytes = collect_stream(encoder);
405        assert!(bytes.ends_with(b"0\r\n\r\n"));
406    }
407
408    #[test]
409    fn final_chunk_format_with_multiple_trailers() {
410        let t = Trailers::new()
411            .add("Digest", "sha-256=abc")
412            .add("Signature", "sig123");
413        let encoder = ChunkedEncoder {
414            head: None,
415            body: Box::pin(iter(Vec::<Vec<u8>>::new())),
416            finished: false,
417            trailers: Some(t),
418        };
419        let final_chunk = encoder.encode_final_chunk();
420        let s = std::str::from_utf8(&final_chunk).unwrap();
421        assert_eq!(s, "0\r\nDigest: sha-256=abc\r\nSignature: sig123\r\n\r\n");
422    }
423}