scitadel-db 0.6.0

SQLite-backed repositories and migrations for scitadel.
Documentation
use rusqlite::{OptionalExtension, params};
use scitadel_core::error::CoreError;
use scitadel_core::models::{
    Citation, CitationDirection, PaperId, QuestionId, SearchId, SnowballRun, SnowballRunId,
};
use scitadel_core::ports::CitationRepository;

use super::Database;
use crate::error::DbError;

const UPSERT_SQL: &str = "\
    INSERT INTO citations
        (source_paper_id, target_paper_id, direction,
         discovered_by, depth, snowball_run_id)
    VALUES (?1, ?2, ?3, ?4, ?5, ?6)
    ON CONFLICT(source_paper_id, target_paper_id, direction) DO UPDATE SET
        depth = MIN(citations.depth, excluded.depth),
        snowball_run_id = COALESCE(excluded.snowball_run_id,
                                   citations.snowball_run_id)";

pub struct SqliteCitationRepository {
    db: Database,
}

impl SqliteCitationRepository {
    pub fn new(db: Database) -> Self {
        Self { db }
    }
}

fn row_to_citation(row: &rusqlite::Row) -> rusqlite::Result<Citation> {
    let source_paper_id: String = row.get("source_paper_id")?;
    let target_paper_id: String = row.get("target_paper_id")?;
    let direction: String = row.get("direction")?;
    let snowball_run_id: Option<String> = row.get("snowball_run_id")?;

    Ok(Citation {
        source_paper_id: PaperId::from(source_paper_id),
        target_paper_id: PaperId::from(target_paper_id),
        direction: CitationDirection::from_str_value(&direction)
            .unwrap_or(CitationDirection::References),
        discovered_by: row.get("discovered_by")?,
        depth: row.get("depth")?,
        snowball_run_id: snowball_run_id.map(SnowballRunId::from),
    })
}

fn row_to_snowball_run(row: &rusqlite::Row) -> rusqlite::Result<SnowballRun> {
    let id: String = row.get("id")?;
    let search_id: Option<String> = row.get("search_id")?;
    let question_id: Option<String> = row.get("question_id")?;
    let created_at: String = row.get("created_at")?;

    Ok(SnowballRun {
        id: SnowballRunId::from(id),
        search_id: search_id.map(SearchId::from),
        question_id: question_id.map(QuestionId::from),
        direction: row.get("direction")?,
        max_depth: row.get("max_depth")?,
        threshold: row.get("threshold")?,
        total_discovered: row.get("total_discovered")?,
        total_new_papers: row.get("total_new_papers")?,
        created_at: super::parse_rfc3339_or_now(&created_at),
    })
}

impl CitationRepository for SqliteCitationRepository {
    fn save(&self, citation: &Citation) -> Result<(), CoreError> {
        let conn = self.db.conn()?;
        conn.execute(
            UPSERT_SQL,
            params![
                citation.source_paper_id.as_str(),
                citation.target_paper_id.as_str(),
                citation.direction.to_string(),
                citation.discovered_by,
                citation.depth,
                citation
                    .snowball_run_id
                    .as_ref()
                    .map(|id| id.as_str().to_string()),
            ],
        )
        .map_err(DbError::Sqlite)?;
        Ok(())
    }

    fn save_many(&self, citations: &[Citation]) -> Result<(), CoreError> {
        let conn = self.db.conn()?;
        for c in citations {
            conn.execute(
                UPSERT_SQL,
                params![
                    c.source_paper_id.as_str(),
                    c.target_paper_id.as_str(),
                    c.direction.to_string(),
                    c.discovered_by,
                    c.depth,
                    c.snowball_run_id.as_ref().map(|id| id.as_str().to_string()),
                ],
            )
            .map_err(DbError::Sqlite)?;
        }
        Ok(())
    }

    fn get_references(&self, paper_id: &str) -> Result<Vec<Citation>, CoreError> {
        let conn = self.db.conn()?;
        let mut stmt = conn
            .prepare("SELECT * FROM citations WHERE source_paper_id = ?1 AND direction = ?2")
            .map_err(DbError::Sqlite)?;
        let citations = stmt
            .query_map(
                params![paper_id, CitationDirection::References.to_string()],
                row_to_citation,
            )
            .map_err(DbError::Sqlite)?
            .filter_map(Result::ok)
            .collect();
        Ok(citations)
    }

