rustls-cng 0.7.0

Windows CNG API bridge for rustls
Documentation
use std::{
    io::{Read, Write},
    net::{Shutdown, TcpStream},
    path::PathBuf,
    sync::Arc,
};

use clap::Parser;
use rustls::{
    client::ResolvesClientCert, sign::CertifiedKey, ClientConfig, ClientConnection, RootCertStore,
    SignatureScheme, Stream,
};
use rustls_cng::{
    signer::CngSigningKey,
    store::{CertStore, CertStoreType, Pkcs12Flags},
};
use rustls_pki_types::{CertificateDer, ServerName};

const PORT: u16 = 8000;

#[derive(Debug)]
pub struct ClientCertResolver {
    store: CertStore,
    cert_name: String,
    pin: Option<String>,
}

fn get_chain(
    store: &CertStore,
    name: &str,
) -> anyhow::Result<(Vec<CertificateDer<'static>>, CngSigningKey)> {
    let contexts = store.find_by_subject_str(name)?;
    let context = contexts
        .first()
        .ok_or_else(|| anyhow::Error::msg("No client cert"))?;
    let key = context.acquire_key(true)?;
    let signing_key = CngSigningKey::new(key)?;
    let chain = context
        .as_chain_der()?
        .into_iter()
        .map(Into::into)
        .collect();
    Ok((chain, signing_key))
}

impl ResolvesClientCert for ClientCertResolver {
    fn resolve(
        &self,
        _acceptable_issuers: &[&[u8]],
        sigschemes: &[SignatureScheme],
    ) -> Option<Arc<CertifiedKey>> {
        println!("Server sig schemes: {sigschemes:#?}");
        let (chain, signing_key) = get_chain(&self.store, &self.cert_name).ok()?;
        if let Some(ref pin) = self.pin {
            signing_key.key().set_pin(pin).ok()?;
        }
        for scheme in signing_key.supported_schemes() {
            if sigschemes.contains(scheme) {
                return Some(Arc::new(CertifiedKey {
                    cert: chain,
                    key: Arc::new(signing_key),
                    ocsp: None,
                }));
            }
        }
        None
    }

    fn has_certs(&self) -> bool {
        true
    }
}

#[derive(Parser)]
#[clap(name = "rustls-client-sample")]
struct AppParams {
    #[clap(
        short = 'c',
        long = "ca-cert",
        help = "CA cert name to verify the peer certificate"
    )]
    ca_cert: String,

    #[clap(short = 'k', long = "keystore", help = "Use external PFX keystore")]
    keystore: Option<PathBuf>,

    #[clap(
        short = 'p',
        long = "password",
        help = "Keystore password or token pin"
    )]
    password: Option<String>,

    #[clap(
        short = 's',
        long = "server-name",
        help = "Server name for TLS SNI extension"
    )]
    server_name: String,

    #[clap(
        short = 'l',
        long = "client-cert",
        help = "Client cert name for client auth"
    )]
    client_cert: String,

    #[clap(help = "Server address")]
    server_address: String,
}

fn main() -> anyhow::Result<()> {
    let params: AppParams = AppParams::parse();

    let store = if let Some(ref keystore) = params.keystore {
        let data = std::fs::read(keystore)?;
        CertStore::from_pkcs12(
            &data,
            params.password.as_deref().unwrap_or_default(),
            Pkcs12Flags::default(),
        )?
    } else {
        CertStore::open(CertStoreType::CurrentUser, "my")?
    };

    let ca_cert_context = store.find_by_subject_str(&params.ca_cert)?;
    let ca_cert = ca_cert_context.first().unwrap();

    let mut root_store = RootCertStore::empty();
    root_store.add(ca_cert.as_der().into())?;

    let client_config = ClientConfig::builder()
        .with_root_certificates(root_store)
        .with_client_cert_resolver(Arc::new(ClientCertResolver {
            store,
            cert_name: params.client_cert.clone(),
            pin: params.password.clone(),
        }));

    let server_name = ServerName::try_from(params.server_name.as_str())?.to_owned();
    let mut connection = ClientConnection::new(Arc::new(client_config), server_name)?;
    let mut client = TcpStream::connect(format!("{}:{}", params.server_address, PORT))?;

    let mut tls_stream = Stream::new(&mut connection, &mut client);
    tls_stream.write_all(b"ping")?;
    tls_stream.sock.shutdown(Shutdown::Write)?;

    let mut buf = [0u8; 4];
    tls_stream.read_exact(&mut buf)?;
    println!("{}", String::from_utf8_lossy(&buf));

    tls_stream.sock.shutdown(Shutdown::Read)?;

    Ok(())
}