use std::{
cell::RefCell,
collections::{HashMap, HashSet},
io::ErrorKind,
net::SocketAddr,
rc::Rc,
time::{Duration, Instant},
};
use mio::{
Interest, Registry, Token,
net::{TcpStream, UdpSocket},
};
use sozu_command::{proto::command::UdpHealthConfig, state::ClusterId};
use crate::backends::BackendMap;
use crate::metrics::names;
use crate::socket::udp_connect;
macro_rules! log_context {
() => {
"UDP-HEALTH"
};
}
const UDP_HEALTH_TOKEN_BASE: usize = (1 << 24) + (1 << 20);
const UDP_HEALTH_TOKEN_CAPACITY: usize = 1 << 16;
#[derive(Clone, Debug)]
pub struct UdpHealthSettings {
pub tcp_port: Option<u16>,
pub rise: u32,
pub fall: u32,
pub interval: Duration,
pub timeout: Duration,
pub udp_probe_payload: Option<Vec<u8>>,
}
impl UdpHealthSettings {
pub fn from_proto(cfg: &UdpHealthConfig) -> Self {
UdpHealthSettings {
tcp_port: cfg.tcp_port.map(|p| p as u16),
rise: cfg.rise.unwrap_or(2),
fall: cfg.fall.unwrap_or(3),
interval: Duration::from_secs(u64::from(cfg.probe_interval_seconds.unwrap_or(5))),
timeout: Duration::from_secs(u64::from(cfg.probe_timeout_seconds.unwrap_or(2))),
udp_probe_payload: cfg.udp_probe_payload.clone(),
}
}
}
type ProbeBatch = Vec<(ClusterId, UdpHealthSettings, Vec<(String, SocketAddr)>)>;
enum ProbeSocket {
Tcp(TcpStream),
Udp(UdpSocket),
}
impl ProbeSocket {
fn deregister(&mut self, registry: &Registry) {
match self {
ProbeSocket::Tcp(s) => {
let _ = registry.deregister(s);
}
ProbeSocket::Udp(s) => {
let _ = registry.deregister(s);
}
}
}
}
struct InFlightProbe {
socket: ProbeSocket,
token: Token,
cluster_id: ClusterId,
backend_id: String,
address: SocketAddr,
started_at: Instant,
timeout: Duration,
rise: u32,
fall: u32,
}
#[derive(Default)]
pub struct UdpHealthChecker {
settings: HashMap<ClusterId, UdpHealthSettings>,
in_flight: Vec<InFlightProbe>,
last_check: HashMap<ClusterId, Instant>,
next_token_id: usize,
ready_tokens: HashSet<Token>,
}
impl UdpHealthChecker {
pub fn new() -> Self {
Self::default()
}
pub fn set_cluster(
&mut self,
cluster_id: &str,
settings: Option<UdpHealthSettings>,
registry: &Registry,
) {
match settings {
Some(s) => {
self.settings.insert(cluster_id.to_owned(), s);
}
None => {
self.settings.remove(cluster_id);
self.last_check.remove(cluster_id);
let mut kept = Vec::with_capacity(self.in_flight.len());
for mut probe in self.in_flight.drain(..) {
if probe.cluster_id == cluster_id {
probe.socket.deregister(registry);
} else {
kept.push(probe);
}
}
self.in_flight = kept;
}
}
}
pub fn remove_cluster(&mut self, cluster_id: &str, registry: &Registry) {
self.set_cluster(cluster_id, None, registry);
}
pub fn owns_token(&self, token: Token) -> bool {
token.0 >= UDP_HEALTH_TOKEN_BASE
&& token.0 < UDP_HEALTH_TOKEN_BASE + UDP_HEALTH_TOKEN_CAPACITY
}
pub fn ready(&mut self, token: Token) {
self.ready_tokens.insert(token);
}
fn allocate_token(&mut self) -> Option<Token> {
let in_flight: HashSet<usize> = self
.in_flight
.iter()
.map(|p| p.token.0 - UDP_HEALTH_TOKEN_BASE)
.collect();
for _ in 0..UDP_HEALTH_TOKEN_CAPACITY {
let offset = self.next_token_id % UDP_HEALTH_TOKEN_CAPACITY;
self.next_token_id = self.next_token_id.wrapping_add(1);
if !in_flight.contains(&offset) {
let token = Token(UDP_HEALTH_TOKEN_BASE + offset);
debug_assert!(
self.owns_token(token),
"allocate_token returned a token outside the health namespace"
);
debug_assert!(
!self.in_flight.iter().any(|p| p.token == token),
"allocate_token returned a token already in flight"
);
return Some(token);
}
}
error!(
"{} token table full ({} in-flight); refusing new probe slot",
log_context!(),
in_flight.len()
);
None
}
pub fn poll(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
if self.settings.is_empty() && self.in_flight.is_empty() {
return;
}
self.initiate(backends, registry);
self.progress(backends, registry);
}
fn initiate(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
let now = Instant::now();
let backend_map = backends.borrow();
let mut to_probe: ProbeBatch = Vec::new();
for (cluster_id, settings) in &self.settings {
let due = match self.last_check.get(cluster_id) {
Some(last) => now.duration_since(*last) >= settings.interval,
None => true,
};
if !due {
continue;
}
if let Some(list) = backend_map.backends.get(cluster_id) {
let targets: Vec<(String, SocketAddr)> =
list.backends
.iter()
.filter(|b| {
let b = b.borrow();
!self.in_flight.iter().any(|p| {
p.cluster_id == *cluster_id && p.backend_id == b.backend_id
})
})
.map(|b| {
let b = b.borrow();
(b.backend_id.to_owned(), b.address)
})
.collect();
if !targets.is_empty() {
to_probe.push((cluster_id.to_owned(), settings.clone(), targets));
}
}
}
drop(backend_map);
for (cluster_id, settings, targets) in to_probe {
self.last_check.insert(cluster_id.to_owned(), now);
for (backend_id, address) in targets {
self.spawn_tcp_probe(
backends,
registry,
&cluster_id,
&backend_id,
address,
&settings,
now,
);
if settings.udp_probe_payload.is_some() {
self.spawn_udp_probe(
backends,
registry,
&cluster_id,
&backend_id,
address,
&settings,
now,
);
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn spawn_tcp_probe(
&mut self,
backends: &Rc<RefCell<BackendMap>>,
registry: &Registry,
cluster_id: &str,
backend_id: &str,
address: SocketAddr,
settings: &UdpHealthSettings,
now: Instant,
) {
let probe_addr = match settings.tcp_port {
Some(port) => SocketAddr::new(address.ip(), port),
None => address,
};
let record_failure = || {
Self::record(
backends,
cluster_id,
backend_id,
address,
false,
settings.rise,
settings.fall,
)
};
let mut stream = match TcpStream::connect(probe_addr) {
Ok(stream) => stream,
Err(_) => return record_failure(),
};
let Some(token) = self.allocate_token() else {
return record_failure();
};
if registry
.register(&mut stream, token, Interest::WRITABLE)
.is_err()
{
return record_failure();
}
self.in_flight.push(InFlightProbe {
socket: ProbeSocket::Tcp(stream),
token,
cluster_id: cluster_id.to_owned(),
backend_id: backend_id.to_owned(),
address,
started_at: now,
timeout: settings.timeout,
rise: settings.rise,
fall: settings.fall,
});
}
#[allow(clippy::too_many_arguments)]
fn spawn_udp_probe(
&mut self,
backends: &Rc<RefCell<BackendMap>>,
registry: &Registry,
cluster_id: &str,
backend_id: &str,
address: SocketAddr,
settings: &UdpHealthSettings,
now: Instant,
) {
let Some(payload) = settings.udp_probe_payload.as_deref() else {
return;
};
let record_failure = || {
Self::record(
backends,
cluster_id,
backend_id,
address,
false,
settings.rise,
settings.fall,
)
};
let mut socket = match udp_connect(address) {
Ok(socket) => socket,
Err(_) => return record_failure(),
};
match socket.send(payload) {
Ok(_) => {}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => return record_failure(),
Err(_) => return record_failure(),
}
let Some(token) = self.allocate_token() else {
return record_failure();
};
if registry
.register(&mut socket, token, Interest::READABLE)
.is_err()
{
return record_failure();
}
self.in_flight.push(InFlightProbe {
socket: ProbeSocket::Udp(socket),
token,
cluster_id: cluster_id.to_owned(),
backend_id: backend_id.to_owned(),
address,
started_at: now,
timeout: settings.timeout,
rise: settings.rise,
fall: settings.fall,
});
}
fn progress(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
let now = Instant::now();
let ready = std::mem::take(&mut self.ready_tokens);
let mut completed: Vec<(usize, bool)> = Vec::new();
for (idx, probe) in self.in_flight.iter_mut().enumerate() {
if now.duration_since(probe.started_at) > probe.timeout {
completed.push((idx, false));
continue;
}
if !ready.contains(&probe.token) {
continue;
}
let success = match &mut probe.socket {
ProbeSocket::Tcp(stream) => {
let no_so_error = matches!(stream.take_error(), Ok(None));
no_so_error && stream.peer_addr().is_ok()
}
ProbeSocket::Udp(socket) => {
let mut scratch = [0u8; 16];
match socket.recv(&mut scratch) {
Ok(_) => true,
Err(ref e) if e.kind() == ErrorKind::WouldBlock => continue,
Err(_) => false,
}
}
};
completed.push((idx, success));
}
completed.sort_by(|a, b| b.0.cmp(&a.0));
for (idx, success) in completed {
let mut probe = self.in_flight.swap_remove(idx);
probe.socket.deregister(registry);
Self::record(
backends,
&probe.cluster_id,
&probe.backend_id,
probe.address,
success,
probe.rise,
probe.fall,
);
}
}
fn record(
backends: &Rc<RefCell<BackendMap>>,
cluster_id: &str,
backend_id: &str,
address: SocketAddr,
success: bool,
rise: u32,
fall: u32,
) {
let mut backend_map = backends.borrow_mut();
let Some(list) = backend_map.backends.get_mut(cluster_id) else {
return;
};
let Some(backend_ref) = list.find_backend(&address) else {
return;
};
let mut backend = backend_ref.borrow_mut();
if success {
if backend.health.record_success(rise) {
info!(
"{} backend {} at {} marked UP (cluster {})",
log_context!(),
backend_id,
address,
cluster_id
);
incr!(names::udp::BACKEND_HEALTH);
}
} else if backend.health.record_failure(fall) {
warn!(
"{} backend {} at {} marked DOWN (cluster {})",
log_context!(),
backend_id,
address,
cluster_id
);
incr!(names::udp::BACKEND_HEALTH);
}
drop(backend);
backend_map.record_cluster_availability(cluster_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::HealthState;
#[test]
fn hysteresis_rise_fall() {
let mut state = HealthState::default();
assert!(state.is_healthy());
assert!(!state.record_failure(3));
assert!(!state.record_failure(3));
assert!(state.is_healthy());
assert!(state.record_failure(3));
assert!(!state.is_healthy());
assert!(!state.record_success(2));
assert!(state.record_success(2));
assert!(state.is_healthy());
}
#[test]
fn token_namespace_is_disjoint_and_owned() {
let hc = UdpHealthChecker::new();
assert!(hc.owns_token(Token(UDP_HEALTH_TOKEN_BASE)));
assert!(hc.owns_token(Token(UDP_HEALTH_TOKEN_BASE + UDP_HEALTH_TOKEN_CAPACITY - 1)));
assert!(!hc.owns_token(Token(UDP_HEALTH_TOKEN_BASE - 1)));
assert!(!hc.owns_token(Token(UDP_HEALTH_TOKEN_BASE + UDP_HEALTH_TOKEN_CAPACITY)));
assert!(!hc.owns_token(Token(1 << 24)));
}
#[test]
fn settings_from_proto_defaults() {
let cfg = UdpHealthConfig {
mode: None,
tcp_port: Some(5353),
rise: None,
fall: None,
fail_open: None,
udp_probe_payload: None,
probe_interval_seconds: None,
probe_timeout_seconds: None,
};
let s = UdpHealthSettings::from_proto(&cfg);
assert_eq!(s.tcp_port, Some(5353));
assert_eq!(s.rise, 2);
assert_eq!(s.fall, 3);
assert_eq!(s.interval, Duration::from_secs(5));
assert_eq!(s.timeout, Duration::from_secs(2));
}
#[test]
fn udp_probe_payload_is_captured() {
let cfg = UdpHealthConfig {
mode: Some(sozu_command::proto::command::UdpHealthMode::UdpProbe as i32),
tcp_port: None,
rise: Some(1),
fall: Some(1),
fail_open: None,
udp_probe_payload: Some(b"PING".to_vec()),
probe_interval_seconds: Some(1),
probe_timeout_seconds: Some(1),
};
let s = UdpHealthSettings::from_proto(&cfg);
assert_eq!(s.udp_probe_payload.as_deref(), Some(&b"PING"[..]));
assert_eq!(s.tcp_port, None);
}
#[test]
fn udp_probe_result_feeds_same_hysteresis() {
use crate::backends::{Backend, BackendMap};
let cluster = "dns";
let address: SocketAddr = ([127, 0, 0, 1], 5353).into();
let backend_map = Rc::new(RefCell::new(BackendMap::new()));
backend_map
.borrow_mut()
.add_backend(cluster, Backend::new("b1", address, None, None, None));
let (rise, fall) = (2u32, 3u32);
let is_healthy = |map: &Rc<RefCell<BackendMap>>| {
let mut m = map.borrow_mut();
let list = m.backends.get_mut(cluster).unwrap();
let b = list.find_backend(&address).unwrap();
b.borrow().health.is_healthy()
};
assert!(is_healthy(&backend_map));
for _ in 0..fall {
UdpHealthChecker::record(&backend_map, cluster, "b1", address, false, rise, fall);
}
assert!(!is_healthy(&backend_map));
UdpHealthChecker::record(&backend_map, cluster, "b1", address, true, rise, fall);
assert!(!is_healthy(&backend_map));
UdpHealthChecker::record(&backend_map, cluster, "b1", address, true, rise, fall);
assert!(is_healthy(&backend_map));
}
}