1use std::{
2 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
3 sync::Arc,
4 time::Duration,
5};
6
7use thiserror::Error;
8
9#[cfg(feature = "tls-native-roots")]
10fn add_native_roots(roots: &mut rustls::RootCertStore) {
11 tracing::debug!("Loading native root certificates");
12 match rustls_native_certs::load_native_certs() {
13 Ok(certs) => {
14 for cert in certs {
15 let cert = rustls::Certificate(cert.0);
16 if let Err(e) = roots.add(&cert) {
17 tracing::warn!(?cert, "Failed to parse trust anchor: {}", e);
18 }
19 }
20 }
21
22 Err(e) => {
23 tracing::warn!("Failed load any default trust roots: {}", e);
24 }
25 };
26}
27
28#[cfg(feature = "tls-webpki-roots")]
29fn add_webpki_roots(roots: &mut rustls::RootCertStore) {
30 tracing::debug!("Loading webpki root certificates");
31 roots.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
32 rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
33 ta.subject,
34 ta.spki,
35 ta.name_constraints,
36 )
37 }));
38}
39
40fn load_root_certs() -> rustls::RootCertStore {
41 #[cfg(any(feature = "tls-native-roots", feature = "tls-webpki-roots"))]
42 {
43 let mut roots = rustls::RootCertStore::empty();
44 #[cfg(feature = "tls-native-roots")]
45 add_native_roots(&mut roots);
46 #[cfg(feature = "tls-webpki-roots")]
47 add_webpki_roots(&mut roots);
48 roots
49 }
50 #[cfg(not(any(feature = "tls-native-roots", feature = "tls-webpki-roots")))]
51 {
52 tracing::debug!("Creating empty root certificates store");
53 rustls::RootCertStore::empty()
54 }
55}
56
57pub struct ProbeBuilder {
58 pub tls_config: rustls::ClientConfig,
59 transport_config: quinn::TransportConfig,
60 bind_addr: SocketAddr,
61}
62
63impl ProbeBuilder {
64 pub fn initial_rtt(mut self, rtt: Duration) -> Self {
65 self.transport_config.initial_rtt(rtt);
66 self
67 }
68
69 pub fn max_idle_timeout<T: TryInto<quinn::IdleTimeout>>(
70 mut self,
71 timeout: T,
72 ) -> Result<Self, T::Error> {
73 self.transport_config
74 .max_idle_timeout(Some(timeout.try_into()?));
75 Ok(self)
76 }
77
78 pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
79 self.bind_addr = addr;
80 self
81 }
82
83 pub fn build(self) -> Result<Probe, std::io::Error> {
84 let mut endpoint = quinn::Endpoint::client(self.bind_addr)?;
85 let mut client_config = quinn::ClientConfig::new(Arc::new(self.tls_config));
86 client_config.transport_config(Arc::new(self.transport_config));
87 endpoint.set_default_client_config(client_config);
88 Ok(Probe { endpoint })
89 }
90}
91
92impl Default for ProbeBuilder {
93 fn default() -> Self {
94 let roots = load_root_certs();
95 let tls_config = rustls::ClientConfig::builder()
96 .with_safe_default_cipher_suites()
97 .with_safe_default_kx_groups()
98 .with_protocol_versions(&[&rustls::version::TLS13])
99 .unwrap()
100 .with_root_certificates(roots)
101 .with_no_client_auth();
102
103 Self {
104 tls_config,
105 transport_config: quinn::TransportConfig::default(),
106 bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
107 }
108 }
109}
110
111#[derive(Clone)]
112pub struct Probe {
113 endpoint: quinn::Endpoint,
114}
115
116#[derive(Clone, Debug, Error)]
117pub enum ProbeError {
118 #[error("Invalid DNS name: {0}")]
119 InvalidDnsName(String),
120 #[error("Invalid remote address: {0}")]
121 InvalidRemoteAddress(SocketAddr),
122 #[error("Internal error: {0}")]
123 InternalError(String),
124}
125
126impl From<quinn::ConnectError> for ProbeError {
127 fn from(e: quinn::ConnectError) -> Self {
128 match e {
129 quinn::ConnectError::InvalidDnsName(name) => ProbeError::InvalidDnsName(name),
130 quinn::ConnectError::InvalidRemoteAddress(addr) => {
131 ProbeError::InvalidRemoteAddress(addr)
132 }
133 _ => ProbeError::InternalError(format!("{}", e)),
134 }
135 }
136}
137
138impl Probe {
139 pub async fn probe(&self, addr: SocketAddr, server_name: &str) -> Result<bool, ProbeError> {
140 let connect_result = self.endpoint.connect(addr, server_name);
141 tracing::debug!(?connect_result, "Connecting to {}({})", server_name, addr);
142 let connect_result = connect_result?.await;
143 tracing::debug!(?connect_result, "Connection result");
144 Ok(!matches!(
145 connect_result,
146 Err(quinn::ConnectionError::TimedOut)
147 ))
148 }
149}