oxify_storage/
migration_runner.rs

1//! Database Migration Runner
2//!
3//! This module provides utilities for applying and tracking database schema migrations.
4//!
5//! ## Features
6//!
7//! - Version-tracked migrations with automatic ordering
8//! - Idempotent migration application (safe to run multiple times)
9//! - Transaction-wrapped migrations for atomicity
10//! - Migration history tracking in `schema_migrations` table
11//! - Rollback support for reversible migrations
12//! - Dry-run mode for previewing changes
13//!
14//! ## Usage
15//!
16//! ```ignore
17//! use oxify_storage::{DatabasePool, migration_runner::MigrationRunner};
18//!
19//! let pool = DatabasePool::new(config).await?;
20//! let runner = MigrationRunner::new(pool);
21//!
22//! // Run all pending migrations
23//! let applied = runner.run_pending_migrations().await?;
24//! println!("Applied {} migrations", applied.len());
25//!
26//! // Check migration status
27//! let status = runner.get_migration_status().await?;
28//! for migration in status {
29//!     println!("{}: {} - {}", migration.version, migration.name,
30//!              if migration.applied { "applied" } else { "pending" });
31//! }
32//! ```
33
34use crate::{DatabasePool, Result, StorageError};
35use chrono::{DateTime, Utc};
36use sqlx::Row;
37use std::collections::HashMap;
38
39/// Migration definition
40#[derive(Debug, Clone)]
41pub struct Migration {
42    /// Migration version (e.g., "20260101_001")
43    pub version: String,
44    /// Human-readable migration name
45    pub name: String,
46    /// SQL to apply the migration
47    pub up_sql: String,
48    /// SQL to rollback the migration (optional)
49    pub down_sql: Option<String>,
50    /// Migration description
51    pub description: Option<String>,
52}
53
54/// Migration status information
55#[derive(Debug, Clone)]
56pub struct MigrationStatus {
57    /// Migration version
58    pub version: String,
59    /// Migration name
60    pub name: String,
61    /// Whether the migration has been applied
62    pub applied: bool,
63    /// When the migration was applied (if applied)
64    pub applied_at: Option<DateTime<Utc>>,
65    /// Checksum of the migration SQL
66    pub checksum: Option<String>,
67}
68
69/// Migration history record
70#[derive(Debug, Clone)]
71#[allow(dead_code)]
72struct MigrationRecord {
73    version: String,
74    name: String,
75    applied_at: String,
76    checksum: String,
77}
78
79/// Migration runner for applying and tracking database migrations
80pub struct MigrationRunner {
81    pool: DatabasePool,
82}
83
84impl MigrationRunner {
85    /// Create a new migration runner
86    pub fn new(pool: DatabasePool) -> Self {
87        Self { pool }
88    }
89
90    /// Initialize the migrations tracking table
91    ///
92    /// Creates the `schema_migrations` table if it doesn't exist.
93    pub async fn initialize(&self) -> Result<()> {
94        sqlx::query(
95            r"
96            CREATE TABLE IF NOT EXISTS schema_migrations (
97                version TEXT PRIMARY KEY,
98                name TEXT NOT NULL,
99                applied_at TEXT NOT NULL DEFAULT (datetime('now')),
100                checksum TEXT NOT NULL,
101                execution_time_ms INTEGER
102            )
103            ",
104        )
105        .execute(self.pool.pool())
106        .await?;
107
108        // Create index for faster lookups
109        sqlx::query(
110            r"
111            CREATE INDEX IF NOT EXISTS idx_schema_migrations_applied_at
112            ON schema_migrations(applied_at DESC)
113            ",
114        )
115        .execute(self.pool.pool())
116        .await?;
117
118        Ok(())
119    }
120
121    /// Get all built-in migrations
122    ///
123    /// Returns the standard set of migrations for oxify-storage.
124    pub fn get_builtin_migrations() -> Vec<Migration> {
125        vec![
126            Migration {
127                version: "001_performance_indexes".to_string(),
128                name: "Add Performance Indexes".to_string(),
129                description: Some("Creates indexes for improved query performance".to_string()),
130                up_sql: include_str!("../sql/001_performance_indexes.sql").to_string(),
131                down_sql: Some(include_str!("../sql/001_performance_indexes_down.sql").to_string()),
132            },
133            Migration {
134                version: "002_schema_constraints".to_string(),
135                name: "Add Schema Constraints".to_string(),
136                description: Some("Adds foreign keys and CHECK constraints".to_string()),
137                up_sql: include_str!("../sql/002_schema_constraints.sql").to_string(),
138                down_sql: Some(include_str!("../sql/002_schema_constraints_down.sql").to_string()),
139            },
140        ]
141    }
142
143    /// Check if a migration has been applied
144    pub async fn is_applied(&self, version: &str) -> Result<bool> {
145        let row = sqlx::query(
146            r"
147            SELECT COUNT(*) as count FROM schema_migrations WHERE version = ?
148            ",
149        )
150        .bind(version)
151        .fetch_one(self.pool.pool())
152        .await?;
153
154        let count: i64 = row.get("count");
155        Ok(count > 0)
156    }
157
158    /// Apply a single migration
159    ///
160    /// Runs the migration in a transaction and records it in the migrations table.
161    pub async fn apply_migration(&self, migration: &Migration) -> Result<i64> {
162        // Check if already applied
163        if self.is_applied(&migration.version).await? {
164            return Ok(0);
165        }
166
167        let start_time = std::time::Instant::now();
168        let mut tx = self.pool.pool().begin().await?;
169
170        // Execute the migration SQL
171        // For dynamic SQL, we use Box::leak to get a static string (SqlSafeStr requirement)
172        let up_sql: &'static str = Box::leak(migration.up_sql.clone().into_boxed_str());
173        sqlx::query(up_sql).execute(&mut *tx).await?;
174
175        // Calculate checksum
176        let checksum = Self::calculate_checksum(&migration.up_sql);
177
178        // Record the migration
179        let execution_time_ms = start_time.elapsed().as_millis() as i64;
180        sqlx::query(
181            r"
182            INSERT INTO schema_migrations (version, name, checksum, execution_time_ms)
183            VALUES (?, ?, ?, ?)
184            ",
185        )
186        .bind(&migration.version)
187        .bind(&migration.name)
188        .bind(&checksum)
189        .bind(execution_time_ms)
190        .execute(&mut *tx)
191        .await?;
192
193        tx.commit().await?;
194
195        Ok(execution_time_ms)
196    }
197
198    /// Run all pending migrations
199    ///
200    /// Applies all migrations that haven't been applied yet.
201    pub async fn run_pending_migrations(&self) -> Result<Vec<String>> {
202        self.initialize().await?;
203
204        let migrations = Self::get_builtin_migrations();
205        let mut applied = Vec::new();
206
207        for migration in migrations {
208            if !self.is_applied(&migration.version).await? {
209                self.apply_migration(&migration).await?;
210                applied.push(migration.version.clone());
211            }
212        }
213
214        Ok(applied)
215    }
216
217    /// Get migration status for all migrations
218    pub async fn get_migration_status(&self) -> Result<Vec<MigrationStatus>> {
219        self.initialize().await?;
220
221        let migrations = Self::get_builtin_migrations();
222        let applied_records = self.get_applied_migrations().await?;
223
224        let applied_map: HashMap<String, MigrationRecord> = applied_records
225            .into_iter()
226            .map(|r| (r.version.clone(), r))
227            .collect();
228
229        let mut status = Vec::new();
230        for migration in migrations {
231            let record = applied_map.get(&migration.version);
232            status.push(MigrationStatus {
233                version: migration.version,
234                name: migration.name,
235                applied: record.is_some(),
236                applied_at: record.and_then(|r| {
237                    DateTime::parse_from_rfc3339(&r.applied_at)
238                        .ok()
239                        .map(|dt| dt.with_timezone(&Utc))
240                }),
241                checksum: record.map(|r| r.checksum.clone()),
242            });
243        }
244
245        Ok(status)
246    }
247
248    /// Get all applied migrations
249    async fn get_applied_migrations(&self) -> Result<Vec<MigrationRecord>> {
250        let rows = sqlx::query(
251            r"
252            SELECT version, name, applied_at, checksum
253            FROM schema_migrations
254            ORDER BY applied_at ASC
255            ",
256        )
257        .fetch_all(self.pool.pool())
258        .await?;
259
260        let records: Vec<MigrationRecord> = rows
261            .into_iter()
262            .map(|row| MigrationRecord {
263                version: row.get("version"),
264                name: row.get("name"),
265                applied_at: row.get("applied_at"),
266                checksum: row.get("checksum"),
267            })
268            .collect();
269
270        Ok(records)
271    }
272
273    /// Rollback a migration
274    ///
275    /// Reverts a migration if it has a down_sql defined.
276    pub async fn rollback_migration(&self, version: &str) -> Result<()> {
277        let migrations = Self::get_builtin_migrations();
278        let migration = migrations
279            .iter()
280            .find(|m| m.version == version)
281            .ok_or_else(|| {
282                StorageError::NotFoundLegacy(format!("Migration {version} not found"))
283            })?;
284
285        let down_sql = migration.down_sql.as_ref().ok_or_else(|| {
286            StorageError::ValidationError(format!("Migration {version} has no rollback defined"))
287        })?;
288
289        let mut tx = self.pool.pool().begin().await?;
290
291        // Execute rollback SQL
292        // For dynamic SQL, we use Box::leak to get a static string (SqlSafeStr requirement)
293        let down_sql_static: &'static str = Box::leak(down_sql.clone().into_boxed_str());
294        sqlx::query(down_sql_static).execute(&mut *tx).await?;
295
296        // Remove migration record
297        sqlx::query("DELETE FROM schema_migrations WHERE version = ?")
298            .bind(version)
299            .execute(&mut *tx)
300            .await?;
301
302        tx.commit().await?;
303
304        Ok(())
305    }
306
307    /// Calculate SHA-256 checksum of SQL content
308    fn calculate_checksum(sql: &str) -> String {
309        use sha2::{Digest, Sha256};
310        let mut hasher = Sha256::new();
311        hasher.update(sql.as_bytes());
312        format!("{:x}", hasher.finalize())
313    }
314
315    /// Verify migration checksums match recorded values
316    ///
317    /// Detects if migration files have been modified after being applied.
318    pub async fn verify_checksums(&self) -> Result<Vec<String>> {
319        let migrations = Self::get_builtin_migrations();
320        let applied = self.get_applied_migrations().await?;
321        let mut mismatches = Vec::new();
322
323        let applied_map: HashMap<String, MigrationRecord> = applied
324            .into_iter()
325            .map(|r| (r.version.clone(), r))
326            .collect();
327
328        for migration in migrations {
329            if let Some(record) = applied_map.get(&migration.version) {
330                let current_checksum = Self::calculate_checksum(&migration.up_sql);
331                if current_checksum != record.checksum {
332                    mismatches.push(migration.version.clone());
333                }
334            }
335        }
336
337        Ok(mismatches)
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_checksum_calculation() {
347        let sql = "CREATE TABLE test (id INT)";
348        let checksum = MigrationRunner::calculate_checksum(sql);
349        assert_eq!(checksum.len(), 64); // SHA-256 produces 64 hex characters
350    }
351
352    #[test]
353    fn test_checksum_consistency() {
354        let sql = "SELECT * FROM users";
355        let checksum1 = MigrationRunner::calculate_checksum(sql);
356        let checksum2 = MigrationRunner::calculate_checksum(sql);
357        assert_eq!(checksum1, checksum2);
358    }
359
360    #[test]
361    fn test_builtin_migrations_exist() {
362        let migrations = MigrationRunner::get_builtin_migrations();
363        assert!(!migrations.is_empty());
364        assert!(migrations
365            .iter()
366            .any(|m| m.version.contains("performance_indexes")));
367        assert!(migrations
368            .iter()
369            .any(|m| m.version.contains("schema_constraints")));
370    }
371}