1pub mod schema;
5pub mod migration;
6pub mod conversation_store;
7pub mod summary_store;
8pub mod embedding_store;
9
10pub use schema::*;
12pub use migration::MigrationManager;
13pub use conversation_store::ConversationStore;
14pub use summary_store::SummaryStore;
15pub use embedding_store::{EmbeddingStore, EmbeddingStats};
16
17use std::path::Path;
18use std::sync::Arc;
19use r2d2::Pool;
20use r2d2_sqlite::SqliteConnectionManager;
21use tracing::info;
22use crate::cache_management::cache_extractor::KVEntry;
23use crate::cache_management::cache_manager::SessionCacheState;
24
25pub struct MemoryDatabase {
27 pub conversations: ConversationStore,
28 pub summaries: SummaryStore,
29 pub embeddings: EmbeddingStore,
30 pool: Arc<Pool<SqliteConnectionManager>>,
31}
32
33pub struct Transaction<'a> {
35 conn: r2d2::PooledConnection<SqliteConnectionManager>,
36 _marker: std::marker::PhantomData<&'a MemoryDatabase>,
37}
38
39impl<'a> Transaction<'a> {
40 pub fn commit(self) -> anyhow::Result<()> {
42 Ok(())
44 }
45
46 pub fn rollback(self) -> anyhow::Result<()> {
48 Ok(())
50 }
51
52 pub fn connection(&mut self) -> &mut rusqlite::Connection {
54 &mut self.conn
55 }
56}
57
58impl MemoryDatabase {
59 pub fn new(db_path: &Path) -> anyhow::Result<Self> {
61 info!("Opening memory database at: {}", db_path.display());
62
63 if let Some(parent) = db_path.parent() {
64 std::fs::create_dir_all(parent)?;
65 }
66
67 let manager = SqliteConnectionManager::file(db_path)
68 .with_flags(
69 rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
70 | rusqlite::OpenFlags::SQLITE_OPEN_CREATE
71 | rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
72 );
73
74 let pool = Pool::builder()
75 .max_size(10)
76 .build(manager)
77 .map_err(|e| anyhow::anyhow!("Failed to create connection pool: {}", e))?;
78
79 {
81 let mut conn = pool.get()?;
82 let mut migrator = migration::MigrationManager::new(&mut conn);
83 migrator.initialize_database()?;
84
85 conn.execute_batch(
86 "PRAGMA foreign_keys = ON;
87 PRAGMA journal_mode = WAL;
88 PRAGMA synchronous = NORMAL;
89 PRAGMA busy_timeout = 5000;",
90 )?;
91 }
92
93 let pool = Arc::new(pool);
94
95 info!("Memory database initialized successfully");
96
97 Ok(Self {
98 conversations: ConversationStore::new(Arc::clone(&pool)),
99 summaries: SummaryStore::new(Arc::clone(&pool)),
100 embeddings: EmbeddingStore::new(Arc::clone(&pool)),
101 pool,
102 })
103 }
104
105 pub fn new_in_memory() -> anyhow::Result<Self> {
107 let manager = SqliteConnectionManager::memory();
108 let pool = Pool::builder()
109 .max_size(5)
110 .build(manager)?;
111
112 {
113 let conn = pool.get()?;
114 conn.execute_batch(schema::SCHEMA_SQL)?;
115 }
116
117 let pool = Arc::new(pool);
118
119 Ok(Self {
120 conversations: ConversationStore::new(Arc::clone(&pool)),
121 summaries: SummaryStore::new(Arc::clone(&pool)),
122 embeddings: EmbeddingStore::new(Arc::clone(&pool)),
123 pool,
124 })
125 }
126
127 pub fn begin_transaction(&self) -> anyhow::Result<Transaction<'_>> {
129 let conn = self.pool.get()?;
130 conn.execute_batch("BEGIN IMMEDIATE TRANSACTION;")?;
131 Ok(Transaction {
132 conn,
133 _marker: std::marker::PhantomData,
134 })
135 }
136
137 pub fn with_transaction<T, F>(&self, f: F) -> anyhow::Result<T>
139 where
140 F: FnOnce(&mut Transaction<'_>) -> anyhow::Result<T>,
141 {
142 let mut tx = self.begin_transaction()?;
143 match f(&mut tx) {
144 Ok(result) => {
145 tx.commit()?;
146 Ok(result)
147 }
148 Err(e) => {
149 tx.rollback()?;
150 Err(e)
151 }
152 }
153 }
154
155 pub fn get_stats(&self) -> anyhow::Result<DatabaseStats> {
157 let conn = self.pool.get()?;
158 Ok(migration::get_database_stats(&conn)?)
159 }
160
161 pub fn cleanup_old_data(&self, older_than_days: i32) -> anyhow::Result<usize> {
163 let mut conn = self.pool.get()?;
164 let mut migrator = migration::MigrationManager::new(&mut conn);
165 Ok(migrator.cleanup_old_data(older_than_days)?)
166 }
167
168 pub async fn create_kv_snapshot(
170 &self,
171 session_id: &str,
172 entries: &[KVEntry],
173 ) -> anyhow::Result<i64> {
174 use blake3;
175
176 let mut conn = self.pool.get()?; let tx = conn.transaction()?;
178
179 let total_size_bytes: usize = entries.iter()
181 .map(|entry| entry.value_data.len())
182 .sum();
183
184 let kv_state = bincode::serialize(entries)?;
186 let kv_state_hash = blake3::hash(&kv_state).to_string();
187
188 let message_id: i64 = tx.query_row(
190 "SELECT COALESCE(MAX(id), 0) FROM messages WHERE session_id = ?1",
191 [session_id],
192 |row| row.get(0),
193 )?;
194
195 tx.execute(
197 "INSERT INTO kv_snapshots
198 (session_id, message_id, kv_state, kv_state_hash, size_bytes)
199 VALUES (?1, ?2, ?3, ?4, ?5)",
200 rusqlite::params![session_id, message_id, kv_state, kv_state_hash, total_size_bytes as i64],
201 )?;
202
203 let snapshot_id = tx.last_insert_rowid();
204
205 for entry in entries {
207 tx.execute(
208 "INSERT INTO kv_cache_entries
209 (snapshot_id, key_hash, key_data, value_data, key_type,
210 layer_index, head_index, importance_score, access_count)
211 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
212 rusqlite::params![
213 snapshot_id,
214 &entry.key_hash,
215 entry.key_data.as_deref(),
216 &entry.value_data,
217 &entry.key_type,
218 entry.layer_index,
219 entry.head_index,
220 entry.importance_score,
221 entry.access_count,
222 ],
223 )?;
224 }
225
226 let now = chrono::Utc::now().to_rfc3339();
228 tx.execute(
229 "INSERT OR REPLACE INTO kv_cache_metadata
230 (session_id, total_entries, total_size_bytes, last_cleared_at)
231 VALUES (?1, ?2, ?3, ?4)",
232 rusqlite::params![session_id, entries.len() as i64, total_size_bytes as i64, &now],
233 )?;
234
235 tx.commit()?;
236
237 Ok(snapshot_id)
238 }
239
240 pub async fn get_recent_kv_snapshots(
242 &self,
243 session_id: &str,
244 limit: usize,
245 ) -> anyhow::Result<Vec<crate::cache_management::cache_manager::KvSnapshot>> {
246 let conn = self.pool.get()?;
247 let mut stmt = conn.prepare(
248 "SELECT id, session_id, message_id, snapshot_type, size_bytes, created_at
249 FROM kv_snapshots
250 WHERE session_id = ?1
251 ORDER BY created_at DESC
252 LIMIT ?2"
253 )?;
254
255 let mut rows = stmt.query(rusqlite::params![session_id, limit as i64])?;
256 let mut snapshots = Vec::new();
257
258 while let Some(row) = rows.next()? {
259 let created_at_str: String = row.get(5)?;
260 let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str)
261 .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
262 .with_timezone(&chrono::Utc);
263
264 snapshots.push(crate::cache_management::cache_manager::KvSnapshot {
265 id: row.get(0)?,
266 session_id: row.get(1)?,
267 message_id: row.get(2)?,
268 snapshot_type: row.get(3)?,
269 size_bytes: row.get(4)?,
270 created_at,
271 });
272 }
273
274 Ok(snapshots)
275 }
276
277 pub async fn get_kv_snapshot_entries(
279 &self,
280 snapshot_id: i64,
281 ) -> anyhow::Result<Vec<KVEntry>> {
282 let conn = self.pool.get()?;
283 let mut stmt = conn.prepare(
284 "SELECT key_hash, key_data, value_data, key_type, layer_index,
285 head_index, importance_score, access_count, last_accessed
286 FROM kv_cache_entries
287 WHERE snapshot_id = ?1"
288 )?;
289
290 let mut rows = stmt.query([snapshot_id])?;
291 let mut entries = Vec::new();
292
293 while let Some(row) = rows.next()? {
294 let last_accessed_str: String = row.get(8)?;
295 let last_accessed = chrono::DateTime::parse_from_rfc3339(&last_accessed_str)
296 .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
297 .with_timezone(&chrono::Utc);
298
299 entries.push(KVEntry {
300 key_hash: row.get(0)?,
301 key_data: row.get(1)?,
302 value_data: row.get(2)?,
303 key_type: row.get(3)?,
304 layer_index: row.get(4)?,
305 head_index: row.get(5)?,
306 importance_score: row.get(6)?,
307 access_count: row.get(7)?,
308 last_accessed,
309 });
310 }
311
312 Ok(entries)
313 }
314
315 pub async fn search_messages_by_keywords(
317 &self,
318 session_id: &str,
319 keywords: &[String],
320 limit: usize,
321 ) -> anyhow::Result<Vec<StoredMessage>> {
322 let patterns: Vec<String> = keywords.iter()
324 .map(|k| format!("%{}%", k))
325 .collect();
326
327 let conn = self.pool.get()?;
328
329 let mut query = String::from(
331 "SELECT id, session_id, message_index, role, content, tokens,
332 timestamp, importance_score, embedding_generated
333 FROM messages
334 WHERE session_id = ?1"
335 );
336
337 for _ in &patterns {
338 query.push_str(" AND content LIKE ?");
339 }
340
341 query.push_str(" ORDER BY timestamp DESC LIMIT ?");
342
343 let mut stmt = conn.prepare(&query)?;
344
345 let mut params: Vec<&dyn rusqlite::ToSql> = Vec::new();
347 params.push(&session_id);
348 for pattern in &patterns {
349 params.push(pattern);
350 }
351 let limit_i64 = limit as i64;
353 params.push(&limit_i64);
354
355 let mut rows = stmt.query(rusqlite::params_from_iter(params))?;
356 let mut messages = Vec::new();
357
358 while let Some(row) = rows.next()? {
359 let timestamp_str: String = row.get(6)?;
360 let timestamp = chrono::DateTime::parse_from_rfc3339(×tamp_str)
361 .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
362 .with_timezone(&chrono::Utc);
363
364 messages.push(StoredMessage {
365 id: row.get(0)?,
366 session_id: row.get(1)?,
367 message_index: row.get(2)?,
368 role: row.get(3)?,
369 content: row.get(4)?,
370 tokens: row.get(5)?,
371 timestamp,
372 importance_score: row.get(7)?,
373 embedding_generated: row.get(8)?,
374 });
375 }
376
377 Ok(messages)
378 }
379
380 pub async fn update_kv_cache_metadata(
382 &self,
383 session_id: &str,
384 state: &SessionCacheState,
385 ) -> anyhow::Result<()> {
386 let conn = self.pool.get()?;
387 let metadata_json = serde_json::to_string(&state.metadata)?;
388
389 conn.execute(
390 "INSERT OR REPLACE INTO kv_cache_metadata
391 (session_id, total_entries, total_size_bytes, conversation_count, metadata)
392 VALUES (?1, ?2, ?3, ?4, ?5)",
393 rusqlite::params![
394 session_id,
395 state.entry_count as i64,
396 state.cache_size_bytes as i64,
397 state.conversation_count as i64,
398 metadata_json,
399 ],
400 )?;
401
402 Ok(())
403 }
404
405 pub async fn cleanup_session_snapshots(
407 &self,
408 session_id: &str,
409 ) -> anyhow::Result<()> {
410 let conn = self.pool.get()?;
411
412 conn.execute(
413 "DELETE FROM kv_snapshots WHERE session_id = ?1",
414 [session_id],
415 )?;
416
417 conn.execute(
418 "DELETE FROM kv_cache_metadata WHERE session_id = ?1",
419 [session_id],
420 )?;
421
422 Ok(())
423 }
424
425 pub async fn prune_old_kv_snapshots(
427 &self,
428 keep_max: usize,
429 ) -> anyhow::Result<usize> {
430 let conn = self.pool.get()?;
431
432 let mut stmt = conn.prepare(
434 "SELECT ks.id
435 FROM kv_snapshots ks
436 WHERE (
437 SELECT COUNT(*)
438 FROM kv_snapshots ks2
439 WHERE ks2.session_id = ks.session_id
440 AND ks2.created_at >= ks.created_at
441 ) > ?1"
442 )?;
443
444 let ids_to_delete: Vec<i64> = stmt
445 .query_map([keep_max as i64], |row| row.get(0))?
446 .collect::<rusqlite::Result<Vec<_>>>()?;
447
448 if ids_to_delete.is_empty() {
449 return Ok(0);
450 }
451
452 let placeholders = vec!["?"; ids_to_delete.len()].join(",");
454 let query = format!("DELETE FROM kv_snapshots WHERE id IN ({})", placeholders);
455
456 let mut stmt = conn.prepare(&query)?;
457 let deleted = stmt.execute(rusqlite::params_from_iter(&ids_to_delete))?;
458
459 Ok(deleted)
460 }
461}
462
463impl Drop for MemoryDatabase {
464 fn drop(&mut self) {
465 if let Ok(conn) = self.pool.get() {
467 let _ = conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);");
468 }
469 }
470}