use std::collections::{BTreeMap, HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;
use sqlx_core::connection::Connection;
use sqlx_core::executor::Executor;
use sqlx_core::row::Row;
use sqlx_core::sql_str::AssertSqlSafe;
use sqlx_postgres::{PgConnection, PgPool};
use crate::value::{VmError, VmValue};
use super::{handle_id, pool_by_id, runtime_error, HANDLE_POOL};
const MIGRATION_LOCK_KEY: i64 = 0x4861_726E_4D67_7201;
const DEFAULT_TABLE: &str = "harn_migrations";
const SQLX_TABLE: &str = "_sqlx_migrations";
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Ledger {
Harn,
Sqlx,
}
impl Ledger {
fn parse(opts: &BTreeMap<String, VmValue>) -> Result<Self, VmError> {
match opts.get("ledger") {
None => Ok(Ledger::Harn),
Some(VmValue::String(s)) => match s.as_ref() {
"harn" => Ok(Ledger::Harn),
"sqlx" => Ok(Ledger::Sqlx),
other => Err(runtime_error(format!(
"pg_migrate: unknown ledger `{other}`; expected \"harn\" or \"sqlx\""
))),
},
Some(_) => Err(runtime_error(
"pg_migrate: option `ledger` must be a string (\"harn\" or \"sqlx\")",
)),
}
}
}
pub(super) async fn run(args: Vec<VmValue>) -> Result<VmValue, VmError> {
let pool_handle = args.first().ok_or_else(|| {
runtime_error("pg_migrate: pool handle is required as the first argument")
})?;
let opts = args
.get(1)
.and_then(VmValue::as_dict)
.cloned()
.ok_or_else(|| {
runtime_error("pg_migrate: second argument must be an options dict {dir, ...}")
})?;
let pool_id = handle_id(Some(pool_handle), HANDLE_POOL, "pg_migrate")?;
let pool = pool_by_id(&pool_id)?;
let dir = dir_arg(&opts, "dir")?;
let ledger = Ledger::parse(&opts)?;
let table_name = match ledger {
Ledger::Sqlx => {
if let Some(VmValue::String(s)) = opts.get("table") {
if s.as_ref() != SQLX_TABLE {
return Err(runtime_error(format!(
"pg_migrate: ledger \"sqlx\" always uses table `{SQLX_TABLE}`; \
remove the conflicting `table: \"{s}\"`"
)));
}
}
SQLX_TABLE.to_string()
}
Ledger::Harn => opts
.get("table")
.and_then(|v| match v {
VmValue::String(s) => Some(s.to_string()),
_ => None,
})
.unwrap_or_else(|| DEFAULT_TABLE.to_string()),
};
validate_table_name(&table_name)?;
let dry_run = matches!(opts.get("dry_run"), Some(VmValue::Bool(true)));
match ledger {
Ledger::Harn => run_harn(pool, &dir, table_name, dry_run).await,
Ledger::Sqlx => run_sqlx(pool, &dir, table_name, dry_run).await,
}
}
async fn run_harn(
pool: Arc<PgPool>,
dir: &Path,
table_name: String,
dry_run: bool,
) -> Result<VmValue, VmError> {
let entries = discover_migrations(dir)?;
let started = Instant::now();
let mut conn = pool.acquire().await.map_err(|error| {
runtime_error(format!("pg_migrate: acquire connection failed: {error}"))
})?;
acquire_lock(&mut conn, MIGRATION_LOCK_KEY).await?;
let result = async {
ensure_migrations_table(&mut conn, &table_name).await?;
let applied = applied_set(&mut conn, &table_name).await?;
let mut applied_now = Vec::new();
let mut skipped = Vec::new();
for entry in entries.iter() {
if let Some(existing) = applied.get(&entry.name) {
let checksum = sha256_file(&entry.path)?;
if &checksum != existing {
return Err(runtime_error(format!(
"pg_migrate: checksum mismatch for migration {}; the \
recorded checksum differs from the file on disk",
entry.name
)));
}
skipped.push(entry.name.clone());
continue;
}
if !dry_run {
apply_one(&mut conn, &table_name, entry).await?;
}
applied_now.push(entry.name.clone());
}
Ok::<_, VmError>((applied_now, skipped))
}
.await;
release_lock(&mut conn, MIGRATION_LOCK_KEY).await;
let (applied_now, skipped) = result?;
let available: Vec<String> = entries.iter().map(|entry| entry.name.clone()).collect();
Ok(build_response(
applied_now,
skipped,
available,
dry_run,
started.elapsed().as_millis() as i64,
&table_name,
))
}
fn dir_arg(dict: &BTreeMap<String, VmValue>, key: &str) -> Result<PathBuf, VmError> {
let value = dict.get(key).ok_or_else(|| {
runtime_error(format!(
"pg_migrate: option `{key}` is required and must be a path"
))
})?;
match value {
VmValue::String(text) => Ok(PathBuf::from(text.as_ref())),
_ => Err(runtime_error(format!(
"pg_migrate: option `{key}` must be a string path"
))),
}
}
#[derive(Clone)]
struct MigrationEntry {
name: String,
path: PathBuf,
}
fn discover_migrations(dir: &Path) -> Result<Vec<MigrationEntry>, VmError> {
if !dir.exists() {
return Err(runtime_error(format!(
"pg_migrate: directory does not exist: {}",
dir.display()
)));
}
let read_dir = std::fs::read_dir(dir).map_err(|error| {
runtime_error(format!(
"pg_migrate: could not read directory {}: {error}",
dir.display()
))
})?;
let mut entries: Vec<MigrationEntry> = read_dir
.filter_map(|entry| entry.ok())
.filter_map(|entry| {
let path = entry.path();
let name = entry.file_name().to_string_lossy().into_owned();
if name.ends_with(".sql") && !name.ends_with(".down.sql") {
Some(MigrationEntry { name, path })
} else {
None
}
})
.collect();
entries.sort_by(|a, b| a.name.cmp(&b.name));
Ok(entries)
}
async fn ensure_migrations_table(conn: &mut PgConnection, table: &str) -> Result<(), VmError> {
let sql = format!(
"CREATE TABLE IF NOT EXISTS \"{table}\" (\
name TEXT PRIMARY KEY,\
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),\
checksum BYTEA NOT NULL\
)"
);
conn.execute(AssertSqlSafe(sql))
.await
.map_err(|error| runtime_error(format!("pg_migrate: ensure table failed: {error}")))?;
Ok(())
}
async fn applied_set(
conn: &mut PgConnection,
table: &str,
) -> Result<BTreeMap<String, Vec<u8>>, VmError> {
let sql = format!("SELECT name, checksum FROM \"{table}\"");
let rows = sqlx_core::query::query::<sqlx_postgres::Postgres>(AssertSqlSafe(sql))
.fetch_all(conn)
.await
.map_err(|error| runtime_error(format!("pg_migrate: select applied failed: {error}")))?;
Ok(rows
.iter()
.map(|row| (row.get::<String, _>(0), row.get::<Vec<u8>, _>(1)))
.collect::<BTreeMap<_, _>>())
}
async fn apply_one(
conn: &mut PgConnection,
table: &str,
entry: &MigrationEntry,
) -> Result<(), VmError> {
let sql = std::fs::read_to_string(&entry.path).map_err(|error| {
runtime_error(format!(
"pg_migrate: could not read {}: {error}",
entry.path.display()
))
})?;
let checksum = sha256(&sql);
let mut tx = conn
.begin()
.await
.map_err(|error| runtime_error(format!("pg_migrate: begin failed: {error}")))?;
(&mut *tx)
.execute(AssertSqlSafe(sql))
.await
.map_err(|error| runtime_error(format!("pg_migrate: applying {}: {error}", entry.name)))?;
let insert = format!("INSERT INTO \"{table}\" (name, checksum) VALUES ($1, $2)");
sqlx_core::query::query::<sqlx_postgres::Postgres>(AssertSqlSafe(insert))
.bind(entry.name.clone())
.bind(checksum)
.execute(&mut *tx)
.await
.map_err(|error| {
runtime_error(format!("pg_migrate: record {} failed: {error}", entry.name))
})?;
tx.commit().await.map_err(|error| {
runtime_error(format!("pg_migrate: commit {} failed: {error}", entry.name))
})
}
async fn acquire_lock(conn: &mut PgConnection, key: i64) -> Result<(), VmError> {
sqlx_core::query::query::<sqlx_postgres::Postgres>("SELECT pg_advisory_lock($1)")
.bind(key)
.execute(conn)
.await
.map_err(|error| runtime_error(format!("pg_migrate: advisory lock failed: {error}")))?;
Ok(())
}
async fn release_lock(conn: &mut PgConnection, key: i64) {
match sqlx_core::query::query::<sqlx_postgres::Postgres>("SELECT pg_advisory_unlock($1)")
.bind(key)
.fetch_optional(conn)
.await
{
Ok(Some(row)) => {
let released: bool = row.get::<bool, _>(0);
if !released {
tracing::warn!(
lock_key = key,
"pg_migrate: pg_advisory_unlock returned false; the advisory \
lock was not held by this connection"
);
}
}
Ok(None) => {
tracing::warn!(
lock_key = key,
"pg_migrate: pg_advisory_unlock returned no row"
);
}
Err(error) => {
tracing::warn!(
lock_key = key,
%error,
"pg_migrate: releasing advisory lock failed"
);
}
}
}
fn sha256(text: &str) -> Vec<u8> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(text.as_bytes());
hasher.finalize().to_vec()
}
fn sha256_file(path: &Path) -> Result<Vec<u8>, VmError> {
let sql = std::fs::read_to_string(path).map_err(|error| {
runtime_error(format!(
"pg_migrate: could not read {}: {error}",
path.display()
))
})?;
Ok(sha256(&sql))
}
fn validate_table_name(name: &str) -> Result<(), VmError> {
if name.is_empty() || name.len() > 63 {
return Err(runtime_error(
"pg_migrate: option `table` must be 1..=63 bytes",
));
}
let first = name.chars().next().unwrap();
if !(first.is_ascii_alphabetic() || first == '_') {
return Err(runtime_error(
"pg_migrate: option `table` must start with a letter or underscore",
));
}
for ch in name.chars() {
if !(ch.is_ascii_alphanumeric() || ch == '_') {
return Err(runtime_error(format!(
"pg_migrate: option `table` contains invalid character `{ch}`"
)));
}
}
Ok(())
}
fn build_response(
applied_now: Vec<String>,
skipped: Vec<String>,
available: Vec<String>,
dry_run: bool,
duration_ms: i64,
table_name: &str,
) -> VmValue {
fn str_list(items: Vec<String>) -> VmValue {
VmValue::List(Arc::new(
items
.into_iter()
.map(|name| VmValue::String(Arc::from(name)))
.collect(),
))
}
let mut response = BTreeMap::new();
response.insert("applied".to_string(), str_list(applied_now));
response.insert("skipped".to_string(), str_list(skipped));
response.insert("available".to_string(), str_list(available));
response.insert("dry_run".to_string(), VmValue::Bool(dry_run));
response.insert("duration_ms".to_string(), VmValue::Int(duration_ms));
response.insert("table".to_string(), VmValue::String(Arc::from(table_name)));
VmValue::Dict(Arc::new(response))
}
#[derive(Clone, Debug)]
struct SqlxMigration {
version: i64,
description: String,
name: String,
path: PathBuf,
}
fn discover_sqlx_migrations(dir: &Path) -> Result<Vec<SqlxMigration>, VmError> {
if !dir.exists() {
return Err(runtime_error(format!(
"pg_migrate: directory does not exist: {}",
dir.display()
)));
}
let read_dir = std::fs::read_dir(dir).map_err(|error| {
runtime_error(format!(
"pg_migrate: could not read directory {}: {error}",
dir.display()
))
})?;
let mut parsed: Vec<SqlxMigration> = Vec::new();
for entry in read_dir.filter_map(|entry| entry.ok()) {
let path = entry.path();
let name = entry.file_name().to_string_lossy().into_owned();
let parts: Vec<&str> = name.splitn(2, '_').collect();
if parts.len() != 2 || !parts[1].ends_with(".sql") {
continue;
}
if parts[1].ends_with(".down.sql") {
continue;
}
let version: i64 = parts[0].parse().map_err(|_| {
runtime_error(format!(
"pg_migrate: error parsing migration filename {name:?}; \
expected integer version prefix (e.g. `01_foo.sql`)"
))
})?;
let suffix = if parts[1].ends_with(".up.sql") {
".up.sql"
} else {
".sql"
};
let description = parts[1].trim_end_matches(suffix).replace('_', " ");
parsed.push(SqlxMigration {
version,
description,
name,
path,
});
}
parsed.sort_by(|a, b| a.version.cmp(&b.version).then(a.name.cmp(&b.name)));
let mut seen: HashSet<i64> = HashSet::new();
let mut deduped = Vec::with_capacity(parsed.len());
for migration in parsed {
if !seen.insert(migration.version) {
tracing::warn!(
version = migration.version,
description = %migration.description,
file = %migration.name,
"pg_migrate: skipping migration with duplicate version (another file \
already claimed this prefix); fix by renaming one of the files",
);
continue;
}
deduped.push(migration);
}
Ok(deduped)
}
async fn run_sqlx(
pool: Arc<PgPool>,
dir: &Path,
table_name: String,
dry_run: bool,
) -> Result<VmValue, VmError> {
let migrations = discover_sqlx_migrations(dir)?;
let started = Instant::now();
let mut conn = pool.acquire().await.map_err(|error| {
runtime_error(format!("pg_migrate: acquire connection failed: {error}"))
})?;
let lock_id = sqlx_lock_id(&mut conn).await?;
acquire_lock(&mut conn, lock_id).await?;
let result = run_sqlx_locked(&mut conn, &table_name, dry_run, &migrations).await;
release_lock(&mut conn, lock_id).await;
let (applied_now, skipped) = result?;
let available: Vec<String> = migrations.iter().map(|m| m.name.clone()).collect();
Ok(build_response(
applied_now,
skipped,
available,
dry_run,
started.elapsed().as_millis() as i64,
&table_name,
))
}
async fn run_sqlx_locked(
conn: &mut PgConnection,
table: &str,
dry_run: bool,
migrations: &[SqlxMigration],
) -> Result<(Vec<String>, Vec<String>), VmError> {
ensure_sqlx_migrations_table(conn, table).await?;
if let Some(version) = sqlx_dirty_version(conn, table).await? {
return Err(runtime_error(format!(
"pg_migrate: dirty migration {version}; the ledger has a failed \
migration recorded — resolve it before re-running"
)));
}
let applied = sqlx_applied(conn, table).await?;
let mut applied_now = Vec::new();
let mut skipped = Vec::new();
for migration in migrations {
if let Some(existing) = applied.get(&migration.version) {
let checksum = sha384_file(&migration.path)?;
if existing != &checksum {
return Err(runtime_error(format!(
"pg_migrate: checksum mismatch for migration {} ({}); the \
recorded checksum differs from the file on disk",
migration.version, migration.name
)));
}
skipped.push(migration.name.clone());
continue;
}
if !dry_run {
apply_sqlx_one(conn, table, migration).await?;
}
applied_now.push(migration.name.clone());
}
Ok((applied_now, skipped))
}
async fn ensure_sqlx_migrations_table(conn: &mut PgConnection, table: &str) -> Result<(), VmError> {
let sql = format!(
"CREATE TABLE IF NOT EXISTS \"{table}\" (\
version BIGINT PRIMARY KEY,\
description TEXT NOT NULL,\
installed_on TIMESTAMPTZ NOT NULL DEFAULT now(),\
success BOOLEAN NOT NULL,\
checksum BYTEA NOT NULL,\
execution_time BIGINT NOT NULL\
)"
);
conn.execute(AssertSqlSafe(sql))
.await
.map_err(|error| runtime_error(format!("pg_migrate: ensure sqlx table failed: {error}")))?;
Ok(())
}
async fn sqlx_dirty_version(conn: &mut PgConnection, table: &str) -> Result<Option<i64>, VmError> {
let sql =
format!("SELECT version FROM \"{table}\" WHERE success = false ORDER BY version LIMIT 1");
let row = sqlx_core::query::query::<sqlx_postgres::Postgres>(AssertSqlSafe(sql))
.fetch_optional(conn)
.await
.map_err(|error| runtime_error(format!("pg_migrate: dirty check failed: {error}")))?;
Ok(row.map(|row| row.get::<i64, _>(0)))
}
async fn sqlx_applied(
conn: &mut PgConnection,
table: &str,
) -> Result<HashMap<i64, Vec<u8>>, VmError> {
let sql = format!("SELECT version, checksum FROM \"{table}\" ORDER BY version");
let rows = sqlx_core::query::query::<sqlx_postgres::Postgres>(AssertSqlSafe(sql))
.fetch_all(conn)
.await
.map_err(|error| {
runtime_error(format!("pg_migrate: select applied (sqlx) failed: {error}"))
})?;
Ok(rows
.iter()
.map(|row| (row.get::<i64, _>(0), row.get::<Vec<u8>, _>(1)))
.collect())
}
async fn apply_sqlx_one(
conn: &mut PgConnection,
table: &str,
migration: &SqlxMigration,
) -> Result<(), VmError> {
let sql = std::fs::read_to_string(&migration.path).map_err(|error| {
runtime_error(format!(
"pg_migrate: could not read {}: {error}",
migration.path.display()
))
})?;
let checksum = sha384(&sql);
let no_tx = sql.starts_with("-- no-transaction");
let start = Instant::now();
if no_tx {
(&mut *conn)
.execute(AssertSqlSafe(sql))
.await
.map_err(|error| {
runtime_error(format!(
"pg_migrate: applying {} ({}): {error}",
migration.version, migration.name
))
})?;
sqlx_insert_row(&mut *conn, table, migration, &checksum).await?;
} else {
let mut tx = conn
.begin()
.await
.map_err(|error| runtime_error(format!("pg_migrate: begin failed: {error}")))?;
(&mut *tx)
.execute(AssertSqlSafe(sql))
.await
.map_err(|error| {
runtime_error(format!(
"pg_migrate: applying {} ({}): {error}",
migration.version, migration.name
))
})?;
sqlx_insert_row(&mut *tx, table, migration, &checksum).await?;
tx.commit().await.map_err(|error| {
runtime_error(format!(
"pg_migrate: commit {} ({}) failed: {error}",
migration.version, migration.name
))
})?;
}
let elapsed_nanos = start.elapsed().as_nanos() as i64;
let update = format!("UPDATE \"{table}\" SET execution_time = $1 WHERE version = $2");
sqlx_core::query::query::<sqlx_postgres::Postgres>(AssertSqlSafe(update))
.bind(elapsed_nanos)
.bind(migration.version)
.execute(&mut *conn)
.await
.map_err(|error| {
runtime_error(format!(
"pg_migrate: record execution_time for {} failed: {error}",
migration.version
))
})?;
Ok(())
}
async fn sqlx_insert_row<'c, E>(
executor: E,
table: &str,
migration: &SqlxMigration,
checksum: &[u8],
) -> Result<(), VmError>
where
E: Executor<'c, Database = sqlx_postgres::Postgres>,
{
let insert = format!(
"INSERT INTO \"{table}\" (version, description, success, checksum, execution_time) \
VALUES ($1, $2, TRUE, $3, -1)"
);
sqlx_core::query::query::<sqlx_postgres::Postgres>(AssertSqlSafe(insert))
.bind(migration.version)
.bind(migration.description.clone())
.bind(checksum.to_vec())
.execute(executor)
.await
.map_err(|error| {
runtime_error(format!(
"pg_migrate: record {} ({}) failed: {error}",
migration.version, migration.name
))
})?;
Ok(())
}
async fn sqlx_lock_id(conn: &mut PgConnection) -> Result<i64, VmError> {
let row = sqlx_core::query::query::<sqlx_postgres::Postgres>("SELECT current_database()")
.fetch_one(conn)
.await
.map_err(|error| {
runtime_error(format!("pg_migrate: current_database() failed: {error}"))
})?;
let database_name = row.get::<String, _>(0);
Ok(generate_sqlx_lock_id(&database_name))
}
fn generate_sqlx_lock_id(database_name: &str) -> i64 {
0x3d32ad9e * (crc32_iso_hdlc(database_name.as_bytes()) as i64)
}
fn crc32_iso_hdlc(bytes: &[u8]) -> u32 {
let mut crc: u32 = 0xFFFF_FFFF;
for &byte in bytes {
crc ^= byte as u32;
for _ in 0..8 {
let mask = (crc & 1).wrapping_neg();
crc = (crc >> 1) ^ (0xEDB8_8320 & mask);
}
}
!crc
}
fn sha384(text: &str) -> Vec<u8> {
use sha2::{Digest, Sha384};
Sha384::digest(text.as_bytes()).to_vec()
}
fn sha384_file(path: &Path) -> Result<Vec<u8>, VmError> {
let sql = std::fs::read_to_string(path).map_err(|error| {
runtime_error(format!(
"pg_migrate: could not read {}: {error}",
path.display()
))
})?;
Ok(sha384(&sql))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn crc32_iso_hdlc_matches_check_vector() {
assert_eq!(crc32_iso_hdlc(b"123456789"), 0xCBF4_3926);
assert_eq!(crc32_iso_hdlc(b""), 0x0000_0000);
}
#[test]
fn sqlx_lock_id_matches_sqlx_formula() {
let db = "harn_cloud";
let expected = 0x3d32ad9e_i64 * (crc32_iso_hdlc(db.as_bytes()) as i64);
assert_eq!(generate_sqlx_lock_id(db), expected);
assert!(generate_sqlx_lock_id("postgres") != 0);
}
#[test]
fn sha384_is_48_bytes_and_differs_from_sha256() {
let sql = "CREATE TABLE t (id INT);";
let s384 = sha384(sql);
let s256 = sha256(sql);
assert_eq!(s384.len(), 48);
assert_eq!(s256.len(), 32);
assert_ne!(s384[..32], s256[..]);
}
#[test]
fn discover_sqlx_parses_versions_descriptions_and_sorts_numerically() {
let tmp = tempfile::tempdir().unwrap();
let dir = tmp.path();
std::fs::write(dir.join("20260419170000_bootstrap.up.sql"), "SELECT 1").unwrap();
std::fs::write(dir.join("20260419170000_bootstrap.down.sql"), "SELECT 0").unwrap();
std::fs::write(dir.join("9_early_thing.up.sql"), "SELECT 2").unwrap();
std::fs::write(dir.join("100_later_thing.sql"), "SELECT 3").unwrap();
std::fs::write(dir.join("README.md"), "ignore me").unwrap();
let migrations = discover_sqlx_migrations(dir).expect("discover");
let versions: Vec<i64> = migrations.iter().map(|m| m.version).collect();
assert_eq!(versions, vec![9, 100, 20260419170000]);
assert_eq!(migrations[0].description, "early thing");
assert_eq!(migrations[1].description, "later thing");
assert_eq!(migrations[2].description, "bootstrap");
assert!(migrations.iter().all(|m| !m.name.ends_with(".down.sql")));
}
#[test]
fn discover_sqlx_errors_on_non_integer_prefix() {
let tmp = tempfile::tempdir().unwrap();
let dir = tmp.path();
std::fs::write(dir.join("notanumber_thing.up.sql"), "SELECT 1").unwrap();
let err = discover_sqlx_migrations(dir).unwrap_err();
assert!(
format!("{err:?}").contains("integer version prefix"),
"unexpected error: {err:?}"
);
}
#[test]
fn discover_sqlx_dedupes_duplicate_versions() {
let tmp = tempfile::tempdir().unwrap();
let dir = tmp.path();
std::fs::write(dir.join("5_first.up.sql"), "SELECT 1").unwrap();
std::fs::write(dir.join("5_second.up.sql"), "SELECT 2").unwrap();
let migrations = discover_sqlx_migrations(dir).expect("discover");
assert_eq!(migrations.len(), 1);
assert_eq!(migrations[0].version, 5);
assert_eq!(migrations[0].name, "5_first.up.sql");
}
}