#[cfg(any(feature = "postgres", feature = "mysql"))]
use sea_orm::ConnectionTrait;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoolConfig {
max_connections: u32,
min_connections: u32,
idle_timeout: u64,
acquire_timeout: u64,
}
impl PoolConfig {
pub fn new(max_connections: u32, min_connections: u32, idle_timeout: u64, acquire_timeout: u64) -> Self {
Self {
max_connections,
min_connections,
idle_timeout,
acquire_timeout,
}
}
pub fn max_connections(&self) -> u32 {
self.max_connections
}
pub fn min_connections(&self) -> u32 {
self.min_connections
}
pub fn idle_timeout(&self) -> u64 {
self.idle_timeout
}
pub fn acquire_timeout(&self) -> u64 {
self.acquire_timeout
}
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_connections: 5,
min_connections: 1,
idle_timeout: 300,
acquire_timeout: 5000,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DatabaseType {
Postgres,
MySql,
Sqlite,
}
impl DatabaseType {
pub fn parse_database_type(s: &str) -> Self {
let s = s.to_lowercase();
if s.starts_with("postgres") {
DatabaseType::Postgres
} else if s.starts_with("mysql") {
DatabaseType::MySql
} else {
DatabaseType::Sqlite
}
}
pub fn as_str(&self) -> &'static str {
match self {
DatabaseType::Postgres => "postgres",
DatabaseType::MySql => "mysql",
DatabaseType::Sqlite => "sqlite",
}
}
pub fn is_real_database(&self) -> bool {
!matches!(self, DatabaseType::Sqlite)
}
}
impl std::fmt::Display for DatabaseType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
pub use crate::error::ConfigError;
#[derive(Debug, Clone, Default)]
pub struct DbConfigBuilder {
url: Option<String>,
max_connections: Option<u32>,
min_connections: Option<u32>,
idle_timeout: Option<u64>,
acquire_timeout: Option<u64>,
permissions_path: Option<String>,
migrations_dir: Option<PathBuf>,
auto_migrate: Option<bool>,
migration_timeout: Option<u64>,
admin_role: Option<String>,
warmup_timeout: Option<u64>,
warmup_retries: Option<u32>,
}
impl DbConfigBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn url(mut self, url: impl Into<String>) -> Self {
self.url = Some(url.into());
self
}
pub fn max_connections(mut self, n: u32) -> Self {
self.max_connections = Some(n);
self
}
pub fn min_connections(mut self, n: u32) -> Self {
self.min_connections = Some(n);
self
}
pub fn idle_timeout(mut self, timeout: u64) -> Self {
self.idle_timeout = Some(timeout);
self
}
pub fn acquire_timeout(mut self, timeout: u64) -> Self {
self.acquire_timeout = Some(timeout);
self
}
pub fn permissions_path(mut self, path: impl Into<String>) -> Self {
self.permissions_path = Some(path.into());
self
}
pub fn migrations_dir(mut self, path: impl AsRef<Path>) -> Self {
self.migrations_dir = Some(path.as_ref().to_path_buf());
self
}
pub fn auto_migrate(mut self, auto: bool) -> Self {
self.auto_migrate = Some(auto);
self
}
pub fn migration_timeout(mut self, timeout: u64) -> Self {
self.migration_timeout = Some(timeout);
self
}
pub fn admin_role(mut self, role: impl Into<String>) -> Self {
self.admin_role = Some(role.into());
self
}
pub fn warmup_timeout(mut self, timeout: u64) -> Self {
self.warmup_timeout = Some(timeout);
self
}
pub fn warmup_retries(mut self, retries: u32) -> Self {
self.warmup_retries = Some(retries);
self
}
pub fn build(self) -> Result<DbConfig, ConfigError> {
let config = DbConfig {
url: self.url.unwrap_or_default(),
max_connections: self.max_connections.unwrap_or_else(default_max_connections),
min_connections: self.min_connections.unwrap_or_else(default_min_connections),
idle_timeout: self.idle_timeout.unwrap_or_else(default_idle_timeout),
acquire_timeout: self.acquire_timeout.unwrap_or_else(default_acquire_timeout),
permissions_path: self.permissions_path,
migrations_dir: self.migrations_dir,
auto_migrate: self.auto_migrate.unwrap_or(false),
migration_timeout: self.migration_timeout.unwrap_or_else(default_migration_timeout),
admin_role: self.admin_role.unwrap_or_else(default_admin_role),
warmup_timeout: self.warmup_timeout.unwrap_or_else(default_warmup_timeout),
warmup_retries: self.warmup_retries.unwrap_or_else(default_warmup_retries),
};
config.validate()?;
Ok(config)
}
}
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct DbConfig {
#[serde(default)]
url: String,
#[serde(default = "default_max_connections")]
max_connections: u32,
#[serde(default = "default_min_connections")]
min_connections: u32,
#[serde(default = "default_idle_timeout")]
idle_timeout: u64,
#[serde(default = "default_acquire_timeout")]
acquire_timeout: u64,
#[serde(default)]
permissions_path: Option<String>,
#[serde(default)]
migrations_dir: Option<PathBuf>,
#[serde(default)]
auto_migrate: bool,
#[serde(default = "default_migration_timeout")]
migration_timeout: u64,
#[serde(default = "default_admin_role")]
admin_role: String,
#[serde(default = "default_warmup_timeout")]
warmup_timeout: u64,
#[serde(default = "default_warmup_retries")]
warmup_retries: u32,
}
impl DbConfig {
#[deprecated(
since = "0.1.1",
note = "Use url_sanitized() for logging to prevent credential leakage"
)]
pub(crate) fn url(&self) -> &str {
&self.url
}
pub fn url_sanitized(&self) -> String {
sanitize_url_for_logging(&self.url)
}
#[doc(hidden)]
pub(crate) fn url_for_connection(&self) -> &str {
&self.url
}
pub fn max_connections(&self) -> u32 {
self.max_connections
}
pub fn min_connections(&self) -> u32 {
self.min_connections
}
pub fn idle_timeout(&self) -> u64 {
self.idle_timeout
}
pub fn acquire_timeout(&self) -> u64 {
self.acquire_timeout
}
pub fn permissions_path(&self) -> Option<&str> {
self.permissions_path.as_deref()
}
pub fn migrations_dir(&self) -> Option<&Path> {
self.migrations_dir.as_deref()
}
pub fn auto_migrate(&self) -> bool {
self.auto_migrate
}
pub fn migration_timeout(&self) -> u64 {
self.migration_timeout
}
pub fn admin_role(&self) -> &str {
&self.admin_role
}
pub fn warmup_timeout(&self) -> u64 {
self.warmup_timeout
}
pub fn warmup_retries(&self) -> u32 {
self.warmup_retries
}
pub(crate) fn set_url(&mut self, url: String) {
self.url = url;
}
pub(crate) fn set_max_connections(&mut self, max_connections: u32) {
self.max_connections = max_connections;
}
pub(crate) fn set_min_connections(&mut self, min_connections: u32) {
self.min_connections = min_connections;
}
pub(crate) fn set_idle_timeout(&mut self, idle_timeout: u64) {
self.idle_timeout = idle_timeout;
}
pub(crate) fn set_acquire_timeout(&mut self, acquire_timeout: u64) {
self.acquire_timeout = acquire_timeout;
}
pub(crate) fn clone_config(&self) -> Self {
self.clone()
}
}
fn sanitize_url_for_logging(url: &str) -> String {
if url.starts_with("sqlite::memory:") || url.starts_with("sqlite3::memory:") {
return url.to_string();
}
if url.starts_with("sqlite:") || url.starts_with("sqlite3:") {
return url.to_string();
}
if let Some(at_pos) = url.find('@') {
let protocol_end = url.find("://").map(|p| p + 3).unwrap_or(0);
let protocol_part = &url[..protocol_end];
let rest = &url[at_pos..];
format!("{}****@{}", protocol_part, rest)
} else {
url.to_string()
}
}
fn default_admin_role() -> String {
"admin".to_string()
}
fn default_max_connections() -> u32 {
20
}
fn default_min_connections() -> u32 {
5
}
fn default_idle_timeout() -> u64 {
300
}
fn default_acquire_timeout() -> u64 {
5000
}
fn default_migration_timeout() -> u64 {
60
}
fn default_warmup_timeout() -> u64 {
30
}
fn default_warmup_retries() -> u32 {
3
}
impl DbConfig {
pub fn from_env() -> Result<Self, ConfigError> {
const MAX_URL_LENGTH: usize = 2048;
const MAX_ROLE_LENGTH: usize = 64;
const MAX_PATH_LENGTH: usize = 512;
let url = std::env::var("DATABASE_URL").map_err(|_| ConfigError::MissingField("DATABASE_URL"))?;
if url.len() > MAX_URL_LENGTH {
return Err(ConfigError::InvalidFormat("URL too long".to_string()));
}
let max_connections = std::env::var("DB_MAX_CONNECTIONS")
.unwrap_or_else(|_| "20".to_string())
.parse()
.map_err(|e| ConfigError::InvalidFormat(format!("DB_MAX_CONNECTIONS: {}", e)))?;
let min_connections = std::env::var("DB_MIN_CONNECTIONS")
.unwrap_or_else(|_| "5".to_string())
.parse()
.map_err(|e| ConfigError::InvalidFormat(format!("DB_MIN_CONNECTIONS: {}", e)))?;
let idle_timeout = std::env::var("DB_IDLE_TIMEOUT")
.unwrap_or_else(|_| "300".to_string())
.parse()
.map_err(|e| ConfigError::InvalidFormat(format!("DB_IDLE_TIMEOUT: {}", e)))?;
let acquire_timeout = std::env::var("DB_ACQUIRE_TIMEOUT")
.unwrap_or_else(|_| "5000".to_string())
.parse()
.map_err(|e| ConfigError::InvalidFormat(format!("DB_ACQUIRE_TIMEOUT: {}", e)))?;
let admin_role = std::env::var("DB_ADMIN_ROLE").unwrap_or_else(|_| "admin".to_string());
if admin_role.len() > MAX_ROLE_LENGTH {
return Err(ConfigError::InvalidFormat("admin_role too long".to_string()));
}
Ok(Self {
url,
max_connections,
min_connections,
idle_timeout,
acquire_timeout,
permissions_path: std::env::var("DB_PERMISSIONS_PATH").ok(),
migrations_dir: std::env::var("DB_MIGRATIONS_DIR").ok().map(PathBuf::from),
auto_migrate: std::env::var("DB_AUTO_MIGRATE")
.unwrap_or_else(|_| "false".to_string())
.parse()
.unwrap_or(false),
migration_timeout: std::env::var("DB_MIGRATION_TIMEOUT")
.unwrap_or_else(|_| "60".to_string())
.parse()
.unwrap_or(60),
admin_role,
warmup_timeout: std::env::var("DB_WARMUP_TIMEOUT")
.unwrap_or_else(|_| "30".to_string())
.parse()
.unwrap_or(30),
warmup_retries: std::env::var("DB_WARMUP_RETRIES")
.unwrap_or_else(|_| "3".to_string())
.parse()
.unwrap_or(3),
})
}
#[cfg(feature = "config-yaml")]
pub fn from_yaml_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(path.as_ref())?;
Self::from_yaml_str(&content)
}
#[cfg(feature = "config-toml")]
pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(path.as_ref())?;
if let Ok(config) = toml::from_str::<DbConfig>(&content) {
if !config.url.is_empty() {
return Ok(config);
}
}
#[derive(Debug, serde::Deserialize)]
struct ConfigWrapper {
database: DbConfig,
}
let wrapper: ConfigWrapper =
toml::from_str(&content).map_err(|e| ConfigError::InvalidFormat(format!("TOML parse error: {}", e)))?;
wrapper.database.validate()?;
Ok(wrapper.database)
}
#[cfg(feature = "config-yaml")]
pub fn from_yaml_str(yaml: &str) -> Result<Self, ConfigError> {
let config: DbConfig =
serde_yaml::from_str(yaml).map_err(|e| ConfigError::InvalidFormat(format!("YAML parse error: {}", e)))?;
config.validate()?;
Ok(config)
}
#[cfg(feature = "config-toml")]
pub fn from_toml_str(toml: &str) -> Result<Self, ConfigError> {
let config: DbConfig =
toml::from_str(toml).map_err(|e| ConfigError::InvalidFormat(format!("TOML parse error: {}", e)))?;
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.url.is_empty() {
return Err(ConfigError::MissingField("url"));
}
self.validate_url_format()?;
if self.max_connections == 0 {
return Err(ConfigError::MissingField("max_connections"));
}
if self.max_connections > 1000 {
return Err(ConfigError::ValidationFailed);
}
if self.min_connections == 0 || self.min_connections > 100 {
return Err(ConfigError::ValidationFailed);
}
if self.min_connections > self.max_connections {
return Err(ConfigError::InvalidFormat(
"min_connections > max_connections".to_string(),
));
}
Ok(())
}
fn validate_url_format(&self) -> Result<(), ConfigError> {
if self.url.starts_with("sqlite::memory:") || self.url.starts_with("sqlite3::memory:") {
return Ok(());
}
if self.url.starts_with("sqlite:") || self.url.starts_with("sqlite3:") {
return Ok(());
}
let url = url::Url::parse(&self.url).map_err(|_| ConfigError::InvalidUrl("Invalid URL format".to_string()))?;
let protocol = url.scheme();
if !protocol
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '.' || c == '-')
{
return Err(ConfigError::InvalidUrl(
"Protocol contains invalid characters".to_string(),
));
}
let valid_protocols = ["sqlite", "sqlite3", "postgres", "postgresql", "mysql"];
let is_valid_protocol = valid_protocols.contains(&protocol)
|| (protocol.starts_with("sqlite") && ["file", "mem"].contains(&protocol));
if !is_valid_protocol {
return Err(ConfigError::UnsupportedProtocol);
}
if let Some(host) = url.host() {
let host_str = host.to_string();
if host_str
.chars()
.any(|c| c.is_whitespace() || matches!(c, '\'' | '"' | ';' | '|' | '&' | '$' | '`'))
{
return Err(ConfigError::InvalidUrl(
"Hostname contains invalid characters".to_string(),
));
}
}
if let Some(port) = url.port() {
if port == 0 {
return Err(ConfigError::InvalidUrl(
"Port number out of valid range (1-65535)".to_string(),
));
}
}
Ok(())
}
pub fn idle_timeout_duration(&self) -> Duration {
Duration::from_secs(self.idle_timeout)
}
pub fn acquire_timeout_duration(&self) -> Duration {
Duration::from_millis(self.acquire_timeout)
}
pub fn migration_timeout_duration(&self) -> Duration {
Duration::from_secs(self.migration_timeout)
}
#[cfg(feature = "config-yaml")]
pub fn to_yaml(&self) -> Result<String, ConfigError> {
serde_yaml::to_string(self).map_err(|e| ConfigError::InvalidFormat(format!("YAML serialize error: {}", e)))
}
#[cfg(feature = "config-toml")]
pub fn to_toml(&self) -> Result<String, ConfigError> {
toml::to_string(self).map_err(|e| ConfigError::InvalidFormat(format!("TOML serialize error: {}", e)))
}
pub fn from_config_files() -> Result<Self, ConfigError> {
#[cfg(all(feature = "config-yaml", feature = "config-toml"))]
{
let config_paths = [
"dbnexus.yaml",
"dbnexus.toml",
"config/dbnexus.yaml",
"config/dbnexus.toml",
];
for config_path in &config_paths {
let path = Path::new(config_path);
if Self::is_safe_config_path(path)? {
tracing::info!("Loading configuration from: {}", config_path);
if config_path.ends_with(".yaml") || config_path.ends_with(".yml") {
return Self::from_yaml_file(path);
} else {
return Self::from_toml_file(path);
}
}
}
if let Some(home_dir) = home::home_dir() {
let user_config_paths = [
home_dir.join(".config").join("dbnexus").join("config.yaml"),
home_dir.join(".dbnexus").join("config.toml"),
];
for config_path in &user_config_paths {
if Self::is_safe_config_path(config_path)? {
tracing::info!("Loading configuration from: {}", config_path.display());
if config_path.ends_with(".yaml") {
return Self::from_yaml_file(config_path);
} else {
return Self::from_toml_file(config_path);
}
}
}
}
}
#[cfg(all(feature = "config-yaml", not(feature = "config-toml")))]
{
let config_paths = ["dbnexus.yaml", "config/dbnexus.yaml"];
for config_path in &config_paths {
let path = Path::new(config_path);
if Self::is_safe_config_path(path)? {
tracing::info!("Loading configuration from: {}", config_path);
return Self::from_yaml_file(path);
}
}
if let Some(home_dir) = home::home_dir() {
let user_config_paths = [home_dir.join(".config").join("dbnexus").join("config.yaml")];
for config_path in &user_config_paths {
if Self::is_safe_config_path(config_path)? {
tracing::info!("Loading configuration from: {}", config_path.display());
return Self::from_yaml_file(config_path);
}
}
}
}
#[cfg(all(not(feature = "config-yaml"), feature = "config-toml"))]
{
let config_paths = ["dbnexus.toml", "config/dbnexus.toml"];
for config_path in &config_paths {
let path = Path::new(config_path);
if Self::is_safe_config_path(path)? {
tracing::info!("Loading configuration from: {}", config_path);
return Self::from_toml_file(path);
}
}
if let Some(home_dir) = home::home_dir() {
let user_config_paths = [home_dir.join(".dbnexus").join("config.toml")];
for config_path in &user_config_paths {
if Self::is_safe_config_path(config_path)? {
tracing::info!("Loading configuration from: {}", config_path.display());
return Self::from_toml_file(config_path);
}
}
}
}
Err(ConfigError::FileNotFound)
}
fn is_safe_config_path(path: &Path) -> Result<bool, ConfigError> {
let path_str = path.to_string_lossy();
if path_str.contains('\0') {
tracing::warn!("Rejected config path with null byte: {:?}", path);
return Ok(false);
}
if path_str.contains("..") {
tracing::warn!("Rejected config path with parent directory traversal: {:?}", path);
return Ok(false);
}
if path_str.contains(".\\") || path_str.starts_with(".\\") {
tracing::warn!("Rejected config path with Windows-style traversal: {:?}", path);
return Ok(false);
}
let canonical = match path.canonicalize() {
Ok(p) => p,
Err(e) => {
tracing::warn!("Failed to canonicalize config path {:?}: {}", path, e);
return Ok(false);
}
};
if canonical.is_absolute() {
let forbidden_prefixes = [
"/etc", "/usr", "/var", "/root", "/boot", "/srv", "/opt", "/bin", "/sbin", "/lib", "/lib64",
];
for prefix in &forbidden_prefixes {
if canonical.starts_with(prefix) {
tracing::warn!("Rejected config path in system directory: {:?}", path);
return Ok(false);
}
}
}
if path.is_symlink() {
tracing::warn!("Rejected symlink config path: {:?}", path);
return Ok(false);
}
if canonical.to_string_lossy().contains("..") {
tracing::warn!(
"Rejected config path with hidden traversal after canonicalization: {:?}",
path
);
return Ok(false);
}
if canonical.is_dir() {
tracing::warn!("Rejected config path pointing to directory: {:?}", path);
return Ok(false);
}
Ok(true)
}
}
#[derive(Debug, Clone)]
pub struct ConfigCorrector;
impl ConfigCorrector {
pub(crate) async fn query_database_max_connections(
connection: &sea_orm::DatabaseConnection,
db_type: DatabaseType,
) -> u32 {
let _ = connection;
match db_type {
DatabaseType::Postgres => {
#[cfg(feature = "postgres")]
{
let result = connection.execute_unprepared("SHOW max_connections").await;
match result {
Ok(result) => {
let rows_affected = result.rows_affected();
if rows_affected > 0 {
tracing::info!(
"PostgreSQL max_connections query executed, using conservative estimate"
);
}
}
Err(e) => {
tracing::warn!("Failed to query PostgreSQL max_connections: {}", e);
}
}
100
}
#[cfg(not(feature = "postgres"))]
{
100
}
}
DatabaseType::MySql => {
#[cfg(feature = "mysql")]
{
let result = connection
.execute_unprepared("SHOW VARIABLES LIKE 'max_connections'")
.await;
match result {
Ok(_) => {
tracing::info!("MySQL max_connections query executed, using conservative estimate");
}
Err(e) => {
tracing::warn!("Failed to query MySQL max_connections: {}", e);
}
}
200
}
#[cfg(not(feature = "mysql"))]
{
200
}
}
DatabaseType::Sqlite => {
u32::MAX
}
}
}
pub(crate) fn auto_correct(mut config: DbConfig) -> DbConfig {
if config.min_connections > config.max_connections {
tracing::warn!(
"Correcting min_connections ({}) > max_connections ({}), setting min to max",
config.min_connections(),
config.max_connections()
);
config.min_connections = config.max_connections;
}
if config.min_connections == 0 {
config.min_connections = 1;
tracing::warn!("Correcting min_connections from 0 to 1");
}
if config.max_connections == 0 {
config.max_connections = 10;
tracing::warn!("Correcting max_connections from 0 to 10");
}
if config.acquire_timeout == 0 {
config.acquire_timeout = 5000;
} else if config.acquire_timeout < 1000 {
tracing::warn!(
"Adjusting acquire_timeout from {}ms to minimum 1000ms",
config.acquire_timeout()
);
config.acquire_timeout = 1000;
} else if config.acquire_timeout > 60000 {
tracing::warn!(
"Adjusting acquire_timeout from {}ms to maximum 60000ms",
config.acquire_timeout()
);
config.acquire_timeout = 60000;
}
if config.idle_timeout == 0 {
config.idle_timeout = 300;
} else if config.idle_timeout < 30 {
tracing::warn!("Adjusting idle_timeout from {}s to minimum 30s", config.idle_timeout());
config.idle_timeout = 30;
} else if config.idle_timeout > 3600 {
tracing::warn!(
"Adjusting idle_timeout from {}s to maximum 3600s",
config.idle_timeout()
);
config.idle_timeout = 3600;
}
if config.url.starts_with("mysql") || config.url.starts_with("postgres") {
if config.url.contains("localhost") && !config.url.contains("?") && !config.url.contains(";") {
match config.url.as_str() {
url if url.starts_with("mysql://") => {
config.url = format!("{}?connect_timeout=10", url);
}
url if url.starts_with("postgres://") => {
config.url = format!("{}?connect_timeout=10", url);
}
_ => {} }
}
}
config
}
pub(crate) fn validate_config(config: &DbConfig) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
if config.url.is_empty() {
errors.push("Database URL cannot be empty".to_string());
}
if config.max_connections() == 0 {
errors.push("max_connections must be greater than 0".to_string());
}
if config.min_connections() > config.max_connections() {
errors.push("min_connections cannot be greater than max_connections".to_string());
}
if config.acquire_timeout() == 0 {
errors.push("acquire_timeout must be greater than 0".to_string());
}
if config.idle_timeout() == 0 {
errors.push("idle_timeout must be greater than 0".to_string());
}
if errors.is_empty() { Ok(()) } else { Err(errors) }
}
pub(crate) fn load_and_correct_from_env() -> Result<DbConfig, ConfigError> {
let mut config = DbConfig::from_env()?;
config = ConfigCorrector::auto_correct(config);
Ok(config)
}
#[cfg(feature = "config-yaml")]
pub(crate) fn load_and_correct_from_file(path: impl AsRef<Path>) -> Result<DbConfig, ConfigError> {
let mut config = DbConfig::from_yaml_file(path)?;
config = ConfigCorrector::auto_correct(config);
Ok(config)
}
pub(crate) fn validate_and_correct(config: &DbConfig) -> Result<DbConfig, Vec<String>> {
let errors = Self::validate_config(config);
let corrected_config = Self::auto_correct(config.clone());
match errors {
Ok(()) => Ok(corrected_config),
Err(mut validation_errors) => {
validation_errors.extend([
"Some configuration values were automatically corrected".to_string(),
"Consider updating your configuration file to match corrected values".to_string(),
]);
Err(validation_errors)
}
}
}
pub(crate) fn get_actual_config(config: &DbConfig) -> DbConfig {
Self::auto_correct(config.clone())
}
pub(crate) async fn auto_correct_with_database_capability(
mut config: DbConfig,
connection: &sea_orm::DatabaseConnection,
db_type: DatabaseType,
) -> DbConfig {
let db_max_connections = Self::query_database_max_connections(connection, db_type).await;
let recommended_max = (db_max_connections as f64 * 0.8).floor() as u32;
if config.max_connections() > recommended_max {
tracing::warn!(
"Config corrected: max_connections {} -> {} (80% of database limit {})",
config.max_connections(),
recommended_max,
db_max_connections
);
config.max_connections = recommended_max;
}
if config.min_connections() > config.max_connections() {
tracing::warn!(
"Config corrected: min_connections {} -> {} (equal to max_connections)",
config.min_connections(),
config.max_connections()
);
config.min_connections = config.max_connections();
}
config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config_values() {
let config = DbConfig::default();
assert_eq!(config.url_sanitized(), "");
assert_eq!(config.max_connections(), 0);
assert_eq!(config.min_connections(), 0);
assert_eq!(config.idle_timeout(), 0);
assert_eq!(config.acquire_timeout(), 0);
assert!(config.permissions_path().is_none());
}
#[test]
fn test_config_duration_conversion() {
let config = DbConfigBuilder::new()
.url("sqlite::memory:")
.max_connections(10)
.min_connections(2)
.idle_timeout(300)
.acquire_timeout(5000)
.admin_role("admin")
.build()
.unwrap();
assert_eq!(config.idle_timeout_duration(), Duration::from_secs(300));
assert_eq!(config.acquire_timeout_duration(), Duration::from_millis(5000));
}
#[test]
fn test_get_actual_config() {
let mut config = DbConfigBuilder::new()
.url("sqlite::memory:")
.max_connections(10)
.min_connections(10)
.admin_role("admin")
.build()
.unwrap();
config.set_min_connections(30);
let actual = ConfigCorrector::get_actual_config(&config);
assert_eq!(actual.max_connections(), 10);
assert_eq!(actual.min_connections(), 10);
}
#[test]
fn test_get_actual_config_zero_values() {
let config = DbConfigBuilder::new()
.url("sqlite::memory:")
.max_connections(5)
.min_connections(5)
.idle_timeout(0)
.acquire_timeout(0)
.admin_role("admin")
.build()
.unwrap();
let mut zero_config = config.clone();
zero_config.set_max_connections(0);
zero_config.set_min_connections(0);
zero_config.set_idle_timeout(0);
zero_config.set_acquire_timeout(0);
let actual = ConfigCorrector::get_actual_config(&zero_config);
assert_eq!(actual.max_connections(), 10);
assert_eq!(actual.min_connections(), 1);
assert_eq!(actual.idle_timeout(), 300);
assert_eq!(actual.acquire_timeout(), 5000);
}
#[test]
fn test_config_builder_basic() {
let config = DbConfigBuilder::new()
.url("sqlite::memory:")
.max_connections(20)
.min_connections(5)
.build()
.unwrap();
assert_eq!(config.url_sanitized(), "sqlite::memory:");
assert_eq!(config.max_connections(), 20);
assert_eq!(config.min_connections(), 5);
}
#[test]
fn test_config_builder_all_fields() {
let config = DbConfigBuilder::new()
.url("sqlite::memory:")
.max_connections(20)
.min_connections(5)
.idle_timeout(300)
.acquire_timeout(5000)
.permissions_path("/etc/dbnexus/permissions.yaml")
.auto_migrate(true)
.admin_role("superuser")
.build()
.unwrap();
assert_eq!(config.url_sanitized(), "sqlite::memory:");
assert_eq!(config.max_connections(), 20);
assert_eq!(config.min_connections(), 5);
assert_eq!(config.idle_timeout(), 300);
assert_eq!(config.acquire_timeout(), 5000);
assert_eq!(config.permissions_path(), Some("/etc/dbnexus/permissions.yaml"));
assert!(config.auto_migrate());
assert_eq!(config.admin_role(), "superuser");
}
#[test]
fn test_config_builder_validation_failure() {
let result = DbConfigBuilder::new()
.url("sqlite::memory:")
.max_connections(10)
.min_connections(20)
.build();
assert!(result.is_err());
}
#[test]
fn test_config_builder_defaults() {
let config = DbConfigBuilder::new().url("sqlite::memory:").build().unwrap();
assert_eq!(config.max_connections(), 20);
assert_eq!(config.min_connections(), 5);
assert_eq!(config.idle_timeout(), 300);
assert_eq!(config.acquire_timeout(), 5000); assert_eq!(config.admin_role(), "admin");
}
#[cfg(feature = "config-yaml")]
#[test]
fn test_config_loader() {
let yaml = r#"
url: "sqlite::memory:"
max_connections: 20
min_connections: 5
"#;
let config = DbConfig::from_yaml_str(yaml).unwrap();
{
assert_eq!(config.url_sanitized(), "sqlite::memory:");
assert_eq!(config.max_connections(), 20);
}
}
#[test]
fn test_config_validation_empty_url() {
let config = DbConfigBuilder::new().build().unwrap_err();
assert_eq!(config.to_string(), "Missing required field: url");
}
#[test]
fn test_config_validation_invalid_connections() {
let result = DbConfigBuilder::new().url("sqlite::memory:").max_connections(0).build();
assert!(result.is_err());
}
}