use std::sync::Arc;
use rcgen::generate_simple_self_signed;
use rustls::{ClientConfig, ServerConfig};
use sha2::{Digest, Sha256};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{TlsAcceptor, TlsConnector, TlsStream};
pub use rustls::pki_types::{CertificateDer, PrivateKeyDer};
pub use rustls::crypto::ring;
use crate::fingerprint_client_verifier::FingerprintClientVerifier;
use crate::fingerprint_server_verifier::FingerprintServerVerifier;
pub mod fingerprint_server_verifier;
pub mod fingerprint_client_verifier;
pub fn get_cert_fingerprint(cert: &CertificateDer<'_>) -> String {
let mut hasher = Sha256::new();
hasher.update(cert.as_ref()); let result = hasher.finalize();
hex::encode(result) }
pub struct SecureConn {
pub stream: TlsStream<TcpStream>
}
impl SecureConn {
pub fn new(stream: TlsStream<TcpStream>) -> Self {
Self { stream }
}
}
pub fn generate_identity() -> anyhow::Result<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
let cert = generate_simple_self_signed(vec!["p2ps".into()])?;
Ok((cert.cert.into(), cert.signing_key.into()))
}
pub async fn accept(
listener: &TcpListener,
server_cert: CertificateDer<'static>,
server_key: PrivateKeyDer<'static>,
expected_client_hash: String, ) -> anyhow::Result<SecureConn> {
let verifier = Arc::new(FingerprintClientVerifier::new(expected_client_hash));
let config = ServerConfig::builder()
.with_client_cert_verifier(verifier) .with_single_cert(vec![server_cert], server_key)?;
let acceptor = TlsAcceptor::from(Arc::new(config));
let (socket, _) = listener.accept().await?;
let tls_stream = acceptor.accept(socket).await?;
Ok(SecureConn::new(tokio_rustls::TlsStream::Server(tls_stream)))
}
pub async fn connect(
addr: &str,
expected_server_hash: String,
client_cert: CertificateDer<'static>, client_key: PrivateKeyDer<'static>, ) -> anyhow::Result<SecureConn> {
let verifier = Arc::new(FingerprintServerVerifier::new(expected_server_hash));
let config = ClientConfig::builder()
.dangerous() .with_custom_certificate_verifier(verifier)
.with_client_auth_cert(vec![client_cert], client_key)?;
let connector = TlsConnector::from(Arc::new(config));
let domain = "p2ps".try_into()?;
let stream = TcpStream::connect(addr).await?;
let tls_stream = connector.connect(domain, stream).await?;
Ok(SecureConn::new(tokio_rustls::TlsStream::Client(tls_stream)))
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn test_mtls_connection() -> anyhow::Result<()> {
let _ = ring::default_provider().install_default();
let (server_cert, server_key) = generate_identity()?;
let (client_cert, client_key) = generate_identity()?;
let server_hash = get_cert_fingerprint(&server_cert);
let client_hash = get_cert_fingerprint(&client_cert);
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let server_handle = tokio::spawn(async move {
let mut secure_conn = accept(
&listener,
server_cert,
server_key,
client_hash,
).await.expect("Server failed to accept");
let mut buf = [0u8; 12];
secure_conn.stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"Hello Server");
secure_conn.stream.write_all(b"Hello Client").await.unwrap();
});
let mut client_conn = connect(
&addr.to_string(),
server_hash,
client_cert,
client_key,
).await?;
client_conn.stream.write_all(b"Hello Server").await?;
let mut buf = [0u8; 12];
client_conn.stream.read_exact(&mut buf).await?;
assert_eq!(&buf, b"Hello Client");
server_handle.await?;
Ok(())
}
}