use std::collections::{HashMap, VecDeque};
use std::io::{self, Read, Write};
use std::net::TcpStream;
use std::sync::{Arc, Mutex, OnceLock};
use crate::error::{Error, Result};
use crate::tls::TlsStream;
use crate::{Request, Response};
const PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
const F_DATA: u8 = 0x0;
const F_HEADERS: u8 = 0x1;
#[allow(dead_code)]
const F_PRIORITY: u8 = 0x2;
const F_RST_STREAM: u8 = 0x3;
const F_SETTINGS: u8 = 0x4;
#[allow(dead_code)]
const F_PUSH_PROMISE: u8 = 0x5;
const F_PING: u8 = 0x6;
const F_GOAWAY: u8 = 0x7;
const F_WINDOW_UPDATE: u8 = 0x8;
const F_CONTINUATION: u8 = 0x9;
const FLAG_END_STREAM: u8 = 0x01;
const FLAG_ACK: u8 = 0x01;
const FLAG_END_HEADERS: u8 = 0x04;
const FLAG_PADDED: u8 = 0x08;
const FLAG_PRIORITY: u8 = 0x20;
const S_HEADER_TABLE_SIZE: u16 = 0x1;
const S_ENABLE_PUSH: u16 = 0x2;
const S_MAX_CONCURRENT_STREAMS: u16 = 0x3;
const S_INITIAL_WINDOW_SIZE: u16 = 0x4;
const S_MAX_FRAME_SIZE: u16 = 0x5;
const S_MAX_HEADER_LIST_SIZE: u16 = 0x6;
const INITIAL_WINDOW_SIZE_MAX: u32 = 0x7fff_ffff; const MAX_FRAME_SIZE_MIN: u32 = 16_384; const MAX_FRAME_SIZE_MAX: u32 = 16_777_215;
const MAX_RESPONSE_BYTES: usize = 256 * 1024 * 1024;
const MAX_HEADERS_BUF: usize = 256 * 1024;
const MAX_DECODED_HEADER_LIST: usize = 256 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
struct PeerSettings {
header_table_size: u32,
enable_push: bool,
max_concurrent_streams: u32,
initial_window_size: u32,
max_frame_size: u32,
max_header_list_size: u32,
}
impl Default for PeerSettings {
fn default() -> Self {
PeerSettings {
header_table_size: 4096,
enable_push: true,
max_concurrent_streams: u32::MAX,
initial_window_size: 65_535,
max_frame_size: 16_384,
max_header_list_size: u32::MAX,
}
}
}
impl PeerSettings {
fn apply_settings_payload(&mut self, payload: &[u8]) -> Result<()> {
if payload.len() % 6 != 0 {
return Err(Error::BadResponse(format!(
"SETTINGS payload length {} not a multiple of 6",
payload.len()
)));
}
for chunk in payload.chunks_exact(6) {
let id = u16::from_be_bytes([chunk[0], chunk[1]]);
let val = u32::from_be_bytes([chunk[2], chunk[3], chunk[4], chunk[5]]);
match id {
S_HEADER_TABLE_SIZE => self.header_table_size = val,
S_ENABLE_PUSH => {
self.enable_push = match val {
0 => false,
1 => true,
_ => {
return Err(Error::BadResponse(format!(
"SETTINGS_ENABLE_PUSH must be 0 or 1, got {val}"
)));
}
};
}
S_MAX_CONCURRENT_STREAMS => self.max_concurrent_streams = val,
S_INITIAL_WINDOW_SIZE => {
if val > INITIAL_WINDOW_SIZE_MAX {
return Err(Error::BadResponse(format!(
"SETTINGS_INITIAL_WINDOW_SIZE {val} exceeds 2^31-1 (FLOW_CONTROL_ERROR)"
)));
}
self.initial_window_size = val;
}
S_MAX_FRAME_SIZE => {
if !(MAX_FRAME_SIZE_MIN..=MAX_FRAME_SIZE_MAX).contains(&val) {
return Err(Error::BadResponse(format!(
"SETTINGS_MAX_FRAME_SIZE {val} out of range [16384, 16777215]"
)));
}
self.max_frame_size = val;
}
S_MAX_HEADER_LIST_SIZE => self.max_header_list_size = val,
_ => {
}
}
}
Ok(())
}
}
const WINDOW_MAX: i64 = 0x7fff_ffff;
const OUR_INITIAL_WINDOW: i64 = 65_535;
#[derive(Debug, Clone, PartialEq, Eq)]
struct ConnSendWindow {
available: i64,
}
impl ConnSendWindow {
fn new() -> Self {
ConnSendWindow { available: 65_535 }
}
fn apply_window_update(&mut self, increment: u32) -> Result<()> {
if increment == 0 {
return Err(Error::BadResponse(
"WINDOW_UPDATE with zero increment on connection (FLOW_CONTROL_ERROR)".into(),
));
}
let new_val = self.available + increment as i64;
if new_val > WINDOW_MAX {
return Err(Error::BadResponse(format!(
"WINDOW_UPDATE pushes conn send window to {new_val} > 2^31-1 (FLOW_CONTROL_ERROR)"
)));
}
self.available = new_val;
Ok(())
}
fn consume(&mut self, n: usize) {
self.available -= n as i64;
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct StreamSendWindow {
available: i64,
initial_peer_window: i64,
}
impl StreamSendWindow {
fn new(initial: i64) -> Self {
StreamSendWindow {
available: initial,
initial_peer_window: initial,
}
}
fn apply_window_update(&mut self, increment: u32) -> Result<()> {
if increment == 0 {
return Err(Error::BadResponse(
"WINDOW_UPDATE with zero increment on stream (PROTOCOL_ERROR)".into(),
));
}
let new_val = self.available + increment as i64;
if new_val > WINDOW_MAX {
return Err(Error::BadResponse(format!(
"WINDOW_UPDATE pushes stream send window to {new_val} > 2^31-1 (FLOW_CONTROL_ERROR)"
)));
}
self.available = new_val;
Ok(())
}
fn apply_initial_window_change(&mut self, new_initial: u32) -> Result<()> {
let new_i = new_initial as i64;
let delta = new_i - self.initial_peer_window;
let new_available = self.available + delta;
if new_available > WINDOW_MAX {
return Err(Error::BadResponse(format!(
"SETTINGS_INITIAL_WINDOW_SIZE delta pushes stream send window to {new_available} > 2^31-1 (FLOW_CONTROL_ERROR)"
)));
}
self.available = new_available;
self.initial_peer_window = new_i;
Ok(())
}
fn consume(&mut self, n: usize) {
self.available -= n as i64;
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ConnRecvWindow {
available: i64,
initial: i64,
}
impl ConnRecvWindow {
fn new() -> Self {
ConnRecvWindow {
available: OUR_INITIAL_WINDOW,
initial: OUR_INITIAL_WINDOW,
}
}
fn consume(&mut self, n: usize) {
self.available -= n as i64;
}
fn replenish(&mut self) -> Option<Frame> {
let threshold = self.initial / 2;
if self.available < threshold {
let inc = (self.initial - self.available) as u32;
self.available = self.initial;
Some(window_update_frame(0, inc))
} else {
None
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct StreamRecvWindow {
available: i64,
initial: i64,
}
impl StreamRecvWindow {
fn new() -> Self {
StreamRecvWindow {
available: OUR_INITIAL_WINDOW,
initial: OUR_INITIAL_WINDOW,
}
}
fn consume(&mut self, n: usize) {
self.available -= n as i64;
}
fn replenish(&mut self, stream_id: u32) -> Option<Frame> {
let threshold = self.initial / 2;
if self.available < threshold {
let inc = (self.initial - self.available) as u32;
self.available = self.initial;
Some(window_update_frame(stream_id, inc))
} else {
None
}
}
}
fn window_update_frame(stream_id: u32, increment: u32) -> Frame {
let mut payload = Vec::with_capacity(4);
payload.extend_from_slice(&(increment & 0x7fff_ffff).to_be_bytes());
Frame {
typ: F_WINDOW_UPDATE,
flags: 0,
stream_id,
payload,
}
}
fn parse_window_update(payload: &[u8]) -> Result<u32> {
if payload.len() != 4 {
return Err(Error::BadResponse(format!(
"WINDOW_UPDATE payload length {} (expected 4) (FRAME_SIZE_ERROR)",
payload.len()
)));
}
let raw = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
Ok(raw & 0x7fff_ffff)
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Frame {
typ: u8,
flags: u8,
stream_id: u32,
payload: Vec<u8>,
}
const MAX_FRAME_PAYLOAD: usize = 1 << 20;
fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> io::Result<()> {
r.read_exact(buf)
}
fn read_frame<R: Read>(r: &mut R) -> io::Result<Frame> {
let mut hdr = [0u8; 9];
read_exact(r, &mut hdr)?;
let length = ((hdr[0] as usize) << 16) | ((hdr[1] as usize) << 8) | (hdr[2] as usize);
let typ = hdr[3];
let flags = hdr[4];
let stream_id = (((hdr[5] & 0x7f) as u32) << 24)
| ((hdr[6] as u32) << 16)
| ((hdr[7] as u32) << 8)
| (hdr[8] as u32);
if length > MAX_FRAME_PAYLOAD {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("frame payload too large: {length}"),
));
}
let mut payload = vec![0u8; length];
if length > 0 {
read_exact(r, &mut payload)?;
}
Ok(Frame {
typ,
flags,
stream_id,
payload,
})
}
fn write_frame<W: Write>(w: &mut W, f: &Frame) -> io::Result<()> {
if f.payload.len() > MAX_FRAME_PAYLOAD {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"frame payload too large",
));
}
let len = f.payload.len();
let hdr = [
((len >> 16) & 0xff) as u8,
((len >> 8) & 0xff) as u8,
(len & 0xff) as u8,
f.typ,
f.flags,
((f.stream_id >> 24) & 0x7f) as u8, ((f.stream_id >> 16) & 0xff) as u8,
((f.stream_id >> 8) & 0xff) as u8,
(f.stream_id & 0xff) as u8,
];
w.write_all(&hdr)?;
if !f.payload.is_empty() {
w.write_all(&f.payload)?;
}
Ok(())
}
fn encode_int(value: u64, prefix_bits: u8) -> Vec<u8> {
let max_prefix: u64 = (1u64 << prefix_bits) - 1;
let mut out = Vec::new();
if value < max_prefix {
out.push(value as u8);
return out;
}
out.push(max_prefix as u8);
let mut rem = value - max_prefix;
while rem >= 128 {
out.push(((rem & 0x7f) as u8) | 0x80);
rem >>= 7;
}
out.push(rem as u8);
out
}
fn decode_int(buf: &[u8], prefix_bits: u8) -> Result<(u64, usize)> {
if buf.is_empty() {
return Err(Error::BadResponse("hpack: empty integer".into()));
}
let max_prefix: u64 = (1u64 << prefix_bits) - 1;
let mut value = (buf[0] as u64) & max_prefix;
if value < max_prefix {
return Ok((value, 1));
}
let mut i = 1usize;
let mut shift = 0u32;
loop {
if i >= buf.len() {
return Err(Error::BadResponse("hpack: truncated integer".into()));
}
let b = buf[i];
i += 1;
value = value
.checked_add(((b & 0x7f) as u64) << shift)
.ok_or_else(|| Error::BadResponse("hpack: integer overflow".into()))?;
if b & 0x80 == 0 {
return Ok((value, i));
}
shift += 7;
if shift > 63 {
return Err(Error::BadResponse("hpack: integer overflow".into()));
}
}
}
const STATIC_TABLE: &[(&str, &str)] = &[
(":authority", ""), (":method", "GET"), (":method", "POST"), (":path", "/"), (":path", "/index.html"), (":scheme", "http"), (":scheme", "https"), (":status", "200"), (":status", "204"), (":status", "206"), (":status", "304"), (":status", "400"), (":status", "404"), (":status", "500"), ("accept-charset", ""), ("accept-encoding", "gzip, deflate"), ("accept-language", ""), ("accept-ranges", ""), ("accept", ""), ("access-control-allow-origin", ""), ("age", ""), ("allow", ""), ("authorization", ""), ("cache-control", ""), ("content-disposition", ""), ("content-encoding", ""), ("content-language", ""), ("content-length", ""), ("content-location", ""), ("content-range", ""), ("content-type", ""), ("cookie", ""), ("date", ""), ("etag", ""), ("expect", ""), ("expires", ""), ("from", ""), ("host", ""), ("if-match", ""), ("if-modified-since", ""), ("if-none-match", ""), ("if-range", ""), ("if-unmodified-since", ""), ("last-modified", ""), ("link", ""), ("location", ""), ("max-forwards", ""), ("proxy-authenticate", ""), ("proxy-authorization", ""), ("range", ""), ("referer", ""), ("refresh", ""), ("retry-after", ""), ("server", ""), ("set-cookie", ""), ("strict-transport-security", ""), ("transfer-encoding", ""), ("user-agent", ""), ("vary", ""), ("via", ""), ("www-authenticate", ""), ];
fn static_full_index(name: &str, value: &str) -> Option<usize> {
STATIC_TABLE
.iter()
.position(|(n, v)| *n == name && *v == value)
.map(|i| i + 1)
}
fn static_name_index(name: &str) -> Option<usize> {
STATIC_TABLE
.iter()
.position(|(n, _)| *n == name)
.map(|i| i + 1)
}
const HUFFMAN: [(u32, u8); 257] = [
(0x1ff8, 13),
(0x7fffd8, 23),
(0xfffffe2, 28),
(0xfffffe3, 28),
(0xfffffe4, 28),
(0xfffffe5, 28),
(0xfffffe6, 28),
(0xfffffe7, 28),
(0xfffffe8, 28),
(0xffffea, 24),
(0x3ffffffc, 30),
(0xfffffe9, 28),
(0xfffffea, 28),
(0x3ffffffd, 30),
(0xfffffeb, 28),
(0xfffffec, 28),
(0xfffffed, 28),
(0xfffffee, 28),
(0xfffffef, 28),
(0xffffff0, 28),
(0xffffff1, 28),
(0xffffff2, 28),
(0x3ffffffe, 30),
(0xffffff3, 28),
(0xffffff4, 28),
(0xffffff5, 28),
(0xffffff6, 28),
(0xffffff7, 28),
(0xffffff8, 28),
(0xffffff9, 28),
(0xffffffa, 28),
(0xffffffb, 28),
(0x14, 6),
(0x3f8, 10),
(0x3f9, 10),
(0xffa, 12),
(0x1ff9, 13),
(0x15, 6),
(0xf8, 8),
(0x7fa, 11),
(0x3fa, 10),
(0x3fb, 10),
(0xf9, 8),
(0x7fb, 11),
(0xfa, 8),
(0x16, 6),
(0x17, 6),
(0x18, 6),
(0x0, 5),
(0x1, 5),
(0x2, 5),
(0x19, 6),
(0x1a, 6),
(0x1b, 6),
(0x1c, 6),
(0x1d, 6),
(0x1e, 6),
(0x1f, 6),
(0x5c, 7),
(0xfb, 8),
(0x7ffc, 15),
(0x20, 6),
(0xffb, 12),
(0x3fc, 10),
(0x1ffa, 13),
(0x21, 6),
(0x5d, 7),
(0x5e, 7),
(0x5f, 7),
(0x60, 7),
(0x61, 7),
(0x62, 7),
(0x63, 7),
(0x64, 7),
(0x65, 7),
(0x66, 7),
(0x67, 7),
(0x68, 7),
(0x69, 7),
(0x6a, 7),
(0x6b, 7),
(0x6c, 7),
(0x6d, 7),
(0x6e, 7),
(0x6f, 7),
(0x70, 7),
(0x71, 7),
(0x72, 7),
(0xfc, 8),
(0x73, 7),
(0xfd, 8),
(0x1ffb, 13),
(0x7fff0, 19),
(0x1ffc, 13),
(0x3ffc, 14),
(0x22, 6),
(0x7ffd, 15),
(0x3, 5),
(0x23, 6),
(0x4, 5),
(0x24, 6),
(0x5, 5),
(0x25, 6),
(0x26, 6),
(0x27, 6),
(0x6, 5),
(0x74, 7),
(0x75, 7),
(0x28, 6),
(0x29, 6),
(0x2a, 6),
(0x7, 5),
(0x2b, 6),
(0x76, 7),
(0x2c, 6),
(0x8, 5),
(0x9, 5),
(0x2d, 6),
(0x77, 7),
(0x78, 7),
(0x79, 7),
(0x7a, 7),
(0x7b, 7),
(0x7ffe, 15),
(0x7fc, 11),
(0x3ffd, 14),
(0x1ffd, 13),
(0xffffffc, 28),
(0xfffe6, 20),
(0x3fffd2, 22),
(0xfffe7, 20),
(0xfffe8, 20),
(0x3fffd3, 22),
(0x3fffd4, 22),
(0x3fffd5, 22),
(0x7fffd9, 23),
(0x3fffd6, 22),
(0x7fffda, 23),
(0x7fffdb, 23),
(0x7fffdc, 23),
(0x7fffdd, 23),
(0x7fffde, 23),
(0xffffeb, 24),
(0x7fffdf, 23),
(0xffffec, 24),
(0xffffed, 24),
(0x3fffd7, 22),
(0x7fffe0, 23),
(0xffffee, 24),
(0x7fffe1, 23),
(0x7fffe2, 23),
(0x7fffe3, 23),
(0x7fffe4, 23),
(0x1fffdc, 21),
(0x3fffd8, 22),
(0x7fffe5, 23),
(0x3fffd9, 22),
(0x7fffe6, 23),
(0x7fffe7, 23),
(0xffffef, 24),
(0x3fffda, 22),
(0x1fffdd, 21),
(0xfffe9, 20),
(0x3fffdb, 22),
(0x3fffdc, 22),
(0x7fffe8, 23),
(0x7fffe9, 23),
(0x1fffde, 21),
(0x7fffea, 23),
(0x3fffdd, 22),
(0x3fffde, 22),
(0xfffff0, 24),
(0x1fffdf, 21),
(0x3fffdf, 22),
(0x7fffeb, 23),
(0x7fffec, 23),
(0x1fffe0, 21),
(0x1fffe1, 21),
(0x3fffe0, 22),
(0x1fffe2, 21),
(0x7fffed, 23),
(0x3fffe1, 22),
(0x7fffee, 23),
(0x7fffef, 23),
(0xfffea, 20),
(0x3fffe2, 22),
(0x3fffe3, 22),
(0x3fffe4, 22),
(0x7ffff0, 23),
(0x3fffe5, 22),
(0x3fffe6, 22),
(0x7ffff1, 23),
(0x3ffffe0, 26),
(0x3ffffe1, 26),
(0xfffeb, 20),
(0x7fff1, 19),
(0x3fffe7, 22),
(0x7ffff2, 23),
(0x3fffe8, 22),
(0x1ffffec, 25),
(0x3ffffe2, 26),
(0x3ffffe3, 26),
(0x3ffffe4, 26),
(0x7ffffde, 27),
(0x7ffffdf, 27),
(0x3ffffe5, 26),
(0xfffff1, 24),
(0x1ffffed, 25),
(0x7fff2, 19),
(0x1fffe3, 21),
(0x3ffffe6, 26),
(0x7ffffe0, 27),
(0x7ffffe1, 27),
(0x3ffffe7, 26),
(0x7ffffe2, 27),
(0xfffff2, 24),
(0x1fffe4, 21),
(0x1fffe5, 21),
(0x3ffffe8, 26),
(0x3ffffe9, 26),
(0xffffffd, 28),
(0x7ffffe3, 27),
(0x7ffffe4, 27),
(0x7ffffe5, 27),
(0xfffec, 20),
(0xfffff3, 24),
(0xfffed, 20),
(0x1fffe6, 21),
(0x3fffe9, 22),
(0x1fffe7, 21),
(0x1fffe8, 21),
(0x7ffff3, 23),
(0x3fffea, 22),
(0x3fffeb, 22),
(0x1ffffee, 25),
(0x1ffffef, 25),
(0xfffff4, 24),
(0xfffff5, 24),
(0x3ffffea, 26),
(0x7ffff4, 23),
(0x3ffffeb, 26),
(0x7ffffe6, 27),
(0x3ffffec, 26),
(0x3ffffed, 26),
(0x7ffffe7, 27),
(0x7ffffe8, 27),
(0x7ffffe9, 27),
(0x7ffffea, 27),
(0x7ffffeb, 27),
(0xffffffe, 28),
(0x7ffffec, 27),
(0x7ffffed, 27),
(0x7ffffee, 27),
(0x7ffffef, 27),
(0x7fffff0, 27),
(0x3ffffee, 26),
(0x3fffffff, 30), ];
fn huffman_decode(input: &[u8]) -> Result<Vec<u8>> {
let mut out = Vec::with_capacity(input.len().saturating_mul(2));
let mut acc: u64 = 0;
let mut acc_len: u8 = 0;
for &byte in input {
acc = (acc << 8) | (byte as u64);
acc_len += 8;
while acc_len >= 5 {
let mut matched = false;
let max_len = acc_len.min(30);
for try_len in 5..=max_len {
let code = (acc >> (acc_len - try_len)) & ((1u64 << try_len) - 1);
if let Some(sym) = lookup_huffman(code as u32, try_len) {
if sym == 256 {
return Err(Error::BadResponse(
"hpack: EOS symbol in Huffman literal".into(),
));
}
out.push(sym as u8);
acc_len -= try_len;
matched = true;
break;
}
}
if !matched {
break;
}
}
}
if acc_len >= 8 {
return Err(Error::BadResponse(
"hpack: trailing Huffman bits >= 8".into(),
));
}
if acc_len > 0 {
let pad_mask = (1u64 << acc_len) - 1;
let tail = acc & pad_mask;
if tail != pad_mask {
return Err(Error::BadResponse("hpack: bad Huffman padding".into()));
}
}
Ok(out)
}
fn lookup_huffman(code: u32, len: u8) -> Option<u16> {
for (i, (c, l)) in HUFFMAN.iter().enumerate() {
if *l == len && *c == code {
return Some(i as u16);
}
}
None
}
const DYN_TABLE_CAP: usize = 4096;
struct Decoder {
dyn_table: Vec<(String, String)>,
dyn_table_size: usize,
dyn_table_cap: usize,
}
impl Decoder {
fn new() -> Self {
Decoder {
dyn_table: Vec::new(),
dyn_table_size: 0,
dyn_table_cap: DYN_TABLE_CAP,
}
}
fn entry_size(name: &str, value: &str) -> usize {
name.len() + value.len() + 32
}
fn evict_to_fit(&mut self, incoming: usize) {
while self.dyn_table_size + incoming > self.dyn_table_cap && !self.dyn_table.is_empty() {
let (n, v) = self.dyn_table.pop().unwrap();
self.dyn_table_size = self.dyn_table_size.saturating_sub(Self::entry_size(&n, &v));
}
}
fn insert(&mut self, name: String, value: String) {
let sz = Self::entry_size(&name, &value);
if sz > self.dyn_table_cap {
self.dyn_table.clear();
self.dyn_table_size = 0;
return;
}
self.evict_to_fit(sz);
self.dyn_table.insert(0, (name, value));
self.dyn_table_size += sz;
}
fn lookup(&self, index: u64) -> Result<(String, String)> {
if index == 0 {
return Err(Error::BadResponse("hpack: index 0".into()));
}
let idx = index as usize;
if idx <= STATIC_TABLE.len() {
let (n, v) = STATIC_TABLE[idx - 1];
return Ok((n.to_string(), v.to_string()));
}
let dyn_idx = idx - STATIC_TABLE.len() - 1;
if dyn_idx >= self.dyn_table.len() {
return Err(Error::BadResponse(format!(
"hpack: index {idx} out of range"
)));
}
let (n, v) = &self.dyn_table[dyn_idx];
Ok((n.clone(), v.clone()))
}
fn lookup_name(&self, index: u64) -> Result<String> {
Ok(self.lookup(index)?.0)
}
fn read_string(&self, buf: &[u8], pos: &mut usize) -> Result<String> {
if *pos >= buf.len() {
return Err(Error::BadResponse("hpack: truncated string".into()));
}
let huffman = buf[*pos] & 0x80 != 0;
let (len, consumed) = decode_int(&buf[*pos..], 7)?;
*pos += consumed;
let end = pos
.checked_add(len as usize)
.ok_or_else(|| Error::BadResponse("hpack: string length overflow".into()))?;
if end > buf.len() {
return Err(Error::BadResponse("hpack: truncated string body".into()));
}
let raw = &buf[*pos..end];
*pos = end;
if huffman {
let bytes = huffman_decode(raw)?;
String::from_utf8(bytes)
.map_err(|_| Error::BadResponse("hpack: non-utf8 Huffman literal".into()))
} else {
String::from_utf8(raw.to_vec())
.map_err(|_| Error::BadResponse("hpack: non-utf8 literal".into()))
}
}
fn decode_block(&mut self, buf: &[u8]) -> Result<Vec<(String, String)>> {
let mut out = Vec::new();
let mut pos = 0;
let mut list_size: usize = 0;
while pos < buf.len() {
let b = buf[pos];
let entry: (String, String);
if b & 0x80 != 0 {
let (idx, n) = decode_int(&buf[pos..], 7)?;
pos += n;
entry = self.lookup(idx)?;
} else if b & 0x40 != 0 {
let (idx, n) = decode_int(&buf[pos..], 6)?;
pos += n;
let name = if idx == 0 {
self.read_string(buf, &mut pos)?
} else {
self.lookup_name(idx)?
};
let value = self.read_string(buf, &mut pos)?;
self.insert(name.clone(), value.clone());
entry = (name, value);
} else if b & 0x20 != 0 {
let (new_size, n) = decode_int(&buf[pos..], 5)?;
pos += n;
let cap = (new_size as usize).min(DYN_TABLE_CAP);
self.dyn_table_cap = cap;
self.evict_to_fit(0);
continue;
} else {
let (idx, n) = decode_int(&buf[pos..], 4)?;
pos += n;
let name = if idx == 0 {
self.read_string(buf, &mut pos)?
} else {
self.lookup_name(idx)?
};
let value = self.read_string(buf, &mut pos)?;
entry = (name, value);
}
list_size = list_size
.saturating_add(entry.0.len())
.saturating_add(entry.1.len())
.saturating_add(32);
if list_size > MAX_DECODED_HEADER_LIST {
return Err(Error::BadResponse(
"hpack: decoded header list exceeds limit".into(),
));
}
out.push(entry);
}
Ok(out)
}
}
fn huffman_encode(input: &[u8]) -> Vec<u8> {
let total_bits: usize = input.iter().map(|b| HUFFMAN[*b as usize].1 as usize).sum();
let out_len = total_bits.div_ceil(8);
let mut out = vec![0u8; out_len];
let mut bit_pos: usize = 0;
for &b in input {
let (code, len) = HUFFMAN[b as usize];
let len = len as usize;
let mut remaining = len;
let mut code_left = code as u64;
while remaining > 0 {
let byte_index = bit_pos / 8;
let bit_in_byte = bit_pos % 8; let space_in_byte = 8 - bit_in_byte;
let take = remaining.min(space_in_byte);
let shift = (remaining - take) as u32;
let chunk = ((code_left >> shift) & ((1u64 << take) - 1)) as u8;
out[byte_index] |= chunk << (space_in_byte - take);
if shift > 0 {
code_left &= (1u64 << shift) - 1;
} else {
code_left = 0;
}
remaining -= take;
bit_pos += take;
}
}
let trailing = (8 - (total_bits % 8)) % 8;
if trailing > 0 {
let last = out.len() - 1;
out[last] |= (1u8 << trailing) - 1;
}
out
}
fn encode_literal_string(out: &mut Vec<u8>, s: &str) {
let raw = s.as_bytes();
let huff = huffman_encode(raw);
if huff.len() < raw.len() {
let mut len_bytes = encode_int(huff.len() as u64, 7);
len_bytes[0] |= 0x80; out.extend_from_slice(&len_bytes);
out.extend_from_slice(&huff);
} else {
let mut len_bytes = encode_int(raw.len() as u64, 7);
len_bytes[0] &= 0x7f; out.extend_from_slice(&len_bytes);
out.extend_from_slice(raw);
}
}
struct Encoder {
dyn_table: VecDeque<(String, String)>,
dyn_table_size: usize,
max_dyn_table_size: usize,
pending_max_table_size_signal: Option<usize>,
}
impl Encoder {
fn new() -> Self {
Encoder {
dyn_table: VecDeque::new(),
dyn_table_size: 0,
max_dyn_table_size: DYN_TABLE_CAP,
pending_max_table_size_signal: None,
}
}
fn entry_size(name: &str, value: &str) -> usize {
name.len() + value.len() + 32
}
fn set_peer_max_table_size(&mut self, n: usize) {
self.max_dyn_table_size = n;
self.evict_to_fit(0);
self.pending_max_table_size_signal = Some(n);
}
fn evict_to_fit(&mut self, incoming: usize) {
while self.dyn_table_size + incoming > self.max_dyn_table_size && !self.dyn_table.is_empty()
{
let (n, v) = self.dyn_table.pop_back().unwrap();
self.dyn_table_size = self.dyn_table_size.saturating_sub(Self::entry_size(&n, &v));
}
}
fn insert(&mut self, name: &str, value: &str) {
let sz = Self::entry_size(name, value);
if sz > self.max_dyn_table_size {
self.dyn_table.clear();
self.dyn_table_size = 0;
return;
}
self.evict_to_fit(sz);
self.dyn_table
.push_front((name.to_string(), value.to_string()));
self.dyn_table_size += sz;
}
fn combined_full_index(&self, name: &str, value: &str) -> Option<u32> {
if let Some(i) = static_full_index(name, value) {
return Some(i as u32);
}
for (i, (n, v)) in self.dyn_table.iter().enumerate() {
if n == name && v == value {
return Some((STATIC_TABLE.len() + 1 + i) as u32);
}
}
None
}
fn combined_name_index(&self, name: &str) -> Option<u32> {
if let Some(i) = static_name_index(name) {
return Some(i as u32);
}
for (i, (n, _)) in self.dyn_table.iter().enumerate() {
if n == name {
return Some((STATIC_TABLE.len() + 1 + i) as u32);
}
}
None
}
fn encode_header(&mut self, out: &mut Vec<u8>, name: &str, value: &str) {
if let Some(n) = self.pending_max_table_size_signal.take() {
let mut bytes = encode_int(n as u64, 5);
bytes[0] |= 0x20;
out.extend_from_slice(&bytes);
}
if let Some(idx) = self.combined_full_index(name, value) {
let mut bytes = encode_int(idx as u64, 7);
bytes[0] |= 0x80;
out.extend_from_slice(&bytes);
return;
}
if let Some(idx) = self.combined_name_index(name) {
let mut bytes = encode_int(idx as u64, 6);
bytes[0] |= 0x40;
out.extend_from_slice(&bytes);
encode_literal_string(out, value);
self.insert(name, value);
return;
}
out.push(0x40);
encode_literal_string(out, name);
encode_literal_string(out, value);
self.insert(name, value);
}
}
fn tcp_connect(req: &Request) -> Result<TcpStream> {
let proxy = req
.proxy
.as_ref()
.filter(|_| !crate::http::proxy_bypassed(req));
let (target_host, target_port) = match proxy {
Some(p) => (p.host.as_str(), p.port),
None => (req.url.host.as_str(), req.url.port),
};
let addr = format!("{target_host}:{target_port}");
let stream = match req.connect_timeout {
Some(t) => {
let first = std::net::ToSocketAddrs::to_socket_addrs(&addr)?
.next()
.ok_or_else(|| Error::InvalidUrl(target_host.to_string()))?;
TcpStream::connect_timeout(&first, t)?
}
None => TcpStream::connect(&addr)?,
};
stream.set_read_timeout(req.read_timeout)?;
stream.set_write_timeout(req.read_timeout)?;
Ok(stream)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StreamState {
Idle,
Open,
HalfClosedLocal,
HalfClosedRemote,
Closed,
}
impl StreamState {
fn send_data(self, end_stream: bool) -> Result<StreamState> {
match self {
StreamState::Idle => {
if end_stream {
Ok(StreamState::HalfClosedLocal)
} else {
Ok(StreamState::Open)
}
}
StreamState::Open => Ok(if end_stream {
StreamState::HalfClosedLocal
} else {
StreamState::Open
}),
StreamState::HalfClosedRemote => Ok(if end_stream {
StreamState::Closed
} else {
StreamState::HalfClosedRemote
}),
StreamState::HalfClosedLocal | StreamState::Closed => Err(Error::BadResponse(format!(
"internal: tried to send DATA in stream state {self:?}"
))),
}
}
fn recv_data(self, end_stream: bool) -> Result<StreamState> {
match self {
StreamState::Open => Ok(if end_stream {
StreamState::HalfClosedRemote
} else {
StreamState::Open
}),
StreamState::HalfClosedLocal => Ok(if end_stream {
StreamState::Closed
} else {
StreamState::HalfClosedLocal
}),
StreamState::Idle | StreamState::HalfClosedRemote | StreamState::Closed => {
Err(Error::BadResponse(format!(
"received DATA in stream state {self:?} (RFC 9113 §5.1)"
)))
}
}
}
fn recv_headers(self, end_stream: bool) -> Result<StreamState> {
match self {
StreamState::Open => Ok(if end_stream {
StreamState::HalfClosedRemote
} else {
StreamState::Open
}),
StreamState::HalfClosedLocal => Ok(if end_stream {
StreamState::Closed
} else {
StreamState::HalfClosedLocal
}),
StreamState::Closed => Ok(StreamState::Closed),
StreamState::Idle | StreamState::HalfClosedRemote => Err(Error::BadResponse(format!(
"received HEADERS in stream state {self:?} (RFC 9113 §5.1)"
))),
}
}
fn recv_rst(self) -> Result<StreamState> {
match self {
StreamState::Idle => Err(Error::BadResponse(
"RST_STREAM on idle stream (RFC 9113 §5.1)".into(),
)),
_ => Ok(StreamState::Closed),
}
}
}
struct Stream {
#[allow(dead_code)]
id: u32,
state: StreamState,
send_window: StreamSendWindow,
recv_window: StreamRecvWindow,
headers_buf: Vec<u8>,
response_headers: Option<Vec<(String, String)>>,
body: Vec<u8>,
end_stream_recv: bool,
}
impl Stream {
fn new(id: u32, initial_peer_window: i64) -> Self {
Stream {
id,
state: StreamState::Idle,
send_window: StreamSendWindow::new(initial_peer_window),
recv_window: StreamRecvWindow::new(),
headers_buf: Vec::new(),
response_headers: None,
body: Vec::new(),
end_stream_recv: false,
}
}
fn send_budget(&self, conn_window: &ConnSendWindow) -> i64 {
self.send_window.available.min(conn_window.available)
}
fn push_header_fragment(&mut self, frag: &[u8]) -> Result<()> {
if self.headers_buf.len().saturating_add(frag.len()) > MAX_HEADERS_BUF {
return Err(Error::BadResponse(
"header block exceeds size limit (CONTINUATION flood?)".into(),
));
}
self.headers_buf.extend_from_slice(frag);
Ok(())
}
}
struct Connection<S: Read + Write> {
tls: S,
peer: PeerSettings,
conn_send_window: ConnSendWindow,
conn_recv_window: ConnRecvWindow,
decoder: Decoder,
encoder: Encoder,
streams: HashMap<u32, Stream>,
next_stream_id: u32,
goaway_received: Option<u32>,
expecting_continuation: Option<u32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DispatchOutcome {
Continue,
Done(u32),
}
impl<S: Read + Write> Connection<S> {
fn new(mut tls: S) -> Result<Self> {
tls.write_all(PREFACE)?;
let mut settings_payload = Vec::with_capacity(6);
settings_payload.extend_from_slice(&S_ENABLE_PUSH.to_be_bytes());
settings_payload.extend_from_slice(&0u32.to_be_bytes());
let our_settings = Frame {
typ: F_SETTINGS,
flags: 0,
stream_id: 0,
payload: settings_payload,
};
write_frame(&mut tls, &our_settings)?;
tls.flush()?;
Ok(Connection {
tls,
peer: PeerSettings::default(),
conn_send_window: ConnSendWindow::new(),
conn_recv_window: ConnRecvWindow::new(),
decoder: Decoder::new(),
encoder: Encoder::new(),
streams: HashMap::new(),
next_stream_id: 1,
goaway_received: None,
expecting_continuation: None,
})
}
fn is_usable(&self) -> bool {
if self.goaway_received.is_some() {
return false;
}
if !self.streams.is_empty() {
return false;
}
if self.next_stream_id >= 0x8000_0000 {
return false;
}
true
}
fn open_stream(&mut self) -> Result<u32> {
if (self.streams.len() as u64) >= self.peer.max_concurrent_streams as u64 {
return Err(Error::BadResponse("at MAX_CONCURRENT_STREAMS limit".into()));
}
if self.next_stream_id >= 0x8000_0000 {
return Err(Error::BadResponse(
"stream id space exhausted (RFC 9113 §5.1.1)".into(),
));
}
if let Some(last) = self.goaway_received {
if self.next_stream_id > last {
return Err(Error::BadResponse(format!(
"GOAWAY received with last-stream-id={last}; cannot allocate id={}",
self.next_stream_id
)));
}
}
let id = self.next_stream_id;
self.next_stream_id = self.next_stream_id.saturating_add(2);
self.streams
.insert(id, Stream::new(id, self.peer.initial_window_size as i64));
Ok(id)
}
fn send_request_on(&mut self, stream_id: u32, req: &Request) -> Result<()> {
let header_block = build_header_block(&mut self.encoder, req);
let has_body = !req.body.is_empty();
let max_frame_size = self.peer.max_frame_size as usize;
let header_frames =
fragment_header_block(stream_id, &header_block, max_frame_size, !has_body);
for f in &header_frames {
write_frame(&mut self.tls, f)?;
}
{
let s = self
.streams
.get_mut(&stream_id)
.ok_or_else(|| Error::BadResponse(format!("stream {stream_id} not found")))?;
s.state = s.state.send_data(!has_body)?;
}
if has_body {
let mut remaining: &[u8] = req.body.as_slice();
while !remaining.is_empty() {
loop {
let budget = {
let s = self.streams.get(&stream_id).ok_or_else(|| {
Error::BadResponse(format!("stream {stream_id} disappeared mid-send"))
})?;
s.send_budget(&self.conn_send_window)
};
if budget > 0 {
break;
}
match self.read_and_dispatch()? {
DispatchOutcome::Continue => {}
DispatchOutcome::Done(done_id) if done_id == stream_id => {
return Err(Error::BadResponse(
"server ended stream before request body was fully sent".into(),
));
}
DispatchOutcome::Done(_) => {
}
}
}
let max_frame_size = self.peer.max_frame_size as usize;
let budget = self
.streams
.get(&stream_id)
.unwrap()
.send_budget(&self.conn_send_window);
let n = next_data_chunk_size(max_frame_size, budget, remaining.len());
debug_assert!(n > 0, "loop above guarantees positive budget");
let chunk = &remaining[..n];
remaining = &remaining[n..];
let is_last = remaining.is_empty();
let data_frame = Frame {
typ: F_DATA,
flags: if is_last { FLAG_END_STREAM } else { 0 },
stream_id,
payload: chunk.to_vec(),
};
write_frame(&mut self.tls, &data_frame)?;
self.conn_send_window.consume(n);
let s = self.streams.get_mut(&stream_id).unwrap();
s.send_window.consume(n);
s.state = s.state.send_data(is_last)?;
}
}
self.tls.flush()?;
Ok(())
}
fn drive_until_stream_done(&mut self, stream_id: u32) -> Result<Stream> {
loop {
if let Some(s) = self.streams.get(&stream_id) {
if matches!(s.state, StreamState::Closed | StreamState::HalfClosedRemote)
&& s.response_headers.is_some()
&& s.end_stream_recv
{
return Ok(self.streams.remove(&stream_id).unwrap());
}
} else {
return Err(Error::BadResponse(format!(
"stream {stream_id} not registered"
)));
}
match self.read_and_dispatch()? {
DispatchOutcome::Continue => {}
DispatchOutcome::Done(done_id) if done_id == stream_id => {
return Ok(self.streams.remove(&stream_id).unwrap());
}
DispatchOutcome::Done(_) => {
}
}
}
}
fn read_and_dispatch(&mut self) -> Result<DispatchOutcome> {
let frame = match read_frame(&mut self.tls) {
Ok(f) => f,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
return Err(Error::UnexpectedEof);
}
Err(e) => return Err(Error::Io(e)),
};
self.process_frame(frame)
}
fn process_frame(&mut self, frame: Frame) -> Result<DispatchOutcome> {
if let Some(awaiting) = self.expecting_continuation {
let ok = frame.typ == F_CONTINUATION && frame.stream_id == awaiting;
if !ok {
return Err(Error::BadResponse(format!(
"expected CONTINUATION on stream {awaiting}, got type=0x{:x} stream={}",
frame.typ, frame.stream_id
)));
}
}
if frame.stream_id == 0 {
return self.process_conn_frame(frame);
}
self.process_stream_frame(frame)
}
fn process_conn_frame(&mut self, frame: Frame) -> Result<DispatchOutcome> {
match frame.typ {
F_SETTINGS if frame.flags & FLAG_ACK == 0 => {
let old_initial = self.peer.initial_window_size;
let old_header_table_size = self.peer.header_table_size;
self.peer.apply_settings_payload(&frame.payload)?;
let new_initial = self.peer.initial_window_size;
if new_initial != old_initial {
for s in self.streams.values_mut() {
s.send_window.apply_initial_window_change(new_initial)?;
}
}
if self.peer.header_table_size != old_header_table_size {
self.encoder
.set_peer_max_table_size(self.peer.header_table_size as usize);
}
let ack = Frame {
typ: F_SETTINGS,
flags: FLAG_ACK,
stream_id: 0,
payload: Vec::new(),
};
write_frame(&mut self.tls, &ack)?;
self.tls.flush()?;
}
F_SETTINGS => { }
F_PING if frame.flags & FLAG_ACK == 0 => {
let pong = Frame {
typ: F_PING,
flags: FLAG_ACK,
stream_id: 0,
payload: frame.payload.clone(),
};
write_frame(&mut self.tls, &pong)?;
self.tls.flush()?;
}
F_PING => {}
F_WINDOW_UPDATE => {
let inc = parse_window_update(&frame.payload)?;
self.conn_send_window.apply_window_update(inc)?;
}
F_GOAWAY => {
let last = if frame.payload.len() >= 4 {
u32::from_be_bytes([
frame.payload[0],
frame.payload[1],
frame.payload[2],
frame.payload[3],
]) & 0x7fff_ffff
} else {
0
};
self.goaway_received = Some(last);
let doomed: Vec<u32> = self
.streams
.iter()
.filter(|(id, _)| **id > last)
.map(|(id, _)| *id)
.collect();
for id in doomed {
if let Some(s) = self.streams.get_mut(&id) {
s.state = StreamState::Closed;
}
}
}
_ => {
}
}
Ok(DispatchOutcome::Continue)
}
fn process_stream_frame(&mut self, frame: Frame) -> Result<DispatchOutcome> {
match frame.typ {
F_HEADERS => self.process_headers(frame),
F_CONTINUATION => self.process_continuation(frame),
F_DATA => self.process_data(frame),
F_RST_STREAM => self.process_rst(frame),
F_WINDOW_UPDATE => {
let inc = parse_window_update(&frame.payload)?;
if let Some(s) = self.streams.get_mut(&frame.stream_id) {
s.send_window.apply_window_update(inc)?;
}
Ok(DispatchOutcome::Continue)
}
F_PUSH_PROMISE => {
Err(Error::BadResponse(
"received PUSH_PROMISE despite SETTINGS_ENABLE_PUSH=0".into(),
))
}
_ => {
Ok(DispatchOutcome::Continue)
}
}
}
fn process_headers(&mut self, frame: Frame) -> Result<DispatchOutcome> {
let mut payload = frame.payload.as_slice();
let mut pad_len = 0usize;
if frame.flags & FLAG_PADDED != 0 {
if payload.is_empty() {
return Err(Error::BadResponse(
"HEADERS PADDED with empty payload".into(),
));
}
pad_len = payload[0] as usize;
payload = &payload[1..];
}
if frame.flags & FLAG_PRIORITY != 0 {
if payload.len() < 5 {
return Err(Error::BadResponse(
"HEADERS PRIORITY with insufficient payload".into(),
));
}
payload = &payload[5..];
}
if payload.len() < pad_len {
return Err(Error::BadResponse(
"HEADERS padding overruns payload".into(),
));
}
let frag = &payload[..payload.len() - pad_len];
let end_headers = frame.flags & FLAG_END_HEADERS != 0;
let end_stream = frame.flags & FLAG_END_STREAM != 0;
let stream_id = frame.stream_id;
let known = self.streams.contains_key(&stream_id);
if !known {
return Err(Error::BadResponse(format!(
"HEADERS on unknown stream {stream_id} (server push disabled)"
)));
}
let state = self.streams.get(&stream_id).unwrap().state;
if state == StreamState::Closed {
if end_headers {
let _ = self.decoder.decode_block(frag)?;
} else {
self.streams
.get_mut(&stream_id)
.unwrap()
.push_header_fragment(frag)?;
self.expecting_continuation = Some(stream_id);
}
return Ok(DispatchOutcome::Continue);
}
let new_state = state.recv_headers(end_stream)?;
let s = self.streams.get_mut(&stream_id).unwrap();
s.push_header_fragment(frag)?;
if end_stream {
s.end_stream_recv = true;
}
if end_headers {
let block = std::mem::take(&mut s.headers_buf);
let decoded = self.decoder.decode_block(&block)?;
let s = self.streams.get_mut(&stream_id).unwrap();
s.response_headers = Some(decoded);
s.state = new_state;
self.expecting_continuation = None;
} else {
s.state = new_state;
self.expecting_continuation = Some(stream_id);
}
let done = matches!(
self.streams.get(&stream_id).unwrap().state,
StreamState::Closed | StreamState::HalfClosedRemote
) && self.streams.get(&stream_id).unwrap().end_stream_recv
&& self
.streams
.get(&stream_id)
.unwrap()
.response_headers
.is_some();
Ok(if done {
DispatchOutcome::Done(stream_id)
} else {
DispatchOutcome::Continue
})
}
fn process_continuation(&mut self, frame: Frame) -> Result<DispatchOutcome> {
let stream_id = frame.stream_id;
match self.expecting_continuation {
Some(awaiting) if awaiting == stream_id => {}
_ => {
return Err(Error::BadResponse(format!(
"unexpected CONTINUATION on stream {stream_id}"
)));
}
}
let s = self.streams.get_mut(&stream_id).ok_or_else(|| {
Error::BadResponse(format!("CONTINUATION on unknown stream {stream_id}"))
})?;
s.push_header_fragment(&frame.payload)?;
let end_headers = frame.flags & FLAG_END_HEADERS != 0;
if end_headers {
let block = std::mem::take(&mut s.headers_buf);
let decoded = self.decoder.decode_block(&block)?;
let s = self.streams.get_mut(&stream_id).unwrap();
if s.state != StreamState::Closed {
s.response_headers = Some(decoded);
}
self.expecting_continuation = None;
}
let done = matches!(
self.streams.get(&stream_id).unwrap().state,
StreamState::Closed | StreamState::HalfClosedRemote
) && self.streams.get(&stream_id).unwrap().end_stream_recv
&& self
.streams
.get(&stream_id)
.unwrap()
.response_headers
.is_some();
Ok(if done {
DispatchOutcome::Done(stream_id)
} else {
DispatchOutcome::Continue
})
}
fn process_data(&mut self, frame: Frame) -> Result<DispatchOutcome> {
let stream_id = frame.stream_id;
let frame_bytes = frame.payload.len();
self.conn_recv_window.consume(frame_bytes);
let known = self.streams.contains_key(&stream_id);
if !known {
if let Some(upd) = self.conn_recv_window.replenish() {
write_frame(&mut self.tls, &upd)?;
self.tls.flush()?;
}
return Ok(DispatchOutcome::Continue);
}
let state = self.streams.get(&stream_id).unwrap().state;
if state == StreamState::Closed {
return Err(Error::BadResponse(format!(
"DATA on closed stream {stream_id}"
)));
}
let end_stream = frame.flags & FLAG_END_STREAM != 0;
let new_state = state.recv_data(end_stream)?;
let s = self.streams.get_mut(&stream_id).unwrap();
s.recv_window.consume(frame_bytes);
let mut payload = frame.payload.as_slice();
if frame.flags & FLAG_PADDED != 0 {
if payload.is_empty() {
return Err(Error::BadResponse("DATA PADDED with empty payload".into()));
}
let pad_len = payload[0] as usize;
payload = &payload[1..];
if payload.len() < pad_len {
return Err(Error::BadResponse("DATA padding overruns payload".into()));
}
payload = &payload[..payload.len() - pad_len];
}
if s.body.len().saturating_add(payload.len()) > MAX_RESPONSE_BYTES {
return Err(Error::BadResponse(
"response body exceeds size limit".into(),
));
}
s.body.extend_from_slice(payload);
if end_stream {
s.end_stream_recv = true;
}
s.state = new_state;
if let Some(upd) = self.conn_recv_window.replenish() {
write_frame(&mut self.tls, &upd)?;
}
if let Some(upd) = self
.streams
.get_mut(&stream_id)
.unwrap()
.recv_window
.replenish(stream_id)
{
write_frame(&mut self.tls, &upd)?;
}
self.tls.flush()?;
let s = self.streams.get(&stream_id).unwrap();
let done = matches!(s.state, StreamState::Closed | StreamState::HalfClosedRemote)
&& s.end_stream_recv
&& s.response_headers.is_some();
Ok(if done {
DispatchOutcome::Done(stream_id)
} else {
DispatchOutcome::Continue
})
}
fn process_rst(&mut self, frame: Frame) -> Result<DispatchOutcome> {
let stream_id = frame.stream_id;
let code = if frame.payload.len() >= 4 {
u32::from_be_bytes([
frame.payload[0],
frame.payload[1],
frame.payload[2],
frame.payload[3],
])
} else {
0
};
match self.streams.get_mut(&stream_id) {
Some(s) => {
if s.state == StreamState::Closed {
return Ok(DispatchOutcome::Continue);
}
s.state = s.state.recv_rst()?;
Err(Error::BadResponse(format!(
"stream {stream_id} reset by server, error code {code}"
)))
}
None => {
Ok(DispatchOutcome::Continue)
}
}
}
}
fn fragment_header_block(
stream_id: u32,
header_block: &[u8],
max_frame_size: usize,
end_stream: bool,
) -> Vec<Frame> {
debug_assert!(max_frame_size > 0, "max_frame_size must be > 0");
let mut frames = Vec::new();
if header_block.is_empty() {
let mut flags = FLAG_END_HEADERS;
if end_stream {
flags |= FLAG_END_STREAM;
}
frames.push(Frame {
typ: F_HEADERS,
flags,
stream_id,
payload: Vec::new(),
});
return frames;
}
let total_chunks = header_block.len().div_ceil(max_frame_size);
for (i, chunk) in header_block.chunks(max_frame_size).enumerate() {
let is_last = i + 1 == total_chunks;
if i == 0 {
let mut flags = 0u8;
if end_stream {
flags |= FLAG_END_STREAM;
}
if is_last {
flags |= FLAG_END_HEADERS;
}
frames.push(Frame {
typ: F_HEADERS,
flags,
stream_id,
payload: chunk.to_vec(),
});
} else {
let flags = if is_last { FLAG_END_HEADERS } else { 0 };
frames.push(Frame {
typ: F_CONTINUATION,
flags,
stream_id,
payload: chunk.to_vec(),
});
}
}
frames
}
fn next_data_chunk_size(max_frame_size: usize, available: i64, remaining: usize) -> usize {
if available <= 0 {
return 0;
}
let cap_window = available.min(remaining as i64).min(max_frame_size as i64);
cap_window as usize
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub(crate) struct PoolKey {
scheme: String,
host: String,
port: u16,
}
impl PoolKey {
fn from_request(req: &Request) -> Self {
PoolKey {
scheme: req.url.scheme.clone(),
host: req.url.host.clone(),
port: req.url.port,
}
}
}
const POOL_PER_KEY_CAP: usize = 4;
const POOL_GLOBAL_CAP: usize = 32;
type PooledConn<S> = Arc<Mutex<Connection<S>>>;
pub(crate) struct PoolInner<S: Read + Write> {
entries: HashMap<PoolKey, Vec<PooledConn<S>>>,
}
impl<S: Read + Write> PoolInner<S> {
fn new() -> Self {
PoolInner {
entries: HashMap::new(),
}
}
fn checkout(&mut self, key: &PoolKey) -> Option<PooledConn<S>> {
let bucket = self.entries.get_mut(key)?;
let conn = bucket.pop();
if bucket.is_empty() {
self.entries.remove(key);
}
conn
}
fn release(&mut self, key: PoolKey, conn: PooledConn<S>) {
let total: usize = self.entries.values().map(Vec::len).sum();
if total >= POOL_GLOBAL_CAP {
return;
}
let bucket = self.entries.entry(key).or_default();
if bucket.len() >= POOL_PER_KEY_CAP {
return;
}
bucket.push(conn);
}
#[cfg(test)]
fn total_len(&self) -> usize {
self.entries.values().map(Vec::len).sum()
}
}
static POOL: OnceLock<Mutex<PoolInner<TlsStream<TcpStream>>>> = OnceLock::new();
fn global_pool() -> &'static Mutex<PoolInner<TlsStream<TcpStream>>> {
POOL.get_or_init(|| Mutex::new(PoolInner::new()))
}
fn dial_h2(req: &Request) -> Result<Connection<TlsStream<TcpStream>>> {
let tcp = tcp_connect(req)?;
if let Some(p) = req
.proxy
.as_ref()
.filter(|_| !crate::http::proxy_bypassed(req))
{
crate::http::connect_tunnel(&tcp, &req.url, p, &mut std::io::sink())?;
}
let opts = crate::http::tls_opts_from(req, &[b"h2"])?;
let tls = crate::tls::connect_over_tls(tcp, &req.url.host, opts)?;
let negotiated_h2 = tls.alpn_selected().map(|p| p == b"h2").unwrap_or(false);
if !negotiated_h2 {
return Err(Error::H2NotNegotiated);
}
Connection::new(tls)
}
fn pool_eligible(req: &Request) -> bool {
req.verify_tls && req.ca_bundle.is_none()
}
pub fn send(req: Request) -> Result<Response> {
if req.url.scheme != "https" {
return Err(Error::UnsupportedScheme(format!(
"http/2 over {} not supported",
req.url.scheme
)));
}
let key = PoolKey::from_request(&req);
let eligible = pool_eligible(&req);
if eligible {
let pooled = {
let mut guard = global_pool().lock().expect("pool mutex poisoned");
guard.checkout(&key)
};
if let Some(arc) = pooled {
let mut conn_guard = arc.lock().expect("pooled conn mutex poisoned");
if conn_guard.is_usable() {
match run_one_request(&mut conn_guard, &req) {
Ok(resp) => {
let still_usable = conn_guard.is_usable();
drop(conn_guard);
if still_usable {
let mut guard = global_pool().lock().expect("pool mutex poisoned");
guard.release(key.clone(), arc);
}
return Ok(resp);
}
Err(_e) => {
drop(conn_guard);
}
}
}
}
}
let mut fresh = dial_h2(&req)?;
let resp = run_one_request(&mut fresh, &req)?;
if eligible && fresh.is_usable() {
let arc = Arc::new(Mutex::new(fresh));
let mut guard = global_pool().lock().expect("pool mutex poisoned");
guard.release(key, arc);
}
Ok(resp)
}
fn run_one_request<S: Read + Write>(conn: &mut Connection<S>, req: &Request) -> Result<Response> {
let stream_id = conn.open_stream()?;
conn.send_request_on(stream_id, req)?;
let stream = conn.drive_until_stream_done(stream_id)?;
build_response_from_stream(stream)
}
fn build_response_from_stream(stream: Stream) -> Result<Response> {
let headers = stream
.response_headers
.ok_or_else(|| Error::BadResponse("response ended before any HEADERS frame".into()))?;
let mut status: Option<u16> = None;
let mut clean_headers: Vec<(String, String)> = Vec::with_capacity(headers.len());
for (k, v) in headers {
if k == ":status" {
status = Some(
v.parse::<u16>()
.map_err(|_| Error::BadResponse(format!("bad :status {v:?}")))?,
);
} else if k.starts_with(':') {
} else {
clean_headers.push((k, v));
}
}
let status = status.ok_or_else(|| Error::BadResponse("response missing :status".into()))?;
let (clean_headers, body) =
crate::http::maybe_decode_body(clean_headers, stream.body, &mut std::io::sink())?;
Ok(Response {
status,
reason: String::new(), version: "HTTP/2".to_string(),
headers: clean_headers,
body,
})
}
fn build_header_block(encoder: &mut Encoder, req: &Request) -> Vec<u8> {
let mut out = Vec::new();
encoder.encode_header(&mut out, ":method", &req.method);
encoder.encode_header(&mut out, ":scheme", &req.url.scheme);
let authority = if req.url.port == 443 && req.url.scheme == "https" {
req.url.host.clone()
} else {
format!("{}:{}", req.url.host, req.url.port)
};
encoder.encode_header(&mut out, ":authority", &authority);
encoder.encode_header(&mut out, ":path", &req.url.path);
let mut have_ua = false;
let mut have_accept = false;
let mut have_accept_enc = false;
let mut have_auth = false;
for (k, v) in &req.headers {
if is_connection_specific_header(k) || k.eq_ignore_ascii_case("host") {
continue;
}
let lk = k.to_ascii_lowercase();
if lk == "user-agent" {
have_ua = true;
}
if lk == "accept" {
have_accept = true;
}
if lk == "accept-encoding" {
have_accept_enc = true;
}
if lk == "authorization" {
have_auth = true;
}
encoder.encode_header(&mut out, &lk, v);
}
if !have_auth {
if let Some(creds) = crate::http::effective_basic_auth(req) {
let value = format!("Basic {creds}");
encoder.encode_header(&mut out, "authorization", &value);
}
}
if !have_ua {
encoder.encode_header(
&mut out,
"user-agent",
concat!("rsurl/", env!("CARGO_PKG_VERSION")),
);
}
if !have_accept {
encoder.encode_header(&mut out, "accept", "*/*");
}
if !have_accept_enc {
encoder.encode_header(&mut out, "accept-encoding", "gzip, deflate");
}
if !req.body.is_empty() {
let len = req.body.len().to_string();
encoder.encode_header(&mut out, "content-length", &len);
}
out
}
fn is_connection_specific_header(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"connection" | "proxy-connection" | "keep-alive" | "transfer-encoding" | "upgrade" | "te" )
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn int_encode_small() {
assert_eq!(encode_int(10, 5), vec![10]);
}
#[test]
fn int_encode_large() {
assert_eq!(encode_int(1337, 5), vec![0x1f, 0x9a, 0x0a]);
}
#[test]
fn int_encode_eight_bit() {
assert_eq!(encode_int(42, 8), vec![42]);
}
#[test]
fn int_decode_round_trips() {
for &(v, p) in &[
(0u64, 5),
(10, 5),
(30, 5),
(31, 5),
(1337, 5),
(1, 8),
(255, 8),
] {
let enc = encode_int(v, p);
let (dec, n) = decode_int(&enc, p).unwrap();
assert_eq!(dec, v, "value {v} with {p}-bit prefix");
assert_eq!(n, enc.len());
}
}
#[test]
fn int_decode_truncated_errors() {
assert!(decode_int(&[0x1f], 5).is_err());
assert!(decode_int(&[0x1f, 0x80], 5).is_err());
}
#[test]
fn static_table_method_get() {
assert_eq!(static_full_index(":method", "GET"), Some(2));
}
#[test]
fn static_table_method_post() {
assert_eq!(static_full_index(":method", "POST"), Some(3));
}
#[test]
fn static_table_name_only() {
assert_eq!(static_name_index(":status"), Some(8));
assert_eq!(static_name_index("user-agent"), Some(58));
assert_eq!(static_name_index("does-not-exist"), None);
}
#[test]
fn static_table_length() {
assert_eq!(STATIC_TABLE.len(), 61);
}
#[test]
fn frame_round_trip_empty_settings() {
let f = Frame {
typ: F_SETTINGS,
flags: 0,
stream_id: 0,
payload: Vec::new(),
};
let mut buf = Vec::new();
write_frame(&mut buf, &f).unwrap();
assert_eq!(buf.len(), 9);
let mut cur = Cursor::new(buf);
let g = read_frame(&mut cur).unwrap();
assert_eq!(g, f);
}
#[test]
fn frame_round_trip_headers_with_payload() {
let f = Frame {
typ: F_HEADERS,
flags: FLAG_END_STREAM | FLAG_END_HEADERS,
stream_id: 1,
payload: vec![
0x82, 0x86, 0x84, 0x41, 0x88, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab,
0x90, 0xf4, 0xff,
],
};
let mut buf = Vec::new();
write_frame(&mut buf, &f).unwrap();
let mut cur = Cursor::new(buf);
let g = read_frame(&mut cur).unwrap();
assert_eq!(g, f);
assert_eq!(g.flags, 0x05);
}
#[test]
fn frame_stream_id_high_bit_masked_on_read() {
let buf = vec![0, 0, 0, F_DATA, 0, 0x80, 0, 0, 1];
let mut cur = Cursor::new(buf);
let f = read_frame(&mut cur).unwrap();
assert_eq!(f.stream_id, 1);
}
#[test]
fn hpack_encode_indexed_method() {
let mut enc = Encoder::new();
let mut out = Vec::new();
enc.encode_header(&mut out, ":method", "GET");
assert_eq!(out, vec![0x82]);
assert!(enc.dyn_table.is_empty());
}
#[test]
fn hpack_encode_literal_with_indexed_name() {
let mut enc = Encoder::new();
let mut out = Vec::new();
enc.encode_header(&mut out, ":path", "/foo");
assert_eq!(out[0], 0x44);
let mut dec = Decoder::new();
let got = dec.decode_block(&out).unwrap();
assert_eq!(got, vec![(":path".into(), "/foo".into())]);
assert_eq!(enc.dyn_table.len(), 1);
assert_eq!(enc.dyn_table[0], (":path".to_string(), "/foo".to_string()));
}
#[test]
fn hpack_encode_literal_full() {
let mut enc = Encoder::new();
let mut out = Vec::new();
enc.encode_header(&mut out, "x-custom", "yes");
assert_eq!(out[0], 0x40);
let mut dec = Decoder::new();
let got = dec.decode_block(&out).unwrap();
assert_eq!(got, vec![("x-custom".into(), "yes".into())]);
assert_eq!(enc.dyn_table[0], ("x-custom".into(), "yes".into()));
}
#[test]
fn hpack_decode_round_trip_pseudo_headers() {
let mut enc = Encoder::new();
let mut block = Vec::new();
enc.encode_header(&mut block, ":method", "GET");
enc.encode_header(&mut block, ":scheme", "https");
enc.encode_header(&mut block, ":authority", "example.com");
enc.encode_header(&mut block, ":path", "/");
let mut dec = Decoder::new();
let got = dec.decode_block(&block).unwrap();
assert_eq!(got.len(), 4);
assert_eq!(got[0], (":method".into(), "GET".into()));
assert_eq!(got[1], (":scheme".into(), "https".into()));
assert_eq!(got[2], (":authority".into(), "example.com".into()));
assert_eq!(got[3], (":path".into(), "/".into()));
}
#[test]
fn hpack_decode_indexed_static() {
let mut dec = Decoder::new();
let got = dec.decode_block(&[0x82]).unwrap();
assert_eq!(got, vec![(":method".into(), "GET".into())]);
}
#[test]
fn hpack_decode_literal_with_incremental_indexing() {
let buf: Vec<u8> = vec![
0x40, 0x0a, b'c', b'u', b's', b't', b'o', b'm', b'-', b'k', b'e', b'y', 0x0d, b'c',
b'u', b's', b't', b'o', b'm', b'-', b'h', b'e', b'a', b'd', b'e', b'r',
];
let mut dec = Decoder::new();
let got = dec.decode_block(&buf).unwrap();
assert_eq!(got, vec![("custom-key".into(), "custom-header".into())]);
assert_eq!(dec.dyn_table.len(), 1);
}
#[test]
fn huffman_decode_c4_1() {
let coded = [
0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
];
let out = huffman_decode(&coded).unwrap();
assert_eq!(out, b"www.example.com");
}
#[test]
fn huffman_decode_c4_2() {
let coded = [0xa8, 0xeb, 0x10, 0x64, 0x9c, 0xbf];
let out = huffman_decode(&coded).unwrap();
assert_eq!(out, b"no-cache");
}
#[test]
fn huffman_decode_c4_3() {
let coded = [0x25, 0xa8, 0x49, 0xe9, 0x5b, 0xa9, 0x7d, 0x7f];
let out = huffman_decode(&coded).unwrap();
assert_eq!(out, b"custom-key");
}
#[test]
fn huffman_decode_rejects_short_padding() {
assert!(huffman_decode(&[0x00]).is_err());
}
#[test]
fn huffman_encode_padding_bits() {
let out = huffman_encode(b"a");
assert_eq!(out, vec![0x1f]);
}
#[test]
fn huffman_encode_appendix_c_www_example_com() {
let out = huffman_encode(b"www.example.com");
assert_eq!(
out,
vec![0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,]
);
}
#[test]
fn huffman_encode_appendix_c_no_cache() {
let out = huffman_encode(b"no-cache");
assert_eq!(out, vec![0xa8, 0xeb, 0x10, 0x64, 0x9c, 0xbf]);
}
#[test]
fn huffman_encode_appendix_c_custom_key() {
let out = huffman_encode(b"custom-key");
assert_eq!(out, vec![0x25, 0xa8, 0x49, 0xe9, 0x5b, 0xa9, 0x7d, 0x7f]);
}
#[test]
fn huffman_encode_appendix_c_custom_value() {
let out = huffman_encode(b"custom-value");
assert_eq!(
out,
vec![0x25, 0xa8, 0x49, 0xe9, 0x5b, 0xb8, 0xe8, 0xb4, 0xbf]
);
}
#[test]
fn huffman_encode_round_trips_through_decoder() {
for s in &[
"",
"a",
"ab",
"abc",
"Hello, world!",
"the quick brown fox jumps",
"/foo/bar/baz",
] {
let bytes = s.as_bytes();
if bytes.is_empty() {
let enc = huffman_encode(bytes);
assert!(enc.is_empty());
continue;
}
let enc = huffman_encode(bytes);
let dec = huffman_decode(&enc).unwrap();
assert_eq!(dec, bytes, "round-trip mismatch for {s:?}");
}
}
#[test]
fn encode_literal_chooses_huffman_when_shorter() {
let mut out = Vec::new();
let s: String = "a".repeat(100);
encode_literal_string(&mut out, &s);
assert_eq!(out[0] & 0x80, 0x80, "Huffman bit should be set");
}
#[test]
fn encode_literal_chooses_raw_when_huffman_longer() {
let mut out = Vec::new();
let bytes: Vec<u8> = vec![0xff; 100];
let huff = huffman_encode(&bytes);
assert!(
huff.len() > bytes.len(),
"0xff Huffman should be longer than raw"
);
let s: String = "|".repeat(50);
out.clear();
encode_literal_string(&mut out, &s);
assert_eq!(out[0] & 0x80, 0x00, "Huffman bit should be cleared");
assert_eq!(out[0] as usize & 0x7f, 50);
assert_eq!(&out[1..], s.as_bytes());
}
#[test]
fn encoder_inserts_into_dyn_table_on_incremental_indexing() {
let mut enc = Encoder::new();
let mut out = Vec::new();
enc.encode_header(&mut out, "x-custom", "value1");
assert_eq!(enc.dyn_table.len(), 1);
assert_eq!(
enc.dyn_table[0],
("x-custom".to_string(), "value1".to_string())
);
assert_eq!(enc.dyn_table_size, "x-custom".len() + "value1".len() + 32);
}
#[test]
fn encoder_evicts_to_fit_max_size() {
let mut enc = Encoder::new();
enc.max_dyn_table_size = 64;
let mut out = Vec::new();
enc.encode_header(&mut out, "n1aa", "v1aa");
enc.encode_header(&mut out, "n2aa", "v2aa");
assert_eq!(enc.dyn_table.len(), 1, "only the newest should remain");
assert_eq!(enc.dyn_table[0], ("n2aa".to_string(), "v2aa".to_string()));
assert_eq!(enc.dyn_table_size, 40);
}
#[test]
fn encoder_emits_size_update_signal_on_next_encode_after_setting_change() {
let mut enc = Encoder::new();
enc.set_peer_max_table_size(1024);
let mut out = Vec::new();
enc.encode_header(&mut out, ":method", "GET");
assert_eq!(out, vec![0x3f, 0xe1, 0x07, 0x82]);
out.clear();
enc.encode_header(&mut out, ":method", "GET");
assert_eq!(out, vec![0x82]);
}
#[test]
fn encoder_uses_dynamic_index_for_repeat() {
let mut enc = Encoder::new();
let mut out = Vec::new();
enc.encode_header(&mut out, "x", "y");
out.clear();
enc.encode_header(&mut out, "x", "y");
assert_eq!(out, vec![0xbe]);
}
#[test]
fn encoder_uses_indexed_name_from_dyn_table() {
let mut enc = Encoder::new();
let mut out = Vec::new();
enc.encode_header(&mut out, "x-foo", "v1");
out.clear();
enc.encode_header(&mut out, "x-foo", "v2");
assert_eq!(out[0], 0x7e);
assert_eq!(enc.dyn_table.len(), 2);
assert_eq!(enc.dyn_table[0].1, "v2");
assert_eq!(enc.dyn_table[1].1, "v1");
}
#[test]
fn encode_decode_round_trip() {
let mut enc = Encoder::new();
let mut dec = Decoder::new();
let inputs: Vec<(&str, &str)> = vec![
(":method", "GET"),
(":scheme", "https"),
(":authority", "example.com"),
(":path", "/foo"),
("user-agent", "rsurl/test"),
("accept", "*/*"),
("x-custom", "hello world"),
("user-agent", "rsurl/test"), ("x-custom", "different"), ];
let mut buf = Vec::new();
for (n, v) in &inputs {
enc.encode_header(&mut buf, n, v);
}
let got = dec.decode_block(&buf).unwrap();
let expected: Vec<(String, String)> = inputs
.into_iter()
.map(|(n, v)| (n.to_string(), v.to_string()))
.collect();
assert_eq!(got, expected);
}
#[test]
fn encoder_size_update_evicts_oversize_entries_immediately() {
let mut enc = Encoder::new();
let mut out = Vec::new();
enc.encode_header(&mut out, "n1aa", "v1aa"); enc.encode_header(&mut out, "n2aa", "v2aa"); assert_eq!(enc.dyn_table.len(), 2);
enc.set_peer_max_table_size(50);
assert_eq!(enc.dyn_table.len(), 1);
assert_eq!(enc.dyn_table[0].0, "n2aa");
}
#[test]
fn hpack_decode_huffman_literal_value() {
let buf = vec![
0x44, 0x8c, 0x60, 0xd4, 0x85, 0x31, 0x68, 0xdf, 0x1c, 0x6f, 0xa2, 0xa6, 0xfd, 0x95,
0xb6, 0x88,
];
let _ = Decoder::new().decode_block(&buf);
}
#[test]
fn build_header_block_includes_pseudo() {
let req = Request::new("GET", "https://example.com/foo").unwrap();
let mut enc = Encoder::new();
let block = build_header_block(&mut enc, &req);
let mut dec = Decoder::new();
let headers = dec.decode_block(&block).unwrap();
let kv: Vec<(&str, &str)> = headers
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
assert!(kv.contains(&(":method", "GET")));
assert!(kv.contains(&(":scheme", "https")));
assert!(kv.contains(&(":authority", "example.com")));
assert!(kv.contains(&(":path", "/foo")));
assert!(kv.iter().any(|(k, _)| *k == "user-agent"));
assert!(kv.iter().any(|(k, _)| *k == "accept"));
}
#[test]
fn build_header_block_strips_banned_headers() {
let req = Request::new("GET", "https://example.com/")
.unwrap()
.header("Connection", "close")
.header("Host", "evil.example")
.header("X-Allowed", "yes");
let mut enc = Encoder::new();
let block = build_header_block(&mut enc, &req);
let mut dec = Decoder::new();
let headers = dec.decode_block(&block).unwrap();
let names: Vec<&str> = headers.iter().map(|(k, _)| k.as_str()).collect();
assert!(!names.contains(&"connection"));
assert!(!names.contains(&"host"));
assert!(names.contains(&"x-allowed"));
}
#[test]
fn build_header_block_authority_includes_nonstandard_port() {
let req = Request::new("GET", "https://example.com:8443/").unwrap();
let mut enc = Encoder::new();
let block = build_header_block(&mut enc, &req);
let mut dec = Decoder::new();
let headers = dec.decode_block(&block).unwrap();
let auth = headers.iter().find(|(k, _)| k == ":authority").unwrap();
assert_eq!(auth.1, "example.com:8443");
}
#[test]
fn decoder_dynamic_table_size_update_caps_to_4096() {
let mut dec = Decoder::new();
dec.decode_block(&[0x20]).unwrap();
assert_eq!(dec.dyn_table_cap, 0);
}
#[test]
fn decoder_rejects_oversize_index() {
let mut dec = Decoder::new();
let err = dec.decode_block(&[0xff, 0x01]).unwrap_err();
match err {
Error::BadResponse(_) => {}
other => panic!("expected BadResponse, got {other:?}"),
}
}
fn settings_payload(entries: &[(u16, u32)]) -> Vec<u8> {
let mut out = Vec::with_capacity(entries.len() * 6);
for (id, val) in entries {
out.extend_from_slice(&id.to_be_bytes());
out.extend_from_slice(&val.to_be_bytes());
}
out
}
#[test]
fn peer_settings_defaults_match_rfc() {
let p = PeerSettings::default();
assert_eq!(p.header_table_size, 4096);
assert!(p.enable_push);
assert_eq!(p.max_concurrent_streams, u32::MAX);
assert_eq!(p.initial_window_size, 65_535);
assert_eq!(p.max_frame_size, 16_384);
assert_eq!(p.max_header_list_size, u32::MAX);
}
#[test]
fn peer_settings_apply_updates_known_identifiers() {
let mut p = PeerSettings::default();
let payload = settings_payload(&[
(S_HEADER_TABLE_SIZE, 8192),
(S_INITIAL_WINDOW_SIZE, 131_072),
(S_MAX_FRAME_SIZE, 32_768),
]);
p.apply_settings_payload(&payload).unwrap();
assert_eq!(p.header_table_size, 8192);
assert_eq!(p.initial_window_size, 131_072);
assert_eq!(p.max_frame_size, 32_768);
assert!(p.enable_push);
assert_eq!(p.max_concurrent_streams, u32::MAX);
assert_eq!(p.max_header_list_size, u32::MAX);
}
#[test]
fn peer_settings_ignores_unknown_identifier() {
let mut p = PeerSettings::default();
let before = p.clone();
let payload = settings_payload(&[(0xFFFF, 42)]);
p.apply_settings_payload(&payload).unwrap();
assert_eq!(p, before);
}
#[test]
fn peer_settings_rejects_bad_enable_push() {
let mut p = PeerSettings::default();
let payload = settings_payload(&[(S_ENABLE_PUSH, 2)]);
let err = p.apply_settings_payload(&payload).unwrap_err();
match err {
Error::BadResponse(_) => {}
other => panic!("expected BadResponse, got {other:?}"),
}
}
#[test]
fn peer_settings_rejects_oversize_window() {
let mut p = PeerSettings::default();
let payload = settings_payload(&[(S_INITIAL_WINDOW_SIZE, 0x8000_0000)]);
let err = p.apply_settings_payload(&payload).unwrap_err();
match err {
Error::BadResponse(_) => {}
other => panic!("expected BadResponse, got {other:?}"),
}
}
#[test]
fn peer_settings_rejects_undersize_max_frame() {
let mut p = PeerSettings::default();
let payload = settings_payload(&[(S_MAX_FRAME_SIZE, 16_383)]);
let err = p.apply_settings_payload(&payload).unwrap_err();
match err {
Error::BadResponse(_) => {}
other => panic!("expected BadResponse, got {other:?}"),
}
}
#[test]
fn peer_settings_rejects_truncated_payload() {
let mut p = PeerSettings::default();
let payload = vec![0u8; 5];
let err = p.apply_settings_payload(&payload).unwrap_err();
match err {
Error::BadResponse(_) => {}
other => panic!("expected BadResponse, got {other:?}"),
}
}
#[test]
fn peer_settings_enable_push_zero_disables() {
let mut p = PeerSettings::default();
p.apply_settings_payload(&settings_payload(&[(S_ENABLE_PUSH, 0)]))
.unwrap();
assert!(!p.enable_push);
p.apply_settings_payload(&settings_payload(&[(S_ENABLE_PUSH, 1)]))
.unwrap();
assert!(p.enable_push);
}
#[test]
fn peer_settings_max_frame_size_boundaries() {
let mut p = PeerSettings::default();
p.apply_settings_payload(&settings_payload(&[(S_MAX_FRAME_SIZE, 16_384)]))
.unwrap();
assert_eq!(p.max_frame_size, 16_384);
p.apply_settings_payload(&settings_payload(&[(S_MAX_FRAME_SIZE, 16_777_215)]))
.unwrap();
assert_eq!(p.max_frame_size, 16_777_215);
let err = p
.apply_settings_payload(&settings_payload(&[(S_MAX_FRAME_SIZE, 16_777_216)]))
.unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn send_window_defaults_match_rfc() {
let c = ConnSendWindow::new();
assert_eq!(c.available, 65_535);
let s = StreamSendWindow::new(65_535);
assert_eq!(s.available, 65_535);
assert_eq!(s.initial_peer_window, 65_535);
}
#[test]
fn send_window_decrements_after_data() {
let mut c = ConnSendWindow::new();
let mut s = StreamSendWindow::new(65_535);
c.consume(1000);
s.consume(1000);
assert_eq!(c.available, 64_535);
assert_eq!(s.available, 64_535);
c.consume(64_535);
s.consume(64_535);
assert_eq!(c.available, 0);
assert_eq!(s.available, 0);
}
#[test]
fn window_update_zero_increment_is_error() {
let zero_payload = [0u8; 4];
let inc = parse_window_update(&zero_payload).unwrap();
assert_eq!(inc, 0);
let mut c = ConnSendWindow::new();
assert!(matches!(
c.apply_window_update(inc),
Err(Error::BadResponse(_))
));
let mut s = StreamSendWindow::new(65_535);
assert!(matches!(
s.apply_window_update(inc),
Err(Error::BadResponse(_))
));
}
#[test]
fn window_update_overflow_is_error() {
let mut c = ConnSendWindow::new();
c.available = WINDOW_MAX;
assert!(matches!(
c.apply_window_update(1),
Err(Error::BadResponse(_))
));
let mut s = StreamSendWindow::new(65_535);
s.available = WINDOW_MAX;
assert!(matches!(
s.apply_window_update(1),
Err(Error::BadResponse(_))
));
}
#[test]
fn window_update_high_bit_ignored_on_parse() {
let payload = [0x80, 0x00, 0x00, 0x01];
let inc = parse_window_update(&payload).unwrap();
assert_eq!(inc, 1);
}
#[test]
fn window_update_wrong_length_is_error() {
assert!(matches!(
parse_window_update(&[0u8; 3]),
Err(Error::BadResponse(_))
));
assert!(matches!(
parse_window_update(&[0u8; 5]),
Err(Error::BadResponse(_))
));
}
#[test]
fn initial_window_size_delta_adjusts_stream_send_window() {
let mut s = StreamSendWindow::new(65_535);
s.apply_initial_window_change(131_072).unwrap();
assert_eq!(s.available, 65_535 + (131_072 - 65_535));
assert_eq!(s.initial_peer_window, 131_072);
s.apply_initial_window_change(0).unwrap();
assert_eq!(s.available, 0);
assert_eq!(s.initial_peer_window, 0);
}
#[test]
fn initial_window_size_delta_overflow_is_error() {
let mut s = StreamSendWindow::new(65_535);
s.available = WINDOW_MAX;
let err = s.apply_initial_window_change(65_536).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn initial_window_size_delta_allows_negative_window() {
let mut s = StreamSendWindow::new(65_535);
s.available = 100;
s.apply_initial_window_change(0).unwrap();
assert_eq!(s.available, -65_435);
}
#[test]
fn recv_window_defaults_match_rfc() {
let c = ConnRecvWindow::new();
assert_eq!(c.available, OUR_INITIAL_WINDOW);
assert_eq!(c.initial, OUR_INITIAL_WINDOW);
let s = StreamRecvWindow::new();
assert_eq!(s.available, OUR_INITIAL_WINDOW);
assert_eq!(s.initial, OUR_INITIAL_WINDOW);
}
#[test]
fn recv_window_no_replenish_above_half() {
let mut c = ConnRecvWindow::new();
c.consume(1000);
assert!(c.replenish().is_none());
assert_eq!(c.available, OUR_INITIAL_WINDOW - 1000);
let mut s = StreamRecvWindow::new();
s.consume(1000);
assert!(s.replenish(1).is_none());
assert_eq!(s.available, OUR_INITIAL_WINDOW - 1000);
}
#[test]
fn recv_window_replenishes_when_below_half() {
let mut c = ConnRecvWindow::new();
c.consume(20_000);
c.consume(20_000);
assert_eq!(c.available, 25_535);
let f = c.replenish().expect("conn window expected replenish");
assert_eq!(f.typ, F_WINDOW_UPDATE);
assert_eq!(f.stream_id, 0);
let inc = parse_window_update(&f.payload).unwrap();
assert_eq!(inc, (OUR_INITIAL_WINDOW - 25_535) as u32);
assert_eq!(c.available, OUR_INITIAL_WINDOW);
assert!(c.replenish().is_none());
let mut s = StreamRecvWindow::new();
s.consume(40_000);
let f = s.replenish(7).expect("stream window expected replenish");
assert_eq!(f.typ, F_WINDOW_UPDATE);
assert_eq!(f.stream_id, 7);
let inc = parse_window_update(&f.payload).unwrap();
assert_eq!(inc, 40_000);
assert_eq!(s.available, OUR_INITIAL_WINDOW);
}
#[test]
fn window_update_frame_payload_shape() {
let f = window_update_frame(7, 0x0102_0304);
assert_eq!(f.typ, F_WINDOW_UPDATE);
assert_eq!(f.flags, 0);
assert_eq!(f.stream_id, 7);
assert_eq!(f.payload, vec![0x01, 0x02, 0x03, 0x04]);
}
#[test]
fn fragment_header_block_into_continuation() {
let max: usize = 16_384;
let payload_len = max * 2 + 7;
let block: Vec<u8> = (0..payload_len).map(|i| (i & 0xff) as u8).collect();
let frames = fragment_header_block(1, &block, max, false);
assert_eq!(frames.len(), 3, "expected HEADERS + 2 CONTINUATION");
assert_eq!(frames[0].typ, F_HEADERS);
assert_eq!(frames[0].stream_id, 1);
assert_eq!(frames[0].payload.len(), max);
assert_eq!(frames[0].flags & FLAG_END_HEADERS, 0);
assert_eq!(frames[0].flags & FLAG_END_STREAM, 0);
assert_eq!(frames[1].typ, F_CONTINUATION);
assert_eq!(frames[1].stream_id, 1);
assert_eq!(frames[1].payload.len(), max);
assert_eq!(frames[1].flags, 0);
assert_eq!(frames[2].typ, F_CONTINUATION);
assert_eq!(frames[2].stream_id, 1);
assert_eq!(frames[2].payload.len(), 7);
assert_eq!(frames[2].flags, FLAG_END_HEADERS);
let mut reassembled = Vec::with_capacity(payload_len);
for f in &frames {
reassembled.extend_from_slice(&f.payload);
}
assert_eq!(reassembled, block);
let frames = fragment_header_block(1, &block, max, true);
assert_eq!(frames[0].flags & FLAG_END_STREAM, FLAG_END_STREAM);
assert_eq!(frames[2].flags & FLAG_END_STREAM, 0);
assert_eq!(frames[2].flags & FLAG_END_HEADERS, FLAG_END_HEADERS);
}
#[test]
fn fragment_header_block_exact_fit() {
let max: usize = 16_384;
let block: Vec<u8> = vec![0xab; max];
let frames = fragment_header_block(1, &block, max, false);
assert_eq!(frames.len(), 1);
assert_eq!(frames[0].typ, F_HEADERS);
assert_eq!(frames[0].stream_id, 1);
assert_eq!(frames[0].payload.len(), max);
assert_eq!(frames[0].flags & FLAG_END_HEADERS, FLAG_END_HEADERS);
assert_eq!(frames[0].flags & FLAG_END_STREAM, 0);
let frames = fragment_header_block(1, &block, max, true);
assert_eq!(frames.len(), 1);
assert_eq!(
frames[0].flags,
FLAG_END_HEADERS | FLAG_END_STREAM,
"exact-fit HEADERS with no body should have END_HEADERS|END_STREAM"
);
}
#[test]
fn fragment_header_block_empty() {
let frames = fragment_header_block(1, &[], 16_384, true);
assert_eq!(frames.len(), 1);
assert_eq!(frames[0].typ, F_HEADERS);
assert!(frames[0].payload.is_empty());
assert_eq!(frames[0].flags, FLAG_END_HEADERS | FLAG_END_STREAM);
}
#[test]
fn fragment_header_block_small_under_cap() {
let block = vec![0x82, 0x86, 0x84]; let frames = fragment_header_block(1, &block, 16_384, false);
assert_eq!(frames.len(), 1);
assert_eq!(frames[0].typ, F_HEADERS);
assert_eq!(frames[0].payload, block);
assert_eq!(frames[0].flags, FLAG_END_HEADERS);
}
#[test]
fn next_data_chunk_size_clamps_to_min_of_three() {
assert_eq!(next_data_chunk_size(16_384, 65_535, 100), 100);
assert_eq!(next_data_chunk_size(16_384, 65_535, 1_000_000), 16_384);
assert_eq!(next_data_chunk_size(16_384, 1_000, 1_000_000), 1_000);
assert_eq!(next_data_chunk_size(16_384, 5_000, 8_000), 5_000);
}
#[test]
fn next_data_chunk_size_zero_when_window_depleted() {
assert_eq!(next_data_chunk_size(16_384, 0, 100), 0);
assert_eq!(next_data_chunk_size(16_384, -1, 100), 0);
assert_eq!(next_data_chunk_size(16_384, -65_535, 100), 0);
}
#[test]
fn fragment_data_into_chunks() {
fn fragment(body: &[u8], max_frame_size: usize, mut available: i64) -> Vec<Frame> {
let mut out = Vec::new();
let mut remaining = body;
while !remaining.is_empty() {
let n = next_data_chunk_size(max_frame_size, available, remaining.len());
if n == 0 {
break; }
let chunk = &remaining[..n];
remaining = &remaining[n..];
let is_last = remaining.is_empty();
out.push(Frame {
typ: F_DATA,
flags: if is_last { FLAG_END_STREAM } else { 0 },
stream_id: 1,
payload: chunk.to_vec(),
});
available -= n as i64;
}
out
}
let body: Vec<u8> = (0..50_000u32).map(|i| (i & 0xff) as u8).collect();
let frames = fragment(&body, 16_384, 65_535);
assert_eq!(frames.len(), 4);
assert_eq!(frames[0].payload.len(), 16_384);
assert_eq!(frames[1].payload.len(), 16_384);
assert_eq!(frames[2].payload.len(), 16_384);
assert_eq!(frames[3].payload.len(), 50_000 - 3 * 16_384);
assert_eq!(frames[0].flags, 0);
assert_eq!(frames[1].flags, 0);
assert_eq!(frames[2].flags, 0);
assert_eq!(frames[3].flags, FLAG_END_STREAM);
let mut roundtrip = Vec::with_capacity(body.len());
for f in &frames {
roundtrip.extend_from_slice(&f.payload);
}
assert_eq!(roundtrip, body);
let frames = fragment(&body, 16_384, 4_000);
assert_eq!(frames.len(), 1);
assert_eq!(frames[0].payload.len(), 4_000);
assert_eq!(frames[0].flags, 0);
let body = vec![0xab; 16_384];
let frames = fragment(&body, 16_384, 65_535);
assert_eq!(frames.len(), 1);
assert_eq!(frames[0].payload.len(), 16_384);
assert_eq!(frames[0].flags, FLAG_END_STREAM);
let frames = fragment(&[], 16_384, 65_535);
assert!(frames.is_empty());
}
struct FakeTls {
wire_in: Cursor<Vec<u8>>,
wire_out: Vec<u8>,
}
impl FakeTls {
fn new() -> Self {
FakeTls {
wire_in: Cursor::new(Vec::new()),
wire_out: Vec::new(),
}
}
}
impl Read for FakeTls {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.wire_in.read(buf)
}
}
impl Write for FakeTls {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.wire_out.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
fn fake_conn() -> Connection<FakeTls> {
Connection {
tls: FakeTls::new(),
peer: PeerSettings::default(),
conn_send_window: ConnSendWindow::new(),
conn_recv_window: ConnRecvWindow::new(),
decoder: Decoder::new(),
encoder: Encoder::new(),
streams: HashMap::new(),
next_stream_id: 1,
goaway_received: None,
expecting_continuation: None,
}
}
#[test]
fn connection_process_settings_acks_and_applies() {
let payload =
settings_payload(&[(S_MAX_FRAME_SIZE, 32_768), (S_INITIAL_WINDOW_SIZE, 131_072)]);
let frame = Frame {
typ: F_SETTINGS,
flags: 0,
stream_id: 0,
payload,
};
let mut conn = fake_conn();
let outcome = conn.process_frame(frame).unwrap();
assert_eq!(outcome, DispatchOutcome::Continue);
assert_eq!(conn.peer.max_frame_size, 32_768);
assert_eq!(conn.peer.initial_window_size, 131_072);
assert_eq!(conn.conn_send_window.available, 65_535);
assert_eq!(conn.tls.wire_out.len(), 9);
let mut cur = Cursor::new(conn.tls.wire_out.clone());
let ack = read_frame(&mut cur).unwrap();
assert_eq!(ack.typ, F_SETTINGS);
assert_eq!(ack.flags, FLAG_ACK);
assert_eq!(ack.stream_id, 0);
assert!(ack.payload.is_empty());
}
#[test]
fn connection_process_window_update_replenishes_send_window() {
let mut conn = fake_conn();
let id = conn.open_stream().unwrap();
conn.process_frame(window_update_frame(id, 10_000)).unwrap();
assert_eq!(
conn.streams.get(&id).unwrap().send_window.available,
65_535 + 10_000
);
assert_eq!(conn.conn_send_window.available, 65_535);
conn.process_frame(window_update_frame(0, 5_000)).unwrap();
assert_eq!(conn.conn_send_window.available, 65_535 + 5_000);
}
#[test]
fn stream_state_open_to_half_closed_local_on_end_stream_send() {
let s = StreamState::Open;
assert_eq!(
s.send_data( true).unwrap(),
StreamState::HalfClosedLocal
);
assert_eq!(
StreamState::Open.send_data(false).unwrap(),
StreamState::Open
);
}
#[test]
fn stream_state_recv_data_in_idle_is_error() {
let err = StreamState::Idle.recv_data(false).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn stream_state_recv_headers_on_closed_stream_is_ignored() {
assert_eq!(
StreamState::Closed.recv_headers(true).unwrap(),
StreamState::Closed
);
}
#[test]
fn next_stream_id_allocates_odd_only() {
let mut conn = fake_conn();
let ids: Vec<u32> = (0..4).map(|_| conn.open_stream().unwrap()).collect();
assert_eq!(ids, vec![1, 3, 5, 7]);
}
#[test]
fn open_stream_refuses_at_max_concurrent() {
let mut conn = fake_conn();
conn.peer.max_concurrent_streams = 2;
assert!(conn.open_stream().is_ok());
assert!(conn.open_stream().is_ok());
let err = conn.open_stream().unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn open_stream_refuses_after_goaway() {
let mut conn = fake_conn();
conn.goaway_received = Some(3);
assert_eq!(conn.open_stream().unwrap(), 1);
assert_eq!(conn.open_stream().unwrap(), 3);
let err = conn.open_stream().unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
}
fn synth_status_200_headers(id: u32, end_stream: bool) -> Frame {
let payload = vec![0x88];
let mut flags = FLAG_END_HEADERS;
if end_stream {
flags |= FLAG_END_STREAM;
}
Frame {
typ: F_HEADERS,
flags,
stream_id: id,
payload,
}
}
fn synth_data(id: u32, body: &[u8], end_stream: bool) -> Frame {
Frame {
typ: F_DATA,
flags: if end_stream { FLAG_END_STREAM } else { 0 },
stream_id: id,
payload: body.to_vec(),
}
}
#[test]
fn dispatch_frame_routes_to_correct_stream() {
let mut conn = fake_conn();
let id_a = conn.open_stream().unwrap();
let id_b = conn.open_stream().unwrap();
conn.streams.get_mut(&id_a).unwrap().state = StreamState::Open;
conn.streams.get_mut(&id_b).unwrap().state = StreamState::Open;
conn.process_frame(synth_status_200_headers(id_a, false))
.unwrap();
conn.process_frame(synth_status_200_headers(id_b, false))
.unwrap();
conn.process_frame(synth_data(id_a, b"aaa", false)).unwrap();
conn.process_frame(synth_data(id_b, b"bbbb", false))
.unwrap();
conn.process_frame(synth_data(id_a, b"AAA", true)).unwrap();
conn.process_frame(synth_data(id_b, b"BBBB", true)).unwrap();
assert_eq!(conn.streams.get(&id_a).unwrap().body, b"aaaAAA");
assert_eq!(conn.streams.get(&id_b).unwrap().body, b"bbbbBBBB");
}
#[test]
fn dispatch_data_on_unknown_stream_is_silently_dropped() {
let mut conn = fake_conn();
let outcome = conn
.process_frame(synth_data(7, b"orphaned", false))
.unwrap();
assert_eq!(outcome, DispatchOutcome::Continue);
assert!(conn.streams.is_empty());
assert!(conn.conn_recv_window.available <= OUR_INITIAL_WINDOW);
}
#[test]
fn dispatch_continuation_on_wrong_stream_is_protocol_error() {
let mut conn = fake_conn();
let id1 = conn.open_stream().unwrap();
let id3 = conn.open_stream().unwrap();
assert_eq!(id1, 1);
assert_eq!(id3, 3);
conn.streams.get_mut(&id1).unwrap().state = StreamState::Open;
conn.streams.get_mut(&id3).unwrap().state = StreamState::Open;
let frame = Frame {
typ: F_HEADERS,
flags: 0, stream_id: id1,
payload: vec![0x88], };
conn.process_frame(frame).unwrap();
assert_eq!(conn.expecting_continuation, Some(id1));
let bad = Frame {
typ: F_CONTINUATION,
flags: FLAG_END_HEADERS,
stream_id: id3,
payload: vec![],
};
let err = conn.process_frame(bad).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn data_frames_past_body_cap_are_rejected() {
let mut conn = fake_conn();
let id = conn.open_stream().unwrap();
conn.streams.get_mut(&id).unwrap().state = StreamState::Open;
conn.streams.get_mut(&id).unwrap().body = vec![0u8; MAX_RESPONSE_BYTES - 2];
let err = conn
.process_frame(synth_data(id, b"abc", false))
.unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
assert_eq!(
conn.streams.get(&id).unwrap().body.len(),
MAX_RESPONSE_BYTES - 2
);
}
#[test]
fn continuation_flood_is_bounded() {
let mut conn = fake_conn();
let id = conn.open_stream().unwrap();
conn.streams.get_mut(&id).unwrap().state = StreamState::Open;
conn.process_frame(Frame {
typ: F_HEADERS,
flags: 0,
stream_id: id,
payload: vec![0u8; 8 * 1024],
})
.unwrap();
let chunk = vec![0u8; 16 * 1024];
let mut hit_cap = false;
for _ in 0..(MAX_HEADERS_BUF / chunk.len() + 4) {
let r = conn.process_frame(Frame {
typ: F_CONTINUATION,
flags: 0,
stream_id: id,
payload: chunk.clone(),
});
if let Err(Error::BadResponse(_)) = r {
hit_cap = true;
break;
}
r.unwrap();
}
assert!(hit_cap, "CONTINUATION flood was not bounded");
assert!(conn.streams.get(&id).unwrap().headers_buf.len() <= MAX_HEADERS_BUF);
}
#[test]
fn hpack_decompression_bomb_is_rejected() {
let mut dec = Decoder::new();
let mut block: Vec<u8> = Vec::new();
let name = b"a";
let value = vec![b'x'; 4096];
let mut entry = Vec::new();
entry.push(0x40); entry.push(name.len() as u8); entry.extend_from_slice(name);
encode_int_local(value.len() as u64, 7, 0x00, &mut entry);
entry.extend_from_slice(&value);
for _ in 0..200 {
block.extend_from_slice(&entry);
}
let err = dec.decode_block(&block).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
}
fn encode_int_local(mut value: u64, prefix_bits: u8, first_byte_high: u8, out: &mut Vec<u8>) {
let max_prefix = (1u64 << prefix_bits) - 1;
if value < max_prefix {
out.push(first_byte_high | value as u8);
return;
}
out.push(first_byte_high | max_prefix as u8);
value -= max_prefix;
while value >= 128 {
out.push(((value & 0x7f) as u8) | 0x80);
value >>= 7;
}
out.push(value as u8);
}
fn fake_arc_conn() -> Arc<Mutex<Connection<FakeTls>>> {
Arc::new(Mutex::new(fake_conn()))
}
fn url_key(url: &str) -> PoolKey {
let req = Request::new("GET", url).unwrap();
PoolKey::from_request(&req)
}
#[test]
fn pool_key_round_trip() {
let a = url_key("https://example.com/a");
let b = url_key("https://example.com/b"); assert_eq!(a, b);
let c = url_key("https://example.com:8443/a");
assert_ne!(a, c, "port differs");
let d = url_key("https://other.example/a");
assert_ne!(a, d, "host differs");
}
#[test]
fn pool_checkout_empty_returns_none() {
let mut pool: PoolInner<FakeTls> = PoolInner::new();
let k = url_key("https://example.com/");
assert!(pool.checkout(&k).is_none());
}
#[test]
fn pool_release_then_checkout_returns_same_conn() {
let mut pool: PoolInner<FakeTls> = PoolInner::new();
let k = url_key("https://example.com/");
let arc = fake_arc_conn();
let raw_in = Arc::as_ptr(&arc) as usize;
pool.release(k.clone(), arc);
let got = pool.checkout(&k).expect("checkout after release");
let raw_out = Arc::as_ptr(&got) as usize;
assert_eq!(raw_in, raw_out, "pool returned a different Arc");
assert!(pool.checkout(&k).is_none());
}
#[test]
fn pool_per_key_cap_drops_overflow() {
let mut pool: PoolInner<FakeTls> = PoolInner::new();
let k = url_key("https://example.com/");
for _ in 0..(POOL_PER_KEY_CAP + 2) {
pool.release(k.clone(), fake_arc_conn());
}
let mut popped = 0;
while pool.checkout(&k).is_some() {
popped += 1;
}
assert_eq!(popped, POOL_PER_KEY_CAP);
}
#[test]
fn pool_global_cap_drops_overflow() {
let mut pool: PoolInner<FakeTls> = PoolInner::new();
for i in 0..(POOL_GLOBAL_CAP * 2) {
let k = url_key(&format!("https://h{i}.example/"));
pool.release(k, fake_arc_conn());
}
assert!(
pool.total_len() <= POOL_GLOBAL_CAP,
"pool grew past global cap: {} > {}",
pool.total_len(),
POOL_GLOBAL_CAP
);
assert_eq!(pool.total_len(), POOL_GLOBAL_CAP);
}
#[test]
fn connection_is_usable_false_after_goaway() {
let mut conn = fake_conn();
conn.goaway_received = Some(0);
assert!(
conn.streams.is_empty(),
"precondition: fresh conn has no streams"
);
assert!(!conn.is_usable());
}
#[test]
fn connection_is_usable_true_initially() {
let conn = fake_conn();
assert!(conn.is_usable());
}
#[test]
fn initial_window_size_delta_applies_to_all_streams() {
let mut conn = fake_conn();
let id1 = conn.open_stream().unwrap();
let id2 = conn.open_stream().unwrap();
let payload = settings_payload(&[(S_INITIAL_WINDOW_SIZE, 131_072)]);
let frame = Frame {
typ: F_SETTINGS,
flags: 0,
stream_id: 0,
payload,
};
conn.process_frame(frame).unwrap();
let expect = 65_535 + (131_072 - 65_535);
assert_eq!(
conn.streams.get(&id1).unwrap().send_window.available,
expect
);
assert_eq!(
conn.streams.get(&id2).unwrap().send_window.available,
expect
);
assert_eq!(conn.conn_send_window.available, 65_535);
}
}