use anyhow::Context;
use console::style;
use rok_orm_migrate::{FileSource, MigrationRunner};
use sqlx::postgres::PgPoolOptions;
async fn connect() -> anyhow::Result<sqlx::PgPool> {
let _ = dotenvy::dotenv();
let url = std::env::var("DATABASE_URL")
.context("DATABASE_URL not set — add it to .env or the environment")?;
PgPoolOptions::new()
.max_connections(3)
.connect(&url)
.await
.with_context(|| format!("failed to connect to database: {url}"))
}
fn migration_dir() -> &'static str {
if std::path::Path::new("database/migrations").exists() {
"database/migrations"
} else {
"migrations"
}
}
fn runner(pool: sqlx::PgPool) -> MigrationRunner {
MigrationRunner::new(pool).source(FileSource::new(migration_dir()))
}
fn print_header(title: &str) {
println!(
"{}",
style("╔══════════════════════════════════════╗").bold()
);
println!(
"{} {:22} {}",
style("║").bold(),
style(title).cyan(),
style("║").bold()
);
println!(
"{}",
style("╚══════════════════════════════════════╝").bold()
);
println!();
}
pub async fn migrate(dry_run: bool) -> anyhow::Result<()> {
print_header(if dry_run {
"db:migrate — dry run"
} else {
"db:migrate"
});
let pool = connect().await?;
if dry_run {
println!(" {} Pending migrations:", style("ℹ").cyan());
runner(pool).status().await?;
} else {
println!(" {} Applying migrations...", style("~").yellow());
runner(pool).run().await?;
println!(" {} All migrations applied.", style("✔").green());
}
Ok(())
}
pub async fn update() -> anyhow::Result<()> {
print_header("db:update");
let pool = connect().await?;
println!(
" {} Migration directory: {}",
style("ℹ").cyan(),
migration_dir()
);
let pre_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM _migrations")
.fetch_one(&pool)
.await
.unwrap_or(0);
println!(" {} Applied migrations: {}", style("ℹ").cyan(), pre_count);
let total_files: usize = std::fs::read_dir(migration_dir())
.map(|rd| {
rd.filter_map(|e| e.ok())
.filter(|e| e.path().extension().and_then(|x| x.to_str()) == Some("sql"))
.count()
})
.unwrap_or(0);
let pending = total_files.saturating_sub(pre_count as usize);
if pending == 0 {
println!(
" {} No pending migrations. Database is up to date.",
style("✔").green()
);
return Ok(());
}
println!(" {} Pending migrations: {}", style("~").yellow(), pending);
println!();
runner(pool.clone()).run().await?;
let post_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM _migrations")
.fetch_one(&pool)
.await
.unwrap_or(0);
println!();
println!(
" {} Applied {} migration(s) — total: {}",
style("✔").green(),
post_count - pre_count,
post_count
);
Ok(())
}
pub async fn rollback(step: u32) -> anyhow::Result<()> {
print_header("db:rollback");
let pool = connect().await?;
let step = step.max(1);
let pre_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM _migrations")
.fetch_one(&pool)
.await
.unwrap_or(0);
for i in 0..step {
println!(
" {} Rolling back step {}/{}...",
style("~").yellow(),
i + 1,
step
);
runner(pool.clone()).rollback().await?;
}
let post_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM _migrations")
.fetch_one(&pool)
.await
.unwrap_or(0);
println!(
" {} Rolled back {} migration(s) — remaining: {}",
style("✔").green(),
pre_count - post_count,
post_count
);
Ok(())
}
pub async fn status() -> anyhow::Result<()> {
print_header("db:status");
let pool = connect().await?;
runner(pool).status().await?;
Ok(())
}
pub async fn seed(class: Option<&str>) -> anyhow::Result<()> {
print_header("db:seed");
if let Some(class) = class {
println!(" {} Run seeder: {}", style("~").yellow(), class);
println!();
println!(" Seeders are compiled into your app binary. Add to main.rs:");
println!(" {}::run(&pool).await?;", class);
} else {
println!(
" {} Seeders must be registered and run from your app binary.",
style("ℹ").cyan()
);
println!();
println!(" Add this to your main.rs or a seeder binary:");
println!(" UsersSeeder::run(&pool).await?;");
}
Ok(())
}
pub async fn fresh() -> anyhow::Result<()> {
print_header("db:fresh");
let pool = connect().await?;
println!(" {} Dropping all user tables...", style("~").yellow());
sqlx::query(
"DO $$
DECLARE r RECORD;
BEGIN
FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP
EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
END LOOP;
END $$;",
)
.execute(&pool)
.await
.context("failed to drop tables")?;
println!(" {} All tables dropped.", style("✔").green());
println!(" {} Running migrations...", style("~").yellow());
runner(pool).run().await?;
println!(" {} Database refreshed.", style("✔").green());
Ok(())
}