use std::path::PathBuf;
use std::sync::Arc;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Default)]
pub struct TlsConfig {
pub enabled: bool,
pub cert_path: Option<PathBuf>,
pub key_path: Option<PathBuf>,
}
#[cfg(feature = "tls")]
pub type TlsAcceptor = tokio_rustls::TlsAcceptor;
#[cfg(not(feature = "tls"))]
pub type TlsAcceptor = ();
pub fn load_acceptor(
cert_path: &std::path::Path,
key_path: &std::path::Path,
) -> Result<Arc<TlsAcceptor>> {
#[cfg(feature = "tls")]
{
real::load_acceptor(cert_path, key_path)
}
#[cfg(not(feature = "tls"))]
{
let _ = (cert_path, key_path);
Err(Error::NotImplemented(
"TLS support not compiled in — rebuild with --features tls",
))
}
}
pub fn maybe_acceptor(cfg: &TlsConfig) -> Result<Option<Arc<TlsAcceptor>>> {
if !cfg.enabled {
return Ok(None);
}
#[cfg(feature = "tls")]
{
let cert = cfg
.cert_path
.as_ref()
.ok_or_else(|| Error::Invalid("tls.cert_path required when tls.enabled".into()))?;
let key = cfg
.key_path
.as_ref()
.ok_or_else(|| Error::Invalid("tls.key_path required when tls.enabled".into()))?;
let acceptor = real::load_acceptor(cert, key)?;
Ok(Some(acceptor))
}
#[cfg(not(feature = "tls"))]
{
Err(Error::NotImplemented(
"tls.enabled=true but binary was built without --features tls",
))
}
}
#[cfg(feature = "tls")]
mod real {
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;
use crate::error::{Error, Result};
pub(super) fn load_acceptor(cert_path: &Path, key_path: &Path) -> Result<Arc<TlsAcceptor>> {
let cert_chain = read_certs(cert_path)?;
if cert_chain.is_empty() {
return Err(Error::Invalid(format!(
"no certificates found in {}",
cert_path.display()
)));
}
let key = read_key(key_path)?;
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, key)
.map_err(|e| Error::Internal(format!("tls server config: {e}")))?;
Ok(Arc::new(TlsAcceptor::from(Arc::new(config))))
}
fn read_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
let file = File::open(path).map_err(|e| {
Error::Internal(format!("open cert {}: {}", path.display(), e))
})?;
let mut reader = BufReader::new(file);
certs(&mut reader)
.collect::<std::result::Result<_, _>>()
.map_err(|e| Error::Internal(format!("parse cert {}: {}", path.display(), e)))
}
fn read_key(path: &Path) -> Result<PrivateKeyDer<'static>> {
{
let file = File::open(path).map_err(|e| {
Error::Internal(format!("open key {}: {}", path.display(), e))
})?;
let mut reader = BufReader::new(file);
let keys: Vec<_> = pkcs8_private_keys(&mut reader)
.collect::<std::result::Result<_, _>>()
.map_err(|e| Error::Internal(format!("parse pkcs8 key {}: {}", path.display(), e)))?;
if let Some(k) = keys.into_iter().next() {
return Ok(PrivateKeyDer::Pkcs8(k));
}
}
let file = File::open(path).map_err(|e| {
Error::Internal(format!("open key {}: {}", path.display(), e))
})?;
let mut reader = BufReader::new(file);
let keys: Vec<_> = rsa_private_keys(&mut reader)
.collect::<std::result::Result<_, _>>()
.map_err(|e| Error::Internal(format!("parse rsa key {}: {}", path.display(), e)))?;
let k = keys
.into_iter()
.next()
.ok_or_else(|| Error::Invalid(format!("no PKCS#8 or RSA key in {}", path.display())))?;
Ok(PrivateKeyDer::Pkcs1(k))
}
}