use super::dynamic_table::DynamicTable;
use super::error::HpackError;
use super::huffman::huffman_decode;
use super::integer::decode_integer;
use super::static_table::get_static_entry;
const STATIC_TABLE_SIZE: usize = 61;
pub struct Decoder {
dynamic_table: DynamicTable,
}
impl Decoder {
pub fn new() -> Self {
Self {
dynamic_table: DynamicTable::new(4096),
}
}
pub fn set_max_table_size(&mut self, size: usize) {
self.dynamic_table.set_max_size(size);
}
pub fn decode_with_cb<F>(&mut self, data: &[u8], mut cb: F) -> Result<(), HpackError>
where
F: FnMut(&[u8], &[u8]),
{
let mut pos = 0;
while pos < data.len() {
let byte = data[pos];
if (byte & 0xE0) == 0x20 {
let (size, consumed) = decode_integer(&data[pos..], 5, 0x1F)?;
self.dynamic_table.set_max_size(size);
pos += consumed + 1;
continue;
}
if (byte & 0x80) != 0 {
let (index, consumed) = decode_integer(&data[pos..], 7, 0x7F)?;
pos += consumed + 1;
let (name, value) = self.get_entry(index)?;
cb(name, value);
continue;
}
if (byte & 0xC0) == 0x40 {
let (name_idx, consumed) = decode_integer(&data[pos..], 6, 0x3F)?;
pos += consumed + 1;
let name = if name_idx == 0 {
let (name_bytes, name_consumed) = self.decode_string_literal(&data[pos..])?;
pos += name_consumed;
name_bytes
} else {
let (name_bytes, _) = self.get_entry(name_idx)?;
name_bytes.to_vec()
};
let (value_bytes, value_consumed) = self.decode_string_literal(&data[pos..])?;
pos += value_consumed;
cb(&name, &value_bytes);
self.dynamic_table.add(name, value_bytes);
continue;
}
if (byte & 0xF0) == 0x00 {
let (name_idx, consumed) = decode_integer(&data[pos..], 4, 0x0F)?;
pos += consumed + 1;
let name = if name_idx == 0 {
let (name_bytes, name_consumed) = self.decode_string_literal(&data[pos..])?;
pos += name_consumed;
name_bytes
} else {
let (name_bytes, _) = self.get_entry(name_idx)?;
name_bytes.to_vec()
};
let (value_bytes, value_consumed) = self.decode_string_literal(&data[pos..])?;
pos += value_consumed;
cb(&name, &value_bytes);
continue;
}
if (byte & 0xF0) == 0x10 {
let (name_idx, consumed) = decode_integer(&data[pos..], 4, 0x0F)?;
pos += consumed + 1;
let name = if name_idx == 0 {
let (name_bytes, name_consumed) = self.decode_string_literal(&data[pos..])?;
pos += name_consumed;
name_bytes
} else {
let (name_bytes, _) = self.get_entry(name_idx)?;
name_bytes.to_vec()
};
let (value_bytes, value_consumed) = self.decode_string_literal(&data[pos..])?;
pos += value_consumed;
cb(&name, &value_bytes);
continue;
}
return Err(HpackError::Decode(format!(
"Invalid header field representation: 0x{:02x}",
byte
)));
}
Ok(())
}
fn get_entry(&self, index: usize) -> Result<(&[u8], &[u8]), HpackError> {
if index == 0 {
return Err(HpackError::InvalidIndex(0));
}
if index <= STATIC_TABLE_SIZE {
get_static_entry(index).ok_or(HpackError::InvalidIndex(index))
} else {
let dynamic_idx = index - STATIC_TABLE_SIZE;
self.dynamic_table
.get(dynamic_idx)
.map(|e| (e.name(), e.value()))
.ok_or(HpackError::InvalidIndex(index))
}
}
fn decode_string_literal(&self, data: &[u8]) -> Result<(Vec<u8>, usize), HpackError> {
if data.is_empty() {
return Err(HpackError::UnexpectedEof);
}
let h_flag = (data[0] & 0x80) != 0;
let (length, length_consumed) = decode_integer(data, 7, 0x7F)?;
let data_start = length_consumed + 1;
if data_start + length > data.len() {
return Err(HpackError::UnexpectedEof);
}
let string_data = &data[data_start..data_start + length];
let decoded = if h_flag {
huffman_decode(string_data)?
} else {
string_data.to_vec()
};
Ok((decoded, data_start + length))
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::h2::hpack_impl::encoder::Encoder;
#[test]
fn test_decode_indexed_header() {
let mut decoder = Decoder::new();
let data = [0x82];
let mut headers = Vec::new();
decoder
.decode_with_cb(&data, |name, value| {
headers.push((name.to_vec(), value.to_vec()));
})
.unwrap();
assert_eq!(headers.len(), 1);
assert_eq!(headers[0].0, b":method");
assert_eq!(headers[0].1, b"GET");
}
#[test]
fn test_decode_literal() {
let mut encoder = Encoder::new();
let headers = [(b"custom-key".as_slice(), b"custom-value".as_slice())];
let encoded = encoder.encode(&headers);
let mut decoder = Decoder::new();
let mut headers = Vec::new();
decoder
.decode_with_cb(&encoded, |name, value| {
headers.push((name.to_vec(), value.to_vec()));
})
.unwrap();
assert_eq!(headers.len(), 1);
assert_eq!(headers[0].0, b"custom-key");
assert_eq!(headers[0].1, b"custom-value");
}
#[test]
fn test_round_trip() {
let headers = [
(b":method".as_slice(), b"GET".as_slice()),
(b":scheme".as_slice(), b"http".as_slice()),
(b":path".as_slice(), b"/".as_slice()),
(b":authority".as_slice(), b"www.example.com".as_slice()),
];
let mut encoder = Encoder::new();
let encoded = encoder.encode(&headers);
let mut decoder = Decoder::new();
let mut decoded = Vec::new();
decoder
.decode_with_cb(&encoded, |name, value| {
decoded.push((name.to_vec(), value.to_vec()));
})
.unwrap();
assert_eq!(decoded.len(), 4);
assert_eq!(decoded[0].0, b":method");
assert_eq!(decoded[0].1, b"GET");
}
}