Skip to main content

offline_intelligence/memory_db/
migration.rs

1//! Database migration system
2
3use rusqlite::{Connection, OptionalExtension, Result};
4use std::path::Path;
5use tracing::{error, info, warn};
6
7// Import the schema module from the same memory_db module
8use crate::memory_db::schema;
9
10/// Manages database schema migrations
11pub struct MigrationManager<'a> {
12    conn: &'a mut Connection,
13}
14
15impl<'a> MigrationManager<'a> {
16    /// Create a new migration manager
17    pub fn new(conn: &'a mut Connection) -> Self {
18        Self { conn }
19    }
20
21    /// Initialize database with current schema
22    pub fn initialize_database(&mut self) -> Result<()> {
23        info!("Initializing memory database schema...");
24
25        // Create schema version table if it doesn't exist
26        self.conn.execute(
27            "CREATE TABLE IF NOT EXISTS schema_version (
28                version INTEGER PRIMARY KEY,
29                applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
30            )",
31            [],
32        )?;
33
34        // Get current version
35        let current_version: i32 = self
36            .conn
37            .query_row(
38                "SELECT COALESCE(MAX(version), 0) FROM schema_version",
39                [],
40                |row| row.get(0),
41            )
42            .unwrap_or(0);
43
44        info!("Current database schema version: {}", current_version);
45
46        // Apply migrations based on current version
47        self.apply_migrations(current_version)?;
48
49        Ok(())
50    }
51
52    /// Apply all pending migrations
53    fn apply_migrations(&mut self, current_version: i32) -> Result<()> {
54        let migrations = get_migrations();
55
56        for (version, migration_sql) in migrations.iter() {
57            if *version > current_version {
58                info!("Applying migration {}...", version);
59
60                // Begin transaction - requires mutable self
61                let tx = self.conn.transaction()?;
62
63                // Apply migration
64                if let Err(e) = tx.execute_batch(migration_sql) {
65                    error!("Failed to apply migration {}: {}", version, e);
66                    return Err(e);
67                }
68
69                // Record migration
70                tx.execute("INSERT INTO schema_version (version) VALUES (?)", [version])?;
71
72                // Commit transaction
73                tx.commit()?;
74
75                info!("Migration {} applied successfully", version);
76            }
77        }
78
79        Ok(())
80    }
81
82    /// Create database connection with migrations applied
83    pub fn create_connection(db_path: &Path) -> Result<Connection> {
84        // Open or create database
85        let mut conn = Connection::open(db_path)?;
86
87        // Enable foreign keys and WAL mode for better performance
88        conn.execute_batch(
89            "
90            PRAGMA foreign_keys = ON;
91            PRAGMA journal_mode = WAL;
92            PRAGMA synchronous = NORMAL;
93            PRAGMA cache_size = -2000; -- 2MB cache
94        ",
95        )?;
96
97        // Apply migrations - need mutable access
98        let mut migrator = MigrationManager::new(&mut conn);
99        migrator.initialize_database()?;
100
101        Ok(conn)
102    }
103
104    /// Clean up old data - needs mutable access
105    pub fn cleanup_old_data(&mut self, older_than_days: i32) -> Result<usize> {
106        let cutoff = chrono::Utc::now() - chrono::Duration::days(older_than_days as i64);
107        let cutoff_str = cutoff.to_rfc3339();
108
109        // Delete old sessions and their related data (cascading delete)
110        let deleted = self.conn.execute(
111            "DELETE FROM sessions WHERE last_accessed < ?1",
112            [&cutoff_str],
113        )?;
114
115        info!("Cleaned up {} old sessions", deleted);
116
117        // Vacuum to reclaim space
118        if deleted > 0 {
119            self.conn.execute_batch("VACUUM")?;
120            info!("Database vacuum completed");
121        }
122
123        Ok(deleted)
124    }
125
126    /// Get current schema version
127    pub fn get_current_version(&self) -> Result<i32> {
128        self.conn
129            .query_row(
130                "SELECT COALESCE(MAX(version), 0) FROM schema_version",
131                [],
132                |row| row.get(0),
133            )
134            .or_else(|_| Ok(0))
135    }
136
137    /// Check if a specific migration has been applied
138    pub fn has_migration_applied(&self, version: i32) -> Result<bool> {
139        self.conn
140            .query_row(
141                "SELECT 1 FROM schema_version WHERE version = ?",
142                [version],
143                |_| Ok(1),
144            )
145            .optional()
146            .map(|result| result.is_some())
147    }
148}
149
150/// Get all migration SQL scripts
151fn get_migrations() -> Vec<(i32, &'static str)> {
152    vec![
153        (1, include_str!("migrations/001_initial.sql")),
154        (2, include_str!("migrations/002_add_embeddings.sql")),
155        (3, include_str!("migrations/003_add_kv_snapshots.sql")),
156        (4, include_str!("migrations/004_local_files.sql")),
157        (5, include_str!("migrations/005_curated_files.sql")),
158        (6, include_str!("migrations/006_all_files.sql")),
159        (7, include_str!("migrations/007_session_file_contexts.sql")),
160        (8, include_str!("migrations/008_session_summaries.sql")),
161    ]
162}
163
164/// Get database statistics from a connection
165/// This is safe to call even with a locked connection since it only performs read queries
166pub fn get_database_stats(conn: &Connection) -> Result<schema::DatabaseStats> {
167    // Helper function to safely get count from a table
168    fn get_table_count(conn: &Connection, table_name: &str) -> Result<i64> {
169        conn.query_row(&format!("SELECT COUNT(*) FROM {}", table_name), [], |row| {
170            row.get(0)
171        })
172        .or_else(|e| {
173            warn!("Failed to get count from table {}: {}", table_name, e);
174            Ok(0) // Return 0 if table doesn't exist or query fails
175        })
176    }
177
178    let total_sessions = get_table_count(conn, "sessions")?;
179    let total_messages = get_table_count(conn, "messages")?;
180    let total_summaries = get_table_count(conn, "summaries")?;
181    let total_details = get_table_count(conn, "details")?;
182    let total_embeddings = get_table_count(conn, "embeddings")?;
183
184    // Get database size - this query is safe and doesn't modify anything
185    let database_size_bytes: i64 = conn
186        .query_row(
187            "SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
188            [],
189            |row| row.get(0),
190        )
191        .unwrap_or(0);
192
193    Ok(schema::DatabaseStats {
194        total_sessions,
195        total_messages,
196        total_summaries,
197        total_details,
198        total_embeddings,
199        database_size_bytes,
200    })
201}
202
203/// Get database statistics with connection creation
204/// Useful when you don't have an existing connection
205pub fn get_database_stats_from_path(db_path: &Path) -> Result<schema::DatabaseStats> {
206    let conn = Connection::open(db_path)?;
207    get_database_stats(&conn)
208}
209
210/// Run database maintenance tasks
211pub fn run_maintenance(conn: &mut Connection) -> Result<()> {
212    info!("Running database maintenance...");
213
214    // Analyze for better query optimization
215    conn.execute_batch("ANALYZE")?;
216
217    // Incremental vacuum if needed
218    conn.execute_batch("PRAGMA incremental_vacuum(100)")?;
219
220    // Check integrity
221    conn.execute_batch("PRAGMA integrity_check")?;
222
223    info!("Database maintenance completed");
224    Ok(())
225}