use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum DbSystem {
#[default]
Postgres,
#[serde(rename = "mysql")]
MySql,
#[serde(rename = "mariadb")]
MariaDb,
Sqlite,
#[serde(rename = "mongodb")]
MongoDb,
}
impl DbSystem {
pub fn default_port(&self) -> Option<u16> {
match self {
DbSystem::Postgres => Some(5432),
DbSystem::MySql | DbSystem::MariaDb => Some(3306),
DbSystem::MongoDb => Some(27017),
DbSystem::Sqlite => None,
}
}
pub fn scheme(&self) -> &'static str {
match self {
DbSystem::Postgres => "postgres",
DbSystem::MySql | DbSystem::MariaDb => "mysql",
DbSystem::Sqlite => "sqlite",
DbSystem::MongoDb => "mongodb",
}
}
pub fn is_relational(&self) -> bool {
!matches!(self, DbSystem::MongoDb)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PoolConfig {
#[serde(default = "PoolConfig::default_max_connections")]
pub max_connections: u32,
#[serde(default)]
pub min_connections: u32,
#[serde(default = "PoolConfig::default_acquire_timeout_secs")]
pub acquire_timeout_secs: u64,
#[serde(default)]
pub idle_timeout_secs: Option<u64>,
#[serde(default)]
pub max_lifetime_secs: Option<u64>,
}
impl PoolConfig {
fn default_max_connections() -> u32 {
10
}
fn default_acquire_timeout_secs() -> u64 {
30
}
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_connections: Self::default_max_connections(),
min_connections: 0,
acquire_timeout_secs: Self::default_acquire_timeout_secs(),
idle_timeout_secs: None,
max_lifetime_secs: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DatabaseConfig {
#[serde(default)]
pub system: DbSystem,
#[serde(default = "default_host")]
pub host: String,
#[serde(default)]
pub port: Option<u16>,
#[serde(default)]
pub database: String,
#[serde(default)]
pub username: Option<String>,
#[serde(default)]
pub password: Option<String>,
#[serde(default)]
pub url: Option<String>,
#[serde(default)]
pub options: BTreeMap<String, String>,
#[serde(default)]
pub pool: PoolConfig,
}
fn default_host() -> String {
"localhost".to_owned()
}
impl DatabaseConfig {
pub fn effective_port(&self) -> Option<u16> {
self.port.or_else(|| self.system.default_port())
}
pub fn connection_url(&self) -> String {
if let Some(url) = &self.url {
return url.clone();
}
let scheme = self.system.scheme();
if self.system == DbSystem::Sqlite {
return format!("{scheme}://{}", self.database);
}
let mut url = format!("{scheme}://");
if let Some(user) = &self.username {
url.push_str(user);
if let Some(password) = &self.password {
url.push(':');
url.push_str(password);
}
url.push('@');
}
url.push_str(&self.host);
if let Some(port) = self.effective_port() {
url.push(':');
url.push_str(&port.to_string());
}
url.push('/');
url.push_str(&self.database);
if !self.options.is_empty() {
let query =
self.options.iter().map(|(k, v)| format!("{k}={v}")).collect::<Vec<_>>().join("&");
url.push('?');
url.push_str(&query);
}
url
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn deserializes_with_defaults() {
let cfg: DatabaseConfig = serde_json::from_value(json!({
"system": "postgres",
"database": "app",
"username": "svc",
"password": "pw"
}))
.unwrap();
assert_eq!(cfg.host, "localhost");
assert_eq!(cfg.effective_port(), Some(5432));
assert_eq!(cfg.pool.max_connections, 10);
assert_eq!(cfg.connection_url(), "postgres://svc:pw@localhost:5432/app");
}
#[test]
fn url_field_overrides_components() {
let cfg = DatabaseConfig { url: Some("postgres://custom/db".into()), ..Default::default() };
assert_eq!(cfg.connection_url(), "postgres://custom/db");
}
#[test]
fn mongodb_url_with_options() {
let mut options = BTreeMap::new();
options.insert("replicaSet".to_string(), "rs0".to_string());
let cfg = DatabaseConfig {
system: DbSystem::MongoDb,
host: "mongo".into(),
database: "app".into(),
options,
..Default::default()
};
assert_eq!(cfg.effective_port(), Some(27017));
assert!(!cfg.system.is_relational());
assert_eq!(cfg.connection_url(), "mongodb://mongo:27017/app?replicaSet=rs0");
}
#[test]
fn db_system_external_names_are_natural() {
let cases = [
("postgres", DbSystem::Postgres),
("mysql", DbSystem::MySql),
("mariadb", DbSystem::MariaDb),
("sqlite", DbSystem::Sqlite),
("mongodb", DbSystem::MongoDb),
];
for (name, expected) in cases {
let parsed: DbSystem = serde_json::from_value(json!(name)).unwrap();
assert_eq!(parsed, expected, "deserializing {name}");
assert_eq!(serde_json::to_value(expected).unwrap(), json!(name));
}
}
#[test]
fn sqlite_uses_path() {
let cfg = DatabaseConfig {
system: DbSystem::Sqlite,
database: "/var/lib/app.db".into(),
..Default::default()
};
assert_eq!(cfg.effective_port(), None);
assert_eq!(cfg.connection_url(), "sqlite:///var/lib/app.db");
}
}