Skip to main content

khive_db/stores/
note.rs

1//! SQL-backed `NoteStore` implementation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::error::StorageError;
9use khive_storage::note::{Note, NoteKind};
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::NoteStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18    StorageError::driver(StorageCapability::Notes, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22    StorageError::driver(StorageCapability::Notes, op, e)
23}
24
25/// A NoteStore backed by SQLite. Namespace is the caller's responsibility.
26///
27/// UUID is globally unique — get/delete by ID alone. Query/count use the
28/// namespace parameter as passed. The store is just a pool + is_file_backed.
29pub struct SqlNoteStore {
30    pool: Arc<ConnectionPool>,
31    is_file_backed: bool,
32}
33
34impl SqlNoteStore {
35    /// Create a new store.
36    pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37        Self {
38            pool,
39            is_file_backed,
40        }
41    }
42
43    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44        let config = self.pool.config();
45        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46            operation: "note_reader".into(),
47            message: "in-memory databases do not support standalone connections".into(),
48        })?;
49
50        let conn = rusqlite::Connection::open_with_flags(
51            path,
52            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55        )
56        .map_err(|e| map_err(e, "open_note_reader"))?;
57
58        conn.busy_timeout(config.busy_timeout)
59            .map_err(|e| map_err(e, "open_note_reader"))?;
60        conn.pragma_update(None, "foreign_keys", "ON")
61            .map_err(|e| map_err(e, "open_note_reader"))?;
62        conn.pragma_update(None, "synchronous", "NORMAL")
63            .map_err(|e| map_err(e, "open_note_reader"))?;
64
65        Ok(conn)
66    }
67
68    /// Write via pool writer (serializes writes through the mutex).
69    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
70    where
71        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
72        R: Send + 'static,
73    {
74        let pool = Arc::clone(&self.pool);
75        tokio::task::spawn_blocking(move || {
76            let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
77            f(guard.conn()).map_err(|e| map_err(e, op))
78        })
79        .await
80        .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
81    }
82
83    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
84    where
85        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
86        R: Send + 'static,
87    {
88        if self.is_file_backed {
89            let conn = self.open_standalone_reader()?;
90            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
91                .await
92                .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
93        } else {
94            let pool = Arc::clone(&self.pool);
95            tokio::task::spawn_blocking(move || {
96                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
97                f(guard.conn()).map_err(|e| map_err(e, op))
98            })
99            .await
100            .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
101        }
102    }
103}
104
105// =============================================================================
106// Helpers
107// =============================================================================
108
109fn parse_note_kind(s: &str, col: usize) -> Result<NoteKind, rusqlite::Error> {
110    s.parse::<NoteKind>().map_err(|e| {
111        rusqlite::Error::FromSqlConversionFailure(col, rusqlite::types::Type::Text, e.into())
112    })
113}
114
115fn read_note(row: &rusqlite::Row<'_>) -> Result<Note, rusqlite::Error> {
116    let id_str: String = row.get(0)?;
117    let namespace: String = row.get(1)?;
118    let kind_str: String = row.get(2)?;
119    let name: Option<String> = row.get(3)?;
120    let content: String = row.get(4)?;
121    let salience: f64 = row.get(5)?;
122    let decay_factor: f64 = row.get(6)?;
123    let expires_at: Option<i64> = row.get(7)?;
124    let properties_str: Option<String> = row.get(8)?;
125    let created_at: i64 = row.get(9)?;
126    let updated_at: i64 = row.get(10)?;
127    let deleted_at: Option<i64> = row.get(11)?;
128
129    let id = parse_uuid(&id_str)?;
130    let kind = parse_note_kind(&kind_str, 2)?;
131
132    let properties = properties_str
133        .map(|s| {
134            serde_json::from_str(&s).map_err(|e| {
135                rusqlite::Error::FromSqlConversionFailure(
136                    8,
137                    rusqlite::types::Type::Text,
138                    Box::new(e),
139                )
140            })
141        })
142        .transpose()?;
143
144    Ok(Note {
145        id,
146        namespace,
147        kind,
148        name,
149        content,
150        salience,
151        decay_factor,
152        expires_at,
153        properties,
154        created_at,
155        updated_at,
156        deleted_at,
157    })
158}
159
160fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
161    Uuid::parse_str(s).map_err(|e| {
162        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
163    })
164}
165
166fn build_note_where(
167    namespace: &str,
168    kind: Option<&str>,
169) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
170    let mut conditions: Vec<String> = vec![
171        "namespace = ?1".to_string(),
172        "deleted_at IS NULL".to_string(),
173    ];
174    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
175
176    if let Some(k) = kind {
177        params.push(Box::new(k.to_string()));
178        conditions.push(format!("kind = ?{}", params.len()));
179    }
180
181    let clause = format!(" WHERE {}", conditions.join(" AND "));
182    (clause, params)
183}
184
185// =============================================================================
186// NoteStore implementation
187// =============================================================================
188
189#[async_trait]
190impl NoteStore for SqlNoteStore {
191    async fn upsert_note(&self, note: Note) -> Result<(), StorageError> {
192        let namespace = note.namespace.clone();
193        let id_str = note.id.to_string();
194        let kind_str = note.kind.to_string();
195        let properties_str = note
196            .properties
197            .as_ref()
198            .map(|v| serde_json::to_string(v).unwrap_or_default());
199
200        self.with_writer("upsert_note", move |conn| {
201            conn.execute(
202                "INSERT OR REPLACE INTO notes \
203                 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
204                  properties, created_at, updated_at, deleted_at) \
205                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
206                rusqlite::params![
207                    id_str,
208                    namespace,
209                    kind_str,
210                    note.name,
211                    note.content,
212                    note.salience,
213                    note.decay_factor,
214                    note.expires_at,
215                    properties_str,
216                    note.created_at,
217                    note.updated_at,
218                    note.deleted_at,
219                ],
220            )?;
221            Ok(())
222        })
223        .await
224    }
225
226    async fn upsert_notes(&self, notes: Vec<Note>) -> Result<BatchWriteSummary, StorageError> {
227        let attempted = notes.len() as u64;
228
229        self.with_writer("upsert_notes", move |conn| {
230            conn.execute_batch("BEGIN IMMEDIATE")?;
231            let mut affected = 0u64;
232            let mut failed = 0u64;
233            let mut first_error = String::new();
234
235            for note in &notes {
236                let id_str = note.id.to_string();
237                let kind_str = note.kind.to_string();
238                let properties_str = note
239                    .properties
240                    .as_ref()
241                    .map(|v| serde_json::to_string(v).unwrap_or_default());
242
243                match conn.execute(
244                    "INSERT OR REPLACE INTO notes \
245                     (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
246                      properties, created_at, updated_at, deleted_at) \
247                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
248                    rusqlite::params![
249                        id_str,
250                        &note.namespace,
251                        kind_str,
252                        &note.name,
253                        note.content,
254                        note.salience,
255                        note.decay_factor,
256                        note.expires_at,
257                        properties_str,
258                        note.created_at,
259                        note.updated_at,
260                        note.deleted_at,
261                    ],
262                ) {
263                    Ok(_) => affected += 1,
264                    Err(e) => {
265                        if first_error.is_empty() {
266                            first_error = e.to_string();
267                        }
268                        failed += 1;
269                    }
270                }
271            }
272
273            if let Err(e) = conn.execute_batch("COMMIT") {
274                let _ = conn.execute_batch("ROLLBACK");
275                return Err(e);
276            }
277            Ok(BatchWriteSummary {
278                attempted,
279                affected,
280                failed,
281                first_error,
282            })
283        })
284        .await
285    }
286
287    async fn get_note(&self, id: Uuid) -> Result<Option<Note>, StorageError> {
288        let id_str = id.to_string();
289
290        self.with_reader("get_note", move |conn| {
291            let mut stmt = conn.prepare(
292                "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
293                 properties, created_at, updated_at, deleted_at \
294                 FROM notes WHERE id = ?1 AND deleted_at IS NULL",
295            )?;
296            let mut rows = stmt.query(rusqlite::params![id_str])?;
297            match rows.next()? {
298                Some(row) => Ok(Some(read_note(row)?)),
299                None => Ok(None),
300            }
301        })
302        .await
303    }
304
305    async fn delete_note(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
306        let id_str = id.to_string();
307
308        match mode {
309            DeleteMode::Soft => {
310                self.with_writer("delete_note_soft", move |conn| {
311                    let now = chrono::Utc::now().timestamp_micros();
312                    let deleted = conn.execute(
313                        "UPDATE notes SET deleted_at = ?1 \
314                         WHERE id = ?2 AND deleted_at IS NULL",
315                        rusqlite::params![now, id_str],
316                    )?;
317                    Ok(deleted > 0)
318                })
319                .await
320            }
321            DeleteMode::Hard => {
322                self.with_writer("delete_note_hard", move |conn| {
323                    let deleted =
324                        conn.execute("DELETE FROM notes WHERE id = ?1", rusqlite::params![id_str])?;
325                    Ok(deleted > 0)
326                })
327                .await
328            }
329        }
330    }
331
332    async fn query_notes(
333        &self,
334        namespace: &str,
335        kind: Option<NoteKind>,
336        page: PageRequest,
337    ) -> Result<Page<Note>, StorageError> {
338        let namespace = namespace.to_string();
339        let kind = kind.map(|k| k.to_string());
340
341        self.with_reader("query_notes", move |conn| {
342            let (count_sql, count_params) = build_note_where(&namespace, kind.as_deref());
343            let total: i64 = {
344                let sql = format!("SELECT COUNT(*) FROM notes{}", count_sql);
345                let mut stmt = conn.prepare(&sql)?;
346                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
347                    count_params.iter().map(|p| p.as_ref()).collect();
348                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
349            };
350
351            let (where_sql, mut data_params) = build_note_where(&namespace, kind.as_deref());
352            data_params.push(Box::new(page.limit as i64));
353            data_params.push(Box::new(page.offset as i64));
354
355            let limit_idx = data_params.len() - 1;
356            let offset_idx = data_params.len();
357
358            let data_sql = format!(
359                "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
360                 properties, created_at, updated_at, deleted_at \
361                 FROM notes{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
362                where_sql, limit_idx, offset_idx,
363            );
364
365            let mut stmt = conn.prepare(&data_sql)?;
366            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
367                data_params.iter().map(|p| p.as_ref()).collect();
368            let rows = stmt.query_map(param_refs.as_slice(), read_note)?;
369
370            let mut items = Vec::new();
371            for row in rows {
372                items.push(row?);
373            }
374
375            Ok(Page {
376                items,
377                total: Some(total as u64),
378            })
379        })
380        .await
381    }
382
383    async fn count_notes(
384        &self,
385        namespace: &str,
386        kind: Option<NoteKind>,
387    ) -> Result<u64, StorageError> {
388        let namespace = namespace.to_string();
389        let kind = kind.map(|k| k.to_string());
390
391        self.with_reader("count_notes", move |conn| {
392            let (where_sql, params) = build_note_where(&namespace, kind.as_deref());
393            let sql = format!("SELECT COUNT(*) FROM notes{}", where_sql);
394            let mut stmt = conn.prepare(&sql)?;
395            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
396                params.iter().map(|p| p.as_ref()).collect();
397            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
398            Ok(count as u64)
399        })
400        .await
401    }
402
403    async fn upsert_note_if_below_quota(
404        &self,
405        note: Note,
406        max_notes: u64,
407    ) -> Result<bool, StorageError> {
408        let namespace = note.namespace.clone();
409        let id_str = note.id.to_string();
410        let kind_str = note.kind.to_string();
411        let properties_str = note
412            .properties
413            .as_ref()
414            .map(|v| serde_json::to_string(v).unwrap_or_default());
415
416        self.with_writer("upsert_note_if_below_quota", move |conn| {
417            let count: i64 = conn.query_row(
418                "SELECT COUNT(*) FROM notes WHERE namespace = ?1 AND deleted_at IS NULL",
419                [&namespace],
420                |row| row.get(0),
421            )?;
422            if count as u64 >= max_notes {
423                return Ok(false);
424            }
425            conn.execute(
426                "INSERT OR REPLACE INTO notes \
427                 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
428                  properties, created_at, updated_at, deleted_at) \
429                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
430                rusqlite::params![
431                    id_str,
432                    namespace,
433                    kind_str,
434                    note.name,
435                    note.content,
436                    note.salience,
437                    note.decay_factor,
438                    note.expires_at,
439                    properties_str,
440                    note.created_at,
441                    note.updated_at,
442                    note.deleted_at,
443                ],
444            )?;
445            Ok(true)
446        })
447        .await
448    }
449}
450
451// =============================================================================
452// DDL
453// =============================================================================
454
455const NOTES_DDL: &str = "\
456    CREATE TABLE IF NOT EXISTS notes (\
457        id TEXT PRIMARY KEY,\
458        namespace TEXT NOT NULL,\
459        kind TEXT NOT NULL,\
460        name TEXT,\
461        content TEXT NOT NULL DEFAULT '',\
462        salience REAL NOT NULL DEFAULT 0.5,\
463        decay_factor REAL NOT NULL DEFAULT 0.0,\
464        expires_at INTEGER,\
465        properties TEXT,\
466        created_at INTEGER NOT NULL,\
467        updated_at INTEGER NOT NULL,\
468        deleted_at INTEGER\
469    );\
470    CREATE INDEX IF NOT EXISTS idx_notes_namespace ON notes(namespace);\
471    CREATE INDEX IF NOT EXISTS idx_notes_kind ON notes(namespace, kind);\
472    CREATE INDEX IF NOT EXISTS idx_notes_created ON notes(created_at DESC);\
473";
474
475pub(crate) fn ensure_notes_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
476    conn.execute_batch(NOTES_DDL)
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::pool::PoolConfig;
483
484    fn setup_pool() -> Arc<ConnectionPool> {
485        let config = PoolConfig {
486            path: None,
487            ..PoolConfig::default()
488        };
489        let pool = Arc::new(ConnectionPool::new(config).unwrap());
490        {
491            let writer = pool.writer().unwrap();
492            writer.conn().execute_batch(NOTES_DDL).unwrap();
493        }
494        pool
495    }
496
497    fn setup_memory_store() -> SqlNoteStore {
498        SqlNoteStore::new(setup_pool(), false)
499    }
500
501    fn make_note(namespace: &str, kind: NoteKind, content: &str) -> Note {
502        Note::new(namespace, kind, content)
503    }
504
505    #[tokio::test]
506    async fn test_upsert_and_get_note() {
507        let store = setup_memory_store();
508
509        let note = make_note("default", NoteKind::Observation, "Hello world");
510        let id = note.id;
511
512        store.upsert_note(note).await.unwrap();
513
514        let fetched = store.get_note(id).await.unwrap();
515        assert!(fetched.is_some());
516        let fetched = fetched.unwrap();
517        assert_eq!(fetched.id, id);
518        assert_eq!(fetched.content, "Hello world");
519        assert_eq!(fetched.kind, NoteKind::Observation);
520    }
521
522    #[tokio::test]
523    async fn test_kind_roundtrip_all_variants() {
524        let store = setup_memory_store();
525        for kind in [
526            NoteKind::Observation,
527            NoteKind::Insight,
528            NoteKind::Question,
529            NoteKind::Decision,
530            NoteKind::Reference,
531        ] {
532            let note = make_note("default", kind, "content");
533            let id = note.id;
534            store.upsert_note(note).await.unwrap();
535            let fetched = store.get_note(id).await.unwrap().unwrap();
536            assert_eq!(fetched.kind, kind);
537        }
538    }
539
540    #[tokio::test]
541    async fn test_soft_delete() {
542        let store = setup_memory_store();
543
544        let note = make_note("default", NoteKind::Observation, "to be deleted");
545        let id = note.id;
546        store.upsert_note(note).await.unwrap();
547
548        let deleted = store.delete_note(id, DeleteMode::Soft).await.unwrap();
549        assert!(deleted);
550
551        let fetched = store.get_note(id).await.unwrap();
552        assert!(fetched.is_none());
553    }
554
555    #[tokio::test]
556    async fn test_hard_delete() {
557        let store = setup_memory_store();
558
559        let note = make_note("default", NoteKind::Observation, "to be hard deleted");
560        let id = note.id;
561        store.upsert_note(note).await.unwrap();
562
563        let deleted = store.delete_note(id, DeleteMode::Hard).await.unwrap();
564        assert!(deleted);
565
566        let fetched = store.get_note(id).await.unwrap();
567        assert!(fetched.is_none());
568    }
569
570    /// Namespace isolation: one store, two namespaces — each query sees only its own.
571    #[tokio::test]
572    async fn test_namespace_isolation() {
573        let pool = setup_pool();
574        let store = SqlNoteStore::new(Arc::clone(&pool), false);
575
576        for _ in 0..3 {
577            store
578                .upsert_note(make_note("ns1", NoteKind::Observation, "content"))
579                .await
580                .unwrap();
581        }
582        store
583            .upsert_note(make_note("ns2", NoteKind::Observation, "other"))
584            .await
585            .unwrap();
586
587        let count_ns1 = store.count_notes("ns1", None).await.unwrap();
588        assert_eq!(count_ns1, 3);
589
590        let count_ns2 = store.count_notes("ns2", None).await.unwrap();
591        assert_eq!(count_ns2, 1);
592    }
593
594    #[tokio::test]
595    async fn test_quota() {
596        let pool = setup_pool();
597        let store = SqlNoteStore::new(Arc::clone(&pool), false);
598
599        for _ in 0..3 {
600            let inserted = store
601                .upsert_note_if_below_quota(make_note("quota_ns", NoteKind::Observation, "x"), 3)
602                .await
603                .unwrap();
604            assert!(inserted);
605        }
606
607        let inserted = store
608            .upsert_note_if_below_quota(make_note("quota_ns", NoteKind::Observation, "x"), 3)
609            .await
610            .unwrap();
611        assert!(!inserted);
612    }
613
614    /// query_notes and count_notes use the namespace parameter as passed.
615    #[tokio::test]
616    async fn test_query_and_count_use_caller_namespace() {
617        let pool = setup_pool();
618        let store = SqlNoteStore::new(Arc::clone(&pool), false);
619
620        store
621            .upsert_note(make_note("ns_a", NoteKind::Observation, "A"))
622            .await
623            .unwrap();
624        store
625            .upsert_note(make_note("ns_b", NoteKind::Insight, "B"))
626            .await
627            .unwrap();
628
629        let page_a = store
630            .query_notes("ns_a", None, PageRequest::default())
631            .await
632            .unwrap();
633        assert_eq!(page_a.items.len(), 1);
634        assert_eq!(page_a.items[0].content, "A");
635
636        let page_b = store
637            .query_notes("ns_b", None, PageRequest::default())
638            .await
639            .unwrap();
640        assert_eq!(page_b.items.len(), 1);
641        assert_eq!(page_b.items[0].content, "B");
642
643        let count_a = store.count_notes("ns_a", None).await.unwrap();
644        let count_b = store.count_notes("ns_b", None).await.unwrap();
645        assert_eq!(count_a, 1);
646        assert_eq!(count_b, 1);
647    }
648}