1use std::{
10 cell::RefCell,
11 collections::{HashMap, HashSet},
12 hash::{Hash, Hasher},
13 io::{Read, Write},
14 net::SocketAddr,
15 rc::Rc,
16 time::{Duration, Instant},
17};
18
19use mio::{Interest, Registry, Token, net::TcpStream};
20use sozu_command::{
21 proto::command::{Event, EventKind, HealthCheckConfig},
22 state::ClusterId,
23};
24
25use crate::metrics::names;
26use crate::{
27 backends::BackendMap,
28 protocol::mux::{
29 parser::{
30 FLAG_END_HEADERS, FLAG_PADDED, FLAG_PRIORITY, FRAME_HEADER_SIZE, FrameType,
31 frame_header,
32 },
33 serializer::H2_PRI,
34 },
35 server::push_event,
36};
37
38macro_rules! log_context {
45 () => {
46 "HEALTH-CHECK"
47 };
48 ($cluster:expr) => {
49 concat!("HEALTH-CHECK cluster=", $cluster)
50 };
51}
52
53const HEALTH_CHECK_TOKEN_BASE: usize = 1 << 24;
60const HEALTH_CHECK_TOKEN_CAPACITY: usize = 1 << 16;
65
66type PendingChecks = Vec<(
71 ClusterId,
72 HealthCheckConfig,
73 bool,
74 Vec<(String, SocketAddr)>,
75)>;
76
77#[derive(Debug)]
79struct InFlightCheck {
80 stream: TcpStream,
81 token: Token,
82 cluster_id: ClusterId,
83 backend_id: String,
84 address: SocketAddr,
85 started_at: Instant,
86 timeout: Duration,
87 request_bytes: Option<Vec<u8>>,
88 write_offset: usize,
89 response_buf: Vec<u8>,
90 config: HealthCheckConfig,
91 h2c: bool,
97}
98
99#[derive(Debug)]
101pub struct HealthChecker {
102 in_flight: Vec<InFlightCheck>,
103 last_check_time: HashMap<ClusterId, Instant>,
104 next_token_id: usize,
105 ready_tokens: HashSet<Token>,
106}
107
108impl Default for HealthChecker {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl HealthChecker {
115 pub fn new() -> Self {
116 HealthChecker {
117 in_flight: Vec::new(),
118 last_check_time: HashMap::new(),
119 next_token_id: 0,
120 ready_tokens: HashSet::new(),
121 }
122 }
123
124 fn allocate_token(&mut self) -> Option<Token> {
130 let in_flight: HashSet<usize> = self
131 .in_flight
132 .iter()
133 .map(|c| c.token.0 - HEALTH_CHECK_TOKEN_BASE)
134 .collect();
135 debug_assert!(
139 in_flight.iter().all(|&o| o < HEALTH_CHECK_TOKEN_CAPACITY),
140 "every in-flight token offset must fall within the slot capacity"
141 );
142 debug_assert!(
143 in_flight.len() <= HEALTH_CHECK_TOKEN_CAPACITY,
144 "cannot have more in-flight checks than the token slot capacity"
145 );
146
147 for _ in 0..HEALTH_CHECK_TOKEN_CAPACITY {
148 let offset = self.next_token_id % HEALTH_CHECK_TOKEN_CAPACITY;
149 self.next_token_id = self.next_token_id.wrapping_add(1);
150 if !in_flight.contains(&offset) {
151 let token = Token(HEALTH_CHECK_TOKEN_BASE + offset);
152 debug_assert!(
155 self.owns_token(token),
156 "allocated token must fall inside the health-check namespace"
157 );
158 debug_assert!(
159 !in_flight.contains(&offset),
160 "allocated offset must not already be in flight"
161 );
162 return Some(token);
163 }
164 }
165 debug_assert_eq!(
169 in_flight.len(),
170 HEALTH_CHECK_TOKEN_CAPACITY,
171 "allocation only fails when every slot is occupied"
172 );
173 error!(
174 "{} token-table full ({} in-flight checks); refusing to allocate a new probe slot",
175 log_context!(),
176 in_flight.len()
177 );
178 None
179 }
180
181 pub fn owns_token(&self, token: Token) -> bool {
187 let owned = token.0 >= HEALTH_CHECK_TOKEN_BASE
188 && token.0 < HEALTH_CHECK_TOKEN_BASE + HEALTH_CHECK_TOKEN_CAPACITY;
189 debug_assert!(
194 !owned || token.0 - HEALTH_CHECK_TOKEN_BASE < HEALTH_CHECK_TOKEN_CAPACITY,
195 "an owned token must map to a valid bounded slot offset"
196 );
197 debug_assert!(
198 owned || token != Token(HEALTH_CHECK_TOKEN_BASE),
199 "the base token itself must always be classified as owned"
200 );
201 owned
202 }
203
204 pub fn ready(&mut self, token: Token) {
206 self.ready_tokens.insert(token);
207 debug_assert!(
211 self.ready_tokens.contains(&token),
212 "ready() must record the token in the readiness set"
213 );
214 }
215
216 pub fn poll(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
219 if self.in_flight.is_empty() && backends.borrow().health_check_configs.is_empty() {
220 return;
221 }
222 self.initiate_checks(backends, registry);
223 self.progress_checks(backends, registry);
224 }
225
226 fn initiate_checks(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
227 let backend_map = backends.borrow();
228 let now = Instant::now();
229
230 let mut to_check: PendingChecks = Vec::new();
231
232 for (cluster_id, config) in &backend_map.health_check_configs {
233 let interval = Duration::from_secs(u64::from(config.interval));
234
235 let mut hasher = std::collections::hash_map::DefaultHasher::new();
237 cluster_id.hash(&mut hasher);
238 let jitter_ms = hasher.finish() % (interval.as_millis() as u64 / 5).max(1);
239 let jittered_interval = interval + Duration::from_millis(jitter_ms);
240
241 let should_check = match self.last_check_time.get(cluster_id) {
242 Some(last) => now.duration_since(*last) >= jittered_interval,
243 None => true,
244 };
245
246 if !should_check {
247 continue;
248 }
249
250 if let Some(backend_list) = backend_map.backends.get(cluster_id) {
251 let backends_to_check: Vec<(String, SocketAddr)> = backend_list
252 .backends
253 .iter()
254 .filter(|b| {
255 let b = b.borrow();
256 b.status == crate::backends::BackendStatus::Normal
257 && !self.in_flight.iter().any(|f| {
258 f.cluster_id == *cluster_id && f.backend_id == b.backend_id
259 })
260 })
261 .map(|b| {
262 let b = b.borrow();
263 (b.backend_id.to_owned(), b.address)
264 })
265 .collect();
266
267 if !backends_to_check.is_empty() {
268 let h2c = backend_map
269 .cluster_http2
270 .get(cluster_id)
271 .copied()
272 .unwrap_or(false);
273 to_check.push((
274 cluster_id.to_owned(),
275 config.to_owned(),
276 h2c,
277 backends_to_check,
278 ));
279 }
280 }
281 }
282
283 drop(backend_map);
284
285 for (cluster_id, config, h2c, backends_to_check) in to_check {
286 self.last_check_time.insert(cluster_id.to_owned(), now);
287
288 let probe_uri = config.uri.as_str();
295
296 for (backend_id, address) in backends_to_check {
297 match TcpStream::connect(address) {
298 Ok(mut stream) => {
299 let Some(token) = self.allocate_token() else {
300 Self::record_check_result(
304 backends,
305 &cluster_id,
306 &backend_id,
307 address,
308 false,
309 &config,
310 );
311 continue;
312 };
313 if let Err(e) = registry.register(
314 &mut stream,
315 token,
316 Interest::READABLE | Interest::WRITABLE,
317 ) {
318 debug!(
319 "{} failed to register socket for {} ({}) in cluster {}: {}",
320 log_context!(),
321 backend_id,
322 address,
323 cluster_id,
324 e
325 );
326 Self::record_check_result(
327 backends,
328 &cluster_id,
329 &backend_id,
330 address,
331 false,
332 &config,
333 );
334 continue;
335 }
336 trace!(
337 "{} initiated connection to {} ({}) for cluster {}",
338 log_context!(),
339 backend_id,
340 address,
341 cluster_id
342 );
343 let request_bytes = if h2c {
344 build_h2c_probe_bytes(probe_uri, address)
345 } else {
346 format!(
360 "GET {probe_uri} HTTP/1.1\r\nHost: {address}\r\nConnection: close\r\n\r\n"
361 )
362 .into_bytes()
363 };
364 self.in_flight.push(InFlightCheck {
365 stream,
366 token,
367 cluster_id: cluster_id.to_owned(),
368 backend_id,
369 address,
370 started_at: now,
371 timeout: Duration::from_secs(u64::from(config.timeout)),
372 request_bytes: Some(request_bytes),
373 write_offset: 0,
374 response_buf: Vec::with_capacity(256),
375 config: config.to_owned(),
376 h2c,
377 });
378 }
379 Err(e) => {
380 debug!(
381 "{} failed to connect to {} ({}) for cluster {}: {}",
382 log_context!(),
383 backend_id,
384 address,
385 cluster_id,
386 e
387 );
388 Self::record_check_result(
389 backends,
390 &cluster_id,
391 &backend_id,
392 address,
393 false,
394 &config,
395 );
396 }
397 }
398 }
399 }
400 }
401
402 fn progress_checks(&mut self, backends: &Rc<RefCell<BackendMap>>, registry: &Registry) {
403 const MAX_HEALTH_RESPONSE_SIZE: usize = 4096;
404
405 let now = Instant::now();
406 let mut completed = Vec::new();
407 let ready = std::mem::take(&mut self.ready_tokens);
408 debug_assert!(
411 self.ready_tokens.is_empty(),
412 "readiness set must be drained before processing in-flight checks"
413 );
414 let in_flight_before = self.in_flight.len();
415
416 for (idx, check) in self.in_flight.iter_mut().enumerate() {
417 debug_assert!(
420 idx < in_flight_before,
421 "in-flight index ({idx}) must be within the live slot range ({in_flight_before})"
422 );
423 debug_assert!(
425 check
426 .request_bytes
427 .as_ref()
428 .is_none_or(|r| check.write_offset <= r.len()),
429 "write_offset must never exceed the request length"
430 );
431
432 if now.duration_since(check.started_at) > check.timeout {
434 debug!(
435 "{} timeout for {} ({}) in cluster {}",
436 log_context!(),
437 check.backend_id,
438 check.address,
439 check.cluster_id
440 );
441 completed.push((idx, false));
442 continue;
443 }
444
445 if !ready.contains(&check.token) {
447 continue;
448 }
449
450 if let Some(ref request_bytes) = check.request_bytes {
451 match check.stream.write(&request_bytes[check.write_offset..]) {
452 Ok(n) => {
453 check.write_offset += n;
454 if check.write_offset >= request_bytes.len() {
455 check.request_bytes = None;
456 } else {
457 continue;
458 }
459 }
460 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
461 continue;
462 }
463 Err(_e) => {
464 completed.push((idx, false));
465 continue;
466 }
467 }
468 }
469
470 let mut buf = [0u8; 256];
471 match check.stream.read(&mut buf) {
472 Ok(0) => {
473 let success =
474 parse_probe_response(&check.response_buf, &check.config, check.h2c)
475 .unwrap_or(false);
476 completed.push((idx, success));
477 }
478 Ok(n) => {
479 debug_assert!(
482 n <= buf.len(),
483 "read reported {n} bytes into a {}-byte buffer",
484 buf.len()
485 );
486 if check.response_buf.len() + n > MAX_HEALTH_RESPONSE_SIZE {
487 completed.push((idx, false));
488 continue;
489 }
490 check.response_buf.extend_from_slice(&buf[..n]);
491 debug_assert!(
494 check.response_buf.len() <= MAX_HEALTH_RESPONSE_SIZE,
495 "response buffer must stay within the max health response size"
496 );
497 if let Some(success) =
498 parse_probe_response(&check.response_buf, &check.config, check.h2c)
499 {
500 completed.push((idx, success));
501 }
502 }
503 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
504 Err(_e) => {
505 completed.push((idx, false));
506 }
507 }
508 }
509
510 completed.sort_by(|a, b| b.0.cmp(&a.0));
514 debug_assert!(
517 completed.len() <= in_flight_before,
518 "cannot complete more checks ({}) than were in flight ({in_flight_before})",
519 completed.len()
520 );
521 debug_assert!(
522 completed.windows(2).all(|w| w[0].0 > w[1].0),
523 "completed indices must be strictly descending and unique for swap_remove safety"
524 );
525 for (idx, success) in completed {
526 let len_before = self.in_flight.len();
527 let mut check = self.in_flight.swap_remove(idx);
528 debug_assert_eq!(
530 self.in_flight.len(),
531 len_before - 1,
532 "swap_remove must drop exactly one in-flight check"
533 );
534 let _ = registry.deregister(&mut check.stream);
535 Self::record_check_result(
536 backends,
537 &check.cluster_id,
538 &check.backend_id,
539 check.address,
540 success,
541 &check.config,
542 );
543 }
544 }
545
546 fn record_check_result(
547 backends: &Rc<RefCell<BackendMap>>,
548 cluster_id: &str,
549 backend_id: &str,
550 address: SocketAddr,
551 success: bool,
552 config: &HealthCheckConfig,
553 ) {
554 let mut backend_map = backends.borrow_mut();
555 let Some(backend_list) = backend_map.backends.get_mut(cluster_id) else {
556 return;
557 };
558
559 let Some(backend_ref) = backend_list.find_backend(&address) else {
560 return;
561 };
562
563 let mut backend = backend_ref.borrow_mut();
564
565 if success {
566 let was_healthy = backend.health.is_healthy();
570 let transitioned = backend.health.record_success(config.healthy_threshold);
571 debug_assert!(
576 backend.health.consecutive_failures == 0,
577 "a recorded success must zero the consecutive-failure counter"
578 );
579 debug_assert_eq!(
580 transitioned,
581 !was_healthy && backend.health.is_healthy(),
582 "transition flag must be set iff the backend just flipped to healthy"
583 );
584 debug_assert!(
585 !transitioned || backend.health.consecutive_successes >= config.healthy_threshold,
586 "an UP transition only fires once the rise counter reaches the healthy threshold"
587 );
588 debug_assert!(
589 !transitioned || backend.health.is_healthy(),
590 "after an UP transition the backend must report healthy"
591 );
592 if transitioned {
593 info!(
594 "{} backend {} at {} marked UP (health check passed {} consecutive times) for cluster {}",
595 log_context!(),
596 backend_id,
597 address,
598 config.healthy_threshold,
599 cluster_id
600 );
601 incr!(names::health_check::UP);
602 gauge!(
603 names::backend::AVAILABLE,
604 1,
605 Some(cluster_id),
606 Some(backend_id)
607 );
608 push_event(Event {
609 kind: EventKind::HealthCheckHealthy as i32,
610 cluster_id: Some(cluster_id.to_owned()),
611 backend_id: Some(backend_id.to_owned()),
612 address: Some(address.into()),
613 metric_detail: None,
614 });
615 }
616 count!(names::health_check::SUCCESS, 1);
617 } else {
618 let was_healthy = backend.health.is_healthy();
619 let transitioned = backend.health.record_failure(config.unhealthy_threshold);
620 debug_assert!(
625 backend.health.consecutive_successes == 0,
626 "a recorded failure must zero the consecutive-success counter"
627 );
628 debug_assert_eq!(
629 transitioned,
630 was_healthy && !backend.health.is_healthy(),
631 "transition flag must be set iff the backend just flipped to unhealthy"
632 );
633 debug_assert!(
634 !transitioned || backend.health.consecutive_failures >= config.unhealthy_threshold,
635 "a DOWN transition only fires once the fall counter reaches the unhealthy threshold"
636 );
637 debug_assert!(
638 !transitioned || !backend.health.is_healthy(),
639 "after a DOWN transition the backend must report unhealthy"
640 );
641 if transitioned {
642 warn!(
643 "{} backend {} at {} marked DOWN (health check failed {} consecutive times) for cluster {}",
644 log_context!(),
645 backend_id,
646 address,
647 config.unhealthy_threshold,
648 cluster_id
649 );
650 incr!(names::health_check::DOWN);
651 gauge!(
652 names::backend::AVAILABLE,
653 0,
654 Some(cluster_id),
655 Some(backend_id)
656 );
657 push_event(Event {
658 kind: EventKind::HealthCheckUnhealthy as i32,
659 cluster_id: Some(cluster_id.to_owned()),
660 backend_id: Some(backend_id.to_owned()),
661 address: Some(address.into()),
662 metric_detail: None,
663 });
664 }
665 count!(names::health_check::FAILURE, 1);
666 }
667
668 drop(backend);
679 let total = backend_list.backends.len();
680 let healthy = backend_list
681 .backends
682 .iter()
683 .filter(|b| b.borrow().health.is_healthy())
684 .count();
685 debug_assert!(
689 healthy <= total,
690 "healthy backend count ({healthy}) must not exceed total ({total})"
691 );
692 if total > 0 {
693 gauge!(
694 "health_check.healthy_backends",
695 healthy,
696 Some(cluster_id),
697 None
698 );
699 if healthy > 0 && healthy * 2 <= total {
700 warn!(
701 "{} cluster {} has only {}/{} healthy backends",
702 log_context!(),
703 cluster_id,
704 healthy,
705 total
706 );
707 }
708 }
709 backend_map.record_cluster_availability(cluster_id);
722 }
723
724 pub fn remove_cluster(&mut self, cluster_id: &str) {
725 self.last_check_time.remove(cluster_id);
726 self.in_flight
727 .retain(|check| check.cluster_id != cluster_id);
728 debug_assert!(
730 self.in_flight.iter().all(|c| c.cluster_id != cluster_id),
731 "remove_cluster must drop every in-flight check for the cluster"
732 );
733 debug_assert!(
734 !self.last_check_time.contains_key(cluster_id),
735 "remove_cluster must forget the cluster's last-check timestamp"
736 );
737 }
738}
739
740fn parse_probe_response(buf: &[u8], config: &HealthCheckConfig, h2c: bool) -> Option<bool> {
745 if h2c {
746 try_parse_h2c_status(buf, config)
747 } else {
748 try_parse_status_line(buf, config)
749 }
750}
751
752fn try_parse_status_line(buf: &[u8], config: &HealthCheckConfig) -> Option<bool> {
753 let response = std::str::from_utf8(buf).ok()?;
754 let first_line_end = response.find("\r\n")?;
755 let status_line = &response[..first_line_end];
756 debug_assert!(
759 status_line.len() < response.len(),
760 "status line must be a strict prefix ending before the CRLF"
761 );
762
763 let (_, rest) = status_line.split_once(' ')?;
764 let status_str = rest.split(' ').next()?;
765 let status_code: u32 = status_str.parse().unwrap_or(0);
766 Some(is_status_healthy(status_code, config.expected_status))
767}
768
769fn is_status_healthy(actual: u32, expected: u32) -> bool {
770 let healthy = if expected == 0 {
771 (200..300).contains(&actual)
772 } else {
773 actual == expected
774 };
775 debug_assert!(
779 expected == 0 || healthy == (actual == expected),
780 "with a specific expected status, health must be exact equality"
781 );
782 healthy
783}
784
785fn build_h2c_probe_bytes(uri: &str, address: SocketAddr) -> Vec<u8> {
798 let authority = address.to_string();
799
800 let mut encoder = loona_hpack::Encoder::new();
804 let mut hpack: Vec<u8> = Vec::new();
805 let headers: [(&[u8], &[u8]); 4] = [
806 (b":method", b"GET"),
807 (b":scheme", b"http"),
808 (b":path", uri.as_bytes()),
809 (b":authority", authority.as_bytes()),
810 ];
811 if encoder.encode_into(headers, &mut hpack).is_err() {
814 return Vec::new();
817 }
818
819 let mut out = Vec::with_capacity(H2_PRI.len() + FRAME_HEADER_SIZE * 2 + hpack.len());
821 out.extend_from_slice(H2_PRI.as_bytes());
822
823 out.extend_from_slice(&[0, 0, 0, 0x04, 0, 0, 0, 0, 0]);
825
826 let len = hpack.len() as u32;
829 out.push(((len >> 16) & 0xFF) as u8);
830 out.push(((len >> 8) & 0xFF) as u8);
831 out.push((len & 0xFF) as u8);
832 out.push(0x01); out.push(0x05); out.extend_from_slice(&[0, 0, 0, 1]); out.extend_from_slice(&hpack);
836 debug_assert!(
839 out.starts_with(H2_PRI.as_bytes()),
840 "an h2c probe must begin with the connection preface"
841 );
842 debug_assert_eq!(
843 out.len(),
844 H2_PRI.len() + FRAME_HEADER_SIZE * 2 + hpack.len(),
845 "probe length must be preface + SETTINGS + HEADERS header + HPACK block"
846 );
847 out
848}
849
850fn try_parse_h2c_status(buf: &[u8], config: &HealthCheckConfig) -> Option<bool> {
868 const MAX_FRAME_SIZE: u32 = (1 << 24) - 1;
872
873 let mut remaining: &[u8] = buf;
874 let mut headers_block: Option<Vec<u8>> = None;
880
881 while !remaining.is_empty() {
882 if remaining.len() < FRAME_HEADER_SIZE {
887 return None;
888 }
889 let consumable = remaining.len();
890 let (rest, header) = match frame_header(remaining, MAX_FRAME_SIZE) {
891 Ok(parsed) => parsed,
892 Err(_) => return Some(false),
896 };
897 debug_assert!(
900 rest.len() < consumable,
901 "frame_header must consume at least the fixed frame header"
902 );
903 debug_assert_eq!(
904 consumable - rest.len(),
905 FRAME_HEADER_SIZE,
906 "frame_header must consume exactly the fixed-size frame header"
907 );
908 debug_assert!(
909 header.payload_len <= MAX_FRAME_SIZE,
910 "frame_header must enforce the max-frame-size bound it was given"
911 );
912
913 let payload_len = header.payload_len as usize;
914 if rest.len() < payload_len {
915 return None;
917 }
918 let (payload, after) = rest.split_at(payload_len);
919 debug_assert_eq!(
923 payload.len(),
924 payload_len,
925 "payload split must yield exactly the declared payload length"
926 );
927 debug_assert_eq!(
928 payload.len() + after.len(),
929 rest.len(),
930 "payload + remainder must equal the pre-split buffer"
931 );
932 debug_assert!(
933 after.len() < remaining.len(),
934 "each iteration must shrink the remaining buffer to guarantee termination"
935 );
936
937 match header.frame_type {
938 FrameType::Headers if header.stream_id == 1 => {
939 let block = strip_padded_priority(payload, header.flags)?;
940 let mut accumulator = headers_block.take().unwrap_or_default();
941 accumulator.extend_from_slice(block);
942 if header.flags & FLAG_END_HEADERS != 0 {
943 return Some(decode_status_from_block(&accumulator, config));
944 }
945 headers_block = Some(accumulator);
946 }
947 FrameType::Continuation if header.stream_id == 1 => {
948 let Some(mut accumulator) = headers_block.take() else {
951 return Some(false);
954 };
955 accumulator.extend_from_slice(payload);
956 if header.flags & FLAG_END_HEADERS != 0 {
957 return Some(decode_status_from_block(&accumulator, config));
958 }
959 headers_block = Some(accumulator);
960 }
961 FrameType::GoAway => return Some(false),
962 _ => {}
965 }
966
967 remaining = after;
968 }
969 None
970}
971
972fn strip_padded_priority(payload: &[u8], flags: u8) -> Option<&[u8]> {
977 let mut start = 0usize;
978 let mut end = payload.len();
979
980 if flags & FLAG_PADDED != 0 {
981 let &pad_len = payload.first()?;
982 start = 1;
983 let pad = pad_len as usize;
984 let available = end.checked_sub(start)?;
988 if pad > available {
989 return None;
990 }
991 end -= pad;
992 }
993 if flags & FLAG_PRIORITY != 0 {
994 let new_start = start.checked_add(5)?;
995 if new_start > end {
996 return None;
997 }
998 start = new_start;
999 }
1000 debug_assert!(
1004 start <= end && end <= payload.len(),
1005 "stripped header window [{start}, {end}) must lie within the payload ({})",
1006 payload.len()
1007 );
1008 let block = payload.get(start..end)?;
1009 debug_assert!(
1010 block.len() <= payload.len(),
1011 "stripped block must never be larger than the original payload"
1012 );
1013 Some(block)
1014}
1015
1016fn decode_status_from_block(block: &[u8], config: &HealthCheckConfig) -> bool {
1022 let mut decoder = loona_hpack::Decoder::new();
1023 let mut status: Option<u32> = None;
1024 let decode_result = decoder.decode_with_cb(block, |name, value| {
1025 if status.is_some() {
1026 return;
1027 }
1028 if name.as_ref() == b":status"
1029 && let Ok(s) = std::str::from_utf8(value.as_ref())
1030 && let Ok(parsed) = s.parse::<u32>()
1031 {
1032 status = Some(parsed);
1033 }
1034 });
1035 if decode_result.is_err() {
1036 return false;
1037 }
1038 match status {
1039 Some(code) => is_status_healthy(code, config.expected_status),
1040 None => false,
1041 }
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046 use super::*;
1047 use crate::backends::HealthState;
1048
1049 #[test]
1050 fn test_is_status_healthy_any_2xx() {
1051 assert!(is_status_healthy(200, 0));
1052 assert!(is_status_healthy(204, 0));
1053 assert!(is_status_healthy(299, 0));
1054 assert!(!is_status_healthy(301, 0));
1055 assert!(!is_status_healthy(500, 0));
1056 assert!(!is_status_healthy(0, 0));
1057 }
1058
1059 #[test]
1060 fn test_is_status_healthy_specific() {
1061 assert!(is_status_healthy(200, 200));
1062 assert!(!is_status_healthy(204, 200));
1063 assert!(!is_status_healthy(500, 200));
1064 }
1065
1066 #[test]
1067 fn test_try_parse_status_line() {
1068 let config = HealthCheckConfig {
1069 uri: "/health".to_owned(),
1070 interval: 10,
1071 timeout: 5,
1072 healthy_threshold: 3,
1073 unhealthy_threshold: 3,
1074 expected_status: 0,
1075 };
1076
1077 let buf = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
1078 assert_eq!(try_parse_status_line(buf, &config), Some(true));
1079
1080 let buf = b"HTTP/1.1 500 Internal Server Error\r\n\r\n";
1081 assert_eq!(try_parse_status_line(buf, &config), Some(false));
1082
1083 let buf = b"HTTP/1.1 200";
1084 assert_eq!(try_parse_status_line(buf, &config), None);
1085 }
1086
1087 #[test]
1088 fn test_health_state_transitions() {
1089 let mut state = HealthState::default();
1090 assert!(state.is_healthy());
1091
1092 assert!(!state.record_failure(3));
1093 assert!(!state.record_failure(3));
1094 assert!(state.is_healthy());
1095
1096 assert!(state.record_failure(3));
1097 assert!(!state.is_healthy());
1098
1099 assert!(!state.record_success(3));
1100 assert!(!state.record_success(3));
1101 assert!(!state.is_healthy());
1102
1103 assert!(state.record_success(3));
1104 assert!(state.is_healthy());
1105 }
1106
1107 fn h2c_config(expected: u32) -> HealthCheckConfig {
1108 HealthCheckConfig {
1109 uri: "/health".to_owned(),
1110 interval: 10,
1111 timeout: 5,
1112 healthy_threshold: 3,
1113 unhealthy_threshold: 3,
1114 expected_status: expected,
1115 }
1116 }
1117
1118 fn frame_with_header(frame_type: u8, flags: u8, sid: u32, payload: &[u8]) -> Vec<u8> {
1122 let payload_len = payload.len();
1123 let mut out = Vec::with_capacity(FRAME_HEADER_SIZE + payload_len);
1124 out.push(((payload_len >> 16) & 0xFF) as u8);
1125 out.push(((payload_len >> 8) & 0xFF) as u8);
1126 out.push((payload_len & 0xFF) as u8);
1127 out.push(frame_type);
1128 out.push(flags);
1129 out.extend_from_slice(&sid.to_be_bytes());
1130 out.extend_from_slice(payload);
1131 out
1132 }
1133
1134 fn encode_response_headers(headers: &[(&[u8], &[u8])]) -> Vec<u8> {
1139 let mut encoder = loona_hpack::Encoder::new();
1140 let mut out = Vec::new();
1141 encoder
1142 .encode_into(headers.iter().copied(), &mut out)
1143 .unwrap();
1144 out
1145 }
1146
1147 #[test]
1148 fn build_h2c_probe_starts_with_preface_and_frames() {
1149 let bytes = build_h2c_probe_bytes("/health", "127.0.0.1:8080".parse().unwrap());
1150
1151 assert!(bytes.starts_with(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"));
1153
1154 let settings_start = 24;
1156 assert_eq!(&bytes[settings_start..settings_start + 3], &[0u8, 0, 0]); assert_eq!(bytes[settings_start + 3], 0x04); assert_eq!(bytes[settings_start + 4], 0); assert_eq!(
1160 &bytes[settings_start + 5..settings_start + 9],
1161 &[0u8, 0, 0, 0]
1162 );
1163
1164 let headers_start = settings_start + 9;
1166 assert_eq!(bytes[headers_start + 3], 0x01); assert_eq!(bytes[headers_start + 4], 0x05);
1168 assert_eq!(
1169 &bytes[headers_start + 5..headers_start + 9],
1170 &[0u8, 0, 0, 1]
1171 );
1172
1173 let payload_start = headers_start + 9;
1175 let mut decoder = loona_hpack::Decoder::new();
1176 let mut method = None;
1177 let mut scheme = None;
1178 let mut path = None;
1179 let mut authority = None;
1180 decoder
1181 .decode_with_cb(&bytes[payload_start..], |name, value| match name.as_ref() {
1182 b":method" => method = Some(value.to_vec()),
1183 b":scheme" => scheme = Some(value.to_vec()),
1184 b":path" => path = Some(value.to_vec()),
1185 b":authority" => authority = Some(value.to_vec()),
1186 _ => {}
1187 })
1188 .expect("loona_hpack decodes a freshly-encoded probe");
1189 assert_eq!(method.as_deref(), Some(b"GET" as &[u8]));
1190 assert_eq!(scheme.as_deref(), Some(b"http" as &[u8]));
1191 assert_eq!(path.as_deref(), Some(b"/health" as &[u8]));
1192 assert_eq!(authority.as_deref(), Some(b"127.0.0.1:8080" as &[u8]));
1193 }
1194
1195 #[test]
1196 fn h2c_response_with_status_200_decodes_healthy() {
1197 let block = encode_response_headers(&[(b":status", b"200")]);
1198 let buf = frame_with_header(0x01, FLAG_END_HEADERS, 1, &block);
1199 let cfg = h2c_config(0);
1200 assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
1201 }
1202
1203 #[test]
1204 fn h2c_response_with_status_500_fails_default_2xx_check() {
1205 let block = encode_response_headers(&[(b":status", b"500")]);
1206 let buf = frame_with_header(0x01, FLAG_END_HEADERS, 1, &block);
1207 let cfg = h2c_config(0);
1208 assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(false));
1209 }
1210
1211 #[test]
1212 fn h2c_response_with_status_503_matches_expected_503() {
1213 let block =
1214 encode_response_headers(&[(b":status", b"503"), (b"content-type", b"text/plain")]);
1215 let buf = frame_with_header(0x01, FLAG_END_HEADERS, 1, &block);
1216 let cfg = h2c_config(503);
1217 assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
1218 }
1219
1220 #[test]
1221 fn h2c_response_with_continuation_decodes_status_200_healthy() {
1222 let block = encode_response_headers(&[
1227 (b":status", b"200"),
1228 (b"x-trace-id", b"abc-123"),
1229 (b"server", b"sozu-test"),
1230 ]);
1231 assert!(block.len() >= 4, "HPACK block needs to be splittable");
1232 let split = block.len() / 2;
1233 let (head, tail) = block.split_at(split);
1234
1235 let mut buf = frame_with_header(0x01, 0, 1, head);
1237 buf.extend_from_slice(&frame_with_header(0x09, FLAG_END_HEADERS, 1, tail));
1239
1240 let cfg = h2c_config(0);
1241 assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
1242 }
1243
1244 #[test]
1245 fn h2c_response_with_padded_priority_headers_decodes_status_200() {
1246 let block = encode_response_headers(&[(b":status", b"200")]);
1250 let pad_len: u8 = 3;
1251
1252 let mut payload = Vec::new();
1253 payload.push(pad_len); payload.extend_from_slice(&[0u8, 0, 0, 0, 16]); payload.extend_from_slice(&block);
1256 payload.extend_from_slice(&[0u8; 3]); let flags = FLAG_PADDED | FLAG_PRIORITY | FLAG_END_HEADERS;
1259 let buf = frame_with_header(0x01, flags, 1, &payload);
1260 let cfg = h2c_config(0);
1261 assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
1262 }
1263
1264 #[test]
1265 fn h2c_response_after_unrelated_settings_frame_decodes_healthy() {
1266 let mut buf = frame_with_header(0x04, 0, 0, &[]); buf.extend_from_slice(&frame_with_header(0x04, 0x01, 0, &[])); let block = encode_response_headers(&[(b":status", b"200")]);
1272 buf.extend_from_slice(&frame_with_header(0x01, FLAG_END_HEADERS, 1, &block));
1273
1274 let cfg = h2c_config(0);
1275 assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(true));
1276 }
1277
1278 #[test]
1279 fn h2c_goaway_returns_unhealthy() {
1280 let buf = frame_with_header(0x07, 0, 0, &[0u8; 8]);
1282 let cfg = h2c_config(0);
1283 assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(false));
1284 }
1285
1286 #[test]
1287 fn h2c_truncated_frame_returns_none() {
1288 let mut buf: Vec<u8> = vec![
1290 0, 0, 10, 0x01, FLAG_END_HEADERS, ];
1296 buf.extend_from_slice(&1u32.to_be_bytes()); buf.extend_from_slice(&[0u8; 5]); let cfg = h2c_config(0);
1299 assert_eq!(try_parse_h2c_status(&buf, &cfg), None);
1300 }
1301
1302 #[test]
1303 fn h2c_partial_frame_header_returns_none() {
1304 let cfg = h2c_config(0);
1308 for partial_len in 0usize..FRAME_HEADER_SIZE {
1309 let buf = vec![0u8; partial_len];
1310 assert_eq!(
1311 try_parse_h2c_status(&buf, &cfg),
1312 None,
1313 "partial buffer of {partial_len} byte(s) should be 'keep reading'"
1314 );
1315 }
1316 }
1317
1318 #[test]
1319 fn h2c_continuation_without_preceding_headers_returns_unhealthy() {
1320 let block = encode_response_headers(&[(b":status", b"200")]);
1323 let buf = frame_with_header(0x09, FLAG_END_HEADERS, 1, &block);
1324 let cfg = h2c_config(0);
1325 assert_eq!(try_parse_h2c_status(&buf, &cfg), Some(false));
1326 }
1327}