use crate::connection::Role;
use crate::error::{Error, ErrorCode};
use crate::qpack::Header;
const CONNECTION_SPECIFIC_FIELDS: &[&[u8]] = &[
b"connection",
b"keep-alive",
b"proxy-connection",
b"transfer-encoding",
b"upgrade",
];
fn is_tchar(b: u8) -> bool {
matches!(b,
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' |
b'^' | b'_' | b'`' | b'|' | b'~' |
b'0'..=b'9' |
b'A'..=b'Z' |
b'a'..=b'z'
)
}
fn is_valid_method(method: &[u8]) -> bool {
!method.is_empty() && method.iter().all(|&b| is_tchar(b))
}
fn is_valid_scheme(scheme: &[u8]) -> bool {
if scheme.is_empty() {
return false;
}
if !scheme[0].is_ascii_alphabetic() {
return false;
}
scheme[1..]
.iter()
.all(|&b| b.is_ascii_alphanumeric() || b == b'+' || b == b'-' || b == b'.')
}
fn is_valid_protocol(value: &[u8]) -> bool {
if value.is_empty() {
return false;
}
let mut parts = value.splitn(2, |b| *b == b'/');
let name = parts.next().unwrap_or(&[]);
if name.is_empty() || !name.iter().all(|&b| is_tchar(b)) {
return false;
}
if let Some(version) = parts.next()
&& (version.is_empty() || !version.iter().all(|&b| is_tchar(b)))
{
return false;
}
true
}
fn is_valid_http_path(path: &[u8], method: &[u8]) -> bool {
if path.is_empty() {
return false;
}
if method == b"OPTIONS" && path == b"*" {
return true;
}
path[0] == b'/'
}
fn is_valid_field_name_byte(b: u8) -> bool {
matches!(b,
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' |
b'^' | b'_' | b'`' | b'|' | b'~' |
b'0'..=b'9' |
b'a'..=b'z'
)
}
fn is_valid_field_name(name: &[u8]) -> bool {
!name.is_empty() && name.iter().all(|&b| is_valid_field_name_byte(b))
}
fn is_valid_field_value(value: &[u8]) -> bool {
if value.is_empty() {
return true;
}
let is_field_vchar = |b: u8| -> bool { matches!(b, 0x21..=0x7e | 0x80..=0xff) };
if !is_field_vchar(value[0]) || !is_field_vchar(value[value.len() - 1]) {
return false;
}
value
.iter()
.all(|&b| is_field_vchar(b) || b == b' ' || b == b'\t')
}
fn is_reg_name_char(b: u8) -> bool {
b.is_ascii_alphanumeric()
|| matches!(
b,
b'-' | b'.'
| b'_'
| b'~'
| b'!'
| b'$'
| b'&'
| b'\''
| b'('
| b')'
| b'*'
| b'+'
| b','
| b';'
| b'='
| b'%'
)
}
fn is_valid_reg_name(host: &[u8]) -> bool {
let mut i = 0;
while i < host.len() {
if host[i] == b'%' {
if i + 2 >= host.len()
|| !host[i + 1].is_ascii_hexdigit()
|| !host[i + 2].is_ascii_hexdigit()
{
return false;
}
i += 3;
} else if is_reg_name_char(host[i]) {
i += 1;
} else {
return false;
}
}
true
}
fn is_valid_ip_literal_content(content: &[u8]) -> bool {
if content.is_empty() {
return false;
}
if content[0] == b'v' {
return content.len() >= 4
&& content[1..].contains(&b'.')
&& content[1..].iter().all(|&b| {
b.is_ascii_hexdigit()
|| b == b'.'
|| b == b':'
|| b == b'-'
|| b == b'_'
|| b == b'~'
|| matches!(
b,
b'!' | b'$' | b'&' | b'\'' | b'(' | b')' | b'*' | b'+' | b',' | b';' | b'='
)
});
}
content
.iter()
.all(|&b| b.is_ascii_hexdigit() || b == b':' || b == b'.')
}
fn is_valid_authority(value: &[u8]) -> bool {
if value.is_empty() {
return false;
}
if value[0] == b'[' {
let Some(bracket_end) = value.iter().position(|&b| b == b']') else {
return false;
};
let literal_content = &value[1..bracket_end];
if !is_valid_ip_literal_content(literal_content) {
return false;
}
let rest = &value[bracket_end + 1..];
if rest.is_empty() {
return true;
}
if rest[0] != b':' {
return false;
}
let port = &rest[1..];
return port.iter().all(|b| b.is_ascii_digit());
}
if let Some(colon_pos) = value.iter().rposition(|&b| b == b':') {
let host = &value[..colon_pos];
let port = &value[colon_pos + 1..];
if !port.is_empty() && port.iter().all(|b| b.is_ascii_digit()) {
return !host.is_empty() && is_valid_reg_name(host);
}
}
is_valid_reg_name(value)
}
fn is_valid_connect_authority(value: &[u8]) -> bool {
if value.is_empty() {
return false;
}
if value[0] == b'[' {
let Some(bracket_end) = value.iter().position(|&b| b == b']') else {
return false;
};
if bracket_end + 1 >= value.len() || value[bracket_end + 1] != b':' {
return false;
}
let literal_content = &value[1..bracket_end];
if !is_valid_ip_literal_content(literal_content) {
return false;
}
let port = &value[bracket_end + 2..];
return !port.is_empty() && port.iter().all(|b| b.is_ascii_digit());
}
let Some(colon_pos) = value.iter().rposition(|&b| b == b':') else {
return false;
};
let host = &value[..colon_pos];
let port = &value[colon_pos + 1..];
!host.is_empty()
&& !port.is_empty()
&& is_valid_reg_name(host)
&& port.iter().all(|b| b.is_ascii_digit())
}
pub fn validate_request_headers(headers: &[Header]) -> Result<(), Error> {
let mut method: Option<&[u8]> = None;
let mut scheme: Option<&[u8]> = None;
let mut path: Option<&[u8]> = None;
let mut authority: Option<&[u8]> = None;
let mut protocol: Option<&[u8]> = None;
let mut host: Option<&[u8]> = None;
let mut pseudo_done = false;
for header in headers {
if header.name().starts_with(b":") {
if pseudo_done {
return Err(Error::StreamError(ErrorCode::MessageError));
}
match header.name() {
b":method" => {
if method.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
method = Some(header.value());
}
b":scheme" => {
if scheme.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
scheme = Some(header.value());
}
b":path" => {
if path.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
path = Some(header.value());
}
b":authority" => {
if authority.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
authority = Some(header.value());
}
b":protocol" => {
if protocol.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
protocol = Some(header.value());
}
b":status" => {
return Err(Error::StreamError(ErrorCode::MessageError));
}
_ => {
return Err(Error::StreamError(ErrorCode::MessageError));
}
}
} else {
pseudo_done = true;
if !is_valid_field_name(header.name()) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if CONNECTION_SPECIFIC_FIELDS.contains(&header.name()) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if header.name() == b"te" && header.value() != b"trailers" {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if !is_valid_field_value(header.value()) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if header.name() == b"host" {
if host.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
host = Some(header.value());
}
}
}
let method = method.ok_or(Error::StreamError(ErrorCode::MessageError))?;
if !is_valid_method(method) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if method == b"CONNECT" && protocol.is_some() {
if scheme.is_none() || path.is_none() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if let Some(p) = protocol
&& !is_valid_protocol(p)
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
if let Some(s) = scheme
&& !is_valid_scheme(s)
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
if let Some(p) = path
&& p.is_empty()
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
let ext_scheme_requires_authority =
matches!(scheme, Some(s) if s == b"http" || s == b"https");
if ext_scheme_requires_authority && authority.is_none() && host.is_none() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if let (Some(a), Some(h)) = (authority, host)
&& a != h
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
if let Some(a) = authority
&& a.is_empty()
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
if let Some(h) = host
&& h.is_empty()
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
} else if method == b"CONNECT" {
if scheme.is_some() || path.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
match authority {
None => return Err(Error::StreamError(ErrorCode::MessageError)),
Some([]) => return Err(Error::StreamError(ErrorCode::MessageError)),
Some(val) => {
if !is_valid_connect_authority(val) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
}
}
} else {
if protocol.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if scheme.is_none() || path.is_none() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if let Some(s) = scheme
&& !is_valid_scheme(s)
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
let is_http_or_https = matches!(scheme, Some(s) if s == b"http" || s == b"https");
if let Some(p) = path {
if p.is_empty() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if is_http_or_https && !is_valid_http_path(p, method) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
}
if is_http_or_https && authority.is_none() && host.is_none() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
}
if let (Some(a), Some(h)) = (authority, host)
&& a != h
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
if let Some(a) = authority
&& a.is_empty()
&& method != b"CONNECT"
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
if let Some(h) = host
&& h.is_empty()
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
if authority.is_none()
&& let Some(h) = host
&& !is_valid_authority(h)
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
let is_http_scheme = matches!(scheme, Some(s) if s == b"http" || s == b"https");
if is_http_scheme
&& let Some(a) = authority
&& a.contains(&b'@')
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
if (method != b"CONNECT" || protocol.is_some())
&& let Some(a) = authority
&& !is_valid_authority(a)
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
Ok(())
}
pub fn validate_response_headers(headers: &[Header]) -> Result<(), Error> {
let mut status: Option<&[u8]> = None;
let mut pseudo_done = false;
for header in headers {
if header.name().starts_with(b":") {
if pseudo_done {
return Err(Error::StreamError(ErrorCode::MessageError));
}
match header.name() {
b":status" => {
if status.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if header.value().len() != 3
|| !header.value().iter().all(|b| b.is_ascii_digit())
{
return Err(Error::StreamError(ErrorCode::MessageError));
}
if header.value() == b"101" {
return Err(Error::StreamError(ErrorCode::MessageError));
}
status = Some(header.value());
}
_ => {
return Err(Error::StreamError(ErrorCode::MessageError));
}
}
} else {
pseudo_done = true;
if !is_valid_field_name(header.name()) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if CONNECTION_SPECIFIC_FIELDS.contains(&header.name()) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if header.name() == b"te" {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if !is_valid_field_value(header.value()) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
}
}
status.ok_or(Error::StreamError(ErrorCode::MessageError))?;
Ok(())
}
pub fn validate_headers(headers: &[Header], role: Role) -> Result<(), Error> {
match role {
Role::Server => validate_request_headers(headers),
Role::Client => validate_response_headers(headers),
}
}
pub fn validate_content_length(
headers: &[Header],
received_body_size: u64,
skip_body_check: bool,
) -> Result<(), Error> {
let mut content_length: Option<u64> = None;
for header in headers {
if header.name() != b"content-length" {
continue;
}
if content_length.is_some() {
return Err(Error::StreamError(ErrorCode::MessageError));
}
let value_str = std::str::from_utf8(header.value())
.map_err(|_| Error::StreamError(ErrorCode::MessageError))?;
let value = value_str
.parse::<u64>()
.map_err(|_| Error::StreamError(ErrorCode::MessageError))?;
content_length = Some(value);
}
let Some(expected) = content_length else {
return Ok(());
};
if skip_body_check {
return Ok(());
}
if expected != received_body_size {
return Err(Error::StreamError(ErrorCode::MessageError));
}
Ok(())
}
pub fn validate_trailer_headers(headers: &[Header]) -> Result<(), Error> {
for header in headers {
if header.name().starts_with(b":") {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if !is_valid_field_name(header.name()) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if CONNECTION_SPECIFIC_FIELDS.contains(&header.name()) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if header.name() == b"te" {
return Err(Error::StreamError(ErrorCode::MessageError));
}
if !is_valid_field_value(header.value()) {
return Err(Error::StreamError(ErrorCode::MessageError));
}
}
Ok(())
}
pub fn calculate_field_section_size(headers: &[Header]) -> u64 {
headers.iter().map(|h| h.size() as u64).sum()
}
pub fn check_field_section_size(headers: &[Header], peer_max: Option<u64>) -> Result<(), Error> {
if let Some(max_size) = peer_max {
let size = calculate_field_section_size(headers);
if size > max_size {
return Err(Error::ConnectionError(ErrorCode::InternalError));
}
}
Ok(())
}