pub mod api_keys;
pub mod encrypted_data_refs;
pub mod permissions;
use deadpool_postgres::{Config, ManagerConfig, Pool, RecyclingMethod, Runtime};
use once_cell::sync::OnceCell;
use sqlx::{
postgres::{PgConnectOptions, PgPoolOptions, PgSslMode},
PgPool,
};
use std::{str::FromStr, time::Duration};
use thiserror::Error;
use tracing::{info, warn};
pub use api_keys::{redact_bearer, ApiKeyRecord, ApiKeyRepository, ApiKeyUpdate};
pub use encrypted_data_refs::{DataType, EncryptedDataRefRecord, EncryptedDataRefRepository};
pub use permissions::ApiPermission;
#[derive(Debug, Clone)]
pub struct DatabaseConfig {
pub url: String,
pub max_connections: u32,
pub min_connections: u32,
pub connect_timeout: Duration,
pub ssl_mode: PgSslMode,
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
url: "postgres://newton:newton@localhost:5432/newton_gateway".to_string(),
max_connections: 20,
min_connections: 5,
connect_timeout: Duration::from_secs(30),
ssl_mode: PgSslMode::Require,
}
}
}
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("Failed to create database pool: {0}")]
PoolCreation(String),
#[error("Failed to connect to database: {0}")]
Connection(String),
#[error("Database query error: {0}")]
Query(String),
#[error("Database transaction error: {0}")]
Transaction(String),
#[error(
"Database singleton already initialized with a different URL. existing={existing}, requested={requested}. \
Both callers in this process must agree on the connection URL."
)]
SingletonUrlMismatch {
existing: String,
requested: String,
},
}
#[derive(Clone)]
pub struct DatabaseManager {
url: String,
deadpool: Pool,
sqlx_pool: PgPool,
}
impl DatabaseManager {
pub async fn new(config: DatabaseConfig) -> Result<Self, DatabaseError> {
info!(
"Initializing database manager (max_connections: {}, min_connections: {})",
config.max_connections, config.min_connections
);
let pg_config = config
.url
.parse::<deadpool_postgres::tokio_postgres::Config>()
.map_err(|e| DatabaseError::Connection(format!("Invalid connection URL: {}", e)))?;
let manager_config = ManagerConfig {
recycling_method: RecyclingMethod::Fast,
};
let mut deadpool_config = Config::new();
deadpool_config.host = pg_config.get_hosts().first().and_then(|h| match h {
deadpool_postgres::tokio_postgres::config::Host::Tcp(host) => Some(host.clone()),
_ => None,
});
deadpool_config.port = pg_config.get_ports().first().copied();
deadpool_config.user = pg_config.get_user().map(String::from);
deadpool_config.password = pg_config
.get_password()
.map(|p: &[u8]| String::from_utf8_lossy(p).to_string());
deadpool_config.dbname = pg_config.get_dbname().map(String::from);
deadpool_config.manager = Some(manager_config);
deadpool_config.pool = Some(deadpool_postgres::PoolConfig::new(config.max_connections as usize));
let deadpool = deadpool_config
.create_pool(Some(Runtime::Tokio1), deadpool_postgres::tokio_postgres::NoTls)
.map_err(|e| DatabaseError::PoolCreation(format!("Failed to create deadpool: {}", e)))?;
let url_ssl_mode = parse_url_sslmode(&config.url)?;
if let Some(ssl_mode) = url_ssl_mode {
reject_downgraded_sslmode(ssl_mode, config.ssl_mode)?;
}
let pg_connect_options = PgConnectOptions::from_str(&config.url)
.map_err(|e| DatabaseError::Connection(format!("Invalid SQLx connection URL: {}", e)))?
.ssl_mode(url_ssl_mode.unwrap_or(config.ssl_mode)); info!("Database TLS: sslmode={:?}", config.ssl_mode);
let sqlx_pool = PgPoolOptions::new()
.max_connections(config.max_connections)
.min_connections(config.min_connections)
.acquire_timeout(config.connect_timeout)
.connect_with(pg_connect_options)
.await
.map_err(|e| DatabaseError::Connection(format!("Failed to connect with SQLx: {}", e)))?;
info!("Database manager initialized successfully");
Ok(Self {
url: config.url,
deadpool,
sqlx_pool,
})
}
pub fn url(&self) -> &str {
&self.url
}
pub async fn get_connection(&self) -> Result<deadpool_postgres::Client, DatabaseError> {
self.deadpool
.get()
.await
.map_err(|e| DatabaseError::Connection(format!("Failed to get connection: {}", e)))
}
pub fn pool(&self) -> &PgPool {
&self.sqlx_pool
}
pub fn deadpool(&self) -> &Pool {
&self.deadpool
}
pub fn pool_stats(&self) -> PoolStats {
PoolStats {
size: self.sqlx_pool.size(),
idle: self.sqlx_pool.num_idle(),
max_size: self.sqlx_pool.options().get_max_connections(),
}
}
pub async fn health_check(&self) -> Result<(), DatabaseError> {
sqlx::query("SELECT 1")
.execute(&self.sqlx_pool)
.await
.map_err(|e| DatabaseError::Query(format!("Health check failed: {}", e)))?;
Ok(())
}
}
pub fn parse_sslmode(modestring: &str) -> Result<PgSslMode, DatabaseError> {
match modestring.to_lowercase().as_str() {
"url" => Ok(PgSslMode::Prefer),
"prefer" => Ok(PgSslMode::Prefer),
"require" => Ok(PgSslMode::Require),
"verify-ca" => Ok(PgSslMode::VerifyCa),
"verify-full" => Ok(PgSslMode::VerifyFull),
"disable" => Err(DatabaseError::Connection(
"Invalid SSL mode specified: disable".to_string(),
)),
"allow" => Err(DatabaseError::Connection(
"Invalid SSL mode specified: allow".to_string(),
)),
_ => Err(DatabaseError::Connection("Unknown SSL mode specified".to_string())),
}
}
fn parse_url_sslmode(url: &str) -> Result<Option<PgSslMode>, DatabaseError> {
let query = url.split_once('?').map(|(_, q)| q).unwrap_or("");
for pair in query.split('&') {
let (k, v) = pair.split_once('=').unwrap_or((pair, ""));
if !k.eq_ignore_ascii_case("sslmode") {
continue;
}
return Ok(Some(parse_sslmode(&v.trim().to_ascii_lowercase())?));
}
Ok(None)
}
fn sslmode_val(ssl_mode: PgSslMode) -> u32 {
match ssl_mode {
PgSslMode::Disable => 0,
PgSslMode::Allow => 1,
PgSslMode::Prefer => 2,
PgSslMode::Require => 3,
PgSslMode::VerifyCa => 4,
PgSslMode::VerifyFull => 5,
}
}
fn reject_downgraded_sslmode(url_ssl_mode: PgSslMode, config_ssl_mode: PgSslMode) -> Result<(), DatabaseError> {
if sslmode_val(url_ssl_mode) < sslmode_val(config_ssl_mode) {
return Err(DatabaseError::Connection(
"Url downgrades ssl mode and cannot be accepted.".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod sslmode_tests {
use super::*;
#[test]
fn accepts_no_sslmode() {
assert!(parse_url_sslmode("postgres://u:p@h/db").unwrap().is_none());
}
#[test]
fn parses_mode() {
assert!(matches!(
parse_url_sslmode("postgres://u:p@h/db?sslmode=prefer")
.unwrap()
.unwrap(),
PgSslMode::Prefer
));
assert!(matches!(
parse_url_sslmode("postgres://u:p@h/db?sslmode=require")
.unwrap()
.unwrap(),
PgSslMode::Require
));
assert!(matches!(
parse_url_sslmode("postgres://u:p@h/db?sslmode=verify-ca")
.unwrap()
.unwrap(),
PgSslMode::VerifyCa
));
assert!(matches!(
parse_url_sslmode("postgres://u:p@h/db?sslmode=verify-full")
.unwrap()
.unwrap(),
PgSslMode::VerifyFull
));
assert!(matches!(
parse_url_sslmode("postgres://u:p@h/db?foo=bar&sslmode=Require")
.unwrap()
.unwrap(),
PgSslMode::Require
));
assert!(parse_url_sslmode("postgres://u:p@h/db?sslmode=allow").is_err());
assert!(parse_url_sslmode("postgres://u:p@h/db?sslmode=disable").is_err());
assert!(matches!(parse_sslmode("url").unwrap(), PgSslMode::Prefer));
}
#[test]
fn rejects_downgrade() {
for mode in [PgSslMode::Disable, PgSslMode::Allow, PgSslMode::Prefer] {
assert!(
reject_downgraded_sslmode(mode, PgSslMode::Require).is_err(),
"{:?}",
mode
);
}
assert!(reject_downgraded_sslmode(PgSslMode::Require, PgSslMode::VerifyCa).is_err());
assert!(reject_downgraded_sslmode(PgSslMode::VerifyCa, PgSslMode::VerifyFull).is_err());
for mode in [PgSslMode::Require, PgSslMode::VerifyCa, PgSslMode::VerifyFull] {
assert!(reject_downgraded_sslmode(mode, PgSslMode::Prefer).is_ok(), "{:?}", mode);
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct PoolStats {
pub size: u32,
pub idle: usize,
pub max_size: u32,
}
impl std::fmt::Debug for DatabaseManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let stats = self.pool_stats();
f.debug_struct("DatabaseManager")
.field("pool_size", &stats.size)
.field("idle_connections", &stats.idle)
.field("max_connections", &stats.max_size)
.finish()
}
}
static DATABASE: OnceCell<DatabaseManager> = OnceCell::new();
pub async fn initialize_database(config: DatabaseConfig) -> Result<(), DatabaseError> {
if let Some(existing) = DATABASE.get() {
if existing.url() != config.url {
warn!("Database singleton URL mismatch — rejecting. In-process callers must agree on the DB URL.");
return Err(DatabaseError::SingletonUrlMismatch {
existing: existing.url().to_string(),
requested: config.url,
});
}
info!("Global database singleton already initialized with matching URL; reusing");
return Ok(());
}
let manager = DatabaseManager::new(config).await?;
match DATABASE.set(manager) {
Ok(()) => {
info!("Global database singleton initialized");
Ok(())
}
Err(manager) => {
let existing = DATABASE.get().expect("cell was just populated");
if existing.url() != manager.url() {
warn!("Database singleton URL mismatch detected on race — rejecting.");
return Err(DatabaseError::SingletonUrlMismatch {
existing: existing.url().to_string(),
requested: manager.url().to_string(),
});
}
info!("Global database singleton raced; both callers agreed on URL");
Ok(())
}
}
}
pub fn get_database() -> &'static DatabaseManager {
DATABASE
.get()
.expect("Database not initialized. Call initialize_database() first.")
}
pub fn try_get_database() -> Option<&'static DatabaseManager> {
DATABASE.get()
}