1use async_trait::async_trait;
4use rusqlite::{Connection, params};
5use std::path::PathBuf;
6use std::sync::Arc;
7use tokio::sync::Mutex;
8
9use crate::Result;
10use crate::memory::entry::{MemoryEntry, MemoryType};
11use crate::memory::store::MemoryStore;
12
13pub struct SqliteStore {
15 conn: Arc<Mutex<Connection>>,
16}
17
18impl SqliteStore {
19 pub fn new() -> Result<Self> {
21 let conn = Connection::open_in_memory().map_err(|e| {
22 crate::Error::Agent(format!("Failed to create in-memory SQLite: {}", e))
23 })?;
24
25 Self::initialize_schema(&conn)?;
26
27 Ok(Self {
28 conn: Arc::new(Mutex::new(conn)),
29 })
30 }
31
32 pub fn open<P: Into<PathBuf>>(path: P) -> Result<Self> {
34 let path = path.into();
35
36 if let Some(parent) = path.parent() {
38 std::fs::create_dir_all(parent)
39 .map_err(|e| crate::Error::Agent(format!("Failed to create directory: {}", e)))?;
40 }
41
42 let conn = Connection::open(&path)
43 .map_err(|e| crate::Error::Agent(format!("Failed to open SQLite database: {}", e)))?;
44
45 Self::initialize_schema(&conn)?;
46
47 Ok(Self {
48 conn: Arc::new(Mutex::new(conn)),
49 })
50 }
51
52 fn initialize_schema(conn: &Connection) -> Result<()> {
54 conn.execute_batch(
55 r#"
56 CREATE TABLE IF NOT EXISTS memories (
57 id TEXT PRIMARY KEY,
58 content TEXT NOT NULL,
59 embedding BLOB,
60 memory_type TEXT NOT NULL DEFAULT 'short_term',
61 metadata TEXT,
62 created_at TEXT NOT NULL,
63 last_accessed TEXT,
64 importance REAL NOT NULL DEFAULT 0.5,
65 access_count INTEGER NOT NULL DEFAULT 0
66 );
67
68 CREATE INDEX IF NOT EXISTS idx_memory_type ON memories(memory_type);
69 CREATE INDEX IF NOT EXISTS idx_importance ON memories(importance);
70 CREATE INDEX IF NOT EXISTS idx_created_at ON memories(created_at);
71
72 CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
73 id UNINDEXED,
74 content,
75 content='memories',
76 content_rowid='rowid'
77 );
78
79 CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
80 INSERT INTO memories_fts(rowid, id, content)
81 VALUES (new.rowid, new.id, new.content);
82 END;
83
84 CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
85 INSERT INTO memories_fts(memories_fts, rowid, id, content)
86 VALUES('delete', old.rowid, old.id, old.content);
87 END;
88
89 CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
90 INSERT INTO memories_fts(memories_fts, rowid, id, content)
91 VALUES('delete', old.rowid, old.id, old.content);
92 INSERT INTO memories_fts(rowid, id, content)
93 VALUES (new.rowid, new.id, new.content);
94 END;
95 "#,
96 )
97 .map_err(|e| crate::Error::Agent(format!("Failed to initialize schema: {}", e)))?;
98
99 Ok(())
100 }
101
102 fn memory_type_to_string(t: &MemoryType) -> &'static str {
104 match t {
105 MemoryType::ShortTerm => "short_term",
106 MemoryType::LongTerm => "long_term",
107 MemoryType::Episodic => "episodic",
108 MemoryType::Semantic => "semantic",
109 }
110 }
111
112 fn string_to_memory_type(s: &str) -> MemoryType {
114 match s {
115 "long_term" => MemoryType::LongTerm,
116 "episodic" => MemoryType::Episodic,
117 "semantic" => MemoryType::Semantic,
118 _ => MemoryType::ShortTerm,
119 }
120 }
121}
122
123impl Default for SqliteStore {
124 fn default() -> Self {
125 Self::new().expect("Failed to create default SqliteStore")
126 }
127}
128
129#[async_trait]
130impl MemoryStore for SqliteStore {
131 async fn add(&self, entry: MemoryEntry) -> Result<String> {
132 let conn = self.conn.clone();
133 let id = entry.id.clone();
134
135 tokio::task::spawn_blocking(move || {
136 let conn = conn.blocking_lock();
137
138 let embedding_bytes = entry.embedding.as_ref().map(|v| {
139 let len = v.len() * std::mem::size_of::<f32>();
140 let mut bytes = Vec::with_capacity(len);
141 for &f in v {
142 bytes.extend_from_slice(&f.to_le_bytes());
143 }
144 bytes
145 });
146
147 conn.execute(
148 r#"
149 INSERT INTO memories (id, content, embedding, memory_type, metadata, created_at,
150 last_accessed, importance, access_count)
151 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
152 "#,
153 params![
154 entry.id,
155 entry.content,
156 embedding_bytes,
157 Self::memory_type_to_string(&entry.memory_type),
158 if entry.metadata.is_empty() {
159 None::<String>
160 } else {
161 Some(serde_json::to_string(&entry.metadata).unwrap_or_default())
162 },
163 entry.created_at.to_rfc3339(),
164 entry.last_accessed.map(|t| t.to_rfc3339()),
165 entry.importance,
166 entry.access_count,
167 ],
168 )
169 .map_err(|e| crate::Error::Agent(format!("Failed to insert memory: {}", e)))?;
170
171 Ok(id)
172 })
173 .await
174 .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
175 }
176
177 async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
178 let conn = self.conn.clone();
179 let id = id.to_string();
180
181 tokio::task::spawn_blocking(move || {
182 let conn = conn.blocking_lock();
183
184 let result = conn.query_row(
185 "SELECT id, content, embedding, memory_type, metadata, created_at,
186 last_accessed, importance, access_count
187 FROM memories WHERE id = ?1",
188 params![id],
189 |row| {
190 let embedding_blob: Option<Vec<u8>> = row.get(2)?;
191 let embedding = embedding_blob.as_ref().map(|blob| {
192 let len = blob.len() / std::mem::size_of::<f32>();
193 let mut vec = Vec::with_capacity(len);
194 for chunk in blob.chunks(std::mem::size_of::<f32>()) {
195 let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
196 vec.push(f32::from_le_bytes(bytes));
197 }
198 vec
199 });
200
201 let metadata_str: Option<String> = row.get(4)?;
202 let metadata: std::collections::HashMap<String, serde_json::Value> =
203 metadata_str
204 .and_then(|s| serde_json::from_str(&s).ok())
205 .unwrap_or_default();
206
207 let created_at_str: String = row.get(5)?;
208 let last_accessed_str: Option<String> = row.get(6)?;
209
210 Ok(MemoryEntry {
211 id: row.get(0)?,
212 content: row.get(1)?,
213 embedding,
214 memory_type: Self::string_to_memory_type(&row.get::<_, String>(3)?),
215 metadata,
216 created_at: chrono::DateTime::parse_from_rfc3339(&created_at_str)
217 .map(|dt| dt.with_timezone(&chrono::Utc))
218 .unwrap_or_else(|_| chrono::Utc::now()),
219 last_accessed: last_accessed_str
220 .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
221 .map(|dt| dt.with_timezone(&chrono::Utc)),
222 importance: row.get(7)?,
223 access_count: row.get(8)?,
224 })
225 },
226 );
227
228 match result {
229 Ok(entry) => Ok(Some(entry)),
230 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
231 Err(e) => Err(crate::Error::Agent(format!("Failed to get memory: {}", e))),
232 }
233 })
234 .await
235 .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
236 }
237
238 async fn delete(&self, id: &str) -> Result<()> {
239 let conn = self.conn.clone();
240 let id = id.to_string();
241
242 tokio::task::spawn_blocking(move || {
243 let conn = conn.blocking_lock();
244
245 conn.execute("DELETE FROM memories WHERE id = ?1", params![id])
246 .map_err(|e| crate::Error::Agent(format!("Failed to delete memory: {}", e)))?;
247
248 Ok(())
249 })
250 .await
251 .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
252 }
253
254 async fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
255 let conn = self.conn.clone();
256 let query = query.to_string();
257
258 tokio::task::spawn_blocking(move || {
259 let conn = conn.blocking_lock();
260
261 let mut stmt = conn
262 .prepare(
263 r#"
264 SELECT m.id, m.content, m.embedding, m.memory_type, m.metadata,
265 m.created_at, m.last_accessed, m.importance, m.access_count
266 FROM memories m
267 JOIN memories_fts fts ON m.id = fts.id
268 WHERE memories_fts MATCH ?1
269 ORDER BY m.importance DESC
270 LIMIT ?2
271 "#,
272 )
273 .map_err(|e| crate::Error::Agent(format!("Failed to prepare search: {}", e)))?;
274
275 let entries = stmt
276 .query_map(params![query, limit as i64], |row| {
277 let embedding_blob: Option<Vec<u8>> = row.get(2)?;
278 let embedding = embedding_blob.as_ref().map(|blob| {
279 let len = blob.len() / std::mem::size_of::<f32>();
280 let mut vec = Vec::with_capacity(len);
281 for chunk in blob.chunks(std::mem::size_of::<f32>()) {
282 let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
283 vec.push(f32::from_le_bytes(bytes));
284 }
285 vec
286 });
287
288 let metadata_str: Option<String> = row.get(4)?;
289 let metadata: std::collections::HashMap<String, serde_json::Value> =
290 metadata_str
291 .and_then(|s| serde_json::from_str(&s).ok())
292 .unwrap_or_default();
293
294 let created_at_str: String = row.get(5)?;
295 let last_accessed_str: Option<String> = row.get(6)?;
296
297 Ok(MemoryEntry {
298 id: row.get(0)?,
299 content: row.get(1)?,
300 embedding,
301 memory_type: Self::string_to_memory_type(&row.get::<_, String>(3)?),
302 metadata,
303 created_at: chrono::DateTime::parse_from_rfc3339(&created_at_str)
304 .map(|dt| dt.with_timezone(&chrono::Utc))
305 .unwrap_or_else(|_| chrono::Utc::now()),
306 last_accessed: last_accessed_str
307 .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
308 .map(|dt| dt.with_timezone(&chrono::Utc)),
309 importance: row.get(7)?,
310 access_count: row.get(8)?,
311 })
312 })
313 .map_err(|e| crate::Error::Agent(format!("Failed to search memories: {}", e)))?;
314
315 let mut results = Vec::new();
316 for entry in entries {
317 results.push(
318 entry.map_err(|e| {
319 crate::Error::Agent(format!("Failed to parse entry: {}", e))
320 })?,
321 );
322 }
323
324 Ok(results)
325 })
326 .await
327 .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
328 }
329
330 async fn search_by_embedding(
331 &self,
332 embedding: &[f32],
333 limit: usize,
334 threshold: f32,
335 ) -> Result<Vec<MemoryEntry>> {
336 let conn = self.conn.clone();
339 let embedding = embedding.to_vec();
340
341 tokio::task::spawn_blocking(move || {
342 let conn = conn.blocking_lock();
343
344 let mut stmt = conn
345 .prepare(
346 "SELECT id, content, embedding, memory_type, metadata, created_at,
347 last_accessed, importance, access_count
348 FROM memories
349 WHERE embedding IS NOT NULL
350 ORDER BY importance DESC",
351 )
352 .map_err(|e| {
353 crate::Error::Agent(format!("Failed to prepare embedding search: {}", e))
354 })?;
355
356 let entries = stmt
357 .query_map([], |row| {
358 let embedding_blob: Vec<u8> = row.get(2)?;
359 let stored_embedding: Vec<f32> = {
360 let len = embedding_blob.len() / std::mem::size_of::<f32>();
361 let mut vec = Vec::with_capacity(len);
362 for chunk in embedding_blob.chunks(std::mem::size_of::<f32>()) {
363 let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
364 vec.push(f32::from_le_bytes(bytes));
365 }
366 vec
367 };
368
369 let metadata_str: Option<String> = row.get(4)?;
370 let metadata: std::collections::HashMap<String, serde_json::Value> =
371 metadata_str
372 .and_then(|s| serde_json::from_str(&s).ok())
373 .unwrap_or_default();
374
375 let created_at_str: String = row.get(5)?;
376 let last_accessed_str: Option<String> = row.get(6)?;
377
378 let entry = MemoryEntry {
379 id: row.get(0)?,
380 content: row.get(1)?,
381 embedding: Some(stored_embedding.clone()),
382 memory_type: Self::string_to_memory_type(&row.get::<_, String>(3)?),
383 metadata,
384 created_at: chrono::DateTime::parse_from_rfc3339(&created_at_str)
385 .map(|dt| dt.with_timezone(&chrono::Utc))
386 .unwrap_or_else(|_| chrono::Utc::now()),
387 last_accessed: last_accessed_str
388 .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
389 .map(|dt| dt.with_timezone(&chrono::Utc)),
390 importance: row.get(7)?,
391 access_count: row.get(8)?,
392 };
393
394 Ok((entry, stored_embedding))
395 })
396 .map_err(|e| {
397 crate::Error::Agent(format!("Failed to search by embedding: {}", e))
398 })?;
399
400 let mut results: Vec<_> = entries
402 .filter_map(|r| r.ok())
403 .map(|(entry, stored)| {
404 let similarity = cosine_similarity(&embedding, &stored);
405 (entry, similarity)
406 })
407 .filter(|(_, sim)| *sim >= threshold)
408 .collect();
409
410 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
411 results.truncate(limit);
412
413 Ok(results.into_iter().map(|(entry, _)| entry).collect())
414 })
415 .await
416 .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
417 }
418
419 async fn ids(&self) -> Result<Vec<String>> {
420 let conn = self.conn.clone();
421
422 tokio::task::spawn_blocking(move || {
423 let conn = conn.blocking_lock();
424
425 let mut stmt = conn
426 .prepare("SELECT id FROM memories ORDER BY created_at DESC")
427 .map_err(|e| crate::Error::Agent(format!("Failed to prepare ids: {}", e)))?;
428
429 let ids = stmt
430 .query_map([], |row| row.get(0))
431 .map_err(|e| crate::Error::Agent(format!("Failed to get ids: {}", e)))?;
432
433 let mut results = Vec::new();
434 for id in ids {
435 results.push(
436 id.map_err(|e| crate::Error::Agent(format!("Failed to parse id: {}", e)))?,
437 );
438 }
439
440 Ok(results)
441 })
442 .await
443 .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
444 }
445
446 async fn count(&self) -> Result<usize> {
447 let conn = self.conn.clone();
448
449 tokio::task::spawn_blocking(move || {
450 let conn = conn.blocking_lock();
451
452 let count: i64 = conn
453 .query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))
454 .map_err(|e| crate::Error::Agent(format!("Failed to count memories: {}", e)))?;
455
456 Ok(count as usize)
457 })
458 .await
459 .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
460 }
461
462 async fn update(&self, entry: MemoryEntry) -> Result<()> {
463 let conn = self.conn.clone();
464
465 tokio::task::spawn_blocking(move || {
466 let conn = conn.blocking_lock();
467
468 let embedding_bytes = entry.embedding.as_ref().map(|v| {
469 let len = v.len() * std::mem::size_of::<f32>();
470 let mut bytes = Vec::with_capacity(len);
471 for &f in v {
472 bytes.extend_from_slice(&f.to_le_bytes());
473 }
474 bytes
475 });
476
477 conn.execute(
478 r#"
479 UPDATE memories SET
480 content = ?2,
481 embedding = ?3,
482 memory_type = ?4,
483 metadata = ?5,
484 last_accessed = ?6,
485 importance = ?7,
486 access_count = ?8
487 WHERE id = ?1
488 "#,
489 params![
490 entry.id,
491 entry.content,
492 embedding_bytes,
493 Self::memory_type_to_string(&entry.memory_type),
494 if entry.metadata.is_empty() {
495 None::<String>
496 } else {
497 Some(serde_json::to_string(&entry.metadata).unwrap_or_default())
498 },
499 entry.last_accessed.map(|t| t.to_rfc3339()),
500 entry.importance,
501 entry.access_count,
502 ],
503 )
504 .map_err(|e| crate::Error::Agent(format!("Failed to update memory: {}", e)))?;
505
506 Ok(())
507 })
508 .await
509 .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
510 }
511
512 async fn clear(&self) -> Result<()> {
513 let conn = self.conn.clone();
514
515 tokio::task::spawn_blocking(move || {
516 let conn = conn.blocking_lock();
517
518 conn.execute("DELETE FROM memories", [])
519 .map_err(|e| crate::Error::Agent(format!("Failed to clear memories: {}", e)))?;
520
521 Ok(())
522 })
523 .await
524 .map_err(|e| crate::Error::Agent(format!("Task join error: {}", e)))?
525 }
526}
527
528fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
530 if a.len() != b.len() || a.is_empty() {
531 return 0.0;
532 }
533
534 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
535 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
536 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
537
538 if mag_a == 0.0 || mag_b == 0.0 {
539 return 0.0;
540 }
541
542 dot / (mag_a * mag_b)
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[tokio::test]
550 async fn test_sqlite_store_basic() {
551 let store = SqliteStore::new().expect("Failed to create store");
552
553 let entry = MemoryEntry::new("This is a test memory");
554 let id = store.add(entry.clone()).await.expect("Failed to add");
555
556 let retrieved = store.get(&id).await.expect("Failed to get");
557 assert!(retrieved.is_some());
558 assert_eq!(retrieved.unwrap().content, "This is a test memory");
559 }
560
561 #[tokio::test]
562 async fn test_sqlite_store_delete() {
563 let store = SqliteStore::new().expect("Failed to create store");
564
565 let entry = MemoryEntry::new("Memory to delete");
566 let id = store.add(entry).await.expect("Failed to add");
567
568 store.delete(&id).await.expect("Failed to delete");
569
570 let retrieved = store.get(&id).await.expect("Failed to get");
571 assert!(retrieved.is_none());
572 }
573
574 #[tokio::test]
575 async fn test_sqlite_store_search() {
576 let store = SqliteStore::new().expect("Failed to create store");
577
578 store
579 .add(MemoryEntry::new("Rust programming language"))
580 .await
581 .ok();
582 store
583 .add(MemoryEntry::new("Python machine learning"))
584 .await
585 .ok();
586 store
587 .add(MemoryEntry::new("Rust async programming"))
588 .await
589 .ok();
590
591 let results = store.search("Rust", 10).await.expect("Failed to search");
592 assert!(!results.is_empty());
593 }
594
595 #[tokio::test]
596 async fn test_sqlite_store_count() {
597 let store = SqliteStore::new().expect("Failed to create store");
598
599 store.clear().await.ok();
600
601 store.add(MemoryEntry::new("Test 1")).await.ok();
602 store.add(MemoryEntry::new("Test 2")).await.ok();
603
604 let count = store.count().await.expect("Failed to count");
605 assert_eq!(count, 2);
606 }
607
608 #[tokio::test]
609 async fn test_sqlite_store_embedding() {
610 let store = SqliteStore::new().expect("Failed to create store");
611
612 let entry =
613 MemoryEntry::new("Test with embedding").with_embedding(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
614
615 store.add(entry).await.expect("Failed to add");
616
617 let query_embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
618 let results = store
619 .search_by_embedding(&query_embedding, 10, 0.9)
620 .await
621 .expect("Failed to search by embedding");
622
623 assert!(!results.is_empty());
624 assert!(results[0].embedding.is_some());
625 }
626}