athena_rs 3.3.0

Database gateway API
Documentation
use sqlx::Row;
use sqlx::postgres::PgPool;
use std::collections::BTreeSet;
use std::str::FromStr;
use tokio_postgres::Config as PgConfig;
use tokio_postgres::config::Host;

use super::error::ProvisioningError;
use super::types::LocalClusterCreateDatabaseParams;

pub async fn list_postgres_databases(pool: &PgPool) -> Result<Vec<String>, ProvisioningError> {
    let rows = sqlx::query(
        r#"
        SELECT datname
        FROM pg_database
        WHERE datistemplate = false
        ORDER BY lower(datname)
        "#,
    )
    .fetch_all(pool)
    .await
    .map_err(|err| ProvisioningError::Execution(format!("failed to list databases: {err}")))?;

    let mut names = Vec::with_capacity(rows.len());
    for row in rows {
        if let Ok(name) = row.try_get::<String, _>("datname") {
            names.push(name);
        }
    }

    Ok(names)
}

pub async fn create_postgres_database(
    pool: &PgPool,
    params: &LocalClusterCreateDatabaseParams,
) -> Result<(), ProvisioningError> {
    let statement = build_create_database_statement(params)?;
    sqlx::query(&statement)
        .execute(pool)
        .await
        .map_err(map_sql_error)?;
    Ok(())
}

pub fn build_create_database_statement(
    params: &LocalClusterCreateDatabaseParams,
) -> Result<String, ProvisioningError> {
    let database_name = validate_identifier("database_name", &params.database_name)?;

    let mut parts = vec![format!(
        "CREATE DATABASE {}",
        quote_identifier(&database_name)
    )];

    if let Some(owner) = non_empty_identifier(params.options.owner.as_deref())? {
        parts.push(format!("OWNER {}", quote_identifier(&owner)));
    }
    if let Some(template) = non_empty_identifier(params.options.template.as_deref())? {
        parts.push(format!("TEMPLATE {}", quote_identifier(&template)));
    }
    if let Some(encoding) = non_empty_text(params.options.encoding.as_deref())? {
        parts.push(format!("ENCODING {}", quote_literal(&encoding)));
    }
    if let Some(collate) = non_empty_text(params.options.lc_collate.as_deref())? {
        parts.push(format!("LC_COLLATE {}", quote_literal(&collate)));
    }
    if let Some(ctype) = non_empty_text(params.options.lc_ctype.as_deref())? {
        parts.push(format!("LC_CTYPE {}", quote_literal(&ctype)));
    }
    if let Some(tablespace) = non_empty_identifier(params.options.tablespace.as_deref())? {
        parts.push(format!("TABLESPACE {}", quote_identifier(&tablespace)));
    }

    Ok(parts.join(" "))
}

pub fn replace_uri_database_name(
    source_uri: &str,
    database_name: &str,
) -> Result<String, ProvisioningError> {
    let database_name = validate_identifier("database_name", database_name)?;
    let trimmed = source_uri.trim();

    if !(trimmed.starts_with("postgres://") || trimmed.starts_with("postgresql://")) {
        return Err(ProvisioningError::InvalidInput(
            "only postgres:// or postgresql:// URIs are supported for local cluster provisioning"
                .to_string(),
        ));
    }

    let query_start = trimmed.find('?').unwrap_or(trimmed.len());
    let base = &trimmed[..query_start];
    let query_suffix = &trimmed[query_start..];

    let Some(scheme_sep) = base.find("://") else {
        return Err(ProvisioningError::InvalidInput(
            "invalid Postgres URI: missing scheme separator".to_string(),
        ));
    };

    let authority_start = scheme_sep + 3;
    let Some(path_start_rel) = base[authority_start..].find('/') else {
        return Err(ProvisioningError::InvalidInput(
            "invalid Postgres URI: missing database path".to_string(),
        ));
    };

    let path_start = authority_start + path_start_rel;
    let rebuilt = format!("{}/{}{}", &base[..path_start], database_name, query_suffix);
    Ok(rebuilt)
}

pub fn postgres_uri_database_name(uri: &str) -> Option<String> {
    PgConfig::from_str(uri)
        .ok()
        .and_then(|cfg| cfg.get_dbname().map(str::to_string))
}

pub fn postgres_uri_fingerprint(uri: &str) -> Option<String> {
    let cfg = PgConfig::from_str(uri).ok()?;

    let mut pairs = BTreeSet::new();
    for (index, host) in cfg.get_hosts().iter().enumerate() {
        let host_label = match host {
            Host::Tcp(value) => value.to_ascii_lowercase(),
            #[cfg(unix)]
            Host::Unix(path) => format!("unix:{}", path.display()),
        };
        let port = cfg.get_ports().get(index).copied().unwrap_or(5432);
        pairs.insert(format!("{host_label}:{port}"));
    }

    if pairs.is_empty() {
        return None;
    }

    let user = cfg.get_user().unwrap_or("<none>").to_ascii_lowercase();
    Some(format!(
        "{}|{}",
        user,
        pairs.into_iter().collect::<Vec<_>>().join(",")
    ))
}

