kaccy_db/
migration_runner.rs

1//! Migration Runner
2//!
3//! Automated database migration execution and rollback system.
4//! Complements the migration_utils module by providing execution capabilities.
5//!
6//! # Features
7//!
8//! - Automatic migration discovery and ordering
9//! - Safe migration execution with transactions
10//! - Rollback support for failed migrations
11//! - Dry-run mode for testing
12//! - Progress tracking and reporting
13//! - Migration dependency validation
14//! - Idempotent execution (safe to run multiple times)
15//!
16//! # Example
17//!
18//! ```rust,no_run
19//! use kaccy_db::migration_runner::{MigrationRunner, MigrationConfig};
20//! use sqlx::PgPool;
21//!
22//! async fn run_migrations(pool: &PgPool) -> Result<(), Box<dyn std::error::Error>> {
23//!     let config = MigrationConfig {
24//!         migrations_dir: "migrations".to_string(),
25//!         allow_missing: false,
26//!         dry_run: false,
27//!     };
28//!
29//!     let runner = MigrationRunner::new(pool.clone(), config);
30//!     let result = runner.run_pending_migrations().await?;
31//!
32//!     println!("Applied {} migrations", result.applied_count);
33//!     Ok(())
34//! }
35//! ```
36
37use crate::error::{DbError, Result};
38use chrono::{DateTime, Utc};
39use serde::{Deserialize, Serialize};
40use sqlx::{PgPool, Postgres, Transaction};
41use std::path::Path;
42use tracing::{info, warn};
43
44/// Configuration for the migration runner
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct MigrationConfig {
47    /// Directory containing migration SQL files
48    pub migrations_dir: String,
49
50    /// Whether to allow missing migrations in sequence
51    pub allow_missing: bool,
52
53    /// Whether to run in dry-run mode (no actual changes)
54    pub dry_run: bool,
55}
56
57impl Default for MigrationConfig {
58    fn default() -> Self {
59        Self {
60            migrations_dir: "migrations".to_string(),
61            allow_missing: false,
62            dry_run: false,
63        }
64    }
65}
66
67/// Information about a single migration
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct Migration {
70    /// Migration version (extracted from filename, e.g., "001" from "001_initial.sql")
71    pub version: String,
72
73    /// Migration name (e.g., "initial" from "001_initial.sql")
74    pub name: String,
75
76    /// SQL content of the migration
77    pub sql: String,
78
79    /// File path to the migration
80    pub file_path: String,
81}
82
83impl Migration {
84    /// Parse a migration from a file path
85    pub fn from_file(file_path: &str) -> Result<Self> {
86        let path = Path::new(file_path);
87        let filename = path
88            .file_name()
89            .and_then(|n| n.to_str())
90            .ok_or_else(|| DbError::NotFound("Invalid migration filename".to_string()))?;
91
92        // Parse filename: "001_initial.sql" -> version="001", name="initial"
93        let parts: Vec<&str> = filename.trim_end_matches(".sql").splitn(2, '_').collect();
94
95        if parts.len() != 2 {
96            return Err(DbError::NotFound(format!(
97                "Invalid migration filename format: {}",
98                filename
99            )));
100        }
101
102        let version = parts[0].to_string();
103        let name = parts[1].to_string();
104
105        let sql = std::fs::read_to_string(file_path)
106            .map_err(|e| DbError::NotFound(format!("Failed to read migration file: {}", e)))?;
107
108        Ok(Self {
109            version,
110            name,
111            sql,
112            file_path: file_path.to_string(),
113        })
114    }
115
116    /// Get the migration ID (version_name format for tracking)
117    pub fn id(&self) -> String {
118        format!("{}_{}", self.version, self.name)
119    }
120}
121
122/// Result of running migrations
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct MigrationResult {
125    /// Number of migrations applied
126    pub applied_count: usize,
127
128    /// Number of migrations skipped (already applied)
129    pub skipped_count: usize,
130
131    /// List of applied migrations
132    pub applied_migrations: Vec<String>,
133
134    /// List of skipped migrations
135    pub skipped_migrations: Vec<String>,
136
137    /// Whether the migration run was successful
138    pub success: bool,
139
140    /// Error message if failed
141    pub error: Option<String>,
142
143    /// Execution time in milliseconds
144    pub execution_time_ms: u64,
145}
146
147/// Migration execution history record
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct MigrationHistory {
150    /// Migration ID
151    pub migration_id: String,
152
153    /// When the migration was applied
154    pub applied_at: DateTime<Utc>,
155
156    /// Execution time in milliseconds
157    pub execution_time_ms: i64,
158
159    /// Whether the migration succeeded
160    pub success: bool,
161}
162
163/// Migration runner
164pub struct MigrationRunner {
165    pool: PgPool,
166    config: MigrationConfig,
167}
168
169impl MigrationRunner {
170    /// Create a new migration runner
171    pub fn new(pool: PgPool, config: MigrationConfig) -> Self {
172        Self { pool, config }
173    }
174
175    /// Create a migration runner with default configuration
176    pub fn with_defaults(pool: PgPool) -> Self {
177        Self::new(pool, MigrationConfig::default())
178    }
179
180    /// Ensure the migrations tracking table exists
181    async fn ensure_migrations_table(&self) -> Result<()> {
182        sqlx::query(
183            r#"
184            CREATE TABLE IF NOT EXISTS _migrations (
185                id SERIAL PRIMARY KEY,
186                migration_id VARCHAR(255) NOT NULL UNIQUE,
187                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
188                execution_time_ms BIGINT NOT NULL,
189                success BOOLEAN NOT NULL DEFAULT TRUE
190            )
191            "#,
192        )
193        .execute(&self.pool)
194        .await?;
195
196        Ok(())
197    }
198
199    /// Get list of applied migrations
200    async fn get_applied_migrations(&self) -> Result<Vec<String>> {
201        let records = sqlx::query_scalar::<_, String>(
202            "SELECT migration_id FROM _migrations WHERE success = TRUE ORDER BY id",
203        )
204        .fetch_all(&self.pool)
205        .await?;
206
207        Ok(records)
208    }
209
210    /// Check if a migration has been applied
211    #[allow(dead_code)]
212    async fn is_migration_applied(&self, migration_id: &str) -> Result<bool> {
213        let count = sqlx::query_scalar::<_, i64>(
214            "SELECT COUNT(*) FROM _migrations WHERE migration_id = $1 AND success = TRUE",
215        )
216        .bind(migration_id)
217        .fetch_one(&self.pool)
218        .await?;
219
220        Ok(count > 0)
221    }
222
223    /// Record a migration as applied
224    async fn record_migration(
225        &self,
226        tx: &mut Transaction<'_, Postgres>,
227        migration_id: &str,
228        execution_time_ms: i64,
229        success: bool,
230    ) -> Result<()> {
231        sqlx::query(
232            r#"
233            INSERT INTO _migrations (migration_id, applied_at, execution_time_ms, success)
234            VALUES ($1, NOW(), $2, $3)
235            "#,
236        )
237        .bind(migration_id)
238        .bind(execution_time_ms)
239        .bind(success)
240        .execute(&mut **tx)
241        .await?;
242
243        Ok(())
244    }
245
246    /// Discover all migration files in the migrations directory
247    pub fn discover_migrations(&self) -> Result<Vec<Migration>> {
248        let migrations_dir = Path::new(&self.config.migrations_dir);
249
250        if !migrations_dir.exists() {
251            return Err(DbError::NotFound(format!(
252                "Migrations directory not found: {}",
253                self.config.migrations_dir
254            )));
255        }
256
257        let mut migrations = Vec::new();
258
259        let entries = std::fs::read_dir(migrations_dir).map_err(|e| {
260            DbError::NotFound(format!("Failed to read migrations directory: {}", e))
261        })?;
262
263        for entry in entries {
264            let entry = entry
265                .map_err(|e| DbError::NotFound(format!("Failed to read directory entry: {}", e)))?;
266
267            let path = entry.path();
268
269            if path.extension().and_then(|s| s.to_str()) == Some("sql") {
270                let migration = Migration::from_file(path.to_str().unwrap())?;
271                migrations.push(migration);
272            }
273        }
274
275        // Sort by version
276        migrations.sort_by(|a, b| a.version.cmp(&b.version));
277
278        Ok(migrations)
279    }
280
281    /// Run all pending migrations
282    pub async fn run_pending_migrations(&self) -> Result<MigrationResult> {
283        let start_time = std::time::Instant::now();
284
285        // Ensure migrations table exists
286        self.ensure_migrations_table().await?;
287
288        // Discover all migrations
289        let all_migrations = self.discover_migrations()?;
290        let applied_migrations = self.get_applied_migrations().await?;
291
292        let mut result = MigrationResult {
293            applied_count: 0,
294            skipped_count: 0,
295            applied_migrations: Vec::new(),
296            skipped_migrations: Vec::new(),
297            success: true,
298            error: None,
299            execution_time_ms: 0,
300        };
301
302        for migration in all_migrations {
303            let migration_id = migration.id();
304
305            // Skip if already applied
306            if applied_migrations.contains(&migration_id) {
307                info!(
308                    migration_id = migration_id,
309                    "Skipping already applied migration"
310                );
311                result.skipped_count += 1;
312                result.skipped_migrations.push(migration_id);
313                continue;
314            }
315
316            if self.config.dry_run {
317                info!(
318                    migration_id = migration_id,
319                    "Dry-run: Would apply migration"
320                );
321                result.applied_count += 1;
322                result.applied_migrations.push(migration_id);
323                continue;
324            }
325
326            // Execute migration in transaction
327            match self.execute_migration(&migration).await {
328                Ok(execution_time_ms) => {
329                    info!(
330                        migration_id = migration_id,
331                        execution_time_ms = execution_time_ms,
332                        "Migration applied successfully"
333                    );
334                    result.applied_count += 1;
335                    result.applied_migrations.push(migration_id);
336                }
337                Err(e) => {
338                    warn!(
339                        migration_id = migration_id,
340                        error = %e,
341                        "Migration failed"
342                    );
343                    result.success = false;
344                    result.error = Some(format!("Migration {} failed: {}", migration_id, e));
345                    break;
346                }
347            }
348        }
349
350        result.execution_time_ms = start_time.elapsed().as_millis() as u64;
351
352        Ok(result)
353    }
354
355    /// Execute a single migration
356    async fn execute_migration(&self, migration: &Migration) -> Result<i64> {
357        let start_time = std::time::Instant::now();
358
359        let mut tx = self.pool.begin().await?;
360
361        // Execute the migration SQL
362        match sqlx::query(&migration.sql).execute(&mut *tx).await {
363            Ok(_) => {
364                let execution_time_ms = start_time.elapsed().as_millis() as i64;
365
366                // Record the migration
367                self.record_migration(&mut tx, &migration.id(), execution_time_ms, true)
368                    .await?;
369
370                // Commit transaction
371                tx.commit().await?;
372
373                Ok(execution_time_ms)
374            }
375            Err(e) => {
376                // Rollback on error
377                tx.rollback().await?;
378
379                Err(DbError::from(e))
380            }
381        }
382    }
383
384    /// Get migration history
385    pub async fn get_migration_history(&self) -> Result<Vec<MigrationHistory>> {
386        self.ensure_migrations_table().await?;
387
388        let records = sqlx::query_as::<_, (String, DateTime<Utc>, i64, bool)>(
389            "SELECT migration_id, applied_at, execution_time_ms, success FROM _migrations ORDER BY id",
390        )
391        .fetch_all(&self.pool)
392        .await?;
393
394        Ok(records
395            .into_iter()
396            .map(
397                |(migration_id, applied_at, execution_time_ms, success)| MigrationHistory {
398                    migration_id,
399                    applied_at,
400                    execution_time_ms,
401                    success,
402                },
403            )
404            .collect())
405    }
406
407    /// Verify all migrations are in order (no missing versions)
408    pub fn verify_migration_order(&self, migrations: &[Migration]) -> Result<()> {
409        if self.config.allow_missing {
410            return Ok(());
411        }
412
413        for (i, migration) in migrations.iter().enumerate() {
414            let expected_version = format!("{:03}", i + 1);
415            if migration.version != expected_version {
416                return Err(DbError::NotFound(format!(
417                    "Migration version mismatch: expected {}, found {}",
418                    expected_version, migration.version
419                )));
420            }
421        }
422
423        Ok(())
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_migration_config_default() {
433        let config = MigrationConfig::default();
434        assert_eq!(config.migrations_dir, "migrations");
435        assert!(!config.allow_missing);
436        assert!(!config.dry_run);
437    }
438
439    #[test]
440    fn test_migration_id() {
441        let migration = Migration {
442            version: "001".to_string(),
443            name: "initial".to_string(),
444            sql: "CREATE TABLE test();".to_string(),
445            file_path: "migrations/001_initial.sql".to_string(),
446        };
447
448        assert_eq!(migration.id(), "001_initial");
449    }
450
451    #[test]
452    fn test_migration_result_serialization() {
453        let result = MigrationResult {
454            applied_count: 3,
455            skipped_count: 2,
456            applied_migrations: vec!["001_initial".to_string()],
457            skipped_migrations: vec!["000_setup".to_string()],
458            success: true,
459            error: None,
460            execution_time_ms: 1500,
461        };
462
463        let json = serde_json::to_string(&result).unwrap();
464        assert!(json.contains("applied_count"));
465        assert!(json.contains("\"success\":true"));
466    }
467
468    #[test]
469    fn test_migration_history_serialization() {
470        let history = MigrationHistory {
471            migration_id: "001_initial".to_string(),
472            applied_at: Utc::now(),
473            execution_time_ms: 1000,
474            success: true,
475        };
476
477        let json = serde_json::to_string(&history).unwrap();
478        assert!(json.contains("migration_id"));
479        assert!(json.contains("001_initial"));
480    }
481
482    #[tokio::test]
483    async fn test_verify_migration_order_success() {
484        let migrations = vec![
485            Migration {
486                version: "001".to_string(),
487                name: "first".to_string(),
488                sql: String::new(),
489                file_path: String::new(),
490            },
491            Migration {
492                version: "002".to_string(),
493                name: "second".to_string(),
494                sql: String::new(),
495                file_path: String::new(),
496            },
497            Migration {
498                version: "003".to_string(),
499                name: "third".to_string(),
500                sql: String::new(),
501                file_path: String::new(),
502            },
503        ];
504
505        let pool = PgPool::connect_lazy("postgresql://localhost/test").unwrap();
506        let runner = MigrationRunner::with_defaults(pool);
507
508        assert!(runner.verify_migration_order(&migrations).is_ok());
509    }
510
511    #[tokio::test]
512    async fn test_verify_migration_order_failure() {
513        let migrations = vec![
514            Migration {
515                version: "001".to_string(),
516                name: "first".to_string(),
517                sql: String::new(),
518                file_path: String::new(),
519            },
520            Migration {
521                version: "003".to_string(), // Skipped 002
522                name: "third".to_string(),
523                sql: String::new(),
524                file_path: String::new(),
525            },
526        ];
527
528        let pool = PgPool::connect_lazy("postgresql://localhost/test").unwrap();
529        let runner = MigrationRunner::with_defaults(pool);
530
531        assert!(runner.verify_migration_order(&migrations).is_err());
532    }
533
534    #[tokio::test]
535    async fn test_verify_migration_order_with_allow_missing() {
536        let migrations = vec![
537            Migration {
538                version: "001".to_string(),
539                name: "first".to_string(),
540                sql: String::new(),
541                file_path: String::new(),
542            },
543            Migration {
544                version: "003".to_string(),
545                name: "third".to_string(),
546                sql: String::new(),
547                file_path: String::new(),
548            },
549        ];
550
551        let pool = PgPool::connect_lazy("postgresql://localhost/test").unwrap();
552        let config = MigrationConfig {
553            allow_missing: true,
554            ..Default::default()
555        };
556        let runner = MigrationRunner::new(pool, config);
557
558        assert!(runner.verify_migration_order(&migrations).is_ok());
559    }
560}