1use std::{
29 cell::RefCell,
30 collections::{HashMap, HashSet},
31 io::ErrorKind,
32 net::SocketAddr,
33 rc::Rc,
34 time::{Duration, Instant},
35};
36
37use mio::{
38 Interest, Registry, Token,
39 net::{TcpStream, UdpSocket},
40};
41use sozu_command::{proto::command::UdpHealthConfig, state::ClusterId};
42
43use crate::backends::BackendMap;
44use crate::metrics::names;
45use crate::socket::udp_connect;
46
47macro_rules! log_context {
49 () => {
50 "UDP-HEALTH"
51 };
52}
53
54const UDP_HEALTH_TOKEN_BASE: usize = (1 << 24) + (1 << 20);
58const UDP_HEALTH_TOKEN_CAPACITY: usize = 1 << 16;
60
61#[derive(Clone, Debug)]
65pub struct UdpHealthSettings {
66 pub tcp_port: Option<u16>,
68 pub rise: u32,
70 pub fall: u32,
72 pub interval: Duration,
74 pub timeout: Duration,
76 pub udp_probe_payload: Option<Vec<u8>>,
78}
79
80impl UdpHealthSettings {
81 pub fn from_proto(cfg: &UdpHealthConfig) -> Self {
84 UdpHealthSettings {
85 tcp_port: cfg.tcp_port.map(|p| p as u16),
86 rise: cfg.rise.unwrap_or(2),
87 fall: cfg.fall.unwrap_or(3),
88 interval: Duration::from_secs(u64::from(cfg.probe_interval_seconds.unwrap_or(5))),
89 timeout: Duration::from_secs(u64::from(cfg.probe_timeout_seconds.unwrap_or(2))),
90 udp_probe_payload: cfg.udp_probe_payload.clone(),
91 }
92 }
93}
94
95type ProbeBatch = Vec<(ClusterId, UdpHealthSettings, Vec<(String, SocketAddr)>)>;
98
99enum ProbeSocket {
107 Tcp(TcpStream),
108 Udp(UdpSocket),
109}
110
111impl ProbeSocket {
112 fn deregister(&mut self, registry: &Registry) {
114 match self {
115 ProbeSocket::Tcp(s) => {
116 let _ = registry.deregister(s);
117 }
118 ProbeSocket::Udp(s) => {
119 let _ = registry.deregister(s);
120 }
121 }
122 }
123}
124
125struct InFlightProbe {
127 socket: ProbeSocket,
128 token: Token,
129 cluster_id: ClusterId,
130 backend_id: String,
131 address: SocketAddr,
132 started_at: Instant,
133 timeout: Duration,
134 rise: u32,
135 fall: u32,
136}
137
138#[derive(Default)]
142pub struct UdpHealthChecker {
143 settings: HashMap<ClusterId, UdpHealthSettings>,
146 in_flight: Vec<InFlightProbe>,
147 last_check: HashMap<ClusterId, Instant>,
148 next_token_id: usize,
149 ready_tokens: HashSet<Token>,
150}
151
152impl UdpHealthChecker {
153 pub fn new() -> Self {
154 Self::default()
155 }
156
157 pub fn set_cluster(
165 &mut self,
166 cluster_id: &str,
167 settings: Option<UdpHealthSettings>,
168 registry: &Registry,
169 ) {
170 match settings {
171 Some(s) => {
172 self.settings.insert(cluster_id.to_owned(), s);
173 }
174 None => {
175 self.settings.remove(cluster_id);
176 self.last_check.remove(cluster_id);
177 let mut kept = Vec::with_capacity(self.in_flight.len());
182 for mut probe in self.in_flight.drain(..) {
183 if probe.cluster_id == cluster_id {
184 probe.socket.deregister(registry);
185 } else {
186 kept.push(probe);
187 }
188 }
189 self.in_flight = kept;
190 }
191 }
192 }
193
194 pub fn remove_cluster(&mut self, cluster_id: &str, registry: &Registry) {
196 self.set_cluster(cluster_id, None, registry);
197 }
198
199 pub fn owns_token(&self, token: Token) -> bool {
201 token.0 >= UDP_HEALTH_TOKEN_BASE
202 && token.0 < UDP_HEALTH_TOKEN_BASE + UDP_HEALTH_TOKEN_CAPACITY
203 }
204
205 pub fn ready(&mut self, token: Token) {
207 self.ready_tokens.insert(token);
208 }
209
210 fn allocate_token(&mut self) -> Option<Token> {
212 let in_flight: HashSet<usize> = self
213 .in_flight
214 .iter()
215 .map(|p| p.token.0 - UDP_HEALTH_TOKEN_BASE)
216 .collect();
217 for _ in 0..UDP_HEALTH_TOKEN_CAPACITY {
218 let offset = self.next_token_id % UDP_HEALTH_TOKEN_CAPACITY;
219 self.next_token_id = self.next_token_id.wrapping_add(1);
220 if !in_flight.contains(&offset) {
221 let token = Token(UDP_HEALTH_TOKEN_BASE + offset);
222 debug_assert!(
227 self.owns_token(token),
228 "allocate_token returned a token outside the health namespace"
229 );
230 debug_assert!(
231 !self.in_flight.iter().any(|p| p.token == token),
232 "allocate_token returned a token already in flight"
233 );
234 return Some(token);
235 }
236 }
237 error!(
238 "{} token table full ({} in-flight); refusing new probe slot",
239 log_context!(),
240 in_flight.len()
241 );
242 None
243 }
244
245 pub fn poll(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
249 if self.settings.is_empty() && self.in_flight.is_empty() {
250 return;
251 }
252 self.initiate(backends, registry);
253 self.progress(backends, registry);
254 }
255
256 fn initiate(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
257 let now = Instant::now();
258 let backend_map = backends.borrow();
259
260 let mut to_probe: ProbeBatch = Vec::new();
262 for (cluster_id, settings) in &self.settings {
263 let due = match self.last_check.get(cluster_id) {
264 Some(last) => now.duration_since(*last) >= settings.interval,
265 None => true,
266 };
267 if !due {
268 continue;
269 }
270 if let Some(list) = backend_map.backends.get(cluster_id) {
271 let targets: Vec<(String, SocketAddr)> =
272 list.backends
273 .iter()
274 .filter(|b| {
275 let b = b.borrow();
276 !self.in_flight.iter().any(|p| {
277 p.cluster_id == *cluster_id && p.backend_id == b.backend_id
278 })
279 })
280 .map(|b| {
281 let b = b.borrow();
282 (b.backend_id.to_owned(), b.address)
283 })
284 .collect();
285 if !targets.is_empty() {
286 to_probe.push((cluster_id.to_owned(), settings.clone(), targets));
287 }
288 }
289 }
290 drop(backend_map);
291
292 for (cluster_id, settings, targets) in to_probe {
293 self.last_check.insert(cluster_id.to_owned(), now);
294 for (backend_id, address) in targets {
295 self.spawn_tcp_probe(
297 backends,
298 registry,
299 &cluster_id,
300 &backend_id,
301 address,
302 &settings,
303 now,
304 );
305 if settings.udp_probe_payload.is_some() {
310 self.spawn_udp_probe(
311 backends,
312 registry,
313 &cluster_id,
314 &backend_id,
315 address,
316 &settings,
317 now,
318 );
319 }
320 }
321 }
322 }
323
324 #[allow(clippy::too_many_arguments)]
328 fn spawn_tcp_probe(
329 &mut self,
330 backends: &Rc<RefCell<BackendMap>>,
331 registry: &Registry,
332 cluster_id: &str,
333 backend_id: &str,
334 address: SocketAddr,
335 settings: &UdpHealthSettings,
336 now: Instant,
337 ) {
338 let probe_addr = match settings.tcp_port {
339 Some(port) => SocketAddr::new(address.ip(), port),
340 None => address,
341 };
342 let record_failure = || {
343 Self::record(
344 backends,
345 cluster_id,
346 backend_id,
347 address,
348 false,
349 settings.rise,
350 settings.fall,
351 )
352 };
353 let mut stream = match TcpStream::connect(probe_addr) {
354 Ok(stream) => stream,
355 Err(_) => return record_failure(),
356 };
357 let Some(token) = self.allocate_token() else {
358 return record_failure();
359 };
360 if registry
361 .register(&mut stream, token, Interest::WRITABLE)
362 .is_err()
363 {
364 return record_failure();
365 }
366 self.in_flight.push(InFlightProbe {
367 socket: ProbeSocket::Tcp(stream),
368 token,
369 cluster_id: cluster_id.to_owned(),
370 backend_id: backend_id.to_owned(),
371 address,
372 started_at: now,
373 timeout: settings.timeout,
374 rise: settings.rise,
375 fall: settings.fall,
376 });
377 }
378
379 #[allow(clippy::too_many_arguments)]
385 fn spawn_udp_probe(
386 &mut self,
387 backends: &Rc<RefCell<BackendMap>>,
388 registry: &Registry,
389 cluster_id: &str,
390 backend_id: &str,
391 address: SocketAddr,
392 settings: &UdpHealthSettings,
393 now: Instant,
394 ) {
395 let Some(payload) = settings.udp_probe_payload.as_deref() else {
396 return;
397 };
398 let record_failure = || {
399 Self::record(
400 backends,
401 cluster_id,
402 backend_id,
403 address,
404 false,
405 settings.rise,
406 settings.fall,
407 )
408 };
409 let mut socket = match udp_connect(address) {
412 Ok(socket) => socket,
413 Err(_) => return record_failure(),
414 };
415 match socket.send(payload) {
419 Ok(_) => {}
420 Err(ref e) if e.kind() == ErrorKind::WouldBlock => return record_failure(),
421 Err(_) => return record_failure(),
422 }
423 let Some(token) = self.allocate_token() else {
424 return record_failure();
425 };
426 if registry
427 .register(&mut socket, token, Interest::READABLE)
428 .is_err()
429 {
430 return record_failure();
431 }
432 self.in_flight.push(InFlightProbe {
433 socket: ProbeSocket::Udp(socket),
434 token,
435 cluster_id: cluster_id.to_owned(),
436 backend_id: backend_id.to_owned(),
437 address,
438 started_at: now,
439 timeout: settings.timeout,
440 rise: settings.rise,
441 fall: settings.fall,
442 });
443 }
444
445 fn progress(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
446 let now = Instant::now();
447 let ready = std::mem::take(&mut self.ready_tokens);
448 let mut completed: Vec<(usize, bool)> = Vec::new();
449
450 for (idx, probe) in self.in_flight.iter_mut().enumerate() {
451 if now.duration_since(probe.started_at) > probe.timeout {
452 completed.push((idx, false));
453 continue;
454 }
455 if !ready.contains(&probe.token) {
456 continue;
457 }
458 let success = match &mut probe.socket {
459 ProbeSocket::Tcp(stream) => {
464 let no_so_error = matches!(stream.take_error(), Ok(None));
465 no_so_error && stream.peer_addr().is_ok()
466 }
467 ProbeSocket::Udp(socket) => {
472 let mut scratch = [0u8; 16];
473 match socket.recv(&mut scratch) {
474 Ok(_) => true,
475 Err(ref e) if e.kind() == ErrorKind::WouldBlock => continue,
476 Err(_) => false,
479 }
480 }
481 };
482 completed.push((idx, success));
483 }
484
485 completed.sort_by(|a, b| b.0.cmp(&a.0));
486 for (idx, success) in completed {
487 let mut probe = self.in_flight.swap_remove(idx);
488 probe.socket.deregister(registry);
489 Self::record(
490 backends,
491 &probe.cluster_id,
492 &probe.backend_id,
493 probe.address,
494 success,
495 probe.rise,
496 probe.fall,
497 );
498 }
499 }
500
501 fn record(
505 backends: &Rc<RefCell<BackendMap>>,
506 cluster_id: &str,
507 backend_id: &str,
508 address: SocketAddr,
509 success: bool,
510 rise: u32,
511 fall: u32,
512 ) {
513 let mut backend_map = backends.borrow_mut();
514 let Some(list) = backend_map.backends.get_mut(cluster_id) else {
515 return;
516 };
517 let Some(backend_ref) = list.find_backend(&address) else {
518 return;
519 };
520 let mut backend = backend_ref.borrow_mut();
521 if success {
522 if backend.health.record_success(rise) {
523 info!(
524 "{} backend {} at {} marked UP (cluster {})",
525 log_context!(),
526 backend_id,
527 address,
528 cluster_id
529 );
530 incr!(names::udp::BACKEND_HEALTH);
531 }
532 } else if backend.health.record_failure(fall) {
533 warn!(
534 "{} backend {} at {} marked DOWN (cluster {})",
535 log_context!(),
536 backend_id,
537 address,
538 cluster_id
539 );
540 incr!(names::udp::BACKEND_HEALTH);
541 }
542 drop(backend);
543 backend_map.record_cluster_availability(cluster_id);
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use crate::backends::HealthState;
551
552 #[test]
553 fn hysteresis_rise_fall() {
554 let mut state = HealthState::default();
557 assert!(state.is_healthy());
558 assert!(!state.record_failure(3));
559 assert!(!state.record_failure(3));
560 assert!(state.is_healthy());
561 assert!(state.record_failure(3));
562 assert!(!state.is_healthy());
563 assert!(!state.record_success(2));
564 assert!(state.record_success(2));
565 assert!(state.is_healthy());
566 }
567
568 #[test]
569 fn token_namespace_is_disjoint_and_owned() {
570 let hc = UdpHealthChecker::new();
571 assert!(hc.owns_token(Token(UDP_HEALTH_TOKEN_BASE)));
572 assert!(hc.owns_token(Token(UDP_HEALTH_TOKEN_BASE + UDP_HEALTH_TOKEN_CAPACITY - 1)));
573 assert!(!hc.owns_token(Token(UDP_HEALTH_TOKEN_BASE - 1)));
574 assert!(!hc.owns_token(Token(UDP_HEALTH_TOKEN_BASE + UDP_HEALTH_TOKEN_CAPACITY)));
575 assert!(!hc.owns_token(Token(1 << 24)));
577 }
578
579 #[test]
580 fn settings_from_proto_defaults() {
581 let cfg = UdpHealthConfig {
582 mode: None,
583 tcp_port: Some(5353),
584 rise: None,
585 fall: None,
586 fail_open: None,
587 udp_probe_payload: None,
588 probe_interval_seconds: None,
589 probe_timeout_seconds: None,
590 };
591 let s = UdpHealthSettings::from_proto(&cfg);
592 assert_eq!(s.tcp_port, Some(5353));
593 assert_eq!(s.rise, 2);
594 assert_eq!(s.fall, 3);
595 assert_eq!(s.interval, Duration::from_secs(5));
596 assert_eq!(s.timeout, Duration::from_secs(2));
597 }
598
599 #[test]
600 fn udp_probe_payload_is_captured() {
601 let cfg = UdpHealthConfig {
605 mode: Some(sozu_command::proto::command::UdpHealthMode::UdpProbe as i32),
606 tcp_port: None,
607 rise: Some(1),
608 fall: Some(1),
609 fail_open: None,
610 udp_probe_payload: Some(b"PING".to_vec()),
611 probe_interval_seconds: Some(1),
612 probe_timeout_seconds: Some(1),
613 };
614 let s = UdpHealthSettings::from_proto(&cfg);
615 assert_eq!(s.udp_probe_payload.as_deref(), Some(&b"PING"[..]));
616 assert_eq!(s.tcp_port, None);
617 }
618
619 #[test]
625 fn udp_probe_result_feeds_same_hysteresis() {
626 use crate::backends::{Backend, BackendMap};
627
628 let cluster = "dns";
629 let address: SocketAddr = ([127, 0, 0, 1], 5353).into();
630 let backend_map = Rc::new(RefCell::new(BackendMap::new()));
631 backend_map
632 .borrow_mut()
633 .add_backend(cluster, Backend::new("b1", address, None, None, None));
634 let (rise, fall) = (2u32, 3u32);
635
636 let is_healthy = |map: &Rc<RefCell<BackendMap>>| {
637 let mut m = map.borrow_mut();
638 let list = m.backends.get_mut(cluster).unwrap();
639 let b = list.find_backend(&address).unwrap();
640 b.borrow().health.is_healthy()
641 };
642 assert!(is_healthy(&backend_map));
643
644 for _ in 0..fall {
646 UdpHealthChecker::record(&backend_map, cluster, "b1", address, false, rise, fall);
647 }
648 assert!(!is_healthy(&backend_map));
649
650 UdpHealthChecker::record(&backend_map, cluster, "b1", address, true, rise, fall);
653 assert!(!is_healthy(&backend_map));
654 UdpHealthChecker::record(&backend_map, cluster, "b1", address, true, rise, fall);
655 assert!(is_healthy(&backend_map));
656 }
657}