use sqlx::sqlite::SqlitePool;
use sqlx::Row;
#[derive(Debug, Clone)]
pub struct Migration {
pub version: u32,
pub description: &'static str,
pub sql: &'static str,
}
pub static BUILTIN_MIGRATIONS: &[Migration] = &[
Migration {
version: 1,
description: "Initial schema: tasks table with context_id and state indexes",
sql: "\
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
context_id TEXT NOT NULL,
state TEXT NOT NULL,
data TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id);
CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state);",
},
Migration {
version: 2,
description: "Add created_at column to tasks table",
sql: "ALTER TABLE tasks ADD COLUMN created_at TEXT NOT NULL DEFAULT (datetime('now'));",
},
Migration {
version: 3,
description: "Add composite index on (context_id, state) for combined filter queries",
sql: "CREATE INDEX IF NOT EXISTS idx_tasks_context_id_state ON tasks(context_id, state);",
},
];
#[derive(Debug, Clone)]
pub struct MigrationRunner {
pool: SqlitePool,
migrations: &'static [Migration],
}
impl MigrationRunner {
#[must_use]
pub fn new(pool: SqlitePool) -> Self {
Self {
pool,
migrations: BUILTIN_MIGRATIONS,
}
}
#[must_use]
pub const fn with_migrations(pool: SqlitePool, migrations: &'static [Migration]) -> Self {
Self { pool, migrations }
}
async fn ensure_version_table(&self) -> Result<(), sqlx::Error> {
sqlx::query(
"CREATE TABLE IF NOT EXISTS schema_versions (
version INTEGER PRIMARY KEY,
description TEXT NOT NULL,
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
)",
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn current_version(&self) -> Result<u32, sqlx::Error> {
self.ensure_version_table().await?;
let row = sqlx::query("SELECT COALESCE(MAX(version), 0) AS v FROM schema_versions")
.fetch_one(&self.pool)
.await?;
let version: i32 = row.get("v");
#[allow(clippy::cast_sign_loss)]
Ok(version as u32)
}
pub async fn pending_migrations(&self) -> Result<Vec<&Migration>, sqlx::Error> {
let current = self.current_version().await?;
Ok(self
.migrations
.iter()
.filter(|m| m.version > current)
.collect())
}
pub async fn run_pending(&self) -> Result<Vec<u32>, sqlx::Error> {
self.ensure_version_table().await?;
let mut applied = Vec::new();
for migration in self.migrations {
let mut conn = self.pool.acquire().await?;
sqlx::query("BEGIN EXCLUSIVE").execute(&mut *conn).await?;
let row = sqlx::query("SELECT COALESCE(MAX(version), 0) AS v FROM schema_versions")
.fetch_one(&mut *conn)
.await?;
let current: i32 = row.get("v");
#[allow(clippy::cast_sign_loss)]
let current = current as u32;
if migration.version <= current {
sqlx::query("ROLLBACK").execute(&mut *conn).await?;
continue;
}
for statement in migration.sql.split(';') {
let trimmed = statement.trim();
if trimmed.is_empty() {
continue;
}
sqlx::query(trimmed).execute(&mut *conn).await?;
}
sqlx::query("INSERT INTO schema_versions (version, description) VALUES (?1, ?2)")
.bind(migration.version)
.bind(migration.description)
.execute(&mut *conn)
.await?;
sqlx::query("COMMIT").execute(&mut *conn).await?;
applied.push(migration.version);
}
Ok(applied)
}
}
#[cfg(test)]
mod tests {
use super::*;
use sqlx::sqlite::SqlitePoolOptions;
async fn memory_pool() -> SqlitePool {
SqlitePoolOptions::new()
.max_connections(1)
.connect("sqlite::memory:")
.await
.expect("failed to open in-memory sqlite")
}
#[tokio::test]
async fn current_version_starts_at_zero() {
let pool = memory_pool().await;
let runner = MigrationRunner::new(pool);
assert_eq!(runner.current_version().await.unwrap(), 0);
}
#[tokio::test]
async fn run_pending_applies_all_builtin_migrations() {
let pool = memory_pool().await;
let runner = MigrationRunner::new(pool.clone());
let applied = runner.run_pending().await.unwrap();
assert_eq!(applied, vec![1, 2, 3]);
assert_eq!(runner.current_version().await.unwrap(), 3);
let row = sqlx::query("PRAGMA table_info(tasks)")
.fetch_all(&pool)
.await
.unwrap();
let columns: Vec<String> = row.iter().map(|r| r.get::<String, _>("name")).collect();
assert!(columns.contains(&"id".to_string()));
assert!(columns.contains(&"context_id".to_string()));
assert!(columns.contains(&"state".to_string()));
assert!(columns.contains(&"data".to_string()));
assert!(columns.contains(&"updated_at".to_string()));
assert!(columns.contains(&"created_at".to_string()));
}
#[tokio::test]
async fn run_pending_is_idempotent() {
let pool = memory_pool().await;
let runner = MigrationRunner::new(pool);
let first = runner.run_pending().await.unwrap();
assert_eq!(first, vec![1, 2, 3]);
let second = runner.run_pending().await.unwrap();
assert!(second.is_empty());
assert_eq!(runner.current_version().await.unwrap(), 3);
}
#[tokio::test]
async fn pending_migrations_returns_unapplied() {
let pool = memory_pool().await;
let runner = MigrationRunner::new(pool);
let pending = runner.pending_migrations().await.unwrap();
assert_eq!(pending.len(), 3);
assert_eq!(pending[0].version, 1);
assert_eq!(pending[1].version, 2);
assert_eq!(pending[2].version, 3);
runner.run_pending().await.unwrap();
let pending = runner.pending_migrations().await.unwrap();
assert!(pending.is_empty());
}
#[tokio::test]
async fn partial_application_tracks_correctly() {
let pool = memory_pool().await;
let v1_only: &[Migration] = &BUILTIN_MIGRATIONS[..1];
let runner = MigrationRunner::with_migrations(pool.clone(), v1_only);
let applied = runner.run_pending().await.unwrap();
assert_eq!(applied, vec![1]);
assert_eq!(runner.current_version().await.unwrap(), 1);
let full_runner = MigrationRunner::new(pool);
let pending = full_runner.pending_migrations().await.unwrap();
assert_eq!(pending.len(), 2);
assert_eq!(pending[0].version, 2);
assert_eq!(pending[1].version, 3);
let applied = full_runner.run_pending().await.unwrap();
assert_eq!(applied, vec![2, 3]);
assert_eq!(full_runner.current_version().await.unwrap(), 3);
}
#[tokio::test]
async fn schema_versions_table_records_metadata() {
let pool = memory_pool().await;
let runner = MigrationRunner::new(pool.clone());
runner.run_pending().await.unwrap();
let rows = sqlx::query(
"SELECT version, description, applied_at FROM schema_versions ORDER BY version",
)
.fetch_all(&pool)
.await
.unwrap();
assert_eq!(rows.len(), 3);
assert_eq!(rows[0].get::<i32, _>("version"), 1);
assert!(!rows[0].get::<String, _>("description").is_empty());
assert!(!rows[0].get::<String, _>("applied_at").is_empty());
}
#[tokio::test]
async fn composite_index_exists_after_v3() {
let pool = memory_pool().await;
let runner = MigrationRunner::new(pool.clone());
runner.run_pending().await.unwrap();
let rows = sqlx::query("SELECT name FROM sqlite_master WHERE type='index' AND name='idx_tasks_context_id_state'")
.fetch_all(&pool)
.await
.unwrap();
assert_eq!(rows.len(), 1);
}
}