1use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::error::StorageError;
9use khive_storage::note::{Note, NoteKind};
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::NoteStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18 StorageError::driver(StorageCapability::Notes, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22 StorageError::driver(StorageCapability::Notes, op, e)
23}
24
25pub struct SqlNoteStore {
30 pool: Arc<ConnectionPool>,
31 is_file_backed: bool,
32}
33
34impl SqlNoteStore {
35 pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37 Self {
38 pool,
39 is_file_backed,
40 }
41 }
42
43 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44 let config = self.pool.config();
45 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46 operation: "note_reader".into(),
47 message: "in-memory databases do not support standalone connections".into(),
48 })?;
49
50 let conn = rusqlite::Connection::open_with_flags(
51 path,
52 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55 )
56 .map_err(|e| map_err(e, "open_note_reader"))?;
57
58 conn.busy_timeout(config.busy_timeout)
59 .map_err(|e| map_err(e, "open_note_reader"))?;
60 conn.pragma_update(None, "foreign_keys", "ON")
61 .map_err(|e| map_err(e, "open_note_reader"))?;
62 conn.pragma_update(None, "synchronous", "NORMAL")
63 .map_err(|e| map_err(e, "open_note_reader"))?;
64
65 Ok(conn)
66 }
67
68 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
70 where
71 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
72 R: Send + 'static,
73 {
74 let pool = Arc::clone(&self.pool);
75 tokio::task::spawn_blocking(move || {
76 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
77 f(guard.conn()).map_err(|e| map_err(e, op))
78 })
79 .await
80 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
81 }
82
83 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
84 where
85 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
86 R: Send + 'static,
87 {
88 if self.is_file_backed {
89 let conn = self.open_standalone_reader()?;
90 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
91 .await
92 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
93 } else {
94 let pool = Arc::clone(&self.pool);
95 tokio::task::spawn_blocking(move || {
96 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
97 f(guard.conn()).map_err(|e| map_err(e, op))
98 })
99 .await
100 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
101 }
102 }
103}
104
105fn parse_note_kind(s: &str, col: usize) -> Result<NoteKind, rusqlite::Error> {
110 s.parse::<NoteKind>().map_err(|e| {
111 rusqlite::Error::FromSqlConversionFailure(col, rusqlite::types::Type::Text, e.into())
112 })
113}
114
115fn read_note(row: &rusqlite::Row<'_>) -> Result<Note, rusqlite::Error> {
116 let id_str: String = row.get(0)?;
117 let namespace: String = row.get(1)?;
118 let kind_str: String = row.get(2)?;
119 let name: Option<String> = row.get(3)?;
120 let content: String = row.get(4)?;
121 let salience: f64 = row.get(5)?;
122 let decay_factor: f64 = row.get(6)?;
123 let expires_at: Option<i64> = row.get(7)?;
124 let properties_str: Option<String> = row.get(8)?;
125 let created_at: i64 = row.get(9)?;
126 let updated_at: i64 = row.get(10)?;
127 let deleted_at: Option<i64> = row.get(11)?;
128
129 let id = parse_uuid(&id_str)?;
130 let kind = parse_note_kind(&kind_str, 2)?;
131
132 let properties = properties_str
133 .map(|s| {
134 serde_json::from_str(&s).map_err(|e| {
135 rusqlite::Error::FromSqlConversionFailure(
136 8,
137 rusqlite::types::Type::Text,
138 Box::new(e),
139 )
140 })
141 })
142 .transpose()?;
143
144 Ok(Note {
145 id,
146 namespace,
147 kind,
148 name,
149 content,
150 salience,
151 decay_factor,
152 expires_at,
153 properties,
154 created_at,
155 updated_at,
156 deleted_at,
157 })
158}
159
160fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
161 Uuid::parse_str(s).map_err(|e| {
162 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
163 })
164}
165
166fn build_note_where(
167 namespace: &str,
168 kind: Option<&str>,
169) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
170 let mut conditions: Vec<String> = vec![
171 "namespace = ?1".to_string(),
172 "deleted_at IS NULL".to_string(),
173 ];
174 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
175
176 if let Some(k) = kind {
177 params.push(Box::new(k.to_string()));
178 conditions.push(format!("kind = ?{}", params.len()));
179 }
180
181 let clause = format!(" WHERE {}", conditions.join(" AND "));
182 (clause, params)
183}
184
185#[async_trait]
190impl NoteStore for SqlNoteStore {
191 async fn upsert_note(&self, note: Note) -> Result<(), StorageError> {
192 let namespace = note.namespace.clone();
193 let id_str = note.id.to_string();
194 let kind_str = note.kind.to_string();
195 let properties_str = note
196 .properties
197 .as_ref()
198 .map(|v| serde_json::to_string(v).unwrap_or_default());
199
200 self.with_writer("upsert_note", move |conn| {
201 conn.execute(
202 "INSERT OR REPLACE INTO notes \
203 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
204 properties, created_at, updated_at, deleted_at) \
205 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
206 rusqlite::params![
207 id_str,
208 namespace,
209 kind_str,
210 note.name,
211 note.content,
212 note.salience,
213 note.decay_factor,
214 note.expires_at,
215 properties_str,
216 note.created_at,
217 note.updated_at,
218 note.deleted_at,
219 ],
220 )?;
221 Ok(())
222 })
223 .await
224 }
225
226 async fn upsert_notes(&self, notes: Vec<Note>) -> Result<BatchWriteSummary, StorageError> {
227 let attempted = notes.len() as u64;
228
229 self.with_writer("upsert_notes", move |conn| {
230 conn.execute_batch("BEGIN IMMEDIATE")?;
231 let mut affected = 0u64;
232 let mut failed = 0u64;
233 let mut first_error = String::new();
234
235 for note in ¬es {
236 let id_str = note.id.to_string();
237 let kind_str = note.kind.to_string();
238 let properties_str = note
239 .properties
240 .as_ref()
241 .map(|v| serde_json::to_string(v).unwrap_or_default());
242
243 match conn.execute(
244 "INSERT OR REPLACE INTO notes \
245 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
246 properties, created_at, updated_at, deleted_at) \
247 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
248 rusqlite::params![
249 id_str,
250 ¬e.namespace,
251 kind_str,
252 ¬e.name,
253 note.content,
254 note.salience,
255 note.decay_factor,
256 note.expires_at,
257 properties_str,
258 note.created_at,
259 note.updated_at,
260 note.deleted_at,
261 ],
262 ) {
263 Ok(_) => affected += 1,
264 Err(e) => {
265 if first_error.is_empty() {
266 first_error = e.to_string();
267 }
268 failed += 1;
269 }
270 }
271 }
272
273 if let Err(e) = conn.execute_batch("COMMIT") {
274 let _ = conn.execute_batch("ROLLBACK");
275 return Err(e);
276 }
277 Ok(BatchWriteSummary {
278 attempted,
279 affected,
280 failed,
281 first_error,
282 })
283 })
284 .await
285 }
286
287 async fn get_note(&self, id: Uuid) -> Result<Option<Note>, StorageError> {
288 let id_str = id.to_string();
289
290 self.with_reader("get_note", move |conn| {
291 let mut stmt = conn.prepare(
292 "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
293 properties, created_at, updated_at, deleted_at \
294 FROM notes WHERE id = ?1 AND deleted_at IS NULL",
295 )?;
296 let mut rows = stmt.query(rusqlite::params![id_str])?;
297 match rows.next()? {
298 Some(row) => Ok(Some(read_note(row)?)),
299 None => Ok(None),
300 }
301 })
302 .await
303 }
304
305 async fn delete_note(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
306 let id_str = id.to_string();
307
308 match mode {
309 DeleteMode::Soft => {
310 self.with_writer("delete_note_soft", move |conn| {
311 let now = chrono::Utc::now().timestamp_micros();
312 let deleted = conn.execute(
313 "UPDATE notes SET deleted_at = ?1 \
314 WHERE id = ?2 AND deleted_at IS NULL",
315 rusqlite::params![now, id_str],
316 )?;
317 Ok(deleted > 0)
318 })
319 .await
320 }
321 DeleteMode::Hard => {
322 self.with_writer("delete_note_hard", move |conn| {
323 let deleted =
324 conn.execute("DELETE FROM notes WHERE id = ?1", rusqlite::params![id_str])?;
325 Ok(deleted > 0)
326 })
327 .await
328 }
329 }
330 }
331
332 async fn query_notes(
333 &self,
334 namespace: &str,
335 kind: Option<NoteKind>,
336 page: PageRequest,
337 ) -> Result<Page<Note>, StorageError> {
338 let namespace = namespace.to_string();
339 let kind = kind.map(|k| k.to_string());
340
341 self.with_reader("query_notes", move |conn| {
342 let (count_sql, count_params) = build_note_where(&namespace, kind.as_deref());
343 let total: i64 = {
344 let sql = format!("SELECT COUNT(*) FROM notes{}", count_sql);
345 let mut stmt = conn.prepare(&sql)?;
346 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
347 count_params.iter().map(|p| p.as_ref()).collect();
348 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
349 };
350
351 let (where_sql, mut data_params) = build_note_where(&namespace, kind.as_deref());
352 data_params.push(Box::new(page.limit as i64));
353 data_params.push(Box::new(page.offset as i64));
354
355 let limit_idx = data_params.len() - 1;
356 let offset_idx = data_params.len();
357
358 let data_sql = format!(
359 "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
360 properties, created_at, updated_at, deleted_at \
361 FROM notes{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
362 where_sql, limit_idx, offset_idx,
363 );
364
365 let mut stmt = conn.prepare(&data_sql)?;
366 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
367 data_params.iter().map(|p| p.as_ref()).collect();
368 let rows = stmt.query_map(param_refs.as_slice(), read_note)?;
369
370 let mut items = Vec::new();
371 for row in rows {
372 items.push(row?);
373 }
374
375 Ok(Page {
376 items,
377 total: Some(total as u64),
378 })
379 })
380 .await
381 }
382
383 async fn count_notes(
384 &self,
385 namespace: &str,
386 kind: Option<NoteKind>,
387 ) -> Result<u64, StorageError> {
388 let namespace = namespace.to_string();
389 let kind = kind.map(|k| k.to_string());
390
391 self.with_reader("count_notes", move |conn| {
392 let (where_sql, params) = build_note_where(&namespace, kind.as_deref());
393 let sql = format!("SELECT COUNT(*) FROM notes{}", where_sql);
394 let mut stmt = conn.prepare(&sql)?;
395 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
396 params.iter().map(|p| p.as_ref()).collect();
397 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
398 Ok(count as u64)
399 })
400 .await
401 }
402
403 async fn upsert_note_if_below_quota(
404 &self,
405 note: Note,
406 max_notes: u64,
407 ) -> Result<bool, StorageError> {
408 let namespace = note.namespace.clone();
409 let id_str = note.id.to_string();
410 let kind_str = note.kind.to_string();
411 let properties_str = note
412 .properties
413 .as_ref()
414 .map(|v| serde_json::to_string(v).unwrap_or_default());
415
416 self.with_writer("upsert_note_if_below_quota", move |conn| {
417 let count: i64 = conn.query_row(
418 "SELECT COUNT(*) FROM notes WHERE namespace = ?1 AND deleted_at IS NULL",
419 [&namespace],
420 |row| row.get(0),
421 )?;
422 if count as u64 >= max_notes {
423 return Ok(false);
424 }
425 conn.execute(
426 "INSERT OR REPLACE INTO notes \
427 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
428 properties, created_at, updated_at, deleted_at) \
429 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
430 rusqlite::params![
431 id_str,
432 namespace,
433 kind_str,
434 note.name,
435 note.content,
436 note.salience,
437 note.decay_factor,
438 note.expires_at,
439 properties_str,
440 note.created_at,
441 note.updated_at,
442 note.deleted_at,
443 ],
444 )?;
445 Ok(true)
446 })
447 .await
448 }
449}
450
451const NOTES_DDL: &str = "\
456 CREATE TABLE IF NOT EXISTS notes (\
457 id TEXT PRIMARY KEY,\
458 namespace TEXT NOT NULL,\
459 kind TEXT NOT NULL,\
460 name TEXT,\
461 content TEXT NOT NULL DEFAULT '',\
462 salience REAL NOT NULL DEFAULT 0.5,\
463 decay_factor REAL NOT NULL DEFAULT 0.0,\
464 expires_at INTEGER,\
465 properties TEXT,\
466 created_at INTEGER NOT NULL,\
467 updated_at INTEGER NOT NULL,\
468 deleted_at INTEGER\
469 );\
470 CREATE INDEX IF NOT EXISTS idx_notes_namespace ON notes(namespace);\
471 CREATE INDEX IF NOT EXISTS idx_notes_kind ON notes(namespace, kind);\
472 CREATE INDEX IF NOT EXISTS idx_notes_created ON notes(created_at DESC);\
473";
474
475pub(crate) fn ensure_notes_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
476 conn.execute_batch(NOTES_DDL)
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use crate::pool::PoolConfig;
483
484 fn setup_pool() -> Arc<ConnectionPool> {
485 let config = PoolConfig {
486 path: None,
487 ..PoolConfig::default()
488 };
489 let pool = Arc::new(ConnectionPool::new(config).unwrap());
490 {
491 let writer = pool.writer().unwrap();
492 writer.conn().execute_batch(NOTES_DDL).unwrap();
493 }
494 pool
495 }
496
497 fn setup_memory_store() -> SqlNoteStore {
498 SqlNoteStore::new(setup_pool(), false)
499 }
500
501 fn make_note(namespace: &str, kind: NoteKind, content: &str) -> Note {
502 Note::new(namespace, kind, content)
503 }
504
505 #[tokio::test]
506 async fn test_upsert_and_get_note() {
507 let store = setup_memory_store();
508
509 let note = make_note("default", NoteKind::Observation, "Hello world");
510 let id = note.id;
511
512 store.upsert_note(note).await.unwrap();
513
514 let fetched = store.get_note(id).await.unwrap();
515 assert!(fetched.is_some());
516 let fetched = fetched.unwrap();
517 assert_eq!(fetched.id, id);
518 assert_eq!(fetched.content, "Hello world");
519 assert_eq!(fetched.kind, NoteKind::Observation);
520 }
521
522 #[tokio::test]
523 async fn test_kind_roundtrip_all_variants() {
524 let store = setup_memory_store();
525 for kind in [
526 NoteKind::Observation,
527 NoteKind::Insight,
528 NoteKind::Question,
529 NoteKind::Decision,
530 NoteKind::Reference,
531 ] {
532 let note = make_note("default", kind, "content");
533 let id = note.id;
534 store.upsert_note(note).await.unwrap();
535 let fetched = store.get_note(id).await.unwrap().unwrap();
536 assert_eq!(fetched.kind, kind);
537 }
538 }
539
540 #[tokio::test]
541 async fn test_soft_delete() {
542 let store = setup_memory_store();
543
544 let note = make_note("default", NoteKind::Observation, "to be deleted");
545 let id = note.id;
546 store.upsert_note(note).await.unwrap();
547
548 let deleted = store.delete_note(id, DeleteMode::Soft).await.unwrap();
549 assert!(deleted);
550
551 let fetched = store.get_note(id).await.unwrap();
552 assert!(fetched.is_none());
553 }
554
555 #[tokio::test]
556 async fn test_hard_delete() {
557 let store = setup_memory_store();
558
559 let note = make_note("default", NoteKind::Observation, "to be hard deleted");
560 let id = note.id;
561 store.upsert_note(note).await.unwrap();
562
563 let deleted = store.delete_note(id, DeleteMode::Hard).await.unwrap();
564 assert!(deleted);
565
566 let fetched = store.get_note(id).await.unwrap();
567 assert!(fetched.is_none());
568 }
569
570 #[tokio::test]
572 async fn test_namespace_isolation() {
573 let pool = setup_pool();
574 let store = SqlNoteStore::new(Arc::clone(&pool), false);
575
576 for _ in 0..3 {
577 store
578 .upsert_note(make_note("ns1", NoteKind::Observation, "content"))
579 .await
580 .unwrap();
581 }
582 store
583 .upsert_note(make_note("ns2", NoteKind::Observation, "other"))
584 .await
585 .unwrap();
586
587 let count_ns1 = store.count_notes("ns1", None).await.unwrap();
588 assert_eq!(count_ns1, 3);
589
590 let count_ns2 = store.count_notes("ns2", None).await.unwrap();
591 assert_eq!(count_ns2, 1);
592 }
593
594 #[tokio::test]
595 async fn test_quota() {
596 let pool = setup_pool();
597 let store = SqlNoteStore::new(Arc::clone(&pool), false);
598
599 for _ in 0..3 {
600 let inserted = store
601 .upsert_note_if_below_quota(make_note("quota_ns", NoteKind::Observation, "x"), 3)
602 .await
603 .unwrap();
604 assert!(inserted);
605 }
606
607 let inserted = store
608 .upsert_note_if_below_quota(make_note("quota_ns", NoteKind::Observation, "x"), 3)
609 .await
610 .unwrap();
611 assert!(!inserted);
612 }
613
614 #[tokio::test]
616 async fn test_query_and_count_use_caller_namespace() {
617 let pool = setup_pool();
618 let store = SqlNoteStore::new(Arc::clone(&pool), false);
619
620 store
621 .upsert_note(make_note("ns_a", NoteKind::Observation, "A"))
622 .await
623 .unwrap();
624 store
625 .upsert_note(make_note("ns_b", NoteKind::Insight, "B"))
626 .await
627 .unwrap();
628
629 let page_a = store
630 .query_notes("ns_a", None, PageRequest::default())
631 .await
632 .unwrap();
633 assert_eq!(page_a.items.len(), 1);
634 assert_eq!(page_a.items[0].content, "A");
635
636 let page_b = store
637 .query_notes("ns_b", None, PageRequest::default())
638 .await
639 .unwrap();
640 assert_eq!(page_b.items.len(), 1);
641 assert_eq!(page_b.items[0].content, "B");
642
643 let count_a = store.count_notes("ns_a", None).await.unwrap();
644 let count_b = store.count_notes("ns_b", None).await.unwrap();
645 assert_eq!(count_a, 1);
646 assert_eq!(count_b, 1);
647 }
648}