Skip to main content

scitadel_db/sqlite/
citations.rs

1use rusqlite::{OptionalExtension, params};
2use scitadel_core::error::CoreError;
3use scitadel_core::models::{
4    Citation, CitationDirection, PaperId, QuestionId, SearchId, SnowballRun, SnowballRunId,
5};
6use scitadel_core::ports::CitationRepository;
7
8use super::Database;
9use crate::error::DbError;
10
11const UPSERT_SQL: &str = "\
12    INSERT INTO citations
13        (source_paper_id, target_paper_id, direction,
14         discovered_by, depth, snowball_run_id)
15    VALUES (?1, ?2, ?3, ?4, ?5, ?6)
16    ON CONFLICT(source_paper_id, target_paper_id, direction) DO UPDATE SET
17        depth = MIN(citations.depth, excluded.depth),
18        snowball_run_id = COALESCE(excluded.snowball_run_id,
19                                   citations.snowball_run_id)";
20
21pub struct SqliteCitationRepository {
22    db: Database,
23}
24
25impl SqliteCitationRepository {
26    pub fn new(db: Database) -> Self {
27        Self { db }
28    }
29}
30
31fn row_to_citation(row: &rusqlite::Row) -> rusqlite::Result<Citation> {
32    let source_paper_id: String = row.get("source_paper_id")?;
33    let target_paper_id: String = row.get("target_paper_id")?;
34    let direction: String = row.get("direction")?;
35    let snowball_run_id: Option<String> = row.get("snowball_run_id")?;
36
37    Ok(Citation {
38        source_paper_id: PaperId::from(source_paper_id),
39        target_paper_id: PaperId::from(target_paper_id),
40        direction: CitationDirection::from_str_value(&direction)
41            .unwrap_or(CitationDirection::References),
42        discovered_by: row.get("discovered_by")?,
43        depth: row.get("depth")?,
44        snowball_run_id: snowball_run_id.map(SnowballRunId::from),
45    })
46}
47
48fn row_to_snowball_run(row: &rusqlite::Row) -> rusqlite::Result<SnowballRun> {
49    let id: String = row.get("id")?;
50    let search_id: Option<String> = row.get("search_id")?;
51    let question_id: Option<String> = row.get("question_id")?;
52    let created_at: String = row.get("created_at")?;
53
54    Ok(SnowballRun {
55        id: SnowballRunId::from(id),
56        search_id: search_id.map(SearchId::from),
57        question_id: question_id.map(QuestionId::from),
58        direction: row.get("direction")?,
59        max_depth: row.get("max_depth")?,
60        threshold: row.get("threshold")?,
61        total_discovered: row.get("total_discovered")?,
62        total_new_papers: row.get("total_new_papers")?,
63        created_at: super::parse_rfc3339_or_now(&created_at),
64    })
65}
66
67impl CitationRepository for SqliteCitationRepository {
68    fn save(&self, citation: &Citation) -> Result<(), CoreError> {
69        let conn = self.db.conn()?;
70        conn.execute(
71            UPSERT_SQL,
72            params![
73                citation.source_paper_id.as_str(),
74                citation.target_paper_id.as_str(),
75                citation.direction.to_string(),
76                citation.discovered_by,
77                citation.depth,
78                citation
79                    .snowball_run_id
80                    .as_ref()
81                    .map(|id| id.as_str().to_string()),
82            ],
83        )
84        .map_err(DbError::Sqlite)?;
85        Ok(())
86    }
87
88    fn save_many(&self, citations: &[Citation]) -> Result<(), CoreError> {
89        let conn = self.db.conn()?;
90        for c in citations {
91            conn.execute(
92                UPSERT_SQL,
93                params![
94                    c.source_paper_id.as_str(),
95                    c.target_paper_id.as_str(),
96                    c.direction.to_string(),
97                    c.discovered_by,
98                    c.depth,
99                    c.snowball_run_id.as_ref().map(|id| id.as_str().to_string()),
100                ],
101            )
102            .map_err(DbError::Sqlite)?;
103        }
104        Ok(())
105    }
106
107    fn get_references(&self, paper_id: &str) -> Result<Vec<Citation>, CoreError> {
108        let conn = self.db.conn()?;
109        let mut stmt = conn
110            .prepare("SELECT * FROM citations WHERE source_paper_id = ?1 AND direction = ?2")
111            .map_err(DbError::Sqlite)?;
112        let citations = stmt
113            .query_map(
114                params![paper_id, CitationDirection::References.to_string()],
115                row_to_citation,
116            )
117            .map_err(DbError::Sqlite)?
118            .filter_map(Result::ok)
119            .collect();
120        Ok(citations)
121    }
122
123    fn get_citations(&self, paper_id: &str) -> Result<Vec<Citation>, CoreError> {
124        let conn = self.db.conn()?;
125        let mut stmt = conn
126            .prepare("SELECT * FROM citations WHERE target_paper_id = ?1 AND direction = ?2")
127            .map_err(DbError::Sqlite)?;
128        let citations = stmt
129            .query_map(
130                params![paper_id, CitationDirection::CitedBy.to_string()],
131                row_to_citation,
132            )
133            .map_err(DbError::Sqlite)?
134            .filter_map(Result::ok)
135            .collect();
136        Ok(citations)
137    }
138
139    fn exists(
140        &self,
141        source_paper_id: &str,
142        target_paper_id: &str,
143        direction: &str,
144    ) -> Result<bool, CoreError> {
145        let conn = self.db.conn()?;
146        let exists: bool = conn
147            .query_row(
148                "SELECT EXISTS(SELECT 1 FROM citations WHERE source_paper_id = ?1 AND target_paper_id = ?2 AND direction = ?3)",
149                params![source_paper_id, target_paper_id, direction],
150                |row| row.get(0),
151            )
152            .map_err(DbError::Sqlite)?;
153        Ok(exists)
154    }
155
156    fn save_snowball_run(&self, run: &SnowballRun) -> Result<(), CoreError> {
157        let conn = self.db.conn()?;
158        conn.execute(
159            "INSERT OR REPLACE INTO snowball_runs
160                (id, search_id, question_id, direction, max_depth,
161                 threshold, total_discovered, total_new_papers, created_at)
162             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
163            params![
164                run.id.as_str(),
165                run.search_id.as_ref().map(|id| id.as_str().to_string()),
166                run.question_id.as_ref().map(|id| id.as_str().to_string()),
167                run.direction,
168                run.max_depth,
169                run.threshold,
170                run.total_discovered,
171                run.total_new_papers,
172                run.created_at.to_rfc3339(),
173            ],
174        )
175        .map_err(DbError::Sqlite)?;
176        Ok(())
177    }
178
179    fn get_snowball_run(&self, run_id: &str) -> Result<Option<SnowballRun>, CoreError> {
180        let conn = self.db.conn()?;
181        let mut stmt = conn
182            .prepare("SELECT * FROM snowball_runs WHERE id = ?1")
183            .map_err(DbError::Sqlite)?;
184        let result = stmt
185            .query_row(params![run_id], row_to_snowball_run)
186            .optional()
187            .map_err(DbError::Sqlite)?;
188        Ok(result)
189    }
190
191    fn list_snowball_runs(&self, limit: i64) -> Result<Vec<SnowballRun>, CoreError> {
192        let conn = self.db.conn()?;
193        let mut stmt = conn
194            .prepare("SELECT * FROM snowball_runs ORDER BY created_at DESC LIMIT ?1")
195            .map_err(DbError::Sqlite)?;
196        let runs = stmt
197            .query_map(params![limit], row_to_snowball_run)
198            .map_err(DbError::Sqlite)?
199            .filter_map(Result::ok)
200            .collect();
201        Ok(runs)
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::sqlite::{Database, SqlitePaperRepository};
209    use scitadel_core::models::Paper;
210    use scitadel_core::ports::PaperRepository;
211
212    #[test]
213    fn test_citation_crud() {
214        let db = Database::open_in_memory().unwrap();
215        db.migrate().unwrap();
216        let paper_repo = SqlitePaperRepository::new(db.clone());
217        let citation_repo = SqliteCitationRepository::new(db);
218
219        let paper_a = Paper::new("Paper A");
220        let paper_b = Paper::new("Paper B");
221        paper_repo.save(&paper_a).unwrap();
222        paper_repo.save(&paper_b).unwrap();
223
224        let citation = Citation {
225            source_paper_id: paper_a.id.clone(),
226            target_paper_id: paper_b.id.clone(),
227            direction: CitationDirection::References,
228            discovered_by: "openalex".to_string(),
229            depth: 1,
230            snowball_run_id: None,
231        };
232        citation_repo.save(&citation).unwrap();
233
234        let refs = citation_repo.get_references(paper_a.id.as_str()).unwrap();
235        assert_eq!(refs.len(), 1);
236        assert_eq!(refs[0].target_paper_id, paper_b.id);
237
238        assert!(
239            citation_repo
240                .exists(paper_a.id.as_str(), paper_b.id.as_str(), "references")
241                .unwrap()
242        );
243    }
244
245    #[test]
246    fn test_snowball_run_crud() {
247        let db = Database::open_in_memory().unwrap();
248        db.migrate().unwrap();
249        let repo = SqliteCitationRepository::new(db);
250
251        let run = SnowballRun::new();
252        repo.save_snowball_run(&run).unwrap();
253
254        let loaded = repo.get_snowball_run(run.id.as_str()).unwrap().unwrap();
255        assert_eq!(loaded.direction, "both");
256    }
257}