1use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::error::StorageError;
9use khive_storage::note::Note;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::NoteStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18 StorageError::driver(StorageCapability::Notes, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22 StorageError::driver(StorageCapability::Notes, op, e)
23}
24
25pub struct SqlNoteStore {
30 pool: Arc<ConnectionPool>,
31 is_file_backed: bool,
32}
33
34impl SqlNoteStore {
35 pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37 Self {
38 pool,
39 is_file_backed,
40 }
41 }
42
43 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44 let config = self.pool.config();
45 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46 operation: "note_reader".into(),
47 message: "in-memory databases do not support standalone connections".into(),
48 })?;
49
50 let conn = rusqlite::Connection::open_with_flags(
51 path,
52 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55 )
56 .map_err(|e| map_err(e, "open_note_reader"))?;
57
58 conn.busy_timeout(config.busy_timeout)
59 .map_err(|e| map_err(e, "open_note_reader"))?;
60 conn.pragma_update(None, "foreign_keys", "ON")
61 .map_err(|e| map_err(e, "open_note_reader"))?;
62 conn.pragma_update(None, "synchronous", "NORMAL")
63 .map_err(|e| map_err(e, "open_note_reader"))?;
64
65 Ok(conn)
66 }
67
68 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
70 where
71 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
72 R: Send + 'static,
73 {
74 let pool = Arc::clone(&self.pool);
75 tokio::task::spawn_blocking(move || {
76 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
77 f(guard.conn()).map_err(|e| map_err(e, op))
78 })
79 .await
80 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
81 }
82
83 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
84 where
85 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
86 R: Send + 'static,
87 {
88 if self.is_file_backed {
89 let conn = self.open_standalone_reader()?;
90 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
91 .await
92 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
93 } else {
94 let pool = Arc::clone(&self.pool);
95 tokio::task::spawn_blocking(move || {
96 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
97 f(guard.conn()).map_err(|e| map_err(e, op))
98 })
99 .await
100 .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
101 }
102 }
103}
104
105fn read_note(row: &rusqlite::Row<'_>) -> Result<Note, rusqlite::Error> {
110 let id_str: String = row.get(0)?;
111 let namespace: String = row.get(1)?;
112 let kind: String = row.get(2)?;
113 let name: Option<String> = row.get(3)?;
114 let content: String = row.get(4)?;
115 let salience: f64 = row.get(5)?;
116 let decay_factor: f64 = row.get(6)?;
117 let expires_at: Option<i64> = row.get(7)?;
118 let properties_str: Option<String> = row.get(8)?;
119 let created_at: i64 = row.get(9)?;
120 let updated_at: i64 = row.get(10)?;
121 let deleted_at: Option<i64> = row.get(11)?;
122
123 let id = parse_uuid(&id_str)?;
124
125 let properties = properties_str
126 .map(|s| {
127 serde_json::from_str(&s).map_err(|e| {
128 rusqlite::Error::FromSqlConversionFailure(
129 8,
130 rusqlite::types::Type::Text,
131 Box::new(e),
132 )
133 })
134 })
135 .transpose()?;
136
137 Ok(Note {
138 id,
139 namespace,
140 kind,
141 name,
142 content,
143 salience,
144 decay_factor,
145 expires_at,
146 properties,
147 created_at,
148 updated_at,
149 deleted_at,
150 })
151}
152
153fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
154 Uuid::parse_str(s).map_err(|e| {
155 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
156 })
157}
158
159fn build_note_where(
160 namespace: &str,
161 kind: Option<&str>,
162) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
163 let mut conditions: Vec<String> = vec![
164 "namespace = ?1".to_string(),
165 "deleted_at IS NULL".to_string(),
166 ];
167 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
168
169 if let Some(k) = kind {
170 params.push(Box::new(k.to_string()));
171 conditions.push(format!("kind = ?{}", params.len()));
172 }
173
174 let clause = format!(" WHERE {}", conditions.join(" AND "));
175 (clause, params)
176}
177
178#[async_trait]
183impl NoteStore for SqlNoteStore {
184 async fn upsert_note(&self, note: Note) -> Result<(), StorageError> {
185 let namespace = note.namespace.clone();
186 let id_str = note.id.to_string();
187 let kind_str = note.kind.to_string();
188 let properties_str = note
189 .properties
190 .as_ref()
191 .map(|v| serde_json::to_string(v).unwrap_or_default());
192
193 self.with_writer("upsert_note", move |conn| {
194 conn.execute(
195 "INSERT OR REPLACE INTO notes \
196 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
197 properties, created_at, updated_at, deleted_at) \
198 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
199 rusqlite::params![
200 id_str,
201 namespace,
202 kind_str,
203 note.name,
204 note.content,
205 note.salience,
206 note.decay_factor,
207 note.expires_at,
208 properties_str,
209 note.created_at,
210 note.updated_at,
211 note.deleted_at,
212 ],
213 )?;
214 Ok(())
215 })
216 .await
217 }
218
219 async fn upsert_notes(&self, notes: Vec<Note>) -> Result<BatchWriteSummary, StorageError> {
220 let attempted = notes.len() as u64;
221
222 self.with_writer("upsert_notes", move |conn| {
223 conn.execute_batch("BEGIN IMMEDIATE")?;
224 let mut affected = 0u64;
225 let mut failed = 0u64;
226 let mut first_error = String::new();
227
228 for note in ¬es {
229 let id_str = note.id.to_string();
230 let kind_str = note.kind.to_string();
231 let properties_str = note
232 .properties
233 .as_ref()
234 .map(|v| serde_json::to_string(v).unwrap_or_default());
235
236 match conn.execute(
237 "INSERT OR REPLACE INTO notes \
238 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
239 properties, created_at, updated_at, deleted_at) \
240 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
241 rusqlite::params![
242 id_str,
243 ¬e.namespace,
244 kind_str,
245 ¬e.name,
246 note.content,
247 note.salience,
248 note.decay_factor,
249 note.expires_at,
250 properties_str,
251 note.created_at,
252 note.updated_at,
253 note.deleted_at,
254 ],
255 ) {
256 Ok(_) => affected += 1,
257 Err(e) => {
258 if first_error.is_empty() {
259 first_error = e.to_string();
260 }
261 failed += 1;
262 }
263 }
264 }
265
266 if let Err(e) = conn.execute_batch("COMMIT") {
267 let _ = conn.execute_batch("ROLLBACK");
268 return Err(e);
269 }
270 Ok(BatchWriteSummary {
271 attempted,
272 affected,
273 failed,
274 first_error,
275 })
276 })
277 .await
278 }
279
280 async fn get_note(&self, id: Uuid) -> Result<Option<Note>, StorageError> {
281 let id_str = id.to_string();
282
283 self.with_reader("get_note", move |conn| {
284 let mut stmt = conn.prepare(
285 "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
286 properties, created_at, updated_at, deleted_at \
287 FROM notes WHERE id = ?1 AND deleted_at IS NULL",
288 )?;
289 let mut rows = stmt.query(rusqlite::params![id_str])?;
290 match rows.next()? {
291 Some(row) => Ok(Some(read_note(row)?)),
292 None => Ok(None),
293 }
294 })
295 .await
296 }
297
298 async fn delete_note(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
299 let id_str = id.to_string();
300
301 match mode {
302 DeleteMode::Soft => {
303 self.with_writer("delete_note_soft", move |conn| {
304 let now = chrono::Utc::now().timestamp_micros();
305 let deleted = conn.execute(
306 "UPDATE notes SET deleted_at = ?1 \
307 WHERE id = ?2 AND deleted_at IS NULL",
308 rusqlite::params![now, id_str],
309 )?;
310 Ok(deleted > 0)
311 })
312 .await
313 }
314 DeleteMode::Hard => {
315 self.with_writer("delete_note_hard", move |conn| {
316 let deleted =
317 conn.execute("DELETE FROM notes WHERE id = ?1", rusqlite::params![id_str])?;
318 Ok(deleted > 0)
319 })
320 .await
321 }
322 }
323 }
324
325 async fn query_notes(
326 &self,
327 namespace: &str,
328 kind: Option<&str>,
329 page: PageRequest,
330 ) -> Result<Page<Note>, StorageError> {
331 let namespace = namespace.to_string();
332 let kind = kind.map(|k| k.to_string());
333
334 self.with_reader("query_notes", move |conn| {
335 let (count_sql, count_params) = build_note_where(&namespace, kind.as_deref());
336 let total: i64 = {
337 let sql = format!("SELECT COUNT(*) FROM notes{}", count_sql);
338 let mut stmt = conn.prepare(&sql)?;
339 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
340 count_params.iter().map(|p| p.as_ref()).collect();
341 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
342 };
343
344 let (where_sql, mut data_params) = build_note_where(&namespace, kind.as_deref());
345 data_params.push(Box::new(page.limit as i64));
346 data_params.push(Box::new(page.offset as i64));
347
348 let limit_idx = data_params.len() - 1;
349 let offset_idx = data_params.len();
350
351 let data_sql = format!(
352 "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
353 properties, created_at, updated_at, deleted_at \
354 FROM notes{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
355 where_sql, limit_idx, offset_idx,
356 );
357
358 let mut stmt = conn.prepare(&data_sql)?;
359 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
360 data_params.iter().map(|p| p.as_ref()).collect();
361 let rows = stmt.query_map(param_refs.as_slice(), read_note)?;
362
363 let mut items = Vec::new();
364 for row in rows {
365 items.push(row?);
366 }
367
368 Ok(Page {
369 items,
370 total: Some(total as u64),
371 })
372 })
373 .await
374 }
375
376 async fn count_notes(&self, namespace: &str, kind: Option<&str>) -> Result<u64, StorageError> {
377 let namespace = namespace.to_string();
378 let kind = kind.map(|k| k.to_string());
379
380 self.with_reader("count_notes", move |conn| {
381 let (where_sql, params) = build_note_where(&namespace, kind.as_deref());
382 let sql = format!("SELECT COUNT(*) FROM notes{}", where_sql);
383 let mut stmt = conn.prepare(&sql)?;
384 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
385 params.iter().map(|p| p.as_ref()).collect();
386 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
387 Ok(count as u64)
388 })
389 .await
390 }
391
392 async fn upsert_note_if_below_quota(
393 &self,
394 note: Note,
395 max_notes: u64,
396 ) -> Result<bool, StorageError> {
397 let namespace = note.namespace.clone();
398 let id_str = note.id.to_string();
399 let kind_str = note.kind.to_string();
400 let properties_str = note
401 .properties
402 .as_ref()
403 .map(|v| serde_json::to_string(v).unwrap_or_default());
404
405 self.with_writer("upsert_note_if_below_quota", move |conn| {
406 let count: i64 = conn.query_row(
407 "SELECT COUNT(*) FROM notes WHERE namespace = ?1 AND deleted_at IS NULL",
408 [&namespace],
409 |row| row.get(0),
410 )?;
411 if count as u64 >= max_notes {
412 return Ok(false);
413 }
414 conn.execute(
415 "INSERT OR REPLACE INTO notes \
416 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
417 properties, created_at, updated_at, deleted_at) \
418 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
419 rusqlite::params![
420 id_str,
421 namespace,
422 kind_str,
423 note.name,
424 note.content,
425 note.salience,
426 note.decay_factor,
427 note.expires_at,
428 properties_str,
429 note.created_at,
430 note.updated_at,
431 note.deleted_at,
432 ],
433 )?;
434 Ok(true)
435 })
436 .await
437 }
438}
439
440const NOTES_DDL: &str = "\
445 CREATE TABLE IF NOT EXISTS notes (\
446 id TEXT PRIMARY KEY,\
447 namespace TEXT NOT NULL,\
448 kind TEXT NOT NULL,\
449 name TEXT,\
450 content TEXT NOT NULL DEFAULT '',\
451 salience REAL NOT NULL DEFAULT 0.5,\
452 decay_factor REAL NOT NULL DEFAULT 0.0,\
453 expires_at INTEGER,\
454 properties TEXT,\
455 created_at INTEGER NOT NULL,\
456 updated_at INTEGER NOT NULL,\
457 deleted_at INTEGER\
458 );\
459 CREATE INDEX IF NOT EXISTS idx_notes_namespace ON notes(namespace);\
460 CREATE INDEX IF NOT EXISTS idx_notes_kind ON notes(namespace, kind);\
461 CREATE INDEX IF NOT EXISTS idx_notes_created ON notes(created_at DESC);\
462";
463
464pub(crate) fn ensure_notes_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
465 conn.execute_batch(NOTES_DDL)
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use crate::pool::PoolConfig;
472
473 fn setup_pool() -> Arc<ConnectionPool> {
474 let config = PoolConfig {
475 path: None,
476 ..PoolConfig::default()
477 };
478 let pool = Arc::new(ConnectionPool::new(config).unwrap());
479 {
480 let writer = pool.writer().unwrap();
481 writer.conn().execute_batch(NOTES_DDL).unwrap();
482 }
483 pool
484 }
485
486 fn setup_memory_store() -> SqlNoteStore {
487 SqlNoteStore::new(setup_pool(), false)
488 }
489
490 fn make_note(namespace: &str, kind: &str, content: &str) -> Note {
491 Note::new(namespace, kind, content)
492 }
493
494 #[tokio::test]
495 async fn test_upsert_and_get_note() {
496 let store = setup_memory_store();
497
498 let note = make_note("default", "observation", "Hello world");
499 let id = note.id;
500
501 store.upsert_note(note).await.unwrap();
502
503 let fetched = store.get_note(id).await.unwrap();
504 assert!(fetched.is_some());
505 let fetched = fetched.unwrap();
506 assert_eq!(fetched.id, id);
507 assert_eq!(fetched.content, "Hello world");
508 assert_eq!(fetched.kind, "observation");
509 }
510
511 #[tokio::test]
512 async fn test_kind_roundtrip_all_variants() {
513 let store = setup_memory_store();
514 for kind in [
515 "observation",
516 "insight",
517 "question",
518 "decision",
519 "reference",
520 ] {
521 let note = make_note("default", kind, "content");
522 let id = note.id;
523 store.upsert_note(note).await.unwrap();
524 let fetched = store.get_note(id).await.unwrap().unwrap();
525 assert_eq!(fetched.kind, kind);
526 }
527 }
528
529 #[tokio::test]
530 async fn test_soft_delete() {
531 let store = setup_memory_store();
532
533 let note = make_note("default", "observation", "to be deleted");
534 let id = note.id;
535 store.upsert_note(note).await.unwrap();
536
537 let deleted = store.delete_note(id, DeleteMode::Soft).await.unwrap();
538 assert!(deleted);
539
540 let fetched = store.get_note(id).await.unwrap();
541 assert!(fetched.is_none());
542 }
543
544 #[tokio::test]
545 async fn test_hard_delete() {
546 let store = setup_memory_store();
547
548 let note = make_note("default", "observation", "to be hard deleted");
549 let id = note.id;
550 store.upsert_note(note).await.unwrap();
551
552 let deleted = store.delete_note(id, DeleteMode::Hard).await.unwrap();
553 assert!(deleted);
554
555 let fetched = store.get_note(id).await.unwrap();
556 assert!(fetched.is_none());
557 }
558
559 #[tokio::test]
561 async fn test_namespace_isolation() {
562 let pool = setup_pool();
563 let store = SqlNoteStore::new(Arc::clone(&pool), false);
564
565 for _ in 0..3 {
566 store
567 .upsert_note(make_note("ns1", "observation", "content"))
568 .await
569 .unwrap();
570 }
571 store
572 .upsert_note(make_note("ns2", "observation", "other"))
573 .await
574 .unwrap();
575
576 let count_ns1 = store.count_notes("ns1", None).await.unwrap();
577 assert_eq!(count_ns1, 3);
578
579 let count_ns2 = store.count_notes("ns2", None).await.unwrap();
580 assert_eq!(count_ns2, 1);
581 }
582
583 #[tokio::test]
584 async fn test_quota() {
585 let pool = setup_pool();
586 let store = SqlNoteStore::new(Arc::clone(&pool), false);
587
588 for _ in 0..3 {
589 let inserted = store
590 .upsert_note_if_below_quota(make_note("quota_ns", "observation", "x"), 3)
591 .await
592 .unwrap();
593 assert!(inserted);
594 }
595
596 let inserted = store
597 .upsert_note_if_below_quota(make_note("quota_ns", "observation", "x"), 3)
598 .await
599 .unwrap();
600 assert!(!inserted);
601 }
602
603 #[tokio::test]
605 async fn test_query_and_count_use_caller_namespace() {
606 let pool = setup_pool();
607 let store = SqlNoteStore::new(Arc::clone(&pool), false);
608
609 store
610 .upsert_note(make_note("ns_a", "observation", "A"))
611 .await
612 .unwrap();
613 store
614 .upsert_note(make_note("ns_b", "insight", "B"))
615 .await
616 .unwrap();
617
618 let page_a = store
619 .query_notes("ns_a", None, PageRequest::default())
620 .await
621 .unwrap();
622 assert_eq!(page_a.items.len(), 1);
623 assert_eq!(page_a.items[0].content, "A");
624
625 let page_b = store
626 .query_notes("ns_b", None, PageRequest::default())
627 .await
628 .unwrap();
629 assert_eq!(page_b.items.len(), 1);
630 assert_eq!(page_b.items[0].content, "B");
631
632 let count_a = store.count_notes("ns_a", None).await.unwrap();
633 let count_b = store.count_notes("ns_b", None).await.unwrap();
634 assert_eq!(count_a, 1);
635 assert_eq!(count_b, 1);
636 }
637}