Skip to main content

khive_db/stores/
note.rs

1//! SQL-backed `NoteStore` implementation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::error::StorageError;
9use khive_storage::note::Note;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::NoteStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18    StorageError::driver(StorageCapability::Notes, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22    StorageError::driver(StorageCapability::Notes, op, e)
23}
24
25/// A NoteStore backed by SQLite. Namespace is the caller's responsibility.
26///
27/// UUID is globally unique — get/delete by ID alone. Query/count use the
28/// namespace parameter as passed. The store is just a pool + is_file_backed.
29pub struct SqlNoteStore {
30    pool: Arc<ConnectionPool>,
31    is_file_backed: bool,
32}
33
34impl SqlNoteStore {
35    /// Create a new store.
36    pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37        Self {
38            pool,
39            is_file_backed,
40        }
41    }
42
43    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44        let config = self.pool.config();
45        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46            operation: "note_reader".into(),
47            message: "in-memory databases do not support standalone connections".into(),
48        })?;
49
50        let conn = rusqlite::Connection::open_with_flags(
51            path,
52            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55        )
56        .map_err(|e| map_err(e, "open_note_reader"))?;
57
58        conn.busy_timeout(config.busy_timeout)
59            .map_err(|e| map_err(e, "open_note_reader"))?;
60        conn.pragma_update(None, "foreign_keys", "ON")
61            .map_err(|e| map_err(e, "open_note_reader"))?;
62        conn.pragma_update(None, "synchronous", "NORMAL")
63            .map_err(|e| map_err(e, "open_note_reader"))?;
64
65        Ok(conn)
66    }
67
68    /// Write via pool writer (serializes writes through the mutex).
69    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
70    where
71        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
72        R: Send + 'static,
73    {
74        let pool = Arc::clone(&self.pool);
75        tokio::task::spawn_blocking(move || {
76            let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
77            f(guard.conn()).map_err(|e| map_err(e, op))
78        })
79        .await
80        .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
81    }
82
83    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
84    where
85        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
86        R: Send + 'static,
87    {
88        if self.is_file_backed {
89            let conn = self.open_standalone_reader()?;
90            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
91                .await
92                .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
93        } else {
94            let pool = Arc::clone(&self.pool);
95            tokio::task::spawn_blocking(move || {
96                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
97                f(guard.conn()).map_err(|e| map_err(e, op))
98            })
99            .await
100            .map_err(|e| StorageError::driver(StorageCapability::Notes, op, e))?
101        }
102    }
103}
104
105// =============================================================================
106// Helpers
107// =============================================================================
108
109fn read_note(row: &rusqlite::Row<'_>) -> Result<Note, rusqlite::Error> {
110    let id_str: String = row.get(0)?;
111    let namespace: String = row.get(1)?;
112    let kind: String = row.get(2)?;
113    let name: Option<String> = row.get(3)?;
114    let content: String = row.get(4)?;
115    let salience: f64 = row.get(5)?;
116    let decay_factor: f64 = row.get(6)?;
117    let expires_at: Option<i64> = row.get(7)?;
118    let properties_str: Option<String> = row.get(8)?;
119    let created_at: i64 = row.get(9)?;
120    let updated_at: i64 = row.get(10)?;
121    let deleted_at: Option<i64> = row.get(11)?;
122
123    let id = parse_uuid(&id_str)?;
124
125    let properties = properties_str
126        .map(|s| {
127            serde_json::from_str(&s).map_err(|e| {
128                rusqlite::Error::FromSqlConversionFailure(
129                    8,
130                    rusqlite::types::Type::Text,
131                    Box::new(e),
132                )
133            })
134        })
135        .transpose()?;
136
137    Ok(Note {
138        id,
139        namespace,
140        kind,
141        name,
142        content,
143        salience,
144        decay_factor,
145        expires_at,
146        properties,
147        created_at,
148        updated_at,
149        deleted_at,
150    })
151}
152
153fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
154    Uuid::parse_str(s).map_err(|e| {
155        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
156    })
157}
158
159fn build_note_where(
160    namespace: &str,
161    kind: Option<&str>,
162) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
163    let mut conditions: Vec<String> = vec![
164        "namespace = ?1".to_string(),
165        "deleted_at IS NULL".to_string(),
166    ];
167    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
168
169    if let Some(k) = kind {
170        params.push(Box::new(k.to_string()));
171        conditions.push(format!("kind = ?{}", params.len()));
172    }
173
174    let clause = format!(" WHERE {}", conditions.join(" AND "));
175    (clause, params)
176}
177
178// =============================================================================
179// NoteStore implementation
180// =============================================================================
181
182#[async_trait]
183impl NoteStore for SqlNoteStore {
184    async fn upsert_note(&self, note: Note) -> Result<(), StorageError> {
185        let namespace = note.namespace.clone();
186        let id_str = note.id.to_string();
187        let kind_str = note.kind.to_string();
188        let properties_str = note
189            .properties
190            .as_ref()
191            .map(|v| serde_json::to_string(v).unwrap_or_default());
192
193        self.with_writer("upsert_note", move |conn| {
194            conn.execute(
195                "INSERT OR REPLACE INTO notes \
196                 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
197                  properties, created_at, updated_at, deleted_at) \
198                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
199                rusqlite::params![
200                    id_str,
201                    namespace,
202                    kind_str,
203                    note.name,
204                    note.content,
205                    note.salience,
206                    note.decay_factor,
207                    note.expires_at,
208                    properties_str,
209                    note.created_at,
210                    note.updated_at,
211                    note.deleted_at,
212                ],
213            )?;
214            Ok(())
215        })
216        .await
217    }
218
219    async fn upsert_notes(&self, notes: Vec<Note>) -> Result<BatchWriteSummary, StorageError> {
220        let attempted = notes.len() as u64;
221
222        self.with_writer("upsert_notes", move |conn| {
223            conn.execute_batch("BEGIN IMMEDIATE")?;
224            let mut affected = 0u64;
225            let mut failed = 0u64;
226            let mut first_error = String::new();
227
228            for note in &notes {
229                let id_str = note.id.to_string();
230                let kind_str = note.kind.to_string();
231                let properties_str = note
232                    .properties
233                    .as_ref()
234                    .map(|v| serde_json::to_string(v).unwrap_or_default());
235
236                match conn.execute(
237                    "INSERT OR REPLACE INTO notes \
238                     (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
239                      properties, created_at, updated_at, deleted_at) \
240                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
241                    rusqlite::params![
242                        id_str,
243                        &note.namespace,
244                        kind_str,
245                        &note.name,
246                        note.content,
247                        note.salience,
248                        note.decay_factor,
249                        note.expires_at,
250                        properties_str,
251                        note.created_at,
252                        note.updated_at,
253                        note.deleted_at,
254                    ],
255                ) {
256                    Ok(_) => affected += 1,
257                    Err(e) => {
258                        if first_error.is_empty() {
259                            first_error = e.to_string();
260                        }
261                        failed += 1;
262                    }
263                }
264            }
265
266            if let Err(e) = conn.execute_batch("COMMIT") {
267                let _ = conn.execute_batch("ROLLBACK");
268                return Err(e);
269            }
270            Ok(BatchWriteSummary {
271                attempted,
272                affected,
273                failed,
274                first_error,
275            })
276        })
277        .await
278    }
279
280    async fn get_note(&self, id: Uuid) -> Result<Option<Note>, StorageError> {
281        let id_str = id.to_string();
282
283        self.with_reader("get_note", move |conn| {
284            let mut stmt = conn.prepare(
285                "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
286                 properties, created_at, updated_at, deleted_at \
287                 FROM notes WHERE id = ?1 AND deleted_at IS NULL",
288            )?;
289            let mut rows = stmt.query(rusqlite::params![id_str])?;
290            match rows.next()? {
291                Some(row) => Ok(Some(read_note(row)?)),
292                None => Ok(None),
293            }
294        })
295        .await
296    }
297
298    async fn get_notes_batch(&self, ids: &[Uuid]) -> Result<Vec<Note>, StorageError> {
299        if ids.is_empty() {
300            return Ok(vec![]);
301        }
302        let id_strings: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
303
304        self.with_reader("get_notes_batch", move |conn| {
305            let placeholders: String = (1..=id_strings.len())
306                .map(|i| format!("?{i}"))
307                .collect::<Vec<_>>()
308                .join(", ");
309            let sql = format!(
310                "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
311                 properties, created_at, updated_at, deleted_at \
312                 FROM notes WHERE id IN ({placeholders}) AND deleted_at IS NULL"
313            );
314            let mut stmt = conn.prepare(&sql)?;
315            let params: Vec<&dyn rusqlite::types::ToSql> = id_strings
316                .iter()
317                .map(|s| s as &dyn rusqlite::types::ToSql)
318                .collect();
319            let rows = stmt.query_map(params.as_slice(), read_note)?;
320            let mut out = Vec::new();
321            for row in rows {
322                out.push(row?);
323            }
324            Ok(out)
325        })
326        .await
327    }
328
329    async fn delete_note(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
330        let id_str = id.to_string();
331
332        match mode {
333            DeleteMode::Soft => {
334                self.with_writer("delete_note_soft", move |conn| {
335                    let now = chrono::Utc::now().timestamp_micros();
336                    let deleted = conn.execute(
337                        "UPDATE notes SET deleted_at = ?1 \
338                         WHERE id = ?2 AND deleted_at IS NULL",
339                        rusqlite::params![now, id_str],
340                    )?;
341                    Ok(deleted > 0)
342                })
343                .await
344            }
345            DeleteMode::Hard => {
346                self.with_writer("delete_note_hard", move |conn| {
347                    let deleted =
348                        conn.execute("DELETE FROM notes WHERE id = ?1", rusqlite::params![id_str])?;
349                    Ok(deleted > 0)
350                })
351                .await
352            }
353        }
354    }
355
356    async fn query_notes(
357        &self,
358        namespace: &str,
359        kind: Option<&str>,
360        page: PageRequest,
361    ) -> Result<Page<Note>, StorageError> {
362        let namespace = namespace.to_string();
363        let kind = kind.map(|k| k.to_string());
364
365        self.with_reader("query_notes", move |conn| {
366            let (count_sql, count_params) = build_note_where(&namespace, kind.as_deref());
367            let total: i64 = {
368                let sql = format!("SELECT COUNT(*) FROM notes{}", count_sql);
369                let mut stmt = conn.prepare(&sql)?;
370                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
371                    count_params.iter().map(|p| p.as_ref()).collect();
372                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
373            };
374
375            let (where_sql, mut data_params) = build_note_where(&namespace, kind.as_deref());
376            data_params.push(Box::new(page.limit as i64));
377            data_params.push(Box::new(page.offset as i64));
378
379            let limit_idx = data_params.len() - 1;
380            let offset_idx = data_params.len();
381
382            let data_sql = format!(
383                "SELECT id, namespace, kind, name, content, salience, decay_factor, expires_at, \
384                 properties, created_at, updated_at, deleted_at \
385                 FROM notes{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
386                where_sql, limit_idx, offset_idx,
387            );
388
389            let mut stmt = conn.prepare(&data_sql)?;
390            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
391                data_params.iter().map(|p| p.as_ref()).collect();
392            let rows = stmt.query_map(param_refs.as_slice(), read_note)?;
393
394            let mut items = Vec::new();
395            for row in rows {
396                items.push(row?);
397            }
398
399            Ok(Page {
400                items,
401                total: Some(total as u64),
402            })
403        })
404        .await
405    }
406
407    async fn count_notes(&self, namespace: &str, kind: Option<&str>) -> Result<u64, StorageError> {
408        let namespace = namespace.to_string();
409        let kind = kind.map(|k| k.to_string());
410
411        self.with_reader("count_notes", move |conn| {
412            let (where_sql, params) = build_note_where(&namespace, kind.as_deref());
413            let sql = format!("SELECT COUNT(*) FROM notes{}", where_sql);
414            let mut stmt = conn.prepare(&sql)?;
415            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
416                params.iter().map(|p| p.as_ref()).collect();
417            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
418            Ok(count as u64)
419        })
420        .await
421    }
422
423    async fn upsert_note_if_below_quota(
424        &self,
425        note: Note,
426        max_notes: u64,
427    ) -> Result<bool, StorageError> {
428        let namespace = note.namespace.clone();
429        let id_str = note.id.to_string();
430        let kind_str = note.kind.to_string();
431        let properties_str = note
432            .properties
433            .as_ref()
434            .map(|v| serde_json::to_string(v).unwrap_or_default());
435
436        self.with_writer("upsert_note_if_below_quota", move |conn| {
437            let count: i64 = conn.query_row(
438                "SELECT COUNT(*) FROM notes WHERE namespace = ?1 AND deleted_at IS NULL",
439                [&namespace],
440                |row| row.get(0),
441            )?;
442            if count as u64 >= max_notes {
443                return Ok(false);
444            }
445            conn.execute(
446                "INSERT OR REPLACE INTO notes \
447                 (id, namespace, kind, name, content, salience, decay_factor, expires_at, \
448                  properties, created_at, updated_at, deleted_at) \
449                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
450                rusqlite::params![
451                    id_str,
452                    namespace,
453                    kind_str,
454                    note.name,
455                    note.content,
456                    note.salience,
457                    note.decay_factor,
458                    note.expires_at,
459                    properties_str,
460                    note.created_at,
461                    note.updated_at,
462                    note.deleted_at,
463                ],
464            )?;
465            Ok(true)
466        })
467        .await
468    }
469}
470
471// =============================================================================
472// DDL
473// =============================================================================
474
475const NOTES_DDL: &str = "\
476    CREATE TABLE IF NOT EXISTS notes (\
477        id TEXT PRIMARY KEY,\
478        namespace TEXT NOT NULL,\
479        kind TEXT NOT NULL,\
480        name TEXT,\
481        content TEXT NOT NULL DEFAULT '',\
482        salience REAL NOT NULL DEFAULT 0.5,\
483        decay_factor REAL NOT NULL DEFAULT 0.0,\
484        expires_at INTEGER,\
485        properties TEXT,\
486        created_at INTEGER NOT NULL,\
487        updated_at INTEGER NOT NULL,\
488        deleted_at INTEGER\
489    );\
490    CREATE INDEX IF NOT EXISTS idx_notes_namespace ON notes(namespace);\
491    CREATE INDEX IF NOT EXISTS idx_notes_kind ON notes(namespace, kind);\
492    CREATE INDEX IF NOT EXISTS idx_notes_created ON notes(created_at DESC);\
493";
494
495pub(crate) fn ensure_notes_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
496    conn.execute_batch(NOTES_DDL)
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use crate::pool::PoolConfig;
503
504    fn setup_pool() -> Arc<ConnectionPool> {
505        let config = PoolConfig {
506            path: None,
507            ..PoolConfig::default()
508        };
509        let pool = Arc::new(ConnectionPool::new(config).unwrap());
510        {
511            let writer = pool.writer().unwrap();
512            writer.conn().execute_batch(NOTES_DDL).unwrap();
513        }
514        pool
515    }
516
517    fn setup_memory_store() -> SqlNoteStore {
518        SqlNoteStore::new(setup_pool(), false)
519    }
520
521    fn make_note(namespace: &str, kind: &str, content: &str) -> Note {
522        Note::new(namespace, kind, content)
523    }
524
525    #[tokio::test]
526    async fn test_upsert_and_get_note() {
527        let store = setup_memory_store();
528
529        let note = make_note("default", "observation", "Hello world");
530        let id = note.id;
531
532        store.upsert_note(note).await.unwrap();
533
534        let fetched = store.get_note(id).await.unwrap();
535        assert!(fetched.is_some());
536        let fetched = fetched.unwrap();
537        assert_eq!(fetched.id, id);
538        assert_eq!(fetched.content, "Hello world");
539        assert_eq!(fetched.kind, "observation");
540    }
541
542    #[tokio::test]
543    async fn test_kind_roundtrip_all_variants() {
544        let store = setup_memory_store();
545        for kind in [
546            "observation",
547            "insight",
548            "question",
549            "decision",
550            "reference",
551        ] {
552            let note = make_note("default", kind, "content");
553            let id = note.id;
554            store.upsert_note(note).await.unwrap();
555            let fetched = store.get_note(id).await.unwrap().unwrap();
556            assert_eq!(fetched.kind, kind);
557        }
558    }
559
560    #[tokio::test]
561    async fn test_soft_delete() {
562        let store = setup_memory_store();
563
564        let note = make_note("default", "observation", "to be deleted");
565        let id = note.id;
566        store.upsert_note(note).await.unwrap();
567
568        let deleted = store.delete_note(id, DeleteMode::Soft).await.unwrap();
569        assert!(deleted);
570
571        let fetched = store.get_note(id).await.unwrap();
572        assert!(fetched.is_none());
573    }
574
575    #[tokio::test]
576    async fn test_hard_delete() {
577        let store = setup_memory_store();
578
579        let note = make_note("default", "observation", "to be hard deleted");
580        let id = note.id;
581        store.upsert_note(note).await.unwrap();
582
583        let deleted = store.delete_note(id, DeleteMode::Hard).await.unwrap();
584        assert!(deleted);
585
586        let fetched = store.get_note(id).await.unwrap();
587        assert!(fetched.is_none());
588    }
589
590    /// Namespace isolation: one store, two namespaces — each query sees only its own.
591    #[tokio::test]
592    async fn test_namespace_isolation() {
593        let pool = setup_pool();
594        let store = SqlNoteStore::new(Arc::clone(&pool), false);
595
596        for _ in 0..3 {
597            store
598                .upsert_note(make_note("ns1", "observation", "content"))
599                .await
600                .unwrap();
601        }
602        store
603            .upsert_note(make_note("ns2", "observation", "other"))
604            .await
605            .unwrap();
606
607        let count_ns1 = store.count_notes("ns1", None).await.unwrap();
608        assert_eq!(count_ns1, 3);
609
610        let count_ns2 = store.count_notes("ns2", None).await.unwrap();
611        assert_eq!(count_ns2, 1);
612    }
613
614    #[tokio::test]
615    async fn test_quota() {
616        let pool = setup_pool();
617        let store = SqlNoteStore::new(Arc::clone(&pool), false);
618
619        for _ in 0..3 {
620            let inserted = store
621                .upsert_note_if_below_quota(make_note("quota_ns", "observation", "x"), 3)
622                .await
623                .unwrap();
624            assert!(inserted);
625        }
626
627        let inserted = store
628            .upsert_note_if_below_quota(make_note("quota_ns", "observation", "x"), 3)
629            .await
630            .unwrap();
631        assert!(!inserted);
632    }
633
634    /// query_notes and count_notes use the namespace parameter as passed.
635    #[tokio::test]
636    async fn test_query_and_count_use_caller_namespace() {
637        let pool = setup_pool();
638        let store = SqlNoteStore::new(Arc::clone(&pool), false);
639
640        store
641            .upsert_note(make_note("ns_a", "observation", "A"))
642            .await
643            .unwrap();
644        store
645            .upsert_note(make_note("ns_b", "insight", "B"))
646            .await
647            .unwrap();
648
649        let page_a = store
650            .query_notes("ns_a", None, PageRequest::default())
651            .await
652            .unwrap();
653        assert_eq!(page_a.items.len(), 1);
654        assert_eq!(page_a.items[0].content, "A");
655
656        let page_b = store
657            .query_notes("ns_b", None, PageRequest::default())
658            .await
659            .unwrap();
660        assert_eq!(page_b.items.len(), 1);
661        assert_eq!(page_b.items[0].content, "B");
662
663        let count_a = store.count_notes("ns_a", None).await.unwrap();
664        let count_b = store.count_notes("ns_b", None).await.unwrap();
665        assert_eq!(count_a, 1);
666        assert_eq!(count_b, 1);
667    }
668}