quic_probe/
lib.rs

1use std::{
2    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
3    sync::Arc,
4    time::Duration,
5};
6
7use thiserror::Error;
8
9#[cfg(feature = "tls-native-roots")]
10fn add_native_roots(roots: &mut rustls::RootCertStore) {
11    tracing::debug!("Loading native root certificates");
12    match rustls_native_certs::load_native_certs() {
13        Ok(certs) => {
14            for cert in certs {
15                let cert = rustls::Certificate(cert.0);
16                if let Err(e) = roots.add(&cert) {
17                    tracing::warn!(?cert, "Failed to parse trust anchor: {}", e);
18                }
19            }
20        }
21
22        Err(e) => {
23            tracing::warn!("Failed load any default trust roots: {}", e);
24        }
25    };
26}
27
28#[cfg(feature = "tls-webpki-roots")]
29fn add_webpki_roots(roots: &mut rustls::RootCertStore) {
30    tracing::debug!("Loading webpki root certificates");
31    roots.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
32        rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
33            ta.subject,
34            ta.spki,
35            ta.name_constraints,
36        )
37    }));
38}
39
40fn load_root_certs() -> rustls::RootCertStore {
41    #[cfg(any(feature = "tls-native-roots", feature = "tls-webpki-roots"))]
42    {
43        let mut roots = rustls::RootCertStore::empty();
44        #[cfg(feature = "tls-native-roots")]
45        add_native_roots(&mut roots);
46        #[cfg(feature = "tls-webpki-roots")]
47        add_webpki_roots(&mut roots);
48        roots
49    }
50    #[cfg(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")))]
51    {
52        tracing::debug!("Creating empty root certificates store");
53        rustls::RootCertStore::empty()
54    }
55}
56
57pub struct ProbeBuilder {
58    pub tls_config: rustls::ClientConfig,
59    transport_config: quinn::TransportConfig,
60    bind_addr: SocketAddr,
61}
62
63impl ProbeBuilder {
64    pub fn initial_rtt(mut self, rtt: Duration) -> Self {
65        self.transport_config.initial_rtt(rtt);
66        self
67    }
68
69    pub fn max_idle_timeout<T: TryInto<quinn::IdleTimeout>>(
70        mut self,
71        timeout: T,
72    ) -> Result<Self, T::Error> {
73        self.transport_config
74            .max_idle_timeout(Some(timeout.try_into()?));
75        Ok(self)
76    }
77
78    pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
79        self.bind_addr = addr;
80        self
81    }
82
83    pub fn build(self) -> Result<Probe, std::io::Error> {
84        let mut endpoint = quinn::Endpoint::client(self.bind_addr)?;
85        let mut client_config = quinn::ClientConfig::new(Arc::new(self.tls_config));
86        client_config.transport_config(Arc::new(self.transport_config));
87        endpoint.set_default_client_config(client_config);
88        Ok(Probe { endpoint })
89    }
90}
91
92impl Default for ProbeBuilder {
93    fn default() -> Self {
94        let roots = load_root_certs();
95        let tls_config = rustls::ClientConfig::builder()
96            .with_safe_default_cipher_suites()
97            .with_safe_default_kx_groups()
98            .with_protocol_versions(&[&rustls::version::TLS13])
99            .unwrap()
100            .with_root_certificates(roots)
101            .with_no_client_auth();
102
103        Self {
104            tls_config,
105            transport_config: quinn::TransportConfig::default(),
106            bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
107        }
108    }
109}
110
111#[derive(Clone)]
112pub struct Probe {
113    endpoint: quinn::Endpoint,
114}
115
116#[derive(Clone, Debug, Error)]
117pub enum ProbeError {
118    #[error("Invalid DNS name: {0}")]
119    InvalidDnsName(String),
120    #[error("Invalid remote address: {0}")]
121    InvalidRemoteAddress(SocketAddr),
122    #[error("Internal error: {0}")]
123    InternalError(String),
124}
125
126impl From<quinn::ConnectError> for ProbeError {
127    fn from(e: quinn::ConnectError) -> Self {
128        match e {
129            quinn::ConnectError::InvalidDnsName(name) => ProbeError::InvalidDnsName(name),
130            quinn::ConnectError::InvalidRemoteAddress(addr) => {
131                ProbeError::InvalidRemoteAddress(addr)
132            }
133            _ => ProbeError::InternalError(format!("{}", e)),
134        }
135    }
136}
137
138impl Probe {
139    pub async fn probe(&self, addr: SocketAddr, server_name: &str) -> Result<bool, ProbeError> {
140        let connect_result = self.endpoint.connect(addr, server_name);
141        tracing::debug!(?connect_result, "Connecting to {}({})", server_name, addr);
142        let connect_result = connect_result?.await;
143        tracing::debug!(?connect_result, "Connection result");
144        Ok(!matches!(
145            connect_result,
146            Err(quinn::ConnectionError::TimedOut)
147        ))
148    }
149}