1pub mod schema;
5pub mod migration;
6pub mod conversation_store;
7pub mod embedding_store;
8pub mod local_files_store;
9pub mod all_files_store;
10pub mod api_keys_store;
11pub mod users_store;
12pub mod session_file_contexts_store;
13pub mod session_summaries_store;
14
15pub use schema::*;
17pub use migration::MigrationManager;
18pub use conversation_store::ConversationStore;
19pub use embedding_store::{EmbeddingStore, EmbeddingStats};
20pub use local_files_store::{LocalFilesStore, LocalFile, LocalFileTree};
21pub use all_files_store::{AllFilesStore, AllFile, AllFileTree};
22pub use api_keys_store::{ApiKeysStore, ApiKeyType, ApiKeyRecord, Encryption};
23pub use users_store::{UsersStore, User};
24pub use session_file_contexts_store::{SessionFileContextsStore, SessionFileContext, AttachmentRef};
25pub use session_summaries_store::SessionSummariesStore;
26
27use std::path::Path;
28use std::sync::Arc;
29use r2d2::Pool;
30use r2d2_sqlite::SqliteConnectionManager;
31use tracing::info;
32use crate::cache_management::cache_extractor::KVEntry;
33use crate::cache_management::cache_manager::SessionCacheState;
34
35pub struct MemoryDatabase {
37 pub conversations: ConversationStore,
38 pub embeddings: EmbeddingStore,
39 pub local_files: LocalFilesStore,
40 pub all_files: AllFilesStore,
41 pub api_keys: ApiKeysStore,
42 pub users: UsersStore,
43 pub session_file_contexts: SessionFileContextsStore,
44 pub session_summaries: SessionSummariesStore,
45 pool: Arc<Pool<SqliteConnectionManager>>,
46}
47
48pub struct Transaction<'a> {
50 conn: r2d2::PooledConnection<SqliteConnectionManager>,
51 _marker: std::marker::PhantomData<&'a MemoryDatabase>,
52}
53
54impl<'a> Transaction<'a> {
55 pub fn commit(self) -> anyhow::Result<()> {
57 Ok(())
59 }
60
61 pub fn rollback(self) -> anyhow::Result<()> {
63 Ok(())
65 }
66
67 pub fn connection(&mut self) -> &mut rusqlite::Connection {
69 &mut self.conn
70 }
71}
72
73impl MemoryDatabase {
74 pub fn new(db_path: &Path) -> anyhow::Result<Self> {
76 info!("Opening memory database at: {}", db_path.display());
77
78 if let Some(parent) = db_path.parent() {
79 std::fs::create_dir_all(parent)?;
80 }
81
82 let manager = SqliteConnectionManager::file(db_path)
83 .with_flags(
84 rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
85 | rusqlite::OpenFlags::SQLITE_OPEN_CREATE
86 | rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
87 );
88
89 let pool = Pool::builder()
90 .max_size(20) .build(manager)
92 .map_err(|e| anyhow::anyhow!("Failed to create connection pool: {}", e))?;
93
94 {
96 let mut conn = pool.get()?;
97 let mut migrator = migration::MigrationManager::new(&mut conn);
98 migrator.initialize_database()?;
99
100 conn.execute_batch(
101 "PRAGMA foreign_keys = ON;
102 PRAGMA journal_mode = WAL;
103 PRAGMA synchronous = NORMAL;
104 PRAGMA busy_timeout = 5000;",
105 )?;
106 }
107
108 let pool = Arc::new(pool);
109
110 let app_data_dir = dirs::data_dir()
112 .unwrap_or_else(|| std::path::PathBuf::from("."))
113 .join("Aud.io");
114
115 let all_files_dir = app_data_dir.join("all_files");
117
118 let api_keys = ApiKeysStore::new(Arc::clone(&pool));
120 if let Err(e) = api_keys.initialize_schema() {
121 tracing::warn!("Failed to initialize API keys schema: {}", e);
122 }
123
124 let users = UsersStore::new(Arc::clone(&pool));
126 if let Err(e) = users.initialize_schema() {
127 tracing::warn!("Failed to initialize users schema: {}", e);
128 }
129
130 info!("Memory database initialized successfully");
131
132 Ok(Self {
133 conversations: ConversationStore::new(Arc::clone(&pool)),
134 embeddings: EmbeddingStore::new(Arc::clone(&pool)),
135 local_files: LocalFilesStore::new(Arc::clone(&pool), app_data_dir.clone()),
136 all_files: AllFilesStore::new(Arc::clone(&pool), all_files_dir),
137 api_keys,
138 users,
139 session_file_contexts: SessionFileContextsStore::new(Arc::clone(&pool)),
140 session_summaries: SessionSummariesStore::new(Arc::clone(&pool)),
141 pool,
142 })
143 }
144
145 pub fn new_in_memory() -> anyhow::Result<Self> {
147 let manager = SqliteConnectionManager::memory();
148 let pool = Pool::builder()
149 .max_size(10) .build(manager)?;
151
152 {
153 let conn = pool.get()?;
154 conn.execute_batch(schema::SCHEMA_SQL)?;
155 }
156
157 let pool = Arc::new(pool);
158
159 let app_data_dir = dirs::data_dir()
161 .unwrap_or_else(|| std::path::PathBuf::from("."))
162 .join("Aud.io");
163
164 let all_files_dir = app_data_dir.join("all_files");
166
167 let api_keys = ApiKeysStore::new(Arc::clone(&pool));
169 if let Err(e) = api_keys.initialize_schema() {
170 tracing::warn!("Failed to initialize API keys schema (in-memory): {}", e);
171 }
172
173 let users = UsersStore::new(Arc::clone(&pool));
175 if let Err(e) = users.initialize_schema() {
176 tracing::warn!("Failed to initialize users schema (in-memory): {}", e);
177 }
178
179 Ok(Self {
180 conversations: ConversationStore::new(Arc::clone(&pool)),
181 embeddings: EmbeddingStore::new(Arc::clone(&pool)),
182 local_files: LocalFilesStore::new(Arc::clone(&pool), app_data_dir.clone()),
183 all_files: AllFilesStore::new(Arc::clone(&pool), all_files_dir),
184 api_keys,
185 users,
186 session_file_contexts: SessionFileContextsStore::new(Arc::clone(&pool)),
187 session_summaries: SessionSummariesStore::new(Arc::clone(&pool)),
188 pool,
189 })
190 }
191
192 pub fn begin_transaction(&self) -> anyhow::Result<Transaction<'_>> {
194 let conn = self.pool.get()?;
195 conn.execute_batch("BEGIN IMMEDIATE TRANSACTION;")?;
196 Ok(Transaction {
197 conn,
198 _marker: std::marker::PhantomData,
199 })
200 }
201
202 pub fn with_transaction<T, F>(&self, f: F) -> anyhow::Result<T>
204 where
205 F: FnOnce(&mut Transaction<'_>) -> anyhow::Result<T>,
206 {
207 let mut tx = self.begin_transaction()?;
208 match f(&mut tx) {
209 Ok(result) => {
210 tx.commit()?;
211 Ok(result)
212 }
213 Err(e) => {
214 tx.rollback()?;
215 Err(e)
216 }
217 }
218 }
219
220 pub fn get_stats(&self) -> anyhow::Result<DatabaseStats> {
222 let conn = self.pool.get()?;
223 Ok(migration::get_database_stats(&conn)?)
224 }
225
226 pub fn cleanup_old_data(&self, older_than_days: i32) -> anyhow::Result<usize> {
228 let mut conn = self.pool.get()?;
229 let mut migrator = migration::MigrationManager::new(&mut conn);
230 Ok(migrator.cleanup_old_data(older_than_days)?)
231 }
232
233 pub async fn create_kv_snapshot(
235 &self,
236 session_id: &str,
237 entries: &[KVEntry],
238 ) -> anyhow::Result<i64> {
239 use blake3;
240
241 let mut conn = self.pool.get()?; let tx = conn.transaction()?;
243
244 let total_size_bytes: usize = entries.iter()
246 .map(|entry| entry.value_data.len())
247 .sum();
248
249 let kv_state = bincode::serialize(entries)?;
251 let kv_state_hash = blake3::hash(&kv_state).to_string();
252
253 let message_id: i64 = tx.query_row(
255 "SELECT COALESCE(MAX(id), 0) FROM messages WHERE session_id = ?1",
256 [session_id],
257 |row| row.get(0),
258 )?;
259
260 tx.execute(
262 "INSERT INTO kv_snapshots
263 (session_id, message_id, kv_state, kv_state_hash, size_bytes)
264 VALUES (?1, ?2, ?3, ?4, ?5)",
265 rusqlite::params![session_id, message_id, kv_state, kv_state_hash, total_size_bytes as i64],
266 )?;
267
268 let snapshot_id = tx.last_insert_rowid();
269
270 for entry in entries {
272 tx.execute(
273 "INSERT INTO kv_cache_entries
274 (snapshot_id, key_hash, key_data, value_data, key_type,
275 layer_index, head_index, importance_score, access_count)
276 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
277 rusqlite::params![
278 snapshot_id,
279 &entry.key_hash,
280 entry.key_data.as_deref(),
281 &entry.value_data,
282 &entry.key_type,
283 entry.layer_index,
284 entry.head_index,
285 entry.importance_score,
286 entry.access_count,
287 ],
288 )?;
289 }
290
291 let now = chrono::Utc::now().to_rfc3339();
293 tx.execute(
294 "INSERT OR REPLACE INTO kv_cache_metadata
295 (session_id, total_entries, total_size_bytes, last_cleared_at)
296 VALUES (?1, ?2, ?3, ?4)",
297 rusqlite::params![session_id, entries.len() as i64, total_size_bytes as i64, &now],
298 )?;
299
300 tx.commit()?;
301
302 Ok(snapshot_id)
303 }
304
305 pub async fn get_recent_kv_snapshots(
307 &self,
308 session_id: &str,
309 limit: usize,
310 ) -> anyhow::Result<Vec<crate::cache_management::cache_manager::KvSnapshot>> {
311 let conn = self.pool.get()?;
312 let mut stmt = conn.prepare(
313 "SELECT id, session_id, message_id, snapshot_type, size_bytes, created_at
314 FROM kv_snapshots
315 WHERE session_id = ?1
316 ORDER BY created_at DESC
317 LIMIT ?2"
318 )?;
319
320 let mut rows = stmt.query(rusqlite::params![session_id, limit as i64])?;
321 let mut snapshots = Vec::new();
322
323 while let Some(row) = rows.next()? {
324 let created_at_str: String = row.get(5)?;
325 let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str)
326 .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
327 .with_timezone(&chrono::Utc);
328
329 snapshots.push(crate::cache_management::cache_manager::KvSnapshot {
330 id: row.get(0)?,
331 session_id: row.get(1)?,
332 message_id: row.get(2)?,
333 snapshot_type: row.get(3)?,
334 size_bytes: row.get(4)?,
335 created_at,
336 });
337 }
338
339 Ok(snapshots)
340 }
341
342 pub async fn get_kv_snapshot_entries(
344 &self,
345 snapshot_id: i64,
346 ) -> anyhow::Result<Vec<KVEntry>> {
347 let conn = self.pool.get()?;
348 let mut stmt = conn.prepare(
349 "SELECT key_hash, key_data, value_data, key_type, layer_index,
350 head_index, importance_score, access_count, last_accessed
351 FROM kv_cache_entries
352 WHERE snapshot_id = ?1"
353 )?;
354
355 let mut rows = stmt.query([snapshot_id])?;
356 let mut entries = Vec::new();
357
358 while let Some(row) = rows.next()? {
359 let last_accessed_str: String = row.get(8)?;
360 let last_accessed = chrono::DateTime::parse_from_rfc3339(&last_accessed_str)
361 .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
362 .with_timezone(&chrono::Utc);
363
364 entries.push(KVEntry {
365 key_hash: row.get(0)?,
366 key_data: row.get(1)?,
367 value_data: row.get(2)?,
368 key_type: row.get(3)?,
369 layer_index: row.get(4)?,
370 head_index: row.get(5)?,
371 importance_score: row.get(6)?,
372 access_count: row.get(7)?,
373 last_accessed,
374 token_positions: None, embedding: None, size_bytes: { let val: Vec<u8> = row.get(2)?; val.len() as usize }, is_persistent: false, });
379 }
380
381 Ok(entries)
382 }
383
384 pub async fn search_messages_by_keywords(
386 &self,
387 session_id: &str,
388 keywords: &[String],
389 limit: usize,
390 ) -> anyhow::Result<Vec<StoredMessage>> {
391 let patterns: Vec<String> = keywords.iter()
393 .map(|k| format!("%{}%", k))
394 .collect();
395
396 let conn = self.pool.get()?;
397
398 let mut query = String::from(
400 "SELECT id, session_id, message_index, role, content, tokens,
401 timestamp, importance_score, embedding_generated
402 FROM messages
403 WHERE session_id = ?1"
404 );
405
406 for _ in &patterns {
407 query.push_str(" AND content LIKE ?");
408 }
409
410 query.push_str(" ORDER BY timestamp DESC LIMIT ?");
411
412 let mut stmt = conn.prepare(&query)?;
413
414 let mut params: Vec<&dyn rusqlite::ToSql> = Vec::new();
416 params.push(&session_id);
417 for pattern in &patterns {
418 params.push(pattern);
419 }
420 let limit_i64 = limit as i64;
422 params.push(&limit_i64);
423
424 let mut rows = stmt.query(rusqlite::params_from_iter(params))?;
425 let mut messages = Vec::new();
426
427 while let Some(row) = rows.next()? {
428 let timestamp_str: String = row.get(6)?;
429 let timestamp = chrono::DateTime::parse_from_rfc3339(×tamp_str)
430 .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
431 .with_timezone(&chrono::Utc);
432
433 messages.push(StoredMessage {
434 id: row.get(0)?,
435 session_id: row.get(1)?,
436 message_index: row.get(2)?,
437 role: row.get(3)?,
438 content: row.get(4)?,
439 tokens: row.get(5)?,
440 timestamp,
441 importance_score: row.get(7)?,
442 embedding_generated: row.get(8)?,
443 });
444 }
445
446 Ok(messages)
447 }
448
449 pub async fn update_kv_cache_metadata(
451 &self,
452 session_id: &str,
453 state: &SessionCacheState,
454 ) -> anyhow::Result<()> {
455 let conn = self.pool.get()?;
456 let metadata_json = serde_json::to_string(&state.metadata)?;
457
458 conn.execute(
459 "INSERT OR REPLACE INTO kv_cache_metadata
460 (session_id, total_entries, total_size_bytes, conversation_count, metadata)
461 VALUES (?1, ?2, ?3, ?4, ?5)",
462 rusqlite::params![
463 session_id,
464 state.entry_count as i64,
465 state.cache_size_bytes as i64,
466 state.conversation_count as i64,
467 metadata_json,
468 ],
469 )?;
470
471 Ok(())
472 }
473
474 pub async fn cleanup_session_snapshots(
476 &self,
477 session_id: &str,
478 ) -> anyhow::Result<()> {
479 let conn = self.pool.get()?;
480
481 conn.execute(
482 "DELETE FROM kv_snapshots WHERE session_id = ?1",
483 [session_id],
484 )?;
485
486 conn.execute(
487 "DELETE FROM kv_cache_metadata WHERE session_id = ?1",
488 [session_id],
489 )?;
490
491 Ok(())
492 }
493
494 pub async fn prune_old_kv_snapshots(
496 &self,
497 keep_max: usize,
498 ) -> anyhow::Result<usize> {
499 let conn = self.pool.get()?;
500
501 let mut stmt = conn.prepare(
503 "SELECT ks.id
504 FROM kv_snapshots ks
505 WHERE (
506 SELECT COUNT(*)
507 FROM kv_snapshots ks2
508 WHERE ks2.session_id = ks.session_id
509 AND ks2.created_at >= ks.created_at
510 ) > ?1"
511 )?;
512
513 let ids_to_delete: Vec<i64> = stmt
514 .query_map([keep_max as i64], |row| row.get(0))?
515 .collect::<rusqlite::Result<Vec<_>>>()?;
516
517 if ids_to_delete.is_empty() {
518 return Ok(0);
519 }
520
521 let placeholders = vec!["?"; ids_to_delete.len()].join(",");
523 let query = format!("DELETE FROM kv_snapshots WHERE id IN ({})", placeholders);
524
525 let mut stmt = conn.prepare(&query)?;
526 let deleted = stmt.execute(rusqlite::params_from_iter(&ids_to_delete))?;
527
528 Ok(deleted)
529 }
530
531 pub fn optimize(&self) -> anyhow::Result<()> {
537 let conn = self.pool.get()?;
538 conn.execute_batch(
539 "PRAGMA optimize;
540 PRAGMA wal_checkpoint(TRUNCATE);"
541 )?;
542 tracing::info!("SQLite optimize + WAL checkpoint completed");
543 Ok(())
544 }
545}
546
547impl Drop for MemoryDatabase {
548 fn drop(&mut self) {
549 if let Ok(conn) = self.pool.get() {
551 let _ = conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);");
552 }
553 }
554}