fn non_empty_identifier(value: Option<&str>) -> Result<Option<String>, ProvisioningError> {
    match value {
        Some(raw) if !raw.trim().is_empty() => Ok(Some(validate_identifier("option", raw)?)),
        Some(_) => Err(ProvisioningError::InvalidInput(
            "advanced option values must not be empty strings".to_string(),
        )),
        None => Ok(None),
    }
}

fn non_empty_text(value: Option<&str>) -> Result<Option<String>, ProvisioningError> {
    match value {
        Some(raw) if !raw.trim().is_empty() => Ok(Some(raw.trim().to_string())),
        Some(_) => Err(ProvisioningError::InvalidInput(
            "advanced option values must not be empty strings".to_string(),
        )),
        None => Ok(None),
    }
}

fn validate_identifier(field: &str, value: &str) -> Result<String, ProvisioningError> {
    let trimmed = value.trim();
    if trimmed.is_empty() {
        return Err(ProvisioningError::InvalidInput(format!(
            "'{}' must not be empty",
            field
        )));
    }
    if trimmed.len() > 63 {
        return Err(ProvisioningError::InvalidInput(format!(
            "'{}' exceeds PostgreSQL identifier length limit (63)",
            field
        )));
    }
    if !trimmed
        .chars()
        .all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
    {
        return Err(ProvisioningError::InvalidInput(format!(
            "'{}' may contain only ASCII letters, numbers, and '_'",
            field
        )));
    }
    Ok(trimmed.to_ascii_lowercase())
}

fn quote_identifier(identifier: &str) -> String {
    format!("\"{}\"", identifier.replace('"', "\"\""))
}

fn quote_literal(value: &str) -> String {
    format!("'{}'", value.replace('\'', "''"))
}

fn map_sql_error(err: sqlx::Error) -> ProvisioningError {
    let duplicate_database = err
        .as_database_error()
        .and_then(|db| db.code())
        .is_some_and(|code| code == "42P04");
    let insufficient_privilege = err
        .as_database_error()
        .and_then(|db| db.code())
        .is_some_and(|code| code == "42501");

    if duplicate_database {
        return ProvisioningError::Conflict("database already exists".to_string());
    }
    if insufficient_privilege {
        return ProvisioningError::Unavailable(
            "database role lacks CREATE DATABASE privilege".to_string(),
        );
    }

    ProvisioningError::Execution(format!("database operation failed: {err}"))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn create_database_statement_includes_advanced_options() {
        let statement = build_create_database_statement(&LocalClusterCreateDatabaseParams {
            database_name: "tenant_a".to_string(),
            options: crate::provisioning::types::LocalClusterDatabaseCreateOptions {
                owner: Some("athena_owner".to_string()),
                template: Some("template0".to_string()),
                encoding: Some("UTF8".to_string()),
                lc_collate: Some("en_US.UTF-8".to_string()),
                lc_ctype: Some("en_US.UTF-8".to_string()),
                tablespace: Some("pg_default".to_string()),
            },
        })
        .expect("statement should build");

        assert!(statement.contains("CREATE DATABASE \"tenant_a\""));
        assert!(statement.contains("OWNER \"athena_owner\""));
        assert!(statement.contains("TEMPLATE \"template0\""));
        assert!(statement.contains("ENCODING 'UTF8'"));
        assert!(statement.contains("LC_COLLATE 'en_US.UTF-8'"));
        assert!(statement.contains("LC_CTYPE 'en_US.UTF-8'"));
        assert!(statement.contains("TABLESPACE \"pg_default\""));
    }

    #[test]
    fn replace_uri_database_name_preserves_query() {
        let replaced = replace_uri_database_name(
            "postgres://user:pass@localhost:5432/old_db?sslmode=require",
            "new_db",
        )
        .expect("URI should be rewritten");

        assert_eq!(
            replaced,
            "postgres://user:pass@localhost:5432/new_db?sslmode=require"
        );
    }

    #[test]
    fn replace_uri_database_name_rejects_invalid_database() {
        let err =
            replace_uri_database_name("postgres://user:pass@localhost:5432/old_db", "bad-name")
                .expect_err("invalid identifier must fail");

        match err {
            ProvisioningError::InvalidInput(message) => {
                assert!(message.contains("database_name"));
            }
            _ => panic!("expected invalid input"),
        }
    }
}