use crate::compression::{CompressionError, CompressionStatus, Compressor, NoCompression};
use crate::decoder::HttpHead;
use crate::error::EncodeError;
use crate::host::Host;
use crate::request::Request;
use crate::request_target::RequestTargetForm;
use crate::response::Response;
use crate::validate::{
is_valid_field_value, is_valid_header_name, is_valid_method, is_valid_reason_phrase,
is_valid_request_target, is_valid_status_code, trim_ows,
};
use alloc::string::{String, ToString};
use alloc::vec::Vec;
fn is_valid_version_for_encode(version: &str) -> bool {
!version.is_empty() && version.bytes().all(|b| matches!(b, 0x21..=0x7E))
}
fn write_hex_usize(buf: &mut Vec<u8>, n: usize) {
if n == 0 {
buf.push(b'0');
return;
}
let mut tmp = [0u8; 16]; let mut i = tmp.len();
let mut remaining = n;
while remaining > 0 {
i -= 1;
let nibble = (remaining & 0xF) as u8;
tmp[i] = if nibble < 10 {
b'0' + nibble
} else {
b'a' + nibble - 10
};
remaining >>= 4;
}
buf.extend_from_slice(&tmp[i..]);
}
fn write_usize_decimal(buf: &mut Vec<u8>, n: usize) {
if n == 0 {
buf.push(b'0');
return;
}
let mut tmp = [0u8; 20]; let mut i = tmp.len();
let mut remaining = n;
while remaining > 0 {
i -= 1;
tmp[i] = b'0' + (remaining % 10) as u8;
remaining /= 10;
}
buf.extend_from_slice(&tmp[i..]);
}
const ENCODE_CAPACITY_LIMIT: usize = 64 * 1024 * 1024;
const AUTO_CONTENT_LENGTH_CAPACITY: usize = 38;
fn should_auto_emit_content_length_for_request(request: &Request) -> bool {
request.body_bytes().is_some()
&& !request.has_header("Content-Length")
&& !request.has_header("Transfer-Encoding")
}
fn response_status_has_body(status_code: u16) -> bool {
!((100..200).contains(&status_code) || status_code == 204 || status_code == 304)
}
fn should_auto_emit_content_length_for_response(response: &Response) -> bool {
let status_has_body = response_status_has_body(response.status_code());
let body_len = response.body_bytes().map(<[u8]>::len);
status_has_body
&& !response.has_header("Content-Length")
&& !response.has_header("Transfer-Encoding")
&& match (response.is_body_omitted(), body_len) {
(_, None) => false,
(true, Some(0)) => false,
(_, Some(_)) => true,
}
}
fn estimate_request_capacity(request: &Request) -> Option<usize> {
let mut total: usize = 0;
total = total.checked_add(request.method().len())?;
total = total.checked_add(request.uri().len())?;
total = total.checked_add(request.version().len())?;
total = total.checked_add(4)?;
for (name, value) in HttpHead::headers(request) {
total = total.checked_add(name.len())?;
total = total.checked_add(value.len())?;
total = total.checked_add(4)?;
}
if should_auto_emit_content_length_for_request(request) {
total = total.checked_add(AUTO_CONTENT_LENGTH_CAPACITY)?;
}
total = total.checked_add(2)?;
if let Some(body) = request.body_bytes() {
total = total.checked_add(body.len())?;
}
Some(total)
}
fn estimate_response_capacity(response: &Response) -> Option<usize> {
let mut total: usize = 0;
total = total.checked_add(HttpHead::version(response).len())?;
total = total.checked_add(3)?; total = total.checked_add(response.reason_phrase().len())?;
total = total.checked_add(4)?;
for (name, value) in HttpHead::headers(response) {
total = total.checked_add(name.len())?;
total = total.checked_add(value.len())?;
total = total.checked_add(4)?;
}
if should_auto_emit_content_length_for_response(response) {
total = total.checked_add(AUTO_CONTENT_LENGTH_CAPACITY)?;
}
total = total.checked_add(2)?;
let body_will_be_encoded =
response_status_has_body(response.status_code()) && !response.is_body_omitted();
if body_will_be_encoded && let Some(body) = response.body_bytes() {
total = total.checked_add(body.len())?;
}
Some(total)
}
fn allocate_encode_buffer(estimated: Option<usize>) -> Vec<u8> {
match estimated {
Some(c) if c <= ENCODE_CAPACITY_LIMIT => Vec::with_capacity(c),
_ => Vec::new(),
}
}
fn validate_request_fields(request: &Request) -> Result<(), EncodeError> {
if !is_valid_method(request.method()) {
return Err(EncodeError::InvalidMethod {
method: request.method().to_string(),
});
}
if !is_valid_request_target(request.uri()) {
return Err(EncodeError::InvalidRequestTarget {
uri: request.uri().to_string(),
});
}
if request.uri().bytes().any(|b| b > 0x7E) {
return Err(EncodeError::InvalidRequestTarget {
uri: request.uri().to_string(),
});
}
validate_request_target_form(request.method(), request.uri())?;
if !is_valid_version_for_encode(request.version()) {
return Err(EncodeError::InvalidVersion {
version: request.version().to_string(),
});
}
validate_headers(HttpHead::headers(request))?;
Ok(())
}
fn detect_request_target_form(uri: &str) -> Result<RequestTargetForm, EncodeError> {
if uri == "*" {
Ok(RequestTargetForm::Asterisk)
} else if uri.contains("://") {
Ok(RequestTargetForm::Absolute)
} else if uri.starts_with('/') {
Ok(RequestTargetForm::Origin)
} else if looks_like_authority_form(uri) {
Ok(RequestTargetForm::Authority)
} else if detect_scheme(uri).is_some() {
Ok(RequestTargetForm::Absolute)
} else {
Err(EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
})
}
}
fn looks_like_authority_form(uri: &str) -> bool {
if uri.contains('@') {
return false;
}
if let Some(colon_pos) = uri.rfind(':') {
let port_str = &uri[colon_pos + 1..];
let host = &uri[..colon_pos];
!port_str.is_empty()
&& port_str.bytes().all(|b| b.is_ascii_digit())
&& port_str.parse::<u16>().is_ok()
&& !host.is_empty()
} else {
false
}
}
fn validate_encoder_authority_form(uri: &str) -> Result<(), EncodeError> {
if let Some(colon_pos) = uri.rfind(':') {
let host = &uri[..colon_pos];
Host::parse(host).map_err(|_| EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
})?;
}
Ok(())
}
fn detect_scheme(target: &str) -> Option<usize> {
let bytes = target.as_bytes();
if bytes.is_empty() || !bytes[0].is_ascii_alphabetic() {
return None;
}
let colon_pos = bytes.iter().position(|&b| b == b':')?;
if colon_pos == 0 {
return None;
}
for &b in &bytes[1..colon_pos] {
if !b.is_ascii_alphanumeric() && b != b'+' && b != b'-' && b != b'.' {
return None;
}
}
if colon_pos + 1 >= bytes.len() {
return None;
}
Some(colon_pos)
}
fn validate_request_target_form(method: &str, uri: &str) -> Result<(), EncodeError> {
let form = detect_request_target_form(uri)?;
match (method, &form) {
("CONNECT", RequestTargetForm::Authority) => {
validate_encoder_authority_form(uri)?;
Ok(())
}
("CONNECT", _) => Err(EncodeError::InvalidRequestTargetForm {
method: method.to_string(),
uri: uri.to_string(),
}),
(_, RequestTargetForm::Asterisk) => {
if method == "OPTIONS" {
Ok(())
} else {
Err(EncodeError::InvalidRequestTargetForm {
method: method.to_string(),
uri: uri.to_string(),
})
}
}
(_, RequestTargetForm::Authority) => Err(EncodeError::InvalidRequestTargetForm {
method: method.to_string(),
uri: uri.to_string(),
}),
(_, RequestTargetForm::Origin) => {
if uri.bytes().any(|b| b == b'[' || b == b']') {
return Err(EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
});
}
Ok(())
}
(_, RequestTargetForm::Absolute) => {
reject_http_without_authority_prefix(uri)?;
Ok(())
}
}
}
fn validate_response_fields(response: &Response) -> Result<(), EncodeError> {
if !is_valid_version_for_encode(HttpHead::version(response)) {
return Err(EncodeError::InvalidVersion {
version: HttpHead::version(response).to_string(),
});
}
if !is_valid_status_code(response.status_code()) {
return Err(EncodeError::InvalidStatusCode {
code: response.status_code(),
});
}
if !response.reason_phrase().is_empty() && !is_valid_reason_phrase(response.reason_phrase()) {
return Err(EncodeError::InvalidReasonPhrase {
phrase: response.reason_phrase().to_string(),
});
}
validate_headers(HttpHead::headers(response))?;
Ok(())
}
fn validate_headers(headers: &[(String, String)]) -> Result<(), EncodeError> {
for (name, value) in headers {
if !is_valid_header_name(name) {
return Err(EncodeError::InvalidHeaderName { name: name.clone() });
}
if !is_valid_field_value(value) {
return Err(EncodeError::InvalidHeaderValue {
name: name.clone(),
value: value.clone(),
});
}
}
Ok(())
}
fn validate_host_header(request: &Request) -> Result<(), EncodeError> {
if request.version() != "HTTP/1.1" {
return Ok(());
}
let host_headers: Vec<&str> = HttpHead::headers(request)
.iter()
.filter(|(name, _)| name.eq_ignore_ascii_case("Host"))
.map(|(_, value)| value.as_str())
.collect();
if host_headers.is_empty() {
return Err(EncodeError::MissingHostHeader);
}
if host_headers.len() > 1 {
return Err(EncodeError::DuplicateHostHeader);
}
let host_value = host_headers[0];
if !host_value.is_empty() && Host::parse(host_value).is_err() {
return Err(EncodeError::InvalidHostHeader {
value: host_value.to_string(),
});
}
if request.uri().contains("://")
&& let Some(authority) = extract_authority_from_uri(request.uri())
&& !authority.is_empty()
&& !authority.eq_ignore_ascii_case(host_value)
{
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: authority.to_string(),
});
}
if request.method() == "CONNECT"
&& let Some(colon_pos) = request.uri().rfind(':')
{
let target_host = &request.uri()[..colon_pos];
let target_port_str = &request.uri()[colon_pos + 1..];
if host_value.is_empty() {
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: request.uri().to_string(),
});
}
if let Ok(parsed_host) = Host::parse(host_value) {
if !parsed_host.host().eq_ignore_ascii_case(target_host) {
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: request.uri().to_string(),
});
}
if let Some(host_port) = parsed_host.port()
&& let Ok(target_port) = target_port_str.parse::<u16>()
&& host_port != target_port
{
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: request.uri().to_string(),
});
}
}
}
if let Ok(RequestTargetForm::Absolute) = detect_request_target_form(request.uri())
&& !request.uri().contains("://")
&& !host_value.is_empty()
{
return Err(EncodeError::NonEmptyHostWithoutAuthority {
host: host_value.to_string(),
uri: request.uri().to_string(),
});
}
Ok(())
}
fn reject_http_userinfo(uri: &str) -> Result<(), EncodeError> {
let lower = uri.to_ascii_lowercase();
if !lower.starts_with("http://") && !lower.starts_with("https://") {
return Ok(());
}
let after_scheme = match uri.find("://") {
Some(i) => &uri[i + 3..],
None => return Ok(()),
};
let end = after_scheme.find(['/', '?']).unwrap_or(after_scheme.len());
let authority = &after_scheme[..end];
if authority.contains('@') {
return Err(EncodeError::UserinfoInHttpUri {
uri: uri.to_string(),
});
}
Ok(())
}
fn reject_http_empty_host(uri: &str) -> Result<(), EncodeError> {
let lower = uri.to_ascii_lowercase();
if !lower.starts_with("http://") && !lower.starts_with("https://") {
return Ok(());
}
let after_scheme = match uri.find("://") {
Some(i) => &uri[i + 3..],
None => return Ok(()),
};
let end = after_scheme.find(['/', '?']).unwrap_or(after_scheme.len());
let authority = &after_scheme[..end];
let host_port = if let Some(at_pos) = authority.rfind('@') {
&authority[at_pos + 1..]
} else {
authority
};
if host_port.is_empty() || host_port.starts_with(':') {
return Err(EncodeError::EmptyHostInHttpUri {
uri: uri.to_string(),
});
}
Ok(())
}
fn reject_http_without_authority_prefix(uri: &str) -> Result<(), EncodeError> {
if let Some(colon_pos) = uri.find(':') {
let scheme = &uri[..colon_pos];
if (scheme.eq_ignore_ascii_case("http") || scheme.eq_ignore_ascii_case("https"))
&& !uri[colon_pos..].starts_with("://")
{
return Err(EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
});
}
}
Ok(())
}
fn validate_content_length_headers(
headers: &[(String, String)],
) -> Result<Option<u64>, EncodeError> {
let mut result: Option<u64> = None;
for (name, value) in headers {
if !name.eq_ignore_ascii_case("Content-Length") {
continue;
}
let trimmed = trim_ows(value);
if trimmed.is_empty() || !trimmed.bytes().all(|b| b.is_ascii_digit()) {
return Err(EncodeError::InvalidContentLengthValue {
value: value.clone(),
});
}
let parsed =
trimmed
.parse::<u64>()
.map_err(|_| EncodeError::InvalidContentLengthValue {
value: value.clone(),
})?;
match result {
None => result = Some(parsed),
Some(prev) if prev != parsed => {
return Err(EncodeError::DuplicateContentLength);
}
Some(_) => {} }
}
Ok(result)
}
fn extract_authority_from_uri(uri: &str) -> Option<String> {
let after_scheme = uri.find("://").map(|i| &uri[i + 3..])?;
let end = after_scheme.find(['/', '?']).unwrap_or(after_scheme.len());
let authority = &after_scheme[..end];
let host_port = if let Some(at_pos) = authority.rfind('@') {
&authority[at_pos + 1..]
} else {
authority
};
Some(host_port.to_string())
}
pub fn encode_request(request: &Request) -> Result<Vec<u8>, EncodeError> {
validate_request_fields(request)?;
reject_http_userinfo(request.uri())?;
reject_http_empty_host(request.uri())?;
validate_host_header(request)?;
if request.has_header("Transfer-Encoding") && request.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
if !request.has_header("Transfer-Encoding")
&& let Some(header_value) = validate_content_length_headers(HttpHead::headers(request))?
{
let body_length = request.body_bytes().map(<[u8]>::len).unwrap_or(0) as u64;
if header_value != body_length {
return Err(EncodeError::ContentLengthMismatch {
header_value,
body_length,
});
}
}
let mut buf = allocate_encode_buffer(estimate_request_capacity(request));
buf.extend_from_slice(request.method().as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.uri().as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.version().as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in HttpHead::headers(request) {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
if let Some(body) = request.body_bytes()
&& !request.has_header("Content-Length")
&& !request.has_header("Transfer-Encoding")
{
buf.extend_from_slice(b"Content-Length: ");
write_usize_decimal(&mut buf, body.len());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
if let Some(body) = request.body_bytes() {
buf.extend_from_slice(body);
}
Ok(buf)
}
pub fn encode_response(response: &Response) -> Result<Vec<u8>, EncodeError> {
validate_response_fields(response)?;
if response.has_header("Transfer-Encoding") && response.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
let is_1xx_or_204 =
(100..200).contains(&response.status_code()) || response.status_code() == 204;
if is_1xx_or_204 && response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding {
status_code: response.status_code(),
});
}
if is_1xx_or_204 && response.has_header("Content-Length") {
return Err(EncodeError::ForbiddenContentLength {
status_code: response.status_code(),
});
}
if response.status_code() == 205 {
if response.body_bytes().is_some_and(|b| !b.is_empty()) {
return Err(EncodeError::ForbiddenBodyFor205);
}
if response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding { status_code: 205 });
}
if let Some(cl) = response.get_header("Content-Length")
&& cl.trim() != "0"
{
return Err(EncodeError::ForbiddenContentLength { status_code: 205 });
}
}
let body_will_be_encoded =
response_status_has_body(response.status_code()) && !response.is_body_omitted();
if response_status_has_body(response.status_code())
&& !response.has_header("Transfer-Encoding")
&& let Some(header_value) = validate_content_length_headers(HttpHead::headers(response))?
{
let body_length = response.body_bytes().map(<[u8]>::len).unwrap_or(0) as u64;
let should_validate = body_will_be_encoded || body_length != 0;
if should_validate && header_value != body_length {
return Err(EncodeError::ContentLengthMismatch {
header_value,
body_length,
});
}
}
let mut buf = allocate_encode_buffer(estimate_response_capacity(response));
buf.extend_from_slice(HttpHead::version(response).as_bytes());
buf.push(b' ');
write_usize_decimal(&mut buf, response.status_code() as usize);
buf.push(b' ');
buf.extend_from_slice(response.reason_phrase().as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in HttpHead::headers(response) {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
if should_auto_emit_content_length_for_response(response) {
let len = response.body_bytes().map(<[u8]>::len).unwrap_or(0);
buf.extend_from_slice(b"Content-Length: ");
write_usize_decimal(&mut buf, len);
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
if body_will_be_encoded && let Some(body) = response.body_bytes() {
buf.extend_from_slice(body);
}
Ok(buf)
}
impl Request {
pub fn encode(&self) -> Result<Vec<u8>, EncodeError> {
encode_request(self)
}
}
impl Response {
pub fn encode(&self) -> Result<Vec<u8>, EncodeError> {
encode_response(self)
}
}
pub fn encode_chunk(data: &[u8]) -> Vec<u8> {
if data.is_empty() {
let mut buf = Vec::with_capacity(5);
buf.extend_from_slice(b"0\r\n\r\n");
return buf;
}
let cap = data.len().checked_add(20);
let mut buf = match cap {
Some(c) => Vec::with_capacity(c),
None => Vec::new(),
};
write_hex_usize(&mut buf, data.len());
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(data);
buf.extend_from_slice(b"\r\n");
buf
}
pub fn encode_chunks(chunks: &[&[u8]]) -> Vec<u8> {
let cap = encode_chunks_capacity(chunks);
let mut buf = match cap {
Some(c) => Vec::with_capacity(c),
None => Vec::new(),
};
for chunk in chunks {
write_hex_usize(&mut buf, chunk.len());
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(chunk);
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"0\r\n\r\n");
buf
}
fn encode_chunks_capacity(chunks: &[&[u8]]) -> Option<usize> {
let mut total: usize = 0;
for chunk in chunks {
let per = chunk.len().checked_add(20)?;
total = total.checked_add(per)?;
}
total.checked_add(5)
}
pub fn encode_request_headers(request: &Request) -> Result<Vec<u8>, EncodeError> {
validate_request_fields(request)?;
reject_http_userinfo(request.uri())?;
reject_http_empty_host(request.uri())?;
validate_host_header(request)?;
if request.has_header("Transfer-Encoding") && request.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
let mut buf = Vec::new();
buf.extend_from_slice(request.method().as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.uri().as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.version().as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in HttpHead::headers(request) {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
Ok(buf)
}
pub fn encode_response_headers(response: &Response) -> Result<Vec<u8>, EncodeError> {
validate_response_fields(response)?;
if response.has_header("Transfer-Encoding") && response.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
let is_1xx_or_204 =
(100..200).contains(&response.status_code()) || response.status_code() == 204;
if is_1xx_or_204 && response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding {
status_code: response.status_code(),
});
}
if is_1xx_or_204 && response.has_header("Content-Length") {
return Err(EncodeError::ForbiddenContentLength {
status_code: response.status_code(),
});
}
if response.status_code() == 205 && response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding { status_code: 205 });
}
if response.status_code() == 205
&& let Some(cl) = response.get_header("Content-Length")
&& cl.trim() != "0"
{
return Err(EncodeError::ForbiddenContentLength { status_code: 205 });
}
debug_assert!(
{
if response.has_header("Transfer-Encoding") {
true
} else if let Ok(Some(cl)) =
validate_content_length_headers(HttpHead::headers(response))
{
let body_len = response.body_bytes().map(|b| b.len() as u64).unwrap_or(0);
cl == body_len
} else {
true
}
},
"Content-Length header value does not match body length in encode_response_headers"
);
let mut buf = Vec::new();
buf.extend_from_slice(HttpHead::version(response).as_bytes());
buf.push(b' ');
write_usize_decimal(&mut buf, response.status_code() as usize);
buf.push(b' ');
buf.extend_from_slice(response.reason_phrase().as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in HttpHead::headers(response) {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
Ok(buf)
}
impl Request {
pub fn encode_headers(&self) -> Result<Vec<u8>, EncodeError> {
encode_request_headers(self)
}
}
impl Response {
pub fn encode_headers(&self) -> Result<Vec<u8>, EncodeError> {
encode_response_headers(self)
}
}
#[derive(Debug)]
pub struct ResponseEncoder<C: Compressor = NoCompression> {
compressor: C,
}
impl Default for ResponseEncoder<NoCompression> {
fn default() -> Self {
Self::new()
}
}
impl ResponseEncoder<NoCompression> {
pub fn new() -> Self {
Self {
compressor: NoCompression::new(),
}
}
}
impl<C: Compressor> ResponseEncoder<C> {
pub fn with_compressor(compressor: C) -> Self {
Self { compressor }
}
pub fn compress_body(
&mut self,
input: &[u8],
output: &mut [u8],
) -> Result<CompressionStatus, CompressionError> {
self.compressor.compress(input, output)
}
pub fn finish(&mut self, output: &mut [u8]) -> Result<CompressionStatus, CompressionError> {
self.compressor.finish(output)
}
pub fn reset(&mut self) {
self.compressor.reset();
}
}
#[derive(Debug)]
pub struct RequestEncoder<C: Compressor = NoCompression> {
compressor: C,
}
impl Default for RequestEncoder<NoCompression> {
fn default() -> Self {
Self::new()
}
}
impl RequestEncoder<NoCompression> {
pub fn new() -> Self {
Self {
compressor: NoCompression::new(),
}
}
}
impl<C: Compressor> RequestEncoder<C> {
pub fn with_compressor(compressor: C) -> Self {
Self { compressor }
}
pub fn compress_body(
&mut self,
input: &[u8],
output: &mut [u8],
) -> Result<CompressionStatus, CompressionError> {
self.compressor.compress(input, output)
}
pub fn finish(&mut self, output: &mut [u8]) -> Result<CompressionStatus, CompressionError> {
self.compressor.finish(output)
}
pub fn reset(&mut self) {
self.compressor.reset();
}
}
#[cfg(test)]
mod capacity_tests {
use super::*;
use crate::request::Request;
use crate::response::Response;
use crate::status_code::StatusCode;
fn assert_request_capacity_sufficient(req: &Request) {
let est = estimate_request_capacity(req).expect("estimate overflow");
let out = encode_request(req).expect("encode failed");
assert!(
est >= out.len(),
"estimate {} < output {}: req={req:?}",
est,
out.len(),
);
}
fn assert_response_capacity_sufficient(res: &Response) {
let est = estimate_response_capacity(res).expect("estimate overflow");
let out = encode_response(res).expect("encode failed");
assert!(
est >= out.len(),
"estimate {} < output {}: res={res:?}",
est,
out.len(),
);
}
#[test]
fn test_request_capacity_simple_get() {
let req = Request::new("GET", "/")
.unwrap()
.header("Host", "example.com")
.unwrap();
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_post_with_body_auto_content_length() {
let req = Request::new("POST", "/api")
.unwrap()
.header("Host", "example.com")
.unwrap()
.body(b"hello world".to_vec());
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_post_with_explicit_content_length() {
let req = Request::new("POST", "/api")
.unwrap()
.header("Host", "example.com")
.unwrap()
.header("Content-Length", "11")
.unwrap()
.body(b"hello world".to_vec());
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_post_with_transfer_encoding_no_auto() {
let req = Request::new("POST", "/api")
.unwrap()
.header("Host", "example.com")
.unwrap()
.header("Transfer-Encoding", "chunked")
.unwrap()
.body(b"hello".to_vec());
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_many_headers() {
let mut req = Request::new("GET", "/")
.unwrap()
.header("Host", "example.com")
.unwrap();
for i in 0..50 {
req = req
.header(
alloc::format!("X-Custom-{i}"),
alloc::format!("value-{i}-with-some-padding"),
)
.unwrap();
}
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_empty_body_auto_content_length_zero() {
let req = Request::new("POST", "/")
.unwrap()
.header("Host", "example.com")
.unwrap()
.body(Vec::new());
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_no_body() {
let req = Request::new("GET", "/path/to/resource?q=1")
.unwrap()
.header("Host", "example.com")
.unwrap();
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_response_capacity_simple_ok() {
let res = Response::with_status(StatusCode::OK).body(b"hello".to_vec());
assert_response_capacity_sufficient(&res);
}
#[test]
fn test_response_capacity_no_body_status() {
for &code in &[100u16, 204, 304] {
let res = Response::new(code, "Reason").unwrap();
assert_response_capacity_sufficient(&res);
}
}
#[test]
fn test_response_capacity_omit_body_with_content_length() {
let res = Response::with_status(StatusCode::OK)
.header("Content-Length", "100")
.unwrap()
.omit_body(true);
assert_response_capacity_sufficient(&res);
}
#[test]
fn test_response_capacity_with_transfer_encoding() {
let res = Response::with_status(StatusCode::OK)
.header("Transfer-Encoding", "chunked")
.unwrap();
assert_response_capacity_sufficient(&res);
}
#[test]
fn test_response_capacity_many_headers() {
let mut res = Response::with_status(StatusCode::OK).body(vec![b'X'; 1024]);
for i in 0..50 {
res = res
.header(
alloc::format!("X-Custom-{i}"),
alloc::format!("value-{i}-with-some-padding"),
)
.unwrap();
}
assert_response_capacity_sufficient(&res);
}
#[test]
fn test_response_capacity_status_code_3_digit_boundary() {
for &code in &[100u16, 200, 599] {
let res = Response::new(code, "Phrase")
.unwrap()
.body(b"body".to_vec());
assert_response_capacity_sufficient(&res);
}
}
#[test]
fn test_request_capacity_normal_input_does_not_panic() {
let req = Request::new("GET", "/")
.unwrap()
.header("Host", "example.com")
.unwrap();
let _ = encode_request(&req).unwrap();
}
}
#[cfg(test)]
mod validate_response_fields_tests {
use super::*;
use crate::response::Response;
use alloc::string::ToString;
#[test]
fn test_validate_response_fields_empty_reason_phrase_is_accepted() {
let res =
Response::from_raw_parts("HTTP/1.1".to_string(), 200, String::new(), Vec::new(), None);
let encoded = encode_response(&res).unwrap();
assert!(encoded.starts_with(b"HTTP/1.1 200 \r\n"));
}
}