oxidite_db/
migrations.rs

1//! Database migration system
2
3use std::path::{Path, PathBuf};
4use std::fs;
5use chrono::Utc;
6use sqlx::Row;
7
8/// Migration file
9#[derive(Debug)]
10pub struct Migration {
11    pub version: String,
12    pub name: String,
13    pub up_sql: String,
14    pub down_sql: String,
15}
16
17impl Migration {
18    pub fn new(name: &str) -> Self {
19        let timestamp = Utc::now().format("%Y%m%d%H%M%S").to_string();
20        let version = format!("{}_{}", timestamp, name);
21        
22        Self {
23            version,
24            name: name.to_string(),
25            up_sql: String::new(),
26            down_sql: String::new(),
27        }
28    }
29    
30    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
31        let path = path.as_ref();
32        let content = fs::read_to_string(path)?;
33        
34        let filename = path.file_stem()
35            .and_then(|s| s.to_str())
36            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid filename"))?;
37        
38        // Parse filename: 20240101120000_create_users
39        let parts: Vec<&str> = filename.splitn(2, '_').collect();
40        if parts.len() != 2 {
41            return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid migration filename"));
42        }
43        
44        let version = filename.to_string();
45        let name = parts[1].to_string();
46        
47        // Split content into up/down SQL
48        let sections: Vec<&str> = content.split("-- migrate:down").collect();
49        let up_sql = sections.get(0)
50            .unwrap_or(&"")
51            .replace("-- migrate:up", "")
52            .trim()
53            .to_string();
54        let down_sql = sections.get(1)
55            .unwrap_or(&"")
56            .trim()
57            .to_string();
58        
59        Ok(Self {
60            version,
61            name,
62            up_sql,
63            down_sql,
64        })
65    }
66    
67    pub fn save(&self, migrations_dir: impl AsRef<Path>) -> Result<PathBuf, std::io::Error> {
68        let migrations_dir = migrations_dir.as_ref();
69        fs::create_dir_all(migrations_dir)?;
70        
71        let filename = format!("{}.sql", self.version);
72        let path = migrations_dir.join(filename);
73        
74        let content = format!(
75            "-- migrate:up\n{}\n\n-- migrate:down\n{}\n",
76            self.up_sql,
77            self.down_sql
78        );
79        
80        fs::write(&path, content)?;
81        Ok(path)
82    }
83}
84
85/// Migration manager
86pub struct MigrationManager {
87    migrations_dir: PathBuf,
88}
89
90impl MigrationManager {
91    pub fn new(migrations_dir: impl AsRef<Path>) -> Self {
92        Self {
93            migrations_dir: migrations_dir.as_ref().to_path_buf(),
94        }
95    }
96    
97    pub fn list_migrations(&self) -> Result<Vec<Migration>, std::io::Error> {
98        let mut migrations = Vec::new();
99        
100        if !self.migrations_dir.exists() {
101            return Ok(migrations);
102        }
103        
104        for entry in fs::read_dir(&self.migrations_dir)? {
105            let entry = entry?;
106            let path = entry.path();
107            
108            if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("sql") {
109                if let Ok(migration) = Migration::from_file(&path) {
110                    migrations.push(migration);
111                }
112            }
113        }
114        
115        // Sort by version
116        migrations.sort_by(|a, b| a.version.cmp(&b.version));
117        
118        Ok(migrations)
119    }
120    
121    pub fn create_migration(&self, name: &str) -> Result<PathBuf, std::io::Error> {
122        let migration = Migration::new(name);
123        migration.save(&self.migrations_dir)
124    }
125    
126    /// Ensure migrations table exists
127    pub async fn ensure_migrations_table(&self, db: &impl crate::Database) -> crate::Result<()> {
128        let sql = r#"
129            CREATE TABLE IF NOT EXISTS _migrations (
130                id INTEGER PRIMARY KEY AUTOINCREMENT,
131                version TEXT NOT NULL UNIQUE,
132                applied_at INTEGER NOT NULL
133            )
134        "#;
135        db.execute(sql).await?;
136        Ok(())
137    }
138    
139    /// Get list of applied migrations
140    pub async fn get_applied_migrations(&self, db: &impl crate::Database) -> crate::Result<Vec<String>> {
141        self.ensure_migrations_table(db).await?;
142        
143        let rows = db.query("SELECT version FROM _migrations ORDER BY version").await?;
144        let mut versions = Vec::new();
145        
146        for row in rows {
147            if let Ok(version) = row.try_get::<String, _>("version") {
148                versions.push(version);
149            }
150        }
151        
152        Ok(versions)
153    }
154    
155    /// Mark migration as applied
156    pub async fn mark_migration_applied(&self, db: &impl crate::Database, version: &str) -> crate::Result<()> {
157        self.ensure_migrations_table(db).await?;
158        
159        let timestamp = chrono::Utc::now().timestamp();
160        let sql = format!(
161            "INSERT INTO _migrations (version, applied_at) VALUES ('{}', {})",
162            version, timestamp
163        );
164        db.execute(&sql).await?;
165        Ok(())
166    }
167    
168    /// Remove migration record (for rollback)
169    pub async fn mark_migration_reverted(&self, db: &impl crate::Database, version: &str) -> crate::Result<()> {
170        let sql = format!("DELETE FROM _migrations WHERE version = '{}'", version);
171        db.execute(&sql).await?;
172        Ok(())
173    }
174    
175    /// Get pending migrations
176    pub async fn get_pending_migrations(&self, db: &impl crate::Database) -> Result<Vec<Migration>, Box<dyn std::error::Error>> {
177        let all_migrations = self.list_migrations()?;
178        let applied = self.get_applied_migrations(db).await?;
179        
180        let pending: Vec<Migration> = all_migrations
181            .into_iter()
182            .filter(|m| !applied.contains(&m.version))
183            .collect();
184        
185        Ok(pending)
186    }
187}