Skip to main content

actr_cli/core/components/
network_validator.rs

1//! Default NetworkValidator implementation
2
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use std::net::ToSocketAddrs;
6use std::time::Duration;
7use tokio::net::TcpStream;
8use url::Url;
9
10use super::{
11    ConnectivityStatus, HealthStatus, LatencyInfo, NetworkCheckOptions, NetworkCheckResult,
12    NetworkValidator,
13};
14
15/// Default network validator
16pub struct DefaultNetworkValidator;
17
18impl DefaultNetworkValidator {
19    pub fn new() -> Self {
20        Self
21    }
22
23    /// Try to connect to a host and measure latency
24    async fn ping_host(&self, host_port: &str, timeout: Duration) -> Result<Duration> {
25        let start = std::time::Instant::now();
26
27        // Attempt TCP connection as a proxy for reachability
28        let _stream = tokio::time::timeout(timeout, TcpStream::connect(host_port))
29            .await
30            .context("Connection timeout")?
31            .context("Failed to connect")?;
32
33        Ok(start.elapsed())
34    }
35
36    /// Parse a service name or URL into a host:port string
37    fn resolve_address(&self, address: &str) -> Result<String> {
38        if let Ok(url) = Url::parse(address) {
39            let host = url.host_str().context("No host in URL")?;
40            let port = url.port_or_known_default().unwrap_or(80);
41            Ok(format!("{}:{}", host, port))
42        } else if address.contains(':') {
43            Ok(address.to_string())
44        } else {
45            // Assume it's a hostname and try to resolve to verify it exists
46            let addr = format!("{}:80", address);
47            if addr.to_socket_addrs().is_ok() {
48                Ok(addr)
49            } else {
50                anyhow::bail!("Invalid address format: {}", address)
51            }
52        }
53    }
54}
55
56impl Default for DefaultNetworkValidator {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62#[async_trait]
63impl NetworkValidator for DefaultNetworkValidator {
64    async fn check_connectivity(
65        &self,
66        service_address: &str,
67        options: &NetworkCheckOptions,
68    ) -> Result<ConnectivityStatus> {
69        let timeout = options.timeout;
70        match self.resolve_address(service_address) {
71            Ok(addr) => match self.ping_host(&addr, timeout).await {
72                Ok(latency) => Ok(ConnectivityStatus {
73                    is_reachable: true,
74                    response_time_ms: Some(latency.as_millis() as u64),
75                    error: None,
76                }),
77                Err(e) => Ok(ConnectivityStatus {
78                    is_reachable: false,
79                    response_time_ms: None,
80                    error: Some(e.to_string()),
81                }),
82            },
83            Err(e) => Ok(ConnectivityStatus {
84                is_reachable: false,
85                response_time_ms: None,
86                error: Some(format!("Address resolution failed: {}", e)),
87            }),
88        }
89    }
90
91    async fn verify_service_health(
92        &self,
93        service_name: &str,
94        options: &NetworkCheckOptions,
95    ) -> Result<HealthStatus> {
96        let status = self.check_connectivity(service_name, options).await?;
97        if status.is_reachable {
98            Ok(HealthStatus::Healthy)
99        } else {
100            Ok(HealthStatus::Unhealthy)
101        }
102    }
103
104    async fn test_latency(
105        &self,
106        service_name: &str,
107        options: &NetworkCheckOptions,
108    ) -> Result<LatencyInfo> {
109        let mut samples = Vec::new();
110        let timeout = options.timeout;
111        let addr = self.resolve_address(service_name)?;
112
113        for _ in 0..3 {
114            if let Ok(latency) = self.ping_host(&addr, timeout).await {
115                samples.push(latency.as_millis() as u64);
116            }
117            tokio::time::sleep(Duration::from_millis(100)).await;
118        }
119
120        if samples.is_empty() {
121            anyhow::bail!("Failed to get latency samples for {}", service_name);
122        }
123
124        let min = *samples.iter().min().unwrap();
125        let max = *samples.iter().max().unwrap();
126        let avg = samples.iter().sum::<u64>() / samples.len() as u64;
127
128        Ok(LatencyInfo {
129            min_ms: min,
130            max_ms: max,
131            avg_ms: avg,
132            samples: samples.len() as u32,
133        })
134    }
135
136    async fn batch_check(
137        &self,
138        service_names: &[String],
139        options: &NetworkCheckOptions,
140    ) -> Result<Vec<NetworkCheckResult>> {
141        let mut results = Vec::new();
142        for name in service_names {
143            let connectivity = self.check_connectivity(name, options).await?;
144            let health = if connectivity.is_reachable {
145                HealthStatus::Healthy
146            } else {
147                HealthStatus::Unhealthy
148            };
149
150            let latency = if connectivity.is_reachable {
151                self.test_latency(name, options).await.ok()
152            } else {
153                None
154            };
155
156            results.push(NetworkCheckResult {
157                connectivity,
158                health,
159                latency,
160            });
161        }
162        Ok(results)
163    }
164}