use super::*;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
Sqlite(#[from] rusqlite::Error),
#[error("Database does not contain historic migration '{0}'")]
MissingHistoric(String),
#[error("Database contains unexpected historic migration '{0}'")]
ExtraHistoric(String),
#[error("Historic migration '{0}' has a different hash")]
AlteredHistoric(String),
#[error("Database schema requires updating")]
UpdateRequired,
}
#[cfg_attr(feature = "instrument", instrument(skip(conn)))]
pub fn migrate(conn: &mut rusqlite::Connection, upgrade: bool) -> Result<(), Error> {
let migrations = include!(concat!(env!("OUT_DIR"), "/migrations.rs"));
let trans = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)?;
trans.execute_batch(
r"
CREATE TABLE IF NOT EXISTS schema_versions (
seq_no INTEGER UNIQUE NOT NULL,
file_name TEXT UNIQUE NOT NULL,
hash BLOB NOT NULL,
timestamp TEXT NOT NULL
)",
)?;
let mut next = 0;
if let Some(Some::<isize>(current_max)) = trans
.query_row(r"SELECT max(seq_no) FROM schema_versions", [], |row| {
row.get(0)
})
.optional()?
{
trans.execute_batch(
r"
CREATE TEMPORARY TABLE temp.schema_check (
seq_no INTEGER UNIQUE NOT NULL,
file_name TEXT NOT NULL,
hash BLOB NOT NULL
)",
)?;
let mut query = trans
.prepare(r"INSERT INTO temp.schema_check (seq_no,file_name,hash) VALUES (?1,?2,?3)")?;
for (i, (seq, file_name, hash, _)) in migrations.iter().enumerate() {
next = i + 1;
if *seq <= current_max {
query.execute((seq, file_name, hash))?;
} else {
break;
}
}
if let Some(file_name) = trans
.query_row(
r"
SELECT file_name FROM temp.schema_check AS sc
WHERE sc.file_name NOT IN (
SELECT file_name FROM schema_versions
)",
[],
|row| row.get(0),
)
.optional()?
{
Err(Error::MissingHistoric(file_name))?;
}
if let Some(file_name) = trans
.query_row(
r"
SELECT file_name FROM schema_versions AS sv
WHERE sv.file_name NOT IN (
SELECT file_name FROM temp.schema_check
)",
[],
|row| row.get(0),
)
.optional()?
{
Err(Error::ExtraHistoric(file_name))?;
}
if let Some(file_name) = trans
.query_row(
r"
SELECT sv.file_name FROM temp.schema_check AS sc
JOIN schema_versions AS sv ON sc.seq_no = sv.seq_no
WHERE sc.hash != sv.hash OR sc.file_name != sv.file_name
",
[],
|row| row.get(0),
)
.optional()?
{
Err(Error::AlteredHistoric(file_name))?;
}
trans.execute_batch("DROP TABLE temp.schema_check")?;
}
if next < migrations.len() {
if upgrade {
for (seq, file_name, hash, migration) in migrations[next..].iter() {
trans.execute_batch(migration)?;
trans.execute(r"INSERT INTO schema_versions (seq_no,file_name,hash,timestamp) VALUES (?1,?2,?3,datetime('now'))",(seq, file_name, hash))?;
}
} else {
Err(Error::UpdateRequired)?;
}
}
trans.commit()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn open_memory_db() -> rusqlite::Connection {
rusqlite::Connection::open_in_memory().unwrap()
}
#[test]
fn test_migration_creates_schema() {
let mut conn = open_memory_db();
migrate(&mut conn, true).unwrap();
let count: i64 = conn
.query_row("SELECT count(*) FROM schema_versions", [], |row| row.get(0))
.unwrap();
assert!(count > 0, "schema_versions should have migration records");
let table_count: i64 = conn
.query_row(
"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='bundles'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(table_count, 1, "bundles table should exist");
}
#[test]
fn test_migration_reopen_is_noop() {
let mut conn = open_memory_db();
migrate(&mut conn, true).unwrap();
let count_before: i64 = conn
.query_row("SELECT count(*) FROM schema_versions", [], |row| row.get(0))
.unwrap();
migrate(&mut conn, true).unwrap();
let count_after: i64 = conn
.query_row("SELECT count(*) FROM schema_versions", [], |row| row.get(0))
.unwrap();
assert_eq!(count_before, count_after);
}
#[test]
fn test_migration_upgrade_required() {
let mut conn = open_memory_db();
let result = migrate(&mut conn, false);
assert!(
matches!(result, Err(Error::UpdateRequired)),
"should return UpdateRequired when upgrade=false on fresh DB"
);
}
#[test]
fn test_migration_detects_missing_historic() {
let mut conn = open_memory_db();
migrate(&mut conn, true).unwrap();
conn.execute(
"UPDATE schema_versions SET file_name = 'renamed.sql' WHERE seq_no = (SELECT min(seq_no) FROM schema_versions)",
[],
)
.unwrap();
let result = migrate(&mut conn, true);
assert!(
matches!(result, Err(Error::MissingHistoric(_))),
"should detect missing historic migration: {result:?}"
);
}
#[test]
fn test_migration_detects_extra_historic() {
let mut conn = open_memory_db();
migrate(&mut conn, true).unwrap();
conn.execute(
"INSERT INTO schema_versions (seq_no, file_name, hash, timestamp) VALUES (999, 'fake.sql', X'00', datetime('now'))",
[],
)
.unwrap();
let result = migrate(&mut conn, true);
assert!(
matches!(result, Err(Error::ExtraHistoric(_))),
"should detect extra historic migration: {result:?}"
);
}
#[test]
fn test_migration_detects_altered_historic() {
let mut conn = open_memory_db();
migrate(&mut conn, true).unwrap();
conn.execute(
"UPDATE schema_versions SET hash = X'DEADBEEF' WHERE seq_no = (SELECT min(seq_no) FROM schema_versions)",
[],
)
.unwrap();
let result = migrate(&mut conn, true);
assert!(
matches!(result, Err(Error::AlteredHistoric(_))),
"should detect altered historic migration: {result:?}"
);
}
}