Skip to main content

muxtop_proto/
tls.rs

1// TLS client configuration for muxtop.
2
3use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
4use std::path::Path;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use rustls_pki_types::CertificateDer;
9use rustls_pki_types::ServerName;
10use rustls_pki_types::pem::PemObject;
11use tokio_rustls::TlsConnector;
12use tokio_rustls::rustls::{ClientConfig, RootCertStore};
13
14/// TLS client configuration errors.
15#[derive(Debug, thiserror::Error)]
16pub enum TlsClientError {
17    #[error("I/O error: {0}")]
18    Io(#[from] std::io::Error),
19
20    #[error("no certificates found in CA file")]
21    NoCertificates,
22
23    #[error("TLS configuration error: {0}")]
24    Rustls(#[from] tokio_rustls::rustls::Error),
25}
26
27/// Errors from parsing the `--remote` host:port target.
28#[derive(Debug, thiserror::Error)]
29pub enum RemoteTargetError {
30    #[error("missing ':port' in remote target: {0:?}")]
31    MissingPort(String),
32
33    #[error("invalid port in remote target {raw:?}: {source}")]
34    InvalidPort {
35        raw: String,
36        #[source]
37        source: std::num::ParseIntError,
38    },
39
40    #[error("DNS lookup failed for {host:?}: {source}")]
41    DnsLookupFailed {
42        host: String,
43        #[source]
44        source: std::io::Error,
45    },
46
47    #[error("DNS lookup for {host:?} returned no addresses")]
48    DnsNoAddresses { host: String },
49
50    #[error("invalid SNI hostname {host:?}: {message}")]
51    InvalidSni { host: String, message: String },
52}
53
54/// Split `host:port` into its `(host, port)` text components.
55///
56/// Handles bracketed IPv6 literals (`[::1]:4242`) and rejects bare IPv6
57/// literals that lack a port (which are ambiguous w.r.t. the colon).
58fn split_host_port(input: &str) -> Result<(&str, u16), RemoteTargetError> {
59    if let Some(rest) = input.strip_prefix('[') {
60        // IPv6: `[host]:port`
61        let close = rest
62            .find(']')
63            .ok_or_else(|| RemoteTargetError::MissingPort(input.to_string()))?;
64        let host = &rest[..close];
65        let after = &rest[close + 1..];
66        let port_str = after
67            .strip_prefix(':')
68            .ok_or_else(|| RemoteTargetError::MissingPort(input.to_string()))?;
69        let port = port_str
70            .parse::<u16>()
71            .map_err(|source| RemoteTargetError::InvalidPort {
72                raw: input.to_string(),
73                source,
74            })?;
75        Ok((host, port))
76    } else {
77        // IPv4 or DNS: split on the **last** ':' so that hostnames containing
78        // letters work; an IPv4 literal `127.0.0.1:4242` only has one colon
79        // anyway.
80        let idx = input
81            .rfind(':')
82            .ok_or_else(|| RemoteTargetError::MissingPort(input.to_string()))?;
83        let host = &input[..idx];
84        let port_str = &input[idx + 1..];
85        let port = port_str
86            .parse::<u16>()
87            .map_err(|source| RemoteTargetError::InvalidPort {
88                raw: input.to_string(),
89                source,
90            })?;
91        Ok((host, port))
92    }
93}
94
95/// Parse a `--remote` target string into `(socket_addr, sni_server_name)`.
96///
97/// Per ADR-30-1: the `host` portion is preserved as-is for SNI so that
98/// hostname-bound certificates work; the IP is resolved separately for the
99/// TCP connect.
100///
101/// Behaviour:
102/// - `host:port` where `host` is an IP literal → `(SocketAddr,
103///   ServerName::IpAddress)`. No DNS round-trip.
104/// - `host:port` where `host` is a DNS name → DNS-resolve `host:port` to an
105///   IP, build `ServerName::DnsName(host.to_string())`. The hostname (NOT the
106///   resolved IP) is what rustls validates against the cert SAN.
107/// - `[ipv6]:port` → IPv6 literal handling (also IpAddress SNI).
108pub fn parse_remote_target(
109    input: &str,
110) -> Result<(SocketAddr, ServerName<'static>), RemoteTargetError> {
111    let (host, port) = split_host_port(input)?;
112
113    // 1. Resolve a SocketAddr for `connect`. If the host is an IP literal we
114    //    skip DNS entirely; otherwise we ask the OS resolver and pick the
115    //    first answer (this matches std's `(host, port).to_socket_addrs()`
116    //    semantics — the kernel honours `getaddrinfo` ordering, including
117    //    IPv6/IPv4 preference).
118    let socket_addr = if let Ok(ip) = IpAddr::from_str(host) {
119        SocketAddr::new(ip, port)
120    } else {
121        let mut iter = (host, port).to_socket_addrs().map_err(|source| {
122            RemoteTargetError::DnsLookupFailed {
123                host: host.to_string(),
124                source,
125            }
126        })?;
127        iter.next()
128            .ok_or_else(|| RemoteTargetError::DnsNoAddresses {
129                host: host.to_string(),
130            })?
131    };
132
133    // 2. Build the SNI ServerName from the **original host string**, not from
134    //    `socket_addr.ip()`. This is the whole point of HIGH-S2: when the
135    //    user types `--remote example.com:4242`, the rustls handshake
136    //    validates the certificate's CN/SAN against `example.com`, not
137    //    against whatever IP DNS happened to return.
138    let server_name = if let Ok(ip) = IpAddr::from_str(host) {
139        ServerName::IpAddress(ip.into())
140    } else {
141        ServerName::try_from(host.to_string()).map_err(|e| RemoteTargetError::InvalidSni {
142            host: host.to_string(),
143            message: e.to_string(),
144        })?
145    };
146
147    Ok((socket_addr, server_name))
148}
149
150/// Build a `TlsConnector` that trusts certificates from a PEM-encoded CA file.
151pub fn connector_from_ca(ca_path: &Path) -> Result<TlsConnector, TlsClientError> {
152    let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(ca_path)
153        .map_err(|_| TlsClientError::NoCertificates)?
154        .collect::<Result<Vec<_>, _>>()
155        .map_err(|_| TlsClientError::NoCertificates)?;
156
157    if certs.is_empty() {
158        return Err(TlsClientError::NoCertificates);
159    }
160
161    let mut root_store = RootCertStore::empty();
162    for cert in certs {
163        root_store.add(cert).map_err(TlsClientError::Rustls)?;
164    }
165
166    let config = ClientConfig::builder()
167        .with_root_certificates(root_store)
168        .with_no_client_auth();
169
170    Ok(TlsConnector::from(Arc::new(config)))
171}
172
173/// Build a `TlsConnector` that skips certificate verification (INSECURE — for development only).
174pub fn connector_insecure() -> TlsConnector {
175    let config = ClientConfig::builder()
176        .dangerous()
177        .with_custom_certificate_verifier(Arc::new(NoVerifier))
178        .with_no_client_auth();
179
180    TlsConnector::from(Arc::new(config))
181}
182
183/// A certificate verifier that accepts any certificate (INSECURE).
184#[derive(Debug)]
185struct NoVerifier;
186
187impl tokio_rustls::rustls::client::danger::ServerCertVerifier for NoVerifier {
188    fn verify_server_cert(
189        &self,
190        _end_entity: &CertificateDer<'_>,
191        _intermediates: &[CertificateDer<'_>],
192        server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>,
193        _ocsp_response: &[u8],
194        _now: tokio_rustls::rustls::pki_types::UnixTime,
195    ) -> Result<tokio_rustls::rustls::client::danger::ServerCertVerified, tokio_rustls::rustls::Error>
196    {
197        // HIGH-S1 (partial): persistent log heartbeat. Each TLS handshake
198        // performed in `--tls-skip-verify` mode emits a warning on the
199        // dedicated `muxtop::insecure` target so operators can grep for it
200        // and so that long-running sessions cannot silently forget that the
201        // session is unauthenticated.
202        tracing::warn!(
203            target: "muxtop::insecure",
204            server_name = ?server_name,
205            "TLS certificate verification disabled — only safe in local dev"
206        );
207        Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
208    }
209
210    fn verify_tls12_signature(
211        &self,
212        _message: &[u8],
213        _cert: &CertificateDer<'_>,
214        _dss: &tokio_rustls::rustls::DigitallySignedStruct,
215    ) -> Result<
216        tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
217        tokio_rustls::rustls::Error,
218    > {
219        Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
220    }
221
222    fn verify_tls13_signature(
223        &self,
224        _message: &[u8],
225        _cert: &CertificateDer<'_>,
226        _dss: &tokio_rustls::rustls::DigitallySignedStruct,
227    ) -> Result<
228        tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
229        tokio_rustls::rustls::Error,
230    > {
231        Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
232    }
233
234    fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
235        tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()
236            .signature_verification_algorithms
237            .supported_schemes()
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use std::io::Write;
245    use tempfile::NamedTempFile;
246
247    fn make_self_signed_cert() -> (String, String) {
248        let san = vec!["localhost".to_string()];
249        let ck = rcgen::generate_simple_self_signed(san).unwrap();
250        (ck.cert.pem(), ck.signing_key.serialize_pem())
251    }
252
253    #[test]
254    fn test_connector_from_ca_valid() {
255        let (cert_pem, _) = make_self_signed_cert();
256        let mut f = NamedTempFile::new().unwrap();
257        f.write_all(cert_pem.as_bytes()).unwrap();
258
259        let connector = connector_from_ca(f.path());
260        assert!(connector.is_ok());
261    }
262
263    #[test]
264    fn test_connector_from_ca_missing_file() {
265        let result = connector_from_ca(Path::new("/nonexistent/ca.pem"));
266        assert!(result.is_err());
267    }
268
269    #[test]
270    fn test_connector_insecure_builds() {
271        let _connector = connector_insecure();
272    }
273
274    #[test]
275    fn test_parse_remote_target_ipv4_literal() {
276        let (addr, sni) = parse_remote_target("127.0.0.1:4242").unwrap();
277        assert_eq!(addr, "127.0.0.1:4242".parse::<SocketAddr>().unwrap());
278        match sni {
279            ServerName::IpAddress(_) => {}
280            other => panic!("expected ServerName::IpAddress, got {other:?}"),
281        }
282    }
283
284    #[test]
285    fn test_parse_remote_target_ipv6_literal() {
286        let (addr, sni) = parse_remote_target("[::1]:4242").unwrap();
287        assert_eq!(addr, "[::1]:4242".parse::<SocketAddr>().unwrap());
288        match sni {
289            ServerName::IpAddress(_) => {}
290            other => panic!("expected ServerName::IpAddress, got {other:?}"),
291        }
292    }
293
294    #[test]
295    fn test_parse_remote_target_hostname_uses_dns_sni() {
296        // `localhost` resolves on essentially every Unix and CI runner. We
297        // only need to confirm the SNI name is `DnsName("localhost")` —
298        // without a real network call to the open internet.
299        let result = parse_remote_target("localhost:4242");
300        let (addr, sni) = match result {
301            Ok(v) => v,
302            Err(e) => {
303                // `localhost` resolution can theoretically fail on a
304                // hardened/no-network sandbox; in that case there's nothing
305                // useful to assert here, so skip rather than fail.
306                eprintln!("skipping hostname SNI test: {e}");
307                return;
308            }
309        };
310        assert_eq!(addr.port(), 4242);
311        match sni {
312            ServerName::DnsName(name) => {
313                assert_eq!(name.as_ref(), "localhost");
314            }
315            other => panic!("expected ServerName::DnsName, got {other:?}"),
316        }
317    }
318
319    #[test]
320    fn test_parse_remote_target_missing_port() {
321        let err = parse_remote_target("127.0.0.1").unwrap_err();
322        assert!(matches!(err, RemoteTargetError::MissingPort(_)));
323    }
324
325    #[test]
326    fn test_parse_remote_target_invalid_port() {
327        let err = parse_remote_target("127.0.0.1:notaport").unwrap_err();
328        assert!(matches!(err, RemoteTargetError::InvalidPort { .. }));
329    }
330
331    #[test]
332    fn test_split_host_port_ipv6_no_port() {
333        // `[::1]` with no `:port` after the bracket → MissingPort.
334        let err = split_host_port("[::1]").unwrap_err();
335        assert!(matches!(err, RemoteTargetError::MissingPort(_)));
336    }
337}