use crate::error::{OverlayError, Result};
#[cfg(feature = "nat")]
use crate::nat::ConnectionType;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::process::Command;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(30);
pub const HANDSHAKE_TIMEOUT_SECS: u64 = 180;
pub const PING_TIMEOUT_SECS: u64 = 5;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PeerStatus {
pub public_key: String,
pub overlay_ip: Option<IpAddr>,
pub healthy: bool,
pub last_handshake_secs: Option<u64>,
pub last_ping_ms: Option<u64>,
pub failure_count: u32,
pub last_check: u64,
#[cfg(feature = "nat")]
#[serde(default)]
pub connection_type: ConnectionType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverlayHealth {
pub interface: String,
pub total_peers: usize,
pub healthy_peers: usize,
pub unhealthy_peers: usize,
pub peers: Vec<PeerStatus>,
pub last_check: u64,
}
#[derive(Debug, Clone)]
pub struct WgPeerStats {
pub public_key: String,
pub endpoint: Option<String>,
pub allowed_ips: Vec<String>,
pub last_handshake_time: Option<u64>,
pub transfer_rx: u64,
pub transfer_tx: u64,
}
pub struct OverlayHealthChecker {
interface: String,
check_interval: Duration,
handshake_timeout: Duration,
peer_status: RwLock<HashMap<String, PeerStatus>>,
}
impl OverlayHealthChecker {
#[must_use]
pub fn new(interface: &str, check_interval: Duration) -> Self {
Self {
interface: interface.to_string(),
check_interval,
handshake_timeout: Duration::from_secs(HANDSHAKE_TIMEOUT_SECS),
peer_status: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn default_for_interface(interface: &str) -> Self {
Self::new(interface, DEFAULT_CHECK_INTERVAL)
}
#[must_use]
pub fn with_handshake_timeout(mut self, timeout: Duration) -> Self {
self.handshake_timeout = timeout;
self
}
pub async fn run<F>(&self, mut on_status_change: F)
where
F: FnMut(&str, bool) + Send + 'static,
{
info!(
interface = %self.interface,
interval_secs = self.check_interval.as_secs(),
"Starting health check loop"
);
loop {
match self.check_all().await {
Ok(health) => {
for peer in &health.peers {
let mut cache = self.peer_status.write().await;
let changed = cache
.get(&peer.public_key)
.is_none_or(|prev| prev.healthy != peer.healthy);
if changed {
on_status_change(&peer.public_key, peer.healthy);
}
cache.insert(peer.public_key.clone(), peer.clone());
}
}
Err(e) => {
warn!(error = %e, "Health check failed");
}
}
tokio::time::sleep(self.check_interval).await;
}
}
#[allow(clippy::similar_names)]
pub async fn check_all(&self) -> Result<OverlayHealth> {
let now = current_timestamp();
let stats = self.get_wg_stats().await?;
let mut peers = Vec::with_capacity(stats.len());
let mut healthy_count = 0;
for stat in stats {
let healthy = self.is_peer_healthy(&stat);
if healthy {
healthy_count += 1;
}
let overlay_ip: Option<IpAddr> = stat.allowed_ips.iter().find_map(|ip_str| {
if ip_str.ends_with("/32") {
ip_str
.trim_end_matches("/32")
.parse::<IpAddr>()
.ok()
.filter(IpAddr::is_ipv4)
} else if ip_str.ends_with("/128") {
ip_str
.trim_end_matches("/128")
.parse::<IpAddr>()
.ok()
.filter(IpAddr::is_ipv6)
} else {
None
}
});
let status = PeerStatus {
public_key: stat.public_key,
overlay_ip,
healthy,
last_handshake_secs: stat.last_handshake_time.map(|t| now.saturating_sub(t)),
last_ping_ms: None, failure_count: u32::from(!healthy),
last_check: now,
#[cfg(feature = "nat")]
connection_type: ConnectionType::default(),
};
peers.push(status);
}
let total = peers.len();
Ok(OverlayHealth {
interface: self.interface.clone(),
total_peers: total,
healthy_peers: healthy_count,
unhealthy_peers: total - healthy_count,
peers,
last_check: now,
})
}
fn is_peer_healthy(&self, stats: &WgPeerStats) -> bool {
let now = current_timestamp();
let timeout_secs = self.handshake_timeout.as_secs();
stats
.last_handshake_time
.is_some_and(|t| now.saturating_sub(t) < timeout_secs)
}
pub async fn ping_peer(&self, overlay_ip: IpAddr) -> Result<Duration> {
let start = Instant::now();
#[cfg(target_os = "macos")]
let timeout_arg = (PING_TIMEOUT_SECS * 1000).to_string();
#[cfg(not(target_os = "macos"))]
let timeout_arg = PING_TIMEOUT_SECS.to_string();
let mut cmd = match overlay_ip {
IpAddr::V4(_) => Command::new("ping"),
IpAddr::V6(_) => {
#[cfg(target_os = "macos")]
{
Command::new("ping6")
}
#[cfg(not(target_os = "macos"))]
{
let mut c = Command::new("ping");
c.arg("-6");
c
}
}
};
cmd.args([
"-c",
"1", "-W",
&timeout_arg,
&overlay_ip.to_string(),
]);
let output =
tokio::time::timeout(Duration::from_secs(PING_TIMEOUT_SECS), cmd.output()).await;
match output {
Ok(Ok(result)) if result.status.success() => Ok(start.elapsed()),
Ok(Ok(_)) => Err(OverlayError::PeerUnreachable {
ip: overlay_ip,
reason: "ping failed".to_string(),
}),
Ok(Err(e)) => Err(OverlayError::PeerUnreachable {
ip: overlay_ip,
reason: e.to_string(),
}),
Err(_) => Err(OverlayError::PeerUnreachable {
ip: overlay_ip,
reason: "timeout".to_string(),
}),
}
}
pub async fn tcp_check(&self, overlay_ip: IpAddr, port: u16) -> Result<Duration> {
let start = Instant::now();
let addr = SocketAddr::new(overlay_ip, port);
let result = tokio::time::timeout(
Duration::from_secs(PING_TIMEOUT_SECS),
tokio::net::TcpStream::connect(addr),
)
.await;
match result {
Ok(Ok(_stream)) => Ok(start.elapsed()),
Ok(Err(e)) => Err(OverlayError::PeerUnreachable {
ip: overlay_ip,
reason: e.to_string(),
}),
Err(_) => Err(OverlayError::PeerUnreachable {
ip: overlay_ip,
reason: "timeout".to_string(),
}),
}
}
async fn get_wg_stats(&self) -> Result<Vec<WgPeerStats>> {
let sock_path = format!("/var/run/wireguard/{}.sock", self.interface);
let response = match uapi_get_raw(&sock_path).await {
Ok(resp) => resp,
Err(e) => {
let msg = e.to_string();
if msg.contains("No such file")
|| msg.contains("Connection refused")
|| msg.contains("not found")
{
return Ok(Vec::new());
}
return Err(OverlayError::TransportCommand(msg));
}
};
let peers = parse_uapi_get_response(&response);
debug!(interface = %self.interface, peer_count = peers.len(), "Retrieved overlay peer stats via UAPI");
Ok(peers)
}
pub async fn get_cached_status(&self, public_key: &str) -> Option<PeerStatus> {
let cache = self.peer_status.read().await;
cache.get(public_key).cloned()
}
pub fn check_interval(&self) -> Duration {
self.check_interval
}
pub fn interface(&self) -> &str {
&self.interface
}
}
fn current_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
async fn uapi_get_raw(sock_path: &str) -> std::result::Result<String, Box<dyn std::error::Error>> {
let mut stream = tokio::net::UnixStream::connect(sock_path).await?;
stream.write_all(b"get=1\n\n").await?;
stream.shutdown().await?;
let mut response = String::new();
stream.read_to_string(&mut response).await?;
Ok(response)
}
fn hex_key_to_base64(hex_key: &str) -> String {
use base64::{engine::general_purpose::STANDARD, Engine as _};
match hex::decode(hex_key) {
Ok(bytes) => STANDARD.encode(bytes),
Err(_) => hex_key.to_string(), }
}
fn parse_uapi_get_response(response: &str) -> Vec<WgPeerStats> {
let mut peers = Vec::new();
let mut current_peer: Option<WgPeerStats> = None;
let mut in_peer = false;
for line in response.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with("errno=") {
continue;
}
let Some((key, value)) = line.split_once('=') else {
continue;
};
match key {
"public_key" => {
if let Some(peer) = current_peer.take() {
peers.push(peer);
}
in_peer = true;
current_peer = Some(WgPeerStats {
public_key: hex_key_to_base64(value),
endpoint: None,
allowed_ips: Vec::new(),
last_handshake_time: None,
transfer_rx: 0,
transfer_tx: 0,
});
}
"endpoint" if in_peer => {
if let Some(ref mut peer) = current_peer {
if value != "(none)" {
peer.endpoint = Some(value.to_string());
}
}
}
"allowed_ip" if in_peer => {
if let Some(ref mut peer) = current_peer {
peer.allowed_ips.push(value.to_string());
}
}
"last_handshake_time_sec" if in_peer => {
if let Some(ref mut peer) = current_peer {
if let Ok(t) = value.parse::<u64>() {
if t > 0 {
peer.last_handshake_time = Some(t);
}
}
}
}
"rx_bytes" if in_peer => {
if let Some(ref mut peer) = current_peer {
peer.transfer_rx = value.parse().unwrap_or(0);
}
}
"tx_bytes" if in_peer => {
if let Some(ref mut peer) = current_peer {
peer.transfer_tx = value.parse().unwrap_or(0);
}
}
_ => {}
}
}
if let Some(peer) = current_peer {
peers.push(peer);
}
peers
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_peer_status_serialization_v4() {
let status = PeerStatus {
public_key: "test_key".to_string(),
overlay_ip: Some("10.200.0.5".parse::<IpAddr>().unwrap()),
healthy: true,
last_handshake_secs: Some(10),
last_ping_ms: Some(5),
failure_count: 0,
last_check: 1_234_567_890,
#[cfg(feature = "nat")]
connection_type: ConnectionType::default(),
};
let json = serde_json::to_string(&status).unwrap();
let deserialized: PeerStatus = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.public_key, "test_key");
assert!(deserialized.healthy);
assert_eq!(
deserialized.overlay_ip,
Some("10.200.0.5".parse::<IpAddr>().unwrap())
);
}
#[test]
fn test_peer_status_serialization_v6() {
let status = PeerStatus {
public_key: "test_key_v6".to_string(),
overlay_ip: Some("fd00::5".parse::<IpAddr>().unwrap()),
healthy: true,
last_handshake_secs: Some(10),
last_ping_ms: Some(5),
failure_count: 0,
last_check: 1_234_567_890,
#[cfg(feature = "nat")]
connection_type: ConnectionType::default(),
};
let json = serde_json::to_string(&status).unwrap();
let deserialized: PeerStatus = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.public_key, "test_key_v6");
assert!(deserialized.healthy);
assert_eq!(
deserialized.overlay_ip,
Some("fd00::5".parse::<IpAddr>().unwrap())
);
}
#[test]
fn test_overlay_health_serialization() {
let health = OverlayHealth {
interface: "zl-overlay0".to_string(),
total_peers: 2,
healthy_peers: 1,
unhealthy_peers: 1,
peers: vec![],
last_check: 1_234_567_890,
};
let json = serde_json::to_string_pretty(&health).unwrap();
assert!(json.contains("zl-overlay0"));
}
#[test]
fn test_health_checker_creation() {
let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(60));
assert_eq!(checker.interface(), "wg0");
assert_eq!(checker.check_interval(), Duration::from_secs(60));
}
#[test]
fn test_is_peer_healthy_recent_handshake() {
let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(30));
let now = current_timestamp();
let stats = WgPeerStats {
public_key: "key".to_string(),
endpoint: None,
allowed_ips: vec![],
last_handshake_time: Some(now - 60), transfer_rx: 0,
transfer_tx: 0,
};
assert!(checker.is_peer_healthy(&stats));
}
#[test]
fn test_is_peer_healthy_stale_handshake() {
let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(30));
let now = current_timestamp();
let stats = WgPeerStats {
public_key: "key".to_string(),
endpoint: None,
allowed_ips: vec![],
last_handshake_time: Some(now - 300), transfer_rx: 0,
transfer_tx: 0,
};
assert!(!checker.is_peer_healthy(&stats));
}
#[test]
fn test_is_peer_healthy_no_handshake() {
let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(30));
let stats = WgPeerStats {
public_key: "key".to_string(),
endpoint: None,
allowed_ips: vec![],
last_handshake_time: None,
transfer_rx: 0,
transfer_tx: 0,
};
assert!(!checker.is_peer_healthy(&stats));
}
#[test]
fn test_parse_uapi_get_response() {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let key_bytes = [0xABu8; 32];
let hex_key = hex::encode(key_bytes);
let expected_b64 = STANDARD.encode(key_bytes);
let response = format!(
"private_key=0000000000000000000000000000000000000000000000000000000000000000\n\
listen_port=51820\n\
public_key={hex_key}\n\
endpoint=192.168.1.5:51820\n\
allowed_ip=10.200.0.2/32\n\
last_handshake_time_sec=1700000000\n\
last_handshake_time_nsec=0\n\
rx_bytes=12345\n\
tx_bytes=67890\n\
persistent_keepalive_interval=25\n\
errno=0\n"
);
let peers = parse_uapi_get_response(&response);
assert_eq!(peers.len(), 1);
let peer = &peers[0];
assert_eq!(peer.public_key, expected_b64);
assert_eq!(peer.endpoint, Some("192.168.1.5:51820".to_string()));
assert_eq!(peer.allowed_ips, vec!["10.200.0.2/32".to_string()]);
assert_eq!(peer.last_handshake_time, Some(1_700_000_000));
assert_eq!(peer.transfer_rx, 12345);
assert_eq!(peer.transfer_tx, 67890);
}
#[test]
fn test_parse_uapi_get_response_multiple_peers() {
let key1 = hex::encode([0x01u8; 32]);
let key2 = hex::encode([0x02u8; 32]);
let response = format!(
"private_key=0000000000000000000000000000000000000000000000000000000000000000\n\
listen_port=51820\n\
public_key={key1}\n\
endpoint=10.0.0.1:51820\n\
allowed_ip=10.200.0.2/32\n\
rx_bytes=100\n\
tx_bytes=200\n\
public_key={key2}\n\
endpoint=10.0.0.2:51821\n\
allowed_ip=10.200.0.3/32\n\
allowed_ip=10.200.1.0/24\n\
rx_bytes=300\n\
tx_bytes=400\n\
errno=0\n"
);
let peers = parse_uapi_get_response(&response);
assert_eq!(peers.len(), 2);
assert_eq!(peers[0].transfer_rx, 100);
assert_eq!(peers[1].transfer_rx, 300);
assert_eq!(peers[1].allowed_ips.len(), 2);
}
#[test]
fn test_parse_uapi_get_response_empty() {
let response = "private_key=0000\nlisten_port=51820\nerrno=0\n";
let peers = parse_uapi_get_response(response);
assert!(peers.is_empty());
}
#[test]
fn test_hex_key_to_base64_roundtrip() {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let key_bytes = [0xCDu8; 32];
let hex_key = hex::encode(key_bytes);
let b64 = hex_key_to_base64(&hex_key);
let expected = STANDARD.encode(key_bytes);
assert_eq!(b64, expected);
}
}