use std::collections::VecDeque;
use crate::error::H2Error;
pub const MAX_HEADER_FIELDS: usize = 256;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HeaderField {
pub name: Vec<u8>,
pub value: Vec<u8>,
}
impl HeaderField {
pub fn new(name: impl Into<Vec<u8>>, value: impl Into<Vec<u8>>) -> Self {
Self {
name: name.into(),
value: value.into(),
}
}
fn size(&self) -> usize {
self.name.len() + self.value.len() + 32
}
}
pub(crate) fn encode_prefix_int(buf: &mut Vec<u8>, value: u64, prefix_bits: u8, pattern: u8) {
let max = (1u64 << prefix_bits) - 1;
if value < max {
buf.push(pattern | value as u8);
} else {
buf.push(pattern | max as u8);
let mut remaining = value - max;
while remaining >= 128 {
buf.push(0x80 | (remaining & 0x7f) as u8);
remaining >>= 7;
}
buf.push(remaining as u8);
}
}
pub(crate) fn decode_prefix_int(buf: &[u8], prefix_bits: u8) -> Option<(u64, usize)> {
if buf.is_empty() {
return None;
}
let max = (1u64 << prefix_bits) - 1;
let value = u64::from(buf[0]) & max;
if value < max {
return Some((value, 1));
}
let mut value = max;
let mut shift = 0u32;
for (i, &b) in buf[1..].iter().enumerate() {
value += u64::from(b & 0x7f) << shift;
shift += 7;
if b & 0x80 == 0 {
return Some((value, i + 2));
}
if shift > 56 {
return None; }
}
None }
const STATIC_TABLE: &[(&[u8], &[u8])] = &[
(b":authority", b""), (b":method", b"GET"), (b":method", b"POST"), (b":path", b"/"), (b":path", b"/index.html"), (b":scheme", b"http"), (b":scheme", b"https"), (b":status", b"200"), (b":status", b"204"), (b":status", b"206"), (b":status", b"304"), (b":status", b"400"), (b":status", b"404"), (b":status", b"500"), (b"accept-charset", b""), (b"accept-encoding", b"gzip, deflate"), (b"accept-language", b""), (b"accept-ranges", b""), (b"accept", b""), (b"access-control-allow-origin", b""), (b"age", b""), (b"allow", b""), (b"authorization", b""), (b"cache-control", b""), (b"content-disposition", b""), (b"content-encoding", b""), (b"content-language", b""), (b"content-length", b""), (b"content-location", b""), (b"content-range", b""), (b"content-type", b""), (b"cookie", b""), (b"date", b""), (b"etag", b""), (b"expect", b""), (b"expires", b""), (b"from", b""), (b"host", b""), (b"if-match", b""), (b"if-modified-since", b""), (b"if-none-match", b""), (b"if-range", b""), (b"if-unmodified-since", b""), (b"last-modified", b""), (b"link", b""), (b"location", b""), (b"max-forwards", b""), (b"proxy-authenticate", b""), (b"proxy-authorization", b""), (b"range", b""), (b"referer", b""), (b"refresh", b""), (b"retry-after", b""), (b"server", b""), (b"set-cookie", b""), (b"strict-transport-security", b""), (b"transfer-encoding", b""), (b"user-agent", b""), (b"vary", b""), (b"via", b""), (b"www-authenticate", b""), ];
fn find_static_name_value(name: &[u8], value: &[u8]) -> Option<usize> {
STATIC_TABLE
.iter()
.position(|(n, v)| *n == name && *v == value)
.map(|i| i + 1) }
fn find_static_name(name: &[u8]) -> Option<usize> {
STATIC_TABLE
.iter()
.position(|(n, _)| *n == name)
.map(|i| i + 1)
}
pub struct DynamicTable {
entries: VecDeque<HeaderField>,
size: usize,
max_size: usize,
}
impl DynamicTable {
pub fn new(max_size: usize) -> Self {
Self {
entries: VecDeque::new(),
size: 0,
max_size,
}
}
pub fn get(&self, index: usize) -> Option<&HeaderField> {
self.entries.get(index)
}
pub fn insert(&mut self, field: HeaderField) {
let entry_size = field.size();
while self.size + entry_size > self.max_size && !self.entries.is_empty() {
if let Some(evicted) = self.entries.pop_back() {
self.size -= evicted.size();
}
}
if entry_size > self.max_size {
self.entries.clear();
self.size = 0;
return;
}
self.entries.push_front(field);
self.size += entry_size;
}
pub fn set_max_size(&mut self, max_size: usize) {
self.max_size = max_size;
while self.size > self.max_size && !self.entries.is_empty() {
if let Some(evicted) = self.entries.pop_back() {
self.size -= evicted.size();
}
}
}
fn find_name_value(&self, name: &[u8], value: &[u8]) -> Option<usize> {
self.entries
.iter()
.position(|h| h.name == name && h.value == value)
.map(|i| i + 62) }
fn find_name(&self, name: &[u8]) -> Option<usize> {
self.entries
.iter()
.position(|h| h.name == name)
.map(|i| i + 62)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
fn encode_string_literal(buf: &mut Vec<u8>, data: &[u8]) {
let huf_len = crate::huffman::encoded_len(data);
if huf_len < data.len() {
encode_prefix_int(buf, huf_len as u64, 7, 0x80);
crate::huffman::encode(data, buf);
} else {
encode_prefix_int(buf, data.len() as u64, 7, 0x00);
buf.extend_from_slice(data);
}
}
fn decode_string_literal(buf: &[u8]) -> Result<(Vec<u8>, usize), H2Error> {
if buf.is_empty() {
return Err(H2Error::CompressionError);
}
let huffman = buf[0] & 0x80 != 0;
let (str_len, n) = decode_prefix_int(buf, 7).ok_or(H2Error::CompressionError)?;
let str_len = str_len as usize;
let total = n + str_len;
if buf.len() < total {
return Err(H2Error::CompressionError);
}
let data = &buf[n..total];
let value = if huffman {
crate::huffman::decode(data)?
} else {
data.to_vec()
};
Ok((value, total))
}
pub struct Encoder {
dynamic_table: DynamicTable,
pending_min: Option<usize>,
pending_latest: Option<usize>,
}
impl Encoder {
pub fn new(max_table_size: usize) -> Self {
Encoder {
dynamic_table: DynamicTable::new(max_table_size),
pending_min: None,
pending_latest: None,
}
}
pub fn set_max_table_size(&mut self, new_size: usize, buf: &mut Vec<u8>) {
self.dynamic_table.set_max_size(new_size);
encode_prefix_int(buf, new_size as u64, 5, 0x20);
}
pub fn update_max_table_size(&mut self, new_size: usize) {
self.pending_min = Some(match self.pending_min {
Some(m) => m.min(new_size),
None => new_size,
});
self.pending_latest = Some(new_size);
}
pub fn encode(&mut self, headers: &[HeaderField], buf: &mut Vec<u8>) {
if let Some(latest) = self.pending_latest.take() {
let min = self.pending_min.take().unwrap_or(latest);
if min != latest {
self.dynamic_table.set_max_size(min);
encode_prefix_int(buf, min as u64, 5, 0x20);
}
self.dynamic_table.set_max_size(latest);
encode_prefix_int(buf, latest as u64, 5, 0x20);
}
for header in headers {
self.encode_header(header, buf);
}
}
fn encode_header(&mut self, header: &HeaderField, buf: &mut Vec<u8>) {
if let Some(index) = find_static_name_value(&header.name, &header.value) {
encode_prefix_int(buf, index as u64, 7, 0x80);
return;
}
if let Some(index) = self
.dynamic_table
.find_name_value(&header.name, &header.value)
{
encode_prefix_int(buf, index as u64, 7, 0x80);
return;
}
if let Some(name_index) = find_static_name(&header.name) {
encode_prefix_int(buf, name_index as u64, 6, 0x40);
encode_string_literal(buf, &header.value);
self.dynamic_table.insert(header.clone());
return;
}
if let Some(name_index) = self.dynamic_table.find_name(&header.name) {
encode_prefix_int(buf, name_index as u64, 6, 0x40);
encode_string_literal(buf, &header.value);
self.dynamic_table.insert(header.clone());
return;
}
buf.push(0x40);
encode_string_literal(buf, &header.name);
encode_string_literal(buf, &header.value);
self.dynamic_table.insert(header.clone());
}
}
pub struct Decoder {
dynamic_table: DynamicTable,
max_table_size: usize,
}
impl Decoder {
pub fn new(max_table_size: usize) -> Self {
Self {
dynamic_table: DynamicTable::new(max_table_size),
max_table_size,
}
}
pub fn decode(&mut self, buf: &[u8]) -> Result<Vec<HeaderField>, H2Error> {
let mut headers = Vec::new();
let mut pos = 0;
let mut seen_header = false;
while pos < buf.len() {
let first = buf[pos];
if first & 0x80 != 0 {
let (index, n) =
decode_prefix_int(&buf[pos..], 7).ok_or(H2Error::CompressionError)?;
pos += n;
if headers.len() >= MAX_HEADER_FIELDS {
return Err(H2Error::MaxSizeExceeded(format!(
"HPACK block exceeds MAX_HEADER_FIELDS ({MAX_HEADER_FIELDS})"
)));
}
let field = self.get_indexed(index as usize)?;
headers.push(field);
seen_header = true;
} else if first & 0x40 != 0 {
let (name_index, n) =
decode_prefix_int(&buf[pos..], 6).ok_or(H2Error::CompressionError)?;
pos += n;
let name = if name_index > 0 {
self.get_name(name_index as usize)?
} else {
let (name, consumed) = decode_string_literal(&buf[pos..])?;
pos += consumed;
name
};
let (value, consumed) = decode_string_literal(&buf[pos..])?;
pos += consumed;
if headers.len() >= MAX_HEADER_FIELDS {
return Err(H2Error::MaxSizeExceeded(format!(
"HPACK block exceeds MAX_HEADER_FIELDS ({MAX_HEADER_FIELDS})"
)));
}
let field = HeaderField {
name: name.clone(),
value,
};
self.dynamic_table.insert(field.clone());
headers.push(field);
seen_header = true;
} else if first & 0x20 != 0 {
if seen_header {
return Err(H2Error::CompressionError);
}
let (new_size, n) =
decode_prefix_int(&buf[pos..], 5).ok_or(H2Error::CompressionError)?;
pos += n;
let new_size = new_size as usize;
if new_size > self.max_table_size {
return Err(H2Error::CompressionError);
}
self.dynamic_table.set_max_size(new_size);
} else if first & 0x10 != 0 {
let (name_index, n) =
decode_prefix_int(&buf[pos..], 4).ok_or(H2Error::CompressionError)?;
pos += n;
let name = if name_index > 0 {
self.get_name(name_index as usize)?
} else {
let (name, consumed) = decode_string_literal(&buf[pos..])?;
pos += consumed;
name
};
let (value, consumed) = decode_string_literal(&buf[pos..])?;
pos += consumed;
if headers.len() >= MAX_HEADER_FIELDS {
return Err(H2Error::MaxSizeExceeded(format!(
"HPACK block exceeds MAX_HEADER_FIELDS ({MAX_HEADER_FIELDS})"
)));
}
headers.push(HeaderField { name, value });
seen_header = true;
} else {
let (name_index, n) =
decode_prefix_int(&buf[pos..], 4).ok_or(H2Error::CompressionError)?;
pos += n;
let name = if name_index > 0 {
self.get_name(name_index as usize)?
} else {
let (name, consumed) = decode_string_literal(&buf[pos..])?;
pos += consumed;
name
};
let (value, consumed) = decode_string_literal(&buf[pos..])?;
pos += consumed;
if headers.len() >= MAX_HEADER_FIELDS {
return Err(H2Error::MaxSizeExceeded(format!(
"HPACK block exceeds MAX_HEADER_FIELDS ({MAX_HEADER_FIELDS})"
)));
}
headers.push(HeaderField { name, value });
seen_header = true;
}
}
Ok(headers)
}
fn get_indexed(&self, index: usize) -> Result<HeaderField, H2Error> {
if index == 0 {
return Err(H2Error::CompressionError);
}
if index <= STATIC_TABLE.len() {
let (name, value) = STATIC_TABLE[index - 1];
Ok(HeaderField {
name: name.to_vec(),
value: value.to_vec(),
})
} else {
let dyn_index = index - STATIC_TABLE.len() - 1;
self.dynamic_table
.get(dyn_index)
.cloned()
.ok_or(H2Error::CompressionError)
}
}
fn get_name(&self, index: usize) -> Result<Vec<u8>, H2Error> {
if index == 0 {
return Err(H2Error::CompressionError);
}
if index <= STATIC_TABLE.len() {
Ok(STATIC_TABLE[index - 1].0.to_vec())
} else {
let dyn_index = index - STATIC_TABLE.len() - 1;
self.dynamic_table
.get(dyn_index)
.map(|h| h.name.clone())
.ok_or(H2Error::CompressionError)
}
}
pub fn set_max_table_size(&mut self, max_size: usize) {
self.max_table_size = max_size;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prefix_int_round_trip() {
for &(value, prefix_bits, pattern) in &[
(0u64, 7, 0x80u8),
(5, 7, 0x80),
(126, 7, 0x80),
(127, 7, 0x80),
(128, 7, 0x80),
(1000, 7, 0x80),
(0, 6, 0x40),
(62, 6, 0x40),
(63, 6, 0x40),
(64, 6, 0x40),
(255, 6, 0x40),
(0, 5, 0x20),
(31, 5, 0x20),
(32, 5, 0x20),
(4096, 5, 0x20),
(0, 4, 0x00),
(15, 4, 0x00),
(16, 4, 0x00),
] {
let mut buf = Vec::new();
encode_prefix_int(&mut buf, value, prefix_bits, pattern);
let (decoded, len) = decode_prefix_int(&buf, prefix_bits).unwrap();
assert_eq!(
decoded, value,
"mismatch for value={value} prefix={prefix_bits}"
);
assert_eq!(len, buf.len());
let mask = !((1u8 << prefix_bits) - 1);
assert_eq!(buf[0] & mask, pattern & mask);
}
}
#[test]
fn static_table_size() {
assert_eq!(STATIC_TABLE.len(), 61);
}
#[test]
fn encode_decode_indexed() {
let mut encoder = Encoder::new(4096);
let mut decoder = Decoder::new(4096);
let headers = vec![HeaderField::new(b":method", b"GET")];
let mut buf = Vec::new();
encoder.encode(&headers, &mut buf);
let decoded = decoder.decode(&buf).unwrap();
assert_eq!(decoded, headers);
}
#[test]
fn encode_decode_name_reference() {
let mut encoder = Encoder::new(4096);
let mut decoder = Decoder::new(4096);
let headers = vec![HeaderField::new(b":path", b"/foo")];
let mut buf = Vec::new();
encoder.encode(&headers, &mut buf);
let decoded = decoder.decode(&buf).unwrap();
assert_eq!(decoded, headers);
}
#[test]
fn encode_decode_literal() {
let mut encoder = Encoder::new(4096);
let mut decoder = Decoder::new(4096);
let headers = vec![HeaderField::new(b"x-custom", b"value123")];
let mut buf = Vec::new();
encoder.encode(&headers, &mut buf);
let decoded = decoder.decode(&buf).unwrap();
assert_eq!(decoded, headers);
}
#[test]
fn encode_decode_multiple_headers() {
let mut encoder = Encoder::new(4096);
let mut decoder = Decoder::new(4096);
let headers = vec![
HeaderField::new(b":method", b"GET"),
HeaderField::new(b":path", b"/"),
HeaderField::new(b":scheme", b"https"),
HeaderField::new(b":authority", b"example.com"),
HeaderField::new(b"accept", b"*/*"),
HeaderField::new(b"x-request-id", b"abc123"),
];
let mut buf = Vec::new();
encoder.encode(&headers, &mut buf);
let decoded = decoder.decode(&buf).unwrap();
assert_eq!(decoded, headers);
}
#[test]
fn dynamic_table_reuse() {
let mut encoder = Encoder::new(4096);
let mut decoder = Decoder::new(4096);
let headers1 = vec![
HeaderField::new(b":method", b"GET"),
HeaderField::new(b"x-token", b"abc"),
];
let mut buf1 = Vec::new();
encoder.encode(&headers1, &mut buf1);
let decoded1 = decoder.decode(&buf1).unwrap();
assert_eq!(decoded1, headers1);
let headers2 = vec![
HeaderField::new(b":method", b"GET"),
HeaderField::new(b"x-token", b"abc"),
];
let mut buf2 = Vec::new();
encoder.encode(&headers2, &mut buf2);
let decoded2 = decoder.decode(&buf2).unwrap();
assert_eq!(decoded2, headers2);
assert!(buf2.len() <= buf1.len());
}
#[test]
fn dynamic_table_eviction() {
let mut encoder = Encoder::new(64);
let mut decoder = Decoder::new(64);
let headers = vec![HeaderField::new(
b"x-long-header-name",
b"a-somewhat-long-value",
)];
let mut buf = Vec::new();
encoder.encode(&headers, &mut buf);
let decoded = decoder.decode(&buf).unwrap();
assert_eq!(decoded, headers);
}
#[test]
fn encode_decode_status_200() {
let mut encoder = Encoder::new(4096);
let mut decoder = Decoder::new(4096);
let headers = vec![
HeaderField::new(b":status", b"200"),
HeaderField::new(b"content-type", b"text/plain"),
];
let mut buf = Vec::new();
encoder.encode(&headers, &mut buf);
let decoded = decoder.decode(&buf).unwrap();
assert_eq!(decoded, headers);
}
#[test]
fn table_size_update() {
let mut encoder = Encoder::new(4096);
let mut decoder = Decoder::new(4096);
let mut buf = Vec::new();
encoder.set_max_table_size(256, &mut buf);
encoder.encode(&[HeaderField::new(b":method", b"GET")], &mut buf);
let decoded = decoder.decode(&buf).unwrap();
assert_eq!(decoded, vec![HeaderField::new(b":method", b"GET")]);
}
#[test]
fn encoder_emits_min_then_latest_when_multiple_updates_pending() {
let mut encoder = Encoder::new(4096);
encoder.update_max_table_size(4096);
encoder.update_max_table_size(1024); encoder.update_max_table_size(2048);
let mut buf = Vec::new();
encoder.encode(&[HeaderField::new(b":method", b"GET")], &mut buf);
let (first, n) = decode_prefix_int(&buf, 5).unwrap();
assert_eq!(first, 1024, "expected min (1024) first");
let (second, _) = decode_prefix_int(&buf[n..], 5).unwrap();
assert_eq!(second, 2048, "expected latest (2048) second");
}
#[test]
fn encoder_single_update_when_min_equals_latest() {
let mut encoder = Encoder::new(4096);
encoder.update_max_table_size(2048);
encoder.update_max_table_size(2048);
let mut buf = Vec::new();
encoder.encode(&[HeaderField::new(b":method", b"GET")], &mut buf);
let (size, n) = decode_prefix_int(&buf, 5).unwrap();
assert_eq!(size, 2048);
assert_eq!(buf[n], 0x82);
}
#[test]
fn decoder_caps_at_max_header_fields() {
let mut decoder = Decoder::new(4096);
let block: Vec<u8> = std::iter::repeat_n(0x82u8, MAX_HEADER_FIELDS + 1).collect();
let err = decoder.decode(&block).err().unwrap();
assert!(matches!(err, H2Error::MaxSizeExceeded(_)));
}
#[test]
fn decoder_rejects_table_size_update_after_header() {
let mut decoder = Decoder::new(4096);
let block = vec![0x82, 0x20]; let err = decoder.decode(&block).err().unwrap();
assert!(matches!(err, H2Error::CompressionError));
}
#[test]
fn decoder_accepts_table_size_update_before_header() {
let mut decoder = Decoder::new(4096);
let block = vec![0x20, 0x82]; let headers = decoder.decode(&block).unwrap();
assert_eq!(headers.len(), 1);
}
#[test]
fn rfc7541_appendix_c1_integer_examples() {
let mut buf = Vec::new();
encode_prefix_int(&mut buf, 10, 5, 0x00);
assert_eq!(buf, vec![0x0a]);
let mut buf = Vec::new();
encode_prefix_int(&mut buf, 1337, 5, 0x00);
assert_eq!(buf, vec![0x1f, 0x9a, 0x0a]);
let mut buf = Vec::new();
encode_prefix_int(&mut buf, 42, 8, 0x00);
assert_eq!(buf, vec![0x2a]);
}
}