use super::utils::add_query_param;
use crate::core::errors::DataProfilerError;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SslConfig {
pub require_ssl: bool,
pub ca_cert_path: Option<String>,
pub client_cert_path: Option<String>,
pub client_key_path: Option<String>,
pub verify_server_cert: bool,
pub ssl_mode: Option<String>,
}
impl Default for SslConfig {
fn default() -> Self {
Self {
require_ssl: false,
ca_cert_path: None,
client_cert_path: None,
client_key_path: None,
verify_server_cert: true,
ssl_mode: Some("prefer".to_string()),
}
}
}
impl SslConfig {
pub fn production() -> Self {
Self {
require_ssl: true,
ca_cert_path: None, client_cert_path: None,
client_key_path: None,
verify_server_cert: true,
ssl_mode: Some("require".to_string()),
}
}
pub fn development() -> Self {
Self {
require_ssl: false,
ca_cert_path: None,
client_cert_path: None,
client_key_path: None,
verify_server_cert: false,
ssl_mode: Some("prefer".to_string()),
}
}
pub fn validate(&self) -> Result<(), DataProfilerError> {
if let Some(ca_cert_path) = &self.ca_cert_path
&& !Path::new(ca_cert_path).exists()
{
return Err(DataProfilerError::database_ssl(&format!(
"CA certificate file not found: {}",
ca_cert_path
)));
}
if let Some(client_cert_path) = &self.client_cert_path
&& !Path::new(client_cert_path).exists()
{
return Err(DataProfilerError::database_ssl(&format!(
"Client certificate file not found: {}",
client_cert_path
)));
}
if let Some(client_key_path) = &self.client_key_path
&& !Path::new(client_key_path).exists()
{
return Err(DataProfilerError::database_ssl(&format!(
"Client private key file not found: {}",
client_key_path
)));
}
if let Some(ssl_mode) = &self.ssl_mode {
let valid_modes = [
"disable",
"allow",
"prefer",
"require",
"verify-ca",
"verify-full",
];
if !valid_modes.contains(&ssl_mode.as_str()) {
return Err(DataProfilerError::database_ssl(&format!(
"Invalid SSL mode '{}'. Valid modes: {}",
ssl_mode,
valid_modes.join(", ")
)));
}
}
if self.require_ssl && !self.verify_server_cert {
log::warn!(
"SSL required but server certificate verification is disabled. This may be insecure."
);
}
if !self.require_ssl {
log::warn!("SSL not required. Database connections may be unencrypted.");
}
Ok(())
}
pub fn apply_to_connection_string(
&self,
mut connection_string: String,
database_type: &str,
) -> String {
match database_type {
"postgresql" => {
if let Some(ssl_mode) = &self.ssl_mode {
connection_string = add_query_param(connection_string, "sslmode", ssl_mode);
}
if let Some(ca_cert) = &self.ca_cert_path {
connection_string = add_query_param(connection_string, "sslrootcert", ca_cert);
}
if let Some(client_cert) = &self.client_cert_path {
connection_string = add_query_param(connection_string, "sslcert", client_cert);
}
if let Some(client_key) = &self.client_key_path {
connection_string = add_query_param(connection_string, "sslkey", client_key);
}
}
"mysql" => {
if self.require_ssl {
connection_string = add_query_param(connection_string, "tls", "true");
}
if !self.verify_server_cert {
connection_string =
add_query_param(connection_string, "tls-skip-verify", "true");
}
if let Some(ca_cert) = &self.ca_cert_path {
connection_string = add_query_param(connection_string, "tls-ca", ca_cert);
}
}
"sqlite" => {
if self.require_ssl {
log::warn!("SSL configuration ignored for SQLite (embedded database)");
}
}
_ => {
log::warn!(
"SSL configuration validation for '{}' not fully implemented",
database_type
);
}
}
connection_string
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ssl_config_validation() {
let mut config = SslConfig::default();
assert!(config.validate().is_ok());
config.ca_cert_path = Some("/nonexistent/path".to_string());
assert!(config.validate().is_err());
}
}