1use rusqlite::{OptionalExtension, params};
2use scitadel_core::error::CoreError;
3use scitadel_core::models::{PaperId, Search, SearchId, SearchResult, SourceOutcome};
4use scitadel_core::ports::SearchRepository;
5
6use super::Database;
7use crate::error::DbError;
8
9pub struct SqliteSearchRepository {
10 db: Database,
11}
12
13impl SqliteSearchRepository {
14 pub fn new(db: Database) -> Self {
15 Self { db }
16 }
17
18 pub fn find_similar(&self, query: &str, limit: i64) -> Result<Vec<(Search, f64)>, DbError> {
23 let sanitized = sanitize_fts5_query(query);
24 if sanitized.is_empty() {
25 return Ok(Vec::new());
26 }
27 let conn = self.db.conn()?;
28 let mut stmt = conn
29 .prepare(
30 "SELECT s.*, bm25(searches_fts) AS rank
31 FROM searches_fts f
32 JOIN searches s ON s.id = f.search_id
33 WHERE searches_fts MATCH ?1
34 ORDER BY rank ASC
35 LIMIT ?2",
36 )
37 .map_err(DbError::Sqlite)?;
38 let rows = stmt
39 .query_map(params![sanitized, limit], |row| {
40 let search = row_to_search(row)?;
41 let rank: f64 = row.get("rank")?;
42 Ok((search, rank))
43 })
44 .map_err(DbError::Sqlite)?;
45 let out: Vec<(Search, f64)> = rows.filter_map(Result::ok).collect();
46 Ok(out)
47 }
48}
49
50fn sanitize_fts5_query(q: &str) -> String {
55 q.chars()
56 .map(|c| match c {
57 '"' | '\'' | '(' | ')' | ':' | '*' | '-' => ' ',
59 other => other,
60 })
61 .collect::<String>()
62 .split_whitespace()
63 .collect::<Vec<_>>()
64 .join(" ")
65}
66
67fn row_to_search(row: &rusqlite::Row) -> rusqlite::Result<Search> {
68 let id: String = row.get("id")?;
69 let sources_json: String = row.get("sources")?;
70 let parameters_json: String = row.get("parameters")?;
71 let outcomes_json: String = row.get("source_outcomes")?;
72 let created_at: String = row.get("created_at")?;
73
74 let outcomes: Vec<SourceOutcome> = serde_json::from_str(&outcomes_json).unwrap_or_default();
75
76 Ok(Search {
77 id: SearchId::from(id),
78 query: row.get("query")?,
79 sources: serde_json::from_str(&sources_json).unwrap_or_default(),
80 parameters: serde_json::from_str(¶meters_json).unwrap_or_default(),
81 source_outcomes: outcomes,
82 total_candidates: row.get("total_candidates")?,
83 total_papers: row.get("total_papers")?,
84 created_at: super::parse_rfc3339_or_now(&created_at),
85 })
86}
87
88impl SearchRepository for SqliteSearchRepository {
89 fn save(&self, search: &Search) -> Result<(), CoreError> {
90 let conn = self.db.conn()?;
91 let outcomes_json = serde_json::to_string(&search.source_outcomes).unwrap_or_default();
92 conn.execute(
93 "INSERT INTO searches
94 (id, query, sources, parameters, source_outcomes,
95 total_candidates, total_papers, created_at)
96 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
97 ON CONFLICT(id) DO UPDATE SET
98 query = excluded.query,
99 sources = excluded.sources,
100 parameters = excluded.parameters,
101 source_outcomes = excluded.source_outcomes,
102 total_candidates = excluded.total_candidates,
103 total_papers = excluded.total_papers",
104 params![
105 search.id.as_str(),
106 search.query,
107 serde_json::to_string(&search.sources).unwrap_or_default(),
108 serde_json::to_string(&search.parameters).unwrap_or_default(),
109 outcomes_json,
110 search.total_candidates,
111 search.total_papers,
112 search.created_at.to_rfc3339(),
113 ],
114 )
115 .map_err(DbError::Sqlite)?;
116 Ok(())
117 }
118
119 fn get(&self, search_id: &str) -> Result<Option<Search>, CoreError> {
120 let conn = self.db.conn()?;
121 let mut stmt = conn
122 .prepare("SELECT * FROM searches WHERE id = ?1")
123 .map_err(DbError::Sqlite)?;
124 let result = stmt
125 .query_row(params![search_id], row_to_search)
126 .optional()
127 .map_err(DbError::Sqlite)?;
128 Ok(result)
129 }
130
131 fn save_results(&self, results: &[SearchResult]) -> Result<(), CoreError> {
132 let mut conn = self.db.conn()?;
133 let tx = conn.transaction().map_err(DbError::Sqlite)?;
134 for r in results {
135 tx.execute(
136 "INSERT INTO search_results
137 (search_id, paper_id, source, rank, score, raw_metadata)
138 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
139 ON CONFLICT(search_id, paper_id, source) DO UPDATE SET
140 rank = excluded.rank,
141 score = excluded.score,
142 raw_metadata = excluded.raw_metadata",
143 params![
144 r.search_id.as_str(),
145 r.paper_id.as_str(),
146 r.source,
147 r.rank,
148 r.score,
149 serde_json::to_string(&r.raw_metadata).unwrap_or_default(),
150 ],
151 )
152 .map_err(DbError::Sqlite)?;
153 }
154 tx.commit().map_err(DbError::Sqlite)?;
155 Ok(())
156 }
157
158 fn get_results(&self, search_id: &str) -> Result<Vec<SearchResult>, CoreError> {
159 let conn = self.db.conn()?;
160 let mut stmt = conn
161 .prepare("SELECT * FROM search_results WHERE search_id = ?1")
162 .map_err(DbError::Sqlite)?;
163 let results = stmt
164 .query_map(params![search_id], |row| {
165 let search_id: String = row.get("search_id")?;
166 let paper_id: String = row.get("paper_id")?;
167 let raw_json: String = row.get("raw_metadata")?;
168 Ok(SearchResult {
169 search_id: SearchId::from(search_id),
170 paper_id: PaperId::from(paper_id),
171 source: row.get("source")?,
172 rank: row.get("rank")?,
173 score: row.get("score")?,
174 raw_metadata: serde_json::from_str(&raw_json).unwrap_or_default(),
175 })
176 })
177 .map_err(DbError::Sqlite)?
178 .filter_map(Result::ok)
179 .collect();
180 Ok(results)
181 }
182
183 fn list_searches(&self, limit: i64) -> Result<Vec<Search>, CoreError> {
184 let conn = self.db.conn()?;
185 let mut stmt = conn
186 .prepare("SELECT * FROM searches ORDER BY created_at DESC LIMIT ?1")
187 .map_err(DbError::Sqlite)?;
188 let searches = stmt
189 .query_map(params![limit], row_to_search)
190 .map_err(DbError::Sqlite)?
191 .filter_map(Result::ok)
192 .collect();
193 Ok(searches)
194 }
195
196 fn diff_searches(
197 &self,
198 search_id_a: &str,
199 search_id_b: &str,
200 ) -> Result<(Vec<String>, Vec<String>), CoreError> {
201 let conn = self.db.conn()?;
202
203 let get_paper_ids =
204 |search_id: &str| -> Result<std::collections::HashSet<String>, DbError> {
205 let mut stmt = conn
206 .prepare("SELECT DISTINCT paper_id FROM search_results WHERE search_id = ?1")
207 .map_err(DbError::Sqlite)?;
208 let ids: std::collections::HashSet<String> = stmt
209 .query_map(params![search_id], |row| row.get(0))
210 .map_err(DbError::Sqlite)?
211 .filter_map(Result::ok)
212 .collect();
213 Ok(ids)
214 };
215
216 let papers_a = get_paper_ids(search_id_a).map_err(Into::<CoreError>::into)?;
217 let papers_b = get_paper_ids(search_id_b).map_err(Into::<CoreError>::into)?;
218
219 let mut added: Vec<String> = papers_b.difference(&papers_a).cloned().collect();
220 let mut removed: Vec<String> = papers_a.difference(&papers_b).cloned().collect();
221 added.sort();
222 removed.sort();
223
224 Ok((added, removed))
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::sqlite::{Database, SqlitePaperRepository};
232 use scitadel_core::models::Paper;
233 use scitadel_core::ports::PaperRepository;
234
235 fn setup() -> (Database, SqliteSearchRepository, SqlitePaperRepository) {
236 let db = Database::open_in_memory().unwrap();
237 db.migrate().unwrap();
238 let search_repo = SqliteSearchRepository::new(db.clone());
239 let paper_repo = SqlitePaperRepository::new(db.clone());
240 (db, search_repo, paper_repo)
241 }
242
243 #[test]
244 fn test_save_and_get_search() {
245 let (_, repo, _) = setup();
246 let search = Search::new("test query");
247 repo.save(&search).unwrap();
248
249 let loaded = repo.get(search.id.as_str()).unwrap().unwrap();
250 assert_eq!(loaded.query, "test query");
251 }
252
253 #[test]
254 fn fts_sanitizer_strips_operators() {
255 assert_eq!(sanitize_fts5_query(r#"GAN "stability""#), "GAN stability");
256 assert_eq!(sanitize_fts5_query("foo (bar) - baz"), "foo bar baz");
257 assert_eq!(sanitize_fts5_query(" "), "");
258 assert_eq!(sanitize_fts5_query("scope:field"), "scope field");
259 }
260
261 #[test]
262 fn fts_find_similar_roundtrip() {
263 let (_, repo, _) = setup();
264 let a = {
265 let mut s = Search::new("generative adversarial networks stability");
266 s.id = SearchId::from("s-a");
267 s
268 };
269 let b = {
270 let mut s = Search::new("attention is all you need transformers");
271 s.id = SearchId::from("s-b");
272 s
273 };
274 let c = {
275 let mut s = Search::new("retrieval augmented generation");
276 s.id = SearchId::from("s-c");
277 s
278 };
279 repo.save(&a).unwrap();
280 repo.save(&b).unwrap();
281 repo.save(&c).unwrap();
282
283 let hits = repo.find_similar("generative networks", 10).unwrap();
285 assert!(
286 hits.iter().any(|(s, _)| s.id.as_str() == "s-a"),
287 "expected GAN search to be found; got {:?}",
288 hits.iter().map(|(s, _)| s.id.as_str()).collect::<Vec<_>>()
289 );
290 }
291
292 #[test]
293 fn fts_find_similar_empty_query() {
294 let (_, repo, _) = setup();
295 repo.save(&Search::new("something")).unwrap();
296 assert!(repo.find_similar("()(", 10).unwrap().is_empty());
297 }
298
299 #[test]
300 fn test_save_and_get_results() {
301 let (_, search_repo, paper_repo) = setup();
302
303 let paper = Paper::new("Test Paper");
304 paper_repo.save(&paper).unwrap();
305
306 let search = Search::new("test");
307 search_repo.save(&search).unwrap();
308
309 let result = SearchResult {
310 search_id: search.id.clone(),
311 paper_id: paper.id.clone(),
312 source: "pubmed".to_string(),
313 rank: Some(1),
314 score: Some(0.95),
315 raw_metadata: serde_json::Value::Null,
316 };
317 search_repo.save_results(&[result]).unwrap();
318
319 let results = search_repo.get_results(search.id.as_str()).unwrap();
320 assert_eq!(results.len(), 1);
321 assert_eq!(results[0].source, "pubmed");
322 }
323}