rustls-cng 0.7.0

Windows CNG API bridge for rustls
Documentation
const CA_SUBJECT: &str = "Inforce Technologies CA";
const CLIENT_PFX: &[u8] = include_bytes!("assets/rustls-client.pfx");
const SERVER_PFX: &[u8] = include_bytes!("assets/rustls-server.pfx");
const PASSWORD: &str = "changeit";

mod client {
    use std::{
        io::{Read, Write},
        net::{Shutdown, TcpStream},
        sync::Arc,
    };

    use rustls::{
        client::ResolvesClientCert, sign::CertifiedKey, ClientConfig, ClientConnection,
        RootCertStore, SignatureScheme, Stream,
    };
    use rustls_pki_types::CertificateDer;

    use rustls_cng::{
        signer::CngSigningKey,
        store::{CertStore, Pkcs12Flags},
    };

    #[derive(Debug)]
    pub struct ClientCertResolver(CertStore, String);

    fn get_chain(
        store: &CertStore,
        name: &str,
    ) -> anyhow::Result<(Vec<CertificateDer<'static>>, CngSigningKey)> {
        let contexts = store.find_by_subject_str(name)?;
        let context = contexts
            .first()
            .ok_or_else(|| anyhow::Error::msg("No client cert"))?;
        let key = context.acquire_key(true)?;
        let signing_key = CngSigningKey::new(key)?;
        let chain = context
            .as_chain_der()?
            .into_iter()
            .map(Into::into)
            .collect();
        Ok((chain, signing_key))
    }

    impl ResolvesClientCert for ClientCertResolver {
        fn resolve(
            &self,
            _acceptable_issuers: &[&[u8]],
            sigschemes: &[SignatureScheme],
        ) -> Option<Arc<CertifiedKey>> {
            let (chain, signing_key) = get_chain(&self.0, &self.1).ok()?;
            for scheme in signing_key.supported_schemes() {
                if sigschemes.contains(scheme) {
                    return Some(Arc::new(CertifiedKey {
                        cert: chain,
                        key: Arc::new(signing_key),
                        ocsp: None,
                    }));
                }
            }
            None
        }

        fn has_certs(&self) -> bool {
            true
        }
    }

    pub fn run_client(port: u16) -> anyhow::Result<()> {
        let store =
            CertStore::from_pkcs12(super::CLIENT_PFX, super::PASSWORD, Pkcs12Flags::default())?;

        let ca_cert_context = store.find_by_subject_str(super::CA_SUBJECT)?;
        let ca_cert = ca_cert_context.first().unwrap();

        let mut root_store = RootCertStore::empty();
        root_store.add(ca_cert.as_der().into())?;

        let client_config = ClientConfig::builder()
            .with_root_certificates(root_store)
            .with_client_cert_resolver(Arc::new(ClientCertResolver(
                store,
                "rustls-client".to_string(),
            )));

        let mut connection =
            ClientConnection::new(Arc::new(client_config), "rustls-server".try_into()?)?;

        let mut client = TcpStream::connect(format!("localhost:{port}"))?;

        let mut tls_stream = Stream::new(&mut connection, &mut client);
        tls_stream.write_all(b"ping")?;
        tls_stream.sock.shutdown(Shutdown::Write)?;

        let mut buf = [0u8; 4];
        tls_stream.read_exact(&mut buf)?;
        assert_eq!(&buf, b"pong");

        tls_stream.sock.shutdown(Shutdown::Read)?;

        Ok(())
    }
}

mod server {
    use std::{
        io::{Read, Write},
        net::{Shutdown, TcpListener, TcpStream},
        sync::{mpsc::Sender, Arc},
    };

    use rustls::{
        server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier},
        sign::CertifiedKey,
        RootCertStore, ServerConfig, ServerConnection, Stream,
    };
    use rustls_cng::{
        signer::CngSigningKey,
        store::{CertStore, Pkcs12Flags},
    };

    #[derive(Debug)]
    pub struct ServerCertResolver(CertStore);

    impl ResolvesServerCert for ServerCertResolver {
        fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
            let name = client_hello.server_name()?;

            let contexts = self.0.find_by_subject_str(name).ok()?;

            let (context, key) = contexts.into_iter().find_map(|ctx| {
                let key = ctx.acquire_key(true).ok()?;
                CngSigningKey::new(key).ok().map(|key| (ctx, key))
            })?;

            let chain = context.as_chain_der().ok()?;
            let certs = chain.into_iter().map(Into::into).collect();

            Some(Arc::new(CertifiedKey {
                cert: certs,
                key: Arc::new(key),
                ocsp: None,
            }))
        }
    }

    fn handle_connection(mut stream: TcpStream, config: Arc<ServerConfig>) -> anyhow::Result<()> {
        let mut connection = ServerConnection::new(config)?;
        let mut tls_stream = Stream::new(&mut connection, &mut stream);

        let mut buf = [0u8; 4];
        tls_stream.read_exact(&mut buf)?;
        assert_eq!(&buf, b"ping");
        tls_stream.sock.shutdown(Shutdown::Read)?;
        tls_stream.write_all(b"pong")?;
        tls_stream.sock.shutdown(Shutdown::Write)?;

        Ok(())
    }

    pub fn run_server(sender: Sender<u16>) -> anyhow::Result<()> {
        let store =
            CertStore::from_pkcs12(super::SERVER_PFX, super::PASSWORD, Pkcs12Flags::default())?;

        let ca_cert_context = store.find_by_subject_str(super::CA_SUBJECT)?;
        let ca_cert = ca_cert_context.first().unwrap();

        let mut root_store = RootCertStore::empty();
        root_store.add(ca_cert.as_der().into())?;

        let verifier = WebPkiClientVerifier::builder(Arc::new(root_store))
            .build()
            .unwrap();

        let server_config = ServerConfig::builder()
            .with_client_cert_verifier(verifier)
            .with_cert_resolver(Arc::new(ServerCertResolver(store)));

        let server = TcpListener::bind("127.0.0.1:0")?;

        let _ = sender.send(server.local_addr()?.port());

        let stream = server.incoming().next().unwrap()?;
        let config = Arc::new(server_config);
        handle_connection(stream, config)?;

        Ok(())
    }
}

#[test]
fn test_client_server() {
    let (tx, rx) = std::sync::mpsc::channel();

    std::thread::spawn(move || {
        assert!(server::run_server(tx).is_ok());
    });

    if let Ok(port) = rx.recv() {
        client::run_client(port).unwrap();
    }
}