use crate::cli::{DbCommand, DbTunnelArgs};
use crate::util::{Result, usage_error};
use regex::Regex;
use std::env;
use std::fs;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use std::thread;
use std::time::Duration;
pub(crate) fn run_db_command(command: DbCommand) -> Result<()> {
if let DbCommand::Tunnel(args) = command {
return open_database_tunnel(args);
}
let (action, vars) = match command {
DbCommand::Tunnel(_) => unreachable!("handled above"),
DbCommand::Migrate(args) => ("migrate", args.vars),
DbCommand::Up(args) => ("migrate", args.vars),
DbCommand::Status(args) => ("status", args.vars),
DbCommand::Revert(args) => ("revert", args.vars),
DbCommand::Down(args) => ("revert", args.vars),
DbCommand::Reset(args) => ("reset", args.vars),
DbCommand::New(args) => ("new", args.vars),
};
for arg in vars {
if let Some((key, value)) = arg.split_once('=') {
unsafe { env::set_var(key, value) };
}
}
let cfg = MigrationConfig::from_env();
match action {
"migrate" | "up" => apply_migrations(&cfg),
"status" => migration_status(&cfg),
"revert" | "down" => revert_migration(&cfg),
"reset" => reset_database(&cfg),
"new" => create_migration(&cfg.dir),
other => usage_error(format!("unknown db command: {other}")),
}
}
fn open_database_tunnel(mut args: DbTunnelArgs) -> Result<()> {
let ssh_target = match args.ssh_target.take() {
Some(target) if !target.trim().is_empty() => target,
_ => {
let server = match args.server.take() {
Some(server) if !server.trim().is_empty() => server,
_ => prompt("Server IP or hostname")?,
};
let ssh_user = if args.ssh_user.trim().is_empty() {
prompt_with_default("SSH user", "root")?
} else {
args.ssh_user
};
format!("{ssh_user}@{server}")
}
};
if args.remote_host.trim().is_empty() {
args.remote_host = prompt_with_default("Remote database host", "127.0.0.1")?;
}
if args.database.trim().is_empty() {
args.database = prompt_with_default("Database name", "core_auth")?;
}
if args.db_user.trim().is_empty() {
args.db_user = prompt_with_default("Database user", "executesoft")?;
}
println!("Opening Postgres SSH tunnel:");
println!(" local: 127.0.0.1:{}", args.local_port);
println!(" remote: {}:{}", args.remote_host, args.remote_port);
println!(" ssh: {ssh_target}");
println!();
println!("Keep this terminal open, then connect locally to:");
println!(
" postgres://{}:<POSTGRES_PASSWORD>@127.0.0.1:{}/{}",
args.db_user, args.local_port, args.database
);
println!();
let mut child = Command::new("ssh")
.args([
"-N",
"-L",
&format!(
"{}:{}:{}",
args.local_port, args.remote_host, args.remote_port
),
&ssh_target,
])
.stdin(Stdio::inherit())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.spawn()?;
thread::sleep(Duration::from_millis(800));
if let Some(status) = child.try_wait()? {
println!("Tunnel disconnected.");
return Err(format!("ssh tunnel exited with {status}").into());
}
println!("Tunnel connected. Press Ctrl+C to disconnect.");
let status = child.wait()?;
println!("Tunnel disconnected.");
if status.success() {
Ok(())
} else {
Err(format!("ssh tunnel exited with {status}").into())
}
}
fn prompt(label: &str) -> Result<String> {
print!("{label}: ");
io::stdout().flush()?;
let mut value = String::new();
io::stdin().read_line(&mut value)?;
let value = value.trim().to_string();
if value.is_empty() {
usage_error(format!("{label} is required"))
} else {
Ok(value)
}
}
fn prompt_with_default(label: &str, default: &str) -> Result<String> {
print!("{label} [{default}]: ");
io::stdout().flush()?;
let mut value = String::new();
io::stdin().read_line(&mut value)?;
let value = value.trim();
if value.is_empty() {
Ok(default.to_string())
} else {
Ok(value.to_string())
}
}
struct MigrationConfig {
dir: PathBuf,
container: String,
user: String,
database: String,
schema: String,
table: String,
}
impl MigrationConfig {
fn from_env() -> Self {
Self {
dir: PathBuf::from(env::var("MIGRATIONS_DIR").unwrap_or_else(|_| "migrations".into())),
container: env::var("MIGRATION_DB_CONTAINER")
.unwrap_or_else(|_| "executesoft-dev-postgres-1".into()),
user: env::var("MIGRATION_DB_USER").unwrap_or_else(|_| "executesoft".into()),
database: env::var("MIGRATION_DB_NAME").unwrap_or_default(),
schema: env::var("MIGRATION_SCHEMA").unwrap_or_else(|_| "public".into()),
table: env::var("MIGRATION_TABLE").unwrap_or_else(|_| "schema_migrations".into()),
}
}
}
fn require_database(cfg: &MigrationConfig) -> Result<()> {
if cfg.database.is_empty() {
usage_error("MIGRATION_DB_NAME is required".into())
} else {
Ok(())
}
}
fn docker_psql(
cfg: &MigrationConfig,
database: &str,
extra: &[&str],
stdin: Option<&str>,
) -> Result<String> {
let mut cmd = Command::new("docker");
cmd.args([
"exec",
"-i",
&cfg.container,
"psql",
"-v",
"ON_ERROR_STOP=1",
"-U",
&cfg.user,
"-d",
database,
]);
cmd.args(extra);
cmd.stderr(Stdio::inherit());
if stdin.is_some() {
cmd.stdin(Stdio::piped());
}
cmd.stdout(Stdio::piped());
let mut child = cmd.spawn()?;
if let Some(input) = stdin {
child.stdin.as_mut().unwrap().write_all(input.as_bytes())?;
}
let output = child.wait_with_output()?;
if !output.status.success() {
return Err(format!("docker psql exited with {}", output.status).into());
}
Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
}
fn ensure_database(cfg: &MigrationConfig) -> Result<()> {
require_database(cfg)?;
let sql = "SELECT format('CREATE DATABASE %I', :'db_name') WHERE NOT EXISTS (SELECT 1 FROM pg_database WHERE datname = :'db_name')\\gexec\n".to_string();
docker_psql(
cfg,
"postgres",
&["-v", &format!("db_name={}", cfg.database)],
Some(&sql),
)?;
Ok(())
}
fn psql(cfg: &MigrationConfig, extra: &[&str], stdin: Option<&str>) -> Result<String> {
ensure_database(cfg)?;
docker_psql(cfg, &cfg.database, extra, stdin)
}
fn ensure_tracking_table(cfg: &MigrationConfig) -> Result<()> {
let sql = format!(
"CREATE TABLE IF NOT EXISTS {}.{} (version TEXT PRIMARY KEY, filename TEXT NOT NULL, applied_at TIMESTAMPTZ NOT NULL DEFAULT now());",
cfg.schema, cfg.table
);
psql(cfg, &[], Some(&sql)).map(|_| ())
}
fn migration_files(dir: &Path) -> Result<Vec<PathBuf>> {
if !dir.exists() {
println!("No migrations directory found: {}", dir.display());
return Ok(Vec::new());
}
let mut files = Vec::new();
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
let name = entry.file_name().to_string_lossy().to_string();
if path.is_file()
&& name.ends_with(".sql")
&& !name.ends_with(".down.sql")
&& !name.ends_with("_down.sql")
{
files.push(path);
}
}
files.sort();
Ok(files)
}
fn migration_version(path: &Path) -> String {
path.file_stem()
.unwrap_or_default()
.to_string_lossy()
.to_string()
}
fn apply_migrations(cfg: &MigrationConfig) -> Result<()> {
let files = migration_files(&cfg.dir)?;
if files.is_empty() {
println!("No migration files found in {}.", cfg.dir.display());
return Ok(());
}
ensure_tracking_table(cfg)?;
let mut applied = 0;
for file in files {
let version = migration_version(&file);
let exists = psql(
cfg,
&[
"-At",
"-c",
&format!(
"SELECT EXISTS (SELECT 1 FROM {}.{} WHERE version = '{}');",
cfg.schema,
cfg.table,
sql_literal(&version)
),
],
None,
)?;
if exists == "t" {
println!("Skipping already applied migration: {version}");
continue;
}
println!("Applying migration: {version}");
let sql = fs::read_to_string(&file)?;
psql(cfg, &[], Some(&sql))?;
psql(
cfg,
&[
"-c",
&format!(
"INSERT INTO {}.{} (version, filename) VALUES ('{}', '{}');",
cfg.schema,
cfg.table,
sql_literal(&version),
sql_literal(file.file_name().unwrap().to_string_lossy().as_ref())
),
],
None,
)?;
applied += 1;
}
if applied == 0 {
println!("No pending migrations.");
}
Ok(())
}
fn migration_status(cfg: &MigrationConfig) -> Result<()> {
let files = migration_files(&cfg.dir)?;
if files.is_empty() {
println!("No migration files found in {}.", cfg.dir.display());
return Ok(());
}
ensure_tracking_table(cfg)?;
for file in files {
let version = migration_version(&file);
let exists = psql(
cfg,
&[
"-At",
"-c",
&format!(
"SELECT EXISTS (SELECT 1 FROM {}.{} WHERE version = '{}');",
cfg.schema,
cfg.table,
sql_literal(&version)
),
],
None,
)?;
println!(
"{} {version}",
if exists == "t" { "applied" } else { "pending" }
);
}
Ok(())
}
fn down_file_for(dir: &Path, version: &str) -> Option<PathBuf> {
[
dir.join("down").join(format!("{version}.sql")),
dir.join(format!("{version}.down.sql")),
dir.join(format!("{version}_down.sql")),
]
.into_iter()
.find(|path| path.exists())
}
fn revert_migration(cfg: &MigrationConfig) -> Result<()> {
ensure_tracking_table(cfg)?;
let latest = psql(
cfg,
&[
"-At",
"-c",
&format!(
"SELECT version FROM {}.{} ORDER BY applied_at DESC, version DESC LIMIT 1;",
cfg.schema, cfg.table
),
],
None,
)?;
if latest.is_empty() {
println!("No applied migrations to revert.");
return Ok(());
}
let down = down_file_for(&cfg.dir, &latest)
.ok_or_else(|| format!("No down migration found for {latest}"))?;
println!("Reverting migration: {latest}");
psql(cfg, &[], Some(&fs::read_to_string(down)?))?;
psql(
cfg,
&[
"-c",
&format!(
"DELETE FROM {}.{} WHERE version = '{}';",
cfg.schema,
cfg.table,
sql_literal(&latest)
),
],
None,
)?;
Ok(())
}
fn reset_database(cfg: &MigrationConfig) -> Result<()> {
if env::var("CONFIRM").ok().as_deref() != Some("yes") {
return usage_error("Refusing to reset database without CONFIRM=yes".into());
}
println!(
"Resetting schema {} in database {}.",
cfg.schema, cfg.database
);
psql(
cfg,
&[],
Some(&format!(
"DROP SCHEMA IF EXISTS {} CASCADE; CREATE SCHEMA {};",
cfg.schema, cfg.schema
)),
)?;
apply_migrations(cfg)
}
fn create_migration(dir: &Path) -> Result<()> {
let name =
env::var("NAME").map_err(|_| "NAME is required. Example: exe db new NAME=create_orders")?;
fs::create_dir_all(dir.join("down"))?;
let ts = Command::new("date").arg("+%Y%m%d%H%M%S").output()?;
let timestamp = String::from_utf8_lossy(&ts.stdout).trim().to_string();
let safe = Regex::new("[^a-z0-9_]+")?
.replace_all(&name.to_lowercase(), "_")
.trim_matches('_')
.to_string();
let version = format!("{timestamp}_{safe}");
fs::write(
dir.join(format!("{version}.sql")),
"-- Add migration SQL here.\n",
)?;
fs::write(
dir.join("down").join(format!("{version}.sql")),
"-- Add rollback SQL here.\n",
)?;
println!("Created {}/{}.sql", dir.display(), version);
println!("Created {}/down/{}.sql", dir.display(), version);
Ok(())
}
fn sql_literal(value: &str) -> String {
value.replace('\'', "''")
}