athena_rs 3.3.0

Database gateway API
Documentation
use sqlx::postgres::PgPoolOptions;

use crate::provision_sql::PROVISION_SQL;

use super::error::ProvisioningError;

/// Apply bundled `sql/provision.sql` to a target Postgres URI.
pub async fn run_provision_sql(pg_uri: &str) -> Result<usize, ProvisioningError> {
    let pool = PgPoolOptions::new()
        .max_connections(1)
        .connect(pg_uri)
        .await
        .map_err(|err| {
            ProvisioningError::Execution(format!("failed to connect to Postgres: {err}"))
        })?;

    let statements: Vec<&str> = split_provision_statements(PROVISION_SQL);
    let total: usize = statements.len();

    for (index, statement) in statements.iter().enumerate() {
        sqlx::query(statement).execute(&pool).await.map_err(|err| {
            let preview_len = statement.len().min(120);
            let ellipsis = if statement.len() > 120 { "" } else { "" };
            ProvisioningError::Execution(format!(
                "statement {}/{} failed: {}{}{}",
                index + 1,
                total,
                &statement[..preview_len],
                ellipsis,
                err
            ))
        })?;
    }

    Ok(total)
}

pub fn split_provision_statements(sql: &str) -> Vec<&str> {
    sql.split(';')
        .map(trim_leading_comments)
        .filter(|statement| !statement.is_empty())
        .collect()
}

fn trim_leading_comments(mut statement: &str) -> &str {
    loop {
        statement = statement.trim_start();
        if statement.is_empty() {
            return statement;
        }

        if let Some(rest) = statement.strip_prefix("--") {
            if let Some(newline_idx) = rest.find('\n') {
                statement = &rest[newline_idx + 1..];
                continue;
            }
            return "";
        }

        if let Some(rest) = statement.strip_prefix("/*") {
            if let Some(end_idx) = rest.find("*/") {
                statement = &rest[end_idx + 2..];
                continue;
            }
            return "";
        }

        return statement;
    }
}