1use async_trait::async_trait;
25use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions};
26use std::collections::HashMap;
27use std::path::Path;
28use std::str::FromStr;
29use std::sync::Arc;
30
31use cortexai_core::errors::MemoryError;
32use cortexai_core::Message;
33
34use crate::persistence::{MemoryBackend, Session};
35use crate::vector_store::{SearchResult, VectorDocument, VectorStore};
36
37#[derive(Debug, Clone)]
39pub struct SqliteConfig {
40 pub path: String,
42 pub wal_mode: bool,
44 pub max_connections: u32,
46 pub busy_timeout_secs: u64,
48 pub create_if_missing: bool,
50 pub foreign_keys: bool,
52}
53
54impl Default for SqliteConfig {
55 fn default() -> Self {
56 Self {
57 path: "agent_data.db".to_string(),
58 wal_mode: true,
59 max_connections: 5,
60 busy_timeout_secs: 5,
61 create_if_missing: true,
62 foreign_keys: true,
63 }
64 }
65}
66
67impl SqliteConfig {
68 pub fn new(path: impl Into<String>) -> Self {
69 Self {
70 path: path.into(),
71 ..Default::default()
72 }
73 }
74
75 pub fn in_memory() -> Self {
76 Self {
77 path: ":memory:".to_string(),
78 wal_mode: false, ..Default::default()
80 }
81 }
82}
83
84pub struct SqliteStore {
86 pool: SqlitePool,
87 #[allow(dead_code)]
88 config: SqliteConfig,
89}
90
91impl SqliteStore {
92 pub async fn new(config: SqliteConfig) -> Result<Self, MemoryError> {
94 let options = SqliteConnectOptions::from_str(&format!("sqlite://{}?mode=rwc", config.path))
95 .map_err(|e| MemoryError::StorageError(e.to_string()))?
96 .journal_mode(if config.wal_mode {
97 SqliteJournalMode::Wal
98 } else {
99 SqliteJournalMode::Delete
100 })
101 .create_if_missing(config.create_if_missing)
102 .foreign_keys(config.foreign_keys)
103 .busy_timeout(std::time::Duration::from_secs(config.busy_timeout_secs));
104
105 let pool = SqlitePoolOptions::new()
106 .max_connections(config.max_connections)
107 .connect_with(options)
108 .await
109 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
110
111 let store = Self { pool, config };
112 store.initialize_schema().await?;
113
114 Ok(store)
115 }
116
117 pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self, MemoryError> {
119 let config = SqliteConfig::new(path.as_ref().to_string_lossy().to_string());
120 Self::new(config).await
121 }
122
123 pub async fn in_memory() -> Result<Self, MemoryError> {
125 Self::new(SqliteConfig::in_memory()).await
126 }
127
128 async fn initialize_schema(&self) -> Result<(), MemoryError> {
130 sqlx::query(
132 r#"
133 CREATE TABLE IF NOT EXISTS sessions (
134 id TEXT PRIMARY KEY,
135 agent_id TEXT NOT NULL,
136 messages TEXT NOT NULL,
137 metadata TEXT NOT NULL,
138 created_at INTEGER NOT NULL,
139 updated_at INTEGER NOT NULL,
140 resume_token TEXT
141 )
142 "#,
143 )
144 .execute(&self.pool)
145 .await
146 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
147
148 sqlx::query(
150 r#"
151 CREATE INDEX IF NOT EXISTS idx_sessions_resume_token
152 ON sessions(resume_token) WHERE resume_token IS NOT NULL
153 "#,
154 )
155 .execute(&self.pool)
156 .await
157 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
158
159 sqlx::query(
161 r#"
162 CREATE INDEX IF NOT EXISTS idx_sessions_agent_id
163 ON sessions(agent_id)
164 "#,
165 )
166 .execute(&self.pool)
167 .await
168 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
169
170 sqlx::query(
172 r#"
173 CREATE TABLE IF NOT EXISTS vector_documents (
174 id TEXT PRIMARY KEY,
175 content TEXT NOT NULL,
176 embedding BLOB NOT NULL,
177 metadata TEXT NOT NULL,
178 source_id TEXT,
179 chunk_index INTEGER
180 )
181 "#,
182 )
183 .execute(&self.pool)
184 .await
185 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
186
187 sqlx::query(
189 r#"
190 CREATE INDEX IF NOT EXISTS idx_vector_docs_source
191 ON vector_documents(source_id) WHERE source_id IS NOT NULL
192 "#,
193 )
194 .execute(&self.pool)
195 .await
196 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
197
198 sqlx::query(
200 r#"
201 CREATE TABLE IF NOT EXISTS agent_memory (
202 id TEXT PRIMARY KEY,
203 agent_id TEXT NOT NULL,
204 content TEXT NOT NULL,
205 embedding BLOB,
206 tags TEXT,
207 created_at INTEGER NOT NULL
208 )
209 "#,
210 )
211 .execute(&self.pool)
212 .await
213 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
214
215 sqlx::query(
216 r#"
217 CREATE INDEX IF NOT EXISTS idx_agent_memory_agent
218 ON agent_memory(agent_id)
219 "#,
220 )
221 .execute(&self.pool)
222 .await
223 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
224
225 tracing::info!("SQLite schema initialized");
226 Ok(())
227 }
228
229 pub fn pool(&self) -> &SqlitePool {
231 &self.pool
232 }
233
234 pub fn session_backend(self: &Arc<Self>) -> SqliteSessionBackend {
236 SqliteSessionBackend {
237 store: Arc::clone(self),
238 }
239 }
240
241 pub fn vector_store(self: &Arc<Self>, dimension: usize) -> SqliteVectorStore {
243 SqliteVectorStore {
244 store: Arc::clone(self),
245 dimension,
246 }
247 }
248}
249
250pub struct SqliteSessionBackend {
252 store: Arc<SqliteStore>,
253}
254
255#[async_trait]
256impl MemoryBackend for SqliteSessionBackend {
257 async fn save_session(&self, session: &Session) -> Result<(), MemoryError> {
258 let messages_json = serde_json::to_string(&session.messages)
259 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
260 let metadata_json = serde_json::to_string(&session.metadata)
261 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
262
263 sqlx::query(
264 r#"
265 INSERT OR REPLACE INTO sessions
266 (id, agent_id, messages, metadata, created_at, updated_at, resume_token)
267 VALUES (?, ?, ?, ?, ?, ?, ?)
268 "#,
269 )
270 .bind(&session.id)
271 .bind(&session.agent_id)
272 .bind(&messages_json)
273 .bind(&metadata_json)
274 .bind(session.created_at)
275 .bind(session.updated_at)
276 .bind(&session.resume_token)
277 .execute(&self.store.pool)
278 .await
279 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
280
281 Ok(())
282 }
283
284 async fn load_session(&self, session_id: &str) -> Result<Option<Session>, MemoryError> {
285 let row: Option<SessionRow> = sqlx::query_as(
286 "SELECT id, agent_id, messages, metadata, created_at, updated_at, resume_token FROM sessions WHERE id = ?",
287 )
288 .bind(session_id)
289 .fetch_optional(&self.store.pool)
290 .await
291 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
292
293 row.map(|r| r.into_session()).transpose()
294 }
295
296 async fn load_by_resume_token(&self, token: &str) -> Result<Option<Session>, MemoryError> {
297 let row: Option<SessionRow> = sqlx::query_as(
298 "SELECT id, agent_id, messages, metadata, created_at, updated_at, resume_token FROM sessions WHERE resume_token = ?",
299 )
300 .bind(token)
301 .fetch_optional(&self.store.pool)
302 .await
303 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
304
305 row.map(|r| r.into_session()).transpose()
306 }
307
308 async fn delete_session(&self, session_id: &str) -> Result<(), MemoryError> {
309 sqlx::query("DELETE FROM sessions WHERE id = ?")
310 .bind(session_id)
311 .execute(&self.store.pool)
312 .await
313 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
314
315 Ok(())
316 }
317
318 async fn list_sessions(&self, agent_id: &str) -> Result<Vec<Session>, MemoryError> {
319 let rows: Vec<SessionRow> = sqlx::query_as(
320 "SELECT id, agent_id, messages, metadata, created_at, updated_at, resume_token FROM sessions WHERE agent_id = ?",
321 )
322 .bind(agent_id)
323 .fetch_all(&self.store.pool)
324 .await
325 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
326
327 rows.into_iter().map(|r| r.into_session()).collect()
328 }
329
330 async fn list_session_ids(&self, agent_id: &str) -> Result<Vec<String>, MemoryError> {
331 let rows: Vec<(String,)> = sqlx::query_as("SELECT id FROM sessions WHERE agent_id = ?")
332 .bind(agent_id)
333 .fetch_all(&self.store.pool)
334 .await
335 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
336
337 Ok(rows.into_iter().map(|(id,)| id).collect())
338 }
339
340 async fn session_exists(&self, session_id: &str) -> Result<bool, MemoryError> {
341 let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM sessions WHERE id = ?")
342 .bind(session_id)
343 .fetch_one(&self.store.pool)
344 .await
345 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
346
347 Ok(count.0 > 0)
348 }
349
350 async fn clear_agent_sessions(&self, agent_id: &str) -> Result<usize, MemoryError> {
351 let result = sqlx::query("DELETE FROM sessions WHERE agent_id = ?")
352 .bind(agent_id)
353 .execute(&self.store.pool)
354 .await
355 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
356
357 Ok(result.rows_affected() as usize)
358 }
359
360 fn backend_name(&self) -> &'static str {
361 "sqlite"
362 }
363}
364
365#[derive(sqlx::FromRow)]
367struct SessionRow {
368 id: String,
369 agent_id: String,
370 messages: String,
371 metadata: String,
372 created_at: i64,
373 updated_at: i64,
374 resume_token: Option<String>,
375}
376
377impl SessionRow {
378 fn into_session(self) -> Result<Session, MemoryError> {
379 let messages: Vec<Message> = serde_json::from_str(&self.messages)
380 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
381 let metadata: HashMap<String, serde_json::Value> = serde_json::from_str(&self.metadata)
382 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
383
384 Ok(Session {
385 id: self.id,
386 agent_id: self.agent_id,
387 messages,
388 metadata,
389 created_at: self.created_at,
390 updated_at: self.updated_at,
391 resume_token: self.resume_token,
392 })
393 }
394}
395
396pub struct SqliteVectorStore {
398 store: Arc<SqliteStore>,
399 dimension: usize,
400}
401
402impl SqliteVectorStore {
403 fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
405 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
406 }
407
408 fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
410 bytes
411 .chunks_exact(4)
412 .map(|chunk| {
413 let arr: [u8; 4] = chunk.try_into().unwrap();
414 f32::from_le_bytes(arr)
415 })
416 .collect()
417 }
418
419 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
421 if a.len() != b.len() {
422 return 0.0;
423 }
424
425 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
426 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
427 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
428
429 if norm_a == 0.0 || norm_b == 0.0 {
430 return 0.0;
431 }
432
433 dot_product / (norm_a * norm_b)
434 }
435}
436
437#[async_trait]
438impl VectorStore for SqliteVectorStore {
439 async fn insert(&self, doc: VectorDocument) -> Result<(), MemoryError> {
440 if doc.embedding.len() != self.dimension {
441 return Err(MemoryError::StorageError(format!(
442 "Embedding dimension mismatch: expected {}, got {}",
443 self.dimension,
444 doc.embedding.len()
445 )));
446 }
447
448 let embedding_bytes = Self::embedding_to_bytes(&doc.embedding);
449 let metadata_json = serde_json::to_string(&doc.metadata)
450 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
451
452 sqlx::query(
453 r#"
454 INSERT OR REPLACE INTO vector_documents
455 (id, content, embedding, metadata, source_id, chunk_index)
456 VALUES (?, ?, ?, ?, ?, ?)
457 "#,
458 )
459 .bind(&doc.id)
460 .bind(&doc.content)
461 .bind(&embedding_bytes)
462 .bind(&metadata_json)
463 .bind(&doc.source_id)
464 .bind(doc.chunk_index.map(|i| i as i64))
465 .execute(&self.store.pool)
466 .await
467 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
468
469 Ok(())
470 }
471
472 async fn search(
473 &self,
474 query_embedding: &[f32],
475 top_k: usize,
476 ) -> Result<Vec<SearchResult>, MemoryError> {
477 if query_embedding.len() != self.dimension {
478 return Err(MemoryError::StorageError(format!(
479 "Query embedding dimension mismatch: expected {}, got {}",
480 self.dimension,
481 query_embedding.len()
482 )));
483 }
484
485 let rows: Vec<VectorDocRow> = sqlx::query_as(
487 "SELECT id, content, embedding, metadata, source_id, chunk_index FROM vector_documents",
488 )
489 .fetch_all(&self.store.pool)
490 .await
491 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
492
493 let mut results: Vec<SearchResult> = rows
494 .into_iter()
495 .filter_map(|row| {
496 let doc = row.into_document().ok()?;
497 let score = Self::cosine_similarity(query_embedding, &doc.embedding);
498 Some(SearchResult {
499 document: doc,
500 score,
501 })
502 })
503 .collect();
504
505 results.sort_by(|a, b| {
507 b.score
508 .partial_cmp(&a.score)
509 .unwrap_or(std::cmp::Ordering::Equal)
510 });
511 results.truncate(top_k);
512
513 Ok(results)
514 }
515
516 async fn get(&self, id: &str) -> Result<Option<VectorDocument>, MemoryError> {
517 let row: Option<VectorDocRow> = sqlx::query_as(
518 "SELECT id, content, embedding, metadata, source_id, chunk_index FROM vector_documents WHERE id = ?",
519 )
520 .bind(id)
521 .fetch_optional(&self.store.pool)
522 .await
523 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
524
525 row.map(|r| r.into_document()).transpose()
526 }
527
528 async fn delete(&self, id: &str) -> Result<bool, MemoryError> {
529 let result = sqlx::query("DELETE FROM vector_documents WHERE id = ?")
530 .bind(id)
531 .execute(&self.store.pool)
532 .await
533 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
534
535 Ok(result.rows_affected() > 0)
536 }
537
538 async fn delete_by_source(&self, source_id: &str) -> Result<usize, MemoryError> {
539 let result = sqlx::query("DELETE FROM vector_documents WHERE source_id = ?")
540 .bind(source_id)
541 .execute(&self.store.pool)
542 .await
543 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
544
545 Ok(result.rows_affected() as usize)
546 }
547
548 async fn count(&self) -> Result<usize, MemoryError> {
549 let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM vector_documents")
550 .fetch_one(&self.store.pool)
551 .await
552 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
553
554 Ok(count.0 as usize)
555 }
556
557 async fn clear(&self) -> Result<(), MemoryError> {
558 sqlx::query("DELETE FROM vector_documents")
559 .execute(&self.store.pool)
560 .await
561 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
562
563 Ok(())
564 }
565
566 fn name(&self) -> &'static str {
567 "sqlite"
568 }
569}
570
571#[derive(sqlx::FromRow)]
573struct VectorDocRow {
574 id: String,
575 content: String,
576 embedding: Vec<u8>,
577 metadata: String,
578 source_id: Option<String>,
579 chunk_index: Option<i64>,
580}
581
582impl VectorDocRow {
583 fn into_document(self) -> Result<VectorDocument, MemoryError> {
584 let embedding = SqliteVectorStore::bytes_to_embedding(&self.embedding);
585 let metadata: HashMap<String, serde_json::Value> = serde_json::from_str(&self.metadata)
586 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
587
588 Ok(VectorDocument {
589 id: self.id,
590 content: self.content,
591 embedding,
592 metadata,
593 source_id: self.source_id,
594 chunk_index: self.chunk_index.map(|i| i as usize),
595 })
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[tokio::test]
604 async fn test_sqlite_store_creation() {
605 let store = SqliteStore::in_memory().await.unwrap();
606 assert!(store.pool().acquire().await.is_ok());
607 }
608
609 #[tokio::test]
610 async fn test_sqlite_session_backend() {
611 let store = Arc::new(SqliteStore::in_memory().await.unwrap());
612 let backend = store.session_backend();
613
614 let mut session = Session::new("test-agent");
616 session.set_metadata("key", serde_json::json!("value"));
617
618 backend.save_session(&session).await.unwrap();
620
621 let loaded = backend.load_session(&session.id).await.unwrap().unwrap();
623 assert_eq!(loaded.agent_id, "test-agent");
624 assert_eq!(
625 loaded.get_metadata("key"),
626 Some(&serde_json::json!("value"))
627 );
628
629 let sessions = backend.list_sessions("test-agent").await.unwrap();
631 assert_eq!(sessions.len(), 1);
632
633 backend.delete_session(&session.id).await.unwrap();
635 assert!(!backend.session_exists(&session.id).await.unwrap());
636 }
637
638 #[tokio::test]
639 async fn test_sqlite_session_resume_token() {
640 let store = Arc::new(SqliteStore::in_memory().await.unwrap());
641 let backend = store.session_backend();
642
643 let mut session = Session::new("test-agent");
644 let token = session.generate_resume_token();
645 backend.save_session(&session).await.unwrap();
646
647 let resumed = backend.load_by_resume_token(&token).await.unwrap().unwrap();
649 assert_eq!(resumed.id, session.id);
650 }
651
652 #[tokio::test]
653 async fn test_sqlite_vector_store() {
654 let store = Arc::new(SqliteStore::in_memory().await.unwrap());
655 let vector_store = store.vector_store(4);
656
657 let doc1 = VectorDocument::new("doc1", "Hello world", vec![1.0, 0.0, 0.0, 0.0]);
659 let doc2 = VectorDocument::new("doc2", "Goodbye", vec![0.0, 1.0, 0.0, 0.0]);
660
661 vector_store.insert(doc1).await.unwrap();
662 vector_store.insert(doc2).await.unwrap();
663
664 assert_eq!(vector_store.count().await.unwrap(), 2);
665
666 let results = vector_store.search(&[1.0, 0.0, 0.0, 0.0], 1).await.unwrap();
668 assert_eq!(results.len(), 1);
669 assert_eq!(results[0].document.id, "doc1");
670
671 vector_store.delete("doc1").await.unwrap();
673 assert_eq!(vector_store.count().await.unwrap(), 1);
674 }
675
676 #[tokio::test]
677 async fn test_sqlite_vector_store_by_source() {
678 let store = Arc::new(SqliteStore::in_memory().await.unwrap());
679 let vector_store = store.vector_store(4);
680
681 let doc1 = VectorDocument::new("chunk1", "Part 1", vec![1.0, 0.0, 0.0, 0.0])
682 .with_source("doc1", 0);
683 let doc2 = VectorDocument::new("chunk2", "Part 2", vec![0.0, 1.0, 0.0, 0.0])
684 .with_source("doc1", 1);
685 let doc3 =
686 VectorDocument::new("other", "Other", vec![0.0, 0.0, 1.0, 0.0]).with_source("doc2", 0);
687
688 vector_store.insert(doc1).await.unwrap();
689 vector_store.insert(doc2).await.unwrap();
690 vector_store.insert(doc3).await.unwrap();
691
692 let deleted = vector_store.delete_by_source("doc1").await.unwrap();
694 assert_eq!(deleted, 2);
695 assert_eq!(vector_store.count().await.unwrap(), 1);
696 }
697
698 #[tokio::test]
699 async fn test_sqlite_clear_agent_sessions() {
700 let store = Arc::new(SqliteStore::in_memory().await.unwrap());
701 let backend = store.session_backend();
702
703 let session1 = Session::new("agent-1");
705 let session2 = Session::new("agent-1");
706 let session3 = Session::new("agent-2");
707
708 backend.save_session(&session1).await.unwrap();
709 backend.save_session(&session2).await.unwrap();
710 backend.save_session(&session3).await.unwrap();
711
712 let cleared = backend.clear_agent_sessions("agent-1").await.unwrap();
714 assert_eq!(cleared, 2);
715
716 let remaining = backend.list_sessions("agent-1").await.unwrap();
718 assert_eq!(remaining.len(), 0);
719
720 let agent2_sessions = backend.list_sessions("agent-2").await.unwrap();
721 assert_eq!(agent2_sessions.len(), 1);
722 }
723}