1use crate::error::{OverlayError, Result};
8#[cfg(feature = "nat")]
9use crate::nat::ConnectionType;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::{IpAddr, SocketAddr};
13use std::time::{Duration, Instant};
14#[cfg(unix)]
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::process::Command;
17use tokio::sync::RwLock;
18use tracing::{debug, info, warn};
19
20pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(30);
22
23pub const HANDSHAKE_TIMEOUT_SECS: u64 = 180;
25
26pub const PING_TIMEOUT_SECS: u64 = 5;
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct PeerStatus {
32 pub public_key: String,
34
35 pub overlay_ip: Option<IpAddr>,
37
38 pub healthy: bool,
40
41 pub last_handshake_secs: Option<u64>,
43
44 pub last_ping_ms: Option<u64>,
46
47 pub failure_count: u32,
49
50 pub last_check: u64,
52
53 #[cfg(feature = "nat")]
55 #[serde(default)]
56 pub connection_type: ConnectionType,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct OverlayHealth {
62 pub interface: String,
64
65 pub total_peers: usize,
67
68 pub healthy_peers: usize,
70
71 pub unhealthy_peers: usize,
73
74 pub peers: Vec<PeerStatus>,
76
77 pub last_check: u64,
79}
80
81#[derive(Debug, Clone)]
83pub struct WgPeerStats {
84 pub public_key: String,
85 pub endpoint: Option<String>,
86 pub allowed_ips: Vec<String>,
87 pub last_handshake_time: Option<u64>,
88 pub transfer_rx: u64,
89 pub transfer_tx: u64,
90}
91
92pub struct OverlayHealthChecker {
97 interface: String,
99
100 check_interval: Duration,
102
103 handshake_timeout: Duration,
105
106 peer_status: RwLock<HashMap<String, PeerStatus>>,
108}
109
110impl OverlayHealthChecker {
111 #[must_use]
113 pub fn new(interface: &str, check_interval: Duration) -> Self {
114 Self {
115 interface: interface.to_string(),
116 check_interval,
117 handshake_timeout: Duration::from_secs(HANDSHAKE_TIMEOUT_SECS),
118 peer_status: RwLock::new(HashMap::new()),
119 }
120 }
121
122 #[must_use]
124 pub fn default_for_interface(interface: &str) -> Self {
125 Self::new(interface, DEFAULT_CHECK_INTERVAL)
126 }
127
128 #[must_use]
130 pub fn with_handshake_timeout(mut self, timeout: Duration) -> Self {
131 self.handshake_timeout = timeout;
132 self
133 }
134
135 pub async fn run<F>(&self, mut on_status_change: F)
139 where
140 F: FnMut(&str, bool) + Send + 'static,
141 {
142 info!(
143 interface = %self.interface,
144 interval_secs = self.check_interval.as_secs(),
145 "Starting health check loop"
146 );
147
148 loop {
149 match self.check_all().await {
150 Ok(health) => {
151 for peer in &health.peers {
152 let mut cache = self.peer_status.write().await;
154 let changed = cache
155 .get(&peer.public_key)
156 .is_none_or(|prev| prev.healthy != peer.healthy);
157
158 if changed {
159 on_status_change(&peer.public_key, peer.healthy);
160 }
161
162 cache.insert(peer.public_key.clone(), peer.clone());
163 }
164 }
165 Err(e) => {
166 warn!(error = %e, "Health check failed");
167 }
168 }
169
170 tokio::time::sleep(self.check_interval).await;
171 }
172 }
173
174 #[allow(clippy::similar_names)]
180 pub async fn check_all(&self) -> Result<OverlayHealth> {
181 let now = current_timestamp();
182 let stats = self.get_wg_stats().await?;
183
184 let mut peers = Vec::with_capacity(stats.len());
185 let mut healthy_count = 0;
186
187 for stat in stats {
188 let healthy = self.is_peer_healthy(&stat);
189
190 if healthy {
191 healthy_count += 1;
192 }
193
194 let overlay_ip: Option<IpAddr> = stat.allowed_ips.iter().find_map(|ip_str| {
196 if ip_str.ends_with("/32") {
197 ip_str
198 .trim_end_matches("/32")
199 .parse::<IpAddr>()
200 .ok()
201 .filter(IpAddr::is_ipv4)
202 } else if ip_str.ends_with("/128") {
203 ip_str
204 .trim_end_matches("/128")
205 .parse::<IpAddr>()
206 .ok()
207 .filter(IpAddr::is_ipv6)
208 } else {
209 None
210 }
211 });
212
213 let status = PeerStatus {
214 public_key: stat.public_key,
215 overlay_ip,
216 healthy,
217 last_handshake_secs: stat.last_handshake_time.map(|t| now.saturating_sub(t)),
218 last_ping_ms: None, failure_count: u32::from(!healthy),
220 last_check: now,
221 #[cfg(feature = "nat")]
222 connection_type: ConnectionType::default(),
223 };
224
225 peers.push(status);
226 }
227
228 let total = peers.len();
229 Ok(OverlayHealth {
230 interface: self.interface.clone(),
231 total_peers: total,
232 healthy_peers: healthy_count,
233 unhealthy_peers: total - healthy_count,
234 peers,
235 last_check: now,
236 })
237 }
238
239 fn is_peer_healthy(&self, stats: &WgPeerStats) -> bool {
241 let now = current_timestamp();
242 let timeout_secs = self.handshake_timeout.as_secs();
243
244 stats
245 .last_handshake_time
246 .is_some_and(|t| now.saturating_sub(t) < timeout_secs)
247 }
248
249 pub async fn ping_peer(&self, overlay_ip: IpAddr) -> Result<Duration> {
260 let start = Instant::now();
261
262 #[cfg(target_os = "macos")]
265 let timeout_arg = (PING_TIMEOUT_SECS * 1000).to_string();
266 #[cfg(not(target_os = "macos"))]
267 let timeout_arg = PING_TIMEOUT_SECS.to_string();
268
269 let mut cmd = match overlay_ip {
271 IpAddr::V4(_) => Command::new("ping"),
272 IpAddr::V6(_) => {
273 #[cfg(target_os = "macos")]
274 {
275 Command::new("ping6")
276 }
277 #[cfg(not(target_os = "macos"))]
278 {
279 let mut c = Command::new("ping");
280 c.arg("-6");
281 c
282 }
283 }
284 };
285
286 cmd.args([
287 "-c",
288 "1", "-W",
290 &timeout_arg,
291 &overlay_ip.to_string(),
292 ]);
293
294 let output =
295 tokio::time::timeout(Duration::from_secs(PING_TIMEOUT_SECS), cmd.output()).await;
296
297 match output {
298 Ok(Ok(result)) if result.status.success() => Ok(start.elapsed()),
299 Ok(Ok(_)) => Err(OverlayError::PeerUnreachable {
300 ip: overlay_ip,
301 reason: "ping failed".to_string(),
302 }),
303 Ok(Err(e)) => Err(OverlayError::PeerUnreachable {
304 ip: overlay_ip,
305 reason: e.to_string(),
306 }),
307 Err(_) => Err(OverlayError::PeerUnreachable {
308 ip: overlay_ip,
309 reason: "timeout".to_string(),
310 }),
311 }
312 }
313
314 pub async fn tcp_check(&self, overlay_ip: IpAddr, port: u16) -> Result<Duration> {
322 let start = Instant::now();
323
324 let addr = SocketAddr::new(overlay_ip, port);
325 let result = tokio::time::timeout(
326 Duration::from_secs(PING_TIMEOUT_SECS),
327 tokio::net::TcpStream::connect(addr),
328 )
329 .await;
330
331 match result {
332 Ok(Ok(_stream)) => Ok(start.elapsed()),
333 Ok(Err(e)) => Err(OverlayError::PeerUnreachable {
334 ip: overlay_ip,
335 reason: e.to_string(),
336 }),
337 Err(_) => Err(OverlayError::PeerUnreachable {
338 ip: overlay_ip,
339 reason: "timeout".to_string(),
340 }),
341 }
342 }
343
344 async fn get_wg_stats(&self) -> Result<Vec<WgPeerStats>> {
350 let sock_path = format!("/var/run/wireguard/{}.sock", self.interface);
351
352 let response = match uapi_get_raw(&sock_path).await {
353 Ok(resp) => resp,
354 Err(e) => {
355 let msg = e.to_string();
356 if msg.contains("No such file")
358 || msg.contains("Connection refused")
359 || msg.contains("not found")
360 {
361 return Ok(Vec::new());
362 }
363 return Err(OverlayError::TransportCommand(msg));
364 }
365 };
366
367 let peers = parse_uapi_get_response(&response);
368
369 debug!(interface = %self.interface, peer_count = peers.len(), "Retrieved overlay peer stats via UAPI");
370 Ok(peers)
371 }
372
373 pub async fn get_cached_status(&self, public_key: &str) -> Option<PeerStatus> {
375 let cache = self.peer_status.read().await;
376 cache.get(public_key).cloned()
377 }
378
379 pub fn check_interval(&self) -> Duration {
381 self.check_interval
382 }
383
384 pub fn interface(&self) -> &str {
386 &self.interface
387 }
388}
389
390fn current_timestamp() -> u64 {
392 std::time::SystemTime::now()
393 .duration_since(std::time::UNIX_EPOCH)
394 .unwrap_or_default()
395 .as_secs()
396}
397
398#[cfg(unix)]
405async fn uapi_get_raw(sock_path: &str) -> std::result::Result<String, Box<dyn std::error::Error>> {
406 let mut stream = tokio::net::UnixStream::connect(sock_path).await?;
407 stream.write_all(b"get=1\n\n").await?;
408 stream.shutdown().await?;
409 let mut response = String::new();
410 stream.read_to_string(&mut response).await?;
411 Ok(response)
412}
413
414#[cfg(not(unix))]
421#[allow(clippy::unused_async)]
422async fn uapi_get_raw(_sock_path: &str) -> std::result::Result<String, Box<dyn std::error::Error>> {
423 Err(Box::new(std::io::Error::new(
424 std::io::ErrorKind::NotFound,
425 "UAPI Unix socket not supported on this platform",
426 )))
427}
428
429fn hex_key_to_base64(hex_key: &str) -> String {
431 use base64::{engine::general_purpose::STANDARD, Engine as _};
432 match hex::decode(hex_key) {
433 Ok(bytes) => STANDARD.encode(bytes),
434 Err(_) => hex_key.to_string(), }
436}
437
438fn parse_uapi_get_response(response: &str) -> Vec<WgPeerStats> {
446 let mut peers = Vec::new();
447 let mut current_peer: Option<WgPeerStats> = None;
448 let mut in_peer = false;
449
450 for line in response.lines() {
451 let line = line.trim();
452 if line.is_empty() || line.starts_with("errno=") {
453 continue;
454 }
455
456 let Some((key, value)) = line.split_once('=') else {
457 continue;
458 };
459
460 match key {
461 "public_key" => {
462 if let Some(peer) = current_peer.take() {
464 peers.push(peer);
465 }
466 in_peer = true;
467 current_peer = Some(WgPeerStats {
468 public_key: hex_key_to_base64(value),
469 endpoint: None,
470 allowed_ips: Vec::new(),
471 last_handshake_time: None,
472 transfer_rx: 0,
473 transfer_tx: 0,
474 });
475 }
476 "endpoint" if in_peer => {
477 if let Some(ref mut peer) = current_peer {
478 if value != "(none)" {
479 peer.endpoint = Some(value.to_string());
480 }
481 }
482 }
483 "allowed_ip" if in_peer => {
484 if let Some(ref mut peer) = current_peer {
485 peer.allowed_ips.push(value.to_string());
486 }
487 }
488 "last_handshake_time_sec" if in_peer => {
489 if let Some(ref mut peer) = current_peer {
490 if let Ok(t) = value.parse::<u64>() {
491 if t > 0 {
492 peer.last_handshake_time = Some(t);
493 }
494 }
495 }
496 }
497 "rx_bytes" if in_peer => {
498 if let Some(ref mut peer) = current_peer {
499 peer.transfer_rx = value.parse().unwrap_or(0);
500 }
501 }
502 "tx_bytes" if in_peer => {
503 if let Some(ref mut peer) = current_peer {
504 peer.transfer_tx = value.parse().unwrap_or(0);
505 }
506 }
507 _ => {}
509 }
510 }
511
512 if let Some(peer) = current_peer {
514 peers.push(peer);
515 }
516
517 peers
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_peer_status_serialization_v4() {
526 let status = PeerStatus {
527 public_key: "test_key".to_string(),
528 overlay_ip: Some("10.200.0.5".parse::<IpAddr>().unwrap()),
529 healthy: true,
530 last_handshake_secs: Some(10),
531 last_ping_ms: Some(5),
532 failure_count: 0,
533 last_check: 1_234_567_890,
534 #[cfg(feature = "nat")]
535 connection_type: ConnectionType::default(),
536 };
537
538 let json = serde_json::to_string(&status).unwrap();
539 let deserialized: PeerStatus = serde_json::from_str(&json).unwrap();
540
541 assert_eq!(deserialized.public_key, "test_key");
542 assert!(deserialized.healthy);
543 assert_eq!(
544 deserialized.overlay_ip,
545 Some("10.200.0.5".parse::<IpAddr>().unwrap())
546 );
547 }
548
549 #[test]
550 fn test_peer_status_serialization_v6() {
551 let status = PeerStatus {
552 public_key: "test_key_v6".to_string(),
553 overlay_ip: Some("fd00::5".parse::<IpAddr>().unwrap()),
554 healthy: true,
555 last_handshake_secs: Some(10),
556 last_ping_ms: Some(5),
557 failure_count: 0,
558 last_check: 1_234_567_890,
559 #[cfg(feature = "nat")]
560 connection_type: ConnectionType::default(),
561 };
562
563 let json = serde_json::to_string(&status).unwrap();
564 let deserialized: PeerStatus = serde_json::from_str(&json).unwrap();
565
566 assert_eq!(deserialized.public_key, "test_key_v6");
567 assert!(deserialized.healthy);
568 assert_eq!(
569 deserialized.overlay_ip,
570 Some("fd00::5".parse::<IpAddr>().unwrap())
571 );
572 }
573
574 #[test]
575 fn test_overlay_health_serialization() {
576 let health = OverlayHealth {
577 interface: "zl-overlay0".to_string(),
578 total_peers: 2,
579 healthy_peers: 1,
580 unhealthy_peers: 1,
581 peers: vec![],
582 last_check: 1_234_567_890,
583 };
584
585 let json = serde_json::to_string_pretty(&health).unwrap();
586 assert!(json.contains("zl-overlay0"));
587 }
588
589 #[test]
590 fn test_health_checker_creation() {
591 let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(60));
592 assert_eq!(checker.interface(), "wg0");
593 assert_eq!(checker.check_interval(), Duration::from_secs(60));
594 }
595
596 #[test]
597 fn test_is_peer_healthy_recent_handshake() {
598 let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(30));
599
600 let now = current_timestamp();
601 let stats = WgPeerStats {
602 public_key: "key".to_string(),
603 endpoint: None,
604 allowed_ips: vec![],
605 last_handshake_time: Some(now - 60), transfer_rx: 0,
607 transfer_tx: 0,
608 };
609
610 assert!(checker.is_peer_healthy(&stats));
612 }
613
614 #[test]
615 fn test_is_peer_healthy_stale_handshake() {
616 let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(30));
617
618 let now = current_timestamp();
619 let stats = WgPeerStats {
620 public_key: "key".to_string(),
621 endpoint: None,
622 allowed_ips: vec![],
623 last_handshake_time: Some(now - 300), transfer_rx: 0,
625 transfer_tx: 0,
626 };
627
628 assert!(!checker.is_peer_healthy(&stats));
630 }
631
632 #[test]
633 fn test_is_peer_healthy_no_handshake() {
634 let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(30));
635
636 let stats = WgPeerStats {
637 public_key: "key".to_string(),
638 endpoint: None,
639 allowed_ips: vec![],
640 last_handshake_time: None,
641 transfer_rx: 0,
642 transfer_tx: 0,
643 };
644
645 assert!(!checker.is_peer_healthy(&stats));
647 }
648
649 #[test]
650 fn test_parse_uapi_get_response() {
651 use base64::{engine::general_purpose::STANDARD, Engine as _};
652
653 let key_bytes = [0xABu8; 32];
655 let hex_key = hex::encode(key_bytes);
656 let expected_b64 = STANDARD.encode(key_bytes);
657
658 let response = format!(
659 "private_key=0000000000000000000000000000000000000000000000000000000000000000\n\
660 listen_port=51820\n\
661 public_key={hex_key}\n\
662 endpoint=192.168.1.5:51820\n\
663 allowed_ip=10.200.0.2/32\n\
664 last_handshake_time_sec=1700000000\n\
665 last_handshake_time_nsec=0\n\
666 rx_bytes=12345\n\
667 tx_bytes=67890\n\
668 persistent_keepalive_interval=25\n\
669 errno=0\n"
670 );
671
672 let peers = parse_uapi_get_response(&response);
673 assert_eq!(peers.len(), 1);
674
675 let peer = &peers[0];
676 assert_eq!(peer.public_key, expected_b64);
677 assert_eq!(peer.endpoint, Some("192.168.1.5:51820".to_string()));
678 assert_eq!(peer.allowed_ips, vec!["10.200.0.2/32".to_string()]);
679 assert_eq!(peer.last_handshake_time, Some(1_700_000_000));
680 assert_eq!(peer.transfer_rx, 12345);
681 assert_eq!(peer.transfer_tx, 67890);
682 }
683
684 #[test]
685 fn test_parse_uapi_get_response_multiple_peers() {
686 let key1 = hex::encode([0x01u8; 32]);
687 let key2 = hex::encode([0x02u8; 32]);
688
689 let response = format!(
690 "private_key=0000000000000000000000000000000000000000000000000000000000000000\n\
691 listen_port=51820\n\
692 public_key={key1}\n\
693 endpoint=10.0.0.1:51820\n\
694 allowed_ip=10.200.0.2/32\n\
695 rx_bytes=100\n\
696 tx_bytes=200\n\
697 public_key={key2}\n\
698 endpoint=10.0.0.2:51821\n\
699 allowed_ip=10.200.0.3/32\n\
700 allowed_ip=10.200.1.0/24\n\
701 rx_bytes=300\n\
702 tx_bytes=400\n\
703 errno=0\n"
704 );
705
706 let peers = parse_uapi_get_response(&response);
707 assert_eq!(peers.len(), 2);
708 assert_eq!(peers[0].transfer_rx, 100);
709 assert_eq!(peers[1].transfer_rx, 300);
710 assert_eq!(peers[1].allowed_ips.len(), 2);
711 }
712
713 #[test]
714 fn test_parse_uapi_get_response_empty() {
715 let response = "private_key=0000\nlisten_port=51820\nerrno=0\n";
716 let peers = parse_uapi_get_response(response);
717 assert!(peers.is_empty());
718 }
719
720 #[test]
721 fn test_hex_key_to_base64_roundtrip() {
722 use base64::{engine::general_purpose::STANDARD, Engine as _};
723
724 let key_bytes = [0xCDu8; 32];
725 let hex_key = hex::encode(key_bytes);
726 let b64 = hex_key_to_base64(&hex_key);
727 let expected = STANDARD.encode(key_bytes);
728 assert_eq!(b64, expected);
729 }
730}