1use std::io::{Read, Write};
19use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpStream};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22
23use socket2::{Domain, Protocol, Socket, Type};
24
25use netcore::Error;
26use netcore::Result as NcResult;
27use netcore::diag::{PingOpts, ProbeCapabilities, TraceOpts};
28use netcore::path::{Hop, HttpProbeResult, PingResult, TcpProbeResult, TlsProbeResult};
29use netcore::traits::Reachability;
30
31mod icmp;
32
33pub struct ProbeBackend {
36 caps: ProbeCapabilities,
37}
38
39impl ProbeBackend {
40 pub fn new() -> Self {
42 Self {
43 caps: detect_capabilities(),
44 }
45 }
46}
47
48impl Default for ProbeBackend {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl Reachability for ProbeBackend {
55 fn ping(&self, ip: IpAddr, opts: PingOpts) -> NcResult<PingResult> {
56 if !self.caps.has_ping {
57 return Err(Error::Unsupported(
58 "unprivileged ICMP unavailable; run as root or adjust net.ipv4.ping_group_range",
59 ));
60 }
61 icmp::ping(ip, opts)
62 }
63
64 fn tcp_connect(&self, sa: SocketAddr, timeout: Duration) -> NcResult<TcpProbeResult> {
65 let start = Instant::now();
66 match TcpStream::connect_timeout(&sa, timeout) {
67 Ok(s) => {
68 let took = start.elapsed();
69 let _ = s.shutdown(std::net::Shutdown::Both);
70 Ok(TcpProbeResult {
71 addr: sa,
72 connected: true,
73 took,
74 error: None,
75 })
76 }
77 Err(e) => Ok(TcpProbeResult {
78 addr: sa,
79 connected: false,
80 took: start.elapsed(),
81 error: Some(classify_tcp_error(&e)),
82 }),
83 }
84 }
85
86 fn tls_handshake(
87 &self,
88 sa: SocketAddr,
89 sni: &str,
90 timeout: Duration,
91 ) -> NcResult<TlsProbeResult> {
92 let start = Instant::now();
93 let out = do_tls(sa, sni, timeout);
94 let took = start.elapsed();
95 match out {
96 Ok(()) => Ok(TlsProbeResult {
97 peer: sa,
98 sni: sni.into(),
99 negotiated: true,
100 took,
101 error: None,
102 }),
103 Err(e) => Ok(TlsProbeResult {
104 peer: sa,
105 sni: sni.into(),
106 negotiated: false,
107 took,
108 error: Some(e.to_string()),
109 }),
110 }
111 }
112
113 fn http_head(&self, url: &url::Url, timeout: Duration) -> NcResult<HttpProbeResult> {
114 let start = Instant::now();
115 let result = do_http_head(url, timeout);
116 let took = start.elapsed();
117 match result {
118 Ok(status) => Ok(HttpProbeResult {
119 url: url.to_string(),
120 status: Some(status),
121 took,
122 error: None,
123 }),
124 Err(e) => Ok(HttpProbeResult {
125 url: url.to_string(),
126 status: None,
127 took,
128 error: Some(e.to_string()),
129 }),
130 }
131 }
132
133 fn trace(&self, ip: IpAddr, opts: TraceOpts) -> NcResult<Vec<Hop>> {
134 if !self.caps.has_ping {
135 return Err(Error::Unsupported(
136 "trace requires unprivileged ICMP; not available on this host",
137 ));
138 }
139 icmp::trace(ip, opts)
140 }
141
142 fn capabilities(&self) -> ProbeCapabilities {
143 self.caps.clone()
144 }
145}
146
147fn detect_capabilities() -> ProbeCapabilities {
151 let has_v4 = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4)).is_ok();
152 ProbeCapabilities {
153 has_ping: has_v4,
154 has_traceroute: has_v4,
155 has_mtr: false,
156 has_tracepath: false,
157 unprivileged_icmp: has_v4,
158 }
159}
160
161fn classify_tcp_error(e: &std::io::Error) -> String {
162 use std::io::ErrorKind::*;
163 match e.kind() {
164 TimedOut | WouldBlock => "timeout".into(),
165 ConnectionRefused => "refused".into(),
166 ConnectionReset => "reset".into(),
167 HostUnreachable | NetworkUnreachable => "unreachable".into(),
168 _ => e.to_string(),
169 }
170}
171
172fn do_tls(sa: SocketAddr, sni: &str, timeout: Duration) -> Result<(), Box<dyn std::error::Error>> {
175 use rustls::ClientConnection;
176 let server_name = rustls_pki_types::ServerName::try_from(sni.to_string())?;
177 let mut sock = TcpStream::connect_timeout(&sa, timeout)?;
178 sock.set_read_timeout(Some(timeout))?;
179 sock.set_write_timeout(Some(timeout))?;
180 let config = tls_client_config()?;
181 let mut conn = ClientConnection::new(config, server_name)?;
182 while conn.is_handshaking() {
184 if conn.wants_write() {
185 conn.write_tls(&mut sock)?;
186 }
187 if conn.wants_read() {
188 conn.read_tls(&mut sock)?;
189 conn.process_new_packets()?;
190 }
191 }
192 Ok(())
193}
194
195fn tls_client_config() -> Result<Arc<rustls::ClientConfig>, Box<dyn std::error::Error>> {
196 let mut roots = rustls::RootCertStore::empty();
197 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
198 let cfg = rustls::ClientConfig::builder()
199 .with_root_certificates(roots)
200 .with_no_client_auth();
201 Ok(Arc::new(cfg))
202}
203
204fn do_http_head(url: &url::Url, timeout: Duration) -> Result<u16, Box<dyn std::error::Error>> {
207 let host = url.host_str().ok_or("url missing host")?;
208 let port = url.port_or_known_default().ok_or("url missing port")?;
209 let path = if url.path().is_empty() {
210 "/"
211 } else {
212 url.path()
213 };
214 let sas: Vec<SocketAddr> = (host, port).to_socket_addrs_local()?.collect();
215 let sa = *sas.first().ok_or("no address for host")?;
216 let req = format!(
217 "HEAD {path} HTTP/1.1\r\nHost: {host}\r\nUser-Agent: jip/0.1\r\nConnection: close\r\nAccept: */*\r\n\r\n"
218 );
219 let mut buf = [0u8; 2048];
220 let n = match url.scheme() {
221 "https" => http_over_tls(sa, host, &req, timeout, &mut buf)?,
222 "http" => http_plain(sa, &req, timeout, &mut buf)?,
223 other => return Err(format!("unsupported scheme: {other}").into()),
224 };
225 parse_http_status(&buf[..n])
226}
227
228fn http_plain(
229 sa: SocketAddr,
230 req: &str,
231 timeout: Duration,
232 buf: &mut [u8],
233) -> Result<usize, Box<dyn std::error::Error>> {
234 let mut sock = TcpStream::connect_timeout(&sa, timeout)?;
235 sock.set_read_timeout(Some(timeout))?;
236 sock.set_write_timeout(Some(timeout))?;
237 sock.write_all(req.as_bytes())?;
238 Ok(sock.read(buf)?)
239}
240
241fn http_over_tls(
242 sa: SocketAddr,
243 sni: &str,
244 req: &str,
245 timeout: Duration,
246 buf: &mut [u8],
247) -> Result<usize, Box<dyn std::error::Error>> {
248 use rustls::ClientConnection;
249 let server_name = rustls_pki_types::ServerName::try_from(sni.to_string())?;
250 let mut sock = TcpStream::connect_timeout(&sa, timeout)?;
251 sock.set_read_timeout(Some(timeout))?;
252 sock.set_write_timeout(Some(timeout))?;
253 let config = tls_client_config()?;
254 let mut conn = ClientConnection::new(config, server_name)?;
255 let mut tls = rustls::Stream::new(&mut conn, &mut sock);
256 tls.write_all(req.as_bytes())?;
257 Ok(tls.read(buf).unwrap_or(0))
259}
260
261fn parse_http_status(buf: &[u8]) -> Result<u16, Box<dyn std::error::Error>> {
262 let s = std::str::from_utf8(buf)?;
263 let first = s.lines().next().ok_or("empty http response")?;
264 let mut parts = first.split_whitespace();
266 parts.next().ok_or("missing version")?;
267 let code = parts.next().ok_or("missing code")?;
268 Ok(code.parse()?)
269}
270
271trait ToSocketAddrsLocal {
274 fn to_socket_addrs_local(&self) -> std::io::Result<std::vec::IntoIter<SocketAddr>>;
275}
276
277impl ToSocketAddrsLocal for (&str, u16) {
278 fn to_socket_addrs_local(&self) -> std::io::Result<std::vec::IntoIter<SocketAddr>> {
279 use std::net::ToSocketAddrs;
280 let v: Vec<SocketAddr> = self.to_socket_addrs()?.collect();
281 Ok(v.into_iter())
282 }
283}
284
285#[allow(dead_code)]
287fn _touch(_: Ipv4Addr, _: Ipv6Addr) {}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use std::net::{SocketAddr, TcpListener};
293
294 #[test]
295 fn tcp_connect_to_localhost_listener() {
296 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
297 let sa = listener.local_addr().unwrap();
298 let b = ProbeBackend::new();
299 let r = b.tcp_connect(sa, Duration::from_millis(500)).unwrap();
300 assert!(r.connected, "connect to own listener");
301 assert!(r.error.is_none());
302 }
303
304 #[test]
305 fn tcp_connect_refused_on_unused_port() {
306 let sa: SocketAddr = {
311 let l = TcpListener::bind("127.0.0.1:0").unwrap();
312 l.local_addr().unwrap()
313 };
314 let b = ProbeBackend::new();
315 let r = b.tcp_connect(sa, Duration::from_millis(200)).unwrap();
316 assert!(
318 !r.connected || r.error.is_none(),
319 "structural invariants hold even under race"
320 );
321 }
322
323 #[test]
324 fn capabilities_are_detected() {
325 let b = ProbeBackend::new();
326 let c = b.capabilities();
327 let _ = c.has_ping;
330 }
331
332 #[test]
333 fn ping_loopback_if_icmp_available() {
334 let b = ProbeBackend::new();
335 if !b.capabilities().has_ping {
336 return;
337 }
338 let r = b
339 .ping(
340 IpAddr::V4(Ipv4Addr::LOCALHOST),
341 PingOpts {
342 count: 1,
343 timeout: Duration::from_millis(500),
344 },
345 )
346 .unwrap();
347 assert_eq!(r.sent, 1);
348 assert_eq!(r.received, 1, "loopback ICMP must round-trip");
349 }
350}