Skip to main content

khive_db/stores/
note.rs

1//! SQL-backed `NoteStore` implementation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::error::StorageError;
9use khive_storage::note::Note;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::NoteStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18    StorageError::driver(StorageCapability::Notes, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22    StorageError::driver(StorageCapability::Notes, op, e)
23}
24
25/// A NoteStore backed by SQLite. Namespace is the caller's responsibility.
26///
27/// UUID is globally unique — get/delete by ID alone. Query/count use the
28/// namespace parameter as passed. The store is just a pool + is_file_backed.
29pub struct SqlNoteStore {
30    pool: Arc<ConnectionPool>,
31    is_file_backed: bool,
32}
33
34impl SqlNoteStore {
35    /// Create a new store.
36    pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37        Self {
38            pool,
39            is_file_backed,
40        }
41    }
42
43    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44        let config = self.pool.config();
45        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46            operation: "note_reader".into(),
47            message: "in-memory databases do not support standalone connections".into(),
48        })?;
49
50        let conn = rusqlite::Connection::open_with_flags(
51            path,
52            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55        )
56        .map_err(|e| map_err(e, "open_note_reader"))?;
57
58        conn.busy_timeout(config.busy_timeout)
59            .map_err(|e| map_err(e, "open_note_reader"))?;
60        conn.pragma_update(None, "foreign_keys", "ON")
61            .map_err(|e| map_err(e, "open_note_reader"))?;
62        conn.pragma_update(None, "synchronous", "NORMAL")
63            .map_err(|e| map_err(e, "open_note_reader"))?;
64
65        Ok(conn)
66    }
67
68    /// Write via pool writer (serializes writes through the mutex).
69    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
70    where
71        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
72        R: Send + 'static,
73    {
74        let pool = Arc::clone(&self.pool);
75        tokio::task::spawn_blocking(move || {
76            let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
77            f(guard.conn()).map_err(|e| map_err(e, op))
78        })
79        .await
80        .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
81    }
82
83    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
84    where
85        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
86        R: Send + 'static,
87    {
88        if self.is_file_backed {
89            let conn = self.open_standalone_reader()?;
90            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
91                .await
92                .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
93        } else {
94            let pool = Arc::clone(&self.pool);
95            tokio::task::spawn_blocking(move || {
96                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
97                f(guard.conn()).map_err(|e| map_err(e, op))
98            })
99            .await
100            .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
101        }
102    }
103}
104
105// =============================================================================
106// Helpers
107// =============================================================================
108
109fn read_note(row: &rusqlite::Row<'_>) -> Result<Note, rusqlite::Error> {
110    let id_str: String = row.get(0)?;
111    let namespace: String = row.get(1)?;
112    let kind: String = row.get(2)?;
113    let name: Option<String> = row.get(3)?;
114    let content: String = row.get(4)?;
115    let salience: f64 = row.get(5)?;
116    let decay_factor: f64 = row.get(6)?;
117    let expires_at: Option<i64> = row.get(7)?;
118    let properties_str: Option<String> = row.get(8)?;
119    let created_at: i64 = row.get(9)?;
120    let updated_at: i64 = row.get(10)?;
121    let deleted_at: Option<i64> = row.get(11)?;
122
123    let id = parse_uuid(&id_str)?;
124
125    let properties = properties_str
126        .map(|s| {
127            serde_json::from_str(&s).map_err(|e| {
128                rusqlite::Error::FromSqlConversionFailure(
129                    8,
130                    rusqlite::types::Type::Text,
131                    Box::new(e),
132                )
133            })
134        })
135        .transpose()?;
136
137    Ok(Note {
138        id,
139        namespace,
140        kind,
141        name,
142        content,
143        salience,
144        decay_factor,
145        expires_at,
146        properties,
147        created_at,
148        updated_at,
149        deleted_at,
150    })
151}
152
153fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
154    Uuid::parse_str(s).map_err(|e| {
155        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
156    })
157}
158
159fn build_note_where(
160    namespace: &str,
161    kind: Option<&str>,
162) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
163    let mut conditions: Vec<String> = vec![
164        "namespace = ?1".to_string(),
165        "deleted_at IS NULL".to_string(),
166    ];
167    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
168
169    if let Some(k) = kind {
170        params.push(Box::new(k.to_string()));
171        conditions.push(format!("kind = ?{}", params.len()));
172    }
173
174    let clause = format!(" WHERE {}", conditions.join(" AND "));
175    (clause, params)
176}
177
178// =============================================================================
179// NoteStore implementation
180// =============================================================================
181
182#[async_trait]
183impl NoteStore for SqlNoteStore {
184    async fn upsert_note(&self, note: Note) -> Result<(), StorageError> {
185        let namespace = note.namespace.clone();
186        let id_str = note.id.to_string();
187        let kind_str = note.kind.to_string();
188        let properties_str = note
189            .properties
190            .as_ref()
191            .map(|v| serde_json::to_string(v).unwrap_or_default());
192
193        self.with_writer("upsert_note", move |conn| {
194            conn.execute(
195                "INSERT OR REPLACE INTO notes \
196                 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
197                  properties, created_at, updated_at, deleted_at) \
198                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
199                rusqlite::params![
200                    id_str,
201                    namespace,
202                    kind_str,
203                    note.name,
204                    note.content,
205                    note.salience,
206                    note.decay_factor,
207                    note.expires_at,
208                    properties_str,
209                    note.created_at,
210                    note.updated_at,
211                    note.deleted_at,
212                ],
213            )?;
214            Ok(())
215        })
216        .await
217    }
218
219    async fn upsert_notes(&self, notes: Vec<Note>) -> Result<BatchWriteSummary, StorageError> {
220        let attempted = notes.len() as u64;
221
222        self.with_writer("upsert_notes", move |conn| {
223            conn.execute_batch("BEGIN IMMEDIATE")?;
224            let mut affected = 0u64;
225            let mut failed = 0u64;
226            let mut first_error = String::new();
227
228            for note in &notes {
229                let id_str = note.id.to_string();
230                let kind_str = note.kind.to_string();
231                let properties_str = note
232                    .properties
233                    .as_ref()
234                    .map(|v| serde_json::to_string(v).unwrap_or_default());
235
236                match conn.execute(
237                    "INSERT OR REPLACE INTO notes \
238                     (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
239                      properties, created_at, updated_at, deleted_at) \
240                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
241                    rusqlite::params![
242                        id_str,
243                        &note.namespace,
244                        kind_str,
245                        &note.name,
246                        note.content,
247                        note.salience,
248                        note.decay_factor,
249                        note.expires_at,
250                        properties_str,
251                        note.created_at,
252                        note.updated_at,
253                        note.deleted_at,
254                    ],
255                ) {
256                    Ok(_) => affected += 1,
257                    Err(e) => {
258                        if first_error.is_empty() {
259                            first_error = e.to_string();
260                        }
261                        failed += 1;
262                    }
263                }
264            }
265
266            if let Err(e) = conn.execute_batch("COMMIT") {
267                let _ = conn.execute_batch("ROLLBACK");
268                return Err(e);
269            }
270            Ok(BatchWriteSummary {
271                attempted,
272                affected,
273                failed,
274                first_error,
275            })
276        })
277        .await
278    }
279
280    async fn get_note(&self, id: Uuid) -> Result<Option<Note>, StorageError> {
281        let id_str = id.to_string();
282
283        self.with_reader("get_note", move |conn| {
284            let mut stmt = conn.prepare(
285                "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
286                 properties, created_at, updated_at, deleted_at \
287                 FROM notes WHERE id = ?1 AND deleted_at IS NULL",
288            )?;
289            let mut rows = stmt.query(rusqlite::params![id_str])?;
290            match rows.next()? {
291                Some(row) => Ok(Some(read_note(row)?)),
292                None => Ok(None),
293            }
294        })
295        .await
296    }
297
298    async fn delete_note(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
299        let id_str = id.to_string();
300
301        match mode {
302            DeleteMode::Soft => {
303                self.with_writer("delete_note_soft", move |conn| {
304                    let now = chrono::Utc::now().timestamp_micros();
305                    let deleted = conn.execute(
306                        "UPDATE notes SET deleted_at = ?1 \
307                         WHERE id = ?2 AND deleted_at IS NULL",
308                        rusqlite::params![now, id_str],
309                    )?;
310                    Ok(deleted > 0)
311                })
312                .await
313            }
314            DeleteMode::Hard => {
315                self.with_writer("delete_note_hard", move |conn| {
316                    let deleted =
317                        conn.execute("DELETE FROM notes WHERE id = ?1", rusqlite::params![id_str])?;
318                    Ok(deleted > 0)
319                })
320                .await
321            }
322        }
323    }
324
325    async fn query_notes(
326        &self,
327        namespace: &str,
328        kind: Option<&str>,
329        page: PageRequest,
330    ) -> Result<Page<Note>, StorageError> {
331        let namespace = namespace.to_string();
332        let kind = kind.map(|k| k.to_string());
333
334        self.with_reader("query_notes", move |conn| {
335            let (count_sql, count_params) = build_note_where(&namespace, kind.as_deref());
336            let total: i64 = {
337                let sql = format!("SELECT COUNT(*) FROM notes{}", count_sql);
338                let mut stmt = conn.prepare(&sql)?;
339                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
340                    count_params.iter().map(|p| p.as_ref()).collect();
341                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
342            };
343
344            let (where_sql, mut data_params) = build_note_where(&namespace, kind.as_deref());
345            data_params.push(Box::new(page.limit as i64));
346            data_params.push(Box::new(page.offset as i64));
347
348            let limit_idx = data_params.len() - 1;
349            let offset_idx = data_params.len();
350
351            let data_sql = format!(
352                "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
353                 properties, created_at, updated_at, deleted_at \
354                 FROM notes{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
355                where_sql, limit_idx, offset_idx,
356            );
357
358            let mut stmt = conn.prepare(&data_sql)?;
359            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
360                data_params.iter().map(|p| p.as_ref()).collect();
361            let rows = stmt.query_map(param_refs.as_slice(), read_note)?;
362
363            let mut items = Vec::new();
364            for row in rows {
365                items.push(row?);
366            }
367
368            Ok(Page {
369                items,
370                total: Some(total as u64),
371            })
372        })
373        .await
374    }
375
376    async fn count_notes(&self, namespace: &str, kind: Option<&str>) -> Result<u64, StorageError> {
377        let namespace = namespace.to_string();
378        let kind = kind.map(|k| k.to_string());
379
380        self.with_reader("count_notes", move |conn| {
381            let (where_sql, params) = build_note_where(&namespace, kind.as_deref());
382            let sql = format!("SELECT COUNT(*) FROM notes{}", where_sql);
383            let mut stmt = conn.prepare(&sql)?;
384            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
385                params.iter().map(|p| p.as_ref()).collect();
386            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
387            Ok(count as u64)
388        })
389        .await
390    }
391
392    async fn upsert_note_if_below_quota(
393        &self,
394        note: Note,
395        max_notes: u64,
396    ) -> Result<bool, StorageError> {
397        let namespace = note.namespace.clone();
398        let id_str = note.id.to_string();
399        let kind_str = note.kind.to_string();
400        let properties_str = note
401            .properties
402            .as_ref()
403            .map(|v| serde_json::to_string(v).unwrap_or_default());
404
405        self.with_writer("upsert_note_if_below_quota", move |conn| {
406            let count: i64 = conn.query_row(
407                "SELECT COUNT(*) FROM notes WHERE namespace = ?1 AND deleted_at IS NULL",
408                [&namespace],
409                |row| row.get(0),
410            )?;
411            if count as u64 >= max_notes {
412                return Ok(false);
413            }
414            conn.execute(
415                "INSERT OR REPLACE INTO notes \
416                 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
417                  properties, created_at, updated_at, deleted_at) \
418                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
419                rusqlite::params![
420                    id_str,
421                    namespace,
422                    kind_str,
423                    note.name,
424                    note.content,
425                    note.salience,
426                    note.decay_factor,
427                    note.expires_at,
428                    properties_str,
429                    note.created_at,
430                    note.updated_at,
431                    note.deleted_at,
432                ],
433            )?;
434            Ok(true)
435        })
436        .await
437    }
438}
439
440// =============================================================================
441// DDL
442// =============================================================================
443
444const NOTES_DDL: &str = "\
445    CREATE TABLE IF NOT EXISTS notes (\
446        id TEXT PRIMARY KEY,\
447        namespace TEXT NOT NULL,\
448        kind TEXT NOT NULL,\
449        name TEXT,\
450        content TEXT NOT NULL DEFAULT '',\
451        salience REAL NOT NULL DEFAULT 0.5,\
452        decay_factor REAL NOT NULL DEFAULT 0.0,\
453        expires_at INTEGER,\
454        properties TEXT,\
455        created_at INTEGER NOT NULL,\
456        updated_at INTEGER NOT NULL,\
457        deleted_at INTEGER\
458    );\
459    CREATE INDEX IF NOT EXISTS idx_notes_namespace ON notes(namespace);\
460    CREATE INDEX IF NOT EXISTS idx_notes_kind ON notes(namespace, kind);\
461    CREATE INDEX IF NOT EXISTS idx_notes_created ON notes(created_at DESC);\
462";
463
464pub(crate) fn ensure_notes_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
465    conn.execute_batch(NOTES_DDL)
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use crate::pool::PoolConfig;
472
473    fn setup_pool() -> Arc<ConnectionPool> {
474        let config = PoolConfig {
475            path: None,
476            ..PoolConfig::default()
477        };
478        let pool = Arc::new(ConnectionPool::new(config).unwrap());
479        {
480            let writer = pool.writer().unwrap();
481            writer.conn().execute_batch(NOTES_DDL).unwrap();
482        }
483        pool
484    }
485
486    fn setup_memory_store() -> SqlNoteStore {
487        SqlNoteStore::new(setup_pool(), false)
488    }
489
490    fn make_note(namespace: &str, kind: &str, content: &str) -> Note {
491        Note::new(namespace, kind, content)
492    }
493
494    #[tokio::test]
495    async fn test_upsert_and_get_note() {
496        let store = setup_memory_store();
497
498        let note = make_note("default", "observation", "Hello world");
499        let id = note.id;
500
501        store.upsert_note(note).await.unwrap();
502
503        let fetched = store.get_note(id).await.unwrap();
504        assert!(fetched.is_some());
505        let fetched = fetched.unwrap();
506        assert_eq!(fetched.id, id);
507        assert_eq!(fetched.content, "Hello world");
508        assert_eq!(fetched.kind, "observation");
509    }
510
511    #[tokio::test]
512    async fn test_kind_roundtrip_all_variants() {
513        let store = setup_memory_store();
514        for kind in [
515            "observation",
516            "insight",
517            "question",
518            "decision",
519            "reference",
520        ] {
521            let note = make_note("default", kind, "content");
522            let id = note.id;
523            store.upsert_note(note).await.unwrap();
524            let fetched = store.get_note(id).await.unwrap().unwrap();
525            assert_eq!(fetched.kind, kind);
526        }
527    }
528
529    #[tokio::test]
530    async fn test_soft_delete() {
531        let store = setup_memory_store();
532
533        let note = make_note("default", "observation", "to be deleted");
534        let id = note.id;
535        store.upsert_note(note).await.unwrap();
536
537        let deleted = store.delete_note(id, DeleteMode::Soft).await.unwrap();
538        assert!(deleted);
539
540        let fetched = store.get_note(id).await.unwrap();
541        assert!(fetched.is_none());
542    }
543
544    #[tokio::test]
545    async fn test_hard_delete() {
546        let store = setup_memory_store();
547
548        let note = make_note("default", "observation", "to be hard deleted");
549        let id = note.id;
550        store.upsert_note(note).await.unwrap();
551
552        let deleted = store.delete_note(id, DeleteMode::Hard).await.unwrap();
553        assert!(deleted);
554
555        let fetched = store.get_note(id).await.unwrap();
556        assert!(fetched.is_none());
557    }
558
559    /// Namespace isolation: one store, two namespaces — each query sees only its own.
560    #[tokio::test]
561    async fn test_namespace_isolation() {
562        let pool = setup_pool();
563        let store = SqlNoteStore::new(Arc::clone(&pool), false);
564
565        for _ in 0..3 {
566            store
567                .upsert_note(make_note("ns1", "observation", "content"))
568                .await
569                .unwrap();
570        }
571        store
572            .upsert_note(make_note("ns2", "observation", "other"))
573            .await
574            .unwrap();
575
576        let count_ns1 = store.count_notes("ns1", None).await.unwrap();
577        assert_eq!(count_ns1, 3);
578
579        let count_ns2 = store.count_notes("ns2", None).await.unwrap();
580        assert_eq!(count_ns2, 1);
581    }
582
583    #[tokio::test]
584    async fn test_quota() {
585        let pool = setup_pool();
586        let store = SqlNoteStore::new(Arc::clone(&pool), false);
587
588        for _ in 0..3 {
589            let inserted = store
590                .upsert_note_if_below_quota(make_note("quota_ns", "observation", "x"), 3)
591                .await
592                .unwrap();
593            assert!(inserted);
594        }
595
596        let inserted = store
597            .upsert_note_if_below_quota(make_note("quota_ns", "observation", "x"), 3)
598            .await
599            .unwrap();
600        assert!(!inserted);
601    }
602
603    /// query_notes and count_notes use the namespace parameter as passed.
604    #[tokio::test]
605    async fn test_query_and_count_use_caller_namespace() {
606        let pool = setup_pool();
607        let store = SqlNoteStore::new(Arc::clone(&pool), false);
608
609        store
610            .upsert_note(make_note("ns_a", "observation", "A"))
611            .await
612            .unwrap();
613        store
614            .upsert_note(make_note("ns_b", "insight", "B"))
615            .await
616            .unwrap();
617
618        let page_a = store
619            .query_notes("ns_a", None, PageRequest::default())
620            .await
621            .unwrap();
622        assert_eq!(page_a.items.len(), 1);
623        assert_eq!(page_a.items[0].content, "A");
624
625        let page_b = store
626            .query_notes("ns_b", None, PageRequest::default())
627            .await
628            .unwrap();
629        assert_eq!(page_b.items.len(), 1);
630        assert_eq!(page_b.items[0].content, "B");
631
632        let count_a = store.count_notes("ns_a", None).await.unwrap();
633        let count_b = store.count_notes("ns_b", None).await.unwrap();
634        assert_eq!(count_a, 1);
635        assert_eq!(count_b, 1);
636    }
637}