Skip to main content

remote/
tls.rs

1//! TLS support for encrypted and authenticated connections.
2//!
3//! This module provides certificate generation and TLS configuration for:
4//! - Master↔rcpd connections (rcpd is server, master verifies fingerprint)
5//! - Source↔Destination connections (mutual TLS with client certificates)
6use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
7use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
8use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
9use rustls::{
10    ClientConfig, DigitallySignedStruct, DistinguishedName, ServerConfig, SignatureScheme,
11};
12use sha2::{Digest, Sha256};
13use std::sync::Arc;
14
15/// A certificate fingerprint (SHA-256 of DER-encoded certificate).
16pub type Fingerprint = [u8; 32];
17
18/// A certified key pair (certificate + private key) with its fingerprint.
19#[derive(Clone)]
20pub struct CertifiedKey {
21    pub cert_der: Vec<u8>,
22    pub key_der: Vec<u8>,
23    pub fingerprint: Fingerprint,
24}
25
26/// Generates an ephemeral self-signed certificate using Ed25519.
27///
28/// The certificate is valid for 1 day (doesn't matter since ephemeral).
29/// Returns the certificate, private key, and fingerprint.
30pub fn generate_self_signed_cert() -> anyhow::Result<CertifiedKey> {
31    use rcgen::{CertificateParams, KeyPair};
32    // generate Ed25519 key pair
33    let key_pair = KeyPair::generate_for(&rcgen::PKCS_ED25519)?;
34    // create certificate parameters with random subject
35    let mut params = CertificateParams::default();
36    params.distinguished_name = rcgen::DistinguishedName::new();
37    params.distinguished_name.push(
38        rcgen::DnType::CommonName,
39        format!("rcp-{}", rand::random::<u64>()),
40    );
41    // self-sign the certificate
42    let cert = params.self_signed(&key_pair)?;
43    let cert_der = cert.der().to_vec();
44    let key_der = key_pair.serialize_der();
45    // compute fingerprint
46    let fingerprint = compute_fingerprint(&cert_der);
47    Ok(CertifiedKey {
48        cert_der,
49        key_der,
50        fingerprint,
51    })
52}
53
54/// Computes SHA-256 fingerprint of a DER-encoded certificate.
55pub fn compute_fingerprint(cert_der: &[u8]) -> Fingerprint {
56    let mut hasher = Sha256::new();
57    hasher.update(cert_der);
58    hasher.finalize().into()
59}
60
61/// Converts a fingerprint to lowercase hex string (64 characters).
62pub fn fingerprint_to_hex(fp: &Fingerprint) -> String {
63    hex::encode(fp)
64}
65
66/// Parses a fingerprint from hex string.
67pub fn fingerprint_from_hex(s: &str) -> anyhow::Result<Fingerprint> {
68    let bytes = hex::decode(s)?;
69    if bytes.len() != 32 {
70        anyhow::bail!(
71            "fingerprint must be 32 bytes (64 hex chars), got {}",
72            bytes.len()
73        );
74    }
75    let mut fp = [0u8; 32];
76    fp.copy_from_slice(&bytes);
77    Ok(fp)
78}
79
80/// Creates a TLS server config for rcpd (no client authentication required).
81///
82/// Used for master→rcpd connections where master verifies rcpd's certificate.
83pub fn create_server_config(cert_key: &CertifiedKey) -> anyhow::Result<Arc<ServerConfig>> {
84    let cert = CertificateDer::from(cert_key.cert_der.clone());
85    let key = PrivateKeyDer::try_from(cert_key.key_der.clone())
86        .map_err(|e| anyhow::anyhow!("invalid private key: {e}"))?;
87    let config = ServerConfig::builder()
88        .with_no_client_auth()
89        .with_single_cert(vec![cert], key)?;
90    Ok(Arc::new(config))
91}
92
93/// Creates a TLS server config with client certificate verification.
94///
95/// Used for source→destination connections where source verifies destination's client cert.
96pub fn create_server_config_with_client_auth(
97    cert_key: &CertifiedKey,
98    expected_client_fingerprint: Fingerprint,
99) -> anyhow::Result<Arc<ServerConfig>> {
100    let cert = CertificateDer::from(cert_key.cert_der.clone());
101    let key = PrivateKeyDer::try_from(cert_key.key_der.clone())
102        .map_err(|e| anyhow::anyhow!("invalid private key: {e}"))?;
103    let client_verifier = Arc::new(FingerprintClientCertVerifier::new(
104        expected_client_fingerprint,
105    ));
106    let config = ServerConfig::builder()
107        .with_client_cert_verifier(client_verifier)
108        .with_single_cert(vec![cert], key)?;
109    Ok(Arc::new(config))
110}
111
112/// Creates a TLS client config that verifies the server's certificate fingerprint.
113///
114/// Used for master→rcpd connections where master has no client certificate.
115pub fn create_client_config(expected_server_fingerprint: Fingerprint) -> Arc<ClientConfig> {
116    let verifier = Arc::new(FingerprintServerCertVerifier::new(
117        expected_server_fingerprint,
118    ));
119    let config = ClientConfig::builder()
120        .dangerous()
121        .with_custom_certificate_verifier(verifier)
122        .with_no_client_auth();
123    Arc::new(config)
124}
125
126/// Creates a TLS client config with a client certificate.
127///
128/// Used for destination→source connections where destination presents its certificate.
129pub fn create_client_config_with_cert(
130    client_cert_key: &CertifiedKey,
131    expected_server_fingerprint: Fingerprint,
132) -> anyhow::Result<Arc<ClientConfig>> {
133    let verifier = Arc::new(FingerprintServerCertVerifier::new(
134        expected_server_fingerprint,
135    ));
136    let cert = CertificateDer::from(client_cert_key.cert_der.clone());
137    let key = PrivateKeyDer::try_from(client_cert_key.key_der.clone())
138        .map_err(|e| anyhow::anyhow!("invalid private key: {e}"))?;
139    let config = ClientConfig::builder()
140        .dangerous()
141        .with_custom_certificate_verifier(verifier)
142        .with_client_auth_cert(vec![cert], key)?;
143    Ok(Arc::new(config))
144}
145
146/// Server certificate verifier that checks the certificate's fingerprint.
147#[derive(Debug)]
148struct FingerprintServerCertVerifier {
149    expected_fingerprint: Fingerprint,
150}
151
152impl FingerprintServerCertVerifier {
153    fn new(expected_fingerprint: Fingerprint) -> Self {
154        Self {
155            expected_fingerprint,
156        }
157    }
158}
159
160impl ServerCertVerifier for FingerprintServerCertVerifier {
161    fn verify_server_cert(
162        &self,
163        end_entity: &CertificateDer<'_>,
164        _intermediates: &[CertificateDer<'_>],
165        _server_name: &ServerName<'_>,
166        _ocsp_response: &[u8],
167        _now: UnixTime,
168    ) -> Result<ServerCertVerified, rustls::Error> {
169        let actual_fingerprint = compute_fingerprint(end_entity.as_ref());
170        if actual_fingerprint == self.expected_fingerprint {
171            Ok(ServerCertVerified::assertion())
172        } else {
173            tracing::error!(
174                "TLS server certificate fingerprint mismatch: expected {}, got {}",
175                fingerprint_to_hex(&self.expected_fingerprint),
176                fingerprint_to_hex(&actual_fingerprint)
177            );
178            Err(rustls::Error::InvalidCertificate(
179                rustls::CertificateError::BadSignature,
180            ))
181        }
182    }
183    fn verify_tls12_signature(
184        &self,
185        _message: &[u8],
186        _cert: &CertificateDer<'_>,
187        _dss: &DigitallySignedStruct,
188    ) -> Result<HandshakeSignatureValid, rustls::Error> {
189        // we trust the certificate based on fingerprint, not signature chain
190        Ok(HandshakeSignatureValid::assertion())
191    }
192    fn verify_tls13_signature(
193        &self,
194        _message: &[u8],
195        _cert: &CertificateDer<'_>,
196        _dss: &DigitallySignedStruct,
197    ) -> Result<HandshakeSignatureValid, rustls::Error> {
198        // we trust the certificate based on fingerprint, not signature chain
199        Ok(HandshakeSignatureValid::assertion())
200    }
201    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
202        vec![
203            SignatureScheme::ED25519,
204            SignatureScheme::ECDSA_NISTP256_SHA256,
205            SignatureScheme::ECDSA_NISTP384_SHA384,
206            SignatureScheme::RSA_PSS_SHA256,
207            SignatureScheme::RSA_PSS_SHA384,
208            SignatureScheme::RSA_PSS_SHA512,
209            SignatureScheme::RSA_PKCS1_SHA256,
210            SignatureScheme::RSA_PKCS1_SHA384,
211            SignatureScheme::RSA_PKCS1_SHA512,
212        ]
213    }
214}
215
216/// Client certificate verifier that checks the certificate's fingerprint.
217#[derive(Debug)]
218struct FingerprintClientCertVerifier {
219    expected_fingerprint: Fingerprint,
220}
221
222impl FingerprintClientCertVerifier {
223    fn new(expected_fingerprint: Fingerprint) -> Self {
224        Self {
225            expected_fingerprint,
226        }
227    }
228}
229
230impl ClientCertVerifier for FingerprintClientCertVerifier {
231    fn root_hint_subjects(&self) -> &[DistinguishedName] {
232        &[]
233    }
234    fn verify_client_cert(
235        &self,
236        end_entity: &CertificateDer<'_>,
237        _intermediates: &[CertificateDer<'_>],
238        _now: UnixTime,
239    ) -> Result<ClientCertVerified, rustls::Error> {
240        let actual_fingerprint = compute_fingerprint(end_entity.as_ref());
241        if actual_fingerprint == self.expected_fingerprint {
242            Ok(ClientCertVerified::assertion())
243        } else {
244            tracing::error!(
245                "TLS client certificate fingerprint mismatch: expected {}, got {}",
246                fingerprint_to_hex(&self.expected_fingerprint),
247                fingerprint_to_hex(&actual_fingerprint)
248            );
249            Err(rustls::Error::InvalidCertificate(
250                rustls::CertificateError::BadSignature,
251            ))
252        }
253    }
254    fn verify_tls12_signature(
255        &self,
256        _message: &[u8],
257        _cert: &CertificateDer<'_>,
258        _dss: &DigitallySignedStruct,
259    ) -> Result<HandshakeSignatureValid, rustls::Error> {
260        Ok(HandshakeSignatureValid::assertion())
261    }
262    fn verify_tls13_signature(
263        &self,
264        _message: &[u8],
265        _cert: &CertificateDer<'_>,
266        _dss: &DigitallySignedStruct,
267    ) -> Result<HandshakeSignatureValid, rustls::Error> {
268        Ok(HandshakeSignatureValid::assertion())
269    }
270    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
271        vec![
272            SignatureScheme::ED25519,
273            SignatureScheme::ECDSA_NISTP256_SHA256,
274            SignatureScheme::ECDSA_NISTP384_SHA384,
275            SignatureScheme::RSA_PSS_SHA256,
276            SignatureScheme::RSA_PSS_SHA384,
277            SignatureScheme::RSA_PSS_SHA512,
278            SignatureScheme::RSA_PKCS1_SHA256,
279            SignatureScheme::RSA_PKCS1_SHA384,
280            SignatureScheme::RSA_PKCS1_SHA512,
281        ]
282    }
283    fn client_auth_mandatory(&self) -> bool {
284        true
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    fn install_crypto_provider() {
293        rustls::crypto::ring::default_provider()
294            .install_default()
295            .ok(); // ignore if already installed
296    }
297
298    #[test]
299    fn test_generate_cert_and_fingerprint() {
300        install_crypto_provider();
301        let cert_key = generate_self_signed_cert().unwrap();
302        assert_eq!(cert_key.fingerprint.len(), 32);
303        assert!(!cert_key.cert_der.is_empty());
304        assert!(!cert_key.key_der.is_empty());
305        // fingerprint should be deterministic
306        let fp2 = compute_fingerprint(&cert_key.cert_der);
307        assert_eq!(cert_key.fingerprint, fp2);
308    }
309
310    #[test]
311    fn test_fingerprint_hex_roundtrip() {
312        install_crypto_provider();
313        let cert_key = generate_self_signed_cert().unwrap();
314        let hex = fingerprint_to_hex(&cert_key.fingerprint);
315        assert_eq!(hex.len(), 64);
316        let fp2 = fingerprint_from_hex(&hex).unwrap();
317        assert_eq!(cert_key.fingerprint, fp2);
318    }
319
320    #[test]
321    fn test_fingerprint_from_hex_invalid() {
322        // wrong length
323        assert!(fingerprint_from_hex("abcd").is_err());
324        // invalid hex
325        assert!(fingerprint_from_hex("zzzz").is_err());
326    }
327
328    #[test]
329    fn test_create_server_config() {
330        install_crypto_provider();
331        let cert_key = generate_self_signed_cert().unwrap();
332        let config = create_server_config(&cert_key).unwrap();
333        assert!(config.alpn_protocols.is_empty());
334    }
335
336    #[test]
337    fn test_create_client_config() {
338        install_crypto_provider();
339        let fp = [0u8; 32];
340        let config = create_client_config(fp);
341        assert!(config.alpn_protocols.is_empty());
342    }
343
344    #[test]
345    fn test_server_fingerprint_verifier_accepts_matching() {
346        install_crypto_provider();
347        let cert_key = generate_self_signed_cert().unwrap();
348        let verifier = FingerprintServerCertVerifier::new(cert_key.fingerprint);
349        let cert = CertificateDer::from(cert_key.cert_der);
350        let server_name = ServerName::try_from("rcp").unwrap();
351        let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], UnixTime::now());
352        assert!(result.is_ok());
353    }
354
355    #[test]
356    fn test_server_fingerprint_verifier_rejects_mismatch() {
357        install_crypto_provider();
358        let cert_key = generate_self_signed_cert().unwrap();
359        // use a different fingerprint (all zeros)
360        let wrong_fingerprint = [0u8; 32];
361        let verifier = FingerprintServerCertVerifier::new(wrong_fingerprint);
362        let cert = CertificateDer::from(cert_key.cert_der);
363        let server_name = ServerName::try_from("rcp").unwrap();
364        let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], UnixTime::now());
365        assert!(result.is_err());
366        // verify it's the right error type
367        match result {
368            Err(rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature)) => {}
369            other => panic!("expected BadSignature error, got: {:?}", other),
370        }
371    }
372
373    #[test]
374    fn test_client_fingerprint_verifier_accepts_matching() {
375        install_crypto_provider();
376        let cert_key = generate_self_signed_cert().unwrap();
377        let verifier = FingerprintClientCertVerifier::new(cert_key.fingerprint);
378        let cert = CertificateDer::from(cert_key.cert_der);
379        let result = verifier.verify_client_cert(&cert, &[], UnixTime::now());
380        assert!(result.is_ok());
381    }
382
383    #[test]
384    fn test_client_fingerprint_verifier_rejects_mismatch() {
385        install_crypto_provider();
386        let cert_key = generate_self_signed_cert().unwrap();
387        // use a different fingerprint (all zeros)
388        let wrong_fingerprint = [0u8; 32];
389        let verifier = FingerprintClientCertVerifier::new(wrong_fingerprint);
390        let cert = CertificateDer::from(cert_key.cert_der);
391        let result = verifier.verify_client_cert(&cert, &[], UnixTime::now());
392        assert!(result.is_err());
393        // verify it's the right error type
394        match result {
395            Err(rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature)) => {}
396            other => panic!("expected BadSignature error, got: {:?}", other),
397        }
398    }
399
400    #[test]
401    fn test_client_verifier_requires_auth() {
402        install_crypto_provider();
403        let verifier = FingerprintClientCertVerifier::new([0u8; 32]);
404        assert!(verifier.client_auth_mandatory());
405    }
406}
407
408#[cfg(test)]
409mod integration_tests {
410    use super::*;
411    use tokio::io::{AsyncReadExt, AsyncWriteExt};
412    use tokio::net::{TcpListener, TcpStream};
413    use tokio_rustls::{TlsAcceptor, TlsConnector};
414
415    fn install_crypto_provider() {
416        rustls::crypto::ring::default_provider()
417            .install_default()
418            .ok();
419    }
420
421    /// Test TLS handshake succeeds with correct fingerprints.
422    #[tokio::test]
423    async fn test_tls_handshake_success_with_matching_fingerprint() {
424        install_crypto_provider();
425        // generate server certificate
426        let server_cert = generate_self_signed_cert().unwrap();
427        let server_config = create_server_config(&server_cert).unwrap();
428        let acceptor = TlsAcceptor::from(server_config);
429        // create client config with correct fingerprint
430        let client_config = create_client_config(server_cert.fingerprint);
431        let connector = TlsConnector::from(client_config);
432        // bind server
433        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
434        let addr = listener.local_addr().unwrap();
435        // spawn server task
436        let server_acceptor = acceptor.clone();
437        let server_task = tokio::spawn(async move {
438            let (stream, _) = listener.accept().await.unwrap();
439            let mut tls_stream = server_acceptor.accept(stream).await.unwrap();
440            tls_stream.write_all(b"hello").await.unwrap();
441            tls_stream.shutdown().await.unwrap();
442        });
443        // client connects
444        let stream = TcpStream::connect(addr).await.unwrap();
445        let server_name = ServerName::try_from("rcp").unwrap();
446        let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
447        let mut buf = [0u8; 5];
448        tls_stream.read_exact(&mut buf).await.unwrap();
449        assert_eq!(&buf, b"hello");
450        server_task.await.unwrap();
451    }
452
453    /// Test TLS handshake fails when client has wrong server fingerprint.
454    #[tokio::test]
455    async fn test_tls_handshake_fails_with_wrong_server_fingerprint() {
456        install_crypto_provider();
457        // generate server certificate
458        let server_cert = generate_self_signed_cert().unwrap();
459        let server_config = create_server_config(&server_cert).unwrap();
460        let acceptor = TlsAcceptor::from(server_config);
461        // create client config with WRONG fingerprint
462        let wrong_fingerprint = [0xAB; 32];
463        let client_config = create_client_config(wrong_fingerprint);
464        let connector = TlsConnector::from(client_config);
465        // bind server
466        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
467        let addr = listener.local_addr().unwrap();
468        // spawn server task (will fail when client rejects cert)
469        let server_acceptor = acceptor.clone();
470        let server_task = tokio::spawn(async move {
471            let (stream, _) = listener.accept().await.unwrap();
472            // server accept may fail when client aborts handshake
473            let _ = server_acceptor.accept(stream).await;
474        });
475        // client connects - should fail due to fingerprint mismatch
476        let stream = TcpStream::connect(addr).await.unwrap();
477        let server_name = ServerName::try_from("rcp").unwrap();
478        let result = connector.connect(server_name, stream).await;
479        assert!(result.is_err(), "expected TLS handshake to fail");
480        let err = result.unwrap_err();
481        // the error should indicate certificate validation failed
482        assert!(
483            err.to_string().contains("certificate")
484                || err.to_string().contains("Certificate")
485                || err.to_string().contains("invalid"),
486            "expected certificate error, got: {}",
487            err
488        );
489        server_task.await.unwrap();
490    }
491
492    /// Test mutual TLS handshake fails when server has wrong client fingerprint.
493    #[tokio::test]
494    async fn test_mutual_tls_fails_with_wrong_client_fingerprint() {
495        install_crypto_provider();
496        // generate server and client certificates
497        let server_cert = generate_self_signed_cert().unwrap();
498        let client_cert = generate_self_signed_cert().unwrap();
499        // server expects WRONG client fingerprint
500        let wrong_fingerprint = [0xCD; 32];
501        let server_config =
502            create_server_config_with_client_auth(&server_cert, wrong_fingerprint).unwrap();
503        let acceptor = TlsAcceptor::from(server_config);
504        // client has correct server fingerprint
505        let client_config =
506            create_client_config_with_cert(&client_cert, server_cert.fingerprint).unwrap();
507        let connector = TlsConnector::from(client_config);
508        // bind server
509        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
510        let addr = listener.local_addr().unwrap();
511        // spawn server task - will fail when verifying client cert
512        let server_acceptor = acceptor.clone();
513        let server_task = tokio::spawn(async move {
514            let (stream, _) = listener.accept().await.unwrap();
515            let result = server_acceptor.accept(stream).await;
516            assert!(result.is_err(), "expected server to reject client cert");
517        });
518        // client connects
519        let stream = TcpStream::connect(addr).await.unwrap();
520        let server_name = ServerName::try_from("rcp").unwrap();
521        // in TLS 1.3, client cert verification happens after client considers handshake done.
522        // the failure shows up as either: connect() error, or subsequent read/write error.
523        match connector.connect(server_name, stream).await {
524            Ok(mut tls_stream) => {
525                // handshake appeared to succeed from client's view, but server will reject.
526                // try to read - server's rejection will cause connection to fail.
527                let mut buf = [0u8; 1];
528                let read_result = tls_stream.read(&mut buf).await;
529                assert!(
530                    read_result.is_err() || read_result.unwrap() == 0,
531                    "expected read to fail or return EOF after server rejection"
532                );
533            }
534            Err(_) => {
535                // handshake failed directly - also acceptable
536            }
537        }
538        server_task.await.unwrap();
539    }
540
541    /// Test mutual TLS handshake succeeds with correct fingerprints.
542    #[tokio::test]
543    async fn test_mutual_tls_success_with_matching_fingerprints() {
544        install_crypto_provider();
545        // generate server and client certificates
546        let server_cert = generate_self_signed_cert().unwrap();
547        let client_cert = generate_self_signed_cert().unwrap();
548        // server expects correct client fingerprint
549        let server_config =
550            create_server_config_with_client_auth(&server_cert, client_cert.fingerprint).unwrap();
551        let acceptor = TlsAcceptor::from(server_config);
552        // client has correct server fingerprint
553        let client_config =
554            create_client_config_with_cert(&client_cert, server_cert.fingerprint).unwrap();
555        let connector = TlsConnector::from(client_config);
556        // bind server
557        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
558        let addr = listener.local_addr().unwrap();
559        // spawn server task
560        let server_acceptor = acceptor.clone();
561        let server_task = tokio::spawn(async move {
562            let (stream, _) = listener.accept().await.unwrap();
563            let mut tls_stream = server_acceptor.accept(stream).await.unwrap();
564            tls_stream.write_all(b"mutual").await.unwrap();
565            tls_stream.shutdown().await.unwrap();
566        });
567        // client connects
568        let stream = TcpStream::connect(addr).await.unwrap();
569        let server_name = ServerName::try_from("rcp").unwrap();
570        let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
571        let mut buf = [0u8; 6];
572        tls_stream.read_exact(&mut buf).await.unwrap();
573        assert_eq!(&buf, b"mutual");
574        server_task.await.unwrap();
575    }
576}