Skip to main content

agent_io/memory/backends/
lancedb.rs

1//! LanceDB-based memory store for persistent vector storage
2
3use arrow::array::{Array, FixedSizeListArray, Float32Array, Int64Array, StringArray, UInt32Array};
4use arrow::record_batch::RecordBatch;
5use arrow_array::RecordBatchIterator;
6use arrow_array::types::Float32Type;
7use arrow_schema::{DataType, Field, Schema};
8use async_trait::async_trait;
9use futures::TryStreamExt;
10use lancedb::query::{ExecutableQuery, QueryBase, Select};
11use lancedb::{Table, connect};
12use std::path::PathBuf;
13use std::sync::Arc;
14
15use crate::Result;
16use crate::memory::entry::{MemoryEntry, MemoryType};
17use crate::memory::store::MemoryStore;
18
19const TABLE_NAME: &str = "memories";
20const EMBEDDING_DIM: usize = 1536; // OpenAI embedding dimension
21
22/// LanceDB memory store for persistent vector storage with FTS support
23pub struct LanceDbStore {
24    table: Arc<Table>,
25}
26
27impl LanceDbStore {
28    /// Create a new LanceDB store with an in-memory database
29    pub async fn new() -> Result<Self> {
30        Self::open_uri("memory://agent_io_memories").await
31    }
32
33    /// Create a new LanceDB store with a file database
34    pub async fn open<P: Into<PathBuf>>(path: P) -> Result<Self> {
35        let path = path.into();
36
37        // Ensure parent directory exists
38        if let Some(parent) = path.parent() {
39            std::fs::create_dir_all(parent)
40                .map_err(|e| crate::Error::Agent(format!("Failed to create directory: {}", e)))?;
41        }
42
43        let uri = path.to_string_lossy().to_string();
44        Self::open_uri(&uri).await
45    }
46
47    async fn open_uri(uri: &str) -> Result<Self> {
48        let db = connect(uri)
49            .execute()
50            .await
51            .map_err(|e| crate::Error::Agent(format!("Failed to connect to LanceDB: {}", e)))?;
52
53        let table_names = db
54            .table_names()
55            .execute()
56            .await
57            .map_err(|e| crate::Error::Agent(format!("Failed to list tables: {}", e)))?;
58
59        let table = if table_names.contains(&TABLE_NAME.to_string()) {
60            db.open_table(TABLE_NAME)
61                .execute()
62                .await
63                .map_err(|e| crate::Error::Agent(format!("Failed to open table: {}", e)))?
64        } else {
65            // Create an empty table with schema
66            let schema = Self::schema();
67            db.create_empty_table(TABLE_NAME, schema)
68                .execute()
69                .await
70                .map_err(|e| crate::Error::Agent(format!("Failed to create table: {}", e)))?
71        };
72
73        Ok(Self {
74            table: Arc::new(table),
75        })
76    }
77
78    /// Get the table schema
79    fn schema() -> Arc<Schema> {
80        Arc::new(Schema::new(vec![
81            Field::new("id", DataType::Utf8, false),
82            Field::new("content", DataType::Utf8, false),
83            Field::new(
84                "embedding",
85                DataType::FixedSizeList(
86                    Arc::new(Field::new("item", DataType::Float32, true)),
87                    EMBEDDING_DIM as i32,
88                ),
89                true,
90            ),
91            Field::new("memory_type", DataType::Utf8, false),
92            Field::new("metadata", DataType::Utf8, true),
93            Field::new("created_at", DataType::Int64, false),
94            Field::new("last_accessed", DataType::Int64, true),
95            Field::new("importance", DataType::Float32, false),
96            Field::new("access_count", DataType::UInt32, false),
97        ]))
98    }
99
100    /// Convert memory type to string
101    fn memory_type_to_string(t: &MemoryType) -> &'static str {
102        match t {
103            MemoryType::ShortTerm => "short_term",
104            MemoryType::LongTerm => "long_term",
105            MemoryType::Episodic => "episodic",
106            MemoryType::Semantic => "semantic",
107        }
108    }
109
110    /// Convert string to memory type
111    fn string_to_memory_type(s: &str) -> MemoryType {
112        match s {
113            "long_term" => MemoryType::LongTerm,
114            "episodic" => MemoryType::Episodic,
115            "semantic" => MemoryType::Semantic,
116            _ => MemoryType::ShortTerm,
117        }
118    }
119
120    /// Convert MemoryEntry to RecordBatch
121    fn entry_to_batch(entry: &MemoryEntry) -> Result<RecordBatch> {
122        let schema = Self::schema();
123
124        let id_array = StringArray::from(vec![entry.id.clone()]);
125        let content_array = StringArray::from(vec![entry.content.clone()]);
126
127        // Handle embedding as FixedSizeList
128        let embedding_array = if let Some(ref embedding) = entry.embedding {
129            FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
130                vec![Some(embedding.iter().map(|&v| Some(v)).collect::<Vec<_>>())],
131                EMBEDDING_DIM as i32,
132            )
133        } else {
134            // Create a null array with the correct type
135            FixedSizeListArray::from_iter_primitive::<Float32Type, Option<Option<f32>>, _>(
136                vec![None],
137                EMBEDDING_DIM as i32,
138            )
139        };
140
141        let memory_type_array =
142            StringArray::from(vec![Self::memory_type_to_string(&entry.memory_type)]);
143
144        let metadata_array = if entry.metadata.is_empty() {
145            StringArray::from(vec![None::<String>])
146        } else {
147            StringArray::from(vec![Some(
148                serde_json::to_string(&entry.metadata).unwrap_or_default(),
149            )])
150        };
151
152        let created_at_array = Int64Array::from(vec![entry.created_at.timestamp()]);
153        let last_accessed_array =
154            Int64Array::from(vec![entry.last_accessed.map(|la| la.timestamp())]);
155        let importance_array = Float32Array::from(vec![entry.importance]);
156        let access_count_array = UInt32Array::from(vec![entry.access_count]);
157
158        RecordBatch::try_new(
159            schema,
160            vec![
161                Arc::new(id_array),
162                Arc::new(content_array),
163                Arc::new(embedding_array),
164                Arc::new(memory_type_array),
165                Arc::new(metadata_array),
166                Arc::new(created_at_array),
167                Arc::new(last_accessed_array),
168                Arc::new(importance_array),
169                Arc::new(access_count_array),
170            ],
171        )
172        .map_err(|e| crate::Error::Agent(format!("Failed to create record batch: {}", e)))
173    }
174
175    fn parse_batch_row(batch: &RecordBatch, i: usize) -> Result<MemoryEntry> {
176        let id = batch
177            .column(0)
178            .as_any()
179            .downcast_ref::<StringArray>()
180            .map(|arr| arr.value(i).to_string())
181            .unwrap_or_default();
182
183        let content = batch
184            .column(1)
185            .as_any()
186            .downcast_ref::<StringArray>()
187            .map(|arr| arr.value(i).to_string())
188            .unwrap_or_default();
189
190        let embedding = batch
191            .column(2)
192            .as_any()
193            .downcast_ref::<FixedSizeListArray>()
194            .and_then(|arr| {
195                if arr.is_null(i) {
196                    return None;
197                }
198                let values = arr.value(i);
199                values
200                    .as_any()
201                    .downcast_ref::<Float32Array>()
202                    .map(|v| v.values().to_vec())
203            });
204
205        let memory_type = batch
206            .column(3)
207            .as_any()
208            .downcast_ref::<StringArray>()
209            .map(|arr| arr.value(i).to_string())
210            .unwrap_or_default();
211
212        let metadata = batch
213            .column(4)
214            .as_any()
215            .downcast_ref::<StringArray>()
216            .and_then(|arr| {
217                if arr.is_null(i) {
218                    None
219                } else {
220                    Some(arr.value(i).to_string())
221                }
222            });
223
224        let created_at = batch
225            .column(5)
226            .as_any()
227            .downcast_ref::<Int64Array>()
228            .map(|arr| arr.value(i))
229            .unwrap_or(0);
230
231        let last_accessed = batch
232            .column(6)
233            .as_any()
234            .downcast_ref::<Int64Array>()
235            .and_then(|arr| {
236                if arr.is_null(i) {
237                    None
238                } else {
239                    Some(arr.value(i))
240                }
241            });
242
243        let importance = batch
244            .column(7)
245            .as_any()
246            .downcast_ref::<Float32Array>()
247            .map(|arr| arr.value(i))
248            .unwrap_or(0.5);
249
250        let access_count = batch
251            .column(8)
252            .as_any()
253            .downcast_ref::<UInt32Array>()
254            .map(|arr| arr.value(i))
255            .unwrap_or(0);
256
257        let metadata_map: std::collections::HashMap<String, serde_json::Value> = metadata
258            .as_ref()
259            .and_then(|s| serde_json::from_str(s).ok())
260            .unwrap_or_default();
261
262        Ok(MemoryEntry {
263            id,
264            content,
265            embedding,
266            memory_type: Self::string_to_memory_type(&memory_type),
267            metadata: metadata_map,
268            created_at: chrono::DateTime::from_timestamp(created_at, 0)
269                .map(|dt| dt.with_timezone(&chrono::Utc))
270                .unwrap_or_else(chrono::Utc::now),
271            last_accessed: last_accessed
272                .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
273                .map(|dt| dt.with_timezone(&chrono::Utc)),
274            importance,
275            access_count,
276        })
277    }
278}
279
280#[async_trait]
281impl MemoryStore for LanceDbStore {
282    async fn add(&self, entry: MemoryEntry) -> Result<String> {
283        let id = entry.id.clone();
284        let batch = Self::entry_to_batch(&entry)?;
285
286        self.table
287            .add(RecordBatchIterator::new(
288                vec![Ok(batch.clone())],
289                batch.schema(),
290            ))
291            .execute()
292            .await
293            .map_err(|e| crate::Error::Agent(format!("Failed to add memory: {}", e)))?;
294
295        Ok(id)
296    }
297
298    async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
299        let batches = self
300            .table
301            .query()
302            .only_if(format!("id = '{}'", id.replace('\'', "''")))
303            .execute()
304            .await
305            .map_err(|e| crate::Error::Agent(format!("Failed to query: {}", e)))?
306            .try_collect::<Vec<_>>()
307            .await
308            .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
309
310        if let Some(batch) = batches.first()
311            && batch.num_rows() > 0
312        {
313            return Ok(Some(Self::parse_batch_row(batch, 0)?));
314        }
315
316        Ok(None)
317    }
318
319    async fn delete(&self, id: &str) -> Result<()> {
320        self.table
321            .delete(&format!("id = '{}'", id.replace('\'', "''")))
322            .await
323            .map_err(|e| crate::Error::Agent(format!("Failed to delete memory: {}", e)))?;
324
325        Ok(())
326    }
327
328    async fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
329        let batches = self
330            .table
331            .query()
332            .only_if(format!("content LIKE '%{}%'", query.replace('\'', "''")))
333            .limit(limit)
334            .execute()
335            .await
336            .map_err(|e| crate::Error::Agent(format!("Failed to search: {}", e)))?
337            .try_collect::<Vec<_>>()
338            .await
339            .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
340
341        let mut entries = Vec::new();
342        for batch in batches {
343            for i in 0..batch.num_rows() {
344                entries.push(Self::parse_batch_row(&batch, i)?);
345            }
346        }
347
348        Ok(entries)
349    }
350
351    async fn search_by_embedding(
352        &self,
353        embedding: &[f32],
354        limit: usize,
355        threshold: f32,
356    ) -> Result<Vec<MemoryEntry>> {
357        let batches = self
358            .table
359            .query()
360            .limit(limit * 2) // Fetch more to filter by threshold
361            .nearest_to(embedding)
362            .map_err(|e| crate::Error::Agent(format!("Failed to create vector search: {}", e)))?
363            .execute()
364            .await
365            .map_err(|e| crate::Error::Agent(format!("Failed to search by embedding: {}", e)))?
366            .try_collect::<Vec<_>>()
367            .await
368            .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
369
370        let mut entries_with_score = Vec::new();
371        for batch in batches {
372            for i in 0..batch.num_rows() {
373                let entry = Self::parse_batch_row(&batch, i)?;
374
375                // Get similarity from _distance column if present
376                let similarity = if let Some(distance_col) = batch.column_by_name("_distance") {
377                    let dist = distance_col
378                        .as_any()
379                        .downcast_ref::<Float32Array>()
380                        .map(|arr| arr.value(i))
381                        .unwrap_or(1.0);
382                    1.0 - dist // Convert distance to similarity
383                } else if let Some(ref entry_embedding) = entry.embedding {
384                    cosine_similarity(embedding, entry_embedding)
385                } else {
386                    0.0
387                };
388
389                if similarity >= threshold {
390                    entries_with_score.push((entry, similarity));
391                }
392            }
393        }
394
395        // Sort by similarity descending
396        entries_with_score
397            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
398        entries_with_score.truncate(limit);
399
400        Ok(entries_with_score.into_iter().map(|(e, _)| e).collect())
401    }
402
403    async fn ids(&self) -> Result<Vec<String>> {
404        let batches = self
405            .table
406            .query()
407            .select(Select::columns(&["id"]))
408            .execute()
409            .await
410            .map_err(|e| crate::Error::Agent(format!("Failed to query ids: {}", e)))?
411            .try_collect::<Vec<_>>()
412            .await
413            .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
414
415        let mut ids = Vec::new();
416        for batch in batches {
417            if let Some(id_array) = batch
418                .column_by_name("id")
419                .and_then(|col| col.as_any().downcast_ref::<StringArray>())
420            {
421                for i in 0..id_array.len() {
422                    ids.push(id_array.value(i).to_string());
423                }
424            }
425        }
426
427        Ok(ids)
428    }
429
430    async fn count(&self) -> Result<usize> {
431        let batches = self
432            .table
433            .query()
434            .select(Select::columns(&["id"]))
435            .execute()
436            .await
437            .map_err(|e| crate::Error::Agent(format!("Failed to count: {}", e)))?
438            .try_collect::<Vec<_>>()
439            .await
440            .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
441
442        let mut count = 0;
443        for batch in batches {
444            count += batch.num_rows();
445        }
446
447        Ok(count)
448    }
449
450    async fn update(&self, entry: MemoryEntry) -> Result<()> {
451        // LanceDB doesn't have a native update, so we delete and re-add
452        self.delete(&entry.id).await?;
453        self.add(entry).await?;
454        Ok(())
455    }
456
457    async fn clear(&self) -> Result<()> {
458        self.table
459            .delete("true")
460            .await
461            .map_err(|e| crate::Error::Agent(format!("Failed to clear memories: {}", e)))?;
462
463        Ok(())
464    }
465}
466
467/// Compute cosine similarity between two vectors
468fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
469    if a.len() != b.len() || a.is_empty() {
470        return 0.0;
471    }
472
473    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
474    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
475    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
476
477    if mag_a == 0.0 || mag_b == 0.0 {
478        return 0.0;
479    }
480
481    dot / (mag_a * mag_b)
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[tokio::test]
489    async fn test_lancedb_store_basic() {
490        let store = LanceDbStore::new().await.expect("Failed to create store");
491
492        let entry = MemoryEntry::new("This is a test memory");
493        let id = store.add(entry.clone()).await.expect("Failed to add");
494
495        let retrieved = store.get(&id).await.expect("Failed to get");
496        assert!(retrieved.is_some());
497        assert_eq!(retrieved.unwrap().content, "This is a test memory");
498    }
499
500    #[tokio::test]
501    async fn test_lancedb_store_delete() {
502        let store = LanceDbStore::new().await.expect("Failed to create store");
503
504        let entry = MemoryEntry::new("Memory to delete");
505        let id = store.add(entry).await.expect("Failed to add");
506
507        store.delete(&id).await.expect("Failed to delete");
508
509        let retrieved = store.get(&id).await.expect("Failed to get");
510        assert!(retrieved.is_none());
511    }
512
513    #[tokio::test]
514    async fn test_lancedb_store_search() {
515        let store = LanceDbStore::new().await.expect("Failed to create store");
516
517        store
518            .add(MemoryEntry::new("Rust programming language"))
519            .await
520            .ok();
521        store
522            .add(MemoryEntry::new("Python machine learning"))
523            .await
524            .ok();
525        store
526            .add(MemoryEntry::new("Rust async programming"))
527            .await
528            .ok();
529
530        let results = store.search("Rust", 10).await.expect("Failed to search");
531        assert!(!results.is_empty());
532    }
533
534    #[tokio::test]
535    async fn test_lancedb_store_count() {
536        let store = LanceDbStore::new().await.expect("Failed to create store");
537
538        store.clear().await.ok();
539
540        store.add(MemoryEntry::new("Test 1")).await.ok();
541        store.add(MemoryEntry::new("Test 2")).await.ok();
542
543        let count = store.count().await.expect("Failed to count");
544        assert_eq!(count, 2);
545    }
546}