use crate::compression::{CompressionError, CompressionStatus, Compressor, NoCompression};
use crate::error::EncodeError;
use crate::host::Host;
use crate::request::Request;
use crate::request_target::RequestTargetForm;
use crate::response::Response;
use crate::validate::{
is_valid_field_value, is_valid_header_name, is_valid_method, is_valid_reason_phrase,
is_valid_request_target, is_valid_status_code,
};
use alloc::string::{String, ToString};
use alloc::vec::Vec;
fn is_valid_version_for_encode(version: &str) -> bool {
!version.is_empty() && version.bytes().all(|b| matches!(b, 0x21..=0x7E))
}
fn write_hex_usize(buf: &mut Vec<u8>, n: usize) {
if n == 0 {
buf.push(b'0');
return;
}
let mut tmp = [0u8; 16]; let mut i = tmp.len();
let mut remaining = n;
while remaining > 0 {
i -= 1;
let nibble = (remaining & 0xF) as u8;
tmp[i] = if nibble < 10 {
b'0' + nibble
} else {
b'a' + nibble - 10
};
remaining >>= 4;
}
buf.extend_from_slice(&tmp[i..]);
}
fn write_usize_decimal(buf: &mut Vec<u8>, n: usize) {
if n == 0 {
buf.push(b'0');
return;
}
let mut tmp = [0u8; 20]; let mut i = tmp.len();
let mut remaining = n;
while remaining > 0 {
i -= 1;
tmp[i] = b'0' + (remaining % 10) as u8;
remaining /= 10;
}
buf.extend_from_slice(&tmp[i..]);
}
const ENCODE_CAPACITY_LIMIT: usize = 64 * 1024 * 1024;
const AUTO_CONTENT_LENGTH_CAPACITY: usize = 38;
fn should_auto_emit_content_length_for_request(request: &Request) -> bool {
request.body.is_some()
&& !request.has_header("Content-Length")
&& !request.has_header("Transfer-Encoding")
}
fn response_status_has_body(status_code: u16) -> bool {
!((100..200).contains(&status_code) || status_code == 204 || status_code == 304)
}
fn should_auto_emit_content_length_for_response(response: &Response) -> bool {
let status_has_body = response_status_has_body(response.status_code);
let body_len = response.body.as_deref().map(<[u8]>::len);
status_has_body
&& !response.has_header("Content-Length")
&& !response.has_header("Transfer-Encoding")
&& match (response.omit_body, body_len) {
(_, None) => false,
(true, Some(0)) => false,
(_, Some(_)) => true,
}
}
fn estimate_request_capacity(request: &Request) -> Option<usize> {
let mut total: usize = 0;
total = total.checked_add(request.method.len())?;
total = total.checked_add(request.uri.len())?;
total = total.checked_add(request.version.len())?;
total = total.checked_add(4)?;
for (name, value) in &request.headers {
total = total.checked_add(name.len())?;
total = total.checked_add(value.len())?;
total = total.checked_add(4)?;
}
if should_auto_emit_content_length_for_request(request) {
total = total.checked_add(AUTO_CONTENT_LENGTH_CAPACITY)?;
}
total = total.checked_add(2)?;
if let Some(body) = request.body.as_deref() {
total = total.checked_add(body.len())?;
}
Some(total)
}
fn estimate_response_capacity(response: &Response) -> Option<usize> {
let mut total: usize = 0;
total = total.checked_add(response.version.len())?;
total = total.checked_add(3)?; total = total.checked_add(response.reason_phrase.len())?;
total = total.checked_add(4)?;
for (name, value) in &response.headers {
total = total.checked_add(name.len())?;
total = total.checked_add(value.len())?;
total = total.checked_add(4)?;
}
if should_auto_emit_content_length_for_response(response) {
total = total.checked_add(AUTO_CONTENT_LENGTH_CAPACITY)?;
}
total = total.checked_add(2)?;
let body_will_be_encoded =
response_status_has_body(response.status_code) && !response.omit_body;
if body_will_be_encoded && let Some(body) = response.body.as_deref() {
total = total.checked_add(body.len())?;
}
Some(total)
}
fn allocate_encode_buffer(estimated: Option<usize>) -> Vec<u8> {
match estimated {
Some(c) if c <= ENCODE_CAPACITY_LIMIT => Vec::with_capacity(c),
_ => Vec::new(),
}
}
fn validate_request_fields(request: &Request) -> Result<(), EncodeError> {
if !is_valid_method(&request.method) {
return Err(EncodeError::InvalidMethod {
method: request.method.clone(),
});
}
if !is_valid_request_target(&request.uri) {
return Err(EncodeError::InvalidRequestTarget {
uri: request.uri.clone(),
});
}
if request.uri.bytes().any(|b| b > 0x7E) {
return Err(EncodeError::InvalidRequestTarget {
uri: request.uri.clone(),
});
}
validate_request_target_form(&request.method, &request.uri)?;
if !is_valid_version_for_encode(&request.version) {
return Err(EncodeError::InvalidVersion {
version: request.version.clone(),
});
}
validate_headers(&request.headers)?;
Ok(())
}
fn detect_request_target_form(uri: &str) -> Result<RequestTargetForm, EncodeError> {
if uri == "*" {
Ok(RequestTargetForm::Asterisk)
} else if uri.contains("://") {
Ok(RequestTargetForm::Absolute)
} else if uri.starts_with('/') {
Ok(RequestTargetForm::Origin)
} else if looks_like_authority_form(uri) {
Ok(RequestTargetForm::Authority)
} else if detect_scheme(uri).is_some() {
Ok(RequestTargetForm::Absolute)
} else {
Err(EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
})
}
}
fn looks_like_authority_form(uri: &str) -> bool {
if uri.contains('@') {
return false;
}
if let Some(colon_pos) = uri.rfind(':') {
let port_str = &uri[colon_pos + 1..];
let host = &uri[..colon_pos];
!port_str.is_empty()
&& port_str.bytes().all(|b| b.is_ascii_digit())
&& port_str.parse::<u16>().is_ok()
&& !host.is_empty()
} else {
false
}
}
fn validate_encoder_authority_form(uri: &str) -> Result<(), EncodeError> {
if let Some(colon_pos) = uri.rfind(':') {
let host = &uri[..colon_pos];
Host::parse(host).map_err(|_| EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
})?;
}
Ok(())
}
fn detect_scheme(target: &str) -> Option<usize> {
let bytes = target.as_bytes();
if bytes.is_empty() || !bytes[0].is_ascii_alphabetic() {
return None;
}
let colon_pos = bytes.iter().position(|&b| b == b':')?;
if colon_pos == 0 {
return None;
}
for &b in &bytes[1..colon_pos] {
if !b.is_ascii_alphanumeric() && b != b'+' && b != b'-' && b != b'.' {
return None;
}
}
if colon_pos + 1 >= bytes.len() {
return None;
}
Some(colon_pos)
}
fn validate_request_target_form(method: &str, uri: &str) -> Result<(), EncodeError> {
let form = detect_request_target_form(uri)?;
match (method, &form) {
("CONNECT", RequestTargetForm::Authority) => {
validate_encoder_authority_form(uri)?;
Ok(())
}
("CONNECT", _) => Err(EncodeError::InvalidRequestTargetForm {
method: method.to_string(),
uri: uri.to_string(),
}),
(_, RequestTargetForm::Asterisk) => {
if method == "OPTIONS" {
Ok(())
} else {
Err(EncodeError::InvalidRequestTargetForm {
method: method.to_string(),
uri: uri.to_string(),
})
}
}
(_, RequestTargetForm::Authority) => Err(EncodeError::InvalidRequestTargetForm {
method: method.to_string(),
uri: uri.to_string(),
}),
(_, RequestTargetForm::Origin) => {
if uri.bytes().any(|b| b == b'[' || b == b']') {
return Err(EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
});
}
Ok(())
}
(_, RequestTargetForm::Absolute) => {
reject_http_without_authority_prefix(uri)?;
Ok(())
}
}
}
fn validate_response_fields(response: &Response) -> Result<(), EncodeError> {
if !is_valid_version_for_encode(&response.version) {
return Err(EncodeError::InvalidVersion {
version: response.version.clone(),
});
}
if !is_valid_status_code(response.status_code) {
return Err(EncodeError::InvalidStatusCode {
code: response.status_code,
});
}
if !is_valid_reason_phrase(&response.reason_phrase) {
return Err(EncodeError::InvalidReasonPhrase {
phrase: response.reason_phrase.clone(),
});
}
validate_headers(&response.headers)?;
Ok(())
}
fn validate_headers(headers: &[(String, String)]) -> Result<(), EncodeError> {
for (name, value) in headers {
if !is_valid_header_name(name) {
return Err(EncodeError::InvalidHeaderName { name: name.clone() });
}
if !is_valid_field_value(value) {
return Err(EncodeError::InvalidHeaderValue {
name: name.clone(),
value: value.clone(),
});
}
}
Ok(())
}
fn validate_host_header(request: &Request) -> Result<(), EncodeError> {
if request.version != "HTTP/1.1" {
return Ok(());
}
let host_headers: Vec<&str> = request
.headers
.iter()
.filter(|(name, _)| name.eq_ignore_ascii_case("Host"))
.map(|(_, value)| value.as_str())
.collect();
if host_headers.is_empty() {
return Err(EncodeError::MissingHostHeader);
}
if host_headers.len() > 1 {
return Err(EncodeError::DuplicateHostHeader);
}
let host_value = host_headers[0];
if !host_value.is_empty() && Host::parse(host_value).is_err() {
return Err(EncodeError::InvalidHostHeader {
value: host_value.to_string(),
});
}
if request.uri.contains("://")
&& let Some(authority) = extract_authority_from_uri(&request.uri)
&& !authority.is_empty()
&& !authority.eq_ignore_ascii_case(host_value)
{
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: authority.to_string(),
});
}
if request.method == "CONNECT"
&& let Some(colon_pos) = request.uri.rfind(':')
{
let target_host = &request.uri[..colon_pos];
let target_port_str = &request.uri[colon_pos + 1..];
if host_value.is_empty() {
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: request.uri.clone(),
});
}
if let Ok(parsed_host) = Host::parse(host_value) {
if !parsed_host.host().eq_ignore_ascii_case(target_host) {
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: request.uri.clone(),
});
}
if let Some(host_port) = parsed_host.port()
&& let Ok(target_port) = target_port_str.parse::<u16>()
&& host_port != target_port
{
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: request.uri.clone(),
});
}
}
}
if let Ok(RequestTargetForm::Absolute) = detect_request_target_form(&request.uri)
&& !request.uri.contains("://")
&& !host_value.is_empty()
{
return Err(EncodeError::NonEmptyHostWithoutAuthority {
host: host_value.to_string(),
uri: request.uri.clone(),
});
}
Ok(())
}
fn reject_http_userinfo(uri: &str) -> Result<(), EncodeError> {
let lower = uri.to_ascii_lowercase();
if !lower.starts_with("http://") && !lower.starts_with("https://") {
return Ok(());
}
let after_scheme = match uri.find("://") {
Some(i) => &uri[i + 3..],
None => return Ok(()),
};
let end = after_scheme.find(['/', '?']).unwrap_or(after_scheme.len());
let authority = &after_scheme[..end];
if authority.contains('@') {
return Err(EncodeError::UserinfoInHttpUri {
uri: uri.to_string(),
});
}
Ok(())
}
fn reject_http_empty_host(uri: &str) -> Result<(), EncodeError> {
let lower = uri.to_ascii_lowercase();
if !lower.starts_with("http://") && !lower.starts_with("https://") {
return Ok(());
}
let after_scheme = match uri.find("://") {
Some(i) => &uri[i + 3..],
None => return Ok(()),
};
let end = after_scheme.find(['/', '?']).unwrap_or(after_scheme.len());
let authority = &after_scheme[..end];
let host_port = if let Some(at_pos) = authority.rfind('@') {
&authority[at_pos + 1..]
} else {
authority
};
if host_port.is_empty() || host_port.starts_with(':') {
return Err(EncodeError::EmptyHostInHttpUri {
uri: uri.to_string(),
});
}
Ok(())
}
fn reject_http_without_authority_prefix(uri: &str) -> Result<(), EncodeError> {
if let Some(colon_pos) = uri.find(':') {
let scheme = &uri[..colon_pos];
if (scheme.eq_ignore_ascii_case("http") || scheme.eq_ignore_ascii_case("https"))
&& !uri[colon_pos..].starts_with("://")
{
return Err(EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
});
}
}
Ok(())
}
fn validate_content_length_headers(
headers: &[(String, String)],
) -> Result<Option<u64>, EncodeError> {
let mut result: Option<u64> = None;
for (name, value) in headers {
if !name.eq_ignore_ascii_case("Content-Length") {
continue;
}
let trimmed = value.trim();
if trimmed.is_empty() || !trimmed.bytes().all(|b| b.is_ascii_digit()) {
return Err(EncodeError::InvalidContentLengthValue {
value: value.clone(),
});
}
let parsed =
trimmed
.parse::<u64>()
.map_err(|_| EncodeError::InvalidContentLengthValue {
value: value.clone(),
})?;
match result {
None => result = Some(parsed),
Some(prev) if prev != parsed => {
return Err(EncodeError::DuplicateContentLength);
}
Some(_) => {} }
}
Ok(result)
}
fn extract_authority_from_uri(uri: &str) -> Option<String> {
let after_scheme = uri.find("://").map(|i| &uri[i + 3..])?;
let end = after_scheme.find(['/', '?']).unwrap_or(after_scheme.len());
let authority = &after_scheme[..end];
let host_port = if let Some(at_pos) = authority.rfind('@') {
&authority[at_pos + 1..]
} else {
authority
};
Some(host_port.to_string())
}
pub fn encode_request(request: &Request) -> Result<Vec<u8>, EncodeError> {
validate_request_fields(request)?;
reject_http_userinfo(&request.uri)?;
reject_http_empty_host(&request.uri)?;
validate_host_header(request)?;
if request.has_header("Transfer-Encoding") && request.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
if !request.has_header("Transfer-Encoding")
&& let Some(header_value) = validate_content_length_headers(&request.headers)?
{
let body_length = request.body.as_deref().map(<[u8]>::len).unwrap_or(0) as u64;
if header_value != body_length {
return Err(EncodeError::ContentLengthMismatch {
header_value,
body_length,
});
}
}
let mut buf = allocate_encode_buffer(estimate_request_capacity(request));
buf.extend_from_slice(request.method.as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.uri.as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.version.as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in &request.headers {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
if let Some(body) = request.body.as_deref()
&& !request.has_header("Content-Length")
&& !request.has_header("Transfer-Encoding")
{
buf.extend_from_slice(b"Content-Length: ");
write_usize_decimal(&mut buf, body.len());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
if let Some(body) = request.body.as_deref() {
buf.extend_from_slice(body);
}
Ok(buf)
}
pub fn encode_response(response: &Response) -> Result<Vec<u8>, EncodeError> {
validate_response_fields(response)?;
if response.has_header("Transfer-Encoding") && response.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
let is_1xx_or_204 = (100..200).contains(&response.status_code) || response.status_code == 204;
if is_1xx_or_204 && response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding {
status_code: response.status_code,
});
}
if is_1xx_or_204 && response.has_header("Content-Length") {
return Err(EncodeError::ForbiddenContentLength {
status_code: response.status_code,
});
}
if response.status_code == 205 {
if response.body.as_deref().is_some_and(|b| !b.is_empty()) {
return Err(EncodeError::ForbiddenBodyFor205);
}
if response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding { status_code: 205 });
}
if let Some(cl) = response.get_header("Content-Length")
&& cl.trim() != "0"
{
return Err(EncodeError::ForbiddenContentLength { status_code: 205 });
}
}
let body_will_be_encoded =
response_status_has_body(response.status_code) && !response.omit_body;
if response_status_has_body(response.status_code)
&& !response.has_header("Transfer-Encoding")
&& let Some(header_value) = validate_content_length_headers(&response.headers)?
{
let body_length = response.body.as_deref().map(<[u8]>::len).unwrap_or(0) as u64;
let should_validate = body_will_be_encoded || body_length != 0;
if should_validate && header_value != body_length {
return Err(EncodeError::ContentLengthMismatch {
header_value,
body_length,
});
}
}
let mut buf = allocate_encode_buffer(estimate_response_capacity(response));
buf.extend_from_slice(response.version.as_bytes());
buf.push(b' ');
write_usize_decimal(&mut buf, response.status_code as usize);
buf.push(b' ');
buf.extend_from_slice(response.reason_phrase.as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in &response.headers {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
if should_auto_emit_content_length_for_response(response) {
let len = response.body.as_deref().map(<[u8]>::len).unwrap_or(0);
buf.extend_from_slice(b"Content-Length: ");
write_usize_decimal(&mut buf, len);
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
if body_will_be_encoded && let Some(body) = response.body.as_deref() {
buf.extend_from_slice(body);
}
Ok(buf)
}
impl Request {
pub fn encode(&self) -> Vec<u8> {
encode_request(self).expect("HTTP/1.1 request requires Host header")
}
pub fn try_encode(&self) -> Result<Vec<u8>, EncodeError> {
encode_request(self)
}
}
impl Response {
pub fn encode(&self) -> Vec<u8> {
encode_response(self).expect("invalid header combination")
}
pub fn try_encode(&self) -> Result<Vec<u8>, EncodeError> {
encode_response(self)
}
}
pub fn encode_chunk(data: &[u8]) -> Vec<u8> {
if data.is_empty() {
let mut buf = Vec::with_capacity(5);
buf.extend_from_slice(b"0\r\n\r\n");
return buf;
}
let cap = data.len().checked_add(20);
let mut buf = match cap {
Some(c) => Vec::with_capacity(c),
None => Vec::new(),
};
write_hex_usize(&mut buf, data.len());
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(data);
buf.extend_from_slice(b"\r\n");
buf
}
pub fn encode_chunks(chunks: &[&[u8]]) -> Vec<u8> {
let cap = encode_chunks_capacity(chunks);
let mut buf = match cap {
Some(c) => Vec::with_capacity(c),
None => Vec::new(),
};
for chunk in chunks {
write_hex_usize(&mut buf, chunk.len());
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(chunk);
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"0\r\n\r\n");
buf
}
fn encode_chunks_capacity(chunks: &[&[u8]]) -> Option<usize> {
let mut total: usize = 0;
for chunk in chunks {
let per = chunk.len().checked_add(20)?;
total = total.checked_add(per)?;
}
total.checked_add(5)
}
pub fn encode_request_headers(request: &Request) -> Result<Vec<u8>, EncodeError> {
validate_request_fields(request)?;
reject_http_userinfo(&request.uri)?;
reject_http_empty_host(&request.uri)?;
validate_host_header(request)?;
if request.has_header("Transfer-Encoding") && request.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
let mut buf = Vec::new();
buf.extend_from_slice(request.method.as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.uri.as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.version.as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in &request.headers {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
Ok(buf)
}
pub fn encode_response_headers(response: &Response) -> Result<Vec<u8>, EncodeError> {
validate_response_fields(response)?;
if response.has_header("Transfer-Encoding") && response.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
let is_1xx_or_204 = (100..200).contains(&response.status_code) || response.status_code == 204;
if is_1xx_or_204 && response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding {
status_code: response.status_code,
});
}
if is_1xx_or_204 && response.has_header("Content-Length") {
return Err(EncodeError::ForbiddenContentLength {
status_code: response.status_code,
});
}
if response.status_code == 205 && response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding { status_code: 205 });
}
if response.status_code == 205
&& let Some(cl) = response.get_header("Content-Length")
&& cl.trim() != "0"
{
return Err(EncodeError::ForbiddenContentLength { status_code: 205 });
}
let mut buf = Vec::new();
buf.extend_from_slice(response.version.as_bytes());
buf.push(b' ');
write_usize_decimal(&mut buf, response.status_code as usize);
buf.push(b' ');
buf.extend_from_slice(response.reason_phrase.as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in &response.headers {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
Ok(buf)
}
impl Request {
pub fn encode_headers(&self) -> Vec<u8> {
encode_request_headers(self).expect("HTTP/1.1 request requires Host header")
}
pub fn try_encode_headers(&self) -> Result<Vec<u8>, EncodeError> {
encode_request_headers(self)
}
}
impl Response {
pub fn encode_headers(&self) -> Vec<u8> {
encode_response_headers(self).expect("invalid header combination")
}
pub fn try_encode_headers(&self) -> Result<Vec<u8>, EncodeError> {
encode_response_headers(self)
}
}
#[derive(Debug)]
pub struct ResponseEncoder<C: Compressor = NoCompression> {
compressor: C,
}
impl Default for ResponseEncoder<NoCompression> {
fn default() -> Self {
Self::new()
}
}
impl ResponseEncoder<NoCompression> {
pub fn new() -> Self {
Self {
compressor: NoCompression::new(),
}
}
}
impl<C: Compressor> ResponseEncoder<C> {
pub fn with_compressor(compressor: C) -> Self {
Self { compressor }
}
pub fn compress_body(
&mut self,
input: &[u8],
output: &mut [u8],
) -> Result<CompressionStatus, CompressionError> {
self.compressor.compress(input, output)
}
pub fn finish(&mut self, output: &mut [u8]) -> Result<CompressionStatus, CompressionError> {
self.compressor.finish(output)
}
pub fn reset(&mut self) {
self.compressor.reset();
}
}
#[derive(Debug)]
pub struct RequestEncoder<C: Compressor = NoCompression> {
compressor: C,
}
impl Default for RequestEncoder<NoCompression> {
fn default() -> Self {
Self::new()
}
}
impl RequestEncoder<NoCompression> {
pub fn new() -> Self {
Self {
compressor: NoCompression::new(),
}
}
}
impl<C: Compressor> RequestEncoder<C> {
pub fn with_compressor(compressor: C) -> Self {
Self { compressor }
}
pub fn compress_body(
&mut self,
input: &[u8],
output: &mut [u8],
) -> Result<CompressionStatus, CompressionError> {
self.compressor.compress(input, output)
}
pub fn finish(&mut self, output: &mut [u8]) -> Result<CompressionStatus, CompressionError> {
self.compressor.finish(output)
}
pub fn reset(&mut self) {
self.compressor.reset();
}
}
#[cfg(test)]
mod capacity_tests {
use super::*;
use crate::request::Request;
use crate::response::Response;
fn assert_request_capacity_sufficient(req: &Request) {
let est = estimate_request_capacity(req).expect("estimate overflow");
let out = encode_request(req).expect("encode failed");
assert!(
est >= out.len(),
"estimate {} < output {}: req={req:?}",
est,
out.len(),
);
}
fn assert_response_capacity_sufficient(res: &Response) {
let est = estimate_response_capacity(res).expect("estimate overflow");
let out = encode_response(res).expect("encode failed");
assert!(
est >= out.len(),
"estimate {} < output {}: res={res:?}",
est,
out.len(),
);
}
#[test]
fn test_request_capacity_simple_get() {
let req = Request::new("GET", "/").header("Host", "example.com");
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_post_with_body_auto_content_length() {
let req = Request::new("POST", "/api")
.header("Host", "example.com")
.body(b"hello world".to_vec());
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_post_with_explicit_content_length() {
let req = Request::new("POST", "/api")
.header("Host", "example.com")
.header("Content-Length", "11")
.body(b"hello world".to_vec());
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_post_with_transfer_encoding_no_auto() {
let req = Request::new("POST", "/api")
.header("Host", "example.com")
.header("Transfer-Encoding", "chunked")
.body(b"hello".to_vec());
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_many_headers() {
let mut req = Request::new("GET", "/").header("Host", "example.com");
for i in 0..50 {
req = req.header(
&alloc::format!("X-Custom-{i}"),
&alloc::format!("value-{i}-with-some-padding"),
);
}
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_empty_body_auto_content_length_zero() {
let req = Request::new("POST", "/")
.header("Host", "example.com")
.body(Vec::new());
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_request_capacity_no_body() {
let req = Request::new("GET", "/path/to/resource?q=1").header("Host", "example.com");
assert_request_capacity_sufficient(&req);
}
#[test]
fn test_response_capacity_simple_ok() {
let res = Response::new(200, "OK").body(b"hello".to_vec());
assert_response_capacity_sufficient(&res);
}
#[test]
fn test_response_capacity_no_body_status() {
for &code in &[100u16, 204, 304] {
let res = Response::new(code, "Reason");
assert_response_capacity_sufficient(&res);
}
}
#[test]
fn test_response_capacity_omit_body_with_content_length() {
let res = Response::new(200, "OK")
.header("Content-Length", "100")
.omit_body(true);
assert_response_capacity_sufficient(&res);
}
#[test]
fn test_response_capacity_with_transfer_encoding() {
let res = Response::new(200, "OK").header("Transfer-Encoding", "chunked");
assert_response_capacity_sufficient(&res);
}
#[test]
fn test_response_capacity_many_headers() {
let mut res = Response::new(200, "OK").body(vec![b'X'; 1024]);
for i in 0..50 {
res = res.header(
&alloc::format!("X-Custom-{i}"),
&alloc::format!("value-{i}-with-some-padding"),
);
}
assert_response_capacity_sufficient(&res);
}
#[test]
fn test_response_capacity_status_code_3_digit_boundary() {
for &code in &[100u16, 200, 599] {
let res = Response::new(code, "Phrase").body(b"body".to_vec());
assert_response_capacity_sufficient(&res);
}
}
#[test]
fn test_request_capacity_normal_input_does_not_panic() {
let req = Request::new("GET", "/").header("Host", "example.com");
let _ = encode_request(&req).unwrap();
}
}