use std::path::{Path, PathBuf};
use std::sync::Arc;
use serde::Deserialize;
use super::Storage;
use super::error::{Result, StorageError};
use super::paths::QueuePaths;
use super::sqlite::SqliteStorage;
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "adapter", rename_all = "lowercase")]
#[non_exhaustive]
pub enum DatabaseConfig {
Sqlite(SqliteConfig),
Postgres(PostgresConfig),
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self::Sqlite(SqliteConfig::default())
}
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct SqliteConfig {
#[serde(default)]
pub path: Option<PathBuf>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct PostgresConfig {
pub host: String,
#[serde(default = "default_pg_port")]
pub port: u16,
pub database: String,
pub username: String,
#[serde(default)]
pub password: Option<String>,
#[serde(default)]
pub password_env: Option<String>,
#[serde(default = "default_pg_max_conn")]
pub max_connections: u32,
}
const fn default_pg_port() -> u16 {
5432
}
const fn default_pg_max_conn() -> u32 {
30
}
const MAX_REPO_WALK_DEPTH: usize = 4;
fn find_repo_local_config() -> Option<PathBuf> {
let cwd = std::env::current_dir().ok()?;
let mut current = cwd.as_path();
for _ in 0..MAX_REPO_WALK_DEPTH {
let candidate = current.join("queue_database.toml");
if candidate.is_file() {
return Some(candidate);
}
current = current.parent()?;
}
None
}
impl DatabaseConfig {
pub fn load(paths: &dyn QueuePaths) -> Result<Self> {
if let Some(repo_path) = find_repo_local_config() {
tracing::info!(
path = %repo_path.display(),
"queue_database.toml: loading repo-local config",
);
return Self::load_from(&repo_path);
}
let xdg_path = paths.config_dir()?.join("queue_database.toml");
if xdg_path.is_file() {
tracing::info!(
path = %xdg_path.display(),
"queue_database.toml: loading XDG config",
);
} else {
tracing::info!(
"queue_database.toml: no config file found at CWD or XDG; \
using SQLite default",
);
}
Self::load_from(&xdg_path)
}
pub fn load_from(path: &Path) -> Result<Self> {
match std::fs::read_to_string(path) {
Ok(s) => toml::from_str(&s).map_err(|e| {
StorageError::InvalidInput(format!(
"parsing database config {}: {e}",
path.display()
))
}),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Self::default()),
Err(e) => Err(StorageError::InvalidInput(format!(
"reading database config {}: {e}",
path.display()
))),
}
}
pub async fn open_storage(&self, paths: &dyn QueuePaths) -> Result<Storage> {
match self {
Self::Sqlite(cfg) => open_sqlite(cfg, paths).await,
Self::Postgres(cfg) => open_postgres(cfg).await,
}
}
#[must_use]
pub const fn adapter_name(&self) -> &'static str {
match self {
Self::Sqlite(_) => "sqlite",
Self::Postgres(_) => "postgres",
}
}
pub async fn create_database(&self, paths: &dyn QueuePaths) -> Result<()> {
match self {
Self::Sqlite(cfg) => create_sqlite(cfg, paths).await,
Self::Postgres(cfg) => create_postgres(cfg).await,
}
}
pub async fn drop_database(&self, paths: &dyn QueuePaths) -> Result<()> {
match self {
Self::Sqlite(cfg) => drop_sqlite(cfg, paths).await,
Self::Postgres(cfg) => drop_postgres(cfg).await,
}
}
pub async fn migrate(&self, paths: &dyn QueuePaths) -> Result<()> {
let _storage = self.open_storage(paths).await?;
Ok(())
}
pub async fn ping(&self, paths: &dyn QueuePaths) -> Result<()> {
let storage = self.open_storage(paths).await?;
storage.jobs.describe().await?;
Ok(())
}
}
fn default_sqlite_path(paths: &dyn QueuePaths) -> Result<PathBuf> {
Ok(paths.data_dir()?.join("queue.sqlite"))
}
async fn create_sqlite(cfg: &SqliteConfig, paths: &dyn QueuePaths) -> Result<()> {
let path = match cfg.path.clone() {
Some(p) => p,
None => default_sqlite_path(paths)?,
};
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let _ = SqliteStorage::open_file(&path).await?;
tracing::info!(path = %path.display(), "sqlite database created (or already existed)");
Ok(())
}
#[allow(
clippy::unused_async,
reason = "matches the postgres twin's signature so `drop_database` can `.await` both arms uniformly"
)]
async fn drop_sqlite(cfg: &SqliteConfig, paths: &dyn QueuePaths) -> Result<()> {
let path = match cfg.path.clone() {
Some(p) => p,
None => default_sqlite_path(paths)?,
};
for suffix in ["", "-wal", "-shm"] {
let mut p = path.as_os_str().to_owned();
p.push(suffix);
match std::fs::remove_file(std::path::Path::new(&p)) {
Ok(()) => tracing::info!(path = ?p, "sqlite: removed"),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => return Err(e.into()),
}
}
Ok(())
}
#[cfg(feature = "postgres")]
async fn create_postgres(cfg: &PostgresConfig) -> Result<()> {
validate_pg_identifier(&cfg.database)?;
let opts = cfg.maintenance_connect_options()?;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(std::time::Duration::from_secs(10))
.connect_with(opts)
.await
.map_err(|e| StorageError::Backend(format!("postgres maintenance connect: {e}")))?;
let exists: Option<sqlx::postgres::PgRow> =
sqlx::query("SELECT 1 AS x FROM pg_database WHERE datname = $1")
.bind(&cfg.database)
.fetch_optional(&pool)
.await
.map_err(|e| StorageError::Backend(format!("pg_database lookup: {e}")))?;
if exists.is_some() {
tracing::info!(database = %cfg.database, "postgres database already exists; skipping CREATE");
return Ok(());
}
let sql = format!("CREATE DATABASE \"{}\"", cfg.database);
sqlx::query(&sql)
.execute(&pool)
.await
.map_err(|e| StorageError::Backend(format!("CREATE DATABASE: {e}")))?;
tracing::info!(database = %cfg.database, "postgres database created");
Ok(())
}
#[cfg(feature = "postgres")]
async fn drop_postgres(cfg: &PostgresConfig) -> Result<()> {
validate_pg_identifier(&cfg.database)?;
let opts = cfg.maintenance_connect_options()?;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(std::time::Duration::from_secs(10))
.connect_with(opts)
.await
.map_err(|e| StorageError::Backend(format!("postgres maintenance connect: {e}")))?;
let sql = format!("DROP DATABASE IF EXISTS \"{}\" WITH (FORCE)", cfg.database);
sqlx::query(&sql)
.execute(&pool)
.await
.map_err(|e| StorageError::Backend(format!("DROP DATABASE: {e}")))?;
tracing::info!(database = %cfg.database, "postgres database dropped");
Ok(())
}
#[cfg(not(feature = "postgres"))]
#[allow(clippy::unused_async)]
async fn create_postgres(_cfg: &PostgresConfig) -> Result<()> {
Err(StorageError::InvalidInput(
"create on postgres requires --features postgres".into(),
))
}
#[cfg(not(feature = "postgres"))]
#[allow(clippy::unused_async)]
async fn drop_postgres(_cfg: &PostgresConfig) -> Result<()> {
Err(StorageError::InvalidInput(
"drop on postgres requires --features postgres".into(),
))
}
#[cfg(feature = "postgres")]
fn validate_pg_identifier(name: &str) -> Result<()> {
if name.is_empty() || name.len() > 63 {
return Err(StorageError::InvalidInput(format!(
"database name `{name}` must be 1..=63 chars"
)));
}
let mut chars = name.chars();
let Some(first) = chars.next() else {
return Err(StorageError::InvalidInput(format!(
"database name `{name}` must be 1..=63 chars"
)));
};
if !(first.is_ascii_alphabetic() || first == '_') {
return Err(StorageError::InvalidInput(format!(
"database name `{name}` must start with a letter or underscore"
)));
}
for c in chars {
if !(c.is_ascii_alphanumeric() || c == '_') {
return Err(StorageError::InvalidInput(format!(
"database name `{name}` contains invalid character `{c}` \
(allowed: letters, digits, underscore)"
)));
}
}
Ok(())
}
async fn open_sqlite(cfg: &SqliteConfig, paths: &dyn QueuePaths) -> Result<Storage> {
let inner = if let Some(p) = cfg.path.as_deref() {
SqliteStorage::open_file(p).await?
} else {
SqliteStorage::open_default(paths).await?
};
Ok(Storage::from_one(Arc::new(inner)))
}
#[cfg(feature = "postgres")]
async fn open_postgres(cfg: &PostgresConfig) -> Result<Storage> {
let opts = cfg.pg_connect_options()?;
let inner =
super::postgres::PostgresStorage::open_with_options(opts, cfg.max_connections).await?;
Ok(Storage::from_one(Arc::new(inner)))
}
#[cfg(not(feature = "postgres"))]
#[allow(
clippy::unused_async,
reason = "async signature matches the `feature = postgres` variant so callers can `.await` it uniformly"
)]
async fn open_postgres(_cfg: &PostgresConfig) -> Result<Storage> {
Err(StorageError::InvalidInput(
"queue_database.toml requests adapter = \"postgres\" but this build was compiled \
without the `postgres` feature. Rebuild with `--features postgres` (or set \
`adapter = \"sqlite\"`)."
.into(),
))
}
impl PostgresConfig {
#[cfg_attr(
not(feature = "postgres"),
allow(
dead_code,
reason = "only called by `open_postgres` under the postgres feature; \
kept in scope so the config still parses + the unit tests \
can validate the resolution logic without the feature on"
)
)]
fn resolve_password(&self) -> Result<String> {
if let Some(p) = &self.password {
return Ok(p.clone());
}
if let Some(env_name) = &self.password_env {
return std::env::var(env_name).map_err(|_| {
let likely_field_mixup =
env_name.chars().all(|c| c.is_ascii_lowercase()) && !env_name.contains('_');
let hint = if likely_field_mixup {
format!(
" — `password_env` expects the *name* of an env var, not the \
password itself. If `{env_name}` is your literal password, \
use `password = \"{env_name}\"` instead."
)
} else {
String::new()
};
StorageError::InvalidInput(format!(
"queue_database.toml: password_env `{env_name}` is not set{hint}"
))
});
}
Err(StorageError::InvalidInput(
"queue_database.toml: postgres adapter requires either `password` (literal) \
or `password_env` (name of an env var holding the password)"
.into(),
))
}
#[cfg(feature = "postgres")]
pub fn pg_connect_options(&self) -> Result<sqlx::postgres::PgConnectOptions> {
let password = self.resolve_password()?;
Ok(sqlx::postgres::PgConnectOptions::new()
.host(&self.host)
.port(self.port)
.database(&self.database)
.username(&self.username)
.password(&password))
}
#[cfg(feature = "postgres")]
pub fn maintenance_connect_options(&self) -> Result<sqlx::postgres::PgConnectOptions> {
Ok(self.pg_connect_options()?.database("postgres"))
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::panic,
reason = "unit tests crash loudly on setup failure"
)]
use super::*;
#[test]
fn default_is_sqlite_with_no_explicit_path() {
let cfg = DatabaseConfig::default();
match cfg {
DatabaseConfig::Sqlite(s) => assert!(s.path.is_none()),
DatabaseConfig::Postgres(_) => panic!("expected sqlite default"),
}
}
#[test]
fn parses_minimal_sqlite_toml() {
let parsed: DatabaseConfig = toml::from_str(r#"adapter = "sqlite""#).unwrap();
match parsed {
DatabaseConfig::Sqlite(s) => assert!(s.path.is_none()),
DatabaseConfig::Postgres(_) => panic!("expected sqlite"),
}
}
#[test]
fn parses_sqlite_with_custom_path() {
let parsed: DatabaseConfig = toml::from_str(
r#"
adapter = "sqlite"
path = "/var/lib/tech-admin/queue.sqlite"
"#,
)
.unwrap();
match parsed {
DatabaseConfig::Sqlite(s) => assert_eq!(
s.path,
Some(PathBuf::from("/var/lib/tech-admin/queue.sqlite")),
),
DatabaseConfig::Postgres(_) => panic!("expected sqlite"),
}
}
#[test]
fn parses_postgres_with_defaults() {
let parsed: DatabaseConfig = toml::from_str(
r#"
adapter = "postgres"
host = "db.internal"
database = "tech_admin"
username = "tech_admin"
password_env = "TECH_ADMIN_DB_PASSWORD"
"#,
)
.unwrap();
match parsed {
DatabaseConfig::Postgres(p) => {
assert_eq!(p.host, "db.internal");
assert_eq!(p.port, 5432, "default port applies");
assert_eq!(p.max_connections, 30, "default cap applies");
assert_eq!(p.password_env.as_deref(), Some("TECH_ADMIN_DB_PASSWORD"));
}
DatabaseConfig::Sqlite(_) => panic!("expected postgres"),
}
}
#[test]
fn resolve_password_errors_when_neither_is_set() {
let cfg = PostgresConfig {
host: "x".into(),
port: 5432,
database: "x".into(),
username: "x".into(),
password: None,
password_env: None,
max_connections: 30,
};
assert!(cfg.resolve_password().is_err());
}
#[test]
fn resolve_password_prefers_literal_over_env() {
let cfg = PostgresConfig {
host: "x".into(),
port: 5432,
database: "x".into(),
username: "x".into(),
password: Some("hunter2".into()),
password_env: Some("DEFINITELY_NOT_SET_FOR_TEST_999".into()),
max_connections: 30,
};
assert_eq!(cfg.resolve_password().unwrap(), "hunter2");
}
#[test]
fn resolve_password_errors_when_env_var_missing() {
let cfg = PostgresConfig {
host: "x".into(),
port: 5432,
database: "x".into(),
username: "x".into(),
password: None,
password_env: Some("TECH_ADMIN_DEFINITELY_NOT_SET_8a3f".into()),
max_connections: 30,
};
assert!(cfg.resolve_password().is_err());
}
#[test]
fn load_from_missing_path_returns_default() {
let cfg =
DatabaseConfig::load_from(Path::new("/nonexistent/tech-admin/queue_database.toml"))
.unwrap();
match cfg {
DatabaseConfig::Sqlite(s) => assert!(s.path.is_none()),
DatabaseConfig::Postgres(_) => panic!("expected sqlite default"),
}
}
}