use clap::{Parser, Subcommand};
use sha2::{Digest, Sha256};
use sqlx::postgres::PgPoolOptions;
use sqlx::{Executor, Row};
use std::fs;
use std::path::Path;
#[derive(Parser)]
#[command(name = "pgsql-migrate")]
#[command(about = "A simple PostgreSQL migration tool", long_about = None)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
#[command(name = "up")]
Up {
#[arg(short = 'p', long = "path", default_value = "migrations")]
path: String,
#[arg(short = 'd', long = "database")]
database: String,
#[arg(short = 'e', long = "env", default_value = "prod")]
env: String,
},
#[command(name = "down")]
Down {
#[arg(short = 'p', long = "path", default_value = "migrations")]
path: String,
#[arg(short = 'd', long = "database")]
database: String,
#[arg(short = 'e', long = "env", default_value = "prod")]
env: String,
#[arg(default_value = "1")]
count: u32,
},
#[command(name = "create")]
Create {
#[arg(short = 'd', long = "dir", default_value = "migrations")]
dir: String,
#[arg(short = 's', long = "seq")]
name: String,
},
#[command(name = "baseline")]
Baseline {
#[arg(short = 'p', long = "path", default_value = "migrations")]
path: String,
#[arg(short = 'd', long = "database")]
database: String,
#[arg(short = 'v', long = "version")]
version: u32,
},
#[command(name = "redo")]
Redo {
#[arg(short = 'p', long = "path", default_value = "migrations")]
path: String,
#[arg(short = 'd', long = "database")]
database: String,
#[arg(short = 'e', long = "env", default_value = "prod")]
env: String,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum MigrationFeature {
NoTransaction,
SplitStatements,
}
impl MigrationFeature {
fn from_str(s: &str) -> Option<Self> {
match s.trim().to_lowercase().as_str() {
"no-tx" => Some(MigrationFeature::NoTransaction),
"split-statements" => Some(MigrationFeature::SplitStatements),
_ => None,
}
}
}
#[derive(Debug, Clone)]
struct MigrationSpec {
content: String,
features: Vec<MigrationFeature>,
}
impl MigrationSpec {
fn new(content: String) -> Self {
let features = Self::parse_features(&content);
Self { content, features }
}
fn empty() -> Self {
Self {
content: String::new(),
features: Vec::new(),
}
}
fn has_no_tx(&self) -> bool {
self.features.contains(&MigrationFeature::NoTransaction)
}
fn has_split_statements(&self) -> bool {
self.features.contains(&MigrationFeature::SplitStatements)
}
fn is_empty(&self) -> bool {
self.content.is_empty()
}
fn parse_features(content: &str) -> Vec<MigrationFeature> {
for line in content.lines() {
let trimmed = line.trim();
if trimmed.starts_with("-- features:") {
let features_str = trimmed.trim_start_matches("-- features:").trim();
return features_str
.split(',')
.filter_map(MigrationFeature::from_str)
.collect();
}
if !trimmed.is_empty() && !trimmed.starts_with("--") {
break;
}
}
Vec::new()
}
}
struct Migration {
version: u32,
filename: String,
up: MigrationSpec,
down: MigrationSpec,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
match cli.command {
Commands::Up {
path,
database,
env,
} => {
run_up(&path, &database, &env).await?;
}
Commands::Down {
path,
database,
env,
count,
} => {
run_down(&path, &database, &env, count).await?;
}
Commands::Create { dir, name } => {
create_migration(&dir, &name)?;
}
Commands::Baseline {
path,
database,
version,
} => {
run_baseline(&path, &database, version).await?;
}
Commands::Redo {
path,
database,
env,
} => {
run_redo(&path, &database, &env).await?;
}
}
Ok(())
}
fn normalize_name(name: &str) -> String {
name.to_lowercase()
.chars()
.map(|c| if c == ' ' { '_' } else { c })
.filter(|c| c.is_alphanumeric() || *c == '_')
.collect()
}
fn get_next_version(dir: &Path) -> Result<u32, Box<dyn std::error::Error>> {
if !dir.exists() {
return Ok(1);
}
let mut max_version: u32 = 0;
for entry in fs::read_dir(dir)? {
let entry = entry?;
let file_name = entry.file_name();
let name = file_name.to_string_lossy();
if let Some(version_str) = name.split('_').next() {
if let Ok(version) = version_str.parse::<u32>() {
max_version = max_version.max(version);
}
}
}
Ok(max_version + 1)
}
fn create_migration(dir: &str, name: &str) -> Result<(), Box<dyn std::error::Error>> {
let dir_path = Path::new(dir);
if !dir_path.exists() {
fs::create_dir_all(dir_path)?;
println!("Created migrations directory: {}", dir);
}
let version = get_next_version(dir_path)?;
let normalized_name = normalize_name(name);
let up_filename = format!("{:06}_{}.up.sql", version, normalized_name);
let down_filename = format!("{:06}_{}.down.sql", version, normalized_name);
let up_path = dir_path.join(&up_filename);
let down_path = dir_path.join(&down_filename);
fs::write(&up_path, "-- Add migration script here\n")?;
fs::write(&down_path, "-- Add rollback script here\n")?;
println!("Created migration files:");
println!(" {}", up_path.display());
println!(" {}", down_path.display());
Ok(())
}
fn compute_hash(content: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
format!("{:x}", hasher.finalize())
}
#[derive(Debug, Clone)]
struct SqlBlock {
content: String,
skip_on_env: Vec<String>,
}
impl SqlBlock {
fn should_skip(&self, current_env: &str) -> bool {
self.skip_on_env
.iter()
.any(|e| e.eq_ignore_ascii_case(current_env))
}
}
fn split_sql_by_markers(content: &str) -> Result<Vec<SqlBlock>, String> {
let mut blocks = Vec::new();
let mut current_block = String::new();
let mut current_skip_envs: Vec<String> = Vec::new();
let mut in_block = false;
let mut block_start_line = 0;
for (line_num, line) in content.lines().enumerate() {
let trimmed = line.trim();
let line_number = line_num + 1;
if trimmed == "-- split-start" {
if in_block {
return Err(format!(
"Line {}: Found '-- split-start' but previous block starting at line {} was not closed with '-- split-end'",
line_number, block_start_line
));
}
in_block = true;
block_start_line = line_number;
current_block.clear();
current_skip_envs.clear();
continue;
}
if trimmed == "-- split-end" {
if !in_block {
return Err(format!(
"Line {}: Found '-- split-end' without a matching '-- split-start'",
line_number
));
}
let block_content = current_block.trim().to_string();
if !block_content.is_empty() {
blocks.push(SqlBlock {
content: block_content,
skip_on_env: current_skip_envs.clone(),
});
}
in_block = false;
current_block.clear();
current_skip_envs.clear();
continue;
}
if in_block {
if trimmed.starts_with("-- skip-on-env") {
let envs_str = trimmed.trim_start_matches("-- skip-on-env").trim();
current_skip_envs = envs_str
.split(',')
.map(|e| e.trim().to_lowercase())
.filter(|e| !e.is_empty())
.collect();
continue;
}
if !current_block.is_empty() {
current_block.push('\n');
}
current_block.push_str(line);
}
}
if in_block {
return Err(format!(
"Block starting at line {} was not closed with '-- split-end'",
block_start_line
));
}
if blocks.is_empty() {
return Err(
"split-statements feature requires at least one block delimited by '-- split-start' and '-- split-end'".to_string()
);
}
Ok(blocks)
}
async fn ensure_schema_migrations_table(
pool: &sqlx::PgPool,
) -> Result<(), Box<dyn std::error::Error>> {
pool.execute(
r#"
CREATE TABLE IF NOT EXISTS pgsql_migrate_schema_migrations (
version BIGINT PRIMARY KEY,
dirty BOOLEAN NOT NULL DEFAULT FALSE,
content_hash VARCHAR(64),
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"#,
)
.await?;
Ok(())
}
async fn get_applied_migrations(
pool: &sqlx::PgPool,
) -> Result<Vec<(i64, bool, Option<String>)>, Box<dyn std::error::Error>> {
let rows = sqlx::query(
"SELECT version, dirty, content_hash FROM pgsql_migrate_schema_migrations ORDER BY version",
)
.fetch_all(pool)
.await?;
let migrations: Vec<(i64, bool, Option<String>)> = rows
.iter()
.map(|row| {
(
row.get("version"),
row.get("dirty"),
row.get("content_hash"),
)
})
.collect();
Ok(migrations)
}
async fn check_dirty_migrations(pool: &sqlx::PgPool) -> Result<(), Box<dyn std::error::Error>> {
let applied = get_applied_migrations(pool).await?;
for (version, dirty, _) in applied {
if dirty {
return Err(format!(
"Migration {} is dirty. Please fix it manually and update the pgsql_migrate_schema_migrations table.",
version
)
.into());
}
}
Ok(())
}
async fn get_current_version(
pool: &sqlx::PgPool,
) -> Result<Option<i64>, Box<dyn std::error::Error>> {
let result =
sqlx::query("SELECT MAX(version) as max_version FROM pgsql_migrate_schema_migrations")
.fetch_one(pool)
.await?;
let version: Option<i64> = result.get("max_version");
Ok(version)
}
async fn print_current_version(pool: &sqlx::PgPool) -> Result<(), Box<dyn std::error::Error>> {
match get_current_version(pool).await? {
Some(version) => println!("Current version: {}", version),
None => println!("Current version: None (no migrations applied)"),
}
Ok(())
}
fn parse_migrations(dir: &Path) -> Result<Vec<Migration>, Box<dyn std::error::Error>> {
let mut migrations: Vec<Migration> = Vec::new();
if !dir.exists() {
return Err(format!("Migrations directory '{}' does not exist", dir.display()).into());
}
let mut up_files: std::collections::HashMap<u32, (String, String)> =
std::collections::HashMap::new();
let mut down_files: std::collections::HashMap<u32, String> = std::collections::HashMap::new();
for entry in fs::read_dir(dir)? {
let entry = entry?;
let file_name = entry.file_name();
let name = file_name.to_string_lossy().to_string();
if name.ends_with(".up.sql") {
if let Some(version_str) = name.split('_').next() {
if let Ok(version) = version_str.parse::<u32>() {
let content = fs::read_to_string(entry.path())?;
up_files.insert(version, (name.clone(), content));
}
}
} else if name.ends_with(".down.sql") {
if let Some(version_str) = name.split('_').next() {
if let Ok(version) = version_str.parse::<u32>() {
let content = fs::read_to_string(entry.path())?;
down_files.insert(version, content);
}
}
}
}
for (version, (filename, up_content)) in up_files {
let down_content = down_files.get(&version).cloned().unwrap_or_default();
migrations.push(Migration {
version,
filename,
up: MigrationSpec::new(up_content),
down: if down_content.is_empty() {
MigrationSpec::empty()
} else {
MigrationSpec::new(down_content)
},
});
}
migrations.sort_by_key(|m| m.version);
Ok(migrations)
}
async fn run_up(path: &str, database: &str, env: &str) -> Result<(), Box<dyn std::error::Error>> {
println!("Running migrations in environment: {}", env);
let pool = PgPoolOptions::new()
.max_connections(1)
.connect(database)
.await?;
ensure_schema_migrations_table(&pool).await?;
check_dirty_migrations(&pool).await?;
let applied = get_applied_migrations(&pool).await?;
let applied_map: std::collections::HashMap<i64, Option<String>> = applied
.iter()
.map(|(v, _, hash)| (*v, hash.clone()))
.collect();
let migrations = parse_migrations(Path::new(path))?;
let mut applied_count = 0;
for migration in migrations {
let version_i64 = migration.version as i64;
let current_hash = compute_hash(&migration.up.content);
if let Some(stored_hash) = applied_map.get(&version_i64) {
if let Some(ref hash) = stored_hash {
if hash != ¤t_hash {
eprintln!(
" WARNING: Migration {} content has changed since it was applied!",
migration.filename
);
eprintln!(" Stored hash: {}", hash);
eprintln!(" Current hash: {}", current_hash);
}
}
continue;
}
println!("Applying migration: {}", migration.filename);
sqlx::query("INSERT INTO pgsql_migrate_schema_migrations (version, dirty, content_hash) VALUES ($1, TRUE, $2)")
.bind(version_i64)
.bind(¤t_hash)
.execute(&pool)
.await?;
let use_transaction = !migration.up.has_no_tx();
let use_split = migration.up.has_split_statements();
if !use_transaction {
println!(" (running without transaction due to no-tx feature)");
}
if use_split {
println!(" (splitting statements by markers due to split-statements feature)");
}
let result: Result<(), Box<dyn std::error::Error>> = if use_split {
match split_sql_by_markers(&migration.up.content) {
Ok(blocks) => {
let mut exec_result: Result<(), Box<dyn std::error::Error>> = Ok(());
for (i, block) in blocks.iter().enumerate() {
if block.should_skip(env) {
println!(
" Skipping block {} (skip-on-env: {} matches current env: {})",
i + 1,
block.skip_on_env.join(","),
env
);
continue;
}
if use_transaction {
let mut tx = pool.begin().await?;
match tx.execute(block.content.as_str()).await {
Ok(_) => {
tx.commit().await?;
}
Err(e) => {
eprintln!(" Error in block {}: {}", i + 1, e);
exec_result = Err(e.into());
break;
}
}
} else {
match pool.execute(block.content.as_str()).await {
Ok(_) => {}
Err(e) => {
eprintln!(" Error in block {}: {}", i + 1, e);
exec_result = Err(e.into());
break;
}
}
}
}
exec_result
}
Err(e) => Err(format!("Failed to parse split markers: {}", e).into()),
}
} else if use_transaction {
let mut tx = pool.begin().await?;
match tx.execute(migration.up.content.as_str()).await {
Ok(_) => {
tx.commit().await?;
Ok(())
}
Err(e) => Err(e.into()),
}
} else {
pool.execute(migration.up.content.as_str())
.await
.map(|_| ())
.map_err(|e| e.into())
};
match result {
Ok(_) => {
sqlx::query(
"UPDATE pgsql_migrate_schema_migrations SET dirty = FALSE WHERE version = $1",
)
.bind(version_i64)
.execute(&pool)
.await?;
println!(" Applied successfully");
applied_count += 1;
}
Err(e) => {
eprintln!(" Error applying migration {}: {}", migration.filename, e);
eprintln!(" Migration {} is now marked as dirty.", migration.version);
eprintln!(" Please fix the issue and update pgsql_migrate_schema_migrations table manually.");
return Err(e);
}
}
}
if applied_count == 0 {
println!("No new migrations to apply.");
} else {
println!("Applied {} migration(s).", applied_count);
}
print_current_version(&pool).await?;
Ok(())
}
async fn run_down(
path: &str,
database: &str,
env: &str,
count: u32,
) -> Result<(), Box<dyn std::error::Error>> {
println!("Running rollback in environment: {}", env);
let pool = PgPoolOptions::new()
.max_connections(1)
.connect(database)
.await?;
ensure_schema_migrations_table(&pool).await?;
check_dirty_migrations(&pool).await?;
let applied = get_applied_migrations(&pool).await?;
if applied.is_empty() {
println!("No migrations to rollback.");
return Ok(());
}
let migrations = parse_migrations(Path::new(path))?;
let migration_map: std::collections::HashMap<u32, Migration> =
migrations.into_iter().map(|m| (m.version, m)).collect();
let mut versions_to_rollback: Vec<i64> = applied.iter().map(|(v, _, _)| *v).collect();
versions_to_rollback.reverse();
versions_to_rollback.truncate(count as usize);
let mut rolled_back_count = 0;
for version in versions_to_rollback {
let version_u32 = version as u32;
if let Some(migration) = migration_map.get(&version_u32) {
println!("Rolling back migration: {}", migration.filename);
if migration.down.is_empty() {
eprintln!(" Warning: No down migration found for version {}", version);
continue;
}
sqlx::query(
"UPDATE pgsql_migrate_schema_migrations SET dirty = TRUE WHERE version = $1",
)
.bind(version)
.execute(&pool)
.await?;
let use_transaction = !migration.down.has_no_tx();
let use_split = migration.down.has_split_statements();
if !use_transaction {
println!(" (running without transaction due to no-tx feature)");
}
if use_split {
println!(" (splitting statements by markers due to split-statements feature)");
}
let result: Result<(), Box<dyn std::error::Error>> = if use_split {
match split_sql_by_markers(&migration.down.content) {
Ok(blocks) => {
let mut exec_result: Result<(), Box<dyn std::error::Error>> = Ok(());
for (i, block) in blocks.iter().enumerate() {
if block.should_skip(env) {
println!(
" Skipping block {} (skip-on-env: {} matches current env: {})",
i + 1,
block.skip_on_env.join(","),
env
);
continue;
}
if use_transaction {
let mut tx = pool.begin().await?;
match tx.execute(block.content.as_str()).await {
Ok(_) => {
tx.commit().await?;
}
Err(e) => {
eprintln!(" Error in block {}: {}", i + 1, e);
exec_result = Err(e.into());
break;
}
}
} else {
match pool.execute(block.content.as_str()).await {
Ok(_) => {}
Err(e) => {
eprintln!(" Error in block {}: {}", i + 1, e);
exec_result = Err(e.into());
break;
}
}
}
}
exec_result
}
Err(e) => Err(format!("Failed to parse split markers: {}", e).into()),
}
} else if use_transaction {
let mut tx = pool.begin().await?;
match tx.execute(migration.down.content.as_str()).await {
Ok(_) => {
tx.commit().await?;
Ok(())
}
Err(e) => Err(e.into()),
}
} else {
pool.execute(migration.down.content.as_str())
.await
.map(|_| ())
.map_err(|e| e.into())
};
match result {
Ok(_) => {
sqlx::query("DELETE FROM pgsql_migrate_schema_migrations WHERE version = $1")
.bind(version)
.execute(&pool)
.await?;
println!(" Rolled back successfully");
rolled_back_count += 1;
}
Err(e) => {
eprintln!(
" Error rolling back migration {}: {}",
migration.filename, e
);
eprintln!(" Migration {} is now marked as dirty.", version);
eprintln!(
" Please fix the issue and update pgsql_migrate_schema_migrations table manually."
);
return Err(e);
}
}
} else {
eprintln!("Warning: Migration file not found for version {}", version);
}
}
if rolled_back_count == 0 {
println!("No migrations rolled back.");
} else {
println!("Rolled back {} migration(s).", rolled_back_count);
}
print_current_version(&pool).await?;
Ok(())
}
async fn run_baseline(
path: &str,
database: &str,
target_version: u32,
) -> Result<(), Box<dyn std::error::Error>> {
let pool = PgPoolOptions::new()
.max_connections(1)
.connect(database)
.await?;
ensure_schema_migrations_table(&pool).await?;
let applied = get_applied_migrations(&pool).await?;
let applied_versions: std::collections::HashSet<i64> =
applied.iter().map(|(v, _, _)| *v).collect();
let migrations = parse_migrations(Path::new(path))?;
let migrations_to_baseline: Vec<&Migration> = migrations
.iter()
.filter(|m| m.version <= target_version)
.collect();
if migrations_to_baseline.is_empty() {
println!("No migrations found up to version {}", target_version);
return Ok(());
}
let mut baselined_count = 0;
for migration in migrations_to_baseline {
let version_i64 = migration.version as i64;
if applied_versions.contains(&version_i64) {
println!("Skipping already applied migration: {}", migration.filename);
continue;
}
let content_hash = compute_hash(&migration.up.content);
sqlx::query(
"INSERT INTO pgsql_migrate_schema_migrations (version, dirty, content_hash, applied_at) VALUES ($1, FALSE, $2, NOW())",
)
.bind(version_i64)
.bind(&content_hash)
.execute(&pool)
.await?;
println!("Baselined migration: {}", migration.filename);
baselined_count += 1;
}
if baselined_count == 0 {
println!("No new migrations to baseline.");
} else {
println!(
"Baselined {} migration(s) up to version {}.",
baselined_count, target_version
);
}
Ok(())
}
async fn run_redo(path: &str, database: &str, env: &str) -> Result<(), Box<dyn std::error::Error>> {
let pool = PgPoolOptions::new()
.max_connections(1)
.connect(database)
.await?;
ensure_schema_migrations_table(&pool).await?;
let applied = get_applied_migrations(&pool).await?;
let dirty_migration = applied
.iter()
.filter(|(_, dirty, _)| *dirty)
.max_by_key(|(version, _, _)| version);
let (version, _, _) = match dirty_migration {
Some(m) => m,
None => {
println!("No dirty migrations found.");
return Ok(());
}
};
println!("Redoing migration version: {}", version);
sqlx::query("DELETE FROM pgsql_migrate_schema_migrations WHERE version = $1")
.bind(version)
.execute(&pool)
.await?;
run_up(path, database, env).await?;
Ok(())
}