1mod annotations;
2mod assessments;
3mod citations;
4mod migrations;
5mod paper_state;
6mod papers;
7mod questions;
8mod searches;
9mod shortlist;
10mod tui_state;
11
12pub use annotations::{SqliteAnnotationRepository, resolve_anchor};
13pub use assessments::SqliteAssessmentRepository;
14pub use citations::SqliteCitationRepository;
15pub use migrations::run_migrations;
16pub use paper_state::{PaperState, SqlitePaperStateRepository};
17pub use papers::SqlitePaperRepository;
18pub use questions::SqliteQuestionRepository;
19pub use searches::SqliteSearchRepository;
20pub use shortlist::SqliteShortlistRepository;
21pub use tui_state::{SqliteTuiStateRepository, TuiState};
22
23use r2d2::Pool;
24use r2d2_sqlite::SqliteConnectionManager;
25use std::path::Path;
26
27use crate::error::DbError;
28
29pub(crate) fn parse_rfc3339_or_now(s: &str) -> chrono::DateTime<chrono::Utc> {
31 chrono::DateTime::parse_from_rfc3339(s)
32 .map_or_else(|_| chrono::Utc::now(), |dt| dt.with_timezone(&chrono::Utc))
33}
34
35#[derive(Clone)]
37pub struct Database {
38 pool: Pool<SqliteConnectionManager>,
39}
40
41impl Database {
42 pub fn open(path: &Path) -> Result<Self, DbError> {
44 if let Some(parent) = path.parent() {
45 std::fs::create_dir_all(parent)
46 .map_err(|e| DbError::Migration(format!("failed to create db directory: {e}")))?;
47 }
48
49 let manager = SqliteConnectionManager::file(path).with_init(|conn| {
50 conn.execute_batch(
56 "PRAGMA journal_mode=WAL;
57 PRAGMA synchronous=NORMAL;
58 PRAGMA foreign_keys=ON;
59 PRAGMA busy_timeout=5000;",
60 )?;
61 Ok(())
62 });
63
64 let pool = Pool::builder().max_size(4).build(manager)?;
65
66 Ok(Self { pool })
67 }
68
69 pub fn open_in_memory() -> Result<Self, DbError> {
71 let manager = SqliteConnectionManager::memory().with_init(|conn| {
72 conn.execute_batch(
73 "PRAGMA journal_mode=WAL;
74 PRAGMA foreign_keys=ON;",
75 )?;
76 Ok(())
77 });
78
79 let pool = Pool::builder().max_size(1).build(manager)?;
80
81 Ok(Self { pool })
82 }
83
84 pub fn migrate(&self) -> Result<(), DbError> {
91 let conn = self.pool.get()?;
92 run_migrations(&conn)?;
93 drop(conn);
94 self.backfill_bibtex_keys()?;
95 Ok(())
96 }
97
98 fn backfill_bibtex_keys(&self) -> Result<(), DbError> {
104 use scitadel_core::bibtex_key::assign_keys;
105 use std::collections::HashSet;
106
107 let conn = self.pool.get()?;
108
109 let mut taken: HashSet<String> = conn
111 .prepare("SELECT bibtex_key FROM papers WHERE bibtex_key IS NOT NULL")?
112 .query_map([], |row| row.get::<_, String>(0))?
113 .filter_map(Result::ok)
114 .collect();
115
116 let mut stmt =
119 conn.prepare("SELECT id, title, authors, year FROM papers WHERE bibtex_key IS NULL")?;
120 let rows: Vec<(String, String, String, Option<i32>)> = stmt
121 .query_map([], |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?)))?
122 .filter_map(Result::ok)
123 .collect();
124 drop(stmt);
125
126 if rows.is_empty() {
127 return Ok(());
128 }
129
130 let papers: Vec<scitadel_core::models::Paper> = rows
133 .iter()
134 .map(|(id, title, authors_json, year)| {
135 let mut p = scitadel_core::models::Paper::new(title);
136 p.id = scitadel_core::models::PaperId::from(id.as_str());
137 p.authors = serde_json::from_str(authors_json).unwrap_or_default();
138 p.year = *year;
139 p
140 })
141 .collect();
142
143 let keys = assign_keys(&papers, &mut taken);
144 for (paper, key) in papers.iter().zip(keys) {
145 conn.execute(
146 "UPDATE papers SET bibtex_key = ?1 WHERE id = ?2",
147 rusqlite::params![key, paper.id.as_str()],
148 )?;
149 }
150 Ok(())
151 }
152
153 pub fn conn(&self) -> Result<r2d2::PooledConnection<SqliteConnectionManager>, DbError> {
155 Ok(self.pool.get()?)
156 }
157
158 pub fn repositories(
160 &self,
161 ) -> (
162 SqlitePaperRepository,
163 SqliteSearchRepository,
164 SqliteQuestionRepository,
165 SqliteAssessmentRepository,
166 SqliteCitationRepository,
167 ) {
168 let db = self.clone();
169 (
170 SqlitePaperRepository::new(db.clone()),
171 SqliteSearchRepository::new(db.clone()),
172 SqliteQuestionRepository::new(db.clone()),
173 SqliteAssessmentRepository::new(db.clone()),
174 SqliteCitationRepository::new(db),
175 )
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use scitadel_core::models::Paper;
183 use scitadel_core::ports::PaperRepository;
184
185 #[test]
190 fn cross_process_write_visible_within_one_redraw() {
191 let dir = tempfile::tempdir().unwrap();
192 let db_path = dir.path().join("scitadel.db");
193
194 let db_a = Database::open(&db_path).unwrap();
196 db_a.migrate().unwrap();
197 let (paper_repo_a, _, _, _, _) = db_a.repositories();
198
199 let db_b = Database::open(&db_path).unwrap();
201 let (paper_repo_b, _, _, _, _) = db_b.repositories();
202
203 assert!(paper_repo_a.list_all(10, 0).unwrap().is_empty());
205
206 let p = Paper::new("MCP-side write");
208 paper_repo_b.save(&p).unwrap();
209
210 let papers = paper_repo_a.list_all(10, 0).unwrap();
214 assert_eq!(papers.len(), 1, "TUI process must see MCP process's write");
215 assert_eq!(papers[0].title, "MCP-side write");
216 }
217
218 #[test]
221 fn migrate_backfills_bibtex_keys() {
222 let dir = tempfile::tempdir().unwrap();
223 let db_path = dir.path().join("backfill.db");
224 let db = Database::open(&db_path).unwrap();
225 db.migrate().unwrap();
226
227 let conn = db.conn().unwrap();
229 for (id, title, authors, year) in [
230 (
231 "p-1",
232 "Attention Is All You Need",
233 r#"["Vaswani, A."]"#,
234 2017,
235 ),
236 ("p-2", "Deep Residual Learning", r#"["Kaiming He"]"#, 2015),
237 ("p-3", "Quantum Computing", r#"["Müller, Hans"]"#, 2023),
238 ] {
239 conn.execute(
240 "INSERT INTO papers (id, title, authors, year, created_at, updated_at)
241 VALUES (?1, ?2, ?3, ?4, datetime('now'), datetime('now'))",
242 rusqlite::params![id, title, authors, year],
243 )
244 .unwrap();
245 }
246 conn.execute("UPDATE papers SET bibtex_key = NULL", [])
248 .unwrap();
249 drop(conn);
250
251 db.migrate().unwrap();
252
253 let conn = db.conn().unwrap();
254 let keys: Vec<String> = conn
255 .prepare("SELECT bibtex_key FROM papers ORDER BY id")
256 .unwrap()
257 .query_map([], |r| r.get::<_, String>(0))
258 .unwrap()
259 .filter_map(Result::ok)
260 .collect();
261 assert_eq!(keys.len(), 3, "every paper got a key");
262 assert!(
263 keys.contains(&"vaswani2017attention".to_string())
264 || keys.contains(&"vaswani2017transformer".to_string())
265 || keys.iter().any(|k| k.starts_with("vaswani2017")),
266 "got: {keys:?}"
267 );
268 let unique: std::collections::HashSet<_> = keys.iter().collect();
270 assert_eq!(unique.len(), keys.len());
271
272 db.migrate().unwrap();
274 let keys2: Vec<String> = db
275 .conn()
276 .unwrap()
277 .prepare("SELECT bibtex_key FROM papers ORDER BY id")
278 .unwrap()
279 .query_map([], |r| r.get::<_, String>(0))
280 .unwrap()
281 .filter_map(Result::ok)
282 .collect();
283 assert_eq!(keys, keys2, "re-migrate is idempotent");
284 }
285
286 #[test]
287 fn pragma_journal_mode_is_wal_on_disk() {
288 let dir = tempfile::tempdir().unwrap();
289 let db_path = dir.path().join("pragma.db");
290 let db = Database::open(&db_path).unwrap();
291 let conn = db.conn().unwrap();
292 let mode: String = conn
293 .query_row("PRAGMA journal_mode", [], |r| r.get(0))
294 .unwrap();
295 assert_eq!(mode.to_lowercase(), "wal");
296 }
297}