Skip to main content

p2ps/
lib.rs

1use std::sync::Arc;
2use rcgen::generate_simple_self_signed;
3use rustls::{ClientConfig, ServerConfig};
4use sha2::{Digest, Sha256};
5use tokio::net::{TcpListener, TcpStream};
6use tokio_rustls::{TlsAcceptor, TlsConnector, TlsStream};
7pub use rustls::pki_types::{CertificateDer, PrivateKeyDer};
8pub use rustls::crypto::ring;
9use crate::fingerprint_client_verifier::FingerprintClientVerifier;
10use crate::fingerprint_server_verifier::FingerprintServerVerifier;
11pub mod fingerprint_server_verifier;
12pub mod fingerprint_client_verifier;
13
14/// Creates a hash out of a cert
15pub fn get_cert_fingerprint(cert: &CertificateDer<'_>) -> String {
16    let mut hasher = Sha256::new();
17    hasher.update(cert.as_ref()); // Hash the raw DER bytes
18    let result = hasher.finalize();
19    hex::encode(result) // Returns a string like "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
20}
21
22/// TLS Wrapper
23pub struct SecureConn {
24    pub stream: TlsStream<TcpStream>
25}
26
27impl SecureConn {
28    pub fn new(stream: TlsStream<TcpStream>) -> Self {
29        Self { stream }
30    }
31}
32
33/// Generates self signer cert and private key
34pub fn generate_identity() -> anyhow::Result<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
35    let cert = generate_simple_self_signed(vec!["p2ps".into()])?;
36    Ok((cert.cert.into(), cert.signing_key.into()))
37}
38
39/// Accepts an incomming request, upon aceptance a tls secured connection is created and ready to use
40pub async fn accept(
41    listener: &TcpListener,
42    server_cert: CertificateDer<'static>,
43    server_key: PrivateKeyDer<'static>,
44    expected_client_hash: String, // Added: The string hash we expect from the client
45) -> anyhow::Result<SecureConn> {
46
47    // Build our custom verifier using the expected string hash
48    let verifier = Arc::new(FingerprintClientVerifier::new(expected_client_hash));
49
50    // Configure the server to use it
51    let config = ServerConfig::builder()
52        .with_client_cert_verifier(verifier) // Enforce our custom mTLS hash check
53        .with_single_cert(vec![server_cert], server_key)?;
54
55    let acceptor = TlsAcceptor::from(Arc::new(config));
56    let (socket, _) = listener.accept().await?;
57    let tls_stream = acceptor.accept(socket).await?;
58
59    Ok(SecureConn::new(tokio_rustls::TlsStream::Server(tls_stream)))
60}
61
62/// Request connection, upon connection a tls secured connection is created and ready to use
63pub async fn connect(
64    addr: &str,
65    expected_server_hash: String,
66    client_cert: CertificateDer<'static>,      // Added: Client's identity
67    client_key: PrivateKeyDer<'static>,        // Added: Client's key
68) -> anyhow::Result<SecureConn> {
69
70    // Use our custom verifier
71    let verifier = Arc::new(FingerprintServerVerifier::new(expected_server_hash));
72
73    let config = ClientConfig::builder()
74        .dangerous() // Required to use custom verifiers
75        .with_custom_certificate_verifier(verifier)
76        .with_client_auth_cert(vec![client_cert], client_key)?;
77
78    let connector = TlsConnector::from(Arc::new(config));
79    let domain = "p2ps".try_into()?;
80    let stream = TcpStream::connect(addr).await?;
81    let tls_stream = connector.connect(domain, stream).await?;
82
83    Ok(SecureConn::new(tokio_rustls::TlsStream::Client(tls_stream)))
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use tokio::io::{AsyncReadExt, AsyncWriteExt};
90
91    #[tokio::test]
92    async fn test_mtls_connection() -> anyhow::Result<()> {
93
94        // This tells rustls to use 'ring' for all the heavy lifting (signing/verifying)
95        let _ = ring::default_provider().install_default();
96
97        // Generate separate identities for Server and Client
98        let (server_cert, server_key) = generate_identity()?;
99        let (client_cert, client_key) = generate_identity()?;
100
101        // Calculate the fingerprints (the "signatures" we will share)
102        let server_hash = get_cert_fingerprint(&server_cert);
103        let client_hash = get_cert_fingerprint(&client_cert);
104
105        // Setup the Listener
106        let listener = TcpListener::bind("127.0.0.1:0").await?;
107        let addr = listener.local_addr()?;
108
109        // Spawn the Server Task
110        let server_handle = tokio::spawn(async move {
111            let mut secure_conn = accept(
112                &listener,
113                server_cert,
114                server_key,
115                client_hash,
116            ).await.expect("Server failed to accept");
117
118            // Read message from client
119            let mut buf = [0u8; 12];
120            secure_conn.stream.read_exact(&mut buf).await.unwrap();
121            assert_eq!(&buf, b"Hello Server");
122
123            // Write response
124            secure_conn.stream.write_all(b"Hello Client").await.unwrap();
125        });
126
127        // Run the Client Task
128        let mut client_conn = connect(
129            &addr.to_string(),
130            server_hash,
131            client_cert,
132            client_key,
133        ).await?;
134
135        // Write to server
136        client_conn.stream.write_all(b"Hello Server").await?;
137
138        // Read response
139        let mut buf = [0u8; 12];
140        client_conn.stream.read_exact(&mut buf).await?;
141        assert_eq!(&buf, b"Hello Client");
142
143        // Clean up
144        server_handle.await?;
145        Ok(())
146    }
147}