1use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_score::DeterministicScore;
9use khive_storage::error::StorageError;
10use khive_storage::types::{
11 BatchWriteSummary, SparseRecord, SparseSearchHit, SparseSearchRequest, SparseVector,
12};
13use khive_storage::{SparseStore, StorageCapability};
14use khive_types::SubstrateKind;
15
16use crate::error::SqliteError;
17use crate::pool::ConnectionPool;
18
19fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
20 StorageError::driver(StorageCapability::Sparse, op, e)
21}
22
23fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
24 StorageError::driver(StorageCapability::Sparse, op, e)
25}
26
27fn validate_sparse_vector(vector: &SparseVector, op: &'static str) -> Result<(), StorageError> {
34 if vector.indices.len() != vector.values.len() {
35 return Err(StorageError::InvalidInput {
36 capability: StorageCapability::Sparse,
37 operation: op.into(),
38 message: format!(
39 "indices length ({}) != values length ({})",
40 vector.indices.len(),
41 vector.values.len()
42 ),
43 });
44 }
45 if vector.indices.is_empty() {
46 return Err(StorageError::InvalidInput {
47 capability: StorageCapability::Sparse,
48 operation: op.into(),
49 message: "sparse vector must have at least one element".into(),
50 });
51 }
52 for (i, v) in vector.values.iter().enumerate() {
53 if !v.is_finite() {
54 return Err(StorageError::InvalidInput {
55 capability: StorageCapability::Sparse,
56 operation: op.into(),
57 message: format!("non-finite value at position {i}: {v}"),
58 });
59 }
60 }
61 for window in vector.indices.windows(2) {
63 if window[0] >= window[1] {
64 return Err(StorageError::InvalidInput {
65 capability: StorageCapability::Sparse,
66 operation: op.into(),
67 message: format!(
68 "indices must be strictly increasing; found {} then {}",
69 window[0], window[1]
70 ),
71 });
72 }
73 }
74 Ok(())
75}
76
77fn f32_slice_as_bytes(data: &[f32]) -> &[u8] {
79 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) }
81}
82
83pub(crate) fn ensure_sparse_schema(
85 conn: &rusqlite::Connection,
86 model_key: &str,
87) -> Result<(), rusqlite::Error> {
88 let table = format!("sparse_{}", model_key);
89 let ddl = format!(
90 "CREATE TABLE IF NOT EXISTS {table} (\
91 subject_id TEXT NOT NULL, \
92 namespace TEXT NOT NULL, \
93 kind TEXT NOT NULL, \
94 field TEXT NOT NULL, \
95 indices_json TEXT NOT NULL, \
96 values_blob BLOB NOT NULL, \
97 updated_at INTEGER NOT NULL, \
98 PRIMARY KEY(subject_id, namespace, field)\
99 ); \
100 CREATE INDEX IF NOT EXISTS idx_{table}_namespace_kind \
101 ON {table}(namespace, kind);"
102 );
103 conn.execute_batch(&ddl)
104}
105
106pub struct SqliteSparseStore {
107 pool: Arc<ConnectionPool>,
108 is_file_backed: bool,
109 table_name: String,
110 namespace: String,
111}
112
113impl SqliteSparseStore {
114 pub fn new(
115 pool: Arc<ConnectionPool>,
116 is_file_backed: bool,
117 model_key: String,
118 namespace: String,
119 ) -> Result<Self, SqliteError> {
120 let table_name = format!("sparse_{}", model_key);
121 Ok(Self {
122 pool,
123 is_file_backed,
124 table_name,
125 namespace,
126 })
127 }
128
129 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
130 where
131 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
132 R: Send + 'static,
133 {
134 let pool = Arc::clone(&self.pool);
135 tokio::task::spawn_blocking(move || {
136 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
137 f(guard.conn()).map_err(|e| map_err(e, op))
138 })
139 .await
140 .map_err(|e| StorageError::driver(StorageCapability::Sparse, op, e))?
141 }
142
143 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
144 where
145 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
146 R: Send + 'static,
147 {
148 if self.is_file_backed {
149 let config = self.pool.config();
151 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
152 operation: "sparse_reader".into(),
153 message: "in-memory databases do not support standalone connections".into(),
154 })?;
155 let conn = rusqlite::Connection::open_with_flags(
156 path,
157 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
158 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
159 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
160 )
161 .map_err(|e| map_err(e, op))?;
162 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
163 .await
164 .map_err(|e| StorageError::driver(StorageCapability::Sparse, op, e))?
165 } else {
166 let pool = Arc::clone(&self.pool);
167 tokio::task::spawn_blocking(move || {
168 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
169 f(guard.conn()).map_err(|e| map_err(e, op))
170 })
171 .await
172 .map_err(|e| StorageError::driver(StorageCapability::Sparse, op, e))?
173 }
174 }
175
176 async fn upsert_sparse_vector(
177 &self,
178 subject_id: Uuid,
179 kind: SubstrateKind,
180 namespace: &str,
181 field: &str,
182 vector: SparseVector,
183 ) -> Result<(), StorageError> {
184 let table = self.table_name.clone();
185 let ns = namespace.to_string();
186 let field = field.to_string();
187 let id_str = subject_id.to_string();
188 let kind_str = kind.to_string();
189
190 self.with_writer("sparse_upsert", move |conn| {
191 let indices_json = serde_json::to_string(&vector.indices).map_err(|e| {
192 rusqlite::Error::FromSqlConversionFailure(
193 0,
194 rusqlite::types::Type::Text,
195 Box::new(e),
196 )
197 })?;
198 let values_blob = f32_slice_as_bytes(&vector.values);
199 let now = chrono::Utc::now().timestamp();
200 let sql = format!(
201 "INSERT INTO {table} \
202 (subject_id, namespace, kind, field, indices_json, values_blob, updated_at) \
203 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) \
204 ON CONFLICT(subject_id, namespace, field) DO UPDATE SET \
205 kind = excluded.kind, \
206 indices_json = excluded.indices_json, \
207 values_blob = excluded.values_blob, \
208 updated_at = excluded.updated_at"
209 );
210 conn.execute(
211 &sql,
212 rusqlite::params![
213 &id_str,
214 &ns,
215 &kind_str,
216 &field,
217 &indices_json,
218 values_blob,
219 now
220 ],
221 )?;
222 Ok(())
223 })
224 .await
225 }
226
227 async fn insert_sparse_batch(
228 &self,
229 records: Vec<SparseRecord>,
230 ) -> Result<BatchWriteSummary, StorageError> {
231 let table = self.table_name.clone();
232 let attempted = records.len() as u64;
233
234 self.with_writer("sparse_insert_batch", move |conn| {
235 let sql = format!(
236 "INSERT INTO {table} \
237 (subject_id, namespace, kind, field, indices_json, values_blob, updated_at) \
238 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) \
239 ON CONFLICT(subject_id, namespace, field) DO UPDATE SET \
240 indices_json = excluded.indices_json, \
241 values_blob = excluded.values_blob, \
242 updated_at = excluded.updated_at"
243 );
244
245 conn.execute_batch("BEGIN IMMEDIATE")?;
246 let mut affected = 0u64;
247 let mut failed = 0u64;
248 let mut first_error = String::new();
249
250 for record in &records {
251 if record.vector.indices.len() != record.vector.values.len()
253 || record.vector.indices.is_empty()
254 || record.vector.values.iter().any(|v| !v.is_finite())
255 || record.vector.indices.windows(2).any(|w| w[0] >= w[1])
256 {
257 if first_error.is_empty() {
258 first_error =
259 format!("invalid sparse vector for subject {}", record.subject_id);
260 }
261 failed += 1;
262 continue;
263 }
264
265 let indices_json = match serde_json::to_string(&record.vector.indices) {
266 Ok(j) => j,
267 Err(e) => {
268 if first_error.is_empty() {
269 first_error = e.to_string();
270 }
271 failed += 1;
272 continue;
273 }
274 };
275 let values_blob = f32_slice_as_bytes(&record.vector.values);
276 let now = record.updated_at.timestamp();
277 let id_str = record.subject_id.to_string();
278 let kind_str = record.kind.to_string();
279
280 match conn.execute(
281 &sql,
282 rusqlite::params![
283 &id_str,
284 &record.namespace,
285 &kind_str,
286 &record.field,
287 &indices_json,
288 values_blob,
289 now
290 ],
291 ) {
292 Ok(_) => affected += 1,
293 Err(e) => {
294 if first_error.is_empty() {
295 first_error = e.to_string();
296 }
297 failed += 1;
298 }
299 }
300 }
301
302 conn.execute_batch("COMMIT")?;
303 Ok(BatchWriteSummary {
304 attempted,
305 affected,
306 failed,
307 first_error,
308 })
309 })
310 .await
311 }
312
313 async fn delete_sparse_subject(&self, subject_id: Uuid) -> Result<bool, StorageError> {
314 let table = self.table_name.clone();
315 let namespace = self.namespace.clone();
316 let id_str = subject_id.to_string();
317
318 self.with_writer("sparse_delete", move |conn| {
319 let sql = format!("DELETE FROM {table} WHERE subject_id = ?1 AND namespace = ?2");
320 let deleted = conn.execute(&sql, rusqlite::params![&id_str, &namespace])?;
321 Ok(deleted > 0)
322 })
323 .await
324 }
325
326 async fn search_sparse_vectors(
327 &self,
328 request: SparseSearchRequest,
329 ) -> Result<Vec<SparseSearchHit>, StorageError> {
330 let table = self.table_name.clone();
331 let ns = request
332 .namespace
333 .clone()
334 .unwrap_or_else(|| self.namespace.clone());
335 let kind_filter = request.kind.map(|k| k.to_string());
336 let query = request.query;
337 let top_k = request.top_k as usize;
338
339 self.with_reader("sparse_search", move |conn| {
340 let (sql, kind_str_ref) = if let Some(ref kind_str) = kind_filter {
342 (
343 format!(
344 "SELECT subject_id, indices_json, values_blob \
345 FROM {table} WHERE namespace = ?1 AND kind = ?2"
346 ),
347 Some(kind_str.as_str()),
348 )
349 } else {
350 (
351 format!(
352 "SELECT subject_id, indices_json, values_blob \
353 FROM {table} WHERE namespace = ?1"
354 ),
355 None,
356 )
357 };
358
359 let mut stmt = conn.prepare(&sql)?;
360
361 let rows: Vec<rusqlite::Result<(String, String, Vec<u8>)>> =
363 if let Some(kind_str) = kind_str_ref {
364 stmt.query_map(rusqlite::params![&ns, kind_str], |row| {
365 Ok((row.get(0)?, row.get(1)?, row.get(2)?))
366 })?
367 .collect()
368 } else {
369 stmt.query_map(rusqlite::params![&ns], |row| {
370 Ok((row.get(0)?, row.get(1)?, row.get(2)?))
371 })?
372 .collect()
373 };
374
375 let mut scored: Vec<(Uuid, f64)> = Vec::new();
377 for row_result in rows {
378 let (id_str, indices_json, values_blob) = row_result?;
379
380 let subject_id = Uuid::parse_str(&id_str).map_err(|e| {
381 rusqlite::Error::FromSqlConversionFailure(
382 0,
383 rusqlite::types::Type::Text,
384 Box::new(e),
385 )
386 })?;
387
388 let stored_indices: Vec<u32> =
389 serde_json::from_str(&indices_json).unwrap_or_default();
390 let stored_values: Vec<f32> = if values_blob.len() % 4 == 0 {
392 values_blob
393 .chunks_exact(4)
394 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
395 .collect()
396 } else {
397 continue;
398 };
399
400 if stored_indices.len() != stored_values.len() {
401 continue;
402 }
403
404 let score = sparse_dot_product(
406 &query.indices,
407 &query.values,
408 &stored_indices,
409 &stored_values,
410 );
411 scored.push((subject_id, score));
412 }
413
414 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
416 scored.truncate(top_k);
417
418 let hits = scored
419 .into_iter()
420 .enumerate()
421 .map(|(i, (subject_id, score))| SparseSearchHit {
422 subject_id,
423 score: DeterministicScore::from_f64(score),
424 rank: (i + 1) as u32,
425 })
426 .collect();
427
428 Ok(hits)
429 })
430 .await
431 }
432
433 async fn count_sparse_rows(&self) -> Result<u64, StorageError> {
434 let table = self.table_name.clone();
435 let namespace = self.namespace.clone();
436 self.with_reader("sparse_count", move |conn| {
437 let sql = format!("SELECT COUNT(*) FROM {table} WHERE namespace = ?1");
438 let count: i64 =
439 conn.query_row(&sql, rusqlite::params![&namespace], |row| row.get(0))?;
440 Ok(count as u64)
441 })
442 .await
443 }
444}
445
446fn sparse_dot_product(q_idx: &[u32], q_val: &[f32], s_idx: &[u32], s_val: &[f32]) -> f64 {
448 let mut dot = 0.0f64;
449 let mut qi = 0;
450 let mut si = 0;
451 while qi < q_idx.len() && si < s_idx.len() {
452 match q_idx[qi].cmp(&s_idx[si]) {
453 std::cmp::Ordering::Equal => {
454 dot += q_val[qi] as f64 * s_val[si] as f64;
455 qi += 1;
456 si += 1;
457 }
458 std::cmp::Ordering::Less => qi += 1,
459 std::cmp::Ordering::Greater => si += 1,
460 }
461 }
462 dot
463}
464
465#[async_trait]
466impl SparseStore for SqliteSparseStore {
467 async fn insert_sparse(
468 &self,
469 subject_id: Uuid,
470 kind: SubstrateKind,
471 namespace: &str,
472 field: &str,
473 vector: SparseVector,
474 ) -> Result<(), StorageError> {
475 validate_sparse_vector(&vector, "sparse_insert")?;
476 self.upsert_sparse_vector(subject_id, kind, namespace, field, vector)
477 .await
478 }
479
480 async fn insert_batch(
481 &self,
482 records: Vec<SparseRecord>,
483 ) -> Result<BatchWriteSummary, StorageError> {
484 self.insert_sparse_batch(records).await
485 }
486
487 async fn delete(&self, subject_id: Uuid) -> Result<bool, StorageError> {
488 self.delete_sparse_subject(subject_id).await
489 }
490
491 async fn search_sparse(
492 &self,
493 request: SparseSearchRequest,
494 ) -> Result<Vec<SparseSearchHit>, StorageError> {
495 validate_sparse_vector(&request.query, "sparse_search")?;
496 self.search_sparse_vectors(request).await
497 }
498
499 async fn count(&self) -> Result<u64, StorageError> {
500 self.count_sparse_rows().await
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use crate::pool::{ConnectionPool, PoolConfig};
508
509 fn make_store(model_key: &str) -> SqliteSparseStore {
510 let config = PoolConfig {
511 path: None,
512 ..PoolConfig::default()
513 };
514 let pool = Arc::new(ConnectionPool::new(config).expect("pool"));
515 {
517 let writer = pool.try_writer().expect("writer");
518 ensure_sparse_schema(writer.conn(), model_key).expect("schema");
519 }
520 SqliteSparseStore::new(pool, false, model_key.to_string(), "ns:test".to_string())
521 .expect("store")
522 }
523
524 fn sv(indices: Vec<u32>, values: Vec<f32>) -> SparseVector {
525 SparseVector { indices, values }
526 }
527
528 #[tokio::test]
529 async fn insert_and_count() {
530 let store = make_store("test_count");
531 let id = Uuid::new_v4();
532 store
533 .insert_sparse(
534 id,
535 SubstrateKind::Entity,
536 "ns:test",
537 "body",
538 sv(vec![0, 2], vec![1.0, 0.5]),
539 )
540 .await
541 .unwrap();
542 assert_eq!(store.count().await.unwrap(), 1);
543 }
544
545 #[tokio::test]
546 async fn insert_and_search() {
547 let store = make_store("test_search");
548 let id1 = Uuid::new_v4();
549 let id2 = Uuid::new_v4();
550 store
551 .insert_sparse(
552 id1,
553 SubstrateKind::Entity,
554 "ns:test",
555 "body",
556 sv(vec![0, 1], vec![1.0, 0.0]),
557 )
558 .await
559 .unwrap();
560 store
561 .insert_sparse(
562 id2,
563 SubstrateKind::Entity,
564 "ns:test",
565 "body",
566 sv(vec![0, 1], vec![0.0, 1.0]),
567 )
568 .await
569 .unwrap();
570
571 let hits = store
572 .search_sparse(SparseSearchRequest {
573 query: sv(vec![0], vec![1.0]),
574 top_k: 2,
575 namespace: Some("ns:test".into()),
576 kind: None,
577 })
578 .await
579 .unwrap();
580
581 assert!(!hits.is_empty());
582 assert_eq!(hits[0].subject_id, id1, "id1 should rank first");
583 assert_eq!(hits[0].rank, 1);
584 }
585
586 #[tokio::test]
587 async fn delete_removes_row() {
588 let store = make_store("test_delete");
589 let id = Uuid::new_v4();
590 store
591 .insert_sparse(
592 id,
593 SubstrateKind::Entity,
594 "ns:test",
595 "body",
596 sv(vec![1], vec![1.0]),
597 )
598 .await
599 .unwrap();
600 assert_eq!(store.count().await.unwrap(), 1);
601
602 let deleted = store.delete(id).await.unwrap();
603 assert!(deleted);
604 assert_eq!(store.count().await.unwrap(), 0);
605 }
606
607 #[tokio::test]
608 async fn mismatched_lengths_rejected() {
609 let store = make_store("test_mismatch");
610 let result = store
611 .insert_sparse(
612 Uuid::new_v4(),
613 SubstrateKind::Entity,
614 "ns:test",
615 "body",
616 SparseVector {
617 indices: vec![0, 1],
618 values: vec![1.0],
619 },
620 )
621 .await;
622 assert!(matches!(result, Err(StorageError::InvalidInput { .. })));
623 }
624
625 #[tokio::test]
626 async fn non_finite_values_rejected() {
627 let store = make_store("test_nonfinite");
628 let result = store
629 .insert_sparse(
630 Uuid::new_v4(),
631 SubstrateKind::Entity,
632 "ns:test",
633 "body",
634 sv(vec![0], vec![f32::NAN]),
635 )
636 .await;
637 assert!(matches!(result, Err(StorageError::InvalidInput { .. })));
638 }
639
640 #[tokio::test]
641 async fn duplicate_indices_rejected() {
642 let store = make_store("test_dup_idx");
643 let result = store
644 .insert_sparse(
645 Uuid::new_v4(),
646 SubstrateKind::Entity,
647 "ns:test",
648 "body",
649 sv(vec![0, 0], vec![1.0, 2.0]),
650 )
651 .await;
652 assert!(matches!(result, Err(StorageError::InvalidInput { .. })));
653 }
654
655 #[tokio::test]
656 async fn empty_vector_rejected() {
657 let store = make_store("test_empty");
658 let result = store
659 .insert_sparse(
660 Uuid::new_v4(),
661 SubstrateKind::Entity,
662 "ns:test",
663 "body",
664 sv(vec![], vec![]),
665 )
666 .await;
667 assert!(matches!(result, Err(StorageError::InvalidInput { .. })));
668 }
669
670 #[tokio::test]
671 async fn namespace_isolation() {
672 let store = make_store("test_ns_iso");
673 let id = Uuid::new_v4();
674 store
675 .insert_sparse(
676 id,
677 SubstrateKind::Entity,
678 "ns:a",
679 "body",
680 sv(vec![0], vec![1.0]),
681 )
682 .await
683 .unwrap();
684
685 let hits = store
686 .search_sparse(SparseSearchRequest {
687 query: sv(vec![0], vec![1.0]),
688 top_k: 5,
689 namespace: Some("ns:b".into()),
690 kind: None,
691 })
692 .await
693 .unwrap();
694 assert!(hits.is_empty(), "ns:b should not see ns:a data");
695 }
696
697 #[tokio::test]
698 async fn insert_batch_happy_path() {
699 use chrono::Utc;
700 use khive_types::SubstrateKind;
701
702 let store = make_store("test_batch");
703 let id1 = Uuid::new_v4();
704 let id2 = Uuid::new_v4();
705 let records = vec![
706 SparseRecord {
707 subject_id: id1,
708 kind: SubstrateKind::Entity,
709 namespace: "ns:test".into(),
710 field: "body".into(),
711 vector: sv(vec![0, 3], vec![0.5, 0.8]),
712 updated_at: Utc::now(),
713 },
714 SparseRecord {
715 subject_id: id2,
716 kind: SubstrateKind::Entity,
717 namespace: "ns:test".into(),
718 field: "body".into(),
719 vector: sv(vec![1], vec![1.0]),
720 updated_at: Utc::now(),
721 },
722 ];
723 let summary = store.insert_batch(records).await.unwrap();
724 assert_eq!(summary.attempted, 2);
725 assert_eq!(summary.affected, 2);
726 assert_eq!(summary.failed, 0);
727 assert_eq!(store.count().await.unwrap(), 2);
728 }
729}