use super::{Migration, MigrationRecord, MigrationRunner};
use crate::Result;
use chrono::Utc;
use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
use sqlx::{PgPool, Row};
use tracing::{debug, error, info, warn};
fn parse_sql_statements(
sql: &str,
) -> std::result::Result<Vec<String>, sqlparser::parser::ParserError> {
debug!(
"parse_sql_statements called with {} characters of SQL",
sql.len()
);
let dialect = PostgreSqlDialect {};
match Parser::parse_sql(&dialect, sql) {
Ok(statements) => {
debug!(
"Successfully parsed {} statements with sqlparser",
statements.len()
);
Ok(statements
.iter()
.map(|stmt| {
let mut sql_string = format!("{}", stmt);
if let sqlparser::ast::Statement::CreateFunction { args, .. } = stmt {
if args.is_none() {
if let Some(returns_pos) = sql_string.find(" RETURNS ") {
let before_returns = &sql_string[..returns_pos];
if !before_returns.ends_with("()") && !before_returns.ends_with(")")
{
sql_string.insert_str(returns_pos, "()");
}
}
}
}
if sql_string.contains("EXECUTE FUNCTION") {
debug!("Before EXECUTE FUNCTION fix: {}", sql_string);
if let Some(execute_pos) = sql_string.find("EXECUTE FUNCTION ") {
let after_execute = &sql_string[execute_pos + 17..];
let end_pos = after_execute
.find(';')
.or_else(|| after_execute.find('\n'))
.unwrap_or(after_execute.len());
let function_part = &after_execute[..end_pos];
debug!("Function part: '{}'", function_part.trim());
if !function_part.trim().ends_with(')') {
let insert_pos = execute_pos + 17 + function_part.trim().len();
debug!("Adding () at position {}", insert_pos);
sql_string.insert_str(insert_pos, "()");
debug!("After EXECUTE FUNCTION fix: {}", sql_string);
}
}
}
format!("{};", sql_string)
})
.collect())
}
Err(e) => {
error!("sqlparser failed to parse SQL: {}", e);
debug!("Failed SQL content: {}", sql);
debug!("Falling back to manual SQL splitting");
Ok(split_sql_respecting_quotes(sql))
}
}
}
fn split_sql_respecting_quotes(sql: &str) -> Vec<String> {
let mut statements = Vec::new();
let mut current_statement = String::new();
let mut in_single_quote = false;
let mut in_dollar_quote = false;
let mut dollar_tag = String::new();
let mut chars = sql.chars().peekable();
while let Some(ch) = chars.next() {
current_statement.push(ch);
match ch {
'\'' if !in_dollar_quote => {
if in_single_quote {
if chars.peek() == Some(&'\'') {
current_statement.push(chars.next().unwrap());
} else {
in_single_quote = false;
}
} else {
in_single_quote = true;
}
}
'$' if !in_single_quote => {
if in_dollar_quote {
let mut temp_tag = String::new();
let chars_ahead: Vec<char> = chars.clone().collect();
let mut i = 0;
while i < chars_ahead.len()
&& (chars_ahead[i].is_alphanumeric() || chars_ahead[i] == '_')
{
temp_tag.push(chars_ahead[i]);
i += 1;
}
if i < chars_ahead.len() && chars_ahead[i] == '$' && temp_tag == dollar_tag {
for _ in 0..=i {
if let Some(c) = chars.next() {
current_statement.push(c);
}
}
in_dollar_quote = false;
dollar_tag.clear();
}
} else {
let chars_ahead: Vec<char> = chars.clone().collect();
let mut i = 0;
let mut temp_tag = String::new();
while i < chars_ahead.len()
&& (chars_ahead[i].is_alphanumeric() || chars_ahead[i] == '_')
{
temp_tag.push(chars_ahead[i]);
i += 1;
}
if i < chars_ahead.len() && chars_ahead[i] == '$' {
for _ in 0..=i {
if let Some(c) = chars.next() {
current_statement.push(c);
}
}
in_dollar_quote = true;
dollar_tag = temp_tag;
}
}
}
';' if !in_single_quote && !in_dollar_quote => {
let trimmed = current_statement.trim();
let lines: Vec<&str> = trimmed.lines().collect();
let has_sql_content = lines.iter().any(|line| {
let line_trimmed = line.trim();
!line_trimmed.is_empty() && !line_trimmed.starts_with("--")
});
if has_sql_content {
statements.push(current_statement.clone());
}
current_statement.clear();
}
_ => {
}
}
}
let trimmed = current_statement.trim();
if !trimmed.is_empty() {
let lines: Vec<&str> = trimmed.lines().collect();
let has_sql_content = lines.iter().any(|line| {
let line_trimmed = line.trim();
!line_trimmed.is_empty() && !line_trimmed.starts_with("--")
});
if has_sql_content {
statements.push(current_statement);
}
}
statements
}
pub struct PostgresMigrationRunner {
pool: PgPool,
}
impl PostgresMigrationRunner {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[async_trait::async_trait]
impl MigrationRunner<sqlx::Postgres> for PostgresMigrationRunner {
async fn run_migration(&self, migration: &Migration, sql: &str) -> Result<()> {
debug!("Executing PostgreSQL migration: {}", migration.id);
let mut tx = self.pool.begin().await?;
let statements = match parse_sql_statements(sql) {
Ok(stmts) => {
debug!(
"sqlparser-rs successfully parsed {} statements",
stmts.len()
);
for (i, stmt) in stmts.iter().enumerate() {
debug!("Parsed statement {}: '{}'", i + 1, stmt.trim());
}
stmts
}
Err(e) => {
warn!(
"Failed to parse SQL with sqlparser, falling back to naive splitting: {}",
e
);
sql.split(';')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty() && !s.chars().all(|c| c.is_whitespace() || c == '\n'))
.collect()
}
};
debug!(
"Parsed {} statements for migration {}",
statements.len(),
migration.id
);
for (i, statement) in statements.iter().enumerate() {
debug!("Statement {}: '{}'", i + 1, statement.trim());
}
for (i, statement) in statements.iter().enumerate() {
let full_statement = if statement.ends_with(';') {
statement.to_string()
} else {
format!("{};", statement)
};
debug!(
"Executing statement {} of {} for migration {}: {}",
i + 1,
statements.len(),
migration.id,
full_statement
);
if let Err(e) = sqlx::query(&full_statement).execute(&mut *tx).await {
error!(
"Failed to execute statement {}: {} - Error: {}",
i + 1,
full_statement,
e
);
return Err(e.into());
}
}
tx.commit().await?;
info!(
"Successfully executed PostgreSQL migration: {} ({} statements)",
migration.id,
statements.len()
);
Ok(())
}
async fn migration_table_exists(&self) -> Result<bool> {
let row = sqlx::query(
"SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'hammerwork_migrations'
)",
)
.fetch_one(&self.pool)
.await?;
Ok(row.get::<bool, _>(0))
}
async fn create_migration_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS hammerwork_migrations (
migration_id VARCHAR NOT NULL PRIMARY KEY,
executed_at TIMESTAMPTZ NOT NULL,
execution_time_ms BIGINT NOT NULL
)
"#,
)
.execute(&self.pool)
.await?;
info!("Created PostgreSQL migration tracking table");
Ok(())
}
async fn get_executed_migrations(&self) -> Result<Vec<MigrationRecord>> {
let rows = sqlx::query(
"SELECT migration_id, executed_at, execution_time_ms
FROM hammerwork_migrations
ORDER BY executed_at",
)
.fetch_all(&self.pool)
.await?;
let mut records = Vec::new();
for row in rows {
records.push(MigrationRecord {
migration_id: row.get("migration_id"),
executed_at: row.get("executed_at"),
execution_time_ms: row.get::<i64, _>("execution_time_ms") as u64,
});
}
Ok(records)
}
async fn record_migration(&self, migration: &Migration, execution_time_ms: u64) -> Result<()> {
sqlx::query(
"INSERT INTO hammerwork_migrations (migration_id, executed_at, execution_time_ms)
VALUES ($1, $2, $3)",
)
.bind(&migration.id)
.bind(Utc::now())
.bind(execution_time_ms as i64)
.execute(&self.pool)
.await?;
debug!("Recorded PostgreSQL migration: {}", migration.id);
Ok(())
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_sql_statement_splitting() {
let multi_statement_sql = r#"
-- Comment line
CREATE TABLE test_table (
id INTEGER PRIMARY KEY
);
-- Another comment
ALTER TABLE test_table ADD COLUMN name VARCHAR(50);
CREATE INDEX idx_test ON test_table (name);
"#;
let statements: Vec<&str> = multi_statement_sql
.split(";\n")
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
assert_eq!(statements.len(), 3);
assert!(statements[0].contains("CREATE TABLE"));
assert!(statements[1].contains("ALTER TABLE"));
assert!(statements[2].contains("CREATE INDEX"));
}
#[test]
fn test_dollar_quoted_string_parsing() {
let sql_with_function = r#"
CREATE OR REPLACE FUNCTION update_timestamp()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER test_trigger BEFORE UPDATE ON test_table FOR EACH ROW EXECUTE FUNCTION update_timestamp();
"#;
let statements = super::split_sql_respecting_quotes(sql_with_function);
assert_eq!(statements.len(), 2);
assert!(statements[0].contains("RETURNS TRIGGER AS $$"));
assert!(statements[0].contains("$$ LANGUAGE plpgsql"));
assert!(statements[1].contains("CREATE TRIGGER"));
}
#[test]
fn test_sqlparser_integration() {
let sql_with_function = r#"
CREATE OR REPLACE FUNCTION update_hammerwork_queue_pause_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"#;
let result = super::parse_sql_statements(sql_with_function);
assert!(
result.is_ok(),
"sqlparser should handle dollar-quoted strings"
);
let statements = result.unwrap();
assert_eq!(statements.len(), 1);
assert!(
statements[0].contains("LANGUAGE plpgsql"),
"Statement should contain LANGUAGE plpgsql: {}",
statements[0]
);
}
#[test]
fn test_migration_014_sql_parsing() {
let migration_014_sql = r#"-- Add queue pause functionality
-- Migration 014: Add queue pause state tracking
-- Create table for tracking queue pause states
CREATE TABLE IF NOT EXISTS hammerwork_queue_pause (
queue_name VARCHAR(255) PRIMARY KEY,
paused_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
paused_by VARCHAR(255),
reason TEXT,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
-- Create index for faster lookups
CREATE INDEX IF NOT EXISTS idx_hammerwork_queue_pause_paused_at ON hammerwork_queue_pause(paused_at);
-- Add function to automatically update the updated_at timestamp
CREATE OR REPLACE FUNCTION update_hammerwork_queue_pause_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
-- Create trigger to automatically update updated_at
DROP TRIGGER IF EXISTS trigger_update_hammerwork_queue_pause_updated_at ON hammerwork_queue_pause;
CREATE TRIGGER trigger_update_hammerwork_queue_pause_updated_at
BEFORE UPDATE ON hammerwork_queue_pause
FOR EACH ROW
EXECUTE FUNCTION update_hammerwork_queue_pause_updated_at();"#;
let statements = super::split_sql_respecting_quotes(migration_014_sql);
assert_eq!(
statements.len(),
5,
"Should parse 5 statements from migration 014"
);
let function_statement = &statements[2];
assert!(function_statement.contains("CREATE OR REPLACE FUNCTION"));
assert!(function_statement.contains("RETURNS TRIGGER AS $$"));
assert!(function_statement.contains("$$ LANGUAGE plpgsql"));
assert!(function_statement.contains("NEW.updated_at = NOW()"));
let result = super::parse_sql_statements(migration_014_sql);
assert!(
result.is_ok(),
"Migration 014 SQL should parse successfully with sqlparser-rs: {:?}",
result
);
let parsed_statements = result.unwrap();
assert_eq!(
parsed_statements.len(),
5,
"Should parse 5 statements from migration 014 with sqlparser"
);
let create_function_stmt = parsed_statements
.iter()
.find(|stmt| stmt.contains("CREATE OR REPLACE FUNCTION"))
.expect("Should find CREATE FUNCTION statement");
assert!(
create_function_stmt.contains("update_hammerwork_queue_pause_updated_at()"),
"Function name must include parentheses even with no parameters. Statement: {}",
create_function_stmt
);
assert!(
create_function_stmt.contains("RETURNS TRIGGER"),
"Statement should contain RETURNS TRIGGER"
);
assert!(
create_function_stmt.contains("LANGUAGE plpgsql"),
"Statement should contain LANGUAGE plpgsql"
);
}
#[test]
fn test_migration_012_sql_parsing() {
let migration_012_sql = r#"-- Migration 012: Optimize job dependencies using native PostgreSQL arrays
-- Converts JSONB dependency arrays to native UUID[] arrays for better performance
-- This migration is wrapped in a transaction for safety
BEGIN;
-- Step 1: Add new UUID array columns
ALTER TABLE hammerwork_jobs
ADD COLUMN IF NOT EXISTS depends_on_array UUID[] DEFAULT '{}';
ALTER TABLE hammerwork_jobs
ADD COLUMN IF NOT EXISTS dependents_array UUID[] DEFAULT '{}';
-- Step 2: Migrate existing JSONB data to UUID arrays with validation
-- Handle depends_on column with UUID validation
UPDATE hammerwork_jobs
SET depends_on_array = CASE
WHEN depends_on IS NULL OR depends_on = 'null'::jsonb OR depends_on = '[]'::jsonb THEN '{}'::UUID[]
WHEN jsonb_typeof(depends_on) = 'array' THEN
ARRAY(
SELECT elem::UUID
FROM jsonb_array_elements_text(depends_on) AS elem
WHERE elem::text ~ '^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$'
)
ELSE '{}'::UUID[]
END;
-- Handle dependents column with UUID validation
UPDATE hammerwork_jobs
SET dependents_array = CASE
WHEN dependents IS NULL OR dependents = 'null'::jsonb OR dependents = '[]'::jsonb THEN '{}'::UUID[]
WHEN jsonb_typeof(dependents) = 'array' THEN
ARRAY(
SELECT elem::UUID
FROM jsonb_array_elements_text(dependents) AS elem
WHERE elem::text ~ '^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$'
)
ELSE '{}'::UUID[]
END;
-- Step 3: Verify data migration integrity (simplified for migration runner compatibility)
-- Note: Since the migration runner splits on semicolons, we skip complex validation
-- The column constraints and indexes below will catch any issues
-- Step 4: Create indexes on new array columns (before dropping old ones)
CREATE INDEX IF NOT EXISTS idx_hammerwork_jobs_depends_on_array
ON hammerwork_jobs USING GIN (depends_on_array);
CREATE INDEX IF NOT EXISTS idx_hammerwork_jobs_dependents_array
ON hammerwork_jobs USING GIN (dependents_array);
-- Step 5: Drop old JSONB indexes (will be recreated after column rename)
DROP INDEX IF EXISTS idx_hammerwork_jobs_depends_on;
DROP INDEX IF EXISTS idx_hammerwork_jobs_dependents;
-- Step 6: Drop old JSONB columns and rename array columns
ALTER TABLE hammerwork_jobs DROP COLUMN IF EXISTS depends_on;
ALTER TABLE hammerwork_jobs DROP COLUMN IF EXISTS dependents;
ALTER TABLE hammerwork_jobs RENAME COLUMN depends_on_array TO depends_on;
ALTER TABLE hammerwork_jobs RENAME COLUMN dependents_array TO dependents;
-- Step 7: Recreate indexes with original names
DROP INDEX IF EXISTS idx_hammerwork_jobs_depends_on_array;
DROP INDEX IF EXISTS idx_hammerwork_jobs_dependents_array;
CREATE INDEX IF NOT EXISTS idx_hammerwork_jobs_depends_on
ON hammerwork_jobs USING GIN (depends_on);
CREATE INDEX IF NOT EXISTS idx_hammerwork_jobs_dependents
ON hammerwork_jobs USING GIN (dependents);
-- Step 8: Update comments to reflect new column types
COMMENT ON COLUMN hammerwork_jobs.depends_on IS 'Array of job IDs this job depends on (native UUID array)';
COMMENT ON COLUMN hammerwork_jobs.dependents IS 'Cached array of job IDs that depend on this job (native UUID array)';
-- Step 9: Add constraint to ensure reasonable array sizes (prevent abuse)
ALTER TABLE hammerwork_jobs
ADD CONSTRAINT chk_depends_on_size
CHECK (array_length(depends_on, 1) IS NULL OR array_length(depends_on, 1) <= 1000);
ALTER TABLE hammerwork_jobs
ADD CONSTRAINT chk_dependents_size
CHECK (array_length(dependents, 1) IS NULL OR array_length(dependents, 1) <= 10000);
COMMIT;"#;
let statements = super::split_sql_respecting_quotes(migration_012_sql);
assert_eq!(
statements.len(),
22,
"Should parse 22 statements from migration 012"
);
assert!(
statements[0].contains("BEGIN"),
"First statement should be BEGIN"
);
assert!(
statements[21].contains("COMMIT"),
"Last statement should be COMMIT"
);
let depends_on_update = statements.iter().find(|stmt| {
stmt.contains("SET depends_on_array = CASE")
&& stmt.contains("jsonb_array_elements_text(depends_on)")
});
assert!(
depends_on_update.is_some(),
"Should find depends_on UPDATE statement"
);
let dependents_update = statements.iter().find(|stmt| {
stmt.contains("SET dependents_array = CASE")
&& stmt.contains("jsonb_array_elements_text(dependents)")
});
assert!(
dependents_update.is_some(),
"Should find dependents UPDATE statement"
);
let comment_statements: Vec<_> = statements
.iter()
.filter(|stmt| stmt.contains("COMMENT ON COLUMN"))
.collect();
assert_eq!(
comment_statements.len(),
2,
"Should have 2 COMMENT statements"
);
let constraint_statements: Vec<_> = statements
.iter()
.filter(|stmt| {
stmt.contains("ADD CONSTRAINT")
&& (stmt.contains("chk_depends_on_size")
|| stmt.contains("chk_dependents_size"))
})
.collect();
assert_eq!(
constraint_statements.len(),
2,
"Should have 2 constraint statements"
);
let result = super::parse_sql_statements(migration_012_sql);
assert!(
result.is_ok(),
"Migration 012 SQL should parse successfully with sqlparser-rs: {:?}",
result
);
}
}