1use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
7use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
8use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
9use rustls::{
10 ClientConfig, DigitallySignedStruct, DistinguishedName, ServerConfig, SignatureScheme,
11};
12use sha2::{Digest, Sha256};
13use std::sync::Arc;
14
15pub type Fingerprint = [u8; 32];
17
18#[derive(Clone)]
20pub struct CertifiedKey {
21 pub cert_der: Vec<u8>,
22 pub key_der: Vec<u8>,
23 pub fingerprint: Fingerprint,
24}
25
26pub fn generate_self_signed_cert() -> anyhow::Result<CertifiedKey> {
31 use rcgen::{CertificateParams, KeyPair};
32 let key_pair = KeyPair::generate_for(&rcgen::PKCS_ED25519)?;
34 let mut params = CertificateParams::default();
36 params.distinguished_name = rcgen::DistinguishedName::new();
37 params.distinguished_name.push(
38 rcgen::DnType::CommonName,
39 format!("rcp-{}", rand::random::<u64>()),
40 );
41 let cert = params.self_signed(&key_pair)?;
43 let cert_der = cert.der().to_vec();
44 let key_der = key_pair.serialize_der();
45 let fingerprint = compute_fingerprint(&cert_der);
47 Ok(CertifiedKey {
48 cert_der,
49 key_der,
50 fingerprint,
51 })
52}
53
54pub fn compute_fingerprint(cert_der: &[u8]) -> Fingerprint {
56 let mut hasher = Sha256::new();
57 hasher.update(cert_der);
58 hasher.finalize().into()
59}
60
61pub fn fingerprint_to_hex(fp: &Fingerprint) -> String {
63 hex::encode(fp)
64}
65
66pub fn fingerprint_from_hex(s: &str) -> anyhow::Result<Fingerprint> {
68 let bytes = hex::decode(s)?;
69 if bytes.len() != 32 {
70 anyhow::bail!(
71 "fingerprint must be 32 bytes (64 hex chars), got {}",
72 bytes.len()
73 );
74 }
75 let mut fp = [0u8; 32];
76 fp.copy_from_slice(&bytes);
77 Ok(fp)
78}
79
80pub fn create_server_config(cert_key: &CertifiedKey) -> anyhow::Result<Arc<ServerConfig>> {
84 let cert = CertificateDer::from(cert_key.cert_der.clone());
85 let key = PrivateKeyDer::try_from(cert_key.key_der.clone())
86 .map_err(|e| anyhow::anyhow!("invalid private key: {e}"))?;
87 let config = ServerConfig::builder()
88 .with_no_client_auth()
89 .with_single_cert(vec![cert], key)?;
90 Ok(Arc::new(config))
91}
92
93pub fn create_server_config_with_client_auth(
97 cert_key: &CertifiedKey,
98 expected_client_fingerprint: Fingerprint,
99) -> anyhow::Result<Arc<ServerConfig>> {
100 let cert = CertificateDer::from(cert_key.cert_der.clone());
101 let key = PrivateKeyDer::try_from(cert_key.key_der.clone())
102 .map_err(|e| anyhow::anyhow!("invalid private key: {e}"))?;
103 let client_verifier = Arc::new(FingerprintClientCertVerifier::new(
104 expected_client_fingerprint,
105 ));
106 let config = ServerConfig::builder()
107 .with_client_cert_verifier(client_verifier)
108 .with_single_cert(vec![cert], key)?;
109 Ok(Arc::new(config))
110}
111
112pub fn create_client_config(expected_server_fingerprint: Fingerprint) -> Arc<ClientConfig> {
116 let verifier = Arc::new(FingerprintServerCertVerifier::new(
117 expected_server_fingerprint,
118 ));
119 let config = ClientConfig::builder()
120 .dangerous()
121 .with_custom_certificate_verifier(verifier)
122 .with_no_client_auth();
123 Arc::new(config)
124}
125
126pub fn create_client_config_with_cert(
130 client_cert_key: &CertifiedKey,
131 expected_server_fingerprint: Fingerprint,
132) -> anyhow::Result<Arc<ClientConfig>> {
133 let verifier = Arc::new(FingerprintServerCertVerifier::new(
134 expected_server_fingerprint,
135 ));
136 let cert = CertificateDer::from(client_cert_key.cert_der.clone());
137 let key = PrivateKeyDer::try_from(client_cert_key.key_der.clone())
138 .map_err(|e| anyhow::anyhow!("invalid private key: {e}"))?;
139 let config = ClientConfig::builder()
140 .dangerous()
141 .with_custom_certificate_verifier(verifier)
142 .with_client_auth_cert(vec![cert], key)?;
143 Ok(Arc::new(config))
144}
145
146#[derive(Debug)]
148struct FingerprintServerCertVerifier {
149 expected_fingerprint: Fingerprint,
150}
151
152impl FingerprintServerCertVerifier {
153 fn new(expected_fingerprint: Fingerprint) -> Self {
154 Self {
155 expected_fingerprint,
156 }
157 }
158}
159
160impl ServerCertVerifier for FingerprintServerCertVerifier {
161 fn verify_server_cert(
162 &self,
163 end_entity: &CertificateDer<'_>,
164 _intermediates: &[CertificateDer<'_>],
165 _server_name: &ServerName<'_>,
166 _ocsp_response: &[u8],
167 _now: UnixTime,
168 ) -> Result<ServerCertVerified, rustls::Error> {
169 let actual_fingerprint = compute_fingerprint(end_entity.as_ref());
170 if actual_fingerprint == self.expected_fingerprint {
171 Ok(ServerCertVerified::assertion())
172 } else {
173 tracing::error!(
174 "TLS server certificate fingerprint mismatch: expected {}, got {}",
175 fingerprint_to_hex(&self.expected_fingerprint),
176 fingerprint_to_hex(&actual_fingerprint)
177 );
178 Err(rustls::Error::InvalidCertificate(
179 rustls::CertificateError::BadSignature,
180 ))
181 }
182 }
183 fn verify_tls12_signature(
184 &self,
185 _message: &[u8],
186 _cert: &CertificateDer<'_>,
187 _dss: &DigitallySignedStruct,
188 ) -> Result<HandshakeSignatureValid, rustls::Error> {
189 Ok(HandshakeSignatureValid::assertion())
191 }
192 fn verify_tls13_signature(
193 &self,
194 _message: &[u8],
195 _cert: &CertificateDer<'_>,
196 _dss: &DigitallySignedStruct,
197 ) -> Result<HandshakeSignatureValid, rustls::Error> {
198 Ok(HandshakeSignatureValid::assertion())
200 }
201 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
202 vec![
203 SignatureScheme::ED25519,
204 SignatureScheme::ECDSA_NISTP256_SHA256,
205 SignatureScheme::ECDSA_NISTP384_SHA384,
206 SignatureScheme::RSA_PSS_SHA256,
207 SignatureScheme::RSA_PSS_SHA384,
208 SignatureScheme::RSA_PSS_SHA512,
209 SignatureScheme::RSA_PKCS1_SHA256,
210 SignatureScheme::RSA_PKCS1_SHA384,
211 SignatureScheme::RSA_PKCS1_SHA512,
212 ]
213 }
214}
215
216#[derive(Debug)]
218struct FingerprintClientCertVerifier {
219 expected_fingerprint: Fingerprint,
220}
221
222impl FingerprintClientCertVerifier {
223 fn new(expected_fingerprint: Fingerprint) -> Self {
224 Self {
225 expected_fingerprint,
226 }
227 }
228}
229
230impl ClientCertVerifier for FingerprintClientCertVerifier {
231 fn root_hint_subjects(&self) -> &[DistinguishedName] {
232 &[]
233 }
234 fn verify_client_cert(
235 &self,
236 end_entity: &CertificateDer<'_>,
237 _intermediates: &[CertificateDer<'_>],
238 _now: UnixTime,
239 ) -> Result<ClientCertVerified, rustls::Error> {
240 let actual_fingerprint = compute_fingerprint(end_entity.as_ref());
241 if actual_fingerprint == self.expected_fingerprint {
242 Ok(ClientCertVerified::assertion())
243 } else {
244 tracing::error!(
245 "TLS client certificate fingerprint mismatch: expected {}, got {}",
246 fingerprint_to_hex(&self.expected_fingerprint),
247 fingerprint_to_hex(&actual_fingerprint)
248 );
249 Err(rustls::Error::InvalidCertificate(
250 rustls::CertificateError::BadSignature,
251 ))
252 }
253 }
254 fn verify_tls12_signature(
255 &self,
256 _message: &[u8],
257 _cert: &CertificateDer<'_>,
258 _dss: &DigitallySignedStruct,
259 ) -> Result<HandshakeSignatureValid, rustls::Error> {
260 Ok(HandshakeSignatureValid::assertion())
261 }
262 fn verify_tls13_signature(
263 &self,
264 _message: &[u8],
265 _cert: &CertificateDer<'_>,
266 _dss: &DigitallySignedStruct,
267 ) -> Result<HandshakeSignatureValid, rustls::Error> {
268 Ok(HandshakeSignatureValid::assertion())
269 }
270 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
271 vec![
272 SignatureScheme::ED25519,
273 SignatureScheme::ECDSA_NISTP256_SHA256,
274 SignatureScheme::ECDSA_NISTP384_SHA384,
275 SignatureScheme::RSA_PSS_SHA256,
276 SignatureScheme::RSA_PSS_SHA384,
277 SignatureScheme::RSA_PSS_SHA512,
278 SignatureScheme::RSA_PKCS1_SHA256,
279 SignatureScheme::RSA_PKCS1_SHA384,
280 SignatureScheme::RSA_PKCS1_SHA512,
281 ]
282 }
283 fn client_auth_mandatory(&self) -> bool {
284 true
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 fn install_crypto_provider() {
293 rustls::crypto::ring::default_provider()
294 .install_default()
295 .ok(); }
297
298 #[test]
299 fn test_generate_cert_and_fingerprint() {
300 install_crypto_provider();
301 let cert_key = generate_self_signed_cert().unwrap();
302 assert_eq!(cert_key.fingerprint.len(), 32);
303 assert!(!cert_key.cert_der.is_empty());
304 assert!(!cert_key.key_der.is_empty());
305 let fp2 = compute_fingerprint(&cert_key.cert_der);
307 assert_eq!(cert_key.fingerprint, fp2);
308 }
309
310 #[test]
311 fn test_fingerprint_hex_roundtrip() {
312 install_crypto_provider();
313 let cert_key = generate_self_signed_cert().unwrap();
314 let hex = fingerprint_to_hex(&cert_key.fingerprint);
315 assert_eq!(hex.len(), 64);
316 let fp2 = fingerprint_from_hex(&hex).unwrap();
317 assert_eq!(cert_key.fingerprint, fp2);
318 }
319
320 #[test]
321 fn test_fingerprint_from_hex_invalid() {
322 assert!(fingerprint_from_hex("abcd").is_err());
324 assert!(fingerprint_from_hex("zzzz").is_err());
326 }
327
328 #[test]
329 fn test_create_server_config() {
330 install_crypto_provider();
331 let cert_key = generate_self_signed_cert().unwrap();
332 let config = create_server_config(&cert_key).unwrap();
333 assert!(config.alpn_protocols.is_empty());
334 }
335
336 #[test]
337 fn test_create_client_config() {
338 install_crypto_provider();
339 let fp = [0u8; 32];
340 let config = create_client_config(fp);
341 assert!(config.alpn_protocols.is_empty());
342 }
343
344 #[test]
345 fn test_server_fingerprint_verifier_accepts_matching() {
346 install_crypto_provider();
347 let cert_key = generate_self_signed_cert().unwrap();
348 let verifier = FingerprintServerCertVerifier::new(cert_key.fingerprint);
349 let cert = CertificateDer::from(cert_key.cert_der);
350 let server_name = ServerName::try_from("rcp").unwrap();
351 let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], UnixTime::now());
352 assert!(result.is_ok());
353 }
354
355 #[test]
356 fn test_server_fingerprint_verifier_rejects_mismatch() {
357 install_crypto_provider();
358 let cert_key = generate_self_signed_cert().unwrap();
359 let wrong_fingerprint = [0u8; 32];
361 let verifier = FingerprintServerCertVerifier::new(wrong_fingerprint);
362 let cert = CertificateDer::from(cert_key.cert_der);
363 let server_name = ServerName::try_from("rcp").unwrap();
364 let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], UnixTime::now());
365 assert!(result.is_err());
366 match result {
368 Err(rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature)) => {}
369 other => panic!("expected BadSignature error, got: {:?}", other),
370 }
371 }
372
373 #[test]
374 fn test_client_fingerprint_verifier_accepts_matching() {
375 install_crypto_provider();
376 let cert_key = generate_self_signed_cert().unwrap();
377 let verifier = FingerprintClientCertVerifier::new(cert_key.fingerprint);
378 let cert = CertificateDer::from(cert_key.cert_der);
379 let result = verifier.verify_client_cert(&cert, &[], UnixTime::now());
380 assert!(result.is_ok());
381 }
382
383 #[test]
384 fn test_client_fingerprint_verifier_rejects_mismatch() {
385 install_crypto_provider();
386 let cert_key = generate_self_signed_cert().unwrap();
387 let wrong_fingerprint = [0u8; 32];
389 let verifier = FingerprintClientCertVerifier::new(wrong_fingerprint);
390 let cert = CertificateDer::from(cert_key.cert_der);
391 let result = verifier.verify_client_cert(&cert, &[], UnixTime::now());
392 assert!(result.is_err());
393 match result {
395 Err(rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature)) => {}
396 other => panic!("expected BadSignature error, got: {:?}", other),
397 }
398 }
399
400 #[test]
401 fn test_client_verifier_requires_auth() {
402 install_crypto_provider();
403 let verifier = FingerprintClientCertVerifier::new([0u8; 32]);
404 assert!(verifier.client_auth_mandatory());
405 }
406}
407
408#[cfg(test)]
409mod integration_tests {
410 use super::*;
411 use tokio::io::{AsyncReadExt, AsyncWriteExt};
412 use tokio::net::{TcpListener, TcpStream};
413 use tokio_rustls::{TlsAcceptor, TlsConnector};
414
415 fn install_crypto_provider() {
416 rustls::crypto::ring::default_provider()
417 .install_default()
418 .ok();
419 }
420
421 #[tokio::test]
423 async fn test_tls_handshake_success_with_matching_fingerprint() {
424 install_crypto_provider();
425 let server_cert = generate_self_signed_cert().unwrap();
427 let server_config = create_server_config(&server_cert).unwrap();
428 let acceptor = TlsAcceptor::from(server_config);
429 let client_config = create_client_config(server_cert.fingerprint);
431 let connector = TlsConnector::from(client_config);
432 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
434 let addr = listener.local_addr().unwrap();
435 let server_acceptor = acceptor.clone();
437 let server_task = tokio::spawn(async move {
438 let (stream, _) = listener.accept().await.unwrap();
439 let mut tls_stream = server_acceptor.accept(stream).await.unwrap();
440 tls_stream.write_all(b"hello").await.unwrap();
441 tls_stream.shutdown().await.unwrap();
442 });
443 let stream = TcpStream::connect(addr).await.unwrap();
445 let server_name = ServerName::try_from("rcp").unwrap();
446 let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
447 let mut buf = [0u8; 5];
448 tls_stream.read_exact(&mut buf).await.unwrap();
449 assert_eq!(&buf, b"hello");
450 server_task.await.unwrap();
451 }
452
453 #[tokio::test]
455 async fn test_tls_handshake_fails_with_wrong_server_fingerprint() {
456 install_crypto_provider();
457 let server_cert = generate_self_signed_cert().unwrap();
459 let server_config = create_server_config(&server_cert).unwrap();
460 let acceptor = TlsAcceptor::from(server_config);
461 let wrong_fingerprint = [0xAB; 32];
463 let client_config = create_client_config(wrong_fingerprint);
464 let connector = TlsConnector::from(client_config);
465 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
467 let addr = listener.local_addr().unwrap();
468 let server_acceptor = acceptor.clone();
470 let server_task = tokio::spawn(async move {
471 let (stream, _) = listener.accept().await.unwrap();
472 let _ = server_acceptor.accept(stream).await;
474 });
475 let stream = TcpStream::connect(addr).await.unwrap();
477 let server_name = ServerName::try_from("rcp").unwrap();
478 let result = connector.connect(server_name, stream).await;
479 assert!(result.is_err(), "expected TLS handshake to fail");
480 let err = result.unwrap_err();
481 assert!(
483 err.to_string().contains("certificate")
484 || err.to_string().contains("Certificate")
485 || err.to_string().contains("invalid"),
486 "expected certificate error, got: {}",
487 err
488 );
489 server_task.await.unwrap();
490 }
491
492 #[tokio::test]
494 async fn test_mutual_tls_fails_with_wrong_client_fingerprint() {
495 install_crypto_provider();
496 let server_cert = generate_self_signed_cert().unwrap();
498 let client_cert = generate_self_signed_cert().unwrap();
499 let wrong_fingerprint = [0xCD; 32];
501 let server_config =
502 create_server_config_with_client_auth(&server_cert, wrong_fingerprint).unwrap();
503 let acceptor = TlsAcceptor::from(server_config);
504 let client_config =
506 create_client_config_with_cert(&client_cert, server_cert.fingerprint).unwrap();
507 let connector = TlsConnector::from(client_config);
508 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
510 let addr = listener.local_addr().unwrap();
511 let server_acceptor = acceptor.clone();
513 let server_task = tokio::spawn(async move {
514 let (stream, _) = listener.accept().await.unwrap();
515 let result = server_acceptor.accept(stream).await;
516 assert!(result.is_err(), "expected server to reject client cert");
517 });
518 let stream = TcpStream::connect(addr).await.unwrap();
520 let server_name = ServerName::try_from("rcp").unwrap();
521 match connector.connect(server_name, stream).await {
524 Ok(mut tls_stream) => {
525 let mut buf = [0u8; 1];
528 let read_result = tls_stream.read(&mut buf).await;
529 assert!(
530 read_result.is_err() || read_result.unwrap() == 0,
531 "expected read to fail or return EOF after server rejection"
532 );
533 }
534 Err(_) => {
535 }
537 }
538 server_task.await.unwrap();
539 }
540
541 #[tokio::test]
543 async fn test_mutual_tls_success_with_matching_fingerprints() {
544 install_crypto_provider();
545 let server_cert = generate_self_signed_cert().unwrap();
547 let client_cert = generate_self_signed_cert().unwrap();
548 let server_config =
550 create_server_config_with_client_auth(&server_cert, client_cert.fingerprint).unwrap();
551 let acceptor = TlsAcceptor::from(server_config);
552 let client_config =
554 create_client_config_with_cert(&client_cert, server_cert.fingerprint).unwrap();
555 let connector = TlsConnector::from(client_config);
556 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
558 let addr = listener.local_addr().unwrap();
559 let server_acceptor = acceptor.clone();
561 let server_task = tokio::spawn(async move {
562 let (stream, _) = listener.accept().await.unwrap();
563 let mut tls_stream = server_acceptor.accept(stream).await.unwrap();
564 tls_stream.write_all(b"mutual").await.unwrap();
565 tls_stream.shutdown().await.unwrap();
566 });
567 let stream = TcpStream::connect(addr).await.unwrap();
569 let server_name = ServerName::try_from("rcp").unwrap();
570 let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
571 let mut buf = [0u8; 6];
572 tls_stream.read_exact(&mut buf).await.unwrap();
573 assert_eq!(&buf, b"mutual");
574 server_task.await.unwrap();
575 }
576}