Skip to main content

offline_intelligence/memory_db/
mod.rs

1// "D:\_ProjectWorks\AUDIO_Interface\Server\src\memory_db\mod.rs"
2//! Memory database module - SQLite-based storage for conversations, summaries, and embeddings
3
4pub mod schema;
5pub mod migration;
6pub mod conversation_store;
7pub mod summary_store;
8pub mod embedding_store;
9
10// Re-export commonly used types
11pub use schema::*;
12pub use migration::MigrationManager;
13pub use conversation_store::ConversationStore;
14pub use summary_store::SummaryStore;
15pub use embedding_store::{EmbeddingStore, EmbeddingStats};
16
17use std::path::Path;
18use std::sync::Arc;
19use r2d2::Pool;
20use r2d2_sqlite::SqliteConnectionManager;
21use tracing::info;
22use crate::cache_management::cache_extractor::KVEntry;
23use crate::cache_management::cache_manager::SessionCacheState;
24
25/// Main database manager that coordinates all stores
26pub struct MemoryDatabase {
27    pub conversations: ConversationStore,
28    pub summaries: SummaryStore,
29    pub embeddings: EmbeddingStore,
30    pool: Arc<Pool<SqliteConnectionManager>>,
31}
32
33/// Transaction manager for atomic operations across stores
34pub struct Transaction<'a> {
35    conn: r2d2::PooledConnection<SqliteConnectionManager>,
36    _marker: std::marker::PhantomData<&'a MemoryDatabase>,
37}
38
39impl<'a> Transaction<'a> {
40    /// Commit the transaction
41    pub fn commit(self) -> anyhow::Result<()> {
42        // Changes are automatically committed when the connection is dropped
43        Ok(())
44    }
45
46    /// Rollback the transaction
47    pub fn rollback(self) -> anyhow::Result<()> {
48        // SQLite auto-rolls back on DROP if not committed
49        Ok(())
50    }
51
52    /// Get raw connection for store operations
53    pub fn connection(&mut self) -> &mut rusqlite::Connection {
54        &mut self.conn
55    }
56}
57
58impl MemoryDatabase {
59    /// Create a new memory database at the specified path
60    pub fn new(db_path: &Path) -> anyhow::Result<Self> {
61        info!("Opening memory database at: {}", db_path.display());
62
63        if let Some(parent) = db_path.parent() {
64            std::fs::create_dir_all(parent)?;
65        }
66
67        let manager = SqliteConnectionManager::file(db_path)
68            .with_flags(
69                rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
70                | rusqlite::OpenFlags::SQLITE_OPEN_CREATE
71                | rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
72            );
73
74        let pool = Pool::builder()
75            .max_size(10)
76            .build(manager)
77            .map_err(|e| anyhow::anyhow!("Failed to create connection pool: {}", e))?;
78
79        // Initialize DB and pragmas - FIXED: Use mutable connection
80        {
81            let mut conn = pool.get()?;
82            let mut migrator = migration::MigrationManager::new(&mut conn);
83            migrator.initialize_database()?;
84
85            conn.execute_batch(
86                "PRAGMA foreign_keys = ON;
87                 PRAGMA journal_mode = WAL;
88                 PRAGMA synchronous = NORMAL;
89                 PRAGMA busy_timeout = 5000;",
90            )?;
91        }
92
93        let pool = Arc::new(pool);
94
95        info!("Memory database initialized successfully");
96
97        Ok(Self {
98            conversations: ConversationStore::new(Arc::clone(&pool)),
99            summaries: SummaryStore::new(Arc::clone(&pool)),
100            embeddings: EmbeddingStore::new(Arc::clone(&pool)),
101            pool,
102        })
103    }
104
105    /// Create an in-memory database (useful for testing)
106    pub fn new_in_memory() -> anyhow::Result<Self> {
107        let manager = SqliteConnectionManager::memory();
108        let pool = Pool::builder()
109            .max_size(5)
110            .build(manager)?;
111
112        {
113            let conn = pool.get()?;
114            conn.execute_batch(schema::SCHEMA_SQL)?;
115        }
116
117        let pool = Arc::new(pool);
118
119        Ok(Self {
120            conversations: ConversationStore::new(Arc::clone(&pool)),
121            summaries: SummaryStore::new(Arc::clone(&pool)),
122            embeddings: EmbeddingStore::new(Arc::clone(&pool)),
123            pool,
124        })
125    }
126
127    /// Begin a transaction for atomic operations
128    pub fn begin_transaction(&self) -> anyhow::Result<Transaction<'_>> {
129        let conn = self.pool.get()?;
130        conn.execute_batch("BEGIN IMMEDIATE TRANSACTION;")?;
131        Ok(Transaction {
132            conn,
133            _marker: std::marker::PhantomData,
134        })
135    }
136
137    /// Execute operations in a transaction
138    pub fn with_transaction<T, F>(&self, f: F) -> anyhow::Result<T>
139    where
140        F: FnOnce(&mut Transaction<'_>) -> anyhow::Result<T>,
141    {
142        let mut tx = self.begin_transaction()?;
143        match f(&mut tx) {
144            Ok(result) => {
145                tx.commit()?;
146                Ok(result)
147            }
148            Err(e) => {
149                tx.rollback()?;
150                Err(e)
151            }
152        }
153    }
154
155    /// Get database statistics
156    pub fn get_stats(&self) -> anyhow::Result<DatabaseStats> {
157        let conn = self.pool.get()?;
158        Ok(migration::get_database_stats(&conn)?)
159    }
160
161    /// Cleanup old data (older than specified days)
162    pub fn cleanup_old_data(&self, older_than_days: i32) -> anyhow::Result<usize> {
163        let mut conn = self.pool.get()?;
164        let mut migrator = migration::MigrationManager::new(&mut conn);
165        Ok(migrator.cleanup_old_data(older_than_days)?)
166    }
167
168    /// Create a KV snapshot
169    pub async fn create_kv_snapshot(
170        &self,
171        session_id: &str,
172        entries: &[KVEntry],
173    ) -> anyhow::Result<i64> {
174        use blake3;
175
176        let mut conn = self.pool.get()?;  // FIXED: Added mut
177        let tx = conn.transaction()?;
178        
179        // Calculate total size
180        let total_size_bytes: usize = entries.iter()
181            .map(|entry| entry.value_data.len())
182            .sum();
183        
184        // Serialize entries to BLOB
185        let kv_state = bincode::serialize(entries)?;
186        let kv_state_hash = blake3::hash(&kv_state).to_string();
187        
188        // Get the latest message ID for this session
189        let message_id: i64 = tx.query_row(
190            "SELECT COALESCE(MAX(id), 0) FROM messages WHERE session_id = ?1",
191            [session_id],
192            |row| row.get(0),
193        )?;
194        
195        // Insert snapshot
196        tx.execute(
197            "INSERT INTO kv_snapshots 
198             (session_id, message_id, kv_state, kv_state_hash, size_bytes)
199             VALUES (?1, ?2, ?3, ?4, ?5)",
200            rusqlite::params![session_id, message_id, kv_state, kv_state_hash, total_size_bytes as i64],
201        )?;
202        
203        let snapshot_id = tx.last_insert_rowid();
204        
205        // Insert individual cache entries
206        for entry in entries {
207            tx.execute(
208                "INSERT INTO kv_cache_entries 
209                 (snapshot_id, key_hash, key_data, value_data, key_type, 
210                  layer_index, head_index, importance_score, access_count)
211                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
212                rusqlite::params![
213                    snapshot_id,
214                    &entry.key_hash,
215                    entry.key_data.as_deref(),
216                    &entry.value_data,
217                    &entry.key_type,
218                    entry.layer_index,
219                    entry.head_index,
220                    entry.importance_score,
221                    entry.access_count,
222                ],
223            )?;
224        }
225        
226        // Update metadata
227        let now = chrono::Utc::now().to_rfc3339();
228        tx.execute(
229            "INSERT OR REPLACE INTO kv_cache_metadata 
230             (session_id, total_entries, total_size_bytes, last_cleared_at)
231             VALUES (?1, ?2, ?3, ?4)",
232            rusqlite::params![session_id, entries.len() as i64, total_size_bytes as i64, &now],
233        )?;
234        
235        tx.commit()?;
236        
237        Ok(snapshot_id)
238    }
239    
240    /// Get recent KV snapshots for a session
241    pub async fn get_recent_kv_snapshots(
242        &self,
243        session_id: &str,
244        limit: usize,
245    ) -> anyhow::Result<Vec<crate::cache_management::cache_manager::KvSnapshot>> {
246        let conn = self.pool.get()?;
247        let mut stmt = conn.prepare(
248            "SELECT id, session_id, message_id, snapshot_type, size_bytes, created_at
249             FROM kv_snapshots 
250             WHERE session_id = ?1 
251             ORDER BY created_at DESC 
252             LIMIT ?2"
253        )?;
254        
255        let mut rows = stmt.query(rusqlite::params![session_id, limit as i64])?;
256        let mut snapshots = Vec::new();
257        
258        while let Some(row) = rows.next()? {
259            let created_at_str: String = row.get(5)?;
260            let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str)
261                .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
262                .with_timezone(&chrono::Utc);
263            
264            snapshots.push(crate::cache_management::cache_manager::KvSnapshot {
265                id: row.get(0)?,
266                session_id: row.get(1)?,
267                message_id: row.get(2)?,
268                snapshot_type: row.get(3)?,
269                size_bytes: row.get(4)?,
270                created_at,
271            });
272        }
273        
274        Ok(snapshots)
275    }
276    
277    /// Get KV snapshot entries
278    pub async fn get_kv_snapshot_entries(
279        &self,
280        snapshot_id: i64,
281    ) -> anyhow::Result<Vec<KVEntry>> {
282        let conn = self.pool.get()?;
283        let mut stmt = conn.prepare(
284            "SELECT key_hash, key_data, value_data, key_type, layer_index, 
285                    head_index, importance_score, access_count, last_accessed
286             FROM kv_cache_entries 
287             WHERE snapshot_id = ?1"
288        )?;
289        
290        let mut rows = stmt.query([snapshot_id])?;
291        let mut entries = Vec::new();
292        
293        while let Some(row) = rows.next()? {
294            let last_accessed_str: String = row.get(8)?;
295            let last_accessed = chrono::DateTime::parse_from_rfc3339(&last_accessed_str)
296                .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
297                .with_timezone(&chrono::Utc);
298            
299            entries.push(KVEntry {
300                key_hash: row.get(0)?,
301                key_data: row.get(1)?,
302                value_data: row.get(2)?,
303                key_type: row.get(3)?,
304                layer_index: row.get(4)?,
305                head_index: row.get(5)?,
306                importance_score: row.get(6)?,
307                access_count: row.get(7)?,
308                last_accessed,
309            });
310        }
311        
312        Ok(entries)
313    }
314    
315    /// Search messages by keywords (for ConversationStore)
316    pub async fn search_messages_by_keywords(
317        &self,
318        session_id: &str,
319        keywords: &[String],
320        limit: usize,
321    ) -> anyhow::Result<Vec<StoredMessage>> {
322        // Simple keyword search using LIKE pattern
323        let patterns: Vec<String> = keywords.iter()
324            .map(|k| format!("%{}%", k))
325            .collect();
326        
327        let conn = self.pool.get()?;
328        
329        // Build query with multiple LIKE conditions
330        let mut query = String::from(
331            "SELECT id, session_id, message_index, role, content, tokens, 
332                    timestamp, importance_score, embedding_generated
333             FROM messages 
334             WHERE session_id = ?1"
335        );
336        
337        for _ in &patterns {
338            query.push_str(" AND content LIKE ?");
339        }
340        
341        query.push_str(" ORDER BY timestamp DESC LIMIT ?");
342        
343        let mut stmt = conn.prepare(&query)?;
344        
345        // Build parameters: session_id + patterns + limit
346        let mut params: Vec<&dyn rusqlite::ToSql> = Vec::new();
347        params.push(&session_id);
348        for pattern in &patterns {
349            params.push(pattern);
350        }
351        // FIX: Store in variable to avoid temporary reference
352        let limit_i64 = limit as i64;
353        params.push(&limit_i64);
354        
355        let mut rows = stmt.query(rusqlite::params_from_iter(params))?;
356        let mut messages = Vec::new();
357        
358        while let Some(row) = rows.next()? {
359            let timestamp_str: String = row.get(6)?;
360            let timestamp = chrono::DateTime::parse_from_rfc3339(&timestamp_str)
361                .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
362                .with_timezone(&chrono::Utc);
363            
364            messages.push(StoredMessage {
365                id: row.get(0)?,
366                session_id: row.get(1)?,
367                message_index: row.get(2)?,
368                role: row.get(3)?,
369                content: row.get(4)?,
370                tokens: row.get(5)?,
371                timestamp,
372                importance_score: row.get(7)?,
373                embedding_generated: row.get(8)?,
374            });
375        }
376        
377        Ok(messages)
378    }
379    
380    /// Update KV cache metadata
381    pub async fn update_kv_cache_metadata(
382        &self,
383        session_id: &str,
384        state: &SessionCacheState,
385    ) -> anyhow::Result<()> {
386        let conn = self.pool.get()?;
387        let metadata_json = serde_json::to_string(&state.metadata)?;
388        
389        conn.execute(
390            "INSERT OR REPLACE INTO kv_cache_metadata 
391             (session_id, total_entries, total_size_bytes, conversation_count, metadata)
392             VALUES (?1, ?2, ?3, ?4, ?5)",
393            rusqlite::params![
394                session_id,
395                state.entry_count as i64,
396                state.cache_size_bytes as i64,
397                state.conversation_count as i64,
398                metadata_json,
399            ],
400        )?;
401        
402        Ok(())
403    }
404    
405    /// Cleanup session snapshots
406    pub async fn cleanup_session_snapshots(
407        &self,
408        session_id: &str,
409    ) -> anyhow::Result<()> {
410        let conn = self.pool.get()?;
411        
412        conn.execute(
413            "DELETE FROM kv_snapshots WHERE session_id = ?1",
414            [session_id],
415        )?;
416        
417        conn.execute(
418            "DELETE FROM kv_cache_metadata WHERE session_id = ?1",
419            [session_id],
420        )?;
421        
422        Ok(())
423    }
424    
425    /// Prune old KV snapshots
426    pub async fn prune_old_kv_snapshots(
427        &self,
428        keep_max: usize,
429    ) -> anyhow::Result<usize> {
430        let conn = self.pool.get()?;
431        
432        // Get snapshot IDs to delete (keep only the latest keep_max per session)
433        let mut stmt = conn.prepare(
434            "SELECT ks.id 
435             FROM kv_snapshots ks
436             WHERE (
437                 SELECT COUNT(*) 
438                 FROM kv_snapshots ks2 
439                 WHERE ks2.session_id = ks.session_id 
440                 AND ks2.created_at >= ks.created_at
441             ) > ?1"
442        )?;
443        
444        let ids_to_delete: Vec<i64> = stmt
445            .query_map([keep_max as i64], |row| row.get(0))?
446            .collect::<rusqlite::Result<Vec<_>>>()?;
447        
448        if ids_to_delete.is_empty() {
449            return Ok(0);
450        }
451        
452        // Delete snapshots
453        let placeholders = vec!["?"; ids_to_delete.len()].join(",");
454        let query = format!("DELETE FROM kv_snapshots WHERE id IN ({})", placeholders);
455        
456        let mut stmt = conn.prepare(&query)?;
457        let deleted = stmt.execute(rusqlite::params_from_iter(&ids_to_delete))?;
458        
459        Ok(deleted)
460    }
461}
462
463impl Drop for MemoryDatabase {
464    fn drop(&mut self) {
465        // Perform a final checkpoint on shutdown
466        if let Ok(conn) = self.pool.get() {
467            let _ = conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);");
468        }
469    }
470}