Skip to main content

khive_db/stores/
sparse.rs

1//! SQLite-backed `SparseStore` implementation (ADR-031).
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_score::DeterministicScore;
9use khive_storage::error::StorageError;
10use khive_storage::types::{
11    BatchWriteSummary, SparseRecord, SparseSearchHit, SparseSearchRequest, SparseVector,
12};
13use khive_storage::{SparseStore, StorageCapability};
14use khive_types::SubstrateKind;
15
16use crate::error::SqliteError;
17use crate::pool::ConnectionPool;
18
19fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
20    StorageError::driver(StorageCapability::Sparse, op, e)
21}
22
23fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
24    StorageError::driver(StorageCapability::Sparse, op, e)
25}
26
27/// Validate that a sparse vector is well-formed.
28///
29/// - indices and values must have equal lengths
30/// - at least one element
31/// - all values must be finite
32/// - indices must be strictly increasing (no duplicates)
33fn validate_sparse_vector(vector: &SparseVector, op: &'static str) -> Result<(), StorageError> {
34    if vector.indices.len() != vector.values.len() {
35        return Err(StorageError::InvalidInput {
36            capability: StorageCapability::Sparse,
37            operation: op.into(),
38            message: format!(
39                "indices length ({}) != values length ({})",
40                vector.indices.len(),
41                vector.values.len()
42            ),
43        });
44    }
45    if vector.indices.is_empty() {
46        return Err(StorageError::InvalidInput {
47            capability: StorageCapability::Sparse,
48            operation: op.into(),
49            message: "sparse vector must have at least one element".into(),
50        });
51    }
52    for (i, v) in vector.values.iter().enumerate() {
53        if !v.is_finite() {
54            return Err(StorageError::InvalidInput {
55                capability: StorageCapability::Sparse,
56                operation: op.into(),
57                message: format!("non-finite value at position {i}: {v}"),
58            });
59        }
60    }
61    // Verify strictly increasing indices.
62    for window in vector.indices.windows(2) {
63        if window[0] >= window[1] {
64            return Err(StorageError::InvalidInput {
65                capability: StorageCapability::Sparse,
66                operation: op.into(),
67                message: format!(
68                    "indices must be strictly increasing; found {} then {}",
69                    window[0], window[1]
70                ),
71            });
72        }
73    }
74    Ok(())
75}
76
77/// Serialize f32 slice to little-endian bytes (same pattern as vectors.rs).
78fn f32_slice_as_bytes(data: &[f32]) -> &[u8] {
79    // SAFETY: same safety argument as vectors.rs — valid &[f32], alignment = 1, lifetime tied to input.
80    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) }
81}
82
83/// Create the sparse table and its index for the given model_key.
84pub(crate) fn ensure_sparse_schema(
85    conn: &rusqlite::Connection,
86    model_key: &str,
87) -> Result<(), rusqlite::Error> {
88    let table = format!("sparse_{}", model_key);
89    let ddl = format!(
90        "CREATE TABLE IF NOT EXISTS {table} (\
91         subject_id TEXT NOT NULL, \
92         namespace TEXT NOT NULL, \
93         kind TEXT NOT NULL, \
94         field TEXT NOT NULL, \
95         indices_json TEXT NOT NULL, \
96         values_blob BLOB NOT NULL, \
97         updated_at INTEGER NOT NULL, \
98         PRIMARY KEY(subject_id, namespace, field)\
99         ); \
100         CREATE INDEX IF NOT EXISTS idx_{table}_namespace_kind \
101         ON {table}(namespace, kind);"
102    );
103    conn.execute_batch(&ddl)
104}
105
106pub struct SqliteSparseStore {
107    pool: Arc<ConnectionPool>,
108    is_file_backed: bool,
109    table_name: String,
110    namespace: String,
111}
112
113impl SqliteSparseStore {
114    pub fn new(
115        pool: Arc<ConnectionPool>,
116        is_file_backed: bool,
117        model_key: String,
118        namespace: String,
119    ) -> Result<Self, SqliteError> {
120        let table_name = format!("sparse_{}", model_key);
121        Ok(Self {
122            pool,
123            is_file_backed,
124            table_name,
125            namespace,
126        })
127    }
128
129    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
130    where
131        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
132        R: Send + 'static,
133    {
134        let pool = Arc::clone(&self.pool);
135        tokio::task::spawn_blocking(move || {
136            let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
137            f(guard.conn()).map_err(|e| map_err(e, op))
138        })
139        .await
140        .map_err(|e| StorageError::driver(StorageCapability::Sparse, op, e))?
141    }
142
143    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
144    where
145        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
146        R: Send + 'static,
147    {
148        if self.is_file_backed {
149            // For file-backed DBs open a standalone read-only connection.
150            let config = self.pool.config();
151            let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
152                operation: "sparse_reader".into(),
153                message: "in-memory databases do not support standalone connections".into(),
154            })?;
155            let conn = rusqlite::Connection::open_with_flags(
156                path,
157                rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
158                    | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
159                    | rusqlite::OpenFlags::SQLITE_OPEN_URI,
160            )
161            .map_err(|e| map_err(e, op))?;
162            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
163                .await
164                .map_err(|e| StorageError::driver(StorageCapability::Sparse, op, e))?
165        } else {
166            let pool = Arc::clone(&self.pool);
167            tokio::task::spawn_blocking(move || {
168                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
169                f(guard.conn()).map_err(|e| map_err(e, op))
170            })
171            .await
172            .map_err(|e| StorageError::driver(StorageCapability::Sparse, op, e))?
173        }
174    }
175
176    async fn upsert_sparse_vector(
177        &self,
178        subject_id: Uuid,
179        kind: SubstrateKind,
180        namespace: &str,
181        field: &str,
182        vector: SparseVector,
183    ) -> Result<(), StorageError> {
184        let table = self.table_name.clone();
185        let ns = namespace.to_string();
186        let field = field.to_string();
187        let id_str = subject_id.to_string();
188        let kind_str = kind.to_string();
189
190        self.with_writer("sparse_upsert", move |conn| {
191            let indices_json = serde_json::to_string(&vector.indices).map_err(|e| {
192                rusqlite::Error::FromSqlConversionFailure(
193                    0,
194                    rusqlite::types::Type::Text,
195                    Box::new(e),
196                )
197            })?;
198            let values_blob = f32_slice_as_bytes(&vector.values);
199            let now = chrono::Utc::now().timestamp();
200            let sql = format!(
201                "INSERT INTO {table} \
202                 (subject_id, namespace, kind, field, indices_json, values_blob, updated_at) \
203                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) \
204                 ON CONFLICT(subject_id, namespace, field) DO UPDATE SET \
205                 kind = excluded.kind, \
206                 indices_json = excluded.indices_json, \
207                 values_blob = excluded.values_blob, \
208                 updated_at = excluded.updated_at"
209            );
210            conn.execute(
211                &sql,
212                rusqlite::params![
213                    &id_str,
214                    &ns,
215                    &kind_str,
216                    &field,
217                    &indices_json,
218                    values_blob,
219                    now
220                ],
221            )?;
222            Ok(())
223        })
224        .await
225    }
226
227    async fn insert_sparse_batch(
228        &self,
229        records: Vec<SparseRecord>,
230    ) -> Result<BatchWriteSummary, StorageError> {
231        let table = self.table_name.clone();
232        let attempted = records.len() as u64;
233
234        self.with_writer("sparse_insert_batch", move |conn| {
235            let sql = format!(
236                "INSERT INTO {table} \
237                 (subject_id, namespace, kind, field, indices_json, values_blob, updated_at) \
238                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) \
239                 ON CONFLICT(subject_id, namespace, field) DO UPDATE SET \
240                 indices_json = excluded.indices_json, \
241                 values_blob = excluded.values_blob, \
242                 updated_at = excluded.updated_at"
243            );
244
245            conn.execute_batch("BEGIN IMMEDIATE")?;
246            let mut affected = 0u64;
247            let mut failed = 0u64;
248            let mut first_error = String::new();
249
250            for record in &records {
251                // Validate inline — skip invalid records rather than aborting the batch.
252                if record.vector.indices.len() != record.vector.values.len()
253                    || record.vector.indices.is_empty()
254                    || record.vector.values.iter().any(|v| !v.is_finite())
255                    || record.vector.indices.windows(2).any(|w| w[0] >= w[1])
256                {
257                    if first_error.is_empty() {
258                        first_error =
259                            format!("invalid sparse vector for subject {}", record.subject_id);
260                    }
261                    failed += 1;
262                    continue;
263                }
264
265                let indices_json = match serde_json::to_string(&record.vector.indices) {
266                    Ok(j) => j,
267                    Err(e) => {
268                        if first_error.is_empty() {
269                            first_error = e.to_string();
270                        }
271                        failed += 1;
272                        continue;
273                    }
274                };
275                let values_blob = f32_slice_as_bytes(&record.vector.values);
276                let now = record.updated_at.timestamp();
277                let id_str = record.subject_id.to_string();
278                let kind_str = record.kind.to_string();
279
280                match conn.execute(
281                    &sql,
282                    rusqlite::params![
283                        &id_str,
284                        &record.namespace,
285                        &kind_str,
286                        &record.field,
287                        &indices_json,
288                        values_blob,
289                        now
290                    ],
291                ) {
292                    Ok(_) => affected += 1,
293                    Err(e) => {
294                        if first_error.is_empty() {
295                            first_error = e.to_string();
296                        }
297                        failed += 1;
298                    }
299                }
300            }
301
302            conn.execute_batch("COMMIT")?;
303            Ok(BatchWriteSummary {
304                attempted,
305                affected,
306                failed,
307                first_error,
308            })
309        })
310        .await
311    }
312
313    async fn delete_sparse_subject(&self, subject_id: Uuid) -> Result<bool, StorageError> {
314        let table = self.table_name.clone();
315        let namespace = self.namespace.clone();
316        let id_str = subject_id.to_string();
317
318        self.with_writer("sparse_delete", move |conn| {
319            let sql = format!("DELETE FROM {table} WHERE subject_id = ?1 AND namespace = ?2");
320            let deleted = conn.execute(&sql, rusqlite::params![&id_str, &namespace])?;
321            Ok(deleted > 0)
322        })
323        .await
324    }
325
326    async fn search_sparse_vectors(
327        &self,
328        request: SparseSearchRequest,
329    ) -> Result<Vec<SparseSearchHit>, StorageError> {
330        let table = self.table_name.clone();
331        let ns = request
332            .namespace
333            .clone()
334            .unwrap_or_else(|| self.namespace.clone());
335        let kind_filter = request.kind.map(|k| k.to_string());
336        let query = request.query;
337        let top_k = request.top_k as usize;
338
339        self.with_reader("sparse_search", move |conn| {
340            // Load candidate rows for namespace (and optional kind).
341            let (sql, kind_str_ref) = if let Some(ref kind_str) = kind_filter {
342                (
343                    format!(
344                        "SELECT subject_id, indices_json, values_blob \
345                         FROM {table} WHERE namespace = ?1 AND kind = ?2"
346                    ),
347                    Some(kind_str.as_str()),
348                )
349            } else {
350                (
351                    format!(
352                        "SELECT subject_id, indices_json, values_blob \
353                         FROM {table} WHERE namespace = ?1"
354                    ),
355                    None,
356                )
357            };
358
359            let mut stmt = conn.prepare(&sql)?;
360
361            // Collect rows.
362            let rows: Vec<rusqlite::Result<(String, String, Vec<u8>)>> =
363                if let Some(kind_str) = kind_str_ref {
364                    stmt.query_map(rusqlite::params![&ns, kind_str], |row| {
365                        Ok((row.get(0)?, row.get(1)?, row.get(2)?))
366                    })?
367                    .collect()
368                } else {
369                    stmt.query_map(rusqlite::params![&ns], |row| {
370                        Ok((row.get(0)?, row.get(1)?, row.get(2)?))
371                    })?
372                    .collect()
373                };
374
375            // Compute sparse dot product for each candidate.
376            let mut scored: Vec<(Uuid, f64)> = Vec::new();
377            for row_result in rows {
378                let (id_str, indices_json, values_blob) = row_result?;
379
380                let subject_id = Uuid::parse_str(&id_str).map_err(|e| {
381                    rusqlite::Error::FromSqlConversionFailure(
382                        0,
383                        rusqlite::types::Type::Text,
384                        Box::new(e),
385                    )
386                })?;
387
388                let stored_indices: Vec<u32> =
389                    serde_json::from_str(&indices_json).unwrap_or_default();
390                // Deserialize f32 values from little-endian bytes.
391                let stored_values: Vec<f32> = if values_blob.len() % 4 == 0 {
392                    values_blob
393                        .chunks_exact(4)
394                        .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
395                        .collect()
396                } else {
397                    continue;
398                };
399
400                if stored_indices.len() != stored_values.len() {
401                    continue;
402                }
403
404                // Sparse dot product using merge of sorted index arrays.
405                let score = sparse_dot_product(
406                    &query.indices,
407                    &query.values,
408                    &stored_indices,
409                    &stored_values,
410                );
411                scored.push((subject_id, score));
412            }
413
414            // Sort descending by score, take top_k, assign 1-based rank.
415            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
416            scored.truncate(top_k);
417
418            let hits = scored
419                .into_iter()
420                .enumerate()
421                .map(|(i, (subject_id, score))| SparseSearchHit {
422                    subject_id,
423                    score: DeterministicScore::from_f64(score),
424                    rank: (i + 1) as u32,
425                })
426                .collect();
427
428            Ok(hits)
429        })
430        .await
431    }
432
433    async fn count_sparse_rows(&self) -> Result<u64, StorageError> {
434        let table = self.table_name.clone();
435        let namespace = self.namespace.clone();
436        self.with_reader("sparse_count", move |conn| {
437            let sql = format!("SELECT COUNT(*) FROM {table} WHERE namespace = ?1");
438            let count: i64 =
439                conn.query_row(&sql, rusqlite::params![&namespace], |row| row.get(0))?;
440            Ok(count as u64)
441        })
442        .await
443    }
444}
445
446/// Sparse dot product via merge of two sorted index arrays.
447fn sparse_dot_product(q_idx: &[u32], q_val: &[f32], s_idx: &[u32], s_val: &[f32]) -> f64 {
448    let mut dot = 0.0f64;
449    let mut qi = 0;
450    let mut si = 0;
451    while qi < q_idx.len() && si < s_idx.len() {
452        match q_idx[qi].cmp(&s_idx[si]) {
453            std::cmp::Ordering::Equal => {
454                dot += q_val[qi] as f64 * s_val[si] as f64;
455                qi += 1;
456                si += 1;
457            }
458            std::cmp::Ordering::Less => qi += 1,
459            std::cmp::Ordering::Greater => si += 1,
460        }
461    }
462    dot
463}
464
465#[async_trait]
466impl SparseStore for SqliteSparseStore {
467    async fn insert_sparse(
468        &self,
469        subject_id: Uuid,
470        kind: SubstrateKind,
471        namespace: &str,
472        field: &str,
473        vector: SparseVector,
474    ) -> Result<(), StorageError> {
475        validate_sparse_vector(&vector, "sparse_insert")?;
476        self.upsert_sparse_vector(subject_id, kind, namespace, field, vector)
477            .await
478    }
479
480    async fn insert_batch(
481        &self,
482        records: Vec<SparseRecord>,
483    ) -> Result<BatchWriteSummary, StorageError> {
484        self.insert_sparse_batch(records).await
485    }
486
487    async fn delete(&self, subject_id: Uuid) -> Result<bool, StorageError> {
488        self.delete_sparse_subject(subject_id).await
489    }
490
491    async fn search_sparse(
492        &self,
493        request: SparseSearchRequest,
494    ) -> Result<Vec<SparseSearchHit>, StorageError> {
495        validate_sparse_vector(&request.query, "sparse_search")?;
496        self.search_sparse_vectors(request).await
497    }
498
499    async fn count(&self) -> Result<u64, StorageError> {
500        self.count_sparse_rows().await
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use crate::pool::{ConnectionPool, PoolConfig};
508
509    fn make_store(model_key: &str) -> SqliteSparseStore {
510        let config = PoolConfig {
511            path: None,
512            ..PoolConfig::default()
513        };
514        let pool = Arc::new(ConnectionPool::new(config).expect("pool"));
515        // Create schema.
516        {
517            let writer = pool.try_writer().expect("writer");
518            ensure_sparse_schema(writer.conn(), model_key).expect("schema");
519        }
520        SqliteSparseStore::new(pool, false, model_key.to_string(), "ns:test".to_string())
521            .expect("store")
522    }
523
524    fn sv(indices: Vec<u32>, values: Vec<f32>) -> SparseVector {
525        SparseVector { indices, values }
526    }
527
528    #[tokio::test]
529    async fn insert_and_count() {
530        let store = make_store("test_count");
531        let id = Uuid::new_v4();
532        store
533            .insert_sparse(
534                id,
535                SubstrateKind::Entity,
536                "ns:test",
537                "body",
538                sv(vec![0, 2], vec![1.0, 0.5]),
539            )
540            .await
541            .unwrap();
542        assert_eq!(store.count().await.unwrap(), 1);
543    }
544
545    #[tokio::test]
546    async fn insert_and_search() {
547        let store = make_store("test_search");
548        let id1 = Uuid::new_v4();
549        let id2 = Uuid::new_v4();
550        store
551            .insert_sparse(
552                id1,
553                SubstrateKind::Entity,
554                "ns:test",
555                "body",
556                sv(vec![0, 1], vec![1.0, 0.0]),
557            )
558            .await
559            .unwrap();
560        store
561            .insert_sparse(
562                id2,
563                SubstrateKind::Entity,
564                "ns:test",
565                "body",
566                sv(vec![0, 1], vec![0.0, 1.0]),
567            )
568            .await
569            .unwrap();
570
571        let hits = store
572            .search_sparse(SparseSearchRequest {
573                query: sv(vec![0], vec![1.0]),
574                top_k: 2,
575                namespace: Some("ns:test".into()),
576                kind: None,
577            })
578            .await
579            .unwrap();
580
581        assert!(!hits.is_empty());
582        assert_eq!(hits[0].subject_id, id1, "id1 should rank first");
583        assert_eq!(hits[0].rank, 1);
584    }
585
586    #[tokio::test]
587    async fn delete_removes_row() {
588        let store = make_store("test_delete");
589        let id = Uuid::new_v4();
590        store
591            .insert_sparse(
592                id,
593                SubstrateKind::Entity,
594                "ns:test",
595                "body",
596                sv(vec![1], vec![1.0]),
597            )
598            .await
599            .unwrap();
600        assert_eq!(store.count().await.unwrap(), 1);
601
602        let deleted = store.delete(id).await.unwrap();
603        assert!(deleted);
604        assert_eq!(store.count().await.unwrap(), 0);
605    }
606
607    #[tokio::test]
608    async fn mismatched_lengths_rejected() {
609        let store = make_store("test_mismatch");
610        let result = store
611            .insert_sparse(
612                Uuid::new_v4(),
613                SubstrateKind::Entity,
614                "ns:test",
615                "body",
616                SparseVector {
617                    indices: vec![0, 1],
618                    values: vec![1.0],
619                },
620            )
621            .await;
622        assert!(matches!(result, Err(StorageError::InvalidInput { .. })));
623    }
624
625    #[tokio::test]
626    async fn non_finite_values_rejected() {
627        let store = make_store("test_nonfinite");
628        let result = store
629            .insert_sparse(
630                Uuid::new_v4(),
631                SubstrateKind::Entity,
632                "ns:test",
633                "body",
634                sv(vec![0], vec![f32::NAN]),
635            )
636            .await;
637        assert!(matches!(result, Err(StorageError::InvalidInput { .. })));
638    }
639
640    #[tokio::test]
641    async fn duplicate_indices_rejected() {
642        let store = make_store("test_dup_idx");
643        let result = store
644            .insert_sparse(
645                Uuid::new_v4(),
646                SubstrateKind::Entity,
647                "ns:test",
648                "body",
649                sv(vec![0, 0], vec![1.0, 2.0]),
650            )
651            .await;
652        assert!(matches!(result, Err(StorageError::InvalidInput { .. })));
653    }
654
655    #[tokio::test]
656    async fn empty_vector_rejected() {
657        let store = make_store("test_empty");
658        let result = store
659            .insert_sparse(
660                Uuid::new_v4(),
661                SubstrateKind::Entity,
662                "ns:test",
663                "body",
664                sv(vec![], vec![]),
665            )
666            .await;
667        assert!(matches!(result, Err(StorageError::InvalidInput { .. })));
668    }
669
670    #[tokio::test]
671    async fn namespace_isolation() {
672        let store = make_store("test_ns_iso");
673        let id = Uuid::new_v4();
674        store
675            .insert_sparse(
676                id,
677                SubstrateKind::Entity,
678                "ns:a",
679                "body",
680                sv(vec![0], vec![1.0]),
681            )
682            .await
683            .unwrap();
684
685        let hits = store
686            .search_sparse(SparseSearchRequest {
687                query: sv(vec![0], vec![1.0]),
688                top_k: 5,
689                namespace: Some("ns:b".into()),
690                kind: None,
691            })
692            .await
693            .unwrap();
694        assert!(hits.is_empty(), "ns:b should not see ns:a data");
695    }
696
697    #[tokio::test]
698    async fn insert_batch_happy_path() {
699        use chrono::Utc;
700        use khive_types::SubstrateKind;
701
702        let store = make_store("test_batch");
703        let id1 = Uuid::new_v4();
704        let id2 = Uuid::new_v4();
705        let records = vec![
706            SparseRecord {
707                subject_id: id1,
708                kind: SubstrateKind::Entity,
709                namespace: "ns:test".into(),
710                field: "body".into(),
711                vector: sv(vec![0, 3], vec![0.5, 0.8]),
712                updated_at: Utc::now(),
713            },
714            SparseRecord {
715                subject_id: id2,
716                kind: SubstrateKind::Entity,
717                namespace: "ns:test".into(),
718                field: "body".into(),
719                vector: sv(vec![1], vec![1.0]),
720                updated_at: Utc::now(),
721            },
722        ];
723        let summary = store.insert_batch(records).await.unwrap();
724        assert_eq!(summary.attempted, 2);
725        assert_eq!(summary.affected, 2);
726        assert_eq!(summary.failed, 0);
727        assert_eq!(store.count().await.unwrap(), 2);
728    }
729}