use sqlx::postgres::PgPoolOptions;
use crate::provision_sql::PROVISION_SQL;
use super::error::ProvisioningError;
pub async fn run_provision_sql(pg_uri: &str) -> Result<usize, ProvisioningError> {
let pool = PgPoolOptions::new()
.max_connections(1)
.connect(pg_uri)
.await
.map_err(|err| {
ProvisioningError::Execution(format!("failed to connect to Postgres: {err}"))
})?;
let statements: Vec<&str> = split_provision_statements(PROVISION_SQL);
let total: usize = statements.len();
for (index, statement) in statements.iter().enumerate() {
sqlx::query(statement).execute(&pool).await.map_err(|err| {
let preview_len = statement.len().min(120);
let ellipsis = if statement.len() > 120 { "…" } else { "" };
ProvisioningError::Execution(format!(
"statement {}/{} failed: {}{} — {}",
index + 1,
total,
&statement[..preview_len],
ellipsis,
err
))
})?;
}
Ok(total)
}
pub fn split_provision_statements(sql: &str) -> Vec<&str> {
sql.split(';')
.map(trim_leading_comments)
.filter(|statement| !statement.is_empty())
.collect()
}
fn trim_leading_comments(mut statement: &str) -> &str {
loop {
statement = statement.trim_start();
if statement.is_empty() {
return statement;
}
if let Some(rest) = statement.strip_prefix("--") {
if let Some(newline_idx) = rest.find('\n') {
statement = &rest[newline_idx + 1..];
continue;
}
return "";
}
if let Some(rest) = statement.strip_prefix("/*") {
if let Some(end_idx) = rest.find("*/") {
statement = &rest[end_idx + 2..];
continue;
}
return "";
}
return statement;
}
}