1use asupersync::stream::Stream;
4use fastapi_core::{BodyStream, Response, ResponseBody, StatusCode};
5use std::borrow::Cow;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pub enum ResponseWrite {
11 Full(Vec<u8>),
13 Stream(ChunkedEncoder),
15}
16
17#[derive(Debug, Clone, Default)]
33pub struct Trailers {
34 headers: Vec<(String, String)>,
35}
36
37impl Trailers {
38 #[must_use]
40 pub fn new() -> Self {
41 Self::default()
42 }
43
44 #[must_use]
46 pub fn add(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
47 self.headers.push((name.into(), value.into()));
48 self
49 }
50
51 #[must_use]
53 pub fn is_empty(&self) -> bool {
54 self.headers.is_empty()
55 }
56
57 #[must_use]
60 pub fn trailer_header_value(&self) -> String {
61 self.headers
62 .iter()
63 .map(|(n, _)| n.as_str())
64 .collect::<Vec<_>>()
65 .join(", ")
66 }
67
68 fn encode(&self) -> Vec<u8> {
72 let mut out = Vec::new();
73 for (name, value) in &self.headers {
74 write_header_line(&mut out, name, value.as_bytes());
75 }
76 out
77 }
78}
79
80pub struct ChunkedEncoder {
82 head: Option<Vec<u8>>,
83 body: BodyStream,
84 finished: bool,
85 trailers: Option<Trailers>,
86}
87
88impl ChunkedEncoder {
89 fn new(head: Vec<u8>, body: BodyStream) -> Self {
90 Self {
91 head: Some(head),
92 body,
93 finished: false,
94 trailers: None,
95 }
96 }
97
98 #[must_use]
100 pub fn with_trailers(mut self, trailers: Trailers) -> Self {
101 self.trailers = Some(trailers);
102 self
103 }
104
105 fn encode_chunk(chunk: &[u8]) -> Vec<u8> {
106 use std::io::Write as _;
109 let mut out = Vec::with_capacity(20 + chunk.len() + 4);
110 write!(out, "{:x}\r\n", chunk.len()).expect("write to Vec cannot fail");
111 out.extend_from_slice(chunk);
112 out.extend_from_slice(b"\r\n");
113 out
114 }
115
116 fn encode_final_chunk(&self) -> Vec<u8> {
122 let mut out = Vec::new();
123 out.extend_from_slice(b"0\r\n");
124 if let Some(ref trailers) = self.trailers {
125 out.extend_from_slice(&trailers.encode());
126 }
127 out.extend_from_slice(b"\r\n");
128 out
129 }
130}
131
132impl Stream for ChunkedEncoder {
133 type Item = Vec<u8>;
134
135 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136 if let Some(head) = self.head.take() {
137 return Poll::Ready(Some(head));
138 }
139
140 if self.finished {
141 return Poll::Ready(None);
142 }
143
144 loop {
145 match self.body.as_mut().poll_next(cx) {
146 Poll::Pending => return Poll::Pending,
147 Poll::Ready(Some(chunk)) => {
148 if chunk.is_empty() {
149 continue;
150 }
151 return Poll::Ready(Some(Self::encode_chunk(&chunk)));
152 }
153 Poll::Ready(None) => {
154 self.finished = true;
155 return Poll::Ready(Some(self.encode_final_chunk()));
156 }
157 }
158 }
159 }
160}
161
162pub struct ResponseWriter {
164 buffer: Vec<u8>,
165}
166
167impl ResponseWriter {
168 #[must_use]
170 pub fn new() -> Self {
171 Self {
172 buffer: Vec::with_capacity(4096),
173 }
174 }
175
176 #[must_use]
178 pub fn write(&mut self, response: Response) -> ResponseWrite {
179 let (status, headers, body) = response.into_parts();
180 match body {
181 ResponseBody::Empty => {
182 let bytes = self.write_full(status, &headers, &[]);
183 ResponseWrite::Full(bytes)
184 }
185 ResponseBody::Bytes(body) => {
186 let bytes = self.write_full(status, &headers, &body);
187 ResponseWrite::Full(bytes)
188 }
189 ResponseBody::Stream(body) => {
190 let head = self.write_stream_head(status, &headers);
191 ResponseWrite::Stream(ChunkedEncoder::new(head, body))
192 }
193 }
194 }
195
196 fn write_full(
197 &mut self,
198 status: StatusCode,
199 headers: &[(String, Vec<u8>)],
200 body: &[u8],
201 ) -> Vec<u8> {
202 self.buffer.clear();
203
204 self.buffer.extend_from_slice(b"HTTP/1.1 ");
206 self.write_status(status);
207 self.buffer.extend_from_slice(b"\r\n");
208
209 for (name, value) in headers {
211 if is_content_length(name) || is_transfer_encoding(name) {
212 continue;
213 }
214 write_header_line(&mut self.buffer, name, value);
215 }
216
217 self.buffer.extend_from_slice(b"content-length: ");
219 self.buffer
220 .extend_from_slice(body.len().to_string().as_bytes());
221 self.buffer.extend_from_slice(b"\r\n");
222
223 self.buffer.extend_from_slice(b"\r\n");
225
226 self.buffer.extend_from_slice(body);
228
229 self.take_buffer()
230 }
231
232 fn write_stream_head(&mut self, status: StatusCode, headers: &[(String, Vec<u8>)]) -> Vec<u8> {
233 self.buffer.clear();
234
235 self.buffer.extend_from_slice(b"HTTP/1.1 ");
237 self.write_status(status);
238 self.buffer.extend_from_slice(b"\r\n");
239
240 for (name, value) in headers {
242 if is_content_length(name) || is_transfer_encoding(name) {
243 continue;
244 }
245 write_header_line(&mut self.buffer, name, value);
246 }
247
248 self.buffer
250 .extend_from_slice(b"transfer-encoding: chunked\r\n");
251
252 self.buffer.extend_from_slice(b"\r\n");
254
255 self.take_buffer()
256 }
257
258 fn write_status(&mut self, status: StatusCode) {
259 let code = status.as_u16();
260 self.buffer.extend_from_slice(code.to_string().as_bytes());
261 self.buffer.extend_from_slice(b" ");
262 self.buffer
263 .extend_from_slice(status.canonical_reason().as_bytes());
264 }
265
266 fn take_buffer(&mut self) -> Vec<u8> {
267 let mut out = Vec::new();
268 std::mem::swap(&mut out, &mut self.buffer);
269 self.buffer = Vec::with_capacity(out.capacity());
270 out
271 }
272}
273
274fn is_content_length(name: &str) -> bool {
275 name.eq_ignore_ascii_case("content-length")
276}
277
278fn is_transfer_encoding(name: &str) -> bool {
279 name.eq_ignore_ascii_case("transfer-encoding")
280}
281
282fn write_header_line(buffer: &mut Vec<u8>, name: &str, value: &[u8]) {
283 if !is_valid_header_name(name) {
284 return;
285 }
286 buffer.extend_from_slice(name.as_bytes());
287 buffer.extend_from_slice(b": ");
288 buffer.extend_from_slice(sanitize_header_value(value).as_ref());
289 buffer.extend_from_slice(b"\r\n");
290}
291
292fn sanitize_header_value(value: &[u8]) -> Cow<'_, [u8]> {
293 if value
294 .iter()
295 .all(|&byte| byte != b'\r' && byte != b'\n' && byte != 0)
296 {
297 return Cow::Borrowed(value);
298 }
299 Cow::Owned(
300 value
301 .iter()
302 .copied()
303 .filter(|&byte| byte != b'\r' && byte != b'\n' && byte != 0)
304 .collect(),
305 )
306}
307
308fn is_valid_header_name(name: &str) -> bool {
309 !name.is_empty()
310 && name.bytes().all(|byte| {
311 matches!(
312 byte,
313 b'!' | b'#'
314 | b'$'
315 | b'%'
316 | b'&'
317 | b'\''
318 | b'*'
319 | b'+'
320 | b'-'
321 | b'.'
322 | b'0'..=b'9'
323 | b'A'..=b'Z'
324 | b'^'
325 | b'_'
326 | b'`'
327 | b'a'..=b'z'
328 | b'|'
329 | b'~'
330 )
331 })
332}
333
334impl Default for ResponseWriter {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use asupersync::stream::iter;
344 use std::sync::Arc;
345 use std::task::{Wake, Waker};
346
347 struct NoopWaker;
348
349 impl Wake for NoopWaker {
350 fn wake(self: Arc<Self>) {}
351 }
352
353 fn noop_waker() -> Waker {
354 Waker::from(Arc::new(NoopWaker))
355 }
356
357 fn collect_stream<S: Stream<Item = Vec<u8>> + Unpin>(mut stream: S) -> Vec<u8> {
358 let waker = noop_waker();
359 let mut cx = Context::from_waker(&waker);
360 let mut out = Vec::new();
361
362 loop {
363 match Pin::new(&mut stream).poll_next(&mut cx) {
364 Poll::Ready(Some(chunk)) => out.extend_from_slice(&chunk),
365 Poll::Ready(None) => break,
366 Poll::Pending => panic!("unexpected pending stream"),
367 }
368 }
369
370 out
371 }
372
373 #[test]
374 fn write_full_sets_content_length() {
375 let response = Response::ok()
376 .header("content-type", b"text/plain".to_vec())
377 .body(ResponseBody::Bytes(b"hello".to_vec()));
378 let mut writer = ResponseWriter::new();
379 let bytes = match writer.write(response) {
380 ResponseWrite::Full(bytes) => bytes,
381 ResponseWrite::Stream(_) => panic!("expected full response"),
382 };
383 let text = String::from_utf8_lossy(&bytes);
384 assert!(text.starts_with("HTTP/1.1 200 OK\r\n"));
385 assert!(text.contains("content-length: 5\r\n"));
386 assert!(text.contains("\r\n\r\nhello"));
387 }
388
389 #[test]
390 fn write_stream_uses_chunked_encoding() {
391 let stream = iter(vec![b"hello".to_vec(), b"world".to_vec()]);
392 let response = Response::ok()
393 .header("content-type", b"text/plain".to_vec())
394 .body(ResponseBody::stream(stream));
395 let mut writer = ResponseWriter::new();
396 let bytes = match writer.write(response) {
397 ResponseWrite::Stream(stream) => collect_stream(stream),
398 ResponseWrite::Full(_) => panic!("expected stream response"),
399 };
400
401 let expected = b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ntransfer-encoding: chunked\r\n\r\n5\r\nhello\r\n5\r\nworld\r\n0\r\n\r\n";
402 assert_eq!(bytes, expected);
403 }
404
405 #[test]
410 fn trailers_empty() {
411 let t = Trailers::new();
412 assert!(t.is_empty());
413 assert_eq!(t.trailer_header_value(), "");
414 }
415
416 #[test]
417 fn trailers_encode() {
418 let t = Trailers::new()
419 .add("Content-MD5", "abc123")
420 .add("Server-Timing", "total;dur=50");
421 assert!(!t.is_empty());
422 assert_eq!(t.trailer_header_value(), "Content-MD5, Server-Timing");
423 let encoded = t.encode();
424 let s = std::str::from_utf8(&encoded).unwrap();
425 assert!(s.contains("Content-MD5: abc123\r\n"));
426 assert!(s.contains("Server-Timing: total;dur=50\r\n"));
427 }
428
429 #[test]
430 fn chunked_encoder_with_trailers() {
431 let stream = iter(vec![b"data".to_vec()]);
432 let body = Box::pin(stream) as BodyStream;
433 let head = b"HTTP/1.1 200 OK\r\n\r\n".to_vec();
434 let trailers = Trailers::new().add("Checksum", "deadbeef");
435 let encoder = ChunkedEncoder::new(head, body).with_trailers(trailers);
436 let bytes = collect_stream(encoder);
437 let s = std::str::from_utf8(&bytes).unwrap();
438 assert!(s.contains("0\r\nChecksum: deadbeef\r\n\r\n"));
440 }
441
442 #[test]
443 fn chunked_encoder_without_trailers_unchanged() {
444 let stream = iter(vec![b"hi".to_vec()]);
445 let body = Box::pin(stream) as BodyStream;
446 let head = b"HTTP/1.1 200 OK\r\n\r\n".to_vec();
447 let encoder = ChunkedEncoder::new(head, body);
448 let bytes = collect_stream(encoder);
449 assert!(bytes.ends_with(b"0\r\n\r\n"));
450 }
451
452 #[test]
453 fn final_chunk_format_with_multiple_trailers() {
454 let t = Trailers::new()
455 .add("Digest", "sha-256=abc")
456 .add("Signature", "sig123");
457 let encoder = ChunkedEncoder {
458 head: None,
459 body: Box::pin(iter(Vec::<Vec<u8>>::new())),
460 finished: false,
461 trailers: Some(t),
462 };
463 let final_chunk = encoder.encode_final_chunk();
464 let s = std::str::from_utf8(&final_chunk).unwrap();
465 assert_eq!(s, "0\r\nDigest: sha-256=abc\r\nSignature: sig123\r\n\r\n");
466 }
467
468 #[test]
469 fn write_full_drops_invalid_header_names_and_sanitizes_values() {
470 let mut writer = ResponseWriter::new();
471 let headers = vec![
472 ("x-ok".to_string(), b"safe".to_vec()),
473 ("bad\r\nname".to_string(), b"ignored".to_vec()),
474 ("x-test".to_string(), b"hello\r\nx-injected: yes".to_vec()),
475 ];
476
477 let bytes = writer.write_full(StatusCode::OK, &headers, b"body");
478 let text = String::from_utf8_lossy(&bytes);
479
480 assert!(text.contains("x-ok: safe\r\n"));
481 assert!(!text.contains("bad\r\nname:"));
482 assert!(text.contains("x-test: hellox-injected: yes\r\n"));
483 assert!(!text.contains("\r\nx-injected: yes\r\n"));
484 }
485
486 #[test]
487 fn write_stream_head_drops_invalid_header_names_and_sanitizes_values() {
488 let mut writer = ResponseWriter::new();
489 let headers = vec![
490 ("content-type".to_string(), b"text/plain".to_vec()),
491 ("bad\nname".to_string(), b"ignored".to_vec()),
492 ("x-test".to_string(), b"hello\r\nx-injected: yes".to_vec()),
493 ];
494
495 let bytes = writer.write_stream_head(StatusCode::OK, &headers);
496 let text = String::from_utf8_lossy(&bytes);
497
498 assert!(text.contains("content-type: text/plain\r\n"));
499 assert!(!text.contains("bad\nname:"));
500 assert!(text.contains("x-test: hellox-injected: yes\r\n"));
501 assert!(!text.contains("\r\nx-injected: yes\r\n"));
502 }
503
504 #[test]
505 fn trailers_encode_drops_invalid_names_and_sanitizes_values() {
506 let encoded = Trailers::new()
507 .add("Checksum", "abc123")
508 .add("Bad\r\nName", "ignored")
509 .add("Signature", "sig\r\nInjected: yes")
510 .encode();
511 let text = std::str::from_utf8(&encoded).unwrap();
512
513 assert!(text.contains("Checksum: abc123\r\n"));
514 assert!(!text.contains("Bad\r\nName"));
515 assert!(text.contains("Signature: sigInjected: yes\r\n"));
516 assert!(!text.contains("\r\nInjected: yes\r\n"));
517 }
518}