Skip to main content

netcore_probe/
lib.rs

1//! Reachability backend.
2//!
3//! All probes run synchronously on the calling thread. No tokio — probes are
4//! small sequential I/O and the overhead of a runtime per call isn't worth
5//! it here (unlike netlink/D-Bus, which need async for stream protocols).
6//!
7//! ICMP: datagram sockets with `IPPROTO_ICMP` / `IPPROTO_ICMPV6`. These work
8//! without `CAP_NET_RAW` when the kernel's `ping_group_range` covers the
9//! caller (verified at startup via `/proc/sys/net/ipv4/ping_group_range`).
10//! Falls back to returning `capabilities().has_ping = false` when we can't
11//! open the socket, so the Diagnostician can skip ICMP probes gracefully.
12//!
13//! TCP: `std::net::TcpStream::connect_timeout`. Good enough for v0.1.
14//!
15//! TLS: rustls with webpki-roots; we just want "did the handshake complete
16//! with a valid chain?" — not full validation semantics.
17
18use std::io::{Read, Write};
19use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpStream};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22
23use socket2::{Domain, Protocol, Socket, Type};
24
25use netcore::Error;
26use netcore::Result as NcResult;
27use netcore::diag::{PingOpts, ProbeCapabilities, TraceOpts};
28use netcore::path::{Hop, HttpProbeResult, PingResult, TcpProbeResult, TlsProbeResult};
29use netcore::traits::Reachability;
30
31mod icmp;
32
33/// Reachability backend: ICMP ping, TCP connect, TLS handshake, HTTP HEAD,
34/// and traceroute. All methods are synchronous.
35pub struct ProbeBackend {
36    caps: ProbeCapabilities,
37}
38
39impl ProbeBackend {
40    /// Detect available capabilities and return a new backend.
41    pub fn new() -> Self {
42        Self {
43            caps: detect_capabilities(),
44        }
45    }
46}
47
48impl Default for ProbeBackend {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl Reachability for ProbeBackend {
55    fn ping(&self, ip: IpAddr, opts: PingOpts) -> NcResult<PingResult> {
56        if !self.caps.has_ping {
57            return Err(Error::Unsupported(
58                "unprivileged ICMP unavailable; run as root or adjust net.ipv4.ping_group_range",
59            ));
60        }
61        icmp::ping(ip, opts)
62    }
63
64    fn tcp_connect(&self, sa: SocketAddr, timeout: Duration) -> NcResult<TcpProbeResult> {
65        let start = Instant::now();
66        match TcpStream::connect_timeout(&sa, timeout) {
67            Ok(s) => {
68                let took = start.elapsed();
69                let _ = s.shutdown(std::net::Shutdown::Both);
70                Ok(TcpProbeResult {
71                    addr: sa,
72                    connected: true,
73                    took,
74                    error: None,
75                })
76            }
77            Err(e) => Ok(TcpProbeResult {
78                addr: sa,
79                connected: false,
80                took: start.elapsed(),
81                error: Some(classify_tcp_error(&e)),
82            }),
83        }
84    }
85
86    fn tls_handshake(
87        &self,
88        sa: SocketAddr,
89        sni: &str,
90        timeout: Duration,
91    ) -> NcResult<TlsProbeResult> {
92        let start = Instant::now();
93        let out = do_tls(sa, sni, timeout);
94        let took = start.elapsed();
95        match out {
96            Ok(()) => Ok(TlsProbeResult {
97                peer: sa,
98                sni: sni.into(),
99                negotiated: true,
100                took,
101                error: None,
102            }),
103            Err(e) => Ok(TlsProbeResult {
104                peer: sa,
105                sni: sni.into(),
106                negotiated: false,
107                took,
108                error: Some(e.to_string()),
109            }),
110        }
111    }
112
113    fn http_head(&self, url: &url::Url, timeout: Duration) -> NcResult<HttpProbeResult> {
114        let start = Instant::now();
115        let result = do_http_head(url, timeout);
116        let took = start.elapsed();
117        match result {
118            Ok(status) => Ok(HttpProbeResult {
119                url: url.to_string(),
120                status: Some(status),
121                took,
122                error: None,
123            }),
124            Err(e) => Ok(HttpProbeResult {
125                url: url.to_string(),
126                status: None,
127                took,
128                error: Some(e.to_string()),
129            }),
130        }
131    }
132
133    fn trace(&self, ip: IpAddr, opts: TraceOpts) -> NcResult<Vec<Hop>> {
134        if !self.caps.has_ping {
135            return Err(Error::Unsupported(
136                "trace requires unprivileged ICMP; not available on this host",
137            ));
138        }
139        icmp::trace(ip, opts)
140    }
141
142    fn capabilities(&self) -> ProbeCapabilities {
143        self.caps.clone()
144    }
145}
146
147/// Probe whether we can open unprivileged ICMP sockets once, at startup.
148/// Much cheaper than checking `ping_group_range` text — the kernel tells us
149/// directly by permitting or refusing the socket call.
150fn detect_capabilities() -> ProbeCapabilities {
151    let has_v4 = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4)).is_ok();
152    ProbeCapabilities {
153        has_ping: has_v4,
154        has_traceroute: has_v4,
155        has_mtr: false,
156        has_tracepath: false,
157        unprivileged_icmp: has_v4,
158    }
159}
160
161fn classify_tcp_error(e: &std::io::Error) -> String {
162    use std::io::ErrorKind::*;
163    match e.kind() {
164        TimedOut | WouldBlock => "timeout".into(),
165        ConnectionRefused => "refused".into(),
166        ConnectionReset => "reset".into(),
167        HostUnreachable | NetworkUnreachable => "unreachable".into(),
168        _ => e.to_string(),
169    }
170}
171
172// ---- TLS ----
173
174fn do_tls(sa: SocketAddr, sni: &str, timeout: Duration) -> Result<(), Box<dyn std::error::Error>> {
175    use rustls::ClientConnection;
176    let server_name = rustls_pki_types::ServerName::try_from(sni.to_string())?;
177    let mut sock = TcpStream::connect_timeout(&sa, timeout)?;
178    sock.set_read_timeout(Some(timeout))?;
179    sock.set_write_timeout(Some(timeout))?;
180    let config = tls_client_config()?;
181    let mut conn = ClientConnection::new(config, server_name)?;
182    // Drive the handshake to completion.
183    while conn.is_handshaking() {
184        if conn.wants_write() {
185            conn.write_tls(&mut sock)?;
186        }
187        if conn.wants_read() {
188            conn.read_tls(&mut sock)?;
189            conn.process_new_packets()?;
190        }
191    }
192    Ok(())
193}
194
195fn tls_client_config() -> Result<Arc<rustls::ClientConfig>, Box<dyn std::error::Error>> {
196    let mut roots = rustls::RootCertStore::empty();
197    roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
198    let cfg = rustls::ClientConfig::builder()
199        .with_root_certificates(roots)
200        .with_no_client_auth();
201    Ok(Arc::new(cfg))
202}
203
204// ---- HTTP HEAD ----
205
206fn do_http_head(url: &url::Url, timeout: Duration) -> Result<u16, Box<dyn std::error::Error>> {
207    let host = url.host_str().ok_or("url missing host")?;
208    let port = url.port_or_known_default().ok_or("url missing port")?;
209    let path = if url.path().is_empty() {
210        "/"
211    } else {
212        url.path()
213    };
214    let sas: Vec<SocketAddr> = (host, port).to_socket_addrs_local()?.collect();
215    let sa = *sas.first().ok_or("no address for host")?;
216    let req = format!(
217        "HEAD {path} HTTP/1.1\r\nHost: {host}\r\nUser-Agent: jip/0.1\r\nConnection: close\r\nAccept: */*\r\n\r\n"
218    );
219    let mut buf = [0u8; 2048];
220    let n = match url.scheme() {
221        "https" => http_over_tls(sa, host, &req, timeout, &mut buf)?,
222        "http" => http_plain(sa, &req, timeout, &mut buf)?,
223        other => return Err(format!("unsupported scheme: {other}").into()),
224    };
225    parse_http_status(&buf[..n])
226}
227
228fn http_plain(
229    sa: SocketAddr,
230    req: &str,
231    timeout: Duration,
232    buf: &mut [u8],
233) -> Result<usize, Box<dyn std::error::Error>> {
234    let mut sock = TcpStream::connect_timeout(&sa, timeout)?;
235    sock.set_read_timeout(Some(timeout))?;
236    sock.set_write_timeout(Some(timeout))?;
237    sock.write_all(req.as_bytes())?;
238    Ok(sock.read(buf)?)
239}
240
241fn http_over_tls(
242    sa: SocketAddr,
243    sni: &str,
244    req: &str,
245    timeout: Duration,
246    buf: &mut [u8],
247) -> Result<usize, Box<dyn std::error::Error>> {
248    use rustls::ClientConnection;
249    let server_name = rustls_pki_types::ServerName::try_from(sni.to_string())?;
250    let mut sock = TcpStream::connect_timeout(&sa, timeout)?;
251    sock.set_read_timeout(Some(timeout))?;
252    sock.set_write_timeout(Some(timeout))?;
253    let config = tls_client_config()?;
254    let mut conn = ClientConnection::new(config, server_name)?;
255    let mut tls = rustls::Stream::new(&mut conn, &mut sock);
256    tls.write_all(req.as_bytes())?;
257    // Read may not fill the buffer; one read is enough to get the status line.
258    Ok(tls.read(buf).unwrap_or(0))
259}
260
261fn parse_http_status(buf: &[u8]) -> Result<u16, Box<dyn std::error::Error>> {
262    let s = std::str::from_utf8(buf)?;
263    let first = s.lines().next().ok_or("empty http response")?;
264    // HTTP/1.1 200 OK
265    let mut parts = first.split_whitespace();
266    parts.next().ok_or("missing version")?;
267    let code = parts.next().ok_or("missing code")?;
268    Ok(code.parse()?)
269}
270
271/// Small shim so we can route `(host, port)` through stdlib's ToSocketAddrs
272/// without importing the trait everywhere — keeps the call site short.
273trait ToSocketAddrsLocal {
274    fn to_socket_addrs_local(&self) -> std::io::Result<std::vec::IntoIter<SocketAddr>>;
275}
276
277impl ToSocketAddrsLocal for (&str, u16) {
278    fn to_socket_addrs_local(&self) -> std::io::Result<std::vec::IntoIter<SocketAddr>> {
279        use std::net::ToSocketAddrs;
280        let v: Vec<SocketAddr> = self.to_socket_addrs()?.collect();
281        Ok(v.into_iter())
282    }
283}
284
285// Keep these imports used when socket2 doesn't directly appear in pub surface.
286#[allow(dead_code)]
287fn _touch(_: Ipv4Addr, _: Ipv6Addr) {}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use std::net::{SocketAddr, TcpListener};
293
294    #[test]
295    fn tcp_connect_to_localhost_listener() {
296        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
297        let sa = listener.local_addr().unwrap();
298        let b = ProbeBackend::new();
299        let r = b.tcp_connect(sa, Duration::from_millis(500)).unwrap();
300        assert!(r.connected, "connect to own listener");
301        assert!(r.error.is_none());
302    }
303
304    #[test]
305    fn tcp_connect_refused_on_unused_port() {
306        // Ephemeral port that we bind-then-drop — OS is very likely to keep
307        // it closed for the next microsecond. If it races, the test would
308        // incorrectly pass with connected=true; accept either but assert the
309        // call didn't error out structurally.
310        let sa: SocketAddr = {
311            let l = TcpListener::bind("127.0.0.1:0").unwrap();
312            l.local_addr().unwrap()
313        };
314        let b = ProbeBackend::new();
315        let r = b.tcp_connect(sa, Duration::from_millis(200)).unwrap();
316        // After the listener dropped, connect should refuse or (rarely) race.
317        assert!(
318            !r.connected || r.error.is_none(),
319            "structural invariants hold even under race"
320        );
321    }
322
323    #[test]
324    fn capabilities_are_detected() {
325        let b = ProbeBackend::new();
326        let c = b.capabilities();
327        // On this CI/dev machine ping_group_range is wide; if it weren't,
328        // has_ping would be false but the method must still be callable.
329        let _ = c.has_ping;
330    }
331
332    #[test]
333    fn ping_loopback_if_icmp_available() {
334        let b = ProbeBackend::new();
335        if !b.capabilities().has_ping {
336            return;
337        }
338        let r = b
339            .ping(
340                IpAddr::V4(Ipv4Addr::LOCALHOST),
341                PingOpts {
342                    count: 1,
343                    timeout: Duration::from_millis(500),
344                },
345            )
346            .unwrap();
347        assert_eq!(r.sent, 1);
348        assert_eq!(r.received, 1, "loopback ICMP must round-trip");
349    }
350}