Skip to main content

agent_io/memory/backends/
sqlite.rs

1//! SQLite-based memory store for persistent storage
2
3use async_trait::async_trait;
4use rusqlite::{Connection, params};
5use std::path::PathBuf;
6use std::sync::Arc;
7use tokio::sync::Mutex;
8
9use crate::Result;
10use crate::memory::entry::{MemoryEntry, MemoryType};
11use crate::memory::store::MemoryStore;
12
13/// SQLite memory store for persistent storage
14pub struct SqliteStore {
15    conn: Arc<Mutex<Connection>>,
16}
17
18impl SqliteStore {
19    /// Create a new SQLite store with an in-memory database
20    pub fn new() -> Result<Self> {
21        let conn = Connection::open_in_memory().map_err(|e| {
22            crate::Error::Agent(format!("Failed to create in-memory SQLite: {}", e))
23        })?;
24
25        Self::initialize_schema(&conn)?;
26
27        Ok(Self {
28            conn: Arc::new(Mutex::new(conn)),
29        })
30    }
31
32    /// Create a new SQLite store with a file database
33    pub fn open<P: Into<PathBuf>>(path: P) -> Result<Self> {
34        let path = path.into();
35
36        // Ensure parent directory exists
37        if let Some(parent) = path.parent() {
38            std::fs::create_dir_all(parent)
39                .map_err(|e| crate::Error::Agent(format!("Failed to create directory: {}", e)))?;
40        }
41
42        let conn = Connection::open(&path)
43            .map_err(|e| crate::Error::Agent(format!("Failed to open SQLite database: {}", e)))?;
44
45        Self::initialize_schema(&conn)?;
46
47        Ok(Self {
48            conn: Arc::new(Mutex::new(conn)),
49        })
50    }
51
52    /// Initialize database schema (synchronous)
53    fn initialize_schema(conn: &Connection) -> Result<()> {
54        conn.execute_batch(
55            r#"
56            CREATE TABLE IF NOT EXISTS memories (
57                id TEXT PRIMARY KEY,
58                content TEXT NOT NULL,
59                embedding BLOB,
60                memory_type TEXT NOT NULL DEFAULT 'short_term',
61                metadata TEXT,
62                created_at TEXT NOT NULL,
63                last_accessed TEXT,
64                importance REAL NOT NULL DEFAULT 0.5,
65                access_count INTEGER NOT NULL DEFAULT 0
66            );
67            
68            CREATE INDEX IF NOT EXISTS idx_memory_type ON memories(memory_type);
69            CREATE INDEX IF NOT EXISTS idx_importance ON memories(importance);
70            CREATE INDEX IF NOT EXISTS idx_created_at ON memories(created_at);
71            
72            CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
73                id UNINDEXED,
74                content,
75                content='memories',
76                content_rowid='rowid'
77            );
78            
79            CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
80                INSERT INTO memories_fts(rowid, id, content) 
81                VALUES (new.rowid, new.id, new.content);
82            END;
83            
84            CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
85                INSERT INTO memories_fts(memories_fts, rowid, id, content) 
86                VALUES('delete', old.rowid, old.id, old.content);
87            END;
88            
89            CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
90                INSERT INTO memories_fts(memories_fts, rowid, id, content) 
91                VALUES('delete', old.rowid, old.id, old.content);
92                INSERT INTO memories_fts(rowid, id, content) 
93                VALUES (new.rowid, new.id, new.content);
94            END;
95            "#,
96        )
97        .map_err(|e| crate::Error::Agent(format!("Failed to initialize schema: {}", e)))?;
98
99        Ok(())
100    }
101
102    /// Convert memory type to string
103    fn memory_type_to_string(t: &MemoryType) -> &'static str {
104        match t {
105            MemoryType::ShortTerm => "short_term",
106            MemoryType::LongTerm => "long_term",
107            MemoryType::Episodic => "episodic",
108            MemoryType::Semantic => "semantic",
109        }
110    }
111
112    /// Convert string to memory type
113    fn string_to_memory_type(s: &str) -> MemoryType {
114        match s {
115            "long_term" => MemoryType::LongTerm,
116            "episodic" => MemoryType::Episodic,
117            "semantic" => MemoryType::Semantic,
118            _ => MemoryType::ShortTerm,
119        }
120    }
121}
122
123impl Default for SqliteStore {
124    fn default() -> Self {
125        Self::new().expect("Failed to create default SqliteStore")
126    }
127}
128
129#[async_trait]
130impl MemoryStore for SqliteStore {
131    async fn add(&self, entry: MemoryEntry) -> Result<String> {
132        let conn = self.conn.clone();
133        let id = entry.id.clone();
134
135        tokio::task::spawn_blocking(move || {
136            let conn = conn.blocking_lock();
137
138            let embedding_bytes = entry.embedding.as_ref().map(|v| {
139                let len = v.len() * std::mem::size_of::<f32>();
140                let mut bytes = Vec::with_capacity(len);
141                for &f in v {
142                    bytes.extend_from_slice(&f.to_le_bytes());
143                }
144                bytes
145            });
146
147            conn.execute(
148                r#"
149                INSERT INTO memories (id, content, embedding, memory_type, metadata, created_at, 
150                                     last_accessed, importance, access_count)
151                VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
152                "#,
153                params![
154                    entry.id,
155                    entry.content,
156                    embedding_bytes,
157                    Self::memory_type_to_string(&entry.memory_type),
158                    if entry.metadata.is_empty() {
159                        None::<String>
160                    } else {
161                        Some(serde_json::to_string(&entry.metadata).unwrap_or_default())
162                    },
163                    entry.created_at.to_rfc3339(),
164                    entry.last_accessed.map(|t| t.to_rfc3339()),
165                    entry.importance,
166                    entry.access_count,
167                ],
168            )
169            .map_err(|e| crate::Error::Agent(format!("Failed to insert memory: {}", e)))?;
170
171            Ok(id)
172        })
173        .await
174        .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
175    }
176
177    async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
178        let conn = self.conn.clone();
179        let id = id.to_string();
180
181        tokio::task::spawn_blocking(move || {
182            let conn = conn.blocking_lock();
183
184            let result = conn.query_row(
185                "SELECT id, content, embedding, memory_type, metadata, created_at, 
186                        last_accessed, importance, access_count 
187                 FROM memories WHERE id = ?1",
188                params![id],
189                |row| {
190                    let embedding_blob: Option<Vec<u8>> = row.get(2)?;
191                    let embedding = embedding_blob.as_ref().map(|blob| {
192                        let len = blob.len() / std::mem::size_of::<f32>();
193                        let mut vec = Vec::with_capacity(len);
194                        for chunk in blob.chunks(std::mem::size_of::<f32>()) {
195                            let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
196                            vec.push(f32::from_le_bytes(bytes));
197                        }
198                        vec
199                    });
200
201                    let metadata_str: Option<String> = row.get(4)?;
202                    let metadata: std::collections::HashMap<String, serde_json::Value> =
203                        metadata_str
204                            .and_then(|s| serde_json::from_str(&s).ok())
205                            .unwrap_or_default();
206
207                    let created_at_str: String = row.get(5)?;
208                    let last_accessed_str: Option<String> = row.get(6)?;
209
210                    Ok(MemoryEntry {
211                        id: row.get(0)?,
212                        content: row.get(1)?,
213                        embedding,
214                        memory_type: Self::string_to_memory_type(&row.get::<_, String>(3)?),
215                        metadata,
216                        created_at: chrono::DateTime::parse_from_rfc3339(&created_at_str)
217                            .map(|dt| dt.with_timezone(&chrono::Utc))
218                            .unwrap_or_else(|_| chrono::Utc::now()),
219                        last_accessed: last_accessed_str
220                            .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
221                            .map(|dt| dt.with_timezone(&chrono::Utc)),
222                        importance: row.get(7)?,
223                        access_count: row.get(8)?,
224                    })
225                },
226            );
227
228            match result {
229                Ok(entry) => Ok(Some(entry)),
230                Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
231                Err(e) => Err(crate::Error::Agent(format!("Failed to get memory: {}", e))),
232            }
233        })
234        .await
235        .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
236    }
237
238    async fn delete(&self, id: &str) -> Result<()> {
239        let conn = self.conn.clone();
240        let id = id.to_string();
241
242        tokio::task::spawn_blocking(move || {
243            let conn = conn.blocking_lock();
244
245            conn.execute("DELETE FROM memories WHERE id = ?1", params![id])
246                .map_err(|e| crate::Error::Agent(format!("Failed to delete memory: {}", e)))?;
247
248            Ok(())
249        })
250        .await
251        .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
252    }
253
254    async fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
255        let conn = self.conn.clone();
256        let query = query.to_string();
257
258        tokio::task::spawn_blocking(move || {
259            let conn = conn.blocking_lock();
260
261            let mut stmt = conn
262                .prepare(
263                    r#"
264                SELECT m.id, m.content, m.embedding, m.memory_type, m.metadata, 
265                       m.created_at, m.last_accessed, m.importance, m.access_count
266                FROM memories m
267                JOIN memories_fts fts ON m.id = fts.id
268                WHERE memories_fts MATCH ?1
269                ORDER BY m.importance DESC
270                LIMIT ?2
271                "#,
272                )
273                .map_err(|e| crate::Error::Agent(format!("Failed to prepare search: {}", e)))?;
274
275            let entries = stmt
276                .query_map(params![query, limit as i64], |row| {
277                    let embedding_blob: Option<Vec<u8>> = row.get(2)?;
278                    let embedding = embedding_blob.as_ref().map(|blob| {
279                        let len = blob.len() / std::mem::size_of::<f32>();
280                        let mut vec = Vec::with_capacity(len);
281                        for chunk in blob.chunks(std::mem::size_of::<f32>()) {
282                            let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
283                            vec.push(f32::from_le_bytes(bytes));
284                        }
285                        vec
286                    });
287
288                    let metadata_str: Option<String> = row.get(4)?;
289                    let metadata: std::collections::HashMap<String, serde_json::Value> =
290                        metadata_str
291                            .and_then(|s| serde_json::from_str(&s).ok())
292                            .unwrap_or_default();
293
294                    let created_at_str: String = row.get(5)?;
295                    let last_accessed_str: Option<String> = row.get(6)?;
296
297                    Ok(MemoryEntry {
298                        id: row.get(0)?,
299                        content: row.get(1)?,
300                        embedding,
301                        memory_type: Self::string_to_memory_type(&row.get::<_, String>(3)?),
302                        metadata,
303                        created_at: chrono::DateTime::parse_from_rfc3339(&created_at_str)
304                            .map(|dt| dt.with_timezone(&chrono::Utc))
305                            .unwrap_or_else(|_| chrono::Utc::now()),
306                        last_accessed: last_accessed_str
307                            .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
308                            .map(|dt| dt.with_timezone(&chrono::Utc)),
309                        importance: row.get(7)?,
310                        access_count: row.get(8)?,
311                    })
312                })
313                .map_err(|e| crate::Error::Agent(format!("Failed to search memories: {}", e)))?;
314
315            let mut results = Vec::new();
316            for entry in entries {
317                results.push(
318                    entry.map_err(|e| {
319                        crate::Error::Agent(format!("Failed to parse entry: {}", e))
320                    })?,
321                );
322            }
323
324            Ok(results)
325        })
326        .await
327        .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
328    }
329
330    async fn search_by_embedding(
331        &self,
332        embedding: &[f32],
333        limit: usize,
334        threshold: f32,
335    ) -> Result<Vec<MemoryEntry>> {
336        // For SQLite, we need to compute similarity in memory
337        // This is a simplified implementation - for production, consider using a vector database
338        let conn = self.conn.clone();
339        let embedding = embedding.to_vec();
340
341        tokio::task::spawn_blocking(move || {
342            let conn = conn.blocking_lock();
343
344            let mut stmt = conn
345                .prepare(
346                    "SELECT id, content, embedding, memory_type, metadata, created_at, 
347                        last_accessed, importance, access_count 
348                 FROM memories 
349                 WHERE embedding IS NOT NULL
350                 ORDER BY importance DESC",
351                )
352                .map_err(|e| {
353                    crate::Error::Agent(format!("Failed to prepare embedding search: {}", e))
354                })?;
355
356            let entries = stmt
357                .query_map([], |row| {
358                    let embedding_blob: Vec<u8> = row.get(2)?;
359                    let stored_embedding: Vec<f32> = {
360                        let len = embedding_blob.len() / std::mem::size_of::<f32>();
361                        let mut vec = Vec::with_capacity(len);
362                        for chunk in embedding_blob.chunks(std::mem::size_of::<f32>()) {
363                            let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
364                            vec.push(f32::from_le_bytes(bytes));
365                        }
366                        vec
367                    };
368
369                    let metadata_str: Option<String> = row.get(4)?;
370                    let metadata: std::collections::HashMap<String, serde_json::Value> =
371                        metadata_str
372                            .and_then(|s| serde_json::from_str(&s).ok())
373                            .unwrap_or_default();
374
375                    let created_at_str: String = row.get(5)?;
376                    let last_accessed_str: Option<String> = row.get(6)?;
377
378                    let entry = MemoryEntry {
379                        id: row.get(0)?,
380                        content: row.get(1)?,
381                        embedding: Some(stored_embedding.clone()),
382                        memory_type: Self::string_to_memory_type(&row.get::<_, String>(3)?),
383                        metadata,
384                        created_at: chrono::DateTime::parse_from_rfc3339(&created_at_str)
385                            .map(|dt| dt.with_timezone(&chrono::Utc))
386                            .unwrap_or_else(|_| chrono::Utc::now()),
387                        last_accessed: last_accessed_str
388                            .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
389                            .map(|dt| dt.with_timezone(&chrono::Utc)),
390                        importance: row.get(7)?,
391                        access_count: row.get(8)?,
392                    };
393
394                    Ok((entry, stored_embedding))
395                })
396                .map_err(|e| {
397                    crate::Error::Agent(format!("Failed to search by embedding: {}", e))
398                })?;
399
400            // Compute cosine similarity and filter by threshold
401            let mut results: Vec<_> = entries
402                .filter_map(|r| r.ok())
403                .map(|(entry, stored)| {
404                    let similarity = cosine_similarity(&embedding, &stored);
405                    (entry, similarity)
406                })
407                .filter(|(_, sim)| *sim >= threshold)
408                .collect();
409
410            results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
411            results.truncate(limit);
412
413            Ok(results.into_iter().map(|(entry, _)| entry).collect())
414        })
415        .await
416        .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
417    }
418
419    async fn ids(&self) -> Result<Vec<String>> {
420        let conn = self.conn.clone();
421
422        tokio::task::spawn_blocking(move || {
423            let conn = conn.blocking_lock();
424
425            let mut stmt = conn
426                .prepare("SELECT id FROM memories ORDER BY created_at DESC")
427                .map_err(|e| crate::Error::Agent(format!("Failed to prepare ids: {}", e)))?;
428
429            let ids = stmt
430                .query_map([], |row| row.get(0))
431                .map_err(|e| crate::Error::Agent(format!("Failed to get ids: {}", e)))?;
432
433            let mut results = Vec::new();
434            for id in ids {
435                results.push(
436                    id.map_err(|e| crate::Error::Agent(format!("Failed to parse id: {}", e)))?,
437                );
438            }
439
440            Ok(results)
441        })
442        .await
443        .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
444    }
445
446    async fn count(&self) -> Result<usize> {
447        let conn = self.conn.clone();
448
449        tokio::task::spawn_blocking(move || {
450            let conn = conn.blocking_lock();
451
452            let count: i64 = conn
453                .query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))
454                .map_err(|e| crate::Error::Agent(format!("Failed to count memories: {}", e)))?;
455
456            Ok(count as usize)
457        })
458        .await
459        .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
460    }
461
462    async fn update(&self, entry: MemoryEntry) -> Result<()> {
463        let conn = self.conn.clone();
464
465        tokio::task::spawn_blocking(move || {
466            let conn = conn.blocking_lock();
467
468            let embedding_bytes = entry.embedding.as_ref().map(|v| {
469                let len = v.len() * std::mem::size_of::<f32>();
470                let mut bytes = Vec::with_capacity(len);
471                for &f in v {
472                    bytes.extend_from_slice(&f.to_le_bytes());
473                }
474                bytes
475            });
476
477            conn.execute(
478                r#"
479                UPDATE memories SET 
480                    content = ?2,
481                    embedding = ?3,
482                    memory_type = ?4,
483                    metadata = ?5,
484                    last_accessed = ?6,
485                    importance = ?7,
486                    access_count = ?8
487                WHERE id = ?1
488                "#,
489                params![
490                    entry.id,
491                    entry.content,
492                    embedding_bytes,
493                    Self::memory_type_to_string(&entry.memory_type),
494                    if entry.metadata.is_empty() {
495                        None::<String>
496                    } else {
497                        Some(serde_json::to_string(&entry.metadata).unwrap_or_default())
498                    },
499                    entry.last_accessed.map(|t| t.to_rfc3339()),
500                    entry.importance,
501                    entry.access_count,
502                ],
503            )
504            .map_err(|e| crate::Error::Agent(format!("Failed to update memory: {}", e)))?;
505
506            Ok(())
507        })
508        .await
509        .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
510    }
511
512    async fn clear(&self) -> Result<()> {
513        let conn = self.conn.clone();
514
515        tokio::task::spawn_blocking(move || {
516            let conn = conn.blocking_lock();
517
518            conn.execute("DELETE FROM memories", [])
519                .map_err(|e| crate::Error::Agent(format!("Failed to clear memories: {}", e)))?;
520
521            Ok(())
522        })
523        .await
524        .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
525    }
526}
527
528/// Compute cosine similarity between two vectors
529fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
530    if a.len() != b.len() || a.is_empty() {
531        return 0.0;
532    }
533
534    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
535    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
536    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
537
538    if mag_a == 0.0 || mag_b == 0.0 {
539        return 0.0;
540    }
541
542    dot / (mag_a * mag_b)
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    #[tokio::test]
550    async fn test_sqlite_store_basic() {
551        let store = SqliteStore::new().expect("Failed to create store");
552
553        let entry = MemoryEntry::new("This is a test memory");
554        let id = store.add(entry.clone()).await.expect("Failed to add");
555
556        let retrieved = store.get(&id).await.expect("Failed to get");
557        assert!(retrieved.is_some());
558        assert_eq!(retrieved.unwrap().content, "This is a test memory");
559    }
560
561    #[tokio::test]
562    async fn test_sqlite_store_delete() {
563        let store = SqliteStore::new().expect("Failed to create store");
564
565        let entry = MemoryEntry::new("Memory to delete");
566        let id = store.add(entry).await.expect("Failed to add");
567
568        store.delete(&id).await.expect("Failed to delete");
569
570        let retrieved = store.get(&id).await.expect("Failed to get");
571        assert!(retrieved.is_none());
572    }
573
574    #[tokio::test]
575    async fn test_sqlite_store_search() {
576        let store = SqliteStore::new().expect("Failed to create store");
577
578        store
579            .add(MemoryEntry::new("Rust programming language"))
580            .await
581            .ok();
582        store
583            .add(MemoryEntry::new("Python machine learning"))
584            .await
585            .ok();
586        store
587            .add(MemoryEntry::new("Rust async programming"))
588            .await
589            .ok();
590
591        let results = store.search("Rust", 10).await.expect("Failed to search");
592        assert!(!results.is_empty());
593    }
594
595    #[tokio::test]
596    async fn test_sqlite_store_count() {
597        let store = SqliteStore::new().expect("Failed to create store");
598
599        store.clear().await.ok();
600
601        store.add(MemoryEntry::new("Test 1")).await.ok();
602        store.add(MemoryEntry::new("Test 2")).await.ok();
603
604        let count = store.count().await.expect("Failed to count");
605        assert_eq!(count, 2);
606    }
607
608    #[tokio::test]
609    async fn test_sqlite_store_embedding() {
610        let store = SqliteStore::new().expect("Failed to create store");
611
612        let entry =
613            MemoryEntry::new("Test with embedding").with_embedding(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
614
615        store.add(entry).await.expect("Failed to add");
616
617        let query_embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
618        let results = store
619            .search_by_embedding(&query_embedding, 10, 0.9)
620            .await
621            .expect("Failed to search by embedding");
622
623        assert!(!results.is_empty());
624        assert!(results[0].embedding.is_some());
625    }
626}