use crate::compression::{CompressionError, CompressionStatus, Compressor, NoCompression};
use crate::error::EncodeError;
use crate::host::Host;
use crate::request::Request;
use crate::response::Response;
use crate::validate::{
is_valid_field_value, is_valid_header_name, is_valid_method, is_valid_reason_phrase,
is_valid_request_target, is_valid_status_code, is_valid_version_for_encode,
};
fn validate_request_fields(request: &Request) -> Result<(), EncodeError> {
if !is_valid_method(&request.method) {
return Err(EncodeError::InvalidMethod {
method: request.method.clone(),
});
}
if !is_valid_request_target(&request.uri) {
return Err(EncodeError::InvalidRequestTarget {
uri: request.uri.clone(),
});
}
if request.uri.bytes().any(|b| b > 0x7E) {
return Err(EncodeError::InvalidRequestTarget {
uri: request.uri.clone(),
});
}
validate_request_target_form(&request.method, &request.uri)?;
if !is_valid_version_for_encode(&request.version) {
return Err(EncodeError::InvalidVersion {
version: request.version.clone(),
});
}
validate_headers(&request.headers)?;
Ok(())
}
fn detect_request_target_form(uri: &str) -> Result<RequestTargetForm, EncodeError> {
if uri == "*" {
Ok(RequestTargetForm::Asterisk)
} else if uri.contains("://") {
Ok(RequestTargetForm::Absolute)
} else if uri.starts_with('/') {
Ok(RequestTargetForm::Origin)
} else if looks_like_authority_form(uri) {
Ok(RequestTargetForm::Authority)
} else if detect_scheme(uri).is_some() {
Ok(RequestTargetForm::Absolute)
} else {
Err(EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
})
}
}
fn looks_like_authority_form(uri: &str) -> bool {
if uri.contains('@') {
return false;
}
if let Some(colon_pos) = uri.rfind(':') {
let port_str = &uri[colon_pos + 1..];
let host = &uri[..colon_pos];
!port_str.is_empty()
&& port_str.bytes().all(|b| b.is_ascii_digit())
&& port_str.parse::<u16>().is_ok()
&& !host.is_empty()
} else {
false
}
}
fn validate_encoder_authority_form(uri: &str) -> Result<(), EncodeError> {
if let Some(colon_pos) = uri.rfind(':') {
let host = &uri[..colon_pos];
Host::parse(host).map_err(|_| EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
})?;
}
Ok(())
}
fn detect_scheme(target: &str) -> Option<usize> {
let bytes = target.as_bytes();
if bytes.is_empty() || !bytes[0].is_ascii_alphabetic() {
return None;
}
let colon_pos = bytes.iter().position(|&b| b == b':')?;
if colon_pos == 0 {
return None;
}
for &b in &bytes[1..colon_pos] {
if !b.is_ascii_alphanumeric() && b != b'+' && b != b'-' && b != b'.' {
return None;
}
}
if colon_pos + 1 >= bytes.len() {
return None;
}
Some(colon_pos)
}
enum RequestTargetForm {
Origin,
Absolute,
Authority,
Asterisk,
}
fn validate_request_target_form(method: &str, uri: &str) -> Result<(), EncodeError> {
let form = detect_request_target_form(uri)?;
match (method, &form) {
("CONNECT", RequestTargetForm::Authority) => {
validate_encoder_authority_form(uri)?;
Ok(())
}
("CONNECT", _) => Err(EncodeError::InvalidRequestTargetForm {
method: method.to_string(),
uri: uri.to_string(),
}),
(_, RequestTargetForm::Asterisk) => {
if method == "OPTIONS" {
Ok(())
} else {
Err(EncodeError::InvalidRequestTargetForm {
method: method.to_string(),
uri: uri.to_string(),
})
}
}
(_, RequestTargetForm::Authority) => Err(EncodeError::InvalidRequestTargetForm {
method: method.to_string(),
uri: uri.to_string(),
}),
(_, RequestTargetForm::Origin) => {
if uri.bytes().any(|b| b == b'[' || b == b']') {
return Err(EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
});
}
Ok(())
}
(_, RequestTargetForm::Absolute) => {
reject_http_without_authority_prefix(uri)?;
Ok(())
}
}
}
fn validate_response_fields(response: &Response) -> Result<(), EncodeError> {
if !is_valid_version_for_encode(&response.version) {
return Err(EncodeError::InvalidVersion {
version: response.version.clone(),
});
}
if !is_valid_status_code(response.status_code) {
return Err(EncodeError::InvalidStatusCode {
code: response.status_code,
});
}
if !is_valid_reason_phrase(&response.reason_phrase) {
return Err(EncodeError::InvalidReasonPhrase {
phrase: response.reason_phrase.clone(),
});
}
validate_headers(&response.headers)?;
Ok(())
}
fn validate_headers(headers: &[(String, String)]) -> Result<(), EncodeError> {
for (name, value) in headers {
if !is_valid_header_name(name) {
return Err(EncodeError::InvalidHeaderName { name: name.clone() });
}
if !is_valid_field_value(value) {
return Err(EncodeError::InvalidHeaderValue {
name: name.clone(),
value: value.clone(),
});
}
}
Ok(())
}
fn validate_host_header(request: &Request) -> Result<(), EncodeError> {
if request.version != "HTTP/1.1" {
return Ok(());
}
let host_headers: Vec<&str> = request
.headers
.iter()
.filter(|(name, _)| name.eq_ignore_ascii_case("Host"))
.map(|(_, value)| value.as_str())
.collect();
if host_headers.is_empty() {
return Err(EncodeError::MissingHostHeader);
}
if host_headers.len() > 1 {
return Err(EncodeError::DuplicateHostHeader);
}
let host_value = host_headers[0];
if !host_value.is_empty() && Host::parse(host_value).is_err() {
return Err(EncodeError::InvalidHostHeader {
value: host_value.to_string(),
});
}
if request.uri.contains("://")
&& let Some(authority) = extract_authority_from_uri(&request.uri)
&& !authority.is_empty()
&& !authority.eq_ignore_ascii_case(host_value)
{
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: authority.to_string(),
});
}
if request.method == "CONNECT"
&& let Some(colon_pos) = request.uri.rfind(':')
{
let target_host = &request.uri[..colon_pos];
let target_port_str = &request.uri[colon_pos + 1..];
if host_value.is_empty() {
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: request.uri.clone(),
});
}
if let Ok(parsed_host) = Host::parse(host_value) {
if !parsed_host.host().eq_ignore_ascii_case(target_host) {
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: request.uri.clone(),
});
}
if let Some(host_port) = parsed_host.port()
&& let Ok(target_port) = target_port_str.parse::<u16>()
&& host_port != target_port
{
return Err(EncodeError::HostAuthorityMismatch {
host: host_value.to_string(),
authority: request.uri.clone(),
});
}
}
}
if let Ok(RequestTargetForm::Absolute) = detect_request_target_form(&request.uri)
&& !request.uri.contains("://")
&& !host_value.is_empty()
{
return Err(EncodeError::NonEmptyHostWithoutAuthority {
host: host_value.to_string(),
uri: request.uri.clone(),
});
}
Ok(())
}
fn reject_http_userinfo(uri: &str) -> Result<(), EncodeError> {
let lower = uri.to_ascii_lowercase();
if !lower.starts_with("http://") && !lower.starts_with("https://") {
return Ok(());
}
let after_scheme = match uri.find("://") {
Some(i) => &uri[i + 3..],
None => return Ok(()),
};
let end = after_scheme.find(['/', '?']).unwrap_or(after_scheme.len());
let authority = &after_scheme[..end];
if authority.contains('@') {
return Err(EncodeError::UserinfoInHttpUri {
uri: uri.to_string(),
});
}
Ok(())
}
fn reject_http_empty_host(uri: &str) -> Result<(), EncodeError> {
let lower = uri.to_ascii_lowercase();
if !lower.starts_with("http://") && !lower.starts_with("https://") {
return Ok(());
}
let after_scheme = match uri.find("://") {
Some(i) => &uri[i + 3..],
None => return Ok(()),
};
let end = after_scheme.find(['/', '?']).unwrap_or(after_scheme.len());
let authority = &after_scheme[..end];
let host_port = if let Some(at_pos) = authority.rfind('@') {
&authority[at_pos + 1..]
} else {
authority
};
if host_port.is_empty() || host_port.starts_with(':') {
return Err(EncodeError::EmptyHostInHttpUri {
uri: uri.to_string(),
});
}
Ok(())
}
fn reject_http_without_authority_prefix(uri: &str) -> Result<(), EncodeError> {
if let Some(colon_pos) = uri.find(':') {
let scheme = &uri[..colon_pos];
if (scheme.eq_ignore_ascii_case("http") || scheme.eq_ignore_ascii_case("https"))
&& !uri[colon_pos..].starts_with("://")
{
return Err(EncodeError::InvalidRequestTarget {
uri: uri.to_string(),
});
}
}
Ok(())
}
fn validate_content_length_headers(
headers: &[(String, String)],
) -> Result<Option<u64>, EncodeError> {
let mut result: Option<u64> = None;
for (name, value) in headers {
if !name.eq_ignore_ascii_case("Content-Length") {
continue;
}
let trimmed = value.trim();
if trimmed.is_empty() || !trimmed.bytes().all(|b| b.is_ascii_digit()) {
return Err(EncodeError::InvalidContentLengthValue {
value: value.clone(),
});
}
let parsed =
trimmed
.parse::<u64>()
.map_err(|_| EncodeError::InvalidContentLengthValue {
value: value.clone(),
})?;
match result {
None => result = Some(parsed),
Some(prev) if prev != parsed => {
return Err(EncodeError::DuplicateContentLength);
}
Some(_) => {} }
}
Ok(result)
}
fn extract_authority_from_uri(uri: &str) -> Option<String> {
let after_scheme = uri.find("://").map(|i| &uri[i + 3..])?;
let end = after_scheme.find(['/', '?']).unwrap_or(after_scheme.len());
let authority = &after_scheme[..end];
let host_port = if let Some(at_pos) = authority.rfind('@') {
&authority[at_pos + 1..]
} else {
authority
};
Some(host_port.to_string())
}
pub fn encode_request(request: &Request) -> Result<Vec<u8>, EncodeError> {
validate_request_fields(request)?;
reject_http_userinfo(&request.uri)?;
reject_http_empty_host(&request.uri)?;
validate_host_header(request)?;
if request.has_header("Transfer-Encoding") && request.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
if request.method == "CONNECT"
&& (!request.body.is_empty()
|| request.has_header("Content-Length")
|| request.has_header("Transfer-Encoding"))
{
return Err(EncodeError::ConnectRequestWithContent);
}
if !request.has_header("Transfer-Encoding")
&& let Some(header_value) = validate_content_length_headers(&request.headers)?
{
let body_length = request.body.len() as u64;
if header_value != body_length {
return Err(EncodeError::ContentLengthMismatch {
header_value,
body_length,
});
}
}
let mut buf = Vec::new();
buf.extend_from_slice(request.method.as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.uri.as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.version.as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in &request.headers {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
if !request.body.is_empty()
&& !request.has_header("Content-Length")
&& !request.has_header("Transfer-Encoding")
{
buf.extend_from_slice(b"Content-Length: ");
buf.extend_from_slice(request.body.len().to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(&request.body);
Ok(buf)
}
pub fn encode_response(response: &Response) -> Result<Vec<u8>, EncodeError> {
validate_response_fields(response)?;
if response.has_header("Transfer-Encoding") && response.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
let is_1xx_or_204 = (100..200).contains(&response.status_code) || response.status_code == 204;
if is_1xx_or_204 && response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding {
status_code: response.status_code,
});
}
if is_1xx_or_204 && response.has_header("Content-Length") {
return Err(EncodeError::ForbiddenContentLength {
status_code: response.status_code,
});
}
if response.status_code == 205 {
if !response.body.is_empty() {
return Err(EncodeError::ForbiddenBodyFor205);
}
if response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding { status_code: 205 });
}
if let Some(cl) = response.get_header("Content-Length")
&& cl.trim() != "0"
{
return Err(EncodeError::ForbiddenContentLength { status_code: 205 });
}
}
let status_has_body = !((100..200).contains(&response.status_code)
|| response.status_code == 204
|| response.status_code == 304);
let body_will_be_encoded = status_has_body && !response.omit_body;
if status_has_body
&& !response.has_header("Transfer-Encoding")
&& let Some(header_value) = validate_content_length_headers(&response.headers)?
{
let body_length = response.body.len() as u64;
let should_validate = body_will_be_encoded || body_length != 0;
if should_validate && header_value != body_length {
return Err(EncodeError::ContentLengthMismatch {
header_value,
body_length,
});
}
}
let mut buf = Vec::new();
buf.extend_from_slice(response.version.as_bytes());
buf.push(b' ');
buf.extend_from_slice(response.status_code.to_string().as_bytes());
buf.push(b' ');
buf.extend_from_slice(response.reason_phrase.as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in &response.headers {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
if status_has_body
&& (!response.omit_body || !response.body.is_empty())
&& !response.has_header("Content-Length")
&& !response.has_header("Transfer-Encoding")
{
buf.extend_from_slice(b"Content-Length: ");
buf.extend_from_slice(response.body.len().to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
if body_will_be_encoded {
buf.extend_from_slice(&response.body);
}
Ok(buf)
}
impl Request {
pub fn encode(&self) -> Vec<u8> {
encode_request(self).expect("HTTP/1.1 request requires Host header")
}
pub fn try_encode(&self) -> Result<Vec<u8>, EncodeError> {
encode_request(self)
}
}
impl Response {
pub fn encode(&self) -> Vec<u8> {
encode_response(self).expect("invalid header combination")
}
pub fn try_encode(&self) -> Result<Vec<u8>, EncodeError> {
encode_response(self)
}
}
pub fn encode_chunk(data: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
if data.is_empty() {
buf.extend_from_slice(b"0\r\n\r\n");
} else {
buf.extend_from_slice(format!("{:x}\r\n", data.len()).as_bytes());
buf.extend_from_slice(data);
buf.extend_from_slice(b"\r\n");
}
buf
}
pub fn encode_chunks(chunks: &[&[u8]]) -> Vec<u8> {
let mut buf = Vec::new();
for chunk in chunks {
buf.extend_from_slice(format!("{:x}\r\n", chunk.len()).as_bytes());
buf.extend_from_slice(chunk);
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"0\r\n\r\n");
buf
}
pub fn encode_request_headers(request: &Request) -> Result<Vec<u8>, EncodeError> {
validate_request_fields(request)?;
reject_http_userinfo(&request.uri)?;
reject_http_empty_host(&request.uri)?;
validate_host_header(request)?;
if request.has_header("Transfer-Encoding") && request.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
if request.method == "CONNECT"
&& (request.has_header("Content-Length") || request.has_header("Transfer-Encoding"))
{
return Err(EncodeError::ConnectRequestWithContent);
}
let mut buf = Vec::new();
buf.extend_from_slice(request.method.as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.uri.as_bytes());
buf.push(b' ');
buf.extend_from_slice(request.version.as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in &request.headers {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
Ok(buf)
}
pub fn encode_response_headers(response: &Response) -> Result<Vec<u8>, EncodeError> {
validate_response_fields(response)?;
if response.has_header("Transfer-Encoding") && response.has_header("Content-Length") {
return Err(EncodeError::ConflictingTransferEncodingAndContentLength);
}
let is_1xx_or_204 = (100..200).contains(&response.status_code) || response.status_code == 204;
if is_1xx_or_204 && response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding {
status_code: response.status_code,
});
}
if is_1xx_or_204 && response.has_header("Content-Length") {
return Err(EncodeError::ForbiddenContentLength {
status_code: response.status_code,
});
}
if response.status_code == 205 && response.has_header("Transfer-Encoding") {
return Err(EncodeError::ForbiddenTransferEncoding { status_code: 205 });
}
if response.status_code == 205
&& let Some(cl) = response.get_header("Content-Length")
&& cl.trim() != "0"
{
return Err(EncodeError::ForbiddenContentLength { status_code: 205 });
}
let mut buf = Vec::new();
buf.extend_from_slice(response.version.as_bytes());
buf.push(b' ');
buf.extend_from_slice(response.status_code.to_string().as_bytes());
buf.push(b' ');
buf.extend_from_slice(response.reason_phrase.as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in &response.headers {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
Ok(buf)
}
impl Request {
pub fn encode_headers(&self) -> Vec<u8> {
encode_request_headers(self).expect("HTTP/1.1 request requires Host header")
}
pub fn try_encode_headers(&self) -> Result<Vec<u8>, EncodeError> {
encode_request_headers(self)
}
}
impl Response {
pub fn encode_headers(&self) -> Vec<u8> {
encode_response_headers(self).expect("invalid header combination")
}
pub fn try_encode_headers(&self) -> Result<Vec<u8>, EncodeError> {
encode_response_headers(self)
}
}
#[derive(Debug)]
pub struct ResponseEncoder<C: Compressor = NoCompression> {
compressor: C,
}
impl Default for ResponseEncoder<NoCompression> {
fn default() -> Self {
Self::new()
}
}
impl ResponseEncoder<NoCompression> {
pub fn new() -> Self {
Self {
compressor: NoCompression::new(),
}
}
}
impl<C: Compressor> ResponseEncoder<C> {
pub fn with_compressor(compressor: C) -> Self {
Self { compressor }
}
pub fn compress_body(
&mut self,
input: &[u8],
output: &mut [u8],
) -> Result<CompressionStatus, CompressionError> {
self.compressor.compress(input, output)
}
pub fn finish(&mut self, output: &mut [u8]) -> Result<CompressionStatus, CompressionError> {
self.compressor.finish(output)
}
pub fn reset(&mut self) {
self.compressor.reset();
}
}
#[derive(Debug)]
pub struct RequestEncoder<C: Compressor = NoCompression> {
compressor: C,
}
impl Default for RequestEncoder<NoCompression> {
fn default() -> Self {
Self::new()
}
}
impl RequestEncoder<NoCompression> {
pub fn new() -> Self {
Self {
compressor: NoCompression::new(),
}
}
}
impl<C: Compressor> RequestEncoder<C> {
pub fn with_compressor(compressor: C) -> Self {
Self { compressor }
}
pub fn compress_body(
&mut self,
input: &[u8],
output: &mut [u8],
) -> Result<CompressionStatus, CompressionError> {
self.compressor.compress(input, output)
}
pub fn finish(&mut self, output: &mut [u8]) -> Result<CompressionStatus, CompressionError> {
self.compressor.finish(output)
}
pub fn reset(&mut self) {
self.compressor.reset();
}
}