use rusqlite::Connection;
use std::collections::HashSet;
use crate::{Error, MigrateResult, Migration, Seed};
fn ensure_migrations_table(conn: &mut Connection) -> MigrateResult<()> {
conn.execute(
"CREATE TABLE IF NOT EXISTS _migrations (
id TEXT PRIMARY KEY,
applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)",
[],
)?;
Ok(())
}
fn get_applied_migrations(conn: &Connection) -> MigrateResult<HashSet<String>> {
let mut statement = conn.prepare("SELECT id FROM _migrations")?;
let migration_ids = statement.query_map([], |row| row.get::<_, String>(0))?;
let mut applied_set = HashSet::new();
for id in migration_ids.into_iter().flatten() {
applied_set.insert(id);
}
Ok(applied_set)
}
pub fn migrate(conn: &mut Connection, migrations: &[Migration]) -> MigrateResult<()> {
ensure_migrations_table(conn)?;
let applied_migrations = get_applied_migrations(conn)?;
let pending_migrations: Vec<&Migration> = migrations
.iter()
.filter(|m| !applied_migrations.contains(m.id))
.collect();
if pending_migrations.is_empty() {
return Ok(());
}
let tx = conn.transaction()?;
for migration in pending_migrations {
tx.execute_batch(migration.sql)
.map_err(|e| Error::MigrationFailed {
id: migration.id.to_string(),
message: e.to_string(),
})?;
tx.execute("INSERT INTO _migrations(id) VALUES (?)", [migration.id])?;
}
tx.commit()?;
Ok(())
}
fn ensure_seeds_table(conn: &mut Connection) -> MigrateResult<()> {
conn.execute(
"CREATE TABLE IF NOT EXISTS _seeds (
id TEXT PRIMARY KEY,
applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)",
[],
)?;
Ok(())
}
fn get_applied_seeds(conn: &Connection) -> MigrateResult<HashSet<String>> {
let mut statement = conn.prepare("SELECT id FROM _seeds")?;
let seed_ids = statement.query_map([], |row| row.get::<_, String>(0))?;
let mut applied_set = HashSet::new();
for id in seed_ids.into_iter().flatten() {
applied_set.insert(id);
}
Ok(applied_set)
}
pub fn seed(conn: &mut Connection, seeds: &[Seed]) -> MigrateResult<()> {
ensure_seeds_table(conn)?;
let applied_seeds = get_applied_seeds(conn)?;
let pending_seeds: Vec<&Seed> = seeds
.iter()
.filter(|s| !applied_seeds.contains(s.id))
.collect();
if pending_seeds.is_empty() {
return Ok(());
}
for seed in pending_seeds {
let tx = conn.transaction()?;
(seed.seed_fn)(&tx).map_err(|e| Error::MigrationFailed {
id: seed.id.to_string(),
message: e.to_string(),
})?;
tx.execute("INSERT INTO _seeds(id) VALUES (?)", [seed.id])?;
tx.commit()?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::Connection;
#[test]
fn test_migration_creation() {
let migration = Migration::new("001_test", "CREATE TABLE test (id INTEGER);");
assert_eq!(migration.id, "001_test");
assert_eq!(migration.sql, "CREATE TABLE test (id INTEGER);");
}
#[test]
fn test_ensure_migrations_table() {
let mut conn = Connection::open_in_memory().unwrap();
ensure_migrations_table(&mut conn).unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='_migrations'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[test]
fn test_up_migrations() {
let mut conn = Connection::open_in_memory().unwrap();
let migrations = &[
Migration::new(
"001_create_users",
"CREATE TABLE users (id INTEGER PRIMARY KEY);",
),
Migration::new("002_add_email", "ALTER TABLE users ADD COLUMN email TEXT;"),
];
migrate(&mut conn, migrations).unwrap();
let applied = get_applied_migrations(&conn).unwrap();
assert!(applied.contains("001_create_users"));
assert!(applied.contains("002_add_email"));
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM pragma_table_info('users') WHERE name='email'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[test]
fn test_up_migrations_idempotency() {
let mut conn = Connection::open_in_memory().unwrap();
let migrations = &[Migration::new(
"001_test",
"CREATE TABLE test (id INTEGER);",
)];
migrate(&mut conn, migrations).unwrap();
migrate(&mut conn, migrations).unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM _migrations WHERE id='001_test'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[test]
fn test_migration_failure_rollback() {
let mut conn = Connection::open_in_memory().unwrap();
let migrations = &[
Migration::new("001_valid", "CREATE TABLE test (id INTEGER);"),
Migration::new("002_invalid", "INVALID SQL STATEMENT;"),
];
let result = migrate(&mut conn, migrations);
assert!(result.is_err());
let applied = get_applied_migrations(&conn).unwrap();
assert!(applied.is_empty());
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='test'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 0);
}
#[test]
fn test_ensure_seeds_table() {
let mut conn = Connection::open_in_memory().unwrap();
ensure_seeds_table(&mut conn).unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='_seeds'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
fn seed_test_data(conn: &Connection) -> MigrateResult<()> {
conn.execute(
"CREATE TABLE IF NOT EXISTS test_users (id INTEGER PRIMARY KEY, name TEXT)",
[],
)?;
conn.execute("INSERT INTO test_users (name) VALUES ('Alice')", [])?;
conn.execute("INSERT INTO test_users (name) VALUES ('Bob')", [])?;
Ok(())
}
fn seed_more_data(conn: &Connection) -> MigrateResult<()> {
conn.execute("INSERT INTO test_users (name) VALUES ('Charlie')", [])?;
Ok(())
}
#[test]
fn test_seed_execution() {
let mut conn = Connection::open_in_memory().unwrap();
let seeds = &[
Seed::new("001_initial", seed_test_data),
Seed::new("002_more", seed_more_data),
];
seed(&mut conn, seeds).unwrap();
let applied = get_applied_seeds(&conn).unwrap();
assert!(applied.contains("001_initial"));
assert!(applied.contains("002_more"));
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM test_users", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 3);
}
#[test]
fn test_seed_idempotency() {
let mut conn = Connection::open_in_memory().unwrap();
let seeds = &[Seed::new("001_test", seed_test_data)];
seed(&mut conn, seeds).unwrap();
seed(&mut conn, seeds).unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM _seeds WHERE id='001_test'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
let user_count: i64 = conn
.query_row("SELECT COUNT(*) FROM test_users", [], |row| row.get(0))
.unwrap();
assert_eq!(user_count, 2);
}
}