use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum DatabaseBackend {
#[default]
Auto,
Postgres,
#[serde(alias = "mysql")]
MySql,
}
impl DatabaseBackend {
pub fn resolve(self, url: &str) -> Result<ResolvedBackend, ConfigError> {
match self {
Self::Postgres => Ok(ResolvedBackend::Postgres),
Self::MySql => Ok(ResolvedBackend::MySql),
Self::Auto => {
let lower = url.to_ascii_lowercase();
if lower.starts_with("postgres://") || lower.starts_with("postgresql://") {
Ok(ResolvedBackend::Postgres)
} else if lower.starts_with("mysql://") {
Ok(ResolvedBackend::MySql)
} else {
let scheme = url.split("://").next().unwrap_or("unknown");
Err(ConfigError::UnknownUrlScheme(format!("{scheme}://...")))
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResolvedBackend {
Postgres,
MySql,
}
impl std::fmt::Display for ResolvedBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Postgres => write!(f, "postgres"),
Self::MySql => write!(f, "mysql"),
}
}
}
#[derive(Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct DatabaseConfig {
#[serde(skip_serializing)]
pub url: String,
#[serde(default)]
pub backend: DatabaseBackend,
#[serde(default = "default_max_connections")]
pub max_connections: u32,
#[serde(default)]
pub min_connections: u32,
#[serde(default = "default_true")]
pub run_migrations: bool,
#[serde(default = "default_acquire_timeout_secs")]
pub acquire_timeout_secs: u64,
}
const fn default_max_connections() -> u32 {
32
}
const fn default_acquire_timeout_secs() -> u64 {
5
}
const fn default_true() -> bool {
true
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
url: String::new(),
backend: DatabaseBackend::Auto,
max_connections: default_max_connections(),
min_connections: 0,
run_migrations: true,
acquire_timeout_secs: default_acquire_timeout_secs(),
}
}
}
impl std::fmt::Debug for DatabaseConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DatabaseConfig")
.field("url", &redact_url_userinfo(&self.url))
.field("backend", &self.backend)
.field("max_connections", &self.max_connections)
.field("min_connections", &self.min_connections)
.field("run_migrations", &self.run_migrations)
.finish()
}
}
fn redact_url_userinfo(url: &str) -> String {
let Some(scheme_end) = url.find("://") else {
return url.to_string();
};
let after_scheme = scheme_end + "://".len();
let rest = &url[after_scheme..];
let path_start = rest.find('/').unwrap_or(rest.len());
let authority = &rest[..path_start];
match authority.find('@') {
Some(at) => {
let mut redacted = String::with_capacity(url.len());
redacted.push_str(&url[..after_scheme]);
redacted.push_str("***");
redacted.push_str(&rest[at..]);
redacted
}
None => url.to_string(),
}
}
impl DatabaseConfig {
pub fn from_env() -> Result<Self, ConfigError> {
let url = std::env::var("DATABASE_URL")
.map_err(|_| ConfigError::MissingEnvVar("DATABASE_URL".to_string()))?;
let backend = match std::env::var("DB_BACKEND") {
Ok(v) => match v.to_lowercase().as_str() {
"postgres" | "postgresql" => DatabaseBackend::Postgres,
"mysql" => DatabaseBackend::MySql,
"auto" => DatabaseBackend::Auto,
other => {
return Err(ConfigError::InvalidBackend(other.to_string()));
}
},
Err(_) => DatabaseBackend::Auto,
};
let max_connections = std::env::var("DB_MAX_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_else(default_max_connections);
let min_connections = std::env::var("DB_MIN_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0);
let run_migrations =
std::env::var("DB_RUN_MIGRATIONS").map_or(true, |v| v != "false" && v != "0");
let acquire_timeout_secs = std::env::var("DB_ACQUIRE_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_else(default_acquire_timeout_secs);
Ok(Self {
url,
backend,
max_connections,
min_connections,
run_migrations,
acquire_timeout_secs,
})
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ConfigError {
#[error("Required environment variable not set: {0}")]
MissingEnvVar(String),
#[error("Failed to determine database backend from URL scheme: {0}")]
UnknownUrlScheme(String),
#[error("Unknown database backend: '{0}' (expected: postgres, mysql, or auto)")]
InvalidBackend(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auto_detect_postgres_url() {
let result = DatabaseBackend::Auto.resolve("postgres://localhost/mydb");
assert_eq!(result.expect("resolve"), ResolvedBackend::Postgres);
}
#[test]
fn auto_detect_postgresql_url() {
let result = DatabaseBackend::Auto.resolve("postgresql://localhost/mydb");
assert_eq!(result.expect("resolve"), ResolvedBackend::Postgres);
}
#[test]
fn auto_detect_mysql_url() {
let result = DatabaseBackend::Auto.resolve("mysql://localhost/mydb");
assert_eq!(result.expect("resolve"), ResolvedBackend::MySql);
}
#[test]
fn auto_detect_unknown_url_is_error() {
let err = DatabaseBackend::Auto
.resolve("sqlite://local.db")
.expect_err("sqlite is not a supported backend");
assert!(
matches!(err, ConfigError::UnknownUrlScheme(_)),
"expected ConfigError::UnknownUrlScheme, got {err:?}"
);
}
#[test]
fn explicit_backend_ignores_url() {
let result = DatabaseBackend::Postgres.resolve("mysql://localhost/mydb");
assert_eq!(result.expect("resolve"), ResolvedBackend::Postgres);
}
#[test]
fn default_backend_is_auto() {
assert_eq!(DatabaseBackend::default(), DatabaseBackend::Auto);
}
#[test]
fn resolved_backend_display() {
assert_eq!(ResolvedBackend::Postgres.to_string(), "postgres");
assert_eq!(ResolvedBackend::MySql.to_string(), "mysql");
}
#[test]
fn debug_redacts_password() {
let cfg = DatabaseConfig {
url: "postgres://alice:supersecret@db.internal:5432/app".to_string(),
..DatabaseConfig::default()
};
let dbg = format!("{cfg:?}");
assert!(
!dbg.contains("supersecret"),
"password leaked in Debug: {dbg}"
);
assert!(!dbg.contains("alice"), "username leaked in Debug: {dbg}");
assert!(dbg.contains("db.internal:5432/app"));
assert!(dbg.contains("***"));
}
#[test]
fn debug_passes_through_url_without_userinfo() {
let cfg = DatabaseConfig {
url: "postgres://db.internal:5432/app".to_string(),
..DatabaseConfig::default()
};
let dbg = format!("{cfg:?}");
assert!(dbg.contains("db.internal:5432/app"));
assert!(!dbg.contains("***"));
}
#[test]
fn debug_handles_empty_url() {
let cfg = DatabaseConfig::default();
let dbg = format!("{cfg:?}");
assert!(dbg.contains("DatabaseConfig"));
}
#[test]
fn redact_url_userinfo_preserves_non_url_strings() {
assert_eq!(redact_url_userinfo(""), "");
assert_eq!(redact_url_userinfo("not-a-url"), "not-a-url");
}
#[test]
fn redact_url_userinfo_does_not_touch_at_in_path() {
let url = "postgres://host:5432/db/path@with-at";
assert_eq!(redact_url_userinfo(url), url);
}
}