agent_io/memory/backends/
lancedb.rs1use arrow::array::{Array, FixedSizeListArray, Float32Array, Int64Array, StringArray, UInt32Array};
4use arrow::record_batch::RecordBatch;
5use arrow_array::RecordBatchIterator;
6use arrow_array::types::Float32Type;
7use arrow_schema::{DataType, Field, Schema};
8use async_trait::async_trait;
9use futures::TryStreamExt;
10use lancedb::query::{ExecutableQuery, QueryBase, Select};
11use lancedb::{Table, connect};
12use std::path::PathBuf;
13use std::sync::Arc;
14
15use crate::Result;
16use crate::memory::entry::{MemoryEntry, MemoryType};
17use crate::memory::store::MemoryStore;
18
19const TABLE_NAME: &str = "memories";
20const EMBEDDING_DIM: usize = 1536; pub struct LanceDbStore {
24 table: Arc<Table>,
25}
26
27impl LanceDbStore {
28 pub async fn new() -> Result<Self> {
30 Self::open_uri("memory://agent_io_memories").await
31 }
32
33 pub async fn open<P: Into<PathBuf>>(path: P) -> Result<Self> {
35 let path = path.into();
36
37 if let Some(parent) = path.parent() {
39 std::fs::create_dir_all(parent)
40 .map_err(|e| crate::Error::Agent(format!("Failed to create directory: {}", e)))?;
41 }
42
43 let uri = path.to_string_lossy().to_string();
44 Self::open_uri(&uri).await
45 }
46
47 async fn open_uri(uri: &str) -> Result<Self> {
48 let db = connect(uri)
49 .execute()
50 .await
51 .map_err(|e| crate::Error::Agent(format!("Failed to connect to LanceDB: {}", e)))?;
52
53 let table_names = db
54 .table_names()
55 .execute()
56 .await
57 .map_err(|e| crate::Error::Agent(format!("Failed to list tables: {}", e)))?;
58
59 let table = if table_names.contains(&TABLE_NAME.to_string()) {
60 db.open_table(TABLE_NAME)
61 .execute()
62 .await
63 .map_err(|e| crate::Error::Agent(format!("Failed to open table: {}", e)))?
64 } else {
65 let schema = Self::schema();
67 db.create_empty_table(TABLE_NAME, schema)
68 .execute()
69 .await
70 .map_err(|e| crate::Error::Agent(format!("Failed to create table: {}", e)))?
71 };
72
73 Ok(Self {
74 table: Arc::new(table),
75 })
76 }
77
78 fn schema() -> Arc<Schema> {
80 Arc::new(Schema::new(vec![
81 Field::new("id", DataType::Utf8, false),
82 Field::new("content", DataType::Utf8, false),
83 Field::new(
84 "embedding",
85 DataType::FixedSizeList(
86 Arc::new(Field::new("item", DataType::Float32, true)),
87 EMBEDDING_DIM as i32,
88 ),
89 true,
90 ),
91 Field::new("memory_type", DataType::Utf8, false),
92 Field::new("metadata", DataType::Utf8, true),
93 Field::new("created_at", DataType::Int64, false),
94 Field::new("last_accessed", DataType::Int64, true),
95 Field::new("importance", DataType::Float32, false),
96 Field::new("access_count", DataType::UInt32, false),
97 ]))
98 }
99
100 fn memory_type_to_string(t: &MemoryType) -> &'static str {
102 match t {
103 MemoryType::ShortTerm => "short_term",
104 MemoryType::LongTerm => "long_term",
105 MemoryType::Episodic => "episodic",
106 MemoryType::Semantic => "semantic",
107 }
108 }
109
110 fn string_to_memory_type(s: &str) -> MemoryType {
112 match s {
113 "long_term" => MemoryType::LongTerm,
114 "episodic" => MemoryType::Episodic,
115 "semantic" => MemoryType::Semantic,
116 _ => MemoryType::ShortTerm,
117 }
118 }
119
120 fn entry_to_batch(entry: &MemoryEntry) -> Result<RecordBatch> {
122 let schema = Self::schema();
123
124 let id_array = StringArray::from(vec![entry.id.clone()]);
125 let content_array = StringArray::from(vec![entry.content.clone()]);
126
127 let embedding_array = if let Some(ref embedding) = entry.embedding {
129 FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
130 vec![Some(embedding.iter().map(|&v| Some(v)).collect::<Vec<_>>())],
131 EMBEDDING_DIM as i32,
132 )
133 } else {
134 FixedSizeListArray::from_iter_primitive::<Float32Type, Option<Option<f32>>, _>(
136 vec![None],
137 EMBEDDING_DIM as i32,
138 )
139 };
140
141 let memory_type_array =
142 StringArray::from(vec![Self::memory_type_to_string(&entry.memory_type)]);
143
144 let metadata_array = if entry.metadata.is_empty() {
145 StringArray::from(vec![None::<String>])
146 } else {
147 StringArray::from(vec![Some(
148 serde_json::to_string(&entry.metadata).unwrap_or_default(),
149 )])
150 };
151
152 let created_at_array = Int64Array::from(vec![entry.created_at.timestamp()]);
153 let last_accessed_array =
154 Int64Array::from(vec![entry.last_accessed.map(|la| la.timestamp())]);
155 let importance_array = Float32Array::from(vec![entry.importance]);
156 let access_count_array = UInt32Array::from(vec![entry.access_count]);
157
158 RecordBatch::try_new(
159 schema,
160 vec![
161 Arc::new(id_array),
162 Arc::new(content_array),
163 Arc::new(embedding_array),
164 Arc::new(memory_type_array),
165 Arc::new(metadata_array),
166 Arc::new(created_at_array),
167 Arc::new(last_accessed_array),
168 Arc::new(importance_array),
169 Arc::new(access_count_array),
170 ],
171 )
172 .map_err(|e| crate::Error::Agent(format!("Failed to create record batch: {}", e)))
173 }
174
175 fn parse_batch_row(batch: &RecordBatch, i: usize) -> Result<MemoryEntry> {
176 let id = batch
177 .column(0)
178 .as_any()
179 .downcast_ref::<StringArray>()
180 .map(|arr| arr.value(i).to_string())
181 .unwrap_or_default();
182
183 let content = batch
184 .column(1)
185 .as_any()
186 .downcast_ref::<StringArray>()
187 .map(|arr| arr.value(i).to_string())
188 .unwrap_or_default();
189
190 let embedding = batch
191 .column(2)
192 .as_any()
193 .downcast_ref::<FixedSizeListArray>()
194 .and_then(|arr| {
195 if arr.is_null(i) {
196 return None;
197 }
198 let values = arr.value(i);
199 values
200 .as_any()
201 .downcast_ref::<Float32Array>()
202 .map(|v| v.values().to_vec())
203 });
204
205 let memory_type = batch
206 .column(3)
207 .as_any()
208 .downcast_ref::<StringArray>()
209 .map(|arr| arr.value(i).to_string())
210 .unwrap_or_default();
211
212 let metadata = batch
213 .column(4)
214 .as_any()
215 .downcast_ref::<StringArray>()
216 .and_then(|arr| {
217 if arr.is_null(i) {
218 None
219 } else {
220 Some(arr.value(i).to_string())
221 }
222 });
223
224 let created_at = batch
225 .column(5)
226 .as_any()
227 .downcast_ref::<Int64Array>()
228 .map(|arr| arr.value(i))
229 .unwrap_or(0);
230
231 let last_accessed = batch
232 .column(6)
233 .as_any()
234 .downcast_ref::<Int64Array>()
235 .and_then(|arr| {
236 if arr.is_null(i) {
237 None
238 } else {
239 Some(arr.value(i))
240 }
241 });
242
243 let importance = batch
244 .column(7)
245 .as_any()
246 .downcast_ref::<Float32Array>()
247 .map(|arr| arr.value(i))
248 .unwrap_or(0.5);
249
250 let access_count = batch
251 .column(8)
252 .as_any()
253 .downcast_ref::<UInt32Array>()
254 .map(|arr| arr.value(i))
255 .unwrap_or(0);
256
257 let metadata_map: std::collections::HashMap<String, serde_json::Value> = metadata
258 .as_ref()
259 .and_then(|s| serde_json::from_str(s).ok())
260 .unwrap_or_default();
261
262 Ok(MemoryEntry {
263 id,
264 content,
265 embedding,
266 memory_type: Self::string_to_memory_type(&memory_type),
267 metadata: metadata_map,
268 created_at: chrono::DateTime::from_timestamp(created_at, 0)
269 .map(|dt| dt.with_timezone(&chrono::Utc))
270 .unwrap_or_else(chrono::Utc::now),
271 last_accessed: last_accessed
272 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
273 .map(|dt| dt.with_timezone(&chrono::Utc)),
274 importance,
275 access_count,
276 })
277 }
278}
279
280#[async_trait]
281impl MemoryStore for LanceDbStore {
282 async fn add(&self, entry: MemoryEntry) -> Result<String> {
283 let id = entry.id.clone();
284 let batch = Self::entry_to_batch(&entry)?;
285
286 self.table
287 .add(RecordBatchIterator::new(
288 vec![Ok(batch.clone())],
289 batch.schema(),
290 ))
291 .execute()
292 .await
293 .map_err(|e| crate::Error::Agent(format!("Failed to add memory: {}", e)))?;
294
295 Ok(id)
296 }
297
298 async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
299 let batches = self
300 .table
301 .query()
302 .only_if(format!("id = '{}'", id.replace('\'', "''")))
303 .execute()
304 .await
305 .map_err(|e| crate::Error::Agent(format!("Failed to query: {}", e)))?
306 .try_collect::<Vec<_>>()
307 .await
308 .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
309
310 if let Some(batch) = batches.first()
311 && batch.num_rows() > 0
312 {
313 return Ok(Some(Self::parse_batch_row(batch, 0)?));
314 }
315
316 Ok(None)
317 }
318
319 async fn delete(&self, id: &str) -> Result<()> {
320 self.table
321 .delete(&format!("id = '{}'", id.replace('\'', "''")))
322 .await
323 .map_err(|e| crate::Error::Agent(format!("Failed to delete memory: {}", e)))?;
324
325 Ok(())
326 }
327
328 async fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
329 let batches = self
330 .table
331 .query()
332 .only_if(format!("content LIKE '%{}%'", query.replace('\'', "''")))
333 .limit(limit)
334 .execute()
335 .await
336 .map_err(|e| crate::Error::Agent(format!("Failed to search: {}", e)))?
337 .try_collect::<Vec<_>>()
338 .await
339 .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
340
341 let mut entries = Vec::new();
342 for batch in batches {
343 for i in 0..batch.num_rows() {
344 entries.push(Self::parse_batch_row(&batch, i)?);
345 }
346 }
347
348 Ok(entries)
349 }
350
351 async fn search_by_embedding(
352 &self,
353 embedding: &[f32],
354 limit: usize,
355 threshold: f32,
356 ) -> Result<Vec<MemoryEntry>> {
357 let batches = self
358 .table
359 .query()
360 .limit(limit * 2) .nearest_to(embedding)
362 .map_err(|e| crate::Error::Agent(format!("Failed to create vector search: {}", e)))?
363 .execute()
364 .await
365 .map_err(|e| crate::Error::Agent(format!("Failed to search by embedding: {}", e)))?
366 .try_collect::<Vec<_>>()
367 .await
368 .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
369
370 let mut entries_with_score = Vec::new();
371 for batch in batches {
372 for i in 0..batch.num_rows() {
373 let entry = Self::parse_batch_row(&batch, i)?;
374
375 let similarity = if let Some(distance_col) = batch.column_by_name("_distance") {
377 let dist = distance_col
378 .as_any()
379 .downcast_ref::<Float32Array>()
380 .map(|arr| arr.value(i))
381 .unwrap_or(1.0);
382 1.0 - dist } else if let Some(ref entry_embedding) = entry.embedding {
384 cosine_similarity(embedding, entry_embedding)
385 } else {
386 0.0
387 };
388
389 if similarity >= threshold {
390 entries_with_score.push((entry, similarity));
391 }
392 }
393 }
394
395 entries_with_score
397 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
398 entries_with_score.truncate(limit);
399
400 Ok(entries_with_score.into_iter().map(|(e, _)| e).collect())
401 }
402
403 async fn ids(&self) -> Result<Vec<String>> {
404 let batches = self
405 .table
406 .query()
407 .select(Select::columns(&["id"]))
408 .execute()
409 .await
410 .map_err(|e| crate::Error::Agent(format!("Failed to query ids: {}", e)))?
411 .try_collect::<Vec<_>>()
412 .await
413 .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
414
415 let mut ids = Vec::new();
416 for batch in batches {
417 if let Some(id_array) = batch
418 .column_by_name("id")
419 .and_then(|col| col.as_any().downcast_ref::<StringArray>())
420 {
421 for i in 0..id_array.len() {
422 ids.push(id_array.value(i).to_string());
423 }
424 }
425 }
426
427 Ok(ids)
428 }
429
430 async fn count(&self) -> Result<usize> {
431 let batches = self
432 .table
433 .query()
434 .select(Select::columns(&["id"]))
435 .execute()
436 .await
437 .map_err(|e| crate::Error::Agent(format!("Failed to count: {}", e)))?
438 .try_collect::<Vec<_>>()
439 .await
440 .map_err(|e| crate::Error::Agent(format!("Failed to collect batches: {}", e)))?;
441
442 let mut count = 0;
443 for batch in batches {
444 count += batch.num_rows();
445 }
446
447 Ok(count)
448 }
449
450 async fn update(&self, entry: MemoryEntry) -> Result<()> {
451 self.delete(&entry.id).await?;
453 self.add(entry).await?;
454 Ok(())
455 }
456
457 async fn clear(&self) -> Result<()> {
458 self.table
459 .delete("true")
460 .await
461 .map_err(|e| crate::Error::Agent(format!("Failed to clear memories: {}", e)))?;
462
463 Ok(())
464 }
465}
466
467fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
469 if a.len() != b.len() || a.is_empty() {
470 return 0.0;
471 }
472
473 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
474 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
475 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
476
477 if mag_a == 0.0 || mag_b == 0.0 {
478 return 0.0;
479 }
480
481 dot / (mag_a * mag_b)
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[tokio::test]
489 async fn test_lancedb_store_basic() {
490 let store = LanceDbStore::new().await.expect("Failed to create store");
491
492 let entry = MemoryEntry::new("This is a test memory");
493 let id = store.add(entry.clone()).await.expect("Failed to add");
494
495 let retrieved = store.get(&id).await.expect("Failed to get");
496 assert!(retrieved.is_some());
497 assert_eq!(retrieved.unwrap().content, "This is a test memory");
498 }
499
500 #[tokio::test]
501 async fn test_lancedb_store_delete() {
502 let store = LanceDbStore::new().await.expect("Failed to create store");
503
504 let entry = MemoryEntry::new("Memory to delete");
505 let id = store.add(entry).await.expect("Failed to add");
506
507 store.delete(&id).await.expect("Failed to delete");
508
509 let retrieved = store.get(&id).await.expect("Failed to get");
510 assert!(retrieved.is_none());
511 }
512
513 #[tokio::test]
514 async fn test_lancedb_store_search() {
515 let store = LanceDbStore::new().await.expect("Failed to create store");
516
517 store
518 .add(MemoryEntry::new("Rust programming language"))
519 .await
520 .ok();
521 store
522 .add(MemoryEntry::new("Python machine learning"))
523 .await
524 .ok();
525 store
526 .add(MemoryEntry::new("Rust async programming"))
527 .await
528 .ok();
529
530 let results = store.search("Rust", 10).await.expect("Failed to search");
531 assert!(!results.is_empty());
532 }
533
534 #[tokio::test]
535 async fn test_lancedb_store_count() {
536 let store = LanceDbStore::new().await.expect("Failed to create store");
537
538 store.clear().await.ok();
539
540 store.add(MemoryEntry::new("Test 1")).await.ok();
541 store.add(MemoryEntry::new("Test 2")).await.ok();
542
543 let count = store.count().await.expect("Failed to count");
544 assert_eq!(count, 2);
545 }
546}