scitadel-db 0.7.0

SQLite-backed repositories and migrations for scitadel.
Documentation
use rusqlite::Connection;

use crate::error::DbError;

const MIGRATION_001: &str = include_str!("../../migrations/001_initial.sql");
const MIGRATION_002: &str = include_str!("../../migrations/002_citations.sql");
const MIGRATION_003: &str = include_str!("../../migrations/003_full_text.sql");
const MIGRATION_004: &str = include_str!("../../migrations/004_paper_state.sql");
const MIGRATION_005: &str = include_str!("../../migrations/005_annotations.sql");
const MIGRATION_006: &str = include_str!("../../migrations/006_search_fts.sql");
const MIGRATION_007: &str = include_str!("../../migrations/007_paper_download_state.sql");
const MIGRATION_008: &str = include_str!("../../migrations/008_tui_state.sql");
const MIGRATION_009: &str = include_str!("../../migrations/009_bibtex_keys.sql");
const MIGRATION_010: &str = include_str!("../../migrations/010_shortlists.sql");
const MIGRATION_011: &str = include_str!("../../migrations/011_paper_aliases.sql");
const MIGRATION_012: &str = include_str!("../../migrations/012_paper_tags.sql");

const MIGRATIONS: &[(i64, &str)] = &[
    (1, MIGRATION_001),
    (2, MIGRATION_002),
    (3, MIGRATION_003),
    (4, MIGRATION_004),
    (5, MIGRATION_005),
    (6, MIGRATION_006),
    (7, MIGRATION_007),
    (8, MIGRATION_008),
    (9, MIGRATION_009),
    (10, MIGRATION_010),
    (11, MIGRATION_011),
    (12, MIGRATION_012),
];

/// Run all pending migrations, skipping already-applied ones.
pub fn run_migrations(conn: &Connection) -> Result<(), DbError> {
    conn.execute_batch(
        "CREATE TABLE IF NOT EXISTS schema_version (
            version INTEGER PRIMARY KEY,
            applied_at TEXT NOT NULL
        )",
    )
    .map_err(|e| DbError::Migration(e.to_string()))?;

    let applied: Vec<i64> = {
        let mut stmt = conn
            .prepare("SELECT version FROM schema_version")
            .map_err(|e| DbError::Migration(e.to_string()))?;
        let rows = stmt
            .query_map([], |row| row.get(0))
            .map_err(|e| DbError::Migration(e.to_string()))?;
        rows.filter_map(Result::ok).collect()
    };

    for &(version, sql) in MIGRATIONS {
        if applied.contains(&version) {
            continue;
        }
        conn.execute_batch(sql)
            .map_err(|e| DbError::Migration(format!("migration {version} failed: {e}")))?;
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_migrations_idempotent() {
        let conn = Connection::open_in_memory().unwrap();
        run_migrations(&conn).unwrap();
        run_migrations(&conn).unwrap(); // should not fail
    }

    #[test]
    fn test_all_tables_created() {
        let conn = Connection::open_in_memory().unwrap();
        run_migrations(&conn).unwrap();

        let tables: Vec<String> = {
            let mut stmt = conn
                .prepare("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
                .unwrap();
            stmt.query_map([], |row| row.get(0))
                .unwrap()
                .filter_map(Result::ok)
                .collect()
        };

        assert!(tables.contains(&"papers".to_string()));
        assert!(tables.contains(&"searches".to_string()));
        assert!(tables.contains(&"search_results".to_string()));
        assert!(tables.contains(&"research_questions".to_string()));
        assert!(tables.contains(&"search_terms".to_string()));
        assert!(tables.contains(&"assessments".to_string()));
        assert!(tables.contains(&"citations".to_string()));
        assert!(tables.contains(&"snowball_runs".to_string()));
        assert!(tables.contains(&"paper_state".to_string()));
        assert!(tables.contains(&"annotations".to_string()));
        assert!(tables.contains(&"annotation_reads".to_string()));
        assert!(tables.contains(&"searches_fts".to_string()));
        assert!(tables.contains(&"paper_aliases".to_string()));
        assert!(tables.contains(&"paper_tags".to_string()));
        assert!(tables.contains(&"schema_version".to_string()));
    }
}