use std::net::{IpAddr, SocketAddr, UdpSocket};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use serde::Serialize;
use tokio::time;
use tracing::warn;
use crate::config::parser::UnboundConfig;
const PROBE_INTERVAL_SECS: u64 = 30;
const PROBE_TIMEOUT_MS: u64 = 2_000;
const BIND_V4: SocketAddr =
SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0);
const BIND_V6: SocketAddr =
SocketAddr::new(std::net::IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED), 0);
const DNS_PROBE_PACKET: [u8; 17] = [
0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, ];
#[derive(Serialize, Clone)]
pub struct UpstreamStatus {
pub addr: String,
pub healthy: bool,
pub latency_ms: Option<u64>,
pub last_check: String,
pub zone: String,
}
pub type SharedUpstreams = Arc<RwLock<Vec<UpstreamStatus>>>;
pub fn init_upstreams(cfg: &UnboundConfig) -> SharedUpstreams {
let mut statuses = Vec::new();
for fz in &cfg.forward_zones {
for addr in &fz.addrs {
let clean = addr.split('@').next().unwrap_or(addr).to_string();
statuses.push(UpstreamStatus {
addr: clean,
healthy: false,
latency_ms: None,
last_check: String::new(),
zone: fz.name.clone(),
});
}
}
Arc::new(RwLock::new(statuses))
}
pub async fn upstream_health_loop(upstreams: SharedUpstreams) {
let mut interval = time::interval(Duration::from_secs(PROBE_INTERVAL_SECS));
interval.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
let addrs: Vec<String> = {
upstreams.read().expect("upstreams: RwLock poisoned in health task").iter().map(|s| s.addr.clone()).collect()
};
let mut results: Vec<(bool, Option<u64>)> = Vec::with_capacity(addrs.len());
for addr in &addrs {
results.push(probe_upstream(addr));
}
let mut statuses = upstreams.write().expect("upstreams: RwLock poisoned in health task");
let now = crate::logbuffer::format_ts(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
);
for (i, (healthy, latency_ms)) in results.into_iter().enumerate() {
if let Some(s) = statuses.get_mut(i) {
if !healthy {
warn!(upstream = %s.addr, "Upstream DNS health check failed");
}
s.healthy = healthy;
s.latency_ms = latency_ms;
s.last_check = now.clone();
}
}
}
}
fn probe_upstream(addr: &str) -> (bool, Option<u64>) {
let target: SocketAddr = {
let with_port = if addr.contains(':') && !addr.starts_with('[') {
format!("[{}]:53", addr)
} else if addr.contains('@') {
let parts: Vec<&str> = addr.splitn(2, '@').collect();
format!("{}:{}", parts[0], parts.get(1).copied().unwrap_or("53"))
} else if addr.contains("]:") || (addr.contains('.') && addr.contains(':')) {
addr.to_string()
} else {
format!("{}:53", addr)
};
match with_port.parse() {
Ok(a) => a,
Err(_) => return (false, None),
}
};
let bind: SocketAddr = match target.ip() {
IpAddr::V4(_) => BIND_V4,
IpAddr::V6(_) => BIND_V6,
};
let sock = match UdpSocket::bind(bind) {
Ok(s) => s,
Err(_) => return (false, None),
};
let _ = sock.set_read_timeout(Some(Duration::from_millis(PROBE_TIMEOUT_MS)));
let t0 = Instant::now();
if sock.send_to(&DNS_PROBE_PACKET, target).is_err() {
return (false, None);
}
let mut buf = [0u8; 512];
match sock.recv_from(&mut buf) {
Ok((n, _)) if n >= 2 => {
if buf[0] == DNS_PROBE_PACKET[0] && buf[1] == DNS_PROBE_PACKET[1] {
(true, Some(t0.elapsed().as_millis() as u64))
} else {
(false, None)
}
}
_ => (false, None),
}
}