1use std::sync::Arc;
2use rcgen::generate_simple_self_signed;
3use rustls::{ClientConfig, ServerConfig};
4use sha2::{Digest, Sha256};
5use tokio::net::{TcpListener, TcpStream};
6use tokio_rustls::{TlsAcceptor, TlsConnector, TlsStream};
7pub use rustls::pki_types::{CertificateDer, PrivateKeyDer};
8pub use rustls::crypto::ring;
9use crate::fingerprint_client_verifier::FingerprintClientVerifier;
10use crate::fingerprint_server_verifier::FingerprintServerVerifier;
11pub mod fingerprint_server_verifier;
12pub mod fingerprint_client_verifier;
13
14pub fn get_cert_fingerprint(cert: &CertificateDer<'_>) -> String {
16 let mut hasher = Sha256::new();
17 hasher.update(cert.as_ref()); let result = hasher.finalize();
19 hex::encode(result) }
21
22pub struct SecureConn {
24 pub stream: TlsStream<TcpStream>
25}
26
27impl SecureConn {
28 pub fn new(stream: TlsStream<TcpStream>) -> Self {
29 Self { stream }
30 }
31}
32
33pub fn generate_identity() -> anyhow::Result<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
35 let cert = generate_simple_self_signed(vec!["p2ps".into()])?;
36 Ok((cert.cert.into(), cert.signing_key.into()))
37}
38
39pub async fn accept(
41 listener: &TcpListener,
42 server_cert: CertificateDer<'static>,
43 server_key: PrivateKeyDer<'static>,
44 expected_client_hash: String, ) -> anyhow::Result<SecureConn> {
46
47 let verifier = Arc::new(FingerprintClientVerifier::new(expected_client_hash));
49
50 let config = ServerConfig::builder()
52 .with_client_cert_verifier(verifier) .with_single_cert(vec![server_cert], server_key)?;
54
55 let acceptor = TlsAcceptor::from(Arc::new(config));
56 let (socket, _) = listener.accept().await?;
57 let tls_stream = acceptor.accept(socket).await?;
58
59 Ok(SecureConn::new(tokio_rustls::TlsStream::Server(tls_stream)))
60}
61
62pub async fn connect(
64 addr: &str,
65 expected_server_hash: String,
66 client_cert: CertificateDer<'static>, client_key: PrivateKeyDer<'static>, ) -> anyhow::Result<SecureConn> {
69
70 let verifier = Arc::new(FingerprintServerVerifier::new(expected_server_hash));
72
73 let config = ClientConfig::builder()
74 .dangerous() .with_custom_certificate_verifier(verifier)
76 .with_client_auth_cert(vec![client_cert], client_key)?;
77
78 let connector = TlsConnector::from(Arc::new(config));
79 let domain = "p2ps".try_into()?;
80 let stream = TcpStream::connect(addr).await?;
81 let tls_stream = connector.connect(domain, stream).await?;
82
83 Ok(SecureConn::new(tokio_rustls::TlsStream::Client(tls_stream)))
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use tokio::io::{AsyncReadExt, AsyncWriteExt};
90
91 #[tokio::test]
92 async fn test_mtls_connection() -> anyhow::Result<()> {
93
94 let _ = ring::default_provider().install_default();
96
97 let (server_cert, server_key) = generate_identity()?;
99 let (client_cert, client_key) = generate_identity()?;
100
101 let server_hash = get_cert_fingerprint(&server_cert);
103 let client_hash = get_cert_fingerprint(&client_cert);
104
105 let listener = TcpListener::bind("127.0.0.1:0").await?;
107 let addr = listener.local_addr()?;
108
109 let server_handle = tokio::spawn(async move {
111 let mut secure_conn = accept(
112 &listener,
113 server_cert,
114 server_key,
115 client_hash,
116 ).await.expect("Server failed to accept");
117
118 let mut buf = [0u8; 12];
120 secure_conn.stream.read_exact(&mut buf).await.unwrap();
121 assert_eq!(&buf, b"Hello Server");
122
123 secure_conn.stream.write_all(b"Hello Client").await.unwrap();
125 });
126
127 let mut client_conn = connect(
129 &addr.to_string(),
130 server_hash,
131 client_cert,
132 client_key,
133 ).await?;
134
135 client_conn.stream.write_all(b"Hello Server").await?;
137
138 let mut buf = [0u8; 12];
140 client_conn.stream.read_exact(&mut buf).await?;
141 assert_eq!(&buf, b"Hello Client");
142
143 server_handle.await?;
145 Ok(())
146 }
147}