lib_migrations_sql/
migration.rs

1use lib_migrations_core::{Migration, Phase};
2
3/// Trait for SQL execution contexts.
4///
5/// Implement this for your database connection type.
6pub trait SqlExecutor {
7    type Error: std::error::Error + Send + Sync + 'static;
8
9    /// Execute SQL statement(s)
10    fn execute(&mut self, sql: &str) -> std::result::Result<(), Self::Error>;
11}
12
13/// A migration defined by SQL strings.
14///
15/// Works with any context that implements `SqlExecutor`.
16pub struct SqlMigration {
17    version: u64,
18    name: String,
19    phase: Phase,
20    up_sql: String,
21    down_sql: Option<String>,
22}
23
24impl SqlMigration {
25    /// Create a new SQL migration
26    pub fn new(version: u64, name: impl Into<String>, up_sql: impl Into<String>) -> Self {
27        Self {
28            version,
29            name: name.into(),
30            phase: Phase::PreDeploy,
31            up_sql: up_sql.into(),
32            down_sql: None,
33        }
34    }
35
36    /// Set the deployment phase
37    pub fn phase(mut self, phase: Phase) -> Self {
38        self.phase = phase;
39        self
40    }
41
42    /// Add rollback SQL
43    pub fn with_down(mut self, down_sql: impl Into<String>) -> Self {
44        self.down_sql = Some(down_sql.into());
45        self
46    }
47
48    /// Get the up SQL
49    pub fn up_sql(&self) -> &str {
50        &self.up_sql
51    }
52
53    /// Get the down SQL
54    pub fn down_sql(&self) -> Option<&str> {
55        self.down_sql.as_deref()
56    }
57
58    /// Get the version
59    pub fn version(&self) -> u64 {
60        self.version
61    }
62
63    /// Get the name
64    pub fn name(&self) -> &str {
65        &self.name
66    }
67
68    /// Get the deployment phase
69    pub fn get_phase(&self) -> Phase {
70        self.phase
71    }
72
73    /// Whether this migration has rollback SQL
74    pub fn has_rollback(&self) -> bool {
75        self.down_sql.is_some()
76    }
77}
78
79impl<Ctx> Migration<Ctx> for SqlMigration
80where
81    Ctx: SqlExecutor,
82{
83    fn version(&self) -> u64 {
84        self.version
85    }
86
87    fn name(&self) -> &str {
88        &self.name
89    }
90
91    fn phase(&self) -> Phase {
92        self.phase
93    }
94
95    fn apply(&self, ctx: &mut Ctx) -> lib_migrations_core::Result<()> {
96        ctx.execute(&self.up_sql)
97            .map_err(|e| lib_migrations_core::Error::failed(self.version, e.to_string()))
98    }
99
100    fn rollback(&self, ctx: &mut Ctx) -> lib_migrations_core::Result<()> {
101        match &self.down_sql {
102            Some(sql) => ctx
103                .execute(sql)
104                .map_err(|e| lib_migrations_core::Error::failed(self.version, e.to_string())),
105            None => Err(lib_migrations_core::Error::RollbackNotSupported(
106                self.version,
107            )),
108        }
109    }
110
111    fn can_rollback(&self) -> bool {
112        self.down_sql.is_some()
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_sql_migration() {
122        let migration = SqlMigration::new(1, "create_users", "CREATE TABLE users (id INTEGER)")
123            .with_down("DROP TABLE users");
124
125        assert_eq!(migration.version(), 1);
126        assert_eq!(migration.name(), "create_users");
127        assert_eq!(migration.up_sql(), "CREATE TABLE users (id INTEGER)");
128        assert_eq!(migration.down_sql(), Some("DROP TABLE users"));
129        assert!(migration.has_rollback());
130    }
131
132    #[test]
133    fn test_sql_migration_no_rollback() {
134        let migration = SqlMigration::new(1, "create_users", "CREATE TABLE users (id INTEGER)");
135
136        assert!(!migration.has_rollback());
137        assert_eq!(migration.down_sql(), None);
138    }
139
140    #[test]
141    fn test_sql_migration_phase() {
142        let pre = SqlMigration::new(1, "add_column", "ALTER TABLE users ADD email TEXT");
143        assert_eq!(pre.get_phase(), Phase::PreDeploy);
144
145        let post = SqlMigration::new(2, "drop_column", "ALTER TABLE users DROP old_column")
146            .phase(Phase::PostDeploy);
147        assert_eq!(post.get_phase(), Phase::PostDeploy);
148    }
149}