use sqlx::Row;
use sqlx::postgres::PgPool;
use std::collections::BTreeSet;
use std::str::FromStr;
use tokio_postgres::Config as PgConfig;
use tokio_postgres::config::Host;
use super::error::ProvisioningError;
use super::types::LocalClusterCreateDatabaseParams;
pub async fn list_postgres_databases(pool: &PgPool) -> Result<Vec<String>, ProvisioningError> {
let rows = sqlx::query(
r#"
SELECT datname
FROM pg_database
WHERE datistemplate = false
ORDER BY lower(datname)
"#,
)
.fetch_all(pool)
.await
.map_err(|err| ProvisioningError::Execution(format!("failed to list databases: {err}")))?;
let mut names = Vec::with_capacity(rows.len());
for row in rows {
if let Ok(name) = row.try_get::<String, _>("datname") {
names.push(name);
}
}
Ok(names)
}
pub async fn create_postgres_database(
pool: &PgPool,
params: &LocalClusterCreateDatabaseParams,
) -> Result<(), ProvisioningError> {
let statement = build_create_database_statement(params)?;
sqlx::query(&statement)
.execute(pool)
.await
.map_err(map_sql_error)?;
Ok(())
}
pub fn build_create_database_statement(
params: &LocalClusterCreateDatabaseParams,
) -> Result<String, ProvisioningError> {
let database_name = validate_identifier("database_name", ¶ms.database_name)?;
let mut parts = vec![format!(
"CREATE DATABASE {}",
quote_identifier(&database_name)
)];
if let Some(owner) = non_empty_identifier(params.options.owner.as_deref())? {
parts.push(format!("OWNER {}", quote_identifier(&owner)));
}
if let Some(template) = non_empty_identifier(params.options.template.as_deref())? {
parts.push(format!("TEMPLATE {}", quote_identifier(&template)));
}
if let Some(encoding) = non_empty_text(params.options.encoding.as_deref())? {
parts.push(format!("ENCODING {}", quote_literal(&encoding)));
}
if let Some(collate) = non_empty_text(params.options.lc_collate.as_deref())? {
parts.push(format!("LC_COLLATE {}", quote_literal(&collate)));
}
if let Some(ctype) = non_empty_text(params.options.lc_ctype.as_deref())? {
parts.push(format!("LC_CTYPE {}", quote_literal(&ctype)));
}
if let Some(tablespace) = non_empty_identifier(params.options.tablespace.as_deref())? {
parts.push(format!("TABLESPACE {}", quote_identifier(&tablespace)));
}
Ok(parts.join(" "))
}
pub fn replace_uri_database_name(
source_uri: &str,
database_name: &str,
) -> Result<String, ProvisioningError> {
let database_name = validate_identifier("database_name", database_name)?;
let trimmed = source_uri.trim();
if !(trimmed.starts_with("postgres://") || trimmed.starts_with("postgresql://")) {
return Err(ProvisioningError::InvalidInput(
"only postgres:// or postgresql:// URIs are supported for local cluster provisioning"
.to_string(),
));
}
let query_start = trimmed.find('?').unwrap_or(trimmed.len());
let base = &trimmed[..query_start];
let query_suffix = &trimmed[query_start..];
let Some(scheme_sep) = base.find("://") else {
return Err(ProvisioningError::InvalidInput(
"invalid Postgres URI: missing scheme separator".to_string(),
));
};
let authority_start = scheme_sep + 3;
let Some(path_start_rel) = base[authority_start..].find('/') else {
return Err(ProvisioningError::InvalidInput(
"invalid Postgres URI: missing database path".to_string(),
));
};
let path_start = authority_start + path_start_rel;
let rebuilt = format!("{}/{}{}", &base[..path_start], database_name, query_suffix);
Ok(rebuilt)
}
pub fn postgres_uri_database_name(uri: &str) -> Option<String> {
PgConfig::from_str(uri)
.ok()
.and_then(|cfg| cfg.get_dbname().map(str::to_string))
}
pub fn postgres_uri_fingerprint(uri: &str) -> Option<String> {
let cfg = PgConfig::from_str(uri).ok()?;
let mut pairs = BTreeSet::new();
for (index, host) in cfg.get_hosts().iter().enumerate() {
let host_label = match host {
Host::Tcp(value) => value.to_ascii_lowercase(),
#[cfg(unix)]
Host::Unix(path) => format!("unix:{}", path.display()),
};
let port = cfg.get_ports().get(index).copied().unwrap_or(5432);
pairs.insert(format!("{host_label}:{port}"));
}
if pairs.is_empty() {
return None;
}
let user = cfg.get_user().unwrap_or("<none>").to_ascii_lowercase();
Some(format!(
"{}|{}",
user,
pairs.into_iter().collect::<Vec<_>>().join(",")
))
}
fn non_empty_identifier(value: Option<&str>) -> Result<Option<String>, ProvisioningError> {
match value {
Some(raw) if !raw.trim().is_empty() => Ok(Some(validate_identifier("option", raw)?)),
Some(_) => Err(ProvisioningError::InvalidInput(
"advanced option values must not be empty strings".to_string(),
)),
None => Ok(None),
}
}
fn non_empty_text(value: Option<&str>) -> Result<Option<String>, ProvisioningError> {
match value {
Some(raw) if !raw.trim().is_empty() => Ok(Some(raw.trim().to_string())),
Some(_) => Err(ProvisioningError::InvalidInput(
"advanced option values must not be empty strings".to_string(),
)),
None => Ok(None),
}
}
fn validate_identifier(field: &str, value: &str) -> Result<String, ProvisioningError> {
let trimmed = value.trim();
if trimmed.is_empty() {
return Err(ProvisioningError::InvalidInput(format!(
"'{}' must not be empty",
field
)));
}
if trimmed.len() > 63 {
return Err(ProvisioningError::InvalidInput(format!(
"'{}' exceeds PostgreSQL identifier length limit (63)",
field
)));
}
if !trimmed
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
{
return Err(ProvisioningError::InvalidInput(format!(
"'{}' may contain only ASCII letters, numbers, and '_'",
field
)));
}
Ok(trimmed.to_ascii_lowercase())
}
fn quote_identifier(identifier: &str) -> String {
format!("\"{}\"", identifier.replace('"', "\"\""))
}
fn quote_literal(value: &str) -> String {
format!("'{}'", value.replace('\'', "''"))
}
fn map_sql_error(err: sqlx::Error) -> ProvisioningError {
let duplicate_database = err
.as_database_error()
.and_then(|db| db.code())
.is_some_and(|code| code == "42P04");
let insufficient_privilege = err
.as_database_error()
.and_then(|db| db.code())
.is_some_and(|code| code == "42501");
if duplicate_database {
return ProvisioningError::Conflict("database already exists".to_string());
}
if insufficient_privilege {
return ProvisioningError::Unavailable(
"database role lacks CREATE DATABASE privilege".to_string(),
);
}
ProvisioningError::Execution(format!("database operation failed: {err}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_database_statement_includes_advanced_options() {
let statement = build_create_database_statement(&LocalClusterCreateDatabaseParams {
database_name: "tenant_a".to_string(),
options: crate::provisioning::types::LocalClusterDatabaseCreateOptions {
owner: Some("athena_owner".to_string()),
template: Some("template0".to_string()),
encoding: Some("UTF8".to_string()),
lc_collate: Some("en_US.UTF-8".to_string()),
lc_ctype: Some("en_US.UTF-8".to_string()),
tablespace: Some("pg_default".to_string()),
},
})
.expect("statement should build");
assert!(statement.contains("CREATE DATABASE \"tenant_a\""));
assert!(statement.contains("OWNER \"athena_owner\""));
assert!(statement.contains("TEMPLATE \"template0\""));
assert!(statement.contains("ENCODING 'UTF8'"));
assert!(statement.contains("LC_COLLATE 'en_US.UTF-8'"));
assert!(statement.contains("LC_CTYPE 'en_US.UTF-8'"));
assert!(statement.contains("TABLESPACE \"pg_default\""));
}
#[test]
fn replace_uri_database_name_preserves_query() {
let replaced = replace_uri_database_name(
"postgres://user:pass@localhost:5432/old_db?sslmode=require",
"new_db",
)
.expect("URI should be rewritten");
assert_eq!(
replaced,
"postgres://user:pass@localhost:5432/new_db?sslmode=require"
);
}
#[test]
fn replace_uri_database_name_rejects_invalid_database() {
let err =
replace_uri_database_name("postgres://user:pass@localhost:5432/old_db", "bad-name")
.expect_err("invalid identifier must fail");
match err {
ProvisioningError::InvalidInput(message) => {
assert!(message.contains("database_name"));
}
_ => panic!("expected invalid input"),
}
}
}