db_migrate/
migration.rs

1use crate::{
2    config::Config,
3    utils::{calculate_checksum, extract_version_from_filename, parse_migration_content},
4    MigrationError, MigrationFile, MigrationRecord,
5};
6use anyhow::Result;
7use chrono::{TimeZone, Utc};
8use scylla::{Session, SessionBuilder};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use scylla::frame::value::CqlTimestamp;
12use tokio::fs;
13use tracing::{debug, info, warn};
14use walkdir::WalkDir;
15
16/// Main migration manager that handles all migration operations
17pub struct MigrationManager {
18    session: Session,
19    config: Config,
20}
21
22impl MigrationManager {
23    /// Create a new migration manager and establish database connection
24    pub async fn new(config: Config) -> Result<Self, MigrationError> {
25        info!("Connecting to ScyllaDB at: {:?}", config.database.hosts);
26
27        let mut session_builder = SessionBuilder::new().known_nodes(&config.database.hosts);
28
29        if !config.database.username.is_empty() {
30            session_builder =
31                session_builder.user(&config.database.username, &config.database.password);
32        }
33
34        let session = session_builder.build().await?;
35
36        let manager = Self { session, config };
37
38        // Ensure keyspace and migrations table exist
39        manager.initialize_schema().await?;
40
41        Ok(manager)
42    }
43
44    /// Initialize the keyspace and migrations tracking table
45    async fn initialize_schema(&self) -> Result<(), MigrationError> {
46        // Create keyspace if it doesn't exist and auto_create is enabled
47        if self.config.behavior.auto_create_keyspace {
48            let create_keyspace_query = format!(
49                "CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class': 'SimpleStrategy', 'replication_factor': 1}}",
50                self.config.database.keyspace
51            );
52
53            debug!("Creating keyspace: {}", create_keyspace_query);
54            self.session.query(create_keyspace_query, &[]).await?;
55        }
56
57        // Use the keyspace
58        let use_keyspace_query = format!("USE {}", self.config.database.keyspace);
59        self.session.query(use_keyspace_query, &[]).await?;
60
61        // Create migrations table
62        let create_table_query = format!(
63            "CREATE TABLE IF NOT EXISTS {} (
64                version TEXT PRIMARY KEY,
65                applied_at TIMESTAMP,
66                checksum TEXT,
67                description TEXT
68            )",
69            self.config.migrations.table_name
70        );
71
72        debug!("Creating migrations table: {}", create_table_query);
73        self.session.query(create_table_query, &[]).await?;
74
75        info!("Schema initialization completed");
76        Ok(())
77    }
78
79    /// Get all applied migrations from the database
80    pub async fn get_applied_migrations(&self) -> Result<Vec<MigrationRecord>, MigrationError> {
81        let query = format!(
82            "SELECT version, applied_at, checksum, description FROM {}",
83            self.config.migrations.table_name
84        );
85
86        let rows = self.session.query(query, &[]).await?;
87        let mut migrations = Vec::new();
88
89        for row in rows
90            .rows_typed::<(String, CqlTimestamp, String, String)>()
91            .map_err(|e| MigrationError::IntegrityError(e.to_string()))?
92        {
93            let (version, applied_at_ts, checksum, description) =
94                row.map_err(|e| MigrationError::IntegrityError(e.to_string()))?;
95
96            let applied_at = Utc
97                .timestamp_millis_opt(applied_at_ts.0)
98                .single()
99                .ok_or_else(|| MigrationError::IntegrityError("Invalid timestamp".into()))?;
100
101            migrations.push(MigrationRecord {
102                version,
103                applied_at,
104                checksum,
105                description,
106            });
107        }
108
109        Ok(migrations)
110    }
111
112    /// Get all migration files from the filesystem
113    pub async fn get_migration_files(&self) -> Result<Vec<MigrationFile>, MigrationError> {
114        let migrations_dir = &self.config.migrations.directory;
115
116        if !migrations_dir.exists() {
117            fs::create_dir_all(migrations_dir).await?;
118            return Ok(Vec::new());
119        }
120
121        let mut files = Vec::new();
122
123        for entry in WalkDir::new(migrations_dir)
124            .min_depth(1)
125            .max_depth(1)
126            .sort_by_file_name()
127        {
128            let entry = entry.map_err(|e| MigrationError::ConfigError(e.to_string()))?;
129            let path = entry.path();
130
131            if path.extension().and_then(|s| s.to_str()) != Some("cql") {
132                continue;
133            }
134
135            let filename = path
136                .file_name()
137                .and_then(|n| n.to_str())
138                .ok_or_else(|| MigrationError::InvalidFormat("Invalid filename".to_string()))?;
139
140            if let Some(version) = extract_version_from_filename(filename) {
141                let content = fs::read_to_string(path).await?;
142                let checksum = calculate_checksum(&content);
143                let description = crate::utils::extract_description_from_filename(filename);
144
145                files.push(MigrationFile {
146                    version,
147                    description,
148                    file_path: path.to_path_buf(),
149                    content,
150                    checksum,
151                });
152            } else {
153                warn!("Skipping file with invalid format: {}", filename);
154            }
155        }
156
157        Ok(files)
158    }
159
160    /// Get pending migrations (files that haven't been applied)
161    pub async fn get_pending_migrations(&self) -> Result<Vec<MigrationFile>, MigrationError> {
162        let applied = self.get_applied_migrations().await?;
163        let files = self.get_migration_files().await?;
164
165        let applied_versions: std::collections::HashSet<String> =
166            applied.into_iter().map(|m| m.version).collect();
167
168        let pending: Vec<MigrationFile> = files
169            .into_iter()
170            .filter(|f| !applied_versions.contains(&f.version))
171            .collect();
172
173        Ok(pending)
174    }
175
176    /// Apply a single migration
177    pub async fn apply_migration(
178        &mut self,
179        migration: &MigrationFile,
180    ) -> Result<(), MigrationError> {
181        info!("Applying migration: {}", migration.version);
182
183        // Check if already applied
184        if self.is_migration_applied(&migration.version).await? {
185            return Err(MigrationError::AlreadyApplied {
186                version: migration.version.clone(),
187            });
188        }
189
190        // Parse migration content
191        let (up_content, _down_content) = parse_migration_content(&migration.content)
192            .map_err(|e| MigrationError::InvalidFormat(e))?;
193
194        // Execute UP statements
195        for statement in split_cql_statements(&up_content) {
196            if !statement.trim().is_empty() {
197                debug!("Executing: {}", statement.trim());
198                self.session.query(statement, &[]).await?;
199            }
200        }
201
202        // Record the migration as applied
203        self.record_migration_applied(migration).await?;
204
205        info!("✅ Applied migration: {}", migration.version);
206        Ok(())
207    }
208
209    /// Rollback a single migration
210    pub async fn rollback_migration(&mut self, version: &str) -> Result<(), MigrationError> {
211        info!("Rolling back migration: {}", version);
212
213        // Check if migration is applied
214        if !self.is_migration_applied(version).await? {
215            return Err(MigrationError::MigrationNotFound(version.to_string()));
216        }
217
218        // Find the migration file
219        let files = self.get_migration_files().await?;
220        let migration_file = files
221            .iter()
222            .find(|f| f.version == version)
223            .ok_or_else(|| MigrationError::MigrationNotFound(version.to_string()))?;
224
225        // Parse migration content
226        let (_up_content, down_content) = parse_migration_content(&migration_file.content)
227            .map_err(|e| MigrationError::InvalidFormat(e))?;
228
229        let down_content = down_content.ok_or_else(|| MigrationError::RollbackError {
230            version: version.to_string(),
231            reason: "No DOWN section found in migration".to_string(),
232        })?;
233
234        // Execute DOWN statements
235        for statement in split_cql_statements(&down_content) {
236            if !statement.trim().is_empty() {
237                debug!("Executing rollback: {}", statement.trim());
238                self.session.query(statement, &[]).await?;
239            }
240        }
241
242        // Remove the migration record
243        self.remove_migration_record(version).await?;
244
245        info!("✅ Rolled back migration: {}", version);
246        Ok(())
247    }
248
249    /// Check if a migration is already applied
250    pub async fn is_migration_applied(&self, version: &str) -> Result<bool, MigrationError> {
251        let query = format!(
252            "SELECT version FROM {} WHERE version = ? LIMIT 1",
253            self.config.migrations.table_name
254        );
255
256        let rows = self.session.query(query, (version,)).await?;
257        Ok(!rows.rows.unwrap_or_default().is_empty())
258    }
259
260    /// Record a migration as applied
261    async fn record_migration_applied(
262        &self,
263        migration: &MigrationFile,
264    ) -> Result<(), MigrationError> {
265        let query = format!(
266            "INSERT INTO {} (version, applied_at, checksum, description) VALUES (?, ?, ?, ?)",
267            self.config.migrations.table_name
268        );
269
270        self.session
271            .query(
272                query,
273                (
274                    &migration.version,
275                    CqlTimestamp(Utc::now().timestamp_millis()),
276                    &migration.checksum,
277                    &migration.description,
278                ),
279            )
280            .await?;
281
282        Ok(())
283    }
284
285    /// Remove a migration record
286    pub(crate) async fn remove_migration_record(
287        &self,
288        version: &str,
289    ) -> Result<(), MigrationError> {
290        let query = format!(
291            "DELETE FROM {} WHERE version = ?",
292            self.config.migrations.table_name
293        );
294
295        self.session.query(query, (version,)).await?;
296        Ok(())
297    }
298
299    /// Verify migration integrity (check checksums)
300    pub async fn verify_migrations(&self) -> Result<Vec<MigrationError>, MigrationError> {
301        let applied = self.get_applied_migrations().await?;
302        let files = self.get_migration_files().await?;
303
304        let file_map: HashMap<String, &MigrationFile> =
305            files.iter().map(|f| (f.version.clone(), f)).collect();
306
307        let mut errors = Vec::new();
308
309        for applied_migration in applied {
310            if let Some(file) = file_map.get(&applied_migration.version) {
311                if file.checksum != applied_migration.checksum {
312                    errors.push(MigrationError::ChecksumMismatch {
313                        version: applied_migration.version,
314                        expected: applied_migration.checksum,
315                        actual: file.checksum.clone(),
316                    });
317                }
318            } else {
319                errors.push(MigrationError::MigrationNotFound(applied_migration.version));
320            }
321        }
322
323        Ok(errors)
324    }
325
326    /// Reset all migrations (destructive operation)
327    pub async fn reset_migrations(&mut self) -> Result<(), MigrationError> {
328        if !self.config.behavior.allow_destructive {
329            return Err(MigrationError::ConfigError(
330                "Destructive operations are disabled in configuration".to_string(),
331            ));
332        }
333
334        warn!("Resetting all migrations - this is destructive!");
335
336        // Drop and recreate the migrations table
337        let drop_query = format!("DROP TABLE IF EXISTS {}", self.config.migrations.table_name);
338        self.session.query(drop_query, &[]).await?;
339
340        self.initialize_schema().await?;
341
342        info!("✅ All migrations reset");
343        Ok(())
344    }
345
346    /// Get the configuration
347    pub fn get_config(&self) -> &Config {
348        &self.config
349    }
350
351    /// Update the checksum of an existing migration record
352    pub async fn update_migration_checksum(
353        &self,
354        version: &str,
355        new_checksum: &str,
356    ) -> Result<(), MigrationError> {
357        let query = format!(
358            "UPDATE {} SET checksum = ? WHERE version = ?",
359            self.config.migrations.table_name
360        );
361
362        self.session.query(query, (new_checksum, version)).await?;
363        Ok(())
364    }
365
366    /// Create a new migration file
367    pub async fn create_migration_file(
368        &self,
369        description: &str,
370    ) -> Result<PathBuf, MigrationError> {
371        let filename = crate::utils::create_migration_filename(description);
372        let file_path = self.config.migrations.directory.join(&filename);
373
374        // Ensure migrations directory exists
375        if let Some(parent) = file_path.parent() {
376            fs::create_dir_all(parent).await?;
377        }
378
379        // Generate template content
380        let content = crate::utils::generate_migration_template(description);
381
382        // Write the file
383        fs::write(&file_path, content).await?;
384
385        info!("✅ Created migration file: {}", filename);
386        Ok(file_path)
387    }
388}
389
390/// Split CQL content into individual statements
391fn split_cql_statements(content: &str) -> Vec<String> {
392    content
393        .split(';')
394        .map(|s| s.trim().to_string())
395        .filter(|s| !s.is_empty())
396        .collect()
397}