use rusqlite::{Connection, Error as RusqliteError, Result as SqliteResult};
use std::fmt;
type MigrationFn = fn(&Connection) -> SqliteResult<()>;
#[derive(Debug)]
pub enum MigrationError {
UnsupportedVersion {
current_version: i32,
max_supported: i32,
},
}
impl fmt::Display for MigrationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MigrationError::UnsupportedVersion {
current_version,
max_supported,
} => write!(
f,
"Database schema version {} is newer than this vipune binary supports (max: {}). Upgrade vipune.",
current_version, max_supported
),
}
}
}
impl std::error::Error for MigrationError {}
impl From<MigrationError> for RusqliteError {
fn from(err: MigrationError) -> Self {
RusqliteError::ToSqlConversionFailure(Box::new(err))
}
}
fn migrate_v1(_conn: &Connection) -> SqliteResult<()> {
Ok(())
}
fn migrate_v2(conn: &Connection) -> SqliteResult<()> {
conn.execute_batch(
"ALTER TABLE memories ADD COLUMN type TEXT NOT NULL DEFAULT 'fact';
ALTER TABLE memories ADD COLUMN status TEXT NOT NULL DEFAULT 'active';
ALTER TABLE memories ADD COLUMN superseded_by TEXT;
CREATE INDEX IF NOT EXISTS idx_memories_type ON memories(type);
CREATE INDEX IF NOT EXISTS idx_memories_status ON memories(status);
CREATE INDEX IF NOT EXISTS idx_memories_project_status ON memories(project_id, status);",
)?;
Ok(())
}
fn migrations() -> Vec<MigrationFn> {
vec![migrate_v1, migrate_v2]
}
fn total_migrations() -> i32 {
migrations().len() as i32
}
pub fn run_migrations(conn: &Connection) -> SqliteResult<()> {
let current: i32 = conn.pragma_query_value(None, "user_version", |r| r.get(0))?;
if current > total_migrations() {
return Err(MigrationError::UnsupportedVersion {
current_version: current,
max_supported: total_migrations(),
}
.into());
}
let all = migrations();
for (i, migration) in all.iter().enumerate() {
let version = (i + 1) as i32;
if version > current {
conn.execute_batch("BEGIN EXCLUSIVE;")?;
match migration(conn) {
Ok(()) => {
conn.execute_batch("COMMIT;")?;
conn.pragma_update(None, "user_version", version)?;
}
Err(e) => {
conn.execute_batch("ROLLBACK;")?;
return Err(e);
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_db() -> Connection {
Connection::open_in_memory().unwrap()
}
fn init_schema(conn: &Connection) -> SqliteResult<()> {
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
content TEXT NOT NULL,
embedding BLOB NOT NULL,
metadata TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
"#,
)?;
Ok(())
}
#[test]
fn test_fresh_db_version_becomes_2() {
let conn = create_test_db();
init_schema(&conn).unwrap();
let initial: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(initial, 0);
run_migrations(&conn).unwrap();
let final_version: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(final_version, 2);
}
#[test]
fn test_already_at_version_2_is_noop() {
let conn = create_test_db();
init_schema(&conn).unwrap();
run_migrations(&conn).unwrap();
let version: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(version, 2);
run_migrations(&conn).unwrap();
let version_after: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(version_after, 2);
}
#[test]
fn test_upgrade_from_v0_to_v2() {
let conn = create_test_db();
init_schema(&conn).unwrap();
conn.pragma_update(None, "user_version", 0).unwrap();
let initial: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(initial, 0);
run_migrations(&conn).unwrap();
let final_version: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(final_version, 2);
}
#[test]
fn test_migration_framework_idempotent() {
let conn = create_test_db();
init_schema(&conn).unwrap();
for _ in 0..5 {
run_migrations(&conn).unwrap();
}
let version: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(version, 2);
}
#[test]
fn test_migration_transaction_rollback_on_error() {
let conn = create_test_db();
init_schema(&conn).unwrap();
fn failing_migration(_conn: &Connection) -> SqliteResult<()> {
Err(RusqliteError::InvalidQuery)
}
conn.pragma_update(None, "user_version", 0).unwrap();
let initial_version: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(initial_version, 0);
conn.execute_batch("BEGIN EXCLUSIVE;").unwrap();
let result = failing_migration(&conn);
assert!(result.is_err()); conn.execute_batch("ROLLBACK;").unwrap();
let version_after: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(version_after, 0);
run_migrations(&conn).unwrap();
let final_version: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(final_version, 2);
}
#[test]
fn test_future_version_database_error() {
let conn = create_test_db();
init_schema(&conn).unwrap();
conn.pragma_update(None, "user_version", 999).unwrap();
let result = run_migrations(&conn);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("schema version"));
assert!(err_msg.contains("999"));
assert!(err_msg.contains("Upgrade vipune"));
let version: i32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(version, 999);
}
}