use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::time::Duration;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Channel {
Default,
DnsOverHttps,
StaticIp,
AlternateRegion,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Status {
Healthy,
Failing,
Unknown,
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum HealthError {
#[error("dns resolution failed for {hostport}")]
DnsResolution {
hostport: String,
},
#[error("io error: {0}")]
Io(String),
}
pub trait HealthChannel: Send + Sync {
fn channel_id(&self) -> Channel;
fn probe(&self) -> Result<Status, HealthError>;
}
fn tcp_reachable(hostport: &str, timeout: Duration) -> Result<Status, HealthError> {
let mut addrs = hostport
.to_socket_addrs()
.map_err(|_| HealthError::DnsResolution {
hostport: hostport.to_string(),
})?;
match addrs.next() {
Some(addr) => match TcpStream::connect_timeout(&addr, timeout) {
Ok(_) => Ok(Status::Healthy),
Err(_) => Ok(Status::Failing),
},
None => Err(HealthError::DnsResolution {
hostport: hostport.to_string(),
}),
}
}
fn static_reachable(addr: SocketAddr, timeout: Duration) -> Status {
match TcpStream::connect_timeout(&addr, timeout) {
Ok(_) => Status::Healthy,
Err(_) => Status::Failing,
}
}
pub struct DohHealthChannel {
hostport: String,
timeout: Duration,
}
impl DohHealthChannel {
pub fn cloudflare() -> Self {
Self {
hostport: "1.1.1.1:443".to_string(),
timeout: Duration::from_secs(3),
}
}
pub fn google() -> Self {
Self {
hostport: "8.8.8.8:443".to_string(),
timeout: Duration::from_secs(3),
}
}
pub fn custom(hostport: impl Into<String>, timeout: Duration) -> Self {
Self {
hostport: hostport.into(),
timeout,
}
}
}
impl HealthChannel for DohHealthChannel {
fn channel_id(&self) -> Channel {
Channel::DnsOverHttps
}
fn probe(&self) -> Result<Status, HealthError> {
tcp_reachable(&self.hostport, self.timeout)
}
}
pub struct RegionHealthChannel {
hostport: String,
timeout: Duration,
}
impl RegionHealthChannel {
pub fn new(hostport: impl Into<String>, timeout: Duration) -> Self {
Self {
hostport: hostport.into(),
timeout,
}
}
}
impl HealthChannel for RegionHealthChannel {
fn channel_id(&self) -> Channel {
Channel::AlternateRegion
}
fn probe(&self) -> Result<Status, HealthError> {
tcp_reachable(&self.hostport, self.timeout)
}
}
pub struct StaticIpHealthChannel {
addr: SocketAddr,
timeout: Duration,
}
impl StaticIpHealthChannel {
pub fn new(addr: SocketAddr, timeout: Duration) -> Self {
Self { addr, timeout }
}
}
impl HealthChannel for StaticIpHealthChannel {
fn channel_id(&self) -> Channel {
Channel::StaticIp
}
fn probe(&self) -> Result<Status, HealthError> {
Ok(static_reachable(self.addr, self.timeout))
}
}
#[derive(Debug, Clone)]
pub struct ConsensusResult {
pub per_channel: Vec<(Channel, Status)>,
pub failing_count: usize,
pub total: usize,
pub is_quorum_failing: bool,
}
pub struct ConsensusHealthChecker {
channels: Vec<Box<dyn HealthChannel>>,
threshold: usize,
}
impl ConsensusHealthChecker {
pub fn new(channels: Vec<Box<dyn HealthChannel>>, threshold: usize) -> Self {
Self {
channels,
threshold,
}
}
pub fn two_of_three(channels: Vec<Box<dyn HealthChannel>>) -> Self {
assert_eq!(
channels.len(),
3,
"two_of_three requires exactly 3 channels; got {}",
channels.len()
);
Self::new(channels, 2)
}
pub fn total(&self) -> usize {
self.channels.len()
}
pub fn check(&self) -> ConsensusResult {
let mut per_channel = Vec::with_capacity(self.channels.len());
let mut failing_count = 0usize;
for ch in &self.channels {
let status = ch.probe().unwrap_or(Status::Failing);
if status == Status::Failing {
failing_count += 1;
}
per_channel.push((ch.channel_id(), status));
}
let total = per_channel.len();
let is_quorum_failing = failing_count >= self.threshold;
ConsensusResult {
per_channel,
failing_count,
total,
is_quorum_failing,
}
}
}
pub struct MultiChannelHealth {
channels: Vec<(Channel, Status)>,
}
impl MultiChannelHealth {
pub fn new(channels: &[Channel]) -> Self {
Self {
channels: channels.iter().map(|c| (*c, Status::Unknown)).collect(),
}
}
pub fn set_status(&mut self, channel: Channel, status: Status) {
for (c, s) in &mut self.channels {
if *c == channel {
*s = status;
return;
}
}
self.channels.push((channel, status));
}
pub fn is_quorum_failing(&self, threshold: usize) -> bool {
self.channels
.iter()
.filter(|(_, s)| *s == Status::Failing)
.count()
>= threshold
}
pub fn healthy_count(&self) -> usize {
self.channels
.iter()
.filter(|(_, s)| *s == Status::Healthy)
.count()
}
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
struct FakeChannel {
id: Channel,
status: Status,
}
impl HealthChannel for FakeChannel {
fn channel_id(&self) -> Channel {
self.id
}
fn probe(&self) -> Result<Status, HealthError> {
Ok(self.status)
}
}
fn fake(id: Channel, status: Status) -> Box<dyn HealthChannel> {
Box::new(FakeChannel { id, status })
}
#[test]
fn consensus_single_failure_does_not_trigger() {
let checker = ConsensusHealthChecker::two_of_three(vec![
fake(Channel::DnsOverHttps, Status::Healthy),
fake(Channel::AlternateRegion, Status::Healthy),
fake(Channel::StaticIp, Status::Failing),
]);
let r = checker.check();
assert_eq!(r.failing_count, 1);
assert_eq!(r.total, 3);
assert!(!r.is_quorum_failing);
}
#[test]
fn consensus_two_of_three_triggers() {
let checker = ConsensusHealthChecker::two_of_three(vec![
fake(Channel::DnsOverHttps, Status::Failing),
fake(Channel::AlternateRegion, Status::Failing),
fake(Channel::StaticIp, Status::Healthy),
]);
let r = checker.check();
assert_eq!(r.failing_count, 2);
assert!(r.is_quorum_failing);
}
#[test]
fn consensus_all_failing_triggers() {
let checker = ConsensusHealthChecker::two_of_three(vec![
fake(Channel::DnsOverHttps, Status::Failing),
fake(Channel::AlternateRegion, Status::Failing),
fake(Channel::StaticIp, Status::Failing),
]);
let r = checker.check();
assert_eq!(r.failing_count, 3);
assert!(r.is_quorum_failing);
}
#[test]
fn explicit_threshold_overrides_default() {
let checker = ConsensusHealthChecker::new(
vec![
fake(Channel::DnsOverHttps, Status::Failing),
fake(Channel::AlternateRegion, Status::Healthy),
],
1,
);
let r = checker.check();
assert_eq!(r.failing_count, 1);
assert!(r.is_quorum_failing);
}
#[test]
fn static_ip_channel_records_channel_id() {
let addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let ch = StaticIpHealthChannel::new(addr, Duration::from_millis(50));
assert_eq!(ch.channel_id(), Channel::StaticIp);
let status = ch.probe().unwrap();
assert!(matches!(status, Status::Failing | Status::Healthy));
}
#[test]
fn doh_defaults_provide_cloudflare_and_google() {
let cf = DohHealthChannel::cloudflare();
let gg = DohHealthChannel::google();
assert_eq!(cf.channel_id(), Channel::DnsOverHttps);
assert_eq!(gg.channel_id(), Channel::DnsOverHttps);
}
#[test]
fn region_channel_reports_channel_id() {
let ch = RegionHealthChannel::new("127.0.0.1:1", Duration::from_millis(50));
assert_eq!(ch.channel_id(), Channel::AlternateRegion);
}
#[test]
fn n_of_m_2_of_3_trigger() {
let mut h =
MultiChannelHealth::new(&[Channel::Default, Channel::DnsOverHttps, Channel::StaticIp]);
h.set_status(Channel::Default, Status::Failing);
assert!(!h.is_quorum_failing(2));
h.set_status(Channel::DnsOverHttps, Status::Failing);
assert!(h.is_quorum_failing(2));
}
#[test]
fn healthy_count_accurate() {
let mut h = MultiChannelHealth::new(&[Channel::Default, Channel::DnsOverHttps]);
h.set_status(Channel::Default, Status::Healthy);
assert_eq!(h.healthy_count(), 1);
}
}