use std::path::Path;
use std::sync::Arc;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use zamsync_core::{ZamError, ZamResult};
pub struct TlsConfig {
cert_pem: Vec<u8>,
key_pem: Vec<u8>,
ca_pem: Vec<u8>,
}
impl TlsConfig {
pub fn from_files(
cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>,
ca_path: impl AsRef<Path>,
) -> ZamResult<Self> {
Ok(Self {
cert_pem: std::fs::read(cert_path)?,
key_pem: std::fs::read(key_path)?,
ca_pem: std::fs::read(ca_path)?,
})
}
pub fn from_pem(cert_pem: String, key_pem: String, ca_pem: String) -> Self {
Self {
cert_pem: cert_pem.into_bytes(),
key_pem: key_pem.into_bytes(),
ca_pem: ca_pem.into_bytes(),
}
}
fn load_cert(&self) -> ZamResult<CertificateDer<'static>> {
rustls_pemfile::certs(&mut self.cert_pem.as_slice())
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.next()
.ok_or_else(|| ZamError::Config("no certificate in cert file".into()))
}
fn load_key(&self) -> ZamResult<PrivateKeyDer<'static>> {
rustls_pemfile::private_key(&mut self.key_pem.as_slice())?
.ok_or_else(|| ZamError::Config("no private key in key file".into()))
}
fn load_ca(&self) -> ZamResult<CertificateDer<'static>> {
rustls_pemfile::certs(&mut self.ca_pem.as_slice())
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.next()
.ok_or_else(|| ZamError::Config("no certificate in CA file".into()))
}
pub(crate) fn server_config(&self) -> ZamResult<Arc<rustls::ServerConfig>> {
let cert = self.load_cert()?;
let key = self.load_key()?;
let ca = self.load_ca()?;
let mut client_roots = rustls::RootCertStore::empty();
client_roots
.add(ca)
.map_err(|e| ZamError::Config(format!("invalid CA cert: {e}")))?;
let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(client_roots))
.build()
.map_err(|e| ZamError::Config(format!("client verifier: {e}")))?;
let config = rustls::ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(vec![cert], key)
.map_err(|e| ZamError::Config(format!("server TLS config: {e}")))?;
Ok(Arc::new(config))
}
pub(crate) fn client_config(&self) -> ZamResult<Arc<rustls::ClientConfig>> {
let cert = self.load_cert()?;
let key = self.load_key()?;
let ca = self.load_ca()?;
let mut server_roots = rustls::RootCertStore::empty();
server_roots
.add(ca)
.map_err(|e| ZamError::Config(format!("invalid CA cert: {e}")))?;
let config = rustls::ClientConfig::builder()
.with_root_certificates(server_roots)
.with_client_auth_cert(vec![cert], key)
.map_err(|e| ZamError::Config(format!("client TLS config: {e}")))?;
Ok(Arc::new(config))
}
}
pub struct GeneratedCredentials {
pub ca_cert_pem: String,
pub ca_key_pem: String,
pub node_cert_pem: String,
pub node_key_pem: String,
}
pub fn generate_credentials() -> ZamResult<GeneratedCredentials> {
let ca_key = rcgen::KeyPair::generate()
.map_err(|e| ZamError::Config(format!("CA key generation failed: {e}")))?;
let mut ca_params = rcgen::CertificateParams::new(vec!["ZamSync CA".to_string()])
.map_err(|e| ZamError::Config(format!("CA params: {e}")))?;
ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
let ca_cert = ca_params
.self_signed(&ca_key)
.map_err(|e| ZamError::Config(format!("CA self-sign failed: {e}")))?;
let node_key = rcgen::KeyPair::generate()
.map_err(|e| ZamError::Config(format!("node key generation failed: {e}")))?;
let node_params = rcgen::CertificateParams::new(vec!["zamsync.local".to_string()])
.map_err(|e| ZamError::Config(format!("node params: {e}")))?;
let node_cert = node_params
.signed_by(&node_key, &ca_cert, &ca_key)
.map_err(|e| ZamError::Config(format!("node cert signing failed: {e}")))?;
Ok(GeneratedCredentials {
ca_cert_pem: ca_cert.pem(),
ca_key_pem: ca_key.serialize_pem(),
node_cert_pem: node_cert.pem(),
node_key_pem: node_key.serialize_pem(),
})
}
pub struct SignedNodeCredentials {
pub ca_cert_pem: String,
pub node_cert_pem: String,
pub node_key_pem: String,
}
pub fn sign_node_cert(ca_cert_pem: &str, ca_key_pem: &str) -> ZamResult<SignedNodeCredentials> {
let ca_key = rcgen::KeyPair::from_pem(ca_key_pem)
.map_err(|e| ZamError::Config(format!("parse CA key: {e}")))?;
let mut ca_params = rcgen::CertificateParams::new(vec!["ZamSync CA".to_string()])
.map_err(|e| ZamError::Config(format!("CA params: {e}")))?;
ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
let ca_cert = ca_params
.self_signed(&ca_key)
.map_err(|e| ZamError::Config(format!("reconstruct CA cert: {e}")))?;
let node_key = rcgen::KeyPair::generate()
.map_err(|e| ZamError::Config(format!("node key generation failed: {e}")))?;
let node_params = rcgen::CertificateParams::new(vec!["zamsync.local".to_string()])
.map_err(|e| ZamError::Config(format!("node params: {e}")))?;
let node_cert = node_params
.signed_by(&node_key, &ca_cert, &ca_key)
.map_err(|e| ZamError::Config(format!("node cert signing failed: {e}")))?;
Ok(SignedNodeCredentials {
ca_cert_pem: ca_cert_pem.to_owned(),
node_cert_pem: node_cert.pem(),
node_key_pem: node_key.serialize_pem(),
})
}
pub fn install_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sign_node_cert_produces_valid_chain() {
install_crypto_provider();
let hub = generate_credentials().expect("hub keygen failed");
let clinic_a =
sign_node_cert(&hub.ca_cert_pem, &hub.ca_key_pem).expect("clinic_a signing failed");
let clinic_b =
sign_node_cert(&hub.ca_cert_pem, &hub.ca_key_pem).expect("clinic_b signing failed");
assert_eq!(clinic_a.ca_cert_pem, hub.ca_cert_pem);
assert_eq!(clinic_b.ca_cert_pem, hub.ca_cert_pem);
assert_ne!(clinic_a.node_cert_pem, clinic_b.node_cert_pem);
assert_ne!(clinic_a.node_key_pem, clinic_b.node_key_pem);
let hub_tls = TlsConfig::from_pem(
hub.node_cert_pem.clone(),
hub.node_key_pem.clone(),
hub.ca_cert_pem.clone(),
);
let clinic_a_tls = TlsConfig::from_pem(
clinic_a.node_cert_pem.clone(),
clinic_a.node_key_pem.clone(),
clinic_a.ca_cert_pem.clone(),
);
hub_tls.server_config().expect("hub server_config failed");
clinic_a_tls
.client_config()
.expect("clinic_a client_config failed");
}
#[test]
fn test_rogue_node_with_own_ca_rejected() {
install_crypto_provider();
let hub = generate_credentials().expect("hub keygen");
let rogue = generate_credentials().expect("rogue keygen");
let hub_tls = TlsConfig::from_pem(
hub.node_cert_pem.clone(),
hub.node_key_pem.clone(),
hub.ca_cert_pem.clone(),
);
let rogue_tls = TlsConfig::from_pem(
rogue.node_cert_pem.clone(),
rogue.node_key_pem.clone(),
hub.ca_cert_pem.clone(), );
hub_tls.server_config().expect("hub server_config failed");
let hub_ca_der = rustls_pemfile::certs(&mut hub.ca_cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()
.expect("parse hub CA");
let rogue_cert_der = rustls_pemfile::certs(&mut rogue.node_cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()
.expect("parse rogue cert");
let mut root_store = rustls::RootCertStore::empty();
root_store.add(hub_ca_der[0].clone()).expect("add hub CA");
let verifier =
rustls::server::WebPkiClientVerifier::builder(std::sync::Arc::new(root_store))
.build()
.expect("build verifier");
let now = rustls::pki_types::UnixTime::now();
let result = verifier.verify_client_cert(&rogue_cert_der[0], &[], now);
assert!(
result.is_err(),
"rogue cert must be rejected by hub CA verifier"
);
rogue_tls
.client_config()
.expect("client config builds -- rejection happens at handshake");
}
#[test]
fn test_expired_cert_rejected_at_handshake() {
install_crypto_provider();
let ca_key = rcgen::KeyPair::generate().expect("CA key");
let mut ca_params =
rcgen::CertificateParams::new(vec!["ZamSync CA".to_string()]).expect("CA params");
ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
let ca_cert = ca_params.self_signed(&ca_key).expect("CA self-sign");
let node_key = rcgen::KeyPair::generate().expect("node key");
let mut node_params =
rcgen::CertificateParams::new(vec!["zamsync.local".to_string()]).expect("node params");
node_params.not_before =
time::OffsetDateTime::from_unix_timestamp(0).expect("epoch start");
node_params.not_after =
time::OffsetDateTime::from_unix_timestamp(86400).expect("epoch + 1 day");
let expired_cert = node_params
.signed_by(&node_key, &ca_cert, &ca_key)
.expect("sign expired cert");
let ca_pem = ca_cert.pem();
let expired_pem = expired_cert.pem();
let ca_der: Vec<_> = rustls_pemfile::certs(&mut ca_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()
.expect("parse CA DER");
let expired_der: Vec<_> = rustls_pemfile::certs(&mut expired_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()
.expect("parse expired cert DER");
let mut root_store = rustls::RootCertStore::empty();
root_store.add(ca_der[0].clone()).expect("add CA to root store");
let verifier =
rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
.build()
.expect("build verifier");
let now = rustls::pki_types::UnixTime::now();
let result = verifier.verify_client_cert(&expired_der[0], &[], now);
assert!(
result.is_err(),
"expired certificate must be rejected; got Ok instead"
);
}
}