1use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::error::StorageError;
9use khive_storage::note::Note;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::NoteStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18 StorageError::driver(StorageCapability::Notes, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22 StorageError::driver(StorageCapability::Notes, op, e)
23}
24
25pub struct SqlNoteStore {
30 pool: Arc<ConnectionPool>,
31 is_file_backed: bool,
32}
33
34impl SqlNoteStore {
35 pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37 Self {
38 pool,
39 is_file_backed,
40 }
41 }
42
43 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44 let config = self.pool.config();
45 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46 operation: "note_reader".into(),
47 message: "in-memory databases do not support standalone connections".into(),
48 })?;
49
50 let conn = rusqlite::Connection::open_with_flags(
51 path,
52 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55 )
56 .map_err(|e| map_err(e, "open_note_reader"))?;
57
58 conn.busy_timeout(config.busy_timeout)
59 .map_err(|e| map_err(e, "open_note_reader"))?;
60 conn.pragma_update(None, "foreign_keys", "ON")
61 .map_err(|e| map_err(e, "open_note_reader"))?;
62 conn.pragma_update(None, "synchronous", "NORMAL")
63 .map_err(|e| map_err(e, "open_note_reader"))?;
64
65 Ok(conn)
66 }
67
68 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
70 where
71 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
72 R: Send + 'static,
73 {
74 let pool = Arc::clone(&self.pool);
75 tokio::task::spawn_blocking(move || {
76 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
77 f(guard.conn()).map_err(|e| map_err(e, op))
78 })
79 .await
80 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
81 }
82
83 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
84 where
85 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
86 R: Send + 'static,
87 {
88 if self.is_file_backed {
89 let conn = self.open_standalone_reader()?;
90 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
91 .await
92 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
93 } else {
94 let pool = Arc::clone(&self.pool);
95 tokio::task::spawn_blocking(move || {
96 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
97 f(guard.conn()).map_err(|e| map_err(e, op))
98 })
99 .await
100 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
101 }
102 }
103}
104
105fn read_note(row: &rusqlite::Row<'_>) -> Result<Note, rusqlite::Error> {
110 let id_str: String = row.get(0)?;
111 let namespace: String = row.get(1)?;
112 let kind: String = row.get(2)?;
113 let name: Option<String> = row.get(3)?;
114 let content: String = row.get(4)?;
115 let salience: f64 = row.get(5)?;
116 let decay_factor: f64 = row.get(6)?;
117 let expires_at: Option<i64> = row.get(7)?;
118 let properties_str: Option<String> = row.get(8)?;
119 let created_at: i64 = row.get(9)?;
120 let updated_at: i64 = row.get(10)?;
121 let deleted_at: Option<i64> = row.get(11)?;
122
123 let id = parse_uuid(&id_str)?;
124
125 let properties = properties_str
126 .map(|s| {
127 serde_json::from_str(&s).map_err(|e| {
128 rusqlite::Error::FromSqlConversionFailure(
129 8,
130 rusqlite::types::Type::Text,
131 Box::new(e),
132 )
133 })
134 })
135 .transpose()?;
136
137 Ok(Note {
138 id,
139 namespace,
140 kind,
141 name,
142 content,
143 salience,
144 decay_factor,
145 expires_at,
146 properties,
147 created_at,
148 updated_at,
149 deleted_at,
150 })
151}
152
153fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
154 Uuid::parse_str(s).map_err(|e| {
155 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
156 })
157}
158
159fn build_note_where(
160 namespace: &str,
161 kind: Option<&str>,
162) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
163 let mut conditions: Vec<String> = vec![
164 "namespace = ?1".to_string(),
165 "deleted_at IS NULL".to_string(),
166 ];
167 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
168
169 if let Some(k) = kind {
170 params.push(Box::new(k.to_string()));
171 conditions.push(format!("kind = ?{}", params.len()));
172 }
173
174 let clause = format!(" WHERE {}", conditions.join(" AND "));
175 (clause, params)
176}
177
178#[async_trait]
183impl NoteStore for SqlNoteStore {
184 async fn upsert_note(&self, note: Note) -> Result<(), StorageError> {
185 let namespace = note.namespace.clone();
186 let id_str = note.id.to_string();
187 let kind_str = note.kind.to_string();
188 let properties_str = note
189 .properties
190 .as_ref()
191 .map(|v| serde_json::to_string(v).unwrap_or_default());
192
193 self.with_writer("upsert_note", move |conn| {
194 conn.execute(
195 "INSERT OR REPLACE INTO notes \
196 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
197 properties, created_at, updated_at, deleted_at) \
198 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
199 rusqlite::params![
200 id_str,
201 namespace,
202 kind_str,
203 note.name,
204 note.content,
205 note.salience,
206 note.decay_factor,
207 note.expires_at,
208 properties_str,
209 note.created_at,
210 note.updated_at,
211 note.deleted_at,
212 ],
213 )?;
214 Ok(())
215 })
216 .await
217 }
218
219 async fn upsert_notes(&self, notes: Vec<Note>) -> Result<BatchWriteSummary, StorageError> {
220 let attempted = notes.len() as u64;
221
222 self.with_writer("upsert_notes", move |conn| {
223 conn.execute_batch("BEGIN IMMEDIATE")?;
224 let mut affected = 0u64;
225 let mut failed = 0u64;
226 let mut first_error = String::new();
227
228 for note in ¬es {
229 let id_str = note.id.to_string();
230 let kind_str = note.kind.to_string();
231 let properties_str = note
232 .properties
233 .as_ref()
234 .map(|v| serde_json::to_string(v).unwrap_or_default());
235
236 match conn.execute(
237 "INSERT OR REPLACE INTO notes \
238 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
239 properties, created_at, updated_at, deleted_at) \
240 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
241 rusqlite::params![
242 id_str,
243 ¬e.namespace,
244 kind_str,
245 ¬e.name,
246 note.content,
247 note.salience,
248 note.decay_factor,
249 note.expires_at,
250 properties_str,
251 note.created_at,
252 note.updated_at,
253 note.deleted_at,
254 ],
255 ) {
256 Ok(_) => affected += 1,
257 Err(e) => {
258 if first_error.is_empty() {
259 first_error = e.to_string();
260 }
261 failed += 1;
262 }
263 }
264 }
265
266 if let Err(e) = conn.execute_batch("COMMIT") {
267 let _ = conn.execute_batch("ROLLBACK");
268 return Err(e);
269 }
270 Ok(BatchWriteSummary {
271 attempted,
272 affected,
273 failed,
274 first_error,
275 })
276 })
277 .await
278 }
279
280 async fn get_note(&self, id: Uuid) -> Result<Option<Note>, StorageError> {
281 let id_str = id.to_string();
282
283 self.with_reader("get_note", move |conn| {
284 let mut stmt = conn.prepare(
285 "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
286 properties, created_at, updated_at, deleted_at \
287 FROM notes WHERE id = ?1 AND deleted_at IS NULL",
288 )?;
289 let mut rows = stmt.query(rusqlite::params![id_str])?;
290 match rows.next()? {
291 Some(row) => Ok(Some(read_note(row)?)),
292 None => Ok(None),
293 }
294 })
295 .await
296 }
297
298 async fn get_notes_batch(&self, ids: &[Uuid]) -> Result<Vec<Note>, StorageError> {
299 if ids.is_empty() {
300 return Ok(vec![]);
301 }
302 let id_strings: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
303
304 self.with_reader("get_notes_batch", move |conn| {
305 let placeholders: String = (1..=id_strings.len())
306 .map(|i| format!("?{i}"))
307 .collect::<Vec<_>>()
308 .join(", ");
309 let sql = format!(
310 "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
311 properties, created_at, updated_at, deleted_at \
312 FROM notes WHERE id IN ({placeholders}) AND deleted_at IS NULL"
313 );
314 let mut stmt = conn.prepare(&sql)?;
315 let params: Vec<&dyn rusqlite::types::ToSql> = id_strings
316 .iter()
317 .map(|s| s as &dyn rusqlite::types::ToSql)
318 .collect();
319 let rows = stmt.query_map(params.as_slice(), read_note)?;
320 let mut out = Vec::new();
321 for row in rows {
322 out.push(row?);
323 }
324 Ok(out)
325 })
326 .await
327 }
328
329 async fn delete_note(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
330 let id_str = id.to_string();
331
332 match mode {
333 DeleteMode::Soft => {
334 self.with_writer("delete_note_soft", move |conn| {
335 let now = chrono::Utc::now().timestamp_micros();
336 let deleted = conn.execute(
337 "UPDATE notes SET deleted_at = ?1 \
338 WHERE id = ?2 AND deleted_at IS NULL",
339 rusqlite::params![now, id_str],
340 )?;
341 Ok(deleted > 0)
342 })
343 .await
344 }
345 DeleteMode::Hard => {
346 self.with_writer("delete_note_hard", move |conn| {
347 let deleted =
348 conn.execute("DELETE FROM notes WHERE id = ?1", rusqlite::params![id_str])?;
349 Ok(deleted > 0)
350 })
351 .await
352 }
353 }
354 }
355
356 async fn query_notes(
357 &self,
358 namespace: &str,
359 kind: Option<&str>,
360 page: PageRequest,
361 ) -> Result<Page<Note>, StorageError> {
362 let namespace = namespace.to_string();
363 let kind = kind.map(|k| k.to_string());
364
365 self.with_reader("query_notes", move |conn| {
366 let (count_sql, count_params) = build_note_where(&namespace, kind.as_deref());
367 let total: i64 = {
368 let sql = format!("SELECT COUNT(*) FROM notes{}", count_sql);
369 let mut stmt = conn.prepare(&sql)?;
370 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
371 count_params.iter().map(|p| p.as_ref()).collect();
372 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
373 };
374
375 let (where_sql, mut data_params) = build_note_where(&namespace, kind.as_deref());
376 data_params.push(Box::new(page.limit as i64));
377 data_params.push(Box::new(page.offset as i64));
378
379 let limit_idx = data_params.len() - 1;
380 let offset_idx = data_params.len();
381
382 let data_sql = format!(
383 "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
384 properties, created_at, updated_at, deleted_at \
385 FROM notes{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
386 where_sql, limit_idx, offset_idx,
387 );
388
389 let mut stmt = conn.prepare(&data_sql)?;
390 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
391 data_params.iter().map(|p| p.as_ref()).collect();
392 let rows = stmt.query_map(param_refs.as_slice(), read_note)?;
393
394 let mut items = Vec::new();
395 for row in rows {
396 items.push(row?);
397 }
398
399 Ok(Page {
400 items,
401 total: Some(total as u64),
402 })
403 })
404 .await
405 }
406
407 async fn count_notes(&self, namespace: &str, kind: Option<&str>) -> Result<u64, StorageError> {
408 let namespace = namespace.to_string();
409 let kind = kind.map(|k| k.to_string());
410
411 self.with_reader("count_notes", move |conn| {
412 let (where_sql, params) = build_note_where(&namespace, kind.as_deref());
413 let sql = format!("SELECT COUNT(*) FROM notes{}", where_sql);
414 let mut stmt = conn.prepare(&sql)?;
415 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
416 params.iter().map(|p| p.as_ref()).collect();
417 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
418 Ok(count as u64)
419 })
420 .await
421 }
422
423 async fn upsert_note_if_below_quota(
424 &self,
425 note: Note,
426 max_notes: u64,
427 ) -> Result<bool, StorageError> {
428 let namespace = note.namespace.clone();
429 let id_str = note.id.to_string();
430 let kind_str = note.kind.to_string();
431 let properties_str = note
432 .properties
433 .as_ref()
434 .map(|v| serde_json::to_string(v).unwrap_or_default());
435
436 self.with_writer("upsert_note_if_below_quota", move |conn| {
437 let count: i64 = conn.query_row(
438 "SELECT COUNT(*) FROM notes WHERE namespace = ?1 AND deleted_at IS NULL",
439 [&namespace],
440 |row| row.get(0),
441 )?;
442 if count as u64 >= max_notes {
443 return Ok(false);
444 }
445 conn.execute(
446 "INSERT OR REPLACE INTO notes \
447 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
448 properties, created_at, updated_at, deleted_at) \
449 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
450 rusqlite::params![
451 id_str,
452 namespace,
453 kind_str,
454 note.name,
455 note.content,
456 note.salience,
457 note.decay_factor,
458 note.expires_at,
459 properties_str,
460 note.created_at,
461 note.updated_at,
462 note.deleted_at,
463 ],
464 )?;
465 Ok(true)
466 })
467 .await
468 }
469}
470
471const NOTES_DDL: &str = "\
476 CREATE TABLE IF NOT EXISTS notes (\
477 id TEXT PRIMARY KEY,\
478 namespace TEXT NOT NULL,\
479 kind TEXT NOT NULL,\
480 name TEXT,\
481 content TEXT NOT NULL DEFAULT '',\
482 salience REAL NOT NULL DEFAULT 0.5,\
483 decay_factor REAL NOT NULL DEFAULT 0.0,\
484 expires_at INTEGER,\
485 properties TEXT,\
486 created_at INTEGER NOT NULL,\
487 updated_at INTEGER NOT NULL,\
488 deleted_at INTEGER\
489 );\
490 CREATE INDEX IF NOT EXISTS idx_notes_namespace ON notes(namespace);\
491 CREATE INDEX IF NOT EXISTS idx_notes_kind ON notes(namespace, kind);\
492 CREATE INDEX IF NOT EXISTS idx_notes_created ON notes(created_at DESC);\
493";
494
495pub(crate) fn ensure_notes_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
496 conn.execute_batch(NOTES_DDL)
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use crate::pool::PoolConfig;
503
504 fn setup_pool() -> Arc<ConnectionPool> {
505 let config = PoolConfig {
506 path: None,
507 ..PoolConfig::default()
508 };
509 let pool = Arc::new(ConnectionPool::new(config).unwrap());
510 {
511 let writer = pool.writer().unwrap();
512 writer.conn().execute_batch(NOTES_DDL).unwrap();
513 }
514 pool
515 }
516
517 fn setup_memory_store() -> SqlNoteStore {
518 SqlNoteStore::new(setup_pool(), false)
519 }
520
521 fn make_note(namespace: &str, kind: &str, content: &str) -> Note {
522 Note::new(namespace, kind, content)
523 }
524
525 #[tokio::test]
526 async fn test_upsert_and_get_note() {
527 let store = setup_memory_store();
528
529 let note = make_note("default", "observation", "Hello world");
530 let id = note.id;
531
532 store.upsert_note(note).await.unwrap();
533
534 let fetched = store.get_note(id).await.unwrap();
535 assert!(fetched.is_some());
536 let fetched = fetched.unwrap();
537 assert_eq!(fetched.id, id);
538 assert_eq!(fetched.content, "Hello world");
539 assert_eq!(fetched.kind, "observation");
540 }
541
542 #[tokio::test]
543 async fn test_kind_roundtrip_all_variants() {
544 let store = setup_memory_store();
545 for kind in [
546 "observation",
547 "insight",
548 "question",
549 "decision",
550 "reference",
551 ] {
552 let note = make_note("default", kind, "content");
553 let id = note.id;
554 store.upsert_note(note).await.unwrap();
555 let fetched = store.get_note(id).await.unwrap().unwrap();
556 assert_eq!(fetched.kind, kind);
557 }
558 }
559
560 #[tokio::test]
561 async fn test_soft_delete() {
562 let store = setup_memory_store();
563
564 let note = make_note("default", "observation", "to be deleted");
565 let id = note.id;
566 store.upsert_note(note).await.unwrap();
567
568 let deleted = store.delete_note(id, DeleteMode::Soft).await.unwrap();
569 assert!(deleted);
570
571 let fetched = store.get_note(id).await.unwrap();
572 assert!(fetched.is_none());
573 }
574
575 #[tokio::test]
576 async fn test_hard_delete() {
577 let store = setup_memory_store();
578
579 let note = make_note("default", "observation", "to be hard deleted");
580 let id = note.id;
581 store.upsert_note(note).await.unwrap();
582
583 let deleted = store.delete_note(id, DeleteMode::Hard).await.unwrap();
584 assert!(deleted);
585
586 let fetched = store.get_note(id).await.unwrap();
587 assert!(fetched.is_none());
588 }
589
590 #[tokio::test]
592 async fn test_namespace_isolation() {
593 let pool = setup_pool();
594 let store = SqlNoteStore::new(Arc::clone(&pool), false);
595
596 for _ in 0..3 {
597 store
598 .upsert_note(make_note("ns1", "observation", "content"))
599 .await
600 .unwrap();
601 }
602 store
603 .upsert_note(make_note("ns2", "observation", "other"))
604 .await
605 .unwrap();
606
607 let count_ns1 = store.count_notes("ns1", None).await.unwrap();
608 assert_eq!(count_ns1, 3);
609
610 let count_ns2 = store.count_notes("ns2", None).await.unwrap();
611 assert_eq!(count_ns2, 1);
612 }
613
614 #[tokio::test]
615 async fn test_quota() {
616 let pool = setup_pool();
617 let store = SqlNoteStore::new(Arc::clone(&pool), false);
618
619 for _ in 0..3 {
620 let inserted = store
621 .upsert_note_if_below_quota(make_note("quota_ns", "observation", "x"), 3)
622 .await
623 .unwrap();
624 assert!(inserted);
625 }
626
627 let inserted = store
628 .upsert_note_if_below_quota(make_note("quota_ns", "observation", "x"), 3)
629 .await
630 .unwrap();
631 assert!(!inserted);
632 }
633
634 #[tokio::test]
636 async fn test_query_and_count_use_caller_namespace() {
637 let pool = setup_pool();
638 let store = SqlNoteStore::new(Arc::clone(&pool), false);
639
640 store
641 .upsert_note(make_note("ns_a", "observation", "A"))
642 .await
643 .unwrap();
644 store
645 .upsert_note(make_note("ns_b", "insight", "B"))
646 .await
647 .unwrap();
648
649 let page_a = store
650 .query_notes("ns_a", None, PageRequest::default())
651 .await
652 .unwrap();
653 assert_eq!(page_a.items.len(), 1);
654 assert_eq!(page_a.items[0].content, "A");
655
656 let page_b = store
657 .query_notes("ns_b", None, PageRequest::default())
658 .await
659 .unwrap();
660 assert_eq!(page_b.items.len(), 1);
661 assert_eq!(page_b.items[0].content, "B");
662
663 let count_a = store.count_notes("ns_a", None).await.unwrap();
664 let count_b = store.count_notes("ns_b", None).await.unwrap();
665 assert_eq!(count_a, 1);
666 assert_eq!(count_b, 1);
667 }
668}