1use crate::{
2 config::Config,
3 utils::{calculate_checksum, extract_version_from_filename, parse_migration_content},
4 MigrationError, MigrationFile, MigrationRecord,
5};
6use anyhow::Result;
7use chrono::{TimeZone, Utc};
8use scylla::{Session, SessionBuilder};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use scylla::frame::value::CqlTimestamp;
12use tokio::fs;
13use tracing::{debug, info, warn};
14use walkdir::WalkDir;
15
16pub struct MigrationManager {
18 session: Session,
19 config: Config,
20}
21
22impl MigrationManager {
23 pub async fn new(config: Config) -> Result<Self, MigrationError> {
25 info!("Connecting to ScyllaDB at: {:?}", config.database.hosts);
26
27 let mut session_builder = SessionBuilder::new().known_nodes(&config.database.hosts);
28
29 if !config.database.username.is_empty() {
30 session_builder =
31 session_builder.user(&config.database.username, &config.database.password);
32 }
33
34 let session = session_builder.build().await?;
35
36 let manager = Self { session, config };
37
38 manager.initialize_schema().await?;
40
41 Ok(manager)
42 }
43
44 async fn initialize_schema(&self) -> Result<(), MigrationError> {
46 if self.config.behavior.auto_create_keyspace {
48 let create_keyspace_query = format!(
49 "CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class': 'SimpleStrategy', 'replication_factor': 1}}",
50 self.config.database.keyspace
51 );
52
53 debug!("Creating keyspace: {}", create_keyspace_query);
54 self.session.query(create_keyspace_query, &[]).await?;
55 }
56
57 let use_keyspace_query = format!("USE {}", self.config.database.keyspace);
59 self.session.query(use_keyspace_query, &[]).await?;
60
61 let create_table_query = format!(
63 "CREATE TABLE IF NOT EXISTS {} (
64 version TEXT PRIMARY KEY,
65 applied_at TIMESTAMP,
66 checksum TEXT,
67 description TEXT
68 )",
69 self.config.migrations.table_name
70 );
71
72 debug!("Creating migrations table: {}", create_table_query);
73 self.session.query(create_table_query, &[]).await?;
74
75 info!("Schema initialization completed");
76 Ok(())
77 }
78
79 pub async fn get_applied_migrations(&self) -> Result<Vec<MigrationRecord>, MigrationError> {
81 let query = format!(
82 "SELECT version, applied_at, checksum, description FROM {}",
83 self.config.migrations.table_name
84 );
85
86 let rows = self.session.query(query, &[]).await?;
87 let mut migrations = Vec::new();
88
89 for row in rows
90 .rows_typed::<(String, CqlTimestamp, String, String)>()
91 .map_err(|e| MigrationError::IntegrityError(e.to_string()))?
92 {
93 let (version, applied_at_ts, checksum, description) =
94 row.map_err(|e| MigrationError::IntegrityError(e.to_string()))?;
95
96 let applied_at = Utc
97 .timestamp_millis_opt(applied_at_ts.0)
98 .single()
99 .ok_or_else(|| MigrationError::IntegrityError("Invalid timestamp".into()))?;
100
101 migrations.push(MigrationRecord {
102 version,
103 applied_at,
104 checksum,
105 description,
106 });
107 }
108
109 Ok(migrations)
110 }
111
112 pub async fn get_migration_files(&self) -> Result<Vec<MigrationFile>, MigrationError> {
114 let migrations_dir = &self.config.migrations.directory;
115
116 if !migrations_dir.exists() {
117 fs::create_dir_all(migrations_dir).await?;
118 return Ok(Vec::new());
119 }
120
121 let mut files = Vec::new();
122
123 for entry in WalkDir::new(migrations_dir)
124 .min_depth(1)
125 .max_depth(1)
126 .sort_by_file_name()
127 {
128 let entry = entry.map_err(|e| MigrationError::ConfigError(e.to_string()))?;
129 let path = entry.path();
130
131 if path.extension().and_then(|s| s.to_str()) != Some("cql") {
132 continue;
133 }
134
135 let filename = path
136 .file_name()
137 .and_then(|n| n.to_str())
138 .ok_or_else(|| MigrationError::InvalidFormat("Invalid filename".to_string()))?;
139
140 if let Some(version) = extract_version_from_filename(filename) {
141 let content = fs::read_to_string(path).await?;
142 let checksum = calculate_checksum(&content);
143 let description = crate::utils::extract_description_from_filename(filename);
144
145 files.push(MigrationFile {
146 version,
147 description,
148 file_path: path.to_path_buf(),
149 content,
150 checksum,
151 });
152 } else {
153 warn!("Skipping file with invalid format: {}", filename);
154 }
155 }
156
157 Ok(files)
158 }
159
160 pub async fn get_pending_migrations(&self) -> Result<Vec<MigrationFile>, MigrationError> {
162 let applied = self.get_applied_migrations().await?;
163 let files = self.get_migration_files().await?;
164
165 let applied_versions: std::collections::HashSet<String> =
166 applied.into_iter().map(|m| m.version).collect();
167
168 let pending: Vec<MigrationFile> = files
169 .into_iter()
170 .filter(|f| !applied_versions.contains(&f.version))
171 .collect();
172
173 Ok(pending)
174 }
175
176 pub async fn apply_migration(
178 &mut self,
179 migration: &MigrationFile,
180 ) -> Result<(), MigrationError> {
181 info!("Applying migration: {}", migration.version);
182
183 if self.is_migration_applied(&migration.version).await? {
185 return Err(MigrationError::AlreadyApplied {
186 version: migration.version.clone(),
187 });
188 }
189
190 let (up_content, _down_content) = parse_migration_content(&migration.content)
192 .map_err(|e| MigrationError::InvalidFormat(e))?;
193
194 for statement in split_cql_statements(&up_content) {
196 if !statement.trim().is_empty() {
197 debug!("Executing: {}", statement.trim());
198 self.session.query(statement, &[]).await?;
199 }
200 }
201
202 self.record_migration_applied(migration).await?;
204
205 info!("✅ Applied migration: {}", migration.version);
206 Ok(())
207 }
208
209 pub async fn rollback_migration(&mut self, version: &str) -> Result<(), MigrationError> {
211 info!("Rolling back migration: {}", version);
212
213 if !self.is_migration_applied(version).await? {
215 return Err(MigrationError::MigrationNotFound(version.to_string()));
216 }
217
218 let files = self.get_migration_files().await?;
220 let migration_file = files
221 .iter()
222 .find(|f| f.version == version)
223 .ok_or_else(|| MigrationError::MigrationNotFound(version.to_string()))?;
224
225 let (_up_content, down_content) = parse_migration_content(&migration_file.content)
227 .map_err(|e| MigrationError::InvalidFormat(e))?;
228
229 let down_content = down_content.ok_or_else(|| MigrationError::RollbackError {
230 version: version.to_string(),
231 reason: "No DOWN section found in migration".to_string(),
232 })?;
233
234 for statement in split_cql_statements(&down_content) {
236 if !statement.trim().is_empty() {
237 debug!("Executing rollback: {}", statement.trim());
238 self.session.query(statement, &[]).await?;
239 }
240 }
241
242 self.remove_migration_record(version).await?;
244
245 info!("✅ Rolled back migration: {}", version);
246 Ok(())
247 }
248
249 pub async fn is_migration_applied(&self, version: &str) -> Result<bool, MigrationError> {
251 let query = format!(
252 "SELECT version FROM {} WHERE version = ? LIMIT 1",
253 self.config.migrations.table_name
254 );
255
256 let rows = self.session.query(query, (version,)).await?;
257 Ok(!rows.rows.unwrap_or_default().is_empty())
258 }
259
260 async fn record_migration_applied(
262 &self,
263 migration: &MigrationFile,
264 ) -> Result<(), MigrationError> {
265 let query = format!(
266 "INSERT INTO {} (version, applied_at, checksum, description) VALUES (?, ?, ?, ?)",
267 self.config.migrations.table_name
268 );
269
270 self.session
271 .query(
272 query,
273 (
274 &migration.version,
275 CqlTimestamp(Utc::now().timestamp_millis()),
276 &migration.checksum,
277 &migration.description,
278 ),
279 )
280 .await?;
281
282 Ok(())
283 }
284
285 pub(crate) async fn remove_migration_record(
287 &self,
288 version: &str,
289 ) -> Result<(), MigrationError> {
290 let query = format!(
291 "DELETE FROM {} WHERE version = ?",
292 self.config.migrations.table_name
293 );
294
295 self.session.query(query, (version,)).await?;
296 Ok(())
297 }
298
299 pub async fn verify_migrations(&self) -> Result<Vec<MigrationError>, MigrationError> {
301 let applied = self.get_applied_migrations().await?;
302 let files = self.get_migration_files().await?;
303
304 let file_map: HashMap<String, &MigrationFile> =
305 files.iter().map(|f| (f.version.clone(), f)).collect();
306
307 let mut errors = Vec::new();
308
309 for applied_migration in applied {
310 if let Some(file) = file_map.get(&applied_migration.version) {
311 if file.checksum != applied_migration.checksum {
312 errors.push(MigrationError::ChecksumMismatch {
313 version: applied_migration.version,
314 expected: applied_migration.checksum,
315 actual: file.checksum.clone(),
316 });
317 }
318 } else {
319 errors.push(MigrationError::MigrationNotFound(applied_migration.version));
320 }
321 }
322
323 Ok(errors)
324 }
325
326 pub async fn reset_migrations(&mut self) -> Result<(), MigrationError> {
328 if !self.config.behavior.allow_destructive {
329 return Err(MigrationError::ConfigError(
330 "Destructive operations are disabled in configuration".to_string(),
331 ));
332 }
333
334 warn!("Resetting all migrations - this is destructive!");
335
336 let drop_query = format!("DROP TABLE IF EXISTS {}", self.config.migrations.table_name);
338 self.session.query(drop_query, &[]).await?;
339
340 self.initialize_schema().await?;
341
342 info!("✅ All migrations reset");
343 Ok(())
344 }
345
346 pub fn get_config(&self) -> &Config {
348 &self.config
349 }
350
351 pub async fn update_migration_checksum(
353 &self,
354 version: &str,
355 new_checksum: &str,
356 ) -> Result<(), MigrationError> {
357 let query = format!(
358 "UPDATE {} SET checksum = ? WHERE version = ?",
359 self.config.migrations.table_name
360 );
361
362 self.session.query(query, (new_checksum, version)).await?;
363 Ok(())
364 }
365
366 pub async fn create_migration_file(
368 &self,
369 description: &str,
370 ) -> Result<PathBuf, MigrationError> {
371 let filename = crate::utils::create_migration_filename(description);
372 let file_path = self.config.migrations.directory.join(&filename);
373
374 if let Some(parent) = file_path.parent() {
376 fs::create_dir_all(parent).await?;
377 }
378
379 let content = crate::utils::generate_migration_template(description);
381
382 fs::write(&file_path, content).await?;
384
385 info!("✅ Created migration file: {}", filename);
386 Ok(file_path)
387 }
388}
389
390fn split_cql_statements(content: &str) -> Vec<String> {
392 content
393 .split(';')
394 .map(|s| s.trim().to_string())
395 .filter(|s| !s.is_empty())
396 .collect()
397}