use crate::config::{NetworkSettings, ServerConfig};
use crate::server::{ServerError, ServerResult};
use std::fs;
use std::path::Path;
use tonic::transport::{Certificate, Identity, ServerTlsConfig};
use tracing::{debug, info};
pub struct TlsServerBuilder;
impl TlsServerBuilder {
pub fn build(config: &ServerConfig) -> ServerResult<Option<ServerTlsConfig>> {
let network = &config.network;
if !network.tls_enabled {
debug!("TLS is not enabled");
return Ok(None);
}
info!("Building TLS configuration");
let cert_path = network.tls_cert.as_ref().ok_or_else(|| {
ServerError::ConfigValidation(
"TLS cert path is required when TLS is enabled".to_string(),
)
})?;
let key_path = network.tls_key.as_ref().ok_or_else(|| {
ServerError::ConfigValidation(
"TLS key path is required when TLS is enabled".to_string(),
)
})?;
let cert_pem = Self::load_file(cert_path)
.map_err(|e| ServerError::TlsSetup(format!("Failed to load certificate: {}", e)))?;
let key_pem = Self::load_file(key_path)
.map_err(|e| ServerError::TlsSetup(format!("Failed to load private key: {}", e)))?;
let identity = Identity::from_pem(&cert_pem, &key_pem);
let mut tls_config = ServerTlsConfig::new().identity(identity);
if network.require_client_cert {
info!("mTLS enabled - requiring client certificates");
let ca_path = network.tls_ca.as_ref().ok_or_else(|| {
ServerError::ConfigValidation(
"TLS CA path is required when client certificates are required".to_string(),
)
})?;
let ca_pem = Self::load_file(ca_path).map_err(|e| {
ServerError::TlsSetup(format!("Failed to load CA certificate: {}", e))
})?;
let ca_cert = Certificate::from_pem(&ca_pem);
tls_config = tls_config
.client_ca_root(ca_cert)
.client_auth_optional(false); }
info!("TLS configuration built successfully");
Ok(Some(tls_config))
}
fn load_file(path: &Path) -> ServerResult<Vec<u8>> {
fs::read(path)
.map_err(|e| ServerError::TlsSetup(format!("Failed to read file {:?}: {}", path, e)))
}
}
pub struct TlsClientBuilder;
impl TlsClientBuilder {
pub fn build(
network: &NetworkSettings,
) -> ServerResult<Option<tonic::transport::ClientTlsConfig>> {
if !network.tls_enabled {
debug!("TLS is not enabled for client connections");
return Ok(None);
}
info!("Building client TLS configuration");
let mut tls_config = tonic::transport::ClientTlsConfig::new();
if network.require_client_cert {
let cert_path = network.tls_cert.as_ref().ok_or_else(|| {
ServerError::ConfigValidation(
"TLS cert path is required for client mTLS".to_string(),
)
})?;
let key_path = network.tls_key.as_ref().ok_or_else(|| {
ServerError::ConfigValidation(
"TLS key path is required for client mTLS".to_string(),
)
})?;
let cert_pem = TlsServerBuilder::load_file(cert_path).map_err(|e| {
ServerError::TlsSetup(format!("Failed to load client certificate: {}", e))
})?;
let key_pem = TlsServerBuilder::load_file(key_path).map_err(|e| {
ServerError::TlsSetup(format!("Failed to load client private key: {}", e))
})?;
let identity = Identity::from_pem(&cert_pem, &key_pem);
tls_config = tls_config.identity(identity);
}
if let Some(ca_path) = &network.tls_ca {
let ca_pem = TlsServerBuilder::load_file(ca_path).map_err(|e| {
ServerError::TlsSetup(format!("Failed to load CA certificate for client: {}", e))
})?;
let ca_cert = Certificate::from_pem(&ca_pem);
tls_config = tls_config.ca_certificate(ca_cert);
}
info!("Client TLS configuration built successfully");
Ok(Some(tls_config))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{
AuthSettings, AuthorizationSettings, ClusterSettings, LoggingSettings, MetricsSettings,
ServerSettings, StorageSettings,
};
use std::env;
use std::fs;
use std::path::PathBuf;
fn create_test_config(
tls_enabled: bool,
require_client_cert: bool,
) -> (ServerConfig, tempfile::TempDir) {
let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
let data_dir = temp_dir.path().join("data");
fs::create_dir(&data_dir).expect("Failed to create data dir");
let (tls_cert, tls_key, tls_ca) = if tls_enabled {
let cert_path = temp_dir.path().join("server.crt");
let key_path = temp_dir.path().join("server.key");
let ca_path = temp_dir.path().join("ca.crt");
fs::write(
&cert_path,
b"-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n",
)
.expect("Failed to write cert");
fs::write(
&key_path,
b"-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----\n",
)
.expect("Failed to write key");
fs::write(
&ca_path,
b"-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n",
)
.expect("Failed to write CA");
(Some(cert_path), Some(key_path), Some(ca_path))
} else {
(None, None, None)
};
let config = ServerConfig {
server: ServerSettings {
bind_address: "127.0.0.1:50051".to_string(),
data_dir,
pid_file: temp_dir.path().join("server.pid"),
max_connections: 1000,
shutdown_timeout_secs: 30,
},
storage: StorageSettings {
engine: "memory".to_string(),
wal: crate::config::WalSettings {
enabled: true,
dir: PathBuf::from("wal"),
segment_size_mb: 64,
sync_mode: "interval".to_string(),
},
memtable_size_mb: 64,
block_cache_size_mb: 256,
compaction: crate::config::CompactionSettings {
strategy: "leveled".to_string(),
num_levels: 7,
level_multiplier: 10,
max_concurrent: 2,
},
},
network: NetworkSettings {
tls_enabled,
tls_cert,
tls_key,
tls_ca,
require_client_cert,
connection_timeout_secs: 30,
keepalive_interval_secs: 60,
},
cluster: None,
logging: LoggingSettings {
level: "info".to_string(),
format: "json".to_string(),
file_enabled: false,
file_path: None,
rotation: crate::config::LogRotationSettings::default(),
},
metrics: MetricsSettings {
enabled: true,
bind_address: "127.0.0.1:9090".to_string(),
export_interval_secs: 60,
},
auth: AuthSettings::default(),
authz: AuthorizationSettings {
enabled: false,
default_role: "user".to_string(),
roles_file: None,
policies_file: None,
collection_permissions: true,
default_mode: "deny-by-default".to_string(),
audit_enabled: false,
audit_log_path: None,
},
};
(config, temp_dir)
}
#[test]
fn test_tls_disabled() {
let (config, _temp_dir) = create_test_config(false, false);
let result = TlsServerBuilder::build(&config);
assert!(result.is_ok());
assert!(result.ok().and_then(|x| x).is_none());
}
#[test]
fn test_tls_enabled_basic() {
let (config, _temp_dir) = create_test_config(true, false);
let result = TlsServerBuilder::build(&config);
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_mtls_enabled() {
let (config, _temp_dir) = create_test_config(true, true);
let result = TlsServerBuilder::build(&config);
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_load_file() {
let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
let file_path = temp_dir.path().join("test.txt");
fs::write(&file_path, b"test content").expect("Failed to write file");
let content = TlsServerBuilder::load_file(&file_path);
assert!(content.is_ok());
assert_eq!(content.ok(), Some(b"test content".to_vec()));
}
#[test]
fn test_load_file_not_found() {
let result = TlsServerBuilder::load_file(Path::new("/nonexistent/file.txt"));
assert!(result.is_err());
}
#[test]
fn test_client_tls_disabled() {
let (config, _temp_dir) = create_test_config(false, false);
let result = TlsClientBuilder::build(&config.network);
assert!(result.is_ok());
assert!(result.ok().and_then(|x| x).is_none());
}
#[test]
fn test_client_tls_enabled() {
let (config, _temp_dir) = create_test_config(true, false);
let result = TlsClientBuilder::build(&config.network);
assert!(result.is_ok() || result.is_err());
}
}