use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::error::{ForgeError, Result};
use super::default_true;
use super::types::DurationStr;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
#[non_exhaustive]
pub struct DatabaseConfig {
#[serde(default)]
pub url: String,
#[serde(default = "default_pool_size")]
pub pool_size: u32,
#[serde(default = "default_pool_timeout")]
pub pool_timeout: DurationStr,
#[serde(default = "default_statement_timeout")]
pub statement_timeout: DurationStr,
#[serde(default)]
pub replica_urls: Vec<String>,
#[serde(default)]
pub read_from_replica: bool,
#[serde(default)]
pub replica_pool_size: Option<u32>,
#[serde(default)]
pub min_pool_size: u32,
#[serde(default = "default_true")]
pub test_before_acquire: bool,
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
url: String::new(),
pool_size: default_pool_size(),
pool_timeout: default_pool_timeout(),
statement_timeout: default_statement_timeout(),
replica_urls: Vec::new(),
read_from_replica: false,
replica_pool_size: None,
min_pool_size: 0,
test_before_acquire: true,
}
}
}
impl DatabaseConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn url(&self) -> &str {
&self.url
}
pub fn validate(&self) -> Result<()> {
if self.url.is_empty() {
return Err(ForgeError::config(
"database.url is required. \
Set database.url to a PostgreSQL connection string \
(e.g., \"postgres://user:pass@localhost/mydb\").",
));
}
Ok(())
}
}
fn default_pool_size() -> u32 {
100
}
fn default_pool_timeout() -> DurationStr {
DurationStr::new(Duration::from_secs(30))
}
fn default_statement_timeout() -> DurationStr {
DurationStr::new(Duration::from_secs(30))
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
#[test]
fn test_default_database_config() {
let config = DatabaseConfig::default();
assert_eq!(config.pool_size, 100);
assert_eq!(config.pool_timeout.as_secs(), 30);
assert!(config.url.is_empty());
}
#[test]
fn test_new_config() {
let config = DatabaseConfig::new("postgres://localhost/test");
assert_eq!(config.url(), "postgres://localhost/test");
}
#[test]
fn test_parse_config() {
let toml = r#"
url = "postgres://localhost/test"
pool_size = 100
replica_urls = ["postgres://replica1/test", "postgres://replica2/test"]
read_from_replica = true
"#;
let config: DatabaseConfig = toml::from_str(toml).unwrap();
assert_eq!(config.pool_size, 100);
assert_eq!(config.url(), "postgres://localhost/test");
assert_eq!(config.replica_urls.len(), 2);
assert!(config.read_from_replica);
}
#[test]
fn test_validate_with_url() {
let config = DatabaseConfig::new("postgres://localhost/test");
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_empty_url() {
let config = DatabaseConfig::default();
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("database.url is required"));
}
#[test]
fn test_rejects_legacy_pools_blocks() {
let toml = r#"
url = "postgres://localhost/test"
[pools.jobs]
size = 10
"#;
let err = toml::from_str::<DatabaseConfig>(toml).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("unknown field"),
"expected unknown-field error, got: {msg}"
);
}
}