use crate::headers::Headers;
use crate::transport::h2::hpack_impl::{Decoder, Encoder};
use bytes::Bytes;
fn bytes_eq_ignore_ascii_case(a: &[u8], b: &[u8]) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(x, y)| x.eq_ignore_ascii_case(y))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum PseudoHeaderOrder {
#[default]
Chrome,
Firefox,
Safari,
Standard,
Custom([u8; 4]),
}
impl PseudoHeaderOrder {
fn order(&self) -> [usize; 4] {
match self {
Self::Chrome => [0, 2, 1, 3], Self::Firefox => [0, 3, 1, 2], Self::Safari => [0, 2, 3, 1], Self::Standard => [0, 1, 2, 3], Self::Custom(order) => [
order[0] as usize,
order[1] as usize,
order[2] as usize,
order[3] as usize,
],
}
}
pub fn akamai_string(&self) -> &'static str {
match self {
Self::Chrome => "m,s,a,p",
Self::Firefox => "m,p,a,s",
Self::Safari => "m,s,p,a",
Self::Standard => "m,a,s,p",
Self::Custom(_) => "custom",
}
}
}
pub struct HpackEncoder {
encoder: Encoder,
pseudo_order: PseudoHeaderOrder,
}
impl HpackEncoder {
pub fn new(pseudo_order: PseudoHeaderOrder) -> Self {
Self {
encoder: Encoder::new(),
pseudo_order,
}
}
pub fn chrome() -> Self {
Self::new(PseudoHeaderOrder::Chrome)
}
pub fn set_max_table_size(&mut self, size: usize) {
self.encoder.set_max_table_size(size);
}
pub fn encode_request(
&mut self,
method: &str,
scheme: &str,
authority: &str,
path: &str,
headers: impl Into<Headers>,
) -> Bytes {
let headers = headers.into();
let pseudo_headers: [(&[u8], &[u8]); 4] = [
(b":method", method.as_bytes()),
(b":authority", authority.as_bytes()),
(b":scheme", scheme.as_bytes()),
(b":path", path.as_bytes()),
];
let mut all_headers: Vec<(&[u8], &[u8])> = Vec::new();
let mut valid_headers: Vec<(Vec<u8>, &[u8])> = Vec::with_capacity(headers.len());
for (name, value) in headers.iter_bytes() {
if name.first() == Some(&b':') {
continue;
}
if name.is_empty() {
continue;
}
if name.iter().any(|&b| b < 0x21 || (b > 0x7E && b != 0x7F)) {
continue;
}
let name_lower = if name.iter().all(|b| b.is_ascii_lowercase()) {
name.to_vec()
} else {
name.iter().map(|b| b.to_ascii_lowercase()).collect()
};
if name_lower == b"connection"
|| name_lower == b"keep-alive"
|| name_lower == b"proxy-connection"
|| name_lower == b"transfer-encoding"
|| name_lower == b"upgrade"
{
continue;
}
if name_lower == b"te" && !bytes_eq_ignore_ascii_case(value, b"trailers") {
continue;
}
valid_headers.push((name_lower, value));
}
let order = self.pseudo_order.order();
for &idx in &order {
all_headers.push(pseudo_headers[idx]);
}
for (n, v) in &valid_headers {
all_headers.push((n.as_slice(), *v));
}
let encoded = self.encoder.encode(&all_headers);
Bytes::from(encoded)
}
pub fn encode_extended_connect_websocket(
&mut self,
authority: &str,
scheme: &str,
path: &str,
headers: impl Into<Headers>,
) -> Result<Bytes, String> {
let headers = headers.into();
if authority.is_empty() {
return Err(":authority must not be empty".to_string());
}
if scheme.is_empty() {
return Err(":scheme must not be empty".to_string());
}
if path.is_empty() {
return Err(":path must not be empty".to_string());
}
let pseudo_headers: [(&[u8], &[u8]); 5] = [
(b":method", b"CONNECT"),
(b":protocol", b"websocket"),
(b":scheme", scheme.as_bytes()),
(b":path", path.as_bytes()),
(b":authority", authority.as_bytes()),
];
let mut valid_headers: Vec<(Vec<u8>, &[u8])> = Vec::with_capacity(headers.len());
for (name, value) in headers.iter_bytes() {
if name.first() == Some(&b':') {
return Err(format!(
"RFC 8441 user pseudo-header rejected: {}",
String::from_utf8_lossy(name)
));
}
if name.is_empty() {
return Err("RFC 8441 header name must not be empty".to_string());
}
if name.iter().any(|&b| b < 0x21 || (b > 0x7E && b != 0x7F)) {
return Err(format!(
"RFC 8441 invalid header name rejected: {}",
String::from_utf8_lossy(name)
));
}
let name_lower = if name.iter().all(|b| b.is_ascii_lowercase()) {
name.to_vec()
} else {
name.iter().map(|b| b.to_ascii_lowercase()).collect()
};
if matches!(
name_lower.as_slice(),
b"connection"
| b"upgrade"
| b"host"
| b"sec-websocket-key"
| b"sec-websocket-accept"
| b"sec-websocket-extensions"
| b"keep-alive"
| b"proxy-connection"
| b"transfer-encoding"
) {
return Err(format!(
"RFC 8441 forbidden header rejected: {}",
String::from_utf8_lossy(&name_lower)
));
}
if name_lower == b"te" && !bytes_eq_ignore_ascii_case(value, b"trailers") {
return Err("RFC 8441 forbids TE values other than trailers".to_string());
}
valid_headers.push((name_lower, value));
}
let mut all_headers: Vec<(&[u8], &[u8])> =
Vec::with_capacity(pseudo_headers.len() + valid_headers.len());
all_headers.extend_from_slice(&pseudo_headers);
for (name, value) in &valid_headers {
all_headers.push((name.as_slice(), *value));
}
let encoded = self.encoder.encode(&all_headers);
Ok(Bytes::from(encoded))
}
pub fn chunk_encoded(encoded: Bytes, max_frame_size: usize) -> (Bytes, Vec<Bytes>) {
if encoded.len() <= max_frame_size {
return (encoded, Vec::new());
}
let mut chunks: Vec<Bytes> = encoded
.chunks(max_frame_size)
.map(Bytes::copy_from_slice)
.collect();
let first = chunks.remove(0);
(first, chunks)
}
}
pub struct HpackDecoder {
decoder: Decoder,
}
impl HpackDecoder {
pub fn new() -> Self {
Self {
decoder: Decoder::new(),
}
}
pub fn set_max_table_size(&mut self, size: usize) {
self.decoder.set_max_table_size(size);
}
pub fn decode(&mut self, data: &[u8]) -> Result<Vec<(String, String)>, String> {
let mut headers = Vec::new();
self.decoder
.decode_with_cb(data, |name, value| {
let name_str = String::from_utf8_lossy(name).into_owned();
let value_str = String::from_utf8_lossy(value).into_owned();
headers.push((name_str, value_str));
})
.map_err(|e| format!("HPACK decode error: {:?}", e))?;
Ok(headers)
}
}
impl Default for HpackDecoder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pseudo_order_chrome() {
let order = PseudoHeaderOrder::Chrome;
assert_eq!(order.akamai_string(), "m,s,a,p");
}
#[test]
fn test_pseudo_order_standard() {
let order = PseudoHeaderOrder::Standard;
assert_eq!(order.akamai_string(), "m,a,s,p");
}
#[test]
fn test_encoder_creates_valid_block() {
let mut encoder = HpackEncoder::chrome();
let block = encoder.encode_request(
"GET",
"https",
"example.com",
"/",
&Headers::from(vec![("user-agent".to_string(), "test".to_string())]),
);
assert!(!block.is_empty());
let mut decoder = HpackDecoder::new();
let headers = decoder.decode(&block).unwrap();
assert_eq!(headers.len(), 5);
assert_eq!(headers[0].0, ":method");
assert_eq!(headers[0].1, "GET");
assert_eq!(headers[1].0, ":scheme");
assert_eq!(headers[1].1, "https");
assert_eq!(headers[2].0, ":authority");
assert_eq!(headers[2].1, "example.com");
assert_eq!(headers[3].0, ":path");
assert_eq!(headers[3].1, "/");
assert_eq!(headers[4].0, "user-agent");
assert_eq!(headers[4].1, "test");
}
#[test]
fn test_encoder_standard_order() {
let mut encoder = HpackEncoder::new(PseudoHeaderOrder::Standard);
let block = encoder.encode_request("GET", "https", "example.com", "/", &Headers::new());
let mut decoder = HpackDecoder::new();
let headers = decoder.decode(&block).unwrap();
assert_eq!(headers[0].0, ":method");
assert_eq!(headers[1].0, ":authority");
assert_eq!(headers[2].0, ":scheme");
assert_eq!(headers[3].0, ":path");
}
#[test]
fn test_encoder_filters_connection_headers() {
let mut encoder = HpackEncoder::chrome();
let block = encoder.encode_request(
"GET",
"https",
"example.com",
"/",
&Headers::from(vec![
("connection".to_string(), "keep-alive".to_string()),
("keep-alive".to_string(), "timeout=5".to_string()),
("user-agent".to_string(), "test".to_string()),
]),
);
let mut decoder = HpackDecoder::new();
let headers = decoder.decode(&block).unwrap();
assert_eq!(headers.len(), 5);
assert_eq!(headers[4].0, "user-agent");
}
}