1use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::error::StorageError;
9use khive_storage::note::Note;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::NoteStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18 StorageError::driver(StorageCapability::Notes, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22 StorageError::driver(StorageCapability::Notes, op, e)
23}
24
25pub struct SqlNoteStore {
30 pool: Arc<ConnectionPool>,
31 is_file_backed: bool,
32}
33
34impl SqlNoteStore {
35 pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37 Self {
38 pool,
39 is_file_backed,
40 }
41 }
42
43 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44 let config = self.pool.config();
45 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46 operation: "note_reader".into(),
47 message: "in-memory databases do not support standalone connections".into(),
48 })?;
49
50 let conn = rusqlite::Connection::open_with_flags(
51 path,
52 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55 )
56 .map_err(|e| map_err(e, "open_note_reader"))?;
57
58 conn.busy_timeout(config.busy_timeout)
59 .map_err(|e| map_err(e, "open_note_reader"))?;
60 conn.pragma_update(None, "foreign_keys", "ON")
61 .map_err(|e| map_err(e, "open_note_reader"))?;
62 conn.pragma_update(None, "synchronous", "NORMAL")
63 .map_err(|e| map_err(e, "open_note_reader"))?;
64
65 Ok(conn)
66 }
67
68 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
70 where
71 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
72 R: Send + 'static,
73 {
74 let pool = Arc::clone(&self.pool);
75 tokio::task::spawn_blocking(move || {
76 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
77 f(guard.conn()).map_err(|e| map_err(e, op))
78 })
79 .await
80 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
81 }
82
83 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
84 where
85 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
86 R: Send + 'static,
87 {
88 if self.is_file_backed {
89 let conn = self.open_standalone_reader()?;
90 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
91 .await
92 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
93 } else {
94 let pool = Arc::clone(&self.pool);
95 tokio::task::spawn_blocking(move || {
96 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
97 f(guard.conn()).map_err(|e| map_err(e, op))
98 })
99 .await
100 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
101 }
102 }
103}
104
105fn read_note(row: &rusqlite::Row<'_>) -> Result<Note, rusqlite::Error> {
110 let id_str: String = row.get(0)?;
111 let namespace: String = row.get(1)?;
112 let kind: String = row.get(2)?;
113 let status: String = row.get(3)?;
114 let name: Option<String> = row.get(4)?;
115 let content: String = row.get(5)?;
116 let salience: Option<f64> = row.get(6)?;
117 let decay_factor: Option<f64> = row.get(7)?;
118 let expires_at: Option<i64> = row.get(8)?;
119 let properties_str: Option<String> = row.get(9)?;
120 let created_at: i64 = row.get(10)?;
121 let updated_at: i64 = row.get(11)?;
122 let deleted_at: Option<i64> = row.get(12)?;
123
124 let id = parse_uuid(&id_str)?;
125
126 let properties = properties_str
127 .map(|s| {
128 serde_json::from_str(&s).map_err(|e| {
129 rusqlite::Error::FromSqlConversionFailure(
130 9,
131 rusqlite::types::Type::Text,
132 Box::new(e),
133 )
134 })
135 })
136 .transpose()?;
137
138 Ok(Note {
139 id,
140 namespace,
141 kind,
142 status,
143 name,
144 content,
145 salience,
146 decay_factor,
147 expires_at,
148 properties,
149 created_at,
150 updated_at,
151 deleted_at,
152 })
153}
154
155fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
156 Uuid::parse_str(s).map_err(|e| {
157 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
158 })
159}
160
161fn build_note_where(
162 namespace: &str,
163 kind: Option<&str>,
164) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
165 let mut conditions: Vec<String> = vec![
166 "namespace = ?1".to_string(),
167 "deleted_at IS NULL".to_string(),
168 ];
169 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
170
171 if let Some(k) = kind {
172 params.push(Box::new(k.to_string()));
173 conditions.push(format!("kind = ?{}", params.len()));
174 }
175
176 let clause = format!(" WHERE {}", conditions.join(" AND "));
177 (clause, params)
178}
179
180#[async_trait]
185impl NoteStore for SqlNoteStore {
186 async fn upsert_note(&self, note: Note) -> Result<(), StorageError> {
187 let namespace = note.namespace.clone();
188 let id_str = note.id.to_string();
189 let kind_str = note.kind.to_string();
190 let status_str = note.status.clone();
191 let properties_str = note
192 .properties
193 .as_ref()
194 .map(|v| serde_json::to_string(v).unwrap_or_default());
195
196 self.with_writer("upsert_note", move |conn| {
197 conn.execute(
198 "INSERT OR REPLACE INTO notes \
199 (id, namespace, kind, status, name, content, salience, decay_factor, expires_at, \
200 properties, created_at, updated_at, deleted_at) \
201 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
202 rusqlite::params![
203 id_str,
204 namespace,
205 kind_str,
206 status_str,
207 note.name,
208 note.content,
209 note.salience,
210 note.decay_factor,
211 note.expires_at,
212 properties_str,
213 note.created_at,
214 note.updated_at,
215 note.deleted_at,
216 ],
217 )?;
218 Ok(())
219 })
220 .await
221 }
222
223 async fn upsert_notes(&self, notes: Vec<Note>) -> Result<BatchWriteSummary, StorageError> {
224 let attempted = notes.len() as u64;
225
226 self.with_writer("upsert_notes", move |conn| {
227 conn.execute_batch("BEGIN IMMEDIATE")?;
228 let mut affected = 0u64;
229 let mut failed = 0u64;
230 let mut first_error = String::new();
231
232 for note in ¬es {
233 let id_str = note.id.to_string();
234 let kind_str = note.kind.to_string();
235 let status_str = note.status.clone();
236 let properties_str = note
237 .properties
238 .as_ref()
239 .map(|v| serde_json::to_string(v).unwrap_or_default());
240
241 match conn.execute(
242 "INSERT OR REPLACE INTO notes \
243 (id, namespace, kind, status, name, content, salience, decay_factor, expires_at, \
244 properties, created_at, updated_at, deleted_at) \
245 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
246 rusqlite::params![
247 id_str,
248 ¬e.namespace,
249 kind_str,
250 status_str,
251 ¬e.name,
252 note.content,
253 note.salience,
254 note.decay_factor,
255 note.expires_at,
256 properties_str,
257 note.created_at,
258 note.updated_at,
259 note.deleted_at,
260 ],
261 ) {
262 Ok(_) => affected += 1,
263 Err(e) => {
264 if first_error.is_empty() {
265 first_error = e.to_string();
266 }
267 failed += 1;
268 }
269 }
270 }
271
272 if let Err(e) = conn.execute_batch("COMMIT") {
273 let _ = conn.execute_batch("ROLLBACK");
274 return Err(e);
275 }
276 Ok(BatchWriteSummary {
277 attempted,
278 affected,
279 failed,
280 first_error,
281 })
282 })
283 .await
284 }
285
286 async fn get_note(&self, id: Uuid) -> Result<Option<Note>, StorageError> {
287 let id_str = id.to_string();
288
289 self.with_reader("get_note", move |conn| {
290 let mut stmt = conn.prepare(
291 "SELECT id, namespace, kind, status, name, content, salience, decay_factor, expires_at, \
292 properties, created_at, updated_at, deleted_at \
293 FROM notes WHERE id = ?1 AND deleted_at IS NULL",
294 )?;
295 let mut rows = stmt.query(rusqlite::params![id_str])?;
296 match rows.next()? {
297 Some(row) => Ok(Some(read_note(row)?)),
298 None => Ok(None),
299 }
300 })
301 .await
302 }
303
304 async fn get_notes_batch(&self, ids: &[Uuid]) -> Result<Vec<Note>, StorageError> {
305 if ids.is_empty() {
306 return Ok(vec![]);
307 }
308 let id_strings: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
309
310 self.with_reader("get_notes_batch", move |conn| {
311 let placeholders: String = (1..=id_strings.len())
312 .map(|i| format!("?{i}"))
313 .collect::<Vec<_>>()
314 .join(", ");
315 let sql = format!(
316 "SELECT id, namespace, kind, status, name, content, salience, decay_factor, expires_at, \
317 properties, created_at, updated_at, deleted_at \
318 FROM notes WHERE id IN ({placeholders}) AND deleted_at IS NULL"
319 );
320 let mut stmt = conn.prepare(&sql)?;
321 let params: Vec<&dyn rusqlite::types::ToSql> = id_strings
322 .iter()
323 .map(|s| s as &dyn rusqlite::types::ToSql)
324 .collect();
325 let rows = stmt.query_map(params.as_slice(), read_note)?;
326 let mut out = Vec::new();
327 for row in rows {
328 out.push(row?);
329 }
330 Ok(out)
331 })
332 .await
333 }
334
335 async fn delete_note(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
336 let id_str = id.to_string();
337
338 match mode {
339 DeleteMode::Soft => {
340 self.with_writer("delete_note_soft", move |conn| {
341 let now = chrono::Utc::now().timestamp_micros();
342 let deleted = conn.execute(
343 "UPDATE notes SET status = 'deleted', deleted_at = ?1 \
344 WHERE id = ?2 AND deleted_at IS NULL",
345 rusqlite::params![now, id_str],
346 )?;
347 Ok(deleted > 0)
348 })
349 .await
350 }
351 DeleteMode::Hard => {
352 self.with_writer("delete_note_hard", move |conn| {
353 let deleted =
354 conn.execute("DELETE FROM notes WHERE id = ?1", rusqlite::params![id_str])?;
355 Ok(deleted > 0)
356 })
357 .await
358 }
359 }
360 }
361
362 async fn query_notes(
363 &self,
364 namespace: &str,
365 kind: Option<&str>,
366 page: PageRequest,
367 ) -> Result<Page<Note>, StorageError> {
368 let namespace = namespace.to_string();
369 let kind = kind.map(|k| k.to_string());
370
371 self.with_reader("query_notes", move |conn| {
372 let (count_sql, count_params) = build_note_where(&namespace, kind.as_deref());
373 let total: i64 = {
374 let sql = format!("SELECT COUNT(*) FROM notes{}", count_sql);
375 let mut stmt = conn.prepare(&sql)?;
376 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
377 count_params.iter().map(|p| p.as_ref()).collect();
378 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
379 };
380
381 let (where_sql, mut data_params) = build_note_where(&namespace, kind.as_deref());
382 data_params.push(Box::new(page.limit as i64));
383 data_params.push(Box::new(page.offset as i64));
384
385 let limit_idx = data_params.len() - 1;
386 let offset_idx = data_params.len();
387
388 let data_sql = format!(
389 "SELECT id, namespace, kind, status, name, content, salience, decay_factor, expires_at, \
390 properties, created_at, updated_at, deleted_at \
391 FROM notes{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
392 where_sql, limit_idx, offset_idx,
393 );
394
395 let mut stmt = conn.prepare(&data_sql)?;
396 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
397 data_params.iter().map(|p| p.as_ref()).collect();
398 let rows = stmt.query_map(param_refs.as_slice(), read_note)?;
399
400 let mut items = Vec::new();
401 for row in rows {
402 items.push(row?);
403 }
404
405 Ok(Page {
406 items,
407 total: Some(total as u64),
408 })
409 })
410 .await
411 }
412
413 async fn count_notes(&self, namespace: &str, kind: Option<&str>) -> Result<u64, StorageError> {
414 let namespace = namespace.to_string();
415 let kind = kind.map(|k| k.to_string());
416
417 self.with_reader("count_notes", move |conn| {
418 let (where_sql, params) = build_note_where(&namespace, kind.as_deref());
419 let sql = format!("SELECT COUNT(*) FROM notes{}", where_sql);
420 let mut stmt = conn.prepare(&sql)?;
421 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
422 params.iter().map(|p| p.as_ref()).collect();
423 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
424 Ok(count as u64)
425 })
426 .await
427 }
428}
429
430const NOTES_DDL: &str = "\
435 CREATE TABLE IF NOT EXISTS notes (\
436 id TEXT PRIMARY KEY,\
437 namespace TEXT NOT NULL,\
438 kind TEXT NOT NULL,\
439 status TEXT NOT NULL DEFAULT 'active',\
440 name TEXT,\
441 content TEXT NOT NULL DEFAULT '',\
442 salience REAL,\
443 decay_factor REAL,\
444 expires_at INTEGER,\
445 properties TEXT,\
446 created_at INTEGER NOT NULL,\
447 updated_at INTEGER NOT NULL,\
448 deleted_at INTEGER\
449 );\
450 CREATE INDEX IF NOT EXISTS idx_notes_namespace ON notes(namespace);\
451 CREATE INDEX IF NOT EXISTS idx_notes_kind ON notes(namespace, kind);\
452 CREATE INDEX IF NOT EXISTS idx_notes_created ON notes(created_at DESC);\
453";
454
455pub(crate) fn ensure_notes_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
456 conn.execute_batch(NOTES_DDL)
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use crate::pool::PoolConfig;
463
464 fn setup_pool() -> Arc<ConnectionPool> {
465 let config = PoolConfig {
466 path: None,
467 ..PoolConfig::default()
468 };
469 let pool = Arc::new(ConnectionPool::new(config).unwrap());
470 {
471 let writer = pool.writer().unwrap();
472 writer.conn().execute_batch(NOTES_DDL).unwrap();
473 }
474 pool
475 }
476
477 fn setup_memory_store() -> SqlNoteStore {
478 SqlNoteStore::new(setup_pool(), false)
479 }
480
481 fn make_note(namespace: &str, kind: &str, content: &str) -> Note {
482 Note::new(namespace, kind, content)
483 }
484
485 #[tokio::test]
486 async fn test_upsert_and_get_note() {
487 let store = setup_memory_store();
488
489 let note = make_note("default", "observation", "Hello world");
490 let id = note.id;
491
492 store.upsert_note(note).await.unwrap();
493
494 let fetched = store.get_note(id).await.unwrap();
495 assert!(fetched.is_some());
496 let fetched = fetched.unwrap();
497 assert_eq!(fetched.id, id);
498 assert_eq!(fetched.content, "Hello world");
499 assert_eq!(fetched.kind, "observation");
500 }
501
502 #[tokio::test]
503 async fn test_kind_roundtrip_all_variants() {
504 let store = setup_memory_store();
505 for kind in [
506 "observation",
507 "insight",
508 "question",
509 "decision",
510 "reference",
511 ] {
512 let note = make_note("default", kind, "content");
513 let id = note.id;
514 store.upsert_note(note).await.unwrap();
515 let fetched = store.get_note(id).await.unwrap().unwrap();
516 assert_eq!(fetched.kind, kind);
517 }
518 }
519
520 #[tokio::test]
521 async fn test_soft_delete() {
522 let store = setup_memory_store();
523
524 let note = make_note("default", "observation", "to be deleted");
525 let id = note.id;
526 store.upsert_note(note).await.unwrap();
527
528 let deleted = store.delete_note(id, DeleteMode::Soft).await.unwrap();
529 assert!(deleted);
530
531 let fetched = store.get_note(id).await.unwrap();
532 assert!(fetched.is_none());
533 }
534
535 #[tokio::test]
536 async fn test_hard_delete() {
537 let store = setup_memory_store();
538
539 let note = make_note("default", "observation", "to be hard deleted");
540 let id = note.id;
541 store.upsert_note(note).await.unwrap();
542
543 let deleted = store.delete_note(id, DeleteMode::Hard).await.unwrap();
544 assert!(deleted);
545
546 let fetched = store.get_note(id).await.unwrap();
547 assert!(fetched.is_none());
548 }
549
550 #[tokio::test]
552 async fn test_namespace_isolation() {
553 let pool = setup_pool();
554 let store = SqlNoteStore::new(Arc::clone(&pool), false);
555
556 for _ in 0..3 {
557 store
558 .upsert_note(make_note("ns1", "observation", "content"))
559 .await
560 .unwrap();
561 }
562 store
563 .upsert_note(make_note("ns2", "observation", "other"))
564 .await
565 .unwrap();
566
567 let count_ns1 = store.count_notes("ns1", None).await.unwrap();
568 assert_eq!(count_ns1, 3);
569
570 let count_ns2 = store.count_notes("ns2", None).await.unwrap();
571 assert_eq!(count_ns2, 1);
572 }
573
574 #[tokio::test]
576 async fn test_query_and_count_use_caller_namespace() {
577 let pool = setup_pool();
578 let store = SqlNoteStore::new(Arc::clone(&pool), false);
579
580 store
581 .upsert_note(make_note("ns_a", "observation", "A"))
582 .await
583 .unwrap();
584 store
585 .upsert_note(make_note("ns_b", "insight", "B"))
586 .await
587 .unwrap();
588
589 let page_a = store
590 .query_notes("ns_a", None, PageRequest::default())
591 .await
592 .unwrap();
593 assert_eq!(page_a.items.len(), 1);
594 assert_eq!(page_a.items[0].content, "A");
595 assert_eq!(page_a.total, Some(1));
596
597 let page_b = store
598 .query_notes("ns_b", None, PageRequest::default())
599 .await
600 .unwrap();
601 assert_eq!(page_b.items.len(), 1);
602 assert_eq!(page_b.items[0].content, "B");
603 assert_eq!(page_b.total, Some(1));
604
605 let count_a = store.count_notes("ns_a", None).await.unwrap();
606 let count_b = store.count_notes("ns_b", None).await.unwrap();
607 assert_eq!(count_a, 1);
608 assert_eq!(count_b, 1);
609 }
610
611 #[tokio::test]
612 async fn test_soft_delete_sets_status_deleted() {
613 let pool = setup_pool();
614 let store = SqlNoteStore::new(Arc::clone(&pool), false);
615 let note = make_note("default", "observation", "to delete");
616 let id = note.id;
617 store.upsert_note(note).await.unwrap();
618 let deleted = store.delete_note(id, DeleteMode::Soft).await.unwrap();
619 assert!(deleted);
620 let writer = pool.writer().unwrap();
622 let status: String = writer
623 .conn()
624 .query_row(
625 "SELECT status FROM notes WHERE id = ?1",
626 [id.to_string()],
627 |r| r.get(0),
628 )
629 .unwrap();
630 assert_eq!(status, "deleted");
631 }
632
633 #[tokio::test]
634 async fn test_note_status_field_roundtrip() {
635 let store = setup_memory_store();
636 let note = make_note("default", "observation", "status test");
637 let id = note.id;
638 store.upsert_note(note).await.unwrap();
639 let fetched = store.get_note(id).await.unwrap().unwrap();
640 assert_eq!(fetched.status, "active");
641 }
642}