1use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
4use std::path::Path;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use rustls_pki_types::CertificateDer;
9use rustls_pki_types::ServerName;
10use rustls_pki_types::pem::PemObject;
11use tokio_rustls::TlsConnector;
12use tokio_rustls::rustls::{ClientConfig, RootCertStore};
13
14#[derive(Debug, thiserror::Error)]
16pub enum TlsClientError {
17 #[error("I/O error: {0}")]
18 Io(#[from] std::io::Error),
19
20 #[error("no certificates found in CA file")]
21 NoCertificates,
22
23 #[error("TLS configuration error: {0}")]
24 Rustls(#[from] tokio_rustls::rustls::Error),
25}
26
27#[derive(Debug, thiserror::Error)]
29pub enum RemoteTargetError {
30 #[error("missing ':port' in remote target: {0:?}")]
31 MissingPort(String),
32
33 #[error("invalid port in remote target {raw:?}: {source}")]
34 InvalidPort {
35 raw: String,
36 #[source]
37 source: std::num::ParseIntError,
38 },
39
40 #[error("DNS lookup failed for {host:?}: {source}")]
41 DnsLookupFailed {
42 host: String,
43 #[source]
44 source: std::io::Error,
45 },
46
47 #[error("DNS lookup for {host:?} returned no addresses")]
48 DnsNoAddresses { host: String },
49
50 #[error("invalid SNI hostname {host:?}: {message}")]
51 InvalidSni { host: String, message: String },
52}
53
54fn split_host_port(input: &str) -> Result<(&str, u16), RemoteTargetError> {
59 if let Some(rest) = input.strip_prefix('[') {
60 let close = rest
62 .find(']')
63 .ok_or_else(|| RemoteTargetError::MissingPort(input.to_string()))?;
64 let host = &rest[..close];
65 let after = &rest[close + 1..];
66 let port_str = after
67 .strip_prefix(':')
68 .ok_or_else(|| RemoteTargetError::MissingPort(input.to_string()))?;
69 let port = port_str
70 .parse::<u16>()
71 .map_err(|source| RemoteTargetError::InvalidPort {
72 raw: input.to_string(),
73 source,
74 })?;
75 Ok((host, port))
76 } else {
77 let idx = input
81 .rfind(':')
82 .ok_or_else(|| RemoteTargetError::MissingPort(input.to_string()))?;
83 let host = &input[..idx];
84 let port_str = &input[idx + 1..];
85 let port = port_str
86 .parse::<u16>()
87 .map_err(|source| RemoteTargetError::InvalidPort {
88 raw: input.to_string(),
89 source,
90 })?;
91 Ok((host, port))
92 }
93}
94
95pub fn parse_remote_target(
109 input: &str,
110) -> Result<(SocketAddr, ServerName<'static>), RemoteTargetError> {
111 let (host, port) = split_host_port(input)?;
112
113 let socket_addr = if let Ok(ip) = IpAddr::from_str(host) {
119 SocketAddr::new(ip, port)
120 } else {
121 let mut iter = (host, port).to_socket_addrs().map_err(|source| {
122 RemoteTargetError::DnsLookupFailed {
123 host: host.to_string(),
124 source,
125 }
126 })?;
127 iter.next()
128 .ok_or_else(|| RemoteTargetError::DnsNoAddresses {
129 host: host.to_string(),
130 })?
131 };
132
133 let server_name = if let Ok(ip) = IpAddr::from_str(host) {
139 ServerName::IpAddress(ip.into())
140 } else {
141 ServerName::try_from(host.to_string()).map_err(|e| RemoteTargetError::InvalidSni {
142 host: host.to_string(),
143 message: e.to_string(),
144 })?
145 };
146
147 Ok((socket_addr, server_name))
148}
149
150pub fn connector_from_ca(ca_path: &Path) -> Result<TlsConnector, TlsClientError> {
152 let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(ca_path)
153 .map_err(|_| TlsClientError::NoCertificates)?
154 .collect::<Result<Vec<_>, _>>()
155 .map_err(|_| TlsClientError::NoCertificates)?;
156
157 if certs.is_empty() {
158 return Err(TlsClientError::NoCertificates);
159 }
160
161 let mut root_store = RootCertStore::empty();
162 for cert in certs {
163 root_store.add(cert).map_err(TlsClientError::Rustls)?;
164 }
165
166 let config = ClientConfig::builder()
167 .with_root_certificates(root_store)
168 .with_no_client_auth();
169
170 Ok(TlsConnector::from(Arc::new(config)))
171}
172
173pub fn connector_insecure() -> TlsConnector {
175 let config = ClientConfig::builder()
176 .dangerous()
177 .with_custom_certificate_verifier(Arc::new(NoVerifier))
178 .with_no_client_auth();
179
180 TlsConnector::from(Arc::new(config))
181}
182
183#[derive(Debug)]
185struct NoVerifier;
186
187impl tokio_rustls::rustls::client::danger::ServerCertVerifier for NoVerifier {
188 fn verify_server_cert(
189 &self,
190 _end_entity: &CertificateDer<'_>,
191 _intermediates: &[CertificateDer<'_>],
192 server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>,
193 _ocsp_response: &[u8],
194 _now: tokio_rustls::rustls::pki_types::UnixTime,
195 ) -> Result<tokio_rustls::rustls::client::danger::ServerCertVerified, tokio_rustls::rustls::Error>
196 {
197 tracing::warn!(
203 target: "muxtop::insecure",
204 server_name = ?server_name,
205 "TLS certificate verification disabled — only safe in local dev"
206 );
207 Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
208 }
209
210 fn verify_tls12_signature(
211 &self,
212 _message: &[u8],
213 _cert: &CertificateDer<'_>,
214 _dss: &tokio_rustls::rustls::DigitallySignedStruct,
215 ) -> Result<
216 tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
217 tokio_rustls::rustls::Error,
218 > {
219 Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
220 }
221
222 fn verify_tls13_signature(
223 &self,
224 _message: &[u8],
225 _cert: &CertificateDer<'_>,
226 _dss: &tokio_rustls::rustls::DigitallySignedStruct,
227 ) -> Result<
228 tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
229 tokio_rustls::rustls::Error,
230 > {
231 Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
232 }
233
234 fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
235 tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()
236 .signature_verification_algorithms
237 .supported_schemes()
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use std::io::Write;
245 use tempfile::NamedTempFile;
246
247 fn make_self_signed_cert() -> (String, String) {
248 let san = vec!["localhost".to_string()];
249 let ck = rcgen::generate_simple_self_signed(san).unwrap();
250 (ck.cert.pem(), ck.signing_key.serialize_pem())
251 }
252
253 #[test]
254 fn test_connector_from_ca_valid() {
255 let (cert_pem, _) = make_self_signed_cert();
256 let mut f = NamedTempFile::new().unwrap();
257 f.write_all(cert_pem.as_bytes()).unwrap();
258
259 let connector = connector_from_ca(f.path());
260 assert!(connector.is_ok());
261 }
262
263 #[test]
264 fn test_connector_from_ca_missing_file() {
265 let result = connector_from_ca(Path::new("/nonexistent/ca.pem"));
266 assert!(result.is_err());
267 }
268
269 #[test]
270 fn test_connector_insecure_builds() {
271 let _connector = connector_insecure();
272 }
273
274 #[test]
275 fn test_parse_remote_target_ipv4_literal() {
276 let (addr, sni) = parse_remote_target("127.0.0.1:4242").unwrap();
277 assert_eq!(addr, "127.0.0.1:4242".parse::<SocketAddr>().unwrap());
278 match sni {
279 ServerName::IpAddress(_) => {}
280 other => panic!("expected ServerName::IpAddress, got {other:?}"),
281 }
282 }
283
284 #[test]
285 fn test_parse_remote_target_ipv6_literal() {
286 let (addr, sni) = parse_remote_target("[::1]:4242").unwrap();
287 assert_eq!(addr, "[::1]:4242".parse::<SocketAddr>().unwrap());
288 match sni {
289 ServerName::IpAddress(_) => {}
290 other => panic!("expected ServerName::IpAddress, got {other:?}"),
291 }
292 }
293
294 #[test]
295 fn test_parse_remote_target_hostname_uses_dns_sni() {
296 let result = parse_remote_target("localhost:4242");
300 let (addr, sni) = match result {
301 Ok(v) => v,
302 Err(e) => {
303 eprintln!("skipping hostname SNI test: {e}");
307 return;
308 }
309 };
310 assert_eq!(addr.port(), 4242);
311 match sni {
312 ServerName::DnsName(name) => {
313 assert_eq!(name.as_ref(), "localhost");
314 }
315 other => panic!("expected ServerName::DnsName, got {other:?}"),
316 }
317 }
318
319 #[test]
320 fn test_parse_remote_target_missing_port() {
321 let err = parse_remote_target("127.0.0.1").unwrap_err();
322 assert!(matches!(err, RemoteTargetError::MissingPort(_)));
323 }
324
325 #[test]
326 fn test_parse_remote_target_invalid_port() {
327 let err = parse_remote_target("127.0.0.1:notaport").unwrap_err();
328 assert!(matches!(err, RemoteTargetError::InvalidPort { .. }));
329 }
330
331 #[test]
332 fn test_split_host_port_ipv6_no_port() {
333 let err = split_host_port("[::1]").unwrap_err();
335 assert!(matches!(err, RemoteTargetError::MissingPort(_)));
336 }
337}