use std::{env, path::PathBuf};
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub cert_path: PathBuf,
pub key_path: PathBuf,
pub client_ca_path: Option<PathBuf>,
}
impl TlsConfig {
pub fn new(cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
Self {
cert_path: cert_path.into(),
key_path: key_path.into(),
client_ca_path: None,
}
}
pub fn with_client_ca(mut self, client_ca_path: impl Into<PathBuf>) -> Self {
self.client_ca_path = Some(client_ca_path.into());
self
}
pub fn from_env() -> Option<Self> {
let cert_path = env::var("GRPC_TLS_CERT")
.or_else(|_| env::var("TLS_CERT_PATH"))
.ok()?;
let key_path = env::var("GRPC_TLS_KEY")
.or_else(|_| env::var("TLS_KEY_PATH"))
.ok()?;
let client_ca_path = env::var("GRPC_TLS_CLIENT_CA")
.or_else(|_| env::var("TLS_CLIENT_CA_PATH"))
.ok()
.map(PathBuf::from);
Some(Self {
cert_path: PathBuf::from(cert_path),
key_path: PathBuf::from(key_path),
client_ca_path,
})
}
#[cfg(feature = "grpc-tls")]
pub fn load(&self) -> Result<(Vec<u8>, Vec<u8>), TlsError> {
let cert = std::fs::read(&self.cert_path).map_err(|e| TlsError::CertificateLoad {
path: self.cert_path.clone(),
source: e.to_string(),
})?;
let key = std::fs::read(&self.key_path).map_err(|e| TlsError::KeyLoad {
path: self.key_path.clone(),
source: e.to_string(),
})?;
Ok((cert, key))
}
#[cfg(feature = "grpc-tls")]
pub fn load_client_ca(&self) -> Result<Option<Vec<u8>>, TlsError> {
match &self.client_ca_path {
Some(path) => {
let ca = std::fs::read(path).map_err(|e| TlsError::ClientCaLoad {
path: path.clone(),
source: e.to_string(),
})?;
Ok(Some(ca))
}
None => Ok(None),
}
}
}
#[derive(Debug)]
pub enum TlsError {
CertificateLoad {
path: PathBuf,
source: String,
},
KeyLoad {
path: PathBuf,
source: String,
},
ClientCaLoad {
path: PathBuf,
source: String,
},
Configuration(String),
}
impl std::fmt::Display for TlsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TlsError::CertificateLoad { path, source } => {
write!(f, "Failed to load certificate from {:?}: {}", path, source)
}
TlsError::KeyLoad { path, source } => {
write!(f, "Failed to load key from {:?}: {}", path, source)
}
TlsError::ClientCaLoad { path, source } => {
write!(f, "Failed to load client CA from {:?}: {}", path, source)
}
TlsError::Configuration(msg) => write!(f, "TLS configuration error: {}", msg),
}
}
}
impl std::error::Error for TlsError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tls_config_new() {
let config = TlsConfig::new("/path/to/cert.pem", "/path/to/key.pem");
assert_eq!(config.cert_path, PathBuf::from("/path/to/cert.pem"));
assert_eq!(config.key_path, PathBuf::from("/path/to/key.pem"));
assert!(config.client_ca_path.is_none());
}
#[test]
fn test_tls_config_with_client_ca() {
let config = TlsConfig::new("/path/to/cert.pem", "/path/to/key.pem")
.with_client_ca("/path/to/ca.pem");
assert_eq!(
config.client_ca_path,
Some(PathBuf::from("/path/to/ca.pem"))
);
}
#[test]
fn test_tls_error_display() {
let err = TlsError::CertificateLoad {
path: PathBuf::from("/path/to/cert.pem"),
source: "file not found".to_string(),
};
assert!(err.to_string().contains("cert.pem"));
assert!(err.to_string().contains("file not found"));
}
}