Skip to main content

oxide_sql_core/migrations/
migration.rs

1//! Migration trait and runner.
2//!
3//! Provides the `Migration` trait that all migrations implement, and the
4//! `MigrationRunner` that executes migrations in dependency order.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use super::dialect::MigrationDialect;
9use super::operation::Operation;
10use super::state::MigrationState;
11
12/// A database migration with typed up/down operations.
13///
14/// Implement this trait for each migration in your application.
15///
16/// # Example
17///
18/// ```rust
19/// use oxide_sql_core::migrations::{
20///     Migration, Operation, CreateTableBuilder,
21///     bigint, varchar, timestamp,
22/// };
23///
24/// pub struct Migration0001;
25///
26/// impl Migration for Migration0001 {
27///     const ID: &'static str = "0001_create_users";
28///
29///     fn up() -> Vec<Operation> {
30///         vec![
31///             CreateTableBuilder::new()
32///                 .name("users")
33///                 .column(bigint("id").primary_key().autoincrement().build())
34///                 .column(varchar("username", 255).not_null().unique().build())
35///                 .build()
36///                 .into(),
37///         ]
38///     }
39///
40///     fn down() -> Vec<Operation> {
41///         vec![
42///             Operation::drop_table("users"),
43///         ]
44///     }
45/// }
46/// ```
47pub trait Migration {
48    /// Unique migration identifier (e.g., "0001_initial", "0002_add_email").
49    ///
50    /// This ID is stored in the migrations table to track which migrations
51    /// have been applied.
52    const ID: &'static str;
53
54    /// Dependencies on other migrations (must run first).
55    ///
56    /// Each string should be the `ID` of another migration.
57    const DEPENDENCIES: &'static [&'static str] = &[];
58
59    /// Apply the migration (forward).
60    ///
61    /// Returns a list of operations to execute.
62    fn up() -> Vec<Operation>;
63
64    /// Reverse the migration (backward).
65    ///
66    /// Returns a list of operations to execute to undo the migration.
67    /// Return an empty vec if the migration is not reversible.
68    fn down() -> Vec<Operation>;
69}
70
71/// A registered migration with runtime-accessible metadata.
72pub struct RegisteredMigration {
73    /// Migration ID.
74    pub id: &'static str,
75    /// Dependencies.
76    pub dependencies: &'static [&'static str],
77    /// Function to get up operations.
78    pub up: fn() -> Vec<Operation>,
79    /// Function to get down operations.
80    pub down: fn() -> Vec<Operation>,
81}
82
83impl RegisteredMigration {
84    /// Creates a new registered migration from a `Migration` implementor.
85    #[must_use]
86    pub const fn new<M: Migration>() -> Self {
87        Self {
88            id: M::ID,
89            dependencies: M::DEPENDENCIES,
90            up: M::up,
91            down: M::down,
92        }
93    }
94}
95
96/// Status of a migration.
97#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct MigrationStatus {
99    /// The migration ID.
100    pub id: &'static str,
101    /// Whether the migration has been applied.
102    pub applied: bool,
103    /// When the migration was applied (if known).
104    pub applied_at: Option<String>,
105}
106
107/// Runs migrations in dependency order.
108///
109/// The runner tracks which migrations are registered and uses the provided
110/// `MigrationState` to determine which migrations need to be applied.
111///
112/// # Example
113///
114/// ```rust
115/// use oxide_sql_core::migrations::{
116///     Migration, MigrationRunner, MigrationState, Operation,
117///     CreateTableBuilder, SqliteDialect, bigint,
118/// };
119///
120/// // Define a migration
121/// pub struct Migration0001;
122/// impl Migration for Migration0001 {
123///     const ID: &'static str = "0001_initial";
124///     fn up() -> Vec<Operation> {
125///         vec![CreateTableBuilder::new()
126///             .name("test")
127///             .column(bigint("id").primary_key().build())
128///             .build()
129///             .into()]
130///     }
131///     fn down() -> Vec<Operation> {
132///         vec![Operation::drop_table("test")]
133///     }
134/// }
135///
136/// // Create runner
137/// let mut runner = MigrationRunner::new(SqliteDialect::new());
138/// runner.register::<Migration0001>();
139///
140/// // Check status
141/// let state = MigrationState::new();
142/// let pending = runner.pending_migrations(&state);
143/// assert_eq!(pending.len(), 1);
144/// ```
145pub struct MigrationRunner<D: MigrationDialect> {
146    migrations: Vec<RegisteredMigration>,
147    dialect: D,
148}
149
150impl<D: MigrationDialect> MigrationRunner<D> {
151    /// Creates a new migration runner with the given dialect.
152    #[must_use]
153    pub fn new(dialect: D) -> Self {
154        Self {
155            migrations: Vec::new(),
156            dialect,
157        }
158    }
159
160    /// Registers a migration.
161    pub fn register<M: Migration>(&mut self) -> &mut Self {
162        self.migrations.push(RegisteredMigration::new::<M>());
163        self
164    }
165
166    /// Returns all registered migrations.
167    #[must_use]
168    pub fn migrations(&self) -> &[RegisteredMigration] {
169        &self.migrations
170    }
171
172    /// Returns the dialect.
173    #[must_use]
174    pub fn dialect(&self) -> &D {
175        &self.dialect
176    }
177
178    /// Returns migrations that haven't been applied yet.
179    #[must_use]
180    pub fn pending_migrations(&self, state: &MigrationState) -> Vec<&RegisteredMigration> {
181        self.migrations
182            .iter()
183            .filter(|m| !state.is_applied(m.id))
184            .collect()
185    }
186
187    /// Returns the status of all migrations.
188    #[must_use]
189    pub fn status(&self, state: &MigrationState) -> Vec<MigrationStatus> {
190        self.migrations
191            .iter()
192            .map(|m| MigrationStatus {
193                id: m.id,
194                applied: state.is_applied(m.id),
195                applied_at: None, // Would need to query the DB for this
196            })
197            .collect()
198    }
199
200    /// Returns migrations in dependency order (topological sort).
201    ///
202    /// Returns `Err` if there's a circular dependency.
203    pub fn sorted_migrations(&self) -> Result<Vec<&RegisteredMigration>, MigrationError> {
204        // Build dependency graph
205        let mut in_degree: HashMap<&str, usize> = HashMap::new();
206        let mut dependents: HashMap<&str, Vec<&str>> = HashMap::new();
207        let migration_map: HashMap<&str, &RegisteredMigration> =
208            self.migrations.iter().map(|m| (m.id, m)).collect();
209
210        for m in &self.migrations {
211            in_degree.entry(m.id).or_insert(0);
212            for dep in m.dependencies {
213                *in_degree.entry(m.id).or_insert(0) += 1;
214                dependents.entry(*dep).or_default().push(m.id);
215            }
216        }
217
218        // Kahn's algorithm for topological sort
219        let mut queue: VecDeque<&str> = in_degree
220            .iter()
221            .filter(|(_, deg)| **deg == 0)
222            .map(|(id, _)| *id)
223            .collect();
224        let mut result = Vec::new();
225
226        while let Some(id) = queue.pop_front() {
227            if let Some(m) = migration_map.get(id) {
228                result.push(*m);
229            }
230
231            if let Some(deps) = dependents.get(id) {
232                for dep in deps {
233                    if let Some(deg) = in_degree.get_mut(dep) {
234                        *deg -= 1;
235                        if *deg == 0 {
236                            queue.push_back(dep);
237                        }
238                    }
239                }
240            }
241        }
242
243        if result.len() != self.migrations.len() {
244            return Err(MigrationError::CircularDependency);
245        }
246
247        Ok(result)
248    }
249
250    /// Generates SQL for all pending migrations.
251    ///
252    /// Returns a list of (migration_id, sql_statements) pairs.
253    pub fn sql_for_pending(
254        &self,
255        state: &MigrationState,
256    ) -> Result<Vec<(&'static str, Vec<String>)>, MigrationError> {
257        let sorted = self.sorted_migrations()?;
258        let pending: Vec<_> = sorted
259            .into_iter()
260            .filter(|m| !state.is_applied(m.id))
261            .collect();
262
263        let mut result = Vec::new();
264        for migration in pending {
265            let operations = (migration.up)();
266            let sqls: Vec<String> = operations
267                .iter()
268                .map(|op| self.dialect.generate_sql(op))
269                .collect();
270            result.push((migration.id, sqls));
271        }
272
273        Ok(result)
274    }
275
276    /// Generates SQL for rolling back migrations.
277    ///
278    /// Returns a list of (migration_id, sql_statements) pairs in reverse order.
279    pub fn sql_for_rollback(
280        &self,
281        state: &MigrationState,
282        count: usize,
283    ) -> Result<Vec<(&'static str, Vec<String>)>, MigrationError> {
284        let sorted = self.sorted_migrations()?;
285
286        // Get applied migrations in reverse order
287        let applied: Vec<_> = sorted
288            .into_iter()
289            .rev()
290            .filter(|m| state.is_applied(m.id))
291            .take(count)
292            .collect();
293
294        let mut result = Vec::new();
295        for migration in applied {
296            let operations = (migration.down)();
297            if operations.is_empty() {
298                return Err(MigrationError::NotReversible(migration.id.to_string()));
299            }
300            let sqls: Vec<String> = operations
301                .iter()
302                .map(|op| self.dialect.generate_sql(op))
303                .collect();
304            result.push((migration.id, sqls));
305        }
306
307        Ok(result)
308    }
309
310    /// Validates that all dependencies exist and are registered.
311    pub fn validate(&self) -> Result<(), MigrationError> {
312        let ids: HashSet<&str> = self.migrations.iter().map(|m| m.id).collect();
313
314        for m in &self.migrations {
315            for dep in m.dependencies {
316                if !ids.contains(dep) {
317                    return Err(MigrationError::MissingDependency {
318                        migration: m.id.to_string(),
319                        dependency: (*dep).to_string(),
320                    });
321                }
322            }
323        }
324
325        // Check for circular dependencies
326        let _ = self.sorted_migrations()?;
327
328        Ok(())
329    }
330}
331
332/// Errors that can occur during migration.
333#[derive(Debug, Clone, PartialEq, Eq)]
334pub enum MigrationError {
335    /// A migration has a circular dependency.
336    CircularDependency,
337    /// A migration depends on another that doesn't exist.
338    MissingDependency {
339        /// The migration with the missing dependency.
340        migration: String,
341        /// The dependency that's missing.
342        dependency: String,
343    },
344    /// A migration is not reversible.
345    NotReversible(String),
346    /// Database error.
347    DatabaseError(String),
348}
349
350impl std::fmt::Display for MigrationError {
351    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352        match self {
353            Self::CircularDependency => write!(f, "Circular dependency detected in migrations"),
354            Self::MissingDependency {
355                migration,
356                dependency,
357            } => write!(
358                f,
359                "Migration '{}' depends on '{}' which doesn't exist",
360                migration, dependency
361            ),
362            Self::NotReversible(id) => write!(f, "Migration '{}' is not reversible", id),
363            Self::DatabaseError(msg) => write!(f, "Database error: {}", msg),
364        }
365    }
366}
367
368impl std::error::Error for MigrationError {}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::migrations::column_builder::{bigint, boolean, varchar};
374    use crate::migrations::dialect::SqliteDialect;
375    use crate::migrations::table_builder::CreateTableBuilder;
376
377    // Test migrations
378    struct Migration0001;
379    impl Migration for Migration0001 {
380        const ID: &'static str = "0001_initial";
381        fn up() -> Vec<Operation> {
382            vec![CreateTableBuilder::new()
383                .name("users")
384                .column(bigint("id").primary_key().autoincrement().build())
385                .column(varchar("username", 255).not_null().build())
386                .build()
387                .into()]
388        }
389        fn down() -> Vec<Operation> {
390            vec![Operation::drop_table("users")]
391        }
392    }
393
394    struct Migration0002;
395    impl Migration for Migration0002 {
396        const ID: &'static str = "0002_add_email";
397        const DEPENDENCIES: &'static [&'static str] = &["0001_initial"];
398        fn up() -> Vec<Operation> {
399            vec![Operation::add_column(
400                "users",
401                varchar("email", 255).build(),
402            )]
403        }
404        fn down() -> Vec<Operation> {
405            vec![Operation::drop_column("users", "email")]
406        }
407    }
408
409    struct Migration0003;
410    impl Migration for Migration0003 {
411        const ID: &'static str = "0003_add_active";
412        const DEPENDENCIES: &'static [&'static str] = &["0002_add_email"];
413        fn up() -> Vec<Operation> {
414            vec![Operation::add_column(
415                "users",
416                boolean("active").not_null().default_bool(true).build(),
417            )]
418        }
419        fn down() -> Vec<Operation> {
420            vec![Operation::drop_column("users", "active")]
421        }
422    }
423
424    #[test]
425    fn test_register_migrations() {
426        let mut runner = MigrationRunner::new(SqliteDialect::new());
427        runner.register::<Migration0001>();
428        runner.register::<Migration0002>();
429
430        assert_eq!(runner.migrations().len(), 2);
431    }
432
433    #[test]
434    fn test_pending_migrations() {
435        let mut runner = MigrationRunner::new(SqliteDialect::new());
436        runner.register::<Migration0001>();
437        runner.register::<Migration0002>();
438
439        let state = MigrationState::new();
440        let pending = runner.pending_migrations(&state);
441        assert_eq!(pending.len(), 2);
442
443        let mut state = MigrationState::new();
444        state.mark_applied("0001_initial");
445        let pending = runner.pending_migrations(&state);
446        assert_eq!(pending.len(), 1);
447        assert_eq!(pending[0].id, "0002_add_email");
448    }
449
450    #[test]
451    fn test_topological_sort() {
452        let mut runner = MigrationRunner::new(SqliteDialect::new());
453        // Register in reverse order
454        runner.register::<Migration0003>();
455        runner.register::<Migration0001>();
456        runner.register::<Migration0002>();
457
458        let sorted = runner.sorted_migrations().unwrap();
459        let ids: Vec<_> = sorted.iter().map(|m| m.id).collect();
460
461        // 0001 must come before 0002, 0002 must come before 0003
462        let pos_0001 = ids.iter().position(|&id| id == "0001_initial").unwrap();
463        let pos_0002 = ids.iter().position(|&id| id == "0002_add_email").unwrap();
464        let pos_0003 = ids.iter().position(|&id| id == "0003_add_active").unwrap();
465
466        assert!(pos_0001 < pos_0002);
467        assert!(pos_0002 < pos_0003);
468    }
469
470    #[test]
471    fn test_sql_generation() {
472        let mut runner = MigrationRunner::new(SqliteDialect::new());
473        runner.register::<Migration0001>();
474
475        let state = MigrationState::new();
476        let sql = runner.sql_for_pending(&state).unwrap();
477
478        assert_eq!(sql.len(), 1);
479        assert_eq!(sql[0].0, "0001_initial");
480        assert!(!sql[0].1.is_empty());
481        assert!(sql[0].1[0].contains("CREATE TABLE"));
482    }
483
484    #[test]
485    fn test_rollback_sql() {
486        let mut runner = MigrationRunner::new(SqliteDialect::new());
487        runner.register::<Migration0001>();
488        runner.register::<Migration0002>();
489
490        let mut state = MigrationState::new();
491        state.mark_applied("0001_initial");
492        state.mark_applied("0002_add_email");
493
494        let sql = runner.sql_for_rollback(&state, 1).unwrap();
495        assert_eq!(sql.len(), 1);
496        assert_eq!(sql[0].0, "0002_add_email");
497        assert!(sql[0].1[0].contains("DROP COLUMN"));
498    }
499
500    #[test]
501    fn test_missing_dependency() {
502        struct BadMigration;
503        impl Migration for BadMigration {
504            const ID: &'static str = "bad_migration";
505            const DEPENDENCIES: &'static [&'static str] = &["nonexistent"];
506            fn up() -> Vec<Operation> {
507                vec![]
508            }
509            fn down() -> Vec<Operation> {
510                vec![]
511            }
512        }
513
514        let mut runner = MigrationRunner::new(SqliteDialect::new());
515        runner.register::<BadMigration>();
516
517        let result = runner.validate();
518        assert!(matches!(
519            result,
520            Err(MigrationError::MissingDependency { .. })
521        ));
522    }
523
524    #[test]
525    fn test_status() {
526        let mut runner = MigrationRunner::new(SqliteDialect::new());
527        runner.register::<Migration0001>();
528        runner.register::<Migration0002>();
529
530        let mut state = MigrationState::new();
531        state.mark_applied("0001_initial");
532
533        let status = runner.status(&state);
534        assert_eq!(status.len(), 2);
535
536        let s1 = status.iter().find(|s| s.id == "0001_initial").unwrap();
537        assert!(s1.applied);
538
539        let s2 = status.iter().find(|s| s.id == "0002_add_email").unwrap();
540        assert!(!s2.applied);
541    }
542}