use sqlx::postgres::PgPool;
use sqlx::Row;
#[derive(Debug, Clone)]
pub struct PgMigration {
pub version: u32,
pub description: &'static str,
pub sql: &'static str,
}
pub static BUILTIN_PG_MIGRATIONS: &[PgMigration] = &[
PgMigration {
version: 1,
description: "Initial schema: tasks table with indexes",
sql: "\
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
context_id TEXT NOT NULL,
state TEXT NOT NULL,
data JSONB NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);\
CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id);\
CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state)",
},
PgMigration {
version: 2,
description: "Add composite index on (context_id, state) for combined filter queries",
sql: "CREATE INDEX IF NOT EXISTS idx_tasks_context_id_state ON tasks(context_id, state)",
},
];
#[derive(Debug, Clone)]
pub struct PgMigrationRunner {
pool: PgPool,
migrations: &'static [PgMigration],
}
impl PgMigrationRunner {
#[must_use]
pub fn new(pool: PgPool) -> Self {
Self {
pool,
migrations: BUILTIN_PG_MIGRATIONS,
}
}
#[must_use]
pub const fn with_migrations(pool: PgPool, migrations: &'static [PgMigration]) -> Self {
Self { pool, migrations }
}
async fn ensure_version_table(&self) -> Result<(), sqlx::Error> {
sqlx::query(
"CREATE TABLE IF NOT EXISTS schema_versions (
version INTEGER PRIMARY KEY,
description TEXT NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
)",
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn current_version(&self) -> Result<u32, sqlx::Error> {
self.ensure_version_table().await?;
let row = sqlx::query("SELECT COALESCE(MAX(version), 0) AS v FROM schema_versions")
.fetch_one(&self.pool)
.await?;
let version: i32 = row.get("v");
#[allow(clippy::cast_sign_loss)]
Ok(version as u32)
}
pub async fn pending_migrations(&self) -> Result<Vec<&PgMigration>, sqlx::Error> {
let current = self.current_version().await?;
Ok(self
.migrations
.iter()
.filter(|m| m.version > current)
.collect())
}
pub async fn run_pending(&self) -> Result<Vec<u32>, sqlx::Error> {
self.ensure_version_table().await?;
let current = self.current_version().await?;
let mut applied = Vec::new();
for migration in self.migrations {
if migration.version <= current {
continue;
}
let mut tx = self.pool.begin().await?;
sqlx::query("LOCK TABLE schema_versions IN EXCLUSIVE MODE")
.execute(&mut *tx)
.await?;
let row = sqlx::query("SELECT COALESCE(MAX(version), 0) AS v FROM schema_versions")
.fetch_one(&mut *tx)
.await?;
let current_in_tx: i32 = row.get("v");
#[allow(clippy::cast_sign_loss)]
if migration.version <= current_in_tx as u32 {
tx.rollback().await?;
continue;
}
for statement in migration.sql.split(';') {
let trimmed = statement.trim();
if trimmed.is_empty() {
continue;
}
sqlx::query(trimmed).execute(&mut *tx).await?;
}
#[allow(clippy::cast_possible_wrap)] let version_i32 = migration.version as i32;
sqlx::query("INSERT INTO schema_versions (version, description) VALUES ($1, $2)")
.bind(version_i32)
.bind(migration.description)
.execute(&mut *tx)
.await?;
tx.commit().await?;
applied.push(migration.version);
}
Ok(applied)
}
}