use std::{
fs::{self, File},
io::BufReader,
path::Path,
};
use bevy::log::{trace, warn};
use super::EndpointCertificateError;
use crate::shared::certificate::CertificateFingerprint;
#[derive(Debug, Clone)]
pub enum CertOrigin {
Generated {
server_hostname: String,
},
Loaded,
}
#[derive(Debug, Clone)]
pub enum CertificateRetrievalMode {
GenerateSelfSigned {
server_hostname: String,
},
LoadFromFile {
cert_file: String,
key_file: String,
},
LoadFromFileOrGenerateSelfSigned {
cert_file: String,
key_file: String,
save_on_disk: bool,
server_hostname: String,
},
}
pub struct ServerCertificate {
pub cert_chain: Vec<rustls::pki_types::CertificateDer<'static>>,
pub priv_key: rustls::pki_types::PrivateKeyDer<'static>,
pub fingerprint: CertificateFingerprint,
}
fn read_cert_from_files(
cert_file: &String,
key_file: &String,
) -> Result<ServerCertificate, EndpointCertificateError> {
let mut cert_chain_reader = BufReader::new(File::open(cert_file)?);
let cert_chain: Vec<rustls::pki_types::CertificateDer> =
rustls_pemfile::certs(&mut cert_chain_reader).collect::<Result<_, _>>()?;
assert!(!cert_chain.is_empty());
let mut key_reader = BufReader::new(File::open(key_file)?);
let priv_key = rustls_pemfile::private_key(&mut key_reader)?.expect("private key is present");
let fingerprint = CertificateFingerprint::from(&cert_chain[0]);
Ok(ServerCertificate {
cert_chain,
priv_key,
fingerprint,
})
}
fn write_cert_to_files(
cert: &rcgen::CertifiedKey,
cert_file: &String,
key_file: &String,
) -> std::io::Result<()> {
for file in [cert_file, key_file] {
if let Some(parent) = std::path::Path::new(file).parent() {
std::fs::create_dir_all(parent)?;
}
}
fs::write(cert_file, cert.cert.pem())?;
fs::write(key_file, cert.key_pair.serialize_pem())?;
Ok(())
}
fn generate_self_signed_certificate(
server_host: &String,
) -> Result<(ServerCertificate, rcgen::CertifiedKey), EndpointCertificateError> {
let generated = rcgen::generate_simple_self_signed(vec![server_host.into()])?;
let priv_key_der =
rustls::pki_types::PrivatePkcs8KeyDer::from(generated.key_pair.serialize_der()).into();
let cert_der = generated.cert.der();
let fingerprint = CertificateFingerprint::from(cert_der);
Ok((
ServerCertificate {
cert_chain: vec![cert_der.clone()],
priv_key: priv_key_der,
fingerprint,
},
generated,
))
}
pub(crate) fn retrieve_certificate(
cert_mode: CertificateRetrievalMode,
) -> Result<ServerCertificate, EndpointCertificateError> {
match cert_mode {
CertificateRetrievalMode::GenerateSelfSigned { server_hostname } => {
let (server_cert, _rcgen_cert) = generate_self_signed_certificate(&server_hostname)?;
trace!("Generatied a new self-signed certificate");
Ok(server_cert)
}
CertificateRetrievalMode::LoadFromFile {
cert_file,
key_file,
} => {
let server_cert = read_cert_from_files(&cert_file, &key_file)?;
trace!("Successfuly loaded cert and key from files");
Ok(server_cert)
}
CertificateRetrievalMode::LoadFromFileOrGenerateSelfSigned {
save_on_disk,
cert_file,
key_file,
server_hostname,
} => {
if Path::new(&cert_file).exists() && Path::new(&key_file).exists() {
let server_cert = read_cert_from_files(&cert_file, &key_file)?;
trace!("Successfuly loaded cert and key from files");
Ok(server_cert)
} else {
warn!("{} and/or {} do not exist, could not load existing certificate. Generating a new self-signed certificate.", cert_file, key_file);
let (server_cert, rcgen_cert) = generate_self_signed_certificate(&server_hostname)?;
if save_on_disk {
write_cert_to_files(&rcgen_cert, &cert_file, &key_file)?;
trace!("Successfuly saved cert and key to files");
}
Ok(server_cert)
}
}
}
}