use super::Error;
use tokio_native_tls::native_tls::{Certificate, Identity, TlsAcceptor, TlsConnector};
pub type TlsIdentityInner = tokio_native_tls::TlsAcceptor;
#[cfg(feature = "server")]
pub type HyperConnector =
hyper_tls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>;
pub async fn make_server_config(
cert_path: &str,
key_path: &str,
client_ca_path: Option<&str>,
) -> Result<TlsIdentityInner, Error> {
let identity = read_key_cert(key_path, cert_path).await?;
make_server_config_from_mem(identity, client_ca_path)
}
#[cfg(feature = "acme")]
pub async fn make_server_config_from_pem(
certs: String,
priv_key_pem: String,
client_ca_path: Option<&str>,
) -> Result<TlsIdentityInner, Error> {
let identity = Identity::from_pkcs8(certs.as_bytes(), priv_key_pem.as_bytes())?;
make_server_config_from_mem(identity, client_ca_path)
}
fn make_server_config_from_mem(
identity: Identity,
client_ca_path: Option<&str>,
) -> Result<TlsIdentityInner, Error> {
if client_ca_path.is_some() {
return Err(Error::UnsupportedFeature(
"client CA verification",
"requires rustls",
));
}
let raw_acceptor = TlsAcceptor::builder(identity).build()?;
Ok(raw_acceptor.into())
}
pub async fn make_client_config(
cert_path: Option<&str>,
key_path: Option<&str>,
ca_path: Option<&str>,
tls_skip_verify: bool,
tls_alpn: Option<&[&str]>,
) -> Result<TlsConnector, Error> {
let mut tls_config_builder = TlsConnector::builder();
tls_config_builder
.danger_accept_invalid_certs(tls_skip_verify)
.danger_accept_invalid_hostnames(tls_skip_verify);
if let Some(tls_alpn) = tls_alpn {
tls_config_builder.request_alpns(tls_alpn);
}
if let Some(ca_path) = ca_path {
let ca = tokio::fs::read(ca_path).await.map_err(Error::ReadCert)?;
tls_config_builder.add_root_certificate(Certificate::from_pem(&ca)?);
}
if let Some(cert_path) = cert_path {
let identity = read_key_cert(key_path.unwrap_or(cert_path), cert_path).await?;
tls_config_builder.identity(identity);
}
Ok(tls_config_builder.build()?)
}
async fn read_key_cert(key_path: &str, cert_path: &str) -> Result<Identity, Error> {
let key = tokio::fs::read(key_path).await.map_err(Error::ReadCert)?;
let cert = tokio::fs::read(cert_path).await.map_err(Error::ReadCert)?;
Ok(Identity::from_pkcs8(&cert, &key)?)
}
#[cfg(feature = "server")]
#[expect(clippy::unnecessary_wraps)]
pub fn make_hyper_connector() -> std::io::Result<HyperConnector> {
Ok(HyperConnector::new())
}
#[cfg(test)]
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
mod tests {
use super::*;
use rcgen::CertificateParams;
use tempfile::tempdir;
#[tokio::test]
async fn test_read_key_cert() {
crate::tests::setup_logging();
let tmpdir = tempdir().unwrap();
let key_path = tmpdir.path().join("key.pem");
let cert_path = tmpdir.path().join("cert.pem");
let cert_params = CertificateParams::new(vec!["example.com".into()]).unwrap();
let keypair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P384_SHA384).unwrap();
let custom_crt = cert_params.self_signed(&keypair).unwrap();
let crt = custom_crt.pem();
let crt_key = keypair.serialize_pem();
tokio::fs::write(&cert_path, crt).await.unwrap();
tokio::fs::write(&key_path, crt_key).await.unwrap();
read_key_cert(key_path.to_str().unwrap(), cert_path.to_str().unwrap())
.await
.unwrap();
}
#[tokio::test]
#[cfg(feature = "acme")]
async fn test_make_server_config_from_rcgen_pem() {
crate::tests::setup_logging();
let cert_params = CertificateParams::new(vec!["example.com".into()]).unwrap();
let keypair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P384_SHA384).unwrap();
let custom_crt = cert_params.self_signed(&keypair).unwrap();
let crt = custom_crt.pem();
let result = make_server_config_from_pem(crt, keypair.serialize_pem(), None).await;
assert!(result.is_ok());
}
}