1mod annotations;
2mod assessments;
3mod citations;
4mod migrations;
5mod paper_aliases;
6mod paper_state;
7mod paper_tags;
8mod papers;
9mod questions;
10mod searches;
11mod shortlist;
12mod tui_state;
13
14pub use annotations::{SqliteAnnotationRepository, resolve_anchor};
15
16pub use assessments::SqliteAssessmentRepository;
17pub use citations::SqliteCitationRepository;
18pub use migrations::run_migrations;
19pub use paper_aliases::{SOURCE_BIBTEX_IMPORT, SOURCE_REKEY, SqlitePaperAliasRepository};
20pub use paper_state::{PaperState, SqlitePaperStateRepository};
21pub use paper_tags::{SqlitePaperTagRepository, TAG_SOURCE_BIBTEX_IMPORT};
22pub use papers::SqlitePaperRepository;
23pub use questions::SqliteQuestionRepository;
24pub use rusqlite::Transaction as SqliteTransaction;
28pub use searches::SqliteSearchRepository;
29pub use shortlist::SqliteShortlistRepository;
30pub use tui_state::{SqliteTuiStateRepository, TuiState};
31
32use r2d2::Pool;
33use r2d2_sqlite::SqliteConnectionManager;
34use rusqlite::functions::FunctionFlags;
35use std::path::Path;
36
37use crate::error::DbError;
38
39fn register_unicode_lower(conn: &rusqlite::Connection) -> rusqlite::Result<()> {
45 conn.create_scalar_function(
46 "unicode_lower",
47 1,
48 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
49 |ctx| {
50 let s = ctx.get::<Option<String>>(0)?;
51 Ok(s.map(|v| v.to_lowercase()))
52 },
53 )
54}
55
56pub(crate) fn parse_rfc3339_or_now(s: &str) -> chrono::DateTime<chrono::Utc> {
58 chrono::DateTime::parse_from_rfc3339(s)
59 .map_or_else(|_| chrono::Utc::now(), |dt| dt.with_timezone(&chrono::Utc))
60}
61
62#[derive(Clone)]
64pub struct Database {
65 pool: Pool<SqliteConnectionManager>,
66}
67
68impl Database {
69 pub fn open(path: &Path) -> Result<Self, DbError> {
71 if let Some(parent) = path.parent() {
72 std::fs::create_dir_all(parent)
73 .map_err(|e| DbError::Migration(format!("failed to create db directory: {e}")))?;
74 }
75
76 let manager = SqliteConnectionManager::file(path).with_init(|conn| {
77 conn.execute_batch(
83 "PRAGMA journal_mode=WAL;
84 PRAGMA synchronous=NORMAL;
85 PRAGMA foreign_keys=ON;
86 PRAGMA busy_timeout=5000;",
87 )?;
88 register_unicode_lower(conn)?;
89 Ok(())
90 });
91
92 let pool = Pool::builder().max_size(4).build(manager)?;
93
94 Ok(Self { pool })
95 }
96
97 pub fn open_in_memory() -> Result<Self, DbError> {
99 let manager = SqliteConnectionManager::memory().with_init(|conn| {
100 conn.execute_batch(
101 "PRAGMA journal_mode=WAL;
102 PRAGMA foreign_keys=ON;",
103 )?;
104 register_unicode_lower(conn)?;
105 Ok(())
106 });
107
108 let pool = Pool::builder().max_size(1).build(manager)?;
109
110 Ok(Self { pool })
111 }
112
113 pub fn migrate(&self) -> Result<(), DbError> {
120 let conn = self.pool.get()?;
121 run_migrations(&conn)?;
122 drop(conn);
123 self.backfill_bibtex_keys()?;
124 Ok(())
125 }
126
127 fn backfill_bibtex_keys(&self) -> Result<(), DbError> {
133 use scitadel_core::bibtex_key::assign_keys;
134 use std::collections::HashSet;
135
136 let conn = self.pool.get()?;
137
138 let mut taken: HashSet<String> = conn
140 .prepare("SELECT bibtex_key FROM papers WHERE bibtex_key IS NOT NULL")?
141 .query_map([], |row| row.get::<_, String>(0))?
142 .filter_map(Result::ok)
143 .collect();
144
145 let mut stmt =
148 conn.prepare("SELECT id, title, authors, year FROM papers WHERE bibtex_key IS NULL")?;
149 let rows: Vec<(String, String, String, Option<i32>)> = stmt
150 .query_map([], |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?)))?
151 .filter_map(Result::ok)
152 .collect();
153 drop(stmt);
154
155 if rows.is_empty() {
156 return Ok(());
157 }
158
159 let papers: Vec<scitadel_core::models::Paper> = rows
162 .iter()
163 .map(|(id, title, authors_json, year)| {
164 let mut p = scitadel_core::models::Paper::new(title);
165 p.id = scitadel_core::models::PaperId::from(id.as_str());
166 p.authors = serde_json::from_str(authors_json).unwrap_or_default();
167 p.year = *year;
168 p
169 })
170 .collect();
171
172 let keys = assign_keys(&papers, &mut taken);
173 for (paper, key) in papers.iter().zip(keys) {
174 conn.execute(
175 "UPDATE papers SET bibtex_key = ?1 WHERE id = ?2",
176 rusqlite::params![key, paper.id.as_str()],
177 )?;
178 }
179 Ok(())
180 }
181
182 pub fn conn(&self) -> Result<r2d2::PooledConnection<SqliteConnectionManager>, DbError> {
184 Ok(self.pool.get()?)
185 }
186
187 pub fn repositories(
189 &self,
190 ) -> (
191 SqlitePaperRepository,
192 SqliteSearchRepository,
193 SqliteQuestionRepository,
194 SqliteAssessmentRepository,
195 SqliteCitationRepository,
196 ) {
197 let db = self.clone();
198 (
199 SqlitePaperRepository::new(db.clone()),
200 SqliteSearchRepository::new(db.clone()),
201 SqliteQuestionRepository::new(db.clone()),
202 SqliteAssessmentRepository::new(db.clone()),
203 SqliteCitationRepository::new(db),
204 )
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use scitadel_core::models::Paper;
212 use scitadel_core::ports::PaperRepository;
213
214 #[test]
219 fn cross_process_write_visible_within_one_redraw() {
220 let dir = tempfile::tempdir().unwrap();
221 let db_path = dir.path().join("scitadel.db");
222
223 let db_a = Database::open(&db_path).unwrap();
225 db_a.migrate().unwrap();
226 let (paper_repo_a, _, _, _, _) = db_a.repositories();
227
228 let db_b = Database::open(&db_path).unwrap();
230 let (paper_repo_b, _, _, _, _) = db_b.repositories();
231
232 assert!(paper_repo_a.list_all(10, 0).unwrap().is_empty());
234
235 let p = Paper::new("MCP-side write");
237 paper_repo_b.save(&p).unwrap();
238
239 let papers = paper_repo_a.list_all(10, 0).unwrap();
243 assert_eq!(papers.len(), 1, "TUI process must see MCP process's write");
244 assert_eq!(papers[0].title, "MCP-side write");
245 }
246
247 #[test]
250 fn migrate_backfills_bibtex_keys() {
251 let dir = tempfile::tempdir().unwrap();
252 let db_path = dir.path().join("backfill.db");
253 let db = Database::open(&db_path).unwrap();
254 db.migrate().unwrap();
255
256 let conn = db.conn().unwrap();
258 for (id, title, authors, year) in [
259 (
260 "p-1",
261 "Attention Is All You Need",
262 r#"["Vaswani, A."]"#,
263 2017,
264 ),
265 ("p-2", "Deep Residual Learning", r#"["Kaiming He"]"#, 2015),
266 ("p-3", "Quantum Computing", r#"["Müller, Hans"]"#, 2023),
267 ] {
268 conn.execute(
269 "INSERT INTO papers (id, title, authors, year, created_at, updated_at)
270 VALUES (?1, ?2, ?3, ?4, datetime('now'), datetime('now'))",
271 rusqlite::params![id, title, authors, year],
272 )
273 .unwrap();
274 }
275 conn.execute("UPDATE papers SET bibtex_key = NULL", [])
277 .unwrap();
278 drop(conn);
279
280 db.migrate().unwrap();
281
282 let conn = db.conn().unwrap();
283 let keys: Vec<String> = conn
284 .prepare("SELECT bibtex_key FROM papers ORDER BY id")
285 .unwrap()
286 .query_map([], |r| r.get::<_, String>(0))
287 .unwrap()
288 .filter_map(Result::ok)
289 .collect();
290 assert_eq!(keys.len(), 3, "every paper got a key");
291 assert!(
292 keys.contains(&"vaswani2017attention".to_string())
293 || keys.contains(&"vaswani2017transformer".to_string())
294 || keys.iter().any(|k| k.starts_with("vaswani2017")),
295 "got: {keys:?}"
296 );
297 let unique: std::collections::HashSet<_> = keys.iter().collect();
299 assert_eq!(unique.len(), keys.len());
300
301 db.migrate().unwrap();
303 let keys2: Vec<String> = db
304 .conn()
305 .unwrap()
306 .prepare("SELECT bibtex_key FROM papers ORDER BY id")
307 .unwrap()
308 .query_map([], |r| r.get::<_, String>(0))
309 .unwrap()
310 .filter_map(Result::ok)
311 .collect();
312 assert_eq!(keys, keys2, "re-migrate is idempotent");
313 }
314
315 #[test]
316 fn pragma_journal_mode_is_wal_on_disk() {
317 let dir = tempfile::tempdir().unwrap();
318 let db_path = dir.path().join("pragma.db");
319 let db = Database::open(&db_path).unwrap();
320 let conn = db.conn().unwrap();
321 let mode: String = conn
322 .query_row("PRAGMA journal_mode", [], |r| r.get(0))
323 .unwrap();
324 assert_eq!(mode.to_lowercase(), "wal");
325 }
326}