    fn get_citations(&self, paper_id: &str) -> Result<Vec<Citation>, CoreError> {
        let conn = self.db.conn()?;
        let mut stmt = conn
            .prepare("SELECT * FROM citations WHERE target_paper_id = ?1 AND direction = ?2")
            .map_err(DbError::Sqlite)?;
        let citations = stmt
            .query_map(
                params![paper_id, CitationDirection::CitedBy.to_string()],
                row_to_citation,
            )
            .map_err(DbError::Sqlite)?
            .filter_map(Result::ok)
            .collect();
        Ok(citations)
    }

    fn exists(
        &self,
        source_paper_id: &str,
        target_paper_id: &str,
        direction: &str,
    ) -> Result<bool, CoreError> {
        let conn = self.db.conn()?;
        let exists: bool = conn
            .query_row(
                "SELECT EXISTS(SELECT 1 FROM citations WHERE source_paper_id = ?1 AND target_paper_id = ?2 AND direction = ?3)",
                params![source_paper_id, target_paper_id, direction],
                |row| row.get(0),
            )
            .map_err(DbError::Sqlite)?;
        Ok(exists)
    }

    fn save_snowball_run(&self, run: &SnowballRun) -> Result<(), CoreError> {
        let conn = self.db.conn()?;
        conn.execute(
            "INSERT OR REPLACE INTO snowball_runs
                (id, search_id, question_id, direction, max_depth,
                 threshold, total_discovered, total_new_papers, created_at)
             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
            params![
                run.id.as_str(),
                run.search_id.as_ref().map(|id| id.as_str().to_string()),
                run.question_id.as_ref().map(|id| id.as_str().to_string()),
                run.direction,
                run.max_depth,
                run.threshold,
                run.total_discovered,
                run.total_new_papers,
                run.created_at.to_rfc3339(),
            ],
        )
        .map_err(DbError::Sqlite)?;
        Ok(())
    }

    fn get_snowball_run(&self, run_id: &str) -> Result<Option<SnowballRun>, CoreError> {
        let conn = self.db.conn()?;
        let mut stmt = conn
            .prepare("SELECT * FROM snowball_runs WHERE id = ?1")
            .map_err(DbError::Sqlite)?;
        let result = stmt
            .query_row(params![run_id], row_to_snowball_run)
            .optional()
            .map_err(DbError::Sqlite)?;
        Ok(result)
    }

    fn list_snowball_runs(&self, limit: i64) -> Result<Vec<SnowballRun>, CoreError> {
        let conn = self.db.conn()?;
        let mut stmt = conn
            .prepare("SELECT * FROM snowball_runs ORDER BY created_at DESC LIMIT ?1")
            .map_err(DbError::Sqlite)?;
        let runs = stmt
            .query_map(params![limit], row_to_snowball_run)
            .map_err(DbError::Sqlite)?
            .filter_map(Result::ok)
            .collect();
        Ok(runs)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::sqlite::{Database, SqlitePaperRepository};
    use scitadel_core::models::Paper;
    use scitadel_core::ports::PaperRepository;

    #[test]
    fn test_citation_crud() {
        let db = Database::open_in_memory().unwrap();
        db.migrate().unwrap();
        let paper_repo = SqlitePaperRepository::new(db.clone());
        let citation_repo = SqliteCitationRepository::new(db);

        let paper_a = Paper::new("Paper A");
        let paper_b = Paper::new("Paper B");
        paper_repo.save(&paper_a).unwrap();
        paper_repo.save(&paper_b).unwrap();

        let citation = Citation {
            source_paper_id: paper_a.id.clone(),
            target_paper_id: paper_b.id.clone(),
            direction: CitationDirection::References,
            discovered_by: "openalex".to_string(),
            depth: 1,
            snowball_run_id: None,
        };
        citation_repo.save(&citation).unwrap();

        let refs = citation_repo.get_references(paper_a.id.as_str()).unwrap();
        assert_eq!(refs.len(), 1);
        assert_eq!(refs[0].target_paper_id, paper_b.id);

        assert!(
            citation_repo
                .exists(paper_a.id.as_str(), paper_b.id.as_str(), "references")
                .unwrap()
        );
    }

    #[test]
    fn test_snowball_run_crud() {
        let db = Database::open_in_memory().unwrap();
        db.migrate().unwrap();
        let repo = SqliteCitationRepository::new(db);

        let run = SnowballRun::new();
        repo.save_snowball_run(&run).unwrap();

        let loaded = repo.get_snowball_run(run.id.as_str()).unwrap().unwrap();
        assert_eq!(loaded.direction, "both");
    }
}