1use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use arrow_array::types::Float32Type;
7use arrow_array::{FixedSizeListArray, RecordBatch, RecordBatchIterator, StringArray};
8use arrow_schema::{DataType, Field, Schema};
9use futures::TryStreamExt;
10use lancedb::query::{ExecutableQuery, QueryBase};
11
12use crate::config::{ContextConfig, SearchResult, StorageStats};
13use crate::error::ContextStoreError;
14
15static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
16
17fn generate_id() -> String {
18 let count = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
19 let nanos = std::time::SystemTime::now()
20 .duration_since(std::time::UNIX_EPOCH)
21 .unwrap_or_default()
22 .as_nanos();
23 format!("{nanos:x}-{count:x}")
24}
25
26pub struct ContextStore {
27 table: lancedb::Table,
28 embedding: Arc<dyn erio_embedding::EmbeddingEngine>,
29}
30
31impl ContextStore {
32 pub async fn new(
33 config: ContextConfig,
34 embedding: Arc<dyn erio_embedding::EmbeddingEngine>,
35 ) -> Result<Self, ContextStoreError> {
36 let dims = embedding.dimensions();
37 let schema = Arc::new(Schema::new(vec![
38 Field::new("id", DataType::Utf8, false),
39 Field::new("content", DataType::Utf8, false),
40 Field::new("metadata", DataType::Utf8, false),
41 Field::new(
42 "vector",
43 DataType::FixedSizeList(
44 Arc::new(Field::new("item", DataType::Float32, true)),
45 i32::try_from(dims).map_err(|e| {
46 ContextStoreError::InvalidInput(format!("bad dimensions: {e}"))
47 })?,
48 ),
49 false,
50 ),
51 ]));
52
53 let db = lancedb::connect(config.path.to_string_lossy().as_ref())
54 .execute()
55 .await
56 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
57
58 let table = db
59 .create_empty_table("context", schema)
60 .execute()
61 .await
62 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
63
64 Ok(Self { table, embedding })
65 }
66
67 pub async fn add(
68 &self,
69 content: &str,
70 metadata: serde_json::Value,
71 ) -> Result<String, ContextStoreError> {
72 if content.is_empty() {
73 return Err(ContextStoreError::InvalidInput(
74 "content must not be empty".into(),
75 ));
76 }
77
78 let id = generate_id();
79 let vector = self
80 .embedding
81 .embed(content)
82 .await
83 .map_err(|e| ContextStoreError::Embedding(e.to_string()))?;
84
85 let dims = i32::try_from(self.embedding.dimensions())
86 .map_err(|e| ContextStoreError::InvalidInput(format!("bad dimensions: {e}")))?;
87
88 let schema = self
89 .table
90 .schema()
91 .await
92 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
93
94 let batch = RecordBatch::try_new(
95 schema,
96 vec![
97 Arc::new(StringArray::from(vec![id.as_str()])),
98 Arc::new(StringArray::from(vec![content])),
99 Arc::new(StringArray::from(vec![metadata.to_string().as_str()])),
100 Arc::new(
101 FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
102 vec![Some(vector.into_iter().map(Some).collect::<Vec<_>>())],
103 dims,
104 ),
105 ),
106 ],
107 )
108 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
109
110 let batches = RecordBatchIterator::new(
111 vec![Ok(batch)],
112 self.table
113 .schema()
114 .await
115 .map_err(|e| ContextStoreError::Storage(e.to_string()))?,
116 );
117
118 self.table
119 .add(batches)
120 .execute()
121 .await
122 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
123
124 Ok(id)
125 }
126
127 pub async fn search(
128 &self,
129 query: &str,
130 k: usize,
131 filter: Option<String>,
132 ) -> Result<Vec<SearchResult>, ContextStoreError> {
133 let count = self
134 .table
135 .count_rows(None)
136 .await
137 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
138 if count == 0 {
139 return Ok(vec![]);
140 }
141
142 let vector = self
143 .embedding
144 .embed(query)
145 .await
146 .map_err(|e| ContextStoreError::Embedding(e.to_string()))?;
147
148 let mut builder = self
149 .table
150 .query()
151 .nearest_to(vector)
152 .map_err(|e| ContextStoreError::Storage(e.to_string()))?
153 .limit(k);
154
155 if let Some(f) = filter {
156 builder = builder.only_if(f);
157 }
158
159 let batches: Vec<RecordBatch> = builder
160 .execute()
161 .await
162 .map_err(|e| ContextStoreError::Storage(e.to_string()))?
163 .try_collect()
164 .await
165 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
166
167 let mut results = Vec::new();
168 for batch in &batches {
169 let content_col = batch
170 .column_by_name("content")
171 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
172 .ok_or_else(|| ContextStoreError::Storage("missing content column".into()))?;
173
174 let metadata_col = batch
175 .column_by_name("metadata")
176 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
177 .ok_or_else(|| ContextStoreError::Storage("missing metadata column".into()))?;
178
179 let distance_col = batch
180 .column_by_name("_distance")
181 .and_then(|c| c.as_any().downcast_ref::<arrow_array::Float32Array>())
182 .ok_or_else(|| ContextStoreError::Storage("missing _distance column".into()))?;
183
184 for i in 0..batch.num_rows() {
185 let meta_str = metadata_col.value(i);
186 let metadata: serde_json::Value =
187 serde_json::from_str(meta_str).unwrap_or_default();
188
189 results.push(SearchResult {
190 content: content_col.value(i).to_string(),
191 score: 1.0 - distance_col.value(i),
192 metadata,
193 });
194 }
195 }
196
197 Ok(results)
198 }
199
200 pub async fn stats(&self) -> Result<StorageStats, ContextStoreError> {
201 let count = self
202 .table
203 .count_rows(None)
204 .await
205 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
206 Ok(StorageStats {
207 document_count: count,
208 })
209 }
210
211 pub async fn delete(&self, id: &str) -> Result<(), ContextStoreError> {
212 let escaped = id.replace('\'', "''");
213 let filter = format!("id = '{escaped}'");
214
215 let count_before = self
216 .table
217 .count_rows(Some(filter.clone()))
218 .await
219 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
220
221 if count_before == 0 {
222 return Err(ContextStoreError::NotFound(id.to_string()));
223 }
224
225 self.table
226 .delete(&filter)
227 .await
228 .map_err(|e| ContextStoreError::Storage(e.to_string()))?;
229
230 Ok(())
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::config::ContextConfig;
238 use crate::config::HnswConfig;
239 use erio_embedding::EmbeddingError;
240 use std::sync::Arc;
241
242 struct FakeEmbedding;
243
244 #[async_trait::async_trait]
245 impl erio_embedding::EmbeddingEngine for FakeEmbedding {
246 #[allow(clippy::unnecessary_literal_bound)]
247 fn name(&self) -> &str {
248 "fake"
249 }
250
251 fn dimensions(&self) -> usize {
252 3
253 }
254
255 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
256 let bytes = text.as_bytes();
257 Ok(vec![
258 f32::from(*bytes.first().unwrap_or(&0)),
259 f32::from(*bytes.get(1).unwrap_or(&0)),
260 f32::from(*bytes.get(2).unwrap_or(&0)),
261 ])
262 }
263
264 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
265 let mut results = Vec::with_capacity(texts.len());
266 for text in texts {
267 results.push(self.embed(text).await?);
268 }
269 Ok(results)
270 }
271 }
272
273 fn test_config(dir: &tempfile::TempDir) -> ContextConfig {
274 ContextConfig {
275 path: dir.path().to_path_buf(),
276 index: HnswConfig::default(),
277 }
278 }
279
280 fn fake_embedding() -> Arc<dyn erio_embedding::EmbeddingEngine> {
281 Arc::new(FakeEmbedding)
282 }
283
284 #[tokio::test]
287 async fn new_creates_empty_store() {
288 let dir = tempfile::tempdir().unwrap();
289 let store = ContextStore::new(test_config(&dir), fake_embedding())
290 .await
291 .unwrap();
292 let stats = store.stats().await.unwrap();
293 assert_eq!(stats.document_count, 0);
294 }
295
296 #[tokio::test]
299 async fn add_returns_document_id() {
300 let dir = tempfile::tempdir().unwrap();
301 let store = ContextStore::new(test_config(&dir), fake_embedding())
302 .await
303 .unwrap();
304 let id = store
305 .add("hello world", serde_json::json!({}))
306 .await
307 .unwrap();
308 assert!(!id.is_empty());
309 }
310
311 #[tokio::test]
312 async fn add_increments_document_count() {
313 let dir = tempfile::tempdir().unwrap();
314 let store = ContextStore::new(test_config(&dir), fake_embedding())
315 .await
316 .unwrap();
317 store.add("first", serde_json::json!({})).await.unwrap();
318 store.add("second", serde_json::json!({})).await.unwrap();
319 let stats = store.stats().await.unwrap();
320 assert_eq!(stats.document_count, 2);
321 }
322
323 #[tokio::test]
324 async fn add_rejects_empty_content() {
325 let dir = tempfile::tempdir().unwrap();
326 let store = ContextStore::new(test_config(&dir), fake_embedding())
327 .await
328 .unwrap();
329 let result = store.add("", serde_json::json!({})).await;
330 assert!(result.is_err());
331 }
332
333 #[tokio::test]
336 async fn search_returns_matching_results() {
337 let dir = tempfile::tempdir().unwrap();
338 let store = ContextStore::new(test_config(&dir), fake_embedding())
339 .await
340 .unwrap();
341 store
342 .add("hello world", serde_json::json!({}))
343 .await
344 .unwrap();
345 store
346 .add("hello there", serde_json::json!({}))
347 .await
348 .unwrap();
349 store.add("goodbye", serde_json::json!({})).await.unwrap();
350
351 let results = store.search("hello", 2, None).await.unwrap();
352 assert_eq!(results.len(), 2);
353 }
354
355 #[tokio::test]
356 async fn search_returns_results_with_scores() {
357 let dir = tempfile::tempdir().unwrap();
358 let store = ContextStore::new(test_config(&dir), fake_embedding())
359 .await
360 .unwrap();
361 store.add("hello", serde_json::json!({})).await.unwrap();
362
363 let results = store.search("hello", 1, None).await.unwrap();
364 assert_eq!(results.len(), 1);
365 assert!(results[0].score >= 0.0);
366 }
367
368 #[tokio::test]
369 async fn search_respects_k_limit() {
370 let dir = tempfile::tempdir().unwrap();
371 let store = ContextStore::new(test_config(&dir), fake_embedding())
372 .await
373 .unwrap();
374 for i in 0..10 {
375 store
376 .add(&format!("document {i}"), serde_json::json!({}))
377 .await
378 .unwrap();
379 }
380
381 let results = store.search("document", 3, None).await.unwrap();
382 assert_eq!(results.len(), 3);
383 }
384
385 #[tokio::test]
386 async fn search_on_empty_store_returns_empty() {
387 let dir = tempfile::tempdir().unwrap();
388 let store = ContextStore::new(test_config(&dir), fake_embedding())
389 .await
390 .unwrap();
391 let results = store.search("anything", 5, None).await.unwrap();
392 assert!(results.is_empty());
393 }
394
395 #[tokio::test]
398 async fn delete_removes_document() {
399 let dir = tempfile::tempdir().unwrap();
400 let store = ContextStore::new(test_config(&dir), fake_embedding())
401 .await
402 .unwrap();
403 let id = store
404 .add("to be deleted", serde_json::json!({}))
405 .await
406 .unwrap();
407
408 store.delete(&id).await.unwrap();
409
410 let stats = store.stats().await.unwrap();
411 assert_eq!(stats.document_count, 0);
412 }
413
414 #[tokio::test]
415 async fn delete_nonexistent_returns_error() {
416 let dir = tempfile::tempdir().unwrap();
417 let store = ContextStore::new(test_config(&dir), fake_embedding())
418 .await
419 .unwrap();
420 let result = store.delete("nonexistent_id").await;
421 assert!(result.is_err());
422 }
423
424 #[tokio::test]
427 async fn search_with_metadata_filter_narrows_results() {
428 let dir = tempfile::tempdir().unwrap();
429 let store = ContextStore::new(test_config(&dir), fake_embedding())
430 .await
431 .unwrap();
432 store
433 .add("hello alpha", serde_json::json!({"source": "docs"}))
434 .await
435 .unwrap();
436 store
437 .add("hello beta", serde_json::json!({"source": "code"}))
438 .await
439 .unwrap();
440 store
441 .add("hello gamma", serde_json::json!({"source": "docs"}))
442 .await
443 .unwrap();
444
445 let filter = "metadata LIKE '%\"source\":\"docs\"%'".to_string();
446 let results = store.search("hello", 10, Some(filter)).await.unwrap();
447 assert_eq!(results.len(), 2);
448 for r in &results {
449 assert_eq!(r.metadata["source"], "docs");
450 }
451 }
452
453 #[tokio::test]
454 async fn search_filter_returns_empty_when_nothing_matches() {
455 let dir = tempfile::tempdir().unwrap();
456 let store = ContextStore::new(test_config(&dir), fake_embedding())
457 .await
458 .unwrap();
459 store
460 .add("hello", serde_json::json!({"source": "docs"}))
461 .await
462 .unwrap();
463
464 let filter = "metadata LIKE '%\"source\":\"nonexistent\"%'".to_string();
465 let results = store.search("hello", 10, Some(filter)).await.unwrap();
466 assert!(results.is_empty());
467 }
468
469 #[tokio::test]
470 async fn search_preserves_metadata_in_results() {
471 let dir = tempfile::tempdir().unwrap();
472 let store = ContextStore::new(test_config(&dir), fake_embedding())
473 .await
474 .unwrap();
475 let meta = serde_json::json!({"source": "test", "priority": 5});
476 store.add("hello world", meta.clone()).await.unwrap();
477
478 let results = store.search("hello", 1, None).await.unwrap();
479 assert_eq!(results.len(), 1);
480 assert_eq!(results[0].metadata["source"], "test");
481 assert_eq!(results[0].metadata["priority"], 5);
482 }
483}