use std::io::{self, Read, Write};
use crate::ascii::HttpChar;
use crate::error::{Error, ResponseErrorKind};
use crate::headers::{Header, HeaderName, StatusCode};
use crate::validate::HttpValidate;
#[allow(clippy::cast_possible_truncation)] const DIGITS_LUT: [u8; 200] = {
let mut t = [0u8; 200];
let mut i = 0u16;
while i < 100 {
t[i as usize * 2] = HttpChar::Zero as u8 + (i / 10) as u8;
t[i as usize * 2 + 1] = HttpChar::Zero as u8 + (i % 10) as u8;
i += 1;
}
t
};
#[inline]
#[allow(clippy::cast_possible_truncation)]
fn format_usize(n: usize) -> FormattedUsize {
let mut buf = [0u8; 20]; let mut pos = buf.len();
let mut rem = n;
while rem >= 100 {
let pair = (rem % 100) * 2;
rem /= 100;
pos -= 2;
buf[pos] = DIGITS_LUT[pair];
buf[pos + 1] = DIGITS_LUT[pair + 1];
}
if rem >= 10 {
let pair = rem * 2;
pos -= 2;
buf[pos] = DIGITS_LUT[pair];
buf[pos + 1] = DIGITS_LUT[pair + 1];
} else {
pos -= 1;
buf[pos] = HttpChar::Zero + rem as u8;
}
FormattedUsize { buf, start: pos }
}
struct FormattedUsize {
buf: [u8; 20],
start: usize,
}
impl FormattedUsize {
#[inline]
fn as_bytes(&self) -> &[u8] {
&self.buf[self.start..]
}
}
pub const DEFAULT_MAX_RESPONSE_HEADERS: usize = 16;
#[must_use]
pub struct Response<'resp, const MAX_HDRS: usize = DEFAULT_MAX_RESPONSE_HEADERS> {
status: StatusCode,
headers: [Header<'resp>; MAX_HDRS],
header_count: usize,
content_length_idx: Option<usize>,
}
impl<'resp, const MAX_HDRS: usize> Response<'resp, MAX_HDRS> {
pub const fn new(status: StatusCode) -> Self {
Self {
status,
headers: [Header::EMPTY; MAX_HDRS],
header_count: 0,
content_length_idx: None,
}
}
pub fn header(
&mut self,
name: impl HeaderName<'resp>,
value: &'resp [u8],
) -> Result<&mut Self, Error> {
if self.header_count >= MAX_HDRS {
return Err(ResponseErrorKind::HeaderCapacityExceeded.into());
}
let name_bytes = name.as_header_bytes();
if !name_bytes.is_valid_token() {
return Err(ResponseErrorKind::InvalidHeaderName.into());
}
if !value.is_valid_header_value() {
return Err(ResponseErrorKind::InvalidHeaderValue.into());
}
let is_content_length = name.known_index().map_or_else(
|| name_bytes.eq_ignore_ascii_case(b"Content-Length"),
|idx| idx == crate::headers::ResponseHeader::ContentLength as usize,
);
if is_content_length {
if self.content_length_idx.is_some() {
return Err(ResponseErrorKind::DuplicateContentLength.into());
}
self.content_length_idx = Some(self.header_count);
}
self.headers[self.header_count] = Header::new(name_bytes, value);
self.header_count += 1;
Ok(self)
}
fn write_status_and_headers(&self, writer: &mut impl Write) -> io::Result<()> {
writer.write_all(self.status.status_line())?;
for i in 0..self.header_count {
let hdr = &self.headers[i];
writer.write_all(hdr.name())?;
writer.write_all(b": ")?;
writer.write_all(hdr.value())?;
writer.write_all(b"\r\n")?;
}
Ok(())
}
pub fn write_headers_to(&self, writer: &mut impl Write) -> io::Result<()> {
self.write_status_and_headers(writer)?;
writer.write_all(b"\r\n")
}
pub fn write(&self, writer: &mut impl Write, body: &[u8]) -> io::Result<()> {
if let Some(cl_idx) = self.content_length_idx {
let expected = format_usize(body.len());
if self.headers[cl_idx].value() != expected.as_bytes() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Content-Length header does not match body length",
));
}
}
self.write_status_and_headers(writer)?;
if self.content_length_idx.is_none() {
writer.write_all(b"Content-Length: ")?;
writer.write_all(format_usize(body.len()).as_bytes())?;
writer.write_all(b"\r\n")?;
}
writer.write_all(b"\r\n")?;
if !body.is_empty() {
writer.write_all(body)?;
}
Ok(())
}
pub fn write_streaming(&self, writer: &mut impl Write, body: &mut impl Read) -> io::Result<()> {
let content_length = self.content_length_idx.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"streaming response requires Content-Length header; \
use write_headers_to and encode chunks manually for chunked responses",
)
})?;
let limit = crate::ascii::parse_content_length(self.headers[content_length].value())
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Content-Length is not a valid integer",
)
})?;
self.write_headers_to(writer)?;
let copied = io::copy(&mut body.take(limit), writer)?;
if copied < limit {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"body reader produced fewer bytes than Content-Length",
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::ResponseErrorKind;
use crate::headers::ResponseHeader;
fn render(response: &Response, body: &[u8]) -> String {
let mut out = Vec::new();
response.write(&mut out, body).unwrap();
String::from_utf8(out).unwrap()
}
fn render_empty(response: &Response) -> String {
render(response, b"")
}
#[test]
fn basic_200_response() {
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Type", b"text/plain").unwrap();
let rendered = render(&resp, b"hello");
assert!(rendered.starts_with("HTTP/1.1 200 OK\r\n"));
assert!(rendered.contains("Content-Type: text/plain\r\n"));
assert!(rendered.contains("\r\n\r\nhello"));
}
#[test]
fn basic_404_response() {
let resp = render(&Response::new(StatusCode::NotFound), b"not found");
assert!(resp.starts_with("HTTP/1.1 404 Not Found\r\n"));
assert!(resp.ends_with("not found"));
}
#[test]
fn response_with_no_body() {
let resp = render_empty(&Response::new(StatusCode::Ok));
assert!(resp.ends_with("\r\n\r\n"));
}
#[test]
fn response_with_binary_body() {
let body = &[0u8, 1, 2, 255, 254, 253];
let mut out = Vec::new();
Response::<16>::new(StatusCode::Ok)
.write(&mut out, body)
.unwrap();
assert!(out.ends_with(body));
}
#[test]
fn multiple_headers() {
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Type", b"text/html").unwrap();
resp.header("Cache-Control", b"no-cache").unwrap();
resp.header("X-Custom", b"value").unwrap();
let rendered = render(&resp, b"");
assert!(rendered.contains("Content-Type: text/html\r\n"));
assert!(rendered.contains("Cache-Control: no-cache\r\n"));
assert!(rendered.contains("X-Custom: value\r\n"));
}
#[test]
fn typed_response_headers() {
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header(ResponseHeader::ContentType, b"text/html")
.unwrap();
resp.header(ResponseHeader::CacheControl, b"no-store")
.unwrap();
resp.header(ResponseHeader::Connection, b"close").unwrap();
let rendered = render(&resp, b"");
assert!(rendered.contains("Content-Type: text/html\r\n"));
assert!(rendered.contains("Cache-Control: no-store\r\n"));
assert!(rendered.contains("Connection: close\r\n"));
}
#[test]
fn all_status_codes() {
for (status, expected) in [
(StatusCode::Continue, "100 Continue"),
(StatusCode::SwitchingProtocols, "101 Switching Protocols"),
(StatusCode::Ok, "200 OK"),
(StatusCode::Created, "201 Created"),
(StatusCode::Accepted, "202 Accepted"),
(StatusCode::NoContent, "204 No Content"),
(StatusCode::MovedPermanently, "301 Moved Permanently"),
(StatusCode::Found, "302 Found"),
(StatusCode::SeeOther, "303 See Other"),
(StatusCode::NotModified, "304 Not Modified"),
(StatusCode::TemporaryRedirect, "307 Temporary Redirect"),
(StatusCode::PermanentRedirect, "308 Permanent Redirect"),
(StatusCode::BadRequest, "400 Bad Request"),
(StatusCode::Unauthorized, "401 Unauthorized"),
(StatusCode::Forbidden, "403 Forbidden"),
(StatusCode::NotFound, "404 Not Found"),
(StatusCode::MethodNotAllowed, "405 Method Not Allowed"),
(StatusCode::Conflict, "409 Conflict"),
(StatusCode::Gone, "410 Gone"),
(StatusCode::LengthRequired, "411 Length Required"),
(StatusCode::PayloadTooLarge, "413 Payload Too Large"),
(StatusCode::UriTooLong, "414 URI Too Long"),
(
StatusCode::UnsupportedMediaType,
"415 Unsupported Media Type",
),
(StatusCode::UnprocessableEntity, "422 Unprocessable Entity"),
(StatusCode::TooManyRequests, "429 Too Many Requests"),
(StatusCode::InternalServerError, "500 Internal Server Error"),
(StatusCode::NotImplemented, "501 Not Implemented"),
(StatusCode::BadGateway, "502 Bad Gateway"),
(StatusCode::ServiceUnavailable, "503 Service Unavailable"),
(StatusCode::GatewayTimeout, "504 Gateway Timeout"),
] {
let resp = render_empty(&Response::new(status));
assert!(
resp.starts_with(&format!("HTTP/1.1 {expected}\r\n")),
"expected {expected}, got: {resp}"
);
}
}
#[test]
fn write_headers_only() {
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Type", b"text/plain").unwrap();
resp.write_headers_to(&mut out).unwrap();
let headers = String::from_utf8(out).unwrap();
assert!(headers.starts_with("HTTP/1.1 200 OK\r\n"));
assert!(headers.contains("Content-Type: text/plain\r\n"));
assert!(headers.ends_with("\r\n\r\n"));
}
#[test]
fn write_headers_then_manual_body() {
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Type", b"text/plain").unwrap();
resp.write_headers_to(&mut out).unwrap();
out.extend_from_slice(b"manual body");
let resp = String::from_utf8(out).unwrap();
assert!(resp.ends_with("manual body"));
}
#[test]
fn write_streaming_from_reader() {
let body = b"streamed content";
let mut reader = &body[..];
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Type", b"application/octet-stream")
.unwrap();
resp.header("Content-Length", b"16").unwrap();
resp.write_streaming(&mut out, &mut reader).unwrap();
let rendered = String::from_utf8(out).unwrap();
assert!(rendered.starts_with("HTTP/1.1 200 OK\r\n"));
assert!(rendered.ends_with("streamed content"));
}
#[test]
fn write_streaming_empty_reader() {
let mut reader = std::io::empty();
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"0").unwrap();
resp.write_streaming(&mut out, &mut reader).unwrap();
let rendered = String::from_utf8(out).unwrap();
assert!(rendered.ends_with("\r\n\r\n"));
}
#[test]
fn custom_small_response_capacity() {
let mut out = Vec::new();
let mut resp = Response::<2>::new(StatusCode::Ok);
resp.header("A", b"1").unwrap();
resp.header("B", b"2").unwrap();
resp.write(&mut out, b"").unwrap();
let s = String::from_utf8(out).unwrap();
assert!(s.contains("A: 1\r\n"));
assert!(s.contains("B: 2\r\n"));
}
#[test]
fn exceeding_header_capacity_returns_error() {
let mut resp = Response::<2>::new(StatusCode::Ok);
resp.header("A", b"1").unwrap();
resp.header("B", b"2").unwrap();
let err = resp.header("C", b"3").err().expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::HeaderCapacityExceeded)
));
}
#[test]
fn rejects_value_with_crlf_injection() {
let mut resp = Response::<16>::new(StatusCode::Ok);
let err = resp
.header("Content-Type", b"text/html\r\nX-Injected: evil")
.err()
.expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::InvalidHeaderValue)
));
}
#[test]
fn rejects_value_with_nul() {
let mut resp = Response::<16>::new(StatusCode::Ok);
let err = resp
.header("X-Bad", b"val\x00ue")
.err()
.expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::InvalidHeaderValue)
));
}
#[test]
fn rejects_value_with_bare_cr() {
let mut resp = Response::<16>::new(StatusCode::Ok);
let err = resp
.header("X-Bad", b"val\rue")
.err()
.expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::InvalidHeaderValue)
));
}
#[test]
fn rejects_value_with_bare_lf() {
let mut resp = Response::<16>::new(StatusCode::Ok);
let err = resp
.header("X-Bad", b"val\nue")
.err()
.expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::InvalidHeaderValue)
));
}
#[test]
fn rejects_invalid_header_name() {
let mut resp = Response::<16>::new(StatusCode::Ok);
let err = resp
.header("Bad Name", b"value")
.err()
.expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::InvalidHeaderName)
));
}
#[test]
fn rejects_header_name_with_crlf() {
let mut resp = Response::<16>::new(StatusCode::Ok);
let err = resp
.header("Bad\r\nName", b"value")
.err()
.expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::InvalidHeaderName)
));
}
#[test]
fn rejects_empty_header_name() {
let mut resp = Response::<16>::new(StatusCode::Ok);
let err = resp.header("", b"value").err().expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::InvalidHeaderName)
));
}
#[test]
fn allows_valid_header_value_with_high_bytes() {
let mut resp = Response::<16>::new(StatusCode::Ok);
assert!(resp.header("X-Custom", &[0x80, 0xFF, 0xFE]).is_ok());
}
#[test]
fn builder_recovers_from_error() {
let mut resp = Response::<16>::new(StatusCode::Ok);
let err = resp
.header("Bad Name", b"value")
.err()
.expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::InvalidHeaderName)
));
resp.header("Good-Name", b"value").unwrap();
let rendered = render(&resp, b"");
assert!(rendered.contains("Good-Name: value\r\n"));
}
#[test]
fn auto_content_length_added() {
let resp = render(&Response::new(StatusCode::Ok), b"hello");
assert!(resp.contains("Content-Length: 5\r\n"));
}
#[test]
fn auto_content_length_zero() {
let resp = render_empty(&Response::new(StatusCode::Ok));
assert!(resp.contains("Content-Length: 0\r\n"));
}
#[test]
fn manual_content_length_not_duplicated() {
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"5").unwrap();
let rendered = render(&resp, b"hello");
assert_eq!(rendered.matches("Content-Length").count(), 1);
}
#[test]
fn rejects_content_length_mismatch() {
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"999").unwrap();
let err = resp.write(&mut out, b"hello").unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[test]
fn accepts_correct_manual_content_length() {
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"5").unwrap();
resp.write(&mut out, b"hello").unwrap();
let rendered = String::from_utf8(out).unwrap();
assert!(rendered.contains("Content-Length: 5\r\n"));
assert!(rendered.ends_with("hello"));
}
#[test]
fn streaming_rejects_missing_framing() {
let mut reader = &b"body"[..];
let mut out = Vec::new();
let err = Response::<16>::new(StatusCode::Ok)
.write_streaming(&mut out, &mut reader)
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[test]
fn streaming_accepts_content_length() {
let mut reader = &b"body"[..];
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"4").unwrap();
resp.write_streaming(&mut out, &mut reader).unwrap();
let rendered = String::from_utf8(out).unwrap();
assert!(rendered.ends_with("body"));
}
#[test]
fn streaming_rejects_transfer_encoding_without_content_length() {
let mut reader = &b"4\r\nbody\r\n0\r\n\r\n"[..];
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Transfer-Encoding", b"chunked").unwrap();
let err = resp.write_streaming(&mut out, &mut reader).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[test]
fn streaming_limits_body_to_content_length() {
let mut reader = &b"bodySECRET"[..];
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"4").unwrap();
resp.write_streaming(&mut out, &mut reader).unwrap();
let rendered = String::from_utf8(out).unwrap();
assert!(rendered.ends_with("body"));
assert!(!rendered.contains("SECRET"));
}
#[test]
fn streaming_rejects_short_body() {
let mut reader = &b"hi"[..];
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"10").unwrap();
let err = resp.write_streaming(&mut out, &mut reader).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
#[test]
fn streaming_rejects_invalid_content_length() {
let mut reader = &b"body"[..];
let mut out = Vec::new();
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"not-a-number").unwrap();
let err = resp.write_streaming(&mut out, &mut reader).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[test]
fn rejects_duplicate_content_length_header() {
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"5").unwrap();
let err = resp
.header("Content-Length", b"10")
.err()
.expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::DuplicateContentLength)
));
}
#[test]
fn rejects_duplicate_content_length_case_insensitive() {
let mut resp = Response::<16>::new(StatusCode::Ok);
resp.header("Content-Length", b"5").unwrap();
let err = resp
.header("content-length", b"5")
.err()
.expect("expected error");
assert!(matches!(
err,
Error::Response(ResponseErrorKind::DuplicateContentLength)
));
}
}