use rusqlite::{OptionalExtension, params};
use scitadel_core::error::CoreError;
use scitadel_core::models::{PaperId, Search, SearchId, SearchResult, SourceOutcome};
use scitadel_core::ports::SearchRepository;
use super::Database;
use crate::error::DbError;
pub struct SqliteSearchRepository {
db: Database,
}
impl SqliteSearchRepository {
pub fn new(db: Database) -> Self {
Self { db }
}
pub fn find_similar(&self, query: &str, limit: i64) -> Result<Vec<(Search, f64)>, DbError> {
let sanitized = sanitize_fts5_query(query);
if sanitized.is_empty() {
return Ok(Vec::new());
}
let conn = self.db.conn()?;
let mut stmt = conn
.prepare(
"SELECT s.*, bm25(searches_fts) AS rank
FROM searches_fts f
JOIN searches s ON s.id = f.search_id
WHERE searches_fts MATCH ?1
ORDER BY rank ASC
LIMIT ?2",
)
.map_err(DbError::Sqlite)?;
let rows = stmt
.query_map(params![sanitized, limit], |row| {
let search = row_to_search(row)?;
let rank: f64 = row.get("rank")?;
Ok((search, rank))
})
.map_err(DbError::Sqlite)?;
let out: Vec<(Search, f64)> = rows.filter_map(Result::ok).collect();
Ok(out)
}
}
fn sanitize_fts5_query(q: &str) -> String {
q.chars()
.map(|c| match c {
'"' | '\'' | '(' | ')' | ':' | '*' | '-' => ' ',
other => other,
})
.collect::<String>()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
fn row_to_search(row: &rusqlite::Row) -> rusqlite::Result<Search> {
let id: String = row.get("id")?;
let sources_json: String = row.get("sources")?;
let parameters_json: String = row.get("parameters")?;
let outcomes_json: String = row.get("source_outcomes")?;
let created_at: String = row.get("created_at")?;
let outcomes: Vec<SourceOutcome> = serde_json::from_str(&outcomes_json).unwrap_or_default();
Ok(Search {
id: SearchId::from(id),
query: row.get("query")?,
sources: serde_json::from_str(&sources_json).unwrap_or_default(),
parameters: serde_json::from_str(¶meters_json).unwrap_or_default(),
source_outcomes: outcomes,
total_candidates: row.get("total_candidates")?,
total_papers: row.get("total_papers")?,
created_at: super::parse_rfc3339_or_now(&created_at),
})
}
impl SearchRepository for SqliteSearchRepository {
fn save(&self, search: &Search) -> Result<(), CoreError> {
let conn = self.db.conn()?;
let outcomes_json = serde_json::to_string(&search.source_outcomes).unwrap_or_default();
conn.execute(
"INSERT INTO searches
(id, query, sources, parameters, source_outcomes,
total_candidates, total_papers, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
ON CONFLICT(id) DO UPDATE SET
query = excluded.query,
sources = excluded.sources,
parameters = excluded.parameters,
source_outcomes = excluded.source_outcomes,
total_candidates = excluded.total_candidates,
total_papers = excluded.total_papers",
params![
search.id.as_str(),
search.query,
serde_json::to_string(&search.sources).unwrap_or_default(),
serde_json::to_string(&search.parameters).unwrap_or_default(),
outcomes_json,
search.total_candidates,
search.total_papers,
search.created_at.to_rfc3339(),
],
)
.map_err(DbError::Sqlite)?;
Ok(())
}
fn get(&self, search_id: &str) -> Result<Option<Search>, CoreError> {
let conn = self.db.conn()?;
let mut stmt = conn
.prepare("SELECT * FROM searches WHERE id = ?1")
.map_err(DbError::Sqlite)?;
let result = stmt
.query_row(params![search_id], row_to_search)
.optional()
.map_err(DbError::Sqlite)?;
Ok(result)
}
fn save_results(&self, results: &[SearchResult]) -> Result<(), CoreError> {
let mut conn = self.db.conn()?;
let tx = conn.transaction().map_err(DbError::Sqlite)?;
for r in results {
tx.execute(
"INSERT INTO search_results
(search_id, paper_id, source, rank, score, raw_metadata)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
ON CONFLICT(search_id, paper_id, source) DO UPDATE SET
rank = excluded.rank,
score = excluded.score,
raw_metadata = excluded.raw_metadata",
params![
r.search_id.as_str(),
r.paper_id.as_str(),
r.source,
r.rank,
r.score,
serde_json::to_string(&r.raw_metadata).unwrap_or_default(),
],
)
.map_err(DbError::Sqlite)?;
}
tx.commit().map_err(DbError::Sqlite)?;
Ok(())
}
fn get_results(&self, search_id: &str) -> Result<Vec<SearchResult>, CoreError> {
let conn = self.db.conn()?;
let mut stmt = conn
.prepare("SELECT * FROM search_results WHERE search_id = ?1")
.map_err(DbError::Sqlite)?;
let results = stmt
.query_map(params![search_id], |row| {
let search_id: String = row.get("search_id")?;
let paper_id: String = row.get("paper_id")?;
let raw_json: String = row.get("raw_metadata")?;
Ok(SearchResult {
search_id: SearchId::from(search_id),
paper_id: PaperId::from(paper_id),
source: row.get("source")?,
rank: row.get("rank")?,
score: row.get("score")?,
raw_metadata: serde_json::from_str(&raw_json).unwrap_or_default(),
})
})
.map_err(DbError::Sqlite)?
.filter_map(Result::ok)
.collect();
Ok(results)
}
fn list_searches(&self, limit: i64) -> Result<Vec<Search>, CoreError> {
let conn = self.db.conn()?;
let mut stmt = conn
.prepare("SELECT * FROM searches ORDER BY created_at DESC LIMIT ?1")
.map_err(DbError::Sqlite)?;
let searches = stmt
.query_map(params![limit], row_to_search)
.map_err(DbError::Sqlite)?
.filter_map(Result::ok)
.collect();
Ok(searches)
}
fn diff_searches(
&self,
search_id_a: &str,
search_id_b: &str,
) -> Result<(Vec<String>, Vec<String>), CoreError> {
let conn = self.db.conn()?;
let get_paper_ids =
|search_id: &str| -> Result<std::collections::HashSet<String>, DbError> {
let mut stmt = conn
.prepare("SELECT DISTINCT paper_id FROM search_results WHERE search_id = ?1")
.map_err(DbError::Sqlite)?;
let ids: std::collections::HashSet<String> = stmt
.query_map(params![search_id], |row| row.get(0))
.map_err(DbError::Sqlite)?
.filter_map(Result::ok)
.collect();
Ok(ids)
};
let papers_a = get_paper_ids(search_id_a).map_err(Into::<CoreError>::into)?;
let papers_b = get_paper_ids(search_id_b).map_err(Into::<CoreError>::into)?;
let mut added: Vec<String> = papers_b.difference(&papers_a).cloned().collect();
let mut removed: Vec<String> = papers_a.difference(&papers_b).cloned().collect();
added.sort();
removed.sort();
Ok((added, removed))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sqlite::{Database, SqlitePaperRepository};
use scitadel_core::models::Paper;
use scitadel_core::ports::PaperRepository;
fn setup() -> (Database, SqliteSearchRepository, SqlitePaperRepository) {
let db = Database::open_in_memory().unwrap();
db.migrate().unwrap();
let search_repo = SqliteSearchRepository::new(db.clone());
let paper_repo = SqlitePaperRepository::new(db.clone());
(db, search_repo, paper_repo)
}
#[test]
fn test_save_and_get_search() {
let (_, repo, _) = setup();
let search = Search::new("test query");
repo.save(&search).unwrap();
let loaded = repo.get(search.id.as_str()).unwrap().unwrap();
assert_eq!(loaded.query, "test query");
}
#[test]
fn fts_sanitizer_strips_operators() {
assert_eq!(sanitize_fts5_query(r#"GAN "stability""#), "GAN stability");
assert_eq!(sanitize_fts5_query("foo (bar) - baz"), "foo bar baz");
assert_eq!(sanitize_fts5_query(" "), "");
assert_eq!(sanitize_fts5_query("scope:field"), "scope field");
}
#[test]
fn fts_find_similar_roundtrip() {
let (_, repo, _) = setup();
let a = {
let mut s = Search::new("generative adversarial networks stability");
s.id = SearchId::from("s-a");
s
};
let b = {
let mut s = Search::new("attention is all you need transformers");
s.id = SearchId::from("s-b");
s
};
let c = {
let mut s = Search::new("retrieval augmented generation");
s.id = SearchId::from("s-c");
s
};
repo.save(&a).unwrap();
repo.save(&b).unwrap();
repo.save(&c).unwrap();
let hits = repo.find_similar("generative networks", 10).unwrap();
assert!(
hits.iter().any(|(s, _)| s.id.as_str() == "s-a"),
"expected GAN search to be found; got {:?}",
hits.iter().map(|(s, _)| s.id.as_str()).collect::<Vec<_>>()
);
}
#[test]
fn fts_find_similar_empty_query() {
let (_, repo, _) = setup();
repo.save(&Search::new("something")).unwrap();
assert!(repo.find_similar("()(", 10).unwrap().is_empty());
}
#[test]
fn test_save_and_get_results() {
let (_, search_repo, paper_repo) = setup();
let paper = Paper::new("Test Paper");
paper_repo.save(&paper).unwrap();
let search = Search::new("test");
search_repo.save(&search).unwrap();
let result = SearchResult {
search_id: search.id.clone(),
paper_id: paper.id.clone(),
source: "pubmed".to_string(),
rank: Some(1),
score: Some(0.95),
raw_metadata: serde_json::Value::Null,
};
search_repo.save_results(&[result]).unwrap();
let results = search_repo.get_results(search.id.as_str()).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].source, "pubmed");
}
}