1use rusqlite::{OptionalExtension, params};
2use scitadel_core::error::CoreError;
3use scitadel_core::models::{
4 Citation, CitationDirection, PaperId, QuestionId, SearchId, SnowballRun, SnowballRunId,
5};
6use scitadel_core::ports::CitationRepository;
7
8use super::Database;
9use crate::error::DbError;
10
11const UPSERT_SQL: &str = "\
12 INSERT INTO citations
13 (source_paper_id, target_paper_id, direction,
14 discovered_by, depth, snowball_run_id)
15 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
16 ON CONFLICT(source_paper_id, target_paper_id, direction) DO UPDATE SET
17 depth = MIN(citations.depth, excluded.depth),
18 snowball_run_id = COALESCE(excluded.snowball_run_id,
19 citations.snowball_run_id)";
20
21pub struct SqliteCitationRepository {
22 db: Database,
23}
24
25impl SqliteCitationRepository {
26 pub fn new(db: Database) -> Self {
27 Self { db }
28 }
29}
30
31fn row_to_citation(row: &rusqlite::Row) -> rusqlite::Result<Citation> {
32 let source_paper_id: String = row.get("source_paper_id")?;
33 let target_paper_id: String = row.get("target_paper_id")?;
34 let direction: String = row.get("direction")?;
35 let snowball_run_id: Option<String> = row.get("snowball_run_id")?;
36
37 Ok(Citation {
38 source_paper_id: PaperId::from(source_paper_id),
39 target_paper_id: PaperId::from(target_paper_id),
40 direction: CitationDirection::from_str_value(&direction)
41 .unwrap_or(CitationDirection::References),
42 discovered_by: row.get("discovered_by")?,
43 depth: row.get("depth")?,
44 snowball_run_id: snowball_run_id.map(SnowballRunId::from),
45 })
46}
47
48fn row_to_snowball_run(row: &rusqlite::Row) -> rusqlite::Result<SnowballRun> {
49 let id: String = row.get("id")?;
50 let search_id: Option<String> = row.get("search_id")?;
51 let question_id: Option<String> = row.get("question_id")?;
52 let created_at: String = row.get("created_at")?;
53
54 Ok(SnowballRun {
55 id: SnowballRunId::from(id),
56 search_id: search_id.map(SearchId::from),
57 question_id: question_id.map(QuestionId::from),
58 direction: row.get("direction")?,
59 max_depth: row.get("max_depth")?,
60 threshold: row.get("threshold")?,
61 total_discovered: row.get("total_discovered")?,
62 total_new_papers: row.get("total_new_papers")?,
63 created_at: super::parse_rfc3339_or_now(&created_at),
64 })
65}
66
67impl CitationRepository for SqliteCitationRepository {
68 fn save(&self, citation: &Citation) -> Result<(), CoreError> {
69 let conn = self.db.conn()?;
70 conn.execute(
71 UPSERT_SQL,
72 params![
73 citation.source_paper_id.as_str(),
74 citation.target_paper_id.as_str(),
75 citation.direction.to_string(),
76 citation.discovered_by,
77 citation.depth,
78 citation
79 .snowball_run_id
80 .as_ref()
81 .map(|id| id.as_str().to_string()),
82 ],
83 )
84 .map_err(DbError::Sqlite)?;
85 Ok(())
86 }
87
88 fn save_many(&self, citations: &[Citation]) -> Result<(), CoreError> {
89 let conn = self.db.conn()?;
90 for c in citations {
91 conn.execute(
92 UPSERT_SQL,
93 params![
94 c.source_paper_id.as_str(),
95 c.target_paper_id.as_str(),
96 c.direction.to_string(),
97 c.discovered_by,
98 c.depth,
99 c.snowball_run_id.as_ref().map(|id| id.as_str().to_string()),
100 ],
101 )
102 .map_err(DbError::Sqlite)?;
103 }
104 Ok(())
105 }
106
107 fn get_references(&self, paper_id: &str) -> Result<Vec<Citation>, CoreError> {
108 let conn = self.db.conn()?;
109 let mut stmt = conn
110 .prepare("SELECT * FROM citations WHERE source_paper_id = ?1 AND direction = ?2")
111 .map_err(DbError::Sqlite)?;
112 let citations = stmt
113 .query_map(
114 params![paper_id, CitationDirection::References.to_string()],
115 row_to_citation,
116 )
117 .map_err(DbError::Sqlite)?
118 .filter_map(Result::ok)
119 .collect();
120 Ok(citations)
121 }
122
123 fn get_citations(&self, paper_id: &str) -> Result<Vec<Citation>, CoreError> {
124 let conn = self.db.conn()?;
125 let mut stmt = conn
126 .prepare("SELECT * FROM citations WHERE target_paper_id = ?1 AND direction = ?2")
127 .map_err(DbError::Sqlite)?;
128 let citations = stmt
129 .query_map(
130 params![paper_id, CitationDirection::CitedBy.to_string()],
131 row_to_citation,
132 )
133 .map_err(DbError::Sqlite)?
134 .filter_map(Result::ok)
135 .collect();
136 Ok(citations)
137 }
138
139 fn exists(
140 &self,
141 source_paper_id: &str,
142 target_paper_id: &str,
143 direction: &str,
144 ) -> Result<bool, CoreError> {
145 let conn = self.db.conn()?;
146 let exists: bool = conn
147 .query_row(
148 "SELECT EXISTS(SELECT 1 FROM citations WHERE source_paper_id = ?1 AND target_paper_id = ?2 AND direction = ?3)",
149 params![source_paper_id, target_paper_id, direction],
150 |row| row.get(0),
151 )
152 .map_err(DbError::Sqlite)?;
153 Ok(exists)
154 }
155
156 fn save_snowball_run(&self, run: &SnowballRun) -> Result<(), CoreError> {
157 let conn = self.db.conn()?;
158 conn.execute(
159 "INSERT OR REPLACE INTO snowball_runs
160 (id, search_id, question_id, direction, max_depth,
161 threshold, total_discovered, total_new_papers, created_at)
162 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
163 params![
164 run.id.as_str(),
165 run.search_id.as_ref().map(|id| id.as_str().to_string()),
166 run.question_id.as_ref().map(|id| id.as_str().to_string()),
167 run.direction,
168 run.max_depth,
169 run.threshold,
170 run.total_discovered,
171 run.total_new_papers,
172 run.created_at.to_rfc3339(),
173 ],
174 )
175 .map_err(DbError::Sqlite)?;
176 Ok(())
177 }
178
179 fn get_snowball_run(&self, run_id: &str) -> Result<Option<SnowballRun>, CoreError> {
180 let conn = self.db.conn()?;
181 let mut stmt = conn
182 .prepare("SELECT * FROM snowball_runs WHERE id = ?1")
183 .map_err(DbError::Sqlite)?;
184 let result = stmt
185 .query_row(params![run_id], row_to_snowball_run)
186 .optional()
187 .map_err(DbError::Sqlite)?;
188 Ok(result)
189 }
190
191 fn list_snowball_runs(&self, limit: i64) -> Result<Vec<SnowballRun>, CoreError> {
192 let conn = self.db.conn()?;
193 let mut stmt = conn
194 .prepare("SELECT * FROM snowball_runs ORDER BY created_at DESC LIMIT ?1")
195 .map_err(DbError::Sqlite)?;
196 let runs = stmt
197 .query_map(params![limit], row_to_snowball_run)
198 .map_err(DbError::Sqlite)?
199 .filter_map(Result::ok)
200 .collect();
201 Ok(runs)
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::sqlite::{Database, SqlitePaperRepository};
209 use scitadel_core::models::Paper;
210 use scitadel_core::ports::PaperRepository;
211
212 #[test]
213 fn test_citation_crud() {
214 let db = Database::open_in_memory().unwrap();
215 db.migrate().unwrap();
216 let paper_repo = SqlitePaperRepository::new(db.clone());
217 let citation_repo = SqliteCitationRepository::new(db);
218
219 let paper_a = Paper::new("Paper A");
220 let paper_b = Paper::new("Paper B");
221 paper_repo.save(&paper_a).unwrap();
222 paper_repo.save(&paper_b).unwrap();
223
224 let citation = Citation {
225 source_paper_id: paper_a.id.clone(),
226 target_paper_id: paper_b.id.clone(),
227 direction: CitationDirection::References,
228 discovered_by: "openalex".to_string(),
229 depth: 1,
230 snowball_run_id: None,
231 };
232 citation_repo.save(&citation).unwrap();
233
234 let refs = citation_repo.get_references(paper_a.id.as_str()).unwrap();
235 assert_eq!(refs.len(), 1);
236 assert_eq!(refs[0].target_paper_id, paper_b.id);
237
238 assert!(
239 citation_repo
240 .exists(paper_a.id.as_str(), paper_b.id.as_str(), "references")
241 .unwrap()
242 );
243 }
244
245 #[test]
246 fn test_snowball_run_crud() {
247 let db = Database::open_in_memory().unwrap();
248 db.migrate().unwrap();
249 let repo = SqliteCitationRepository::new(db);
250
251 let run = SnowballRun::new();
252 repo.save_snowball_run(&run).unwrap();
253
254 let loaded = repo.get_snowball_run(run.id.as_str()).unwrap().unwrap();
255 assert_eq!(loaded.direction, "both");
256 }
257}