Skip to main content

erio_context_store/
store.rs

1//! Context store backed by `LanceDB`.
2
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use arrow_array::types::Float32Type;
7use arrow_array::{FixedSizeListArray, RecordBatch, RecordBatchIterator, StringArray};
8use arrow_schema::{DataType, Field, Schema};
9use futures::TryStreamExt;
10use lancedb::query::{ExecutableQuery, QueryBase};
11
12use crate::config::{ContextConfig, SearchResult, StorageStats};
13use crate::error::ContextStoreError;
14
15static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
16
17fn generate_id() -> String {
18    let count = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
19    let nanos = std::time::SystemTime::now()
20        .duration_since(std::time::UNIX_EPOCH)
21        .unwrap_or_default()
22        .as_nanos();
23    format!("{nanos:x}-{count:x}")
24}
25
26pub struct ContextStore {
27    table: lancedb::Table,
28    embedding: Arc<dyn erio_embedding::EmbeddingEngine>,
29}
30
31impl ContextStore {
32    pub async fn new(
33        config: ContextConfig,
34        embedding: Arc<dyn erio_embedding::EmbeddingEngine>,
35    ) -> Result<Self, ContextStoreError> {
36        let dims = embedding.dimensions();
37        let schema = Arc::new(Schema::new(vec![
38            Field::new("id", DataType::Utf8, false),
39            Field::new("content", DataType::Utf8, false),
40            Field::new("metadata", DataType::Utf8, false),
41            Field::new(
42                "vector",
43                DataType::FixedSizeList(
44                    Arc::new(Field::new("item", DataType::Float32, true)),
45                    i32::try_from(dims).map_err(|e| {
46                        ContextStoreError::InvalidInput(format!("bad dimensions: {e}"))
47                    })?,
48                ),
49                false,
50            ),
51        ]));
52
53        let db = lancedb::connect(config.path.to_string_lossy().as_ref())
54            .execute()
55            .await
56            .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
57
58        let table = db
59            .create_empty_table("context", schema)
60            .execute()
61            .await
62            .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
63
64        Ok(Self { table, embedding })
65    }
66
67    pub async fn add(
68        &self,
69        content: &str,
70        metadata: serde_json::Value,
71    ) -> Result<String, ContextStoreError> {
72        if content.is_empty() {
73            return Err(ContextStoreError::InvalidInput(
74                "content must not be empty".into(),
75            ));
76        }
77
78        let id = generate_id();
79        let vector = self
80            .embedding
81            .embed(content)
82            .await
83            .map_err(|e| ContextStoreError::Embedding(e.to_string()))?;
84
85        let dims = i32::try_from(self.embedding.dimensions())
86            .map_err(|e| ContextStoreError::InvalidInput(format!("bad dimensions: {e}")))?;
87
88        let schema = self
89            .table
90            .schema()
91            .await
92            .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
93
94        let batch = RecordBatch::try_new(
95            schema,
96            vec![
97                Arc::new(StringArray::from(vec![id.as_str()])),
98                Arc::new(StringArray::from(vec![content])),
99                Arc::new(StringArray::from(vec![metadata.to_string().as_str()])),
100                Arc::new(
101                    FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
102                        vec![Some(vector.into_iter().map(Some).collect::<Vec<_>>())],
103                        dims,
104                    ),
105                ),
106            ],
107        )
108        .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
109
110        let batches = RecordBatchIterator::new(
111            vec![Ok(batch)],
112            self.table
113                .schema()
114                .await
115                .map_err(|e| ContextStoreError::Storage(e.to_string()))?,
116        );
117
118        self.table
119            .add(batches)
120            .execute()
121            .await
122            .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
123
124        Ok(id)
125    }
126
127    pub async fn search(
128        &self,
129        query: &str,
130        k: usize,
131        filter: Option<String>,
132    ) -> Result<Vec<SearchResult>, ContextStoreError> {
133        let count = self
134            .table
135            .count_rows(None)
136            .await
137            .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
138        if count == 0 {
139            return Ok(vec![]);
140        }
141
142        let vector = self
143            .embedding
144            .embed(query)
145            .await
146            .map_err(|e| ContextStoreError::Embedding(e.to_string()))?;
147
148        let mut builder = self
149            .table
150            .query()
151            .nearest_to(vector)
152            .map_err(|e| ContextStoreError::Storage(e.to_string()))?
153            .limit(k);
154
155        if let Some(f) = filter {
156            builder = builder.only_if(f);
157        }
158
159        let batches: Vec<RecordBatch> = builder
160            .execute()
161            .await
162            .map_err(|e| ContextStoreError::Storage(e.to_string()))?
163            .try_collect()
164            .await
165            .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
166
167        let mut results = Vec::new();
168        for batch in &batches {
169            let content_col = batch
170                .column_by_name("content")
171                .and_then(|c| c.as_any().downcast_ref::<StringArray>())
172                .ok_or_else(|| ContextStoreError::Storage("missing content column".into()))?;
173
174            let metadata_col = batch
175                .column_by_name("metadata")
176                .and_then(|c| c.as_any().downcast_ref::<StringArray>())
177                .ok_or_else(|| ContextStoreError::Storage("missing metadata column".into()))?;
178
179            let distance_col = batch
180                .column_by_name("_distance")
181                .and_then(|c| c.as_any().downcast_ref::<arrow_array::Float32Array>())
182                .ok_or_else(|| ContextStoreError::Storage("missing _distance column".into()))?;
183
184            for i in 0..batch.num_rows() {
185                let meta_str = metadata_col.value(i);
186                let metadata: serde_json::Value =
187                    serde_json::from_str(meta_str).unwrap_or_default();
188
189                results.push(SearchResult {
190                    content: content_col.value(i).to_string(),
191                    score: 1.0 - distance_col.value(i),
192                    metadata,
193                });
194            }
195        }
196
197        Ok(results)
198    }
199
200    pub async fn stats(&self) -> Result<StorageStats, ContextStoreError> {
201        let count = self
202            .table
203            .count_rows(None)
204            .await
205            .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
206        Ok(StorageStats {
207            document_count: count,
208        })
209    }
210
211    pub async fn delete(&self, id: &str) -> Result<(), ContextStoreError> {
212        let escaped = id.replace('\'', "''");
213        let filter = format!("id = '{escaped}'");
214
215        let count_before = self
216            .table
217            .count_rows(Some(filter.clone()))
218            .await
219            .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
220
221        if count_before == 0 {
222            return Err(ContextStoreError::NotFound(id.to_string()));
223        }
224
225        self.table
226            .delete(&filter)
227            .await
228            .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
229
230        Ok(())
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::config::ContextConfig;
238    use crate::config::HnswConfig;
239    use erio_embedding::EmbeddingError;
240    use std::sync::Arc;
241
242    struct FakeEmbedding;
243
244    #[async_trait::async_trait]
245    impl erio_embedding::EmbeddingEngine for FakeEmbedding {
246        #[allow(clippy::unnecessary_literal_bound)]
247        fn name(&self) -> &str {
248            "fake"
249        }
250
251        fn dimensions(&self) -> usize {
252            3
253        }
254
255        async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
256            let bytes = text.as_bytes();
257            Ok(vec![
258                f32::from(*bytes.first().unwrap_or(&0)),
259                f32::from(*bytes.get(1).unwrap_or(&0)),
260                f32::from(*bytes.get(2).unwrap_or(&0)),
261            ])
262        }
263
264        async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
265            let mut results = Vec::with_capacity(texts.len());
266            for text in texts {
267                results.push(self.embed(text).await?);
268            }
269            Ok(results)
270        }
271    }
272
273    fn test_config(dir: &tempfile::TempDir) -> ContextConfig {
274        ContextConfig {
275            path: dir.path().to_path_buf(),
276            index: HnswConfig::default(),
277        }
278    }
279
280    fn fake_embedding() -> Arc<dyn erio_embedding::EmbeddingEngine> {
281        Arc::new(FakeEmbedding)
282    }
283
284    // === Construction ===
285
286    #[tokio::test]
287    async fn new_creates_empty_store() {
288        let dir = tempfile::tempdir().unwrap();
289        let store = ContextStore::new(test_config(&dir), fake_embedding())
290            .await
291            .unwrap();
292        let stats = store.stats().await.unwrap();
293        assert_eq!(stats.document_count, 0);
294    }
295
296    // === Add ===
297
298    #[tokio::test]
299    async fn add_returns_document_id() {
300        let dir = tempfile::tempdir().unwrap();
301        let store = ContextStore::new(test_config(&dir), fake_embedding())
302            .await
303            .unwrap();
304        let id = store
305            .add("hello world", serde_json::json!({}))
306            .await
307            .unwrap();
308        assert!(!id.is_empty());
309    }
310
311    #[tokio::test]
312    async fn add_increments_document_count() {
313        let dir = tempfile::tempdir().unwrap();
314        let store = ContextStore::new(test_config(&dir), fake_embedding())
315            .await
316            .unwrap();
317        store.add("first", serde_json::json!({})).await.unwrap();
318        store.add("second", serde_json::json!({})).await.unwrap();
319        let stats = store.stats().await.unwrap();
320        assert_eq!(stats.document_count, 2);
321    }
322
323    #[tokio::test]
324    async fn add_rejects_empty_content() {
325        let dir = tempfile::tempdir().unwrap();
326        let store = ContextStore::new(test_config(&dir), fake_embedding())
327            .await
328            .unwrap();
329        let result = store.add("", serde_json::json!({})).await;
330        assert!(result.is_err());
331    }
332
333    // === Search ===
334
335    #[tokio::test]
336    async fn search_returns_matching_results() {
337        let dir = tempfile::tempdir().unwrap();
338        let store = ContextStore::new(test_config(&dir), fake_embedding())
339            .await
340            .unwrap();
341        store
342            .add("hello world", serde_json::json!({}))
343            .await
344            .unwrap();
345        store
346            .add("hello there", serde_json::json!({}))
347            .await
348            .unwrap();
349        store.add("goodbye", serde_json::json!({})).await.unwrap();
350
351        let results = store.search("hello", 2, None).await.unwrap();
352        assert_eq!(results.len(), 2);
353    }
354
355    #[tokio::test]
356    async fn search_returns_results_with_scores() {
357        let dir = tempfile::tempdir().unwrap();
358        let store = ContextStore::new(test_config(&dir), fake_embedding())
359            .await
360            .unwrap();
361        store.add("hello", serde_json::json!({})).await.unwrap();
362
363        let results = store.search("hello", 1, None).await.unwrap();
364        assert_eq!(results.len(), 1);
365        assert!(results[0].score >= 0.0);
366    }
367
368    #[tokio::test]
369    async fn search_respects_k_limit() {
370        let dir = tempfile::tempdir().unwrap();
371        let store = ContextStore::new(test_config(&dir), fake_embedding())
372            .await
373            .unwrap();
374        for i in 0..10 {
375            store
376                .add(&format!("document {i}"), serde_json::json!({}))
377                .await
378                .unwrap();
379        }
380
381        let results = store.search("document", 3, None).await.unwrap();
382        assert_eq!(results.len(), 3);
383    }
384
385    #[tokio::test]
386    async fn search_on_empty_store_returns_empty() {
387        let dir = tempfile::tempdir().unwrap();
388        let store = ContextStore::new(test_config(&dir), fake_embedding())
389            .await
390            .unwrap();
391        let results = store.search("anything", 5, None).await.unwrap();
392        assert!(results.is_empty());
393    }
394
395    // === Delete ===
396
397    #[tokio::test]
398    async fn delete_removes_document() {
399        let dir = tempfile::tempdir().unwrap();
400        let store = ContextStore::new(test_config(&dir), fake_embedding())
401            .await
402            .unwrap();
403        let id = store
404            .add("to be deleted", serde_json::json!({}))
405            .await
406            .unwrap();
407
408        store.delete(&id).await.unwrap();
409
410        let stats = store.stats().await.unwrap();
411        assert_eq!(stats.document_count, 0);
412    }
413
414    #[tokio::test]
415    async fn delete_nonexistent_returns_error() {
416        let dir = tempfile::tempdir().unwrap();
417        let store = ContextStore::new(test_config(&dir), fake_embedding())
418            .await
419            .unwrap();
420        let result = store.delete("nonexistent_id").await;
421        assert!(result.is_err());
422    }
423
424    // === Metadata Filtering ===
425
426    #[tokio::test]
427    async fn search_with_metadata_filter_narrows_results() {
428        let dir = tempfile::tempdir().unwrap();
429        let store = ContextStore::new(test_config(&dir), fake_embedding())
430            .await
431            .unwrap();
432        store
433            .add("hello alpha", serde_json::json!({"source": "docs"}))
434            .await
435            .unwrap();
436        store
437            .add("hello beta", serde_json::json!({"source": "code"}))
438            .await
439            .unwrap();
440        store
441            .add("hello gamma", serde_json::json!({"source": "docs"}))
442            .await
443            .unwrap();
444
445        let filter = "metadata LIKE '%\"source\":\"docs\"%'".to_string();
446        let results = store.search("hello", 10, Some(filter)).await.unwrap();
447        assert_eq!(results.len(), 2);
448        for r in &results {
449            assert_eq!(r.metadata["source"], "docs");
450        }
451    }
452
453    #[tokio::test]
454    async fn search_filter_returns_empty_when_nothing_matches() {
455        let dir = tempfile::tempdir().unwrap();
456        let store = ContextStore::new(test_config(&dir), fake_embedding())
457            .await
458            .unwrap();
459        store
460            .add("hello", serde_json::json!({"source": "docs"}))
461            .await
462            .unwrap();
463
464        let filter = "metadata LIKE '%\"source\":\"nonexistent\"%'".to_string();
465        let results = store.search("hello", 10, Some(filter)).await.unwrap();
466        assert!(results.is_empty());
467    }
468
469    #[tokio::test]
470    async fn search_preserves_metadata_in_results() {
471        let dir = tempfile::tempdir().unwrap();
472        let store = ContextStore::new(test_config(&dir), fake_embedding())
473            .await
474            .unwrap();
475        let meta = serde_json::json!({"source": "test", "priority": 5});
476        store.add("hello world", meta.clone()).await.unwrap();
477
478        let results = store.search("hello", 1, None).await.unwrap();
479        assert_eq!(results.len(), 1);
480        assert_eq!(results[0].metadata["source"], "test");
481        assert_eq!(results[0].metadata["priority"], 5);
482    }
483}