1use asupersync::stream::Stream;
4use fastapi_core::{BodyStream, Response, ResponseBody, StatusCode};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8pub enum ResponseWrite {
10 Full(Vec<u8>),
12 Stream(ChunkedEncoder),
14}
15
16#[derive(Debug, Clone, Default)]
32pub struct Trailers {
33 headers: Vec<(String, String)>,
34}
35
36impl Trailers {
37 #[must_use]
39 pub fn new() -> Self {
40 Self::default()
41 }
42
43 #[must_use]
45 pub fn add(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
46 self.headers.push((name.into(), value.into()));
47 self
48 }
49
50 #[must_use]
52 pub fn is_empty(&self) -> bool {
53 self.headers.is_empty()
54 }
55
56 #[must_use]
59 pub fn trailer_header_value(&self) -> String {
60 self.headers
61 .iter()
62 .map(|(n, _)| n.as_str())
63 .collect::<Vec<_>>()
64 .join(", ")
65 }
66
67 fn encode(&self) -> Vec<u8> {
71 let mut out = Vec::new();
72 for (name, value) in &self.headers {
73 out.extend_from_slice(name.as_bytes());
74 out.extend_from_slice(b": ");
75 out.extend_from_slice(value.as_bytes());
76 out.extend_from_slice(b"\r\n");
77 }
78 out
79 }
80}
81
82pub struct ChunkedEncoder {
84 head: Option<Vec<u8>>,
85 body: BodyStream,
86 finished: bool,
87 trailers: Option<Trailers>,
88}
89
90impl ChunkedEncoder {
91 fn new(head: Vec<u8>, body: BodyStream) -> Self {
92 Self {
93 head: Some(head),
94 body,
95 finished: false,
96 trailers: None,
97 }
98 }
99
100 #[must_use]
102 pub fn with_trailers(mut self, trailers: Trailers) -> Self {
103 self.trailers = Some(trailers);
104 self
105 }
106
107 fn encode_chunk(chunk: &[u8]) -> Vec<u8> {
108 use std::io::Write as _;
111 let mut out = Vec::with_capacity(20 + chunk.len() + 4);
112 write!(out, "{:x}\r\n", chunk.len()).expect("write to Vec cannot fail");
113 out.extend_from_slice(chunk);
114 out.extend_from_slice(b"\r\n");
115 out
116 }
117
118 fn encode_final_chunk(&self) -> Vec<u8> {
124 let mut out = Vec::new();
125 out.extend_from_slice(b"0\r\n");
126 if let Some(ref trailers) = self.trailers {
127 out.extend_from_slice(&trailers.encode());
128 }
129 out.extend_from_slice(b"\r\n");
130 out
131 }
132}
133
134impl Stream for ChunkedEncoder {
135 type Item = Vec<u8>;
136
137 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138 if let Some(head) = self.head.take() {
139 return Poll::Ready(Some(head));
140 }
141
142 if self.finished {
143 return Poll::Ready(None);
144 }
145
146 loop {
147 match self.body.as_mut().poll_next(cx) {
148 Poll::Pending => return Poll::Pending,
149 Poll::Ready(Some(chunk)) => {
150 if chunk.is_empty() {
151 continue;
152 }
153 return Poll::Ready(Some(Self::encode_chunk(&chunk)));
154 }
155 Poll::Ready(None) => {
156 self.finished = true;
157 return Poll::Ready(Some(self.encode_final_chunk()));
158 }
159 }
160 }
161 }
162}
163
164pub struct ResponseWriter {
166 buffer: Vec<u8>,
167}
168
169impl ResponseWriter {
170 #[must_use]
172 pub fn new() -> Self {
173 Self {
174 buffer: Vec::with_capacity(4096),
175 }
176 }
177
178 #[must_use]
180 pub fn write(&mut self, response: Response) -> ResponseWrite {
181 let (status, headers, body) = response.into_parts();
182 match body {
183 ResponseBody::Empty => {
184 let bytes = self.write_full(status, &headers, &[]);
185 ResponseWrite::Full(bytes)
186 }
187 ResponseBody::Bytes(body) => {
188 let bytes = self.write_full(status, &headers, &body);
189 ResponseWrite::Full(bytes)
190 }
191 ResponseBody::Stream(body) => {
192 let head = self.write_stream_head(status, &headers);
193 ResponseWrite::Stream(ChunkedEncoder::new(head, body))
194 }
195 }
196 }
197
198 fn write_full(
199 &mut self,
200 status: StatusCode,
201 headers: &[(String, Vec<u8>)],
202 body: &[u8],
203 ) -> Vec<u8> {
204 self.buffer.clear();
205
206 self.buffer.extend_from_slice(b"HTTP/1.1 ");
208 self.write_status(status);
209 self.buffer.extend_from_slice(b"\r\n");
210
211 for (name, value) in headers {
213 if is_content_length(name) || is_transfer_encoding(name) {
214 continue;
215 }
216 self.buffer.extend_from_slice(name.as_bytes());
217 self.buffer.extend_from_slice(b": ");
218 self.buffer.extend_from_slice(value);
219 self.buffer.extend_from_slice(b"\r\n");
220 }
221
222 self.buffer.extend_from_slice(b"content-length: ");
224 self.buffer
225 .extend_from_slice(body.len().to_string().as_bytes());
226 self.buffer.extend_from_slice(b"\r\n");
227
228 self.buffer.extend_from_slice(b"\r\n");
230
231 self.buffer.extend_from_slice(body);
233
234 self.take_buffer()
235 }
236
237 fn write_stream_head(&mut self, status: StatusCode, headers: &[(String, Vec<u8>)]) -> Vec<u8> {
238 self.buffer.clear();
239
240 self.buffer.extend_from_slice(b"HTTP/1.1 ");
242 self.write_status(status);
243 self.buffer.extend_from_slice(b"\r\n");
244
245 for (name, value) in headers {
247 if is_content_length(name) || is_transfer_encoding(name) {
248 continue;
249 }
250 self.buffer.extend_from_slice(name.as_bytes());
251 self.buffer.extend_from_slice(b": ");
252 self.buffer.extend_from_slice(value);
253 self.buffer.extend_from_slice(b"\r\n");
254 }
255
256 self.buffer
258 .extend_from_slice(b"transfer-encoding: chunked\r\n");
259
260 self.buffer.extend_from_slice(b"\r\n");
262
263 self.take_buffer()
264 }
265
266 fn write_status(&mut self, status: StatusCode) {
267 let code = status.as_u16();
268 self.buffer.extend_from_slice(code.to_string().as_bytes());
269 self.buffer.extend_from_slice(b" ");
270 self.buffer
271 .extend_from_slice(status.canonical_reason().as_bytes());
272 }
273
274 fn take_buffer(&mut self) -> Vec<u8> {
275 let mut out = Vec::new();
276 std::mem::swap(&mut out, &mut self.buffer);
277 self.buffer = Vec::with_capacity(out.capacity());
278 out
279 }
280}
281
282fn is_content_length(name: &str) -> bool {
283 name.eq_ignore_ascii_case("content-length")
284}
285
286fn is_transfer_encoding(name: &str) -> bool {
287 name.eq_ignore_ascii_case("transfer-encoding")
288}
289
290impl Default for ResponseWriter {
291 fn default() -> Self {
292 Self::new()
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use asupersync::stream::iter;
300 use std::sync::Arc;
301 use std::task::{Wake, Waker};
302
303 struct NoopWaker;
304
305 impl Wake for NoopWaker {
306 fn wake(self: Arc<Self>) {}
307 }
308
309 fn noop_waker() -> Waker {
310 Waker::from(Arc::new(NoopWaker))
311 }
312
313 fn collect_stream<S: Stream<Item = Vec<u8>> + Unpin>(mut stream: S) -> Vec<u8> {
314 let waker = noop_waker();
315 let mut cx = Context::from_waker(&waker);
316 let mut out = Vec::new();
317
318 loop {
319 match Pin::new(&mut stream).poll_next(&mut cx) {
320 Poll::Ready(Some(chunk)) => out.extend_from_slice(&chunk),
321 Poll::Ready(None) => break,
322 Poll::Pending => panic!("unexpected pending stream"),
323 }
324 }
325
326 out
327 }
328
329 #[test]
330 fn write_full_sets_content_length() {
331 let response = Response::ok()
332 .header("content-type", b"text/plain".to_vec())
333 .body(ResponseBody::Bytes(b"hello".to_vec()));
334 let mut writer = ResponseWriter::new();
335 let bytes = match writer.write(response) {
336 ResponseWrite::Full(bytes) => bytes,
337 ResponseWrite::Stream(_) => panic!("expected full response"),
338 };
339 let text = String::from_utf8_lossy(&bytes);
340 assert!(text.starts_with("HTTP/1.1 200 OK\r\n"));
341 assert!(text.contains("content-length: 5\r\n"));
342 assert!(text.contains("\r\n\r\nhello"));
343 }
344
345 #[test]
346 fn write_stream_uses_chunked_encoding() {
347 let stream = iter(vec![b"hello".to_vec(), b"world".to_vec()]);
348 let response = Response::ok()
349 .header("content-type", b"text/plain".to_vec())
350 .body(ResponseBody::stream(stream));
351 let mut writer = ResponseWriter::new();
352 let bytes = match writer.write(response) {
353 ResponseWrite::Stream(stream) => collect_stream(stream),
354 ResponseWrite::Full(_) => panic!("expected stream response"),
355 };
356
357 let expected = b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ntransfer-encoding: chunked\r\n\r\n5\r\nhello\r\n5\r\nworld\r\n0\r\n\r\n";
358 assert_eq!(bytes, expected);
359 }
360
361 #[test]
366 fn trailers_empty() {
367 let t = Trailers::new();
368 assert!(t.is_empty());
369 assert_eq!(t.trailer_header_value(), "");
370 }
371
372 #[test]
373 fn trailers_encode() {
374 let t = Trailers::new()
375 .add("Content-MD5", "abc123")
376 .add("Server-Timing", "total;dur=50");
377 assert!(!t.is_empty());
378 assert_eq!(t.trailer_header_value(), "Content-MD5, Server-Timing");
379 let encoded = t.encode();
380 let s = std::str::from_utf8(&encoded).unwrap();
381 assert!(s.contains("Content-MD5: abc123\r\n"));
382 assert!(s.contains("Server-Timing: total;dur=50\r\n"));
383 }
384
385 #[test]
386 fn chunked_encoder_with_trailers() {
387 let stream = iter(vec![b"data".to_vec()]);
388 let body = Box::pin(stream) as BodyStream;
389 let head = b"HTTP/1.1 200 OK\r\n\r\n".to_vec();
390 let trailers = Trailers::new().add("Checksum", "deadbeef");
391 let encoder = ChunkedEncoder::new(head, body).with_trailers(trailers);
392 let bytes = collect_stream(encoder);
393 let s = std::str::from_utf8(&bytes).unwrap();
394 assert!(s.contains("0\r\nChecksum: deadbeef\r\n\r\n"));
396 }
397
398 #[test]
399 fn chunked_encoder_without_trailers_unchanged() {
400 let stream = iter(vec![b"hi".to_vec()]);
401 let body = Box::pin(stream) as BodyStream;
402 let head = b"HTTP/1.1 200 OK\r\n\r\n".to_vec();
403 let encoder = ChunkedEncoder::new(head, body);
404 let bytes = collect_stream(encoder);
405 assert!(bytes.ends_with(b"0\r\n\r\n"));
406 }
407
408 #[test]
409 fn final_chunk_format_with_multiple_trailers() {
410 let t = Trailers::new()
411 .add("Digest", "sha-256=abc")
412 .add("Signature", "sig123");
413 let encoder = ChunkedEncoder {
414 head: None,
415 body: Box::pin(iter(Vec::<Vec<u8>>::new())),
416 finished: false,
417 trailers: Some(t),
418 };
419 let final_chunk = encoder.encode_final_chunk();
420 let s = std::str::from_utf8(&final_chunk).unwrap();
421 assert_eq!(s, "0\r\nDigest: sha-256=abc\r\nSignature: sig123\r\n\r\n");
422 }
423}