use std::{
cell::RefCell,
collections::{HashMap, HashSet},
hash::{Hash, Hasher},
io::{Read, Write},
net::SocketAddr,
rc::Rc,
time::{Duration, Instant},
};
use mio::{Interest, Registry, Token, net::TcpStream};
use sozu_command::{
proto::command::{Event, EventKind, HealthCheckConfig},
state::ClusterId,
};
use crate::metrics::names;
use crate::{
backends::BackendMap,
protocol::mux::{
parser::{
FLAG_END_HEADERS, FLAG_PADDED, FLAG_PRIORITY, FRAME_HEADER_SIZE, FrameType,
frame_header,
},
serializer::H2_PRI,
},
server::push_event,
};
macro_rules! log_context {
() => {
"HEALTH-CHECK"
};
($cluster:expr) => {
concat!("HEALTH-CHECK cluster=", $cluster)
};
}
const HEALTH_CHECK_TOKEN_BASE: usize = 1 << 24;
const HEALTH_CHECK_TOKEN_CAPACITY: usize = 1 << 16;
type PendingChecks = Vec<(
ClusterId,
HealthCheckConfig,
bool,
Vec<(String, SocketAddr)>,
)>;
#[derive(Debug)]
struct InFlightCheck {
stream: TcpStream,
token: Token,
cluster_id: ClusterId,
backend_id: String,
address: SocketAddr,
started_at: Instant,
timeout: Duration,
request_bytes: Option<Vec<u8>>,
write_offset: usize,
response_buf: Vec<u8>,
config: HealthCheckConfig,
h2c: bool,
}
#[derive(Debug)]
pub struct HealthChecker {
in_flight: Vec<InFlightCheck>,
last_check_time: HashMap<ClusterId, Instant>,
next_token_id: usize,
ready_tokens: HashSet<Token>,
}
impl Default for HealthChecker {
fn default() -> Self {
Self::new()
}
}
impl HealthChecker {
pub fn new() -> Self {
HealthChecker {
in_flight: Vec::new(),
last_check_time: HashMap::new(),
next_token_id: 0,
ready_tokens: HashSet::new(),
}
}
fn allocate_token(&mut self) -> Option<Token> {
let in_flight: HashSet<usize> = self
.in_flight
.iter()
.map(|c| c.token.0 - HEALTH_CHECK_TOKEN_BASE)
.collect();
debug_assert!(
in_flight.iter().all(|&o| o < HEALTH_CHECK_TOKEN_CAPACITY),
"every in-flight token offset must fall within the slot capacity"
);
debug_assert!(
in_flight.len() <= HEALTH_CHECK_TOKEN_CAPACITY,
"cannot have more in-flight checks than the token slot capacity"
);
for _ in 0..HEALTH_CHECK_TOKEN_CAPACITY {
let offset = self.next_token_id % HEALTH_CHECK_TOKEN_CAPACITY;
self.next_token_id = self.next_token_id.wrapping_add(1);
if !in_flight.contains(&offset) {
let token = Token(HEALTH_CHECK_TOKEN_BASE + offset);
debug_assert!(
self.owns_token(token),
"allocated token must fall inside the health-check namespace"
);
debug_assert!(
!in_flight.contains(&offset),
"allocated offset must not already be in flight"
);
return Some(token);
}
}
debug_assert_eq!(
in_flight.len(),
HEALTH_CHECK_TOKEN_CAPACITY,
"allocation only fails when every slot is occupied"
);
error!(
"{} token-table full ({} in-flight checks); refusing to allocate a new probe slot",
log_context!(),
in_flight.len()
);
None
}
pub fn owns_token(&self, token: Token) -> bool {
let owned = token.0 >= HEALTH_CHECK_TOKEN_BASE
&& token.0 < HEALTH_CHECK_TOKEN_BASE + HEALTH_CHECK_TOKEN_CAPACITY;
debug_assert!(
!owned || token.0 - HEALTH_CHECK_TOKEN_BASE < HEALTH_CHECK_TOKEN_CAPACITY,
"an owned token must map to a valid bounded slot offset"
);
debug_assert!(
owned || token != Token(HEALTH_CHECK_TOKEN_BASE),
"the base token itself must always be classified as owned"
);
owned
}
pub fn ready(&mut self, token: Token) {
self.ready_tokens.insert(token);
debug_assert!(
self.ready_tokens.contains(&token),
"ready() must record the token in the readiness set"
);
}
pub fn poll(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
if self.in_flight.is_empty() && backends.borrow().health_check_configs.is_empty() {
return;
}
self.initiate_checks(backends, registry);
self.progress_checks(backends, registry);
}
fn initiate_checks(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
let backend_map = backends.borrow();
let now = Instant::now();
let mut to_check: PendingChecks = Vec::new();
for (cluster_id, config) in &backend_map.health_check_configs {
let interval = Duration::from_secs(u64::from(config.interval));
let mut hasher = std::collections::hash_map::DefaultHasher::new();
cluster_id.hash(&mut hasher);
let jitter_ms = hasher.finish() % (interval.as_millis() as u64 / 5).max(1);
let jittered_interval = interval + Duration::from_millis(jitter_ms);
let should_check = match self.last_check_time.get(cluster_id) {
Some(last) => now.duration_since(*last) >= jittered_interval,
None => true,
};
if !should_check {
continue;
}
if let Some(backend_list) = backend_map.backends.get(cluster_id) {
let backends_to_check: Vec<(String, SocketAddr)> = backend_list
.backends
.iter()
.filter(|b| {
let b = b.borrow();
b.status == crate::backends::BackendStatus::Normal
&& !self.in_flight.iter().any(|f| {
f.cluster_id == *cluster_id && f.backend_id == b.backend_id
})
})
.map(|b| {
let b = b.borrow();
(b.backend_id.to_owned(), b.address)
})
.collect();
if !backends_to_check.is_empty() {
let h2c = backend_map
.cluster_http2
.get(cluster_id)
.copied()
.unwrap_or(false);
to_check.push((
cluster_id.to_owned(),
config.to_owned(),
h2c,
backends_to_check,
));
}
}
}
drop(backend_map);
for (cluster_id, config, h2c, backends_to_check) in to_check {
self.last_check_time.insert(cluster_id.to_owned(), now);
let probe_uri = config.uri.as_str();
for (backend_id, address) in backends_to_check {
match TcpStream::connect(address) {
Ok(mut stream) => {
let Some(token) = self.allocate_token() else {
Self::record_check_result(
backends,
&cluster_id,
&backend_id,
address,
false,
&config,
);
continue;
};
if let Err(e) = registry.register(
&mut stream,
token,
Interest::READABLE | Interest::WRITABLE,
) {
debug!(
"{} failed to register socket for {} ({}) in cluster {}: {}",
log_context!(),
backend_id,
address,
cluster_id,
e
);
Self::record_check_result(
backends,
&cluster_id,
&backend_id,
address,
false,
&config,
);
continue;
}
trace!(
"{} initiated connection to {} ({}) for cluster {}",
log_context!(),
backend_id,
address,
cluster_id
);
let request_bytes = if h2c {
build_h2c_probe_bytes(probe_uri, address)
} else {
format!(
"GET {probe_uri} HTTP/1.1\r\nHost: {address}\r\nConnection: close\r\n\r\n"
)
.into_bytes()
};
self.in_flight.push(InFlightCheck {
stream,
token,
cluster_id: cluster_id.to_owned(),
backend_id,
address,
started_at: now,
timeout: Duration::from_secs(u64::from(config.timeout)),
request_bytes: Some(request_bytes),
write_offset: 0,
response_buf: Vec::with_capacity(256),
config: config.to_owned(),
h2c,
});
}
Err(e) => {
debug!(
"{} failed to connect to {} ({}) for cluster {}: {}",
log_context!(),
backend_id,
address,
cluster_id,
e
);
Self::record_check_result(
backends,
&cluster_id,
&backend_id,
address,
false,
&config,
);
}
}
}
}
}
fn progress_checks(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
const MAX_HEALTH_RESPONSE_SIZE: usize = 4096;
let now = Instant::now();
let mut completed = Vec::new();
let ready = std::mem::take(&mut self.ready_tokens);
debug_assert!(
self.ready_tokens.is_empty(),
"readiness set must be drained before processing in-flight checks"
);
let in_flight_before = self.in_flight.len();
for (idx, check) in self.in_flight.iter_mut().enumerate() {
debug_assert!(
idx < in_flight_before,
"in-flight index ({idx}) must be within the live slot range ({in_flight_before})"
);
debug_assert!(
check
.request_bytes
.as_ref()
.is_none_or(|r| check.write_offset <= r.len()),
"write_offset must never exceed the request length"
);
if now.duration_since(check.started_at) > check.timeout {
debug!(
"{} timeout for {} ({}) in cluster {}",
log_context!(),
check.backend_id,
check.address,
check.cluster_id
);
completed.push((idx, false));
continue;
}
if !ready.contains(&check.token) {
continue;
}
if let Some(ref request_bytes) = check.request_bytes {
match check.stream.write(&request_bytes[check.write_offset..]) {
Ok(n) => {
check.write_offset += n;
if check.write_offset >= request_bytes.len() {
check.request_bytes = None;
} else {
continue;
}
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue;
}
Err(_e) => {
completed.push((idx, false));
continue;
}
}
}
let mut buf = [0u8; 256];
match check.stream.read(&mut buf) {
Ok(0) => {
let success =
parse_probe_response(&check.response_buf, &check.config, check.h2c)
.unwrap_or(false);
completed.push((idx, success));
}
Ok(n) => {
debug_assert!(
n <= buf.len(),
"read reported {n} bytes into a {}-byte buffer",
buf.len()
);
if check.response_buf.len() + n > MAX_HEALTH_RESPONSE_SIZE {
completed.push((idx, false));
continue;
}
check.response_buf.extend_from_slice(&buf[..n]);
debug_assert!(
check.response_buf.len() <= MAX_HEALTH_RESPONSE_SIZE,
"response buffer must stay within the max health response size"
);
if let Some(success) =
parse_probe_response(&check.response_buf, &check.config, check.h2c)
{
completed.push((idx, success));
}
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
Err(_e) => {
completed.push((idx, false));
}
}
}
completed.sort_by(|a, b| b.0.cmp(&a.0));
debug_assert!(
completed.len() <= in_flight_before,
"cannot complete more checks ({}) than were in flight ({in_flight_before})",
completed.len()
);
debug_assert!(
completed.windows(2).all(|w| w[0].0 > w[1].0),
"completed indices must be strictly descending and unique for swap_remove safety"
);
for (idx, success) in completed {
let len_before = self.in_flight.len();
let mut check = self.in_flight.swap_remove(idx);
debug_assert_eq!(
self.in_flight.len(),
len_before - 1,
"swap_remove must drop exactly one in-flight check"
);
let _ = registry.deregister(&mut check.stream);
Self::record_check_result(
backends,
&check.cluster_id,
&check.backend_id,
check.address,
success,
&check.config,
);
}
}
fn record_check_result(
backends: &Rc<RefCell<BackendMap>>,
cluster_id: &str,
backend_id: &str,
address: SocketAddr,
success: bool,
config: &HealthCheckConfig,
) {
let mut backend_map = backends.borrow_mut();
let Some(backend_list) = backend_map.backends.get_mut(cluster_id) else {
return;
};
let Some(backend_ref) = backend_list.find_backend(&address) else {
return;
};
let mut backend = backend_ref.borrow_mut();
if success {
let was_healthy = backend.health.is_healthy();
let transitioned = backend.health.record_success(config.healthy_threshold);
debug_assert!(
backend.health.consecutive_failures == 0,
"a recorded success must zero the consecutive-failure counter"
);
debug_assert_eq!(
transitioned,
!was_healthy && backend.health.is_healthy(),
"transition flag must be set iff the backend just flipped to healthy"
);
debug_assert!(
!transitioned || backend.health.consecutive_successes >= config.healthy_threshold,
"an UP transition only fires once the rise counter reaches the healthy threshold"
);
debug_assert!(
!transitioned || backend.health.is_healthy(),
"after an UP transition the backend must report healthy"
);
if transitioned {
info!(
"{} backend {} at {} marked UP (health check passed {} consecutive times) for cluster {}",
log_context!(),
backend_id,
address,
config.healthy_threshold,
cluster_id
);
incr!(names::health_check::UP);
gauge!(
names::backend::AVAILABLE,
1,
Some(cluster_id),
Some(backend_id)
);
push_event(Event {
kind: EventKind::HealthCheckHealthy as i32,
cluster_id: Some(cluster_id.to_owned()),
backend_id: Some(backend_id.to_owned()),
address: Some(address.into()),
metric_detail: None,
});
}
count!(names::health_check::SUCCESS, 1);
} else {
let was_healthy = backend.health.is_healthy();
let transitioned = backend.health.record_failure(config.unhealthy_threshold);
debug_assert!(
backend.health.consecutive_successes == 0,
"a recorded failure must zero the consecutive-success counter"
);
debug_assert_eq!(
transitioned,
was_healthy && !backend.health.is_healthy(),
"transition flag must be set iff the backend just flipped to unhealthy"
);
debug_assert!(
!transitioned || backend.health.consecutive_failures >= config.unhealthy_threshold,
"a DOWN transition only fires once the fall counter reaches the unhealthy threshold"
);
debug_assert!(
!transitioned || !backend.health.is_healthy(),
"after a DOWN transition the backend must report unhealthy"
);
if transitioned {
warn!(
"{} backend {} at {} marked DOWN (health check failed {} consecutive times) for cluster {}",
log_context!(),
backend_id,
address,
config.unhealthy_threshold,
cluster_id
);
incr!(names::health_check::DOWN);
gauge!(
names::backend::AVAILABLE,
0,
Some(cluster_id),
Some(backend_id)
);
push_event(Event {
kind: EventKind::HealthCheckUnhealthy as i32,
cluster_id: Some(cluster_id.to_owned()),
backend_id: Some(backend_id.to_owned()),
address: Some(address.into()),
metric_detail: None,
});
}
count!(names::health_check::FAILURE, 1);
}
drop(backend);
let total = backend_list.backends.len();
let healthy = backend_list
.backends
.iter()
.filter(|b| b.borrow().health.is_healthy())
.count();
debug_assert!(
healthy <= total,
"healthy backend count ({healthy}) must not exceed total ({total})"
);
if total > 0 {
gauge!(
"health_check.healthy_backends",
healthy,
Some(cluster_id),
None
);
if healthy > 0 && healthy * 2 <= total {
warn!(
"{} cluster {} has only {}/{} healthy backends",
log_context!(),
cluster_id,
healthy,
total
);
}
}
backend_map.record_cluster_availability(cluster_id);
}
pub fn remove_cluster(&mut self, cluster_id: &str) {
self.last_check_time.remove(cluster_id);
self.in_flight
.retain(|check| check.cluster_id != cluster_id);
debug_assert!(
self.in_flight.iter().all(|c| c.cluster_id != cluster_id),
"remove_cluster must drop every in-flight check for the cluster"
);
debug_assert!(
!self.last_check_time.contains_key(cluster_id),
"remove_cluster must forget the cluster's last-check timestamp"
);
}
}
fn parse_probe_response(buf: &[u8], config: &HealthCheckConfig, h2c: bool) -> Option<bool> {
if h2c {
try_parse_h2c_status(buf, config)
} else {
try_parse_status_line(buf, config)
}
}
fn try_parse_status_line(buf: &[u8], config: &HealthCheckConfig) -> Option<bool> {
let response = std::str::from_utf8(buf).ok()?;
let first_line_end = response.find("\r\n")?;
let status_line = &response[..first_line_end];
debug_assert!(
status_line.len() < response.len(),
"status line must be a strict prefix ending before the CRLF"
);
let (_, rest) = status_line.split_once(' ')?;
let status_str = rest.split(' ').next()?;
let status_code: u32 = status_str.parse().unwrap_or(0);
Some(is_status_healthy(status_code, config.expected_status))
}
fn is_status_healthy(actual: u32, expected: u32) -> bool {
let healthy = if expected == 0 {
(200..300).contains(&actual)
} else {
actual == expected
};
debug_assert!(
expected == 0 || healthy == (actual == expected),
"with a specific expected status, health must be exact equality"
);
healthy
}
fn build_h2c_probe_bytes(uri: &str, address: SocketAddr) -> Vec<u8> {
let authority = address.to_string();
let mut encoder = loona_hpack::Encoder::new();
let mut hpack: Vec<u8> = Vec::new();
let headers: [(&[u8], &[u8]); 4] = [
(b":method", b"GET"),
(b":scheme", b"http"),
(b":path", uri.as_bytes()),
(b":authority", authority.as_bytes()),
];
if encoder.encode_into(headers, &mut hpack).is_err() {
return Vec::new();
}
let mut out = Vec::with_capacity(H2_PRI.len() + FRAME_HEADER_SIZE * 2 + hpack.len());
out.extend_from_slice(H2_PRI.as_bytes());
out.extend_from_slice(&[0, 0, 0, 0x04, 0, 0, 0, 0, 0]);
let len = hpack.len() as u32;
out.push(((len >> 16) & 0xFF) as u8);
out.push(((len >> 8) & 0xFF) as u8);
out.push((len & 0xFF) as u8);
out.push(0x01); out.push(0x05); out.extend_from_slice(&[0, 0, 0, 1]); out.extend_from_slice(&hpack);
debug_assert!(
out.starts_with(H2_PRI.as_bytes()),
"an h2c probe must begin with the connection preface"
);
debug_assert_eq!(
out.len(),
H2_PRI.len() + FRAME_HEADER_SIZE * 2 + hpack.len(),
"probe length must be preface + SETTINGS + HEADERS header + HPACK block"
);
out
}
fn try_parse_h2c_status(buf: &[u8], config: &HealthCheckConfig) -> Option<bool> {
const MAX_FRAME_SIZE: u32 = (1 << 24) - 1;
let mut remaining: &[u8] = buf;
let mut headers_block: Option<Vec<u8>> = None;
while !remaining.is_empty() {
if remaining.len() < FRAME_HEADER_SIZE {
return None;
}
let consumable = remaining.len();
let (rest, header) = match frame_header(remaining, MAX_FRAME_SIZE) {
Ok(parsed) => parsed,
Err(_) => return Some(false),
};
debug_assert!(
rest.len() < consumable,
"frame_header must consume at least the fixed frame header"
);
debug_assert_eq!(
consumable - rest.len(),
FRAME_HEADER_SIZE,
"frame_header must consume exactly the fixed-size frame header"
);
debug_assert!(
header.payload_len <= MAX_FRAME_SIZE,
"frame_header must enforce the max-frame-size bound it was given"
);
let payload_len = header.payload_len as usize;
if rest.len() < payload_len {
return None;
}
let (payload, after) = rest.split_at(payload_len);
debug_assert_eq!(
payload.len(),
payload_len,
"payload split must yield exactly the declared payload length"
);
debug_assert_eq!(
payload.len() + after.len(),
rest.len(),
"payload + remainder must equal the pre-split buffer"
);
debug_assert!(
after.len() < remaining.len(),
"each iteration must shrink the remaining buffer to guarantee termination"
);
match header.frame_type {
FrameType::Headers if header.stream_id == 1 => {
let block = strip_padded_priority(payload, header.flags)?;
let mut accumulator = headers_block.take().unwrap_or_default();
accumulator.extend_from_slice(block);
if header.flags & FLAG_END_HEADERS != 0 {
return Some(decode_status_from_block(&accumulator, config));
}
headers_block = Some(accumulator);
}
FrameType::Continuation if header.stream_id == 1 => {
let Some(mut accumulator) = headers_block.take() else {
return Some(false);
};
accumulator.extend_from_slice(payload);
if header.flags & FLAG_END_HEADERS != 0 {
return Some(decode_status_from_block(&accumulator, config));
}
headers_block = Some(accumulator);
}
FrameType::GoAway => return Some(false),
_ => {}
}
remaining = after;
}
None
}
fn strip_padded_priority(payload: &[u8], flags: u8) -> Option<&[u8]> {
let mut start = 0usize;
let mut end = payload.len();
if flags & FLAG_PADDED != 0 {
let &pad_len = payload.first()?;
start = 1;
let pad = pad_len as usize;
let available = end.checked_sub(start)?;
if pad > available {
return None;
}
end -= pad;
}
if flags & FLAG_PRIORITY != 0 {
let new_start = start.checked_add(5)?;
if new_start > end {
return None;
}
start = new_start;
}
debug_assert!(
start <= end && end <= payload.len(),
"stripped header window [{start}, {end}) must lie within the payload ({})",
payload.len()
);
let block = payload.get(start..end)?;
debug_assert!(
block.len() <= payload.len(),
"stripped block must never be larger than the original payload"
);
Some(block)
}
fn decode_status_from_block(block: &[u8], config: &HealthCheckConfig) -> bool {
let mut decoder = loona_hpack::Decoder::new();
let mut status: Option<u32> = None;
let decode_result = decoder.decode_with_cb(block, |name, value| {
if status.is_some() {
return;
}
if name.as_ref() == b":status"
&& let Ok(s) = std::str::from_utf8(value.as_ref())
&& let Ok(parsed) = s.parse::<u32>()
{
status = Some(parsed);
}
});
if decode_result.is_err() {
return false;
}
match status {
Some(code) => is_status_healthy(code, config.expected_status),
None => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::HealthState;
#[test]
fn test_is_status_healthy_any_2xx() {
assert!(is_status_healthy(200, 0));
assert!(is_status_healthy(204, 0));
assert!(is_status_healthy(299, 0));
assert!(!is_status_healthy(301, 0));
assert!(!is_status_healthy(500, 0));
assert!(!is_status_healthy(0, 0));
}
#[test]
fn test_is_status_healthy_specific() {
assert!(is_status_healthy(200, 200));
assert!(!is_status_healthy(204, 200));
assert!(!is_status_healthy(500, 200));
}
#[test]
fn test_try_parse_status_line() {
let config = HealthCheckConfig {
uri: "/health".to_owned(),
interval: 10,
timeout: 5,
healthy_threshold: 3,
unhealthy_threshold: 3,
expected_status: 0,
};
let buf = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
assert_eq!(try_parse_status_line(buf, &config), Some(true));
let buf = b"HTTP/1.1 500 Internal Server Error\r\n\r\n";
assert_eq!(try_parse_status_line(buf, &config), Some(false));
let buf = b"HTTP/1.1 200";
assert_eq!(try_parse_status_line(buf, &config), None);
}
#[test]
fn test_health_state_transitions() {
let mut state = HealthState::default();
assert!(state.is_healthy());
assert!(!state.record_failure(3));
assert!(!state.record_failure(3));
assert!(state.is_healthy());
assert!(state.record_failure(3));
assert!(!state.is_healthy());
assert!(!state.record_success(3));
assert!(!state.record_success(3));
assert!(!state.is_healthy());
assert!(state.record_success(3));
assert!(state.is_healthy());
}
fn h2c_config(expected: u32) -> HealthCheckConfig {
HealthCheckConfig {
uri: "/health".to_owned(),
interval: 10,
timeout: 5,
healthy_threshold: 3,
unhealthy_threshold: 3,
expected_status: expected,
}
}
fn frame_with_header(frame_type: u8, flags: u8, sid: u32, payload: &[u8]) -> Vec<u8> {
let payload_len = payload.len();
let mut out = Vec::with_capacity(FRAME_HEADER_SIZE + payload_len);
out.push(((payload_len >> 16) & 0xFF) as u8);
out.push(((payload_len >> 8) & 0xFF) as u8);
out.push((payload_len & 0xFF) as u8);
out.push(frame_type);
out.push(flags);
out.extend_from_slice(&sid.to_be_bytes());
out.extend_from_slice(payload);
out
}
fn encode_response_headers(headers: &[(&[u8], &[u8])]) -> Vec<u8> {
let mut encoder = loona_hpack::Encoder::new();
let mut out = Vec::new();
encoder
.encode_into(headers.iter().copied(), &mut out)
.unwrap();
out
}
#[test]
fn build_h2c_probe_starts_with_preface_and_frames() {
let bytes = build_h2c_probe_bytes("/health", "127.0.0.1:8080".parse().unwrap());
assert!(bytes.starts_with(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"));
let settings_start = 24;
assert_eq!(&bytes[settings_start..settings_start + 3], &[0u8, 0, 0]); assert_eq!(bytes[settings_start + 3], 0x04); assert_eq!(bytes[settings_start + 4], 0); assert_eq!(
&bytes[settings_start + 5..settings_start + 9],
&[0u8, 0, 0, 0]
);
let headers_start = settings_start + 9;
assert_eq!(bytes[headers_start + 3], 0x01); assert_eq!(bytes[headers_start + 4], 0x05);
assert_eq!(
&bytes[headers_start + 5..headers_start + 9],
&[0u8, 0, 0, 1]
);
let payload_start = headers_start + 9;
let mut decoder = loona_hpack::Decoder::new();
let mut method = None;
let mut scheme = None;
let mut path = None;
let mut authority = None;
decoder
.decode_with_cb(&bytes[payload_start..], |name, value| match name.as_ref() {
b":method" => method = Some(value.to_vec()),
b":scheme" => scheme = Some(value.to_vec()),
b":path" => path = Some(value.to_vec()),
b":authority" => authority = Some(value.to_vec()),
_ => {}
})
.expect("loona_hpack decodes a freshly-encoded probe");
assert_eq!(method.as_deref(), Some(b"GET" as &[u8]));
assert_eq!(scheme.as_deref(), Some(b"http" as &[u8]));
assert_eq!(path.as_deref(), Some(b"/health" as &[u8]));
assert_eq!(authority.as_deref(), Some(b"127.0.0.1:8080" as &[u8]));
}
#[test]
fn h2c_response_with_status_200_decodes_healthy() {
let block = encode_response_headers(&[(b":status", b"200")]);
let buf = frame_with_header(0x01, FLAG_END_HEADERS, 1, &block);
let cfg = h2c_config(0);
assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
}
#[test]
fn h2c_response_with_status_500_fails_default_2xx_check() {
let block = encode_response_headers(&[(b":status", b"500")]);
let buf = frame_with_header(0x01, FLAG_END_HEADERS, 1, &block);
let cfg = h2c_config(0);
assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(false));
}
#[test]
fn h2c_response_with_status_503_matches_expected_503() {
let block =
encode_response_headers(&[(b":status", b"503"), (b"content-type", b"text/plain")]);
let buf = frame_with_header(0x01, FLAG_END_HEADERS, 1, &block);
let cfg = h2c_config(503);
assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
}
#[test]
fn h2c_response_with_continuation_decodes_status_200_healthy() {
let block = encode_response_headers(&[
(b":status", b"200"),
(b"x-trace-id", b"abc-123"),
(b"server", b"sozu-test"),
]);
assert!(block.len() >= 4, "HPACK block needs to be splittable");
let split = block.len() / 2;
let (head, tail) = block.split_at(split);
let mut buf = frame_with_header(0x01, 0, 1, head);
buf.extend_from_slice(&frame_with_header(0x09, FLAG_END_HEADERS, 1, tail));
let cfg = h2c_config(0);
assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
}
#[test]
fn h2c_response_with_padded_priority_headers_decodes_status_200() {
let block = encode_response_headers(&[(b":status", b"200")]);
let pad_len: u8 = 3;
let mut payload = Vec::new();
payload.push(pad_len); payload.extend_from_slice(&[0u8, 0, 0, 0, 16]); payload.extend_from_slice(&block);
payload.extend_from_slice(&[0u8; 3]);
let flags = FLAG_PADDED | FLAG_PRIORITY | FLAG_END_HEADERS;
let buf = frame_with_header(0x01, flags, 1, &payload);
let cfg = h2c_config(0);
assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
}
#[test]
fn h2c_response_after_unrelated_settings_frame_decodes_healthy() {
let mut buf = frame_with_header(0x04, 0, 0, &[]); buf.extend_from_slice(&frame_with_header(0x04, 0x01, 0, &[])); let block = encode_response_headers(&[(b":status", b"200")]);
buf.extend_from_slice(&frame_with_header(0x01, FLAG_END_HEADERS, 1, &block));
let cfg = h2c_config(0);
assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
}
#[test]
fn h2c_goaway_returns_unhealthy() {
let buf = frame_with_header(0x07, 0, 0, &[0u8; 8]);
let cfg = h2c_config(0);
assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(false));
}
#[test]
fn h2c_truncated_frame_returns_none() {
let mut buf: Vec<u8> = vec![
0, 0, 10, 0x01, FLAG_END_HEADERS, ];
buf.extend_from_slice(&1u32.to_be_bytes()); buf.extend_from_slice(&[0u8; 5]); let cfg = h2c_config(0);
assert_eq!(try_parse_h2c_status(&buf, &cfg), None);
}
#[test]
fn h2c_partial_frame_header_returns_none() {
let cfg = h2c_config(0);
for partial_len in 0usize..FRAME_HEADER_SIZE {
let buf = vec![0u8; partial_len];
assert_eq!(
try_parse_h2c_status(&buf, &cfg),
None,
"partial buffer of {partial_len} byte(s) should be 'keep reading'"
);
}
}
#[test]
fn h2c_continuation_without_preceding_headers_returns_unhealthy() {
let block = encode_response_headers(&[(b":status", b"200")]);
let buf = frame_with_header(0x09, FLAG_END_HEADERS, 1, &block);
let cfg = h2c_config(0);
assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(false));
}
}