1use crate::error::{DbError, Result};
38use chrono::{DateTime, Utc};
39use serde::{Deserialize, Serialize};
40use sqlx::{PgPool, Postgres, Transaction};
41use std::path::Path;
42use tracing::{info, warn};
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct MigrationConfig {
47 pub migrations_dir: String,
49
50 pub allow_missing: bool,
52
53 pub dry_run: bool,
55}
56
57impl Default for MigrationConfig {
58 fn default() -> Self {
59 Self {
60 migrations_dir: "migrations".to_string(),
61 allow_missing: false,
62 dry_run: false,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct Migration {
70 pub version: String,
72
73 pub name: String,
75
76 pub sql: String,
78
79 pub file_path: String,
81}
82
83impl Migration {
84 pub fn from_file(file_path: &str) -> Result<Self> {
86 let path = Path::new(file_path);
87 let filename = path
88 .file_name()
89 .and_then(|n| n.to_str())
90 .ok_or_else(|| DbError::NotFound("Invalid migration filename".to_string()))?;
91
92 let parts: Vec<&str> = filename.trim_end_matches(".sql").splitn(2, '_').collect();
94
95 if parts.len() != 2 {
96 return Err(DbError::NotFound(format!(
97 "Invalid migration filename format: {}",
98 filename
99 )));
100 }
101
102 let version = parts[0].to_string();
103 let name = parts[1].to_string();
104
105 let sql = std::fs::read_to_string(file_path)
106 .map_err(|e| DbError::NotFound(format!("Failed to read migration file: {}", e)))?;
107
108 Ok(Self {
109 version,
110 name,
111 sql,
112 file_path: file_path.to_string(),
113 })
114 }
115
116 pub fn id(&self) -> String {
118 format!("{}_{}", self.version, self.name)
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct MigrationResult {
125 pub applied_count: usize,
127
128 pub skipped_count: usize,
130
131 pub applied_migrations: Vec<String>,
133
134 pub skipped_migrations: Vec<String>,
136
137 pub success: bool,
139
140 pub error: Option<String>,
142
143 pub execution_time_ms: u64,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct MigrationHistory {
150 pub migration_id: String,
152
153 pub applied_at: DateTime<Utc>,
155
156 pub execution_time_ms: i64,
158
159 pub success: bool,
161}
162
163pub struct MigrationRunner {
165 pool: PgPool,
166 config: MigrationConfig,
167}
168
169impl MigrationRunner {
170 pub fn new(pool: PgPool, config: MigrationConfig) -> Self {
172 Self { pool, config }
173 }
174
175 pub fn with_defaults(pool: PgPool) -> Self {
177 Self::new(pool, MigrationConfig::default())
178 }
179
180 async fn ensure_migrations_table(&self) -> Result<()> {
182 sqlx::query(
183 r#"
184 CREATE TABLE IF NOT EXISTS _migrations (
185 id SERIAL PRIMARY KEY,
186 migration_id VARCHAR(255) NOT NULL UNIQUE,
187 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
188 execution_time_ms BIGINT NOT NULL,
189 success BOOLEAN NOT NULL DEFAULT TRUE
190 )
191 "#,
192 )
193 .execute(&self.pool)
194 .await?;
195
196 Ok(())
197 }
198
199 async fn get_applied_migrations(&self) -> Result<Vec<String>> {
201 let records = sqlx::query_scalar::<_, String>(
202 "SELECT migration_id FROM _migrations WHERE success = TRUE ORDER BY id",
203 )
204 .fetch_all(&self.pool)
205 .await?;
206
207 Ok(records)
208 }
209
210 #[allow(dead_code)]
212 async fn is_migration_applied(&self, migration_id: &str) -> Result<bool> {
213 let count = sqlx::query_scalar::<_, i64>(
214 "SELECT COUNT(*) FROM _migrations WHERE migration_id = $1 AND success = TRUE",
215 )
216 .bind(migration_id)
217 .fetch_one(&self.pool)
218 .await?;
219
220 Ok(count > 0)
221 }
222
223 async fn record_migration(
225 &self,
226 tx: &mut Transaction<'_, Postgres>,
227 migration_id: &str,
228 execution_time_ms: i64,
229 success: bool,
230 ) -> Result<()> {
231 sqlx::query(
232 r#"
233 INSERT INTO _migrations (migration_id, applied_at, execution_time_ms, success)
234 VALUES ($1, NOW(), $2, $3)
235 "#,
236 )
237 .bind(migration_id)
238 .bind(execution_time_ms)
239 .bind(success)
240 .execute(&mut **tx)
241 .await?;
242
243 Ok(())
244 }
245
246 pub fn discover_migrations(&self) -> Result<Vec<Migration>> {
248 let migrations_dir = Path::new(&self.config.migrations_dir);
249
250 if !migrations_dir.exists() {
251 return Err(DbError::NotFound(format!(
252 "Migrations directory not found: {}",
253 self.config.migrations_dir
254 )));
255 }
256
257 let mut migrations = Vec::new();
258
259 let entries = std::fs::read_dir(migrations_dir).map_err(|e| {
260 DbError::NotFound(format!("Failed to read migrations directory: {}", e))
261 })?;
262
263 for entry in entries {
264 let entry = entry
265 .map_err(|e| DbError::NotFound(format!("Failed to read directory entry: {}", e)))?;
266
267 let path = entry.path();
268
269 if path.extension().and_then(|s| s.to_str()) == Some("sql") {
270 let migration = Migration::from_file(path.to_str().unwrap())?;
271 migrations.push(migration);
272 }
273 }
274
275 migrations.sort_by(|a, b| a.version.cmp(&b.version));
277
278 Ok(migrations)
279 }
280
281 pub async fn run_pending_migrations(&self) -> Result<MigrationResult> {
283 let start_time = std::time::Instant::now();
284
285 self.ensure_migrations_table().await?;
287
288 let all_migrations = self.discover_migrations()?;
290 let applied_migrations = self.get_applied_migrations().await?;
291
292 let mut result = MigrationResult {
293 applied_count: 0,
294 skipped_count: 0,
295 applied_migrations: Vec::new(),
296 skipped_migrations: Vec::new(),
297 success: true,
298 error: None,
299 execution_time_ms: 0,
300 };
301
302 for migration in all_migrations {
303 let migration_id = migration.id();
304
305 if applied_migrations.contains(&migration_id) {
307 info!(
308 migration_id = migration_id,
309 "Skipping already applied migration"
310 );
311 result.skipped_count += 1;
312 result.skipped_migrations.push(migration_id);
313 continue;
314 }
315
316 if self.config.dry_run {
317 info!(
318 migration_id = migration_id,
319 "Dry-run: Would apply migration"
320 );
321 result.applied_count += 1;
322 result.applied_migrations.push(migration_id);
323 continue;
324 }
325
326 match self.execute_migration(&migration).await {
328 Ok(execution_time_ms) => {
329 info!(
330 migration_id = migration_id,
331 execution_time_ms = execution_time_ms,
332 "Migration applied successfully"
333 );
334 result.applied_count += 1;
335 result.applied_migrations.push(migration_id);
336 }
337 Err(e) => {
338 warn!(
339 migration_id = migration_id,
340 error = %e,
341 "Migration failed"
342 );
343 result.success = false;
344 result.error = Some(format!("Migration {} failed: {}", migration_id, e));
345 break;
346 }
347 }
348 }
349
350 result.execution_time_ms = start_time.elapsed().as_millis() as u64;
351
352 Ok(result)
353 }
354
355 async fn execute_migration(&self, migration: &Migration) -> Result<i64> {
357 let start_time = std::time::Instant::now();
358
359 let mut tx = self.pool.begin().await?;
360
361 match sqlx::query(&migration.sql).execute(&mut *tx).await {
363 Ok(_) => {
364 let execution_time_ms = start_time.elapsed().as_millis() as i64;
365
366 self.record_migration(&mut tx, &migration.id(), execution_time_ms, true)
368 .await?;
369
370 tx.commit().await?;
372
373 Ok(execution_time_ms)
374 }
375 Err(e) => {
376 tx.rollback().await?;
378
379 Err(DbError::from(e))
380 }
381 }
382 }
383
384 pub async fn get_migration_history(&self) -> Result<Vec<MigrationHistory>> {
386 self.ensure_migrations_table().await?;
387
388 let records = sqlx::query_as::<_, (String, DateTime<Utc>, i64, bool)>(
389 "SELECT migration_id, applied_at, execution_time_ms, success FROM _migrations ORDER BY id",
390 )
391 .fetch_all(&self.pool)
392 .await?;
393
394 Ok(records
395 .into_iter()
396 .map(
397 |(migration_id, applied_at, execution_time_ms, success)| MigrationHistory {
398 migration_id,
399 applied_at,
400 execution_time_ms,
401 success,
402 },
403 )
404 .collect())
405 }
406
407 pub fn verify_migration_order(&self, migrations: &[Migration]) -> Result<()> {
409 if self.config.allow_missing {
410 return Ok(());
411 }
412
413 for (i, migration) in migrations.iter().enumerate() {
414 let expected_version = format!("{:03}", i + 1);
415 if migration.version != expected_version {
416 return Err(DbError::NotFound(format!(
417 "Migration version mismatch: expected {}, found {}",
418 expected_version, migration.version
419 )));
420 }
421 }
422
423 Ok(())
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_migration_config_default() {
433 let config = MigrationConfig::default();
434 assert_eq!(config.migrations_dir, "migrations");
435 assert!(!config.allow_missing);
436 assert!(!config.dry_run);
437 }
438
439 #[test]
440 fn test_migration_id() {
441 let migration = Migration {
442 version: "001".to_string(),
443 name: "initial".to_string(),
444 sql: "CREATE TABLE test();".to_string(),
445 file_path: "migrations/001_initial.sql".to_string(),
446 };
447
448 assert_eq!(migration.id(), "001_initial");
449 }
450
451 #[test]
452 fn test_migration_result_serialization() {
453 let result = MigrationResult {
454 applied_count: 3,
455 skipped_count: 2,
456 applied_migrations: vec!["001_initial".to_string()],
457 skipped_migrations: vec!["000_setup".to_string()],
458 success: true,
459 error: None,
460 execution_time_ms: 1500,
461 };
462
463 let json = serde_json::to_string(&result).unwrap();
464 assert!(json.contains("applied_count"));
465 assert!(json.contains("\"success\":true"));
466 }
467
468 #[test]
469 fn test_migration_history_serialization() {
470 let history = MigrationHistory {
471 migration_id: "001_initial".to_string(),
472 applied_at: Utc::now(),
473 execution_time_ms: 1000,
474 success: true,
475 };
476
477 let json = serde_json::to_string(&history).unwrap();
478 assert!(json.contains("migration_id"));
479 assert!(json.contains("001_initial"));
480 }
481
482 #[tokio::test]
483 async fn test_verify_migration_order_success() {
484 let migrations = vec![
485 Migration {
486 version: "001".to_string(),
487 name: "first".to_string(),
488 sql: String::new(),
489 file_path: String::new(),
490 },
491 Migration {
492 version: "002".to_string(),
493 name: "second".to_string(),
494 sql: String::new(),
495 file_path: String::new(),
496 },
497 Migration {
498 version: "003".to_string(),
499 name: "third".to_string(),
500 sql: String::new(),
501 file_path: String::new(),
502 },
503 ];
504
505 let pool = PgPool::connect_lazy("postgresql://localhost/test").unwrap();
506 let runner = MigrationRunner::with_defaults(pool);
507
508 assert!(runner.verify_migration_order(&migrations).is_ok());
509 }
510
511 #[tokio::test]
512 async fn test_verify_migration_order_failure() {
513 let migrations = vec![
514 Migration {
515 version: "001".to_string(),
516 name: "first".to_string(),
517 sql: String::new(),
518 file_path: String::new(),
519 },
520 Migration {
521 version: "003".to_string(), name: "third".to_string(),
523 sql: String::new(),
524 file_path: String::new(),
525 },
526 ];
527
528 let pool = PgPool::connect_lazy("postgresql://localhost/test").unwrap();
529 let runner = MigrationRunner::with_defaults(pool);
530
531 assert!(runner.verify_migration_order(&migrations).is_err());
532 }
533
534 #[tokio::test]
535 async fn test_verify_migration_order_with_allow_missing() {
536 let migrations = vec![
537 Migration {
538 version: "001".to_string(),
539 name: "first".to_string(),
540 sql: String::new(),
541 file_path: String::new(),
542 },
543 Migration {
544 version: "003".to_string(),
545 name: "third".to_string(),
546 sql: String::new(),
547 file_path: String::new(),
548 },
549 ];
550
551 let pool = PgPool::connect_lazy("postgresql://localhost/test").unwrap();
552 let config = MigrationConfig {
553 allow_missing: true,
554 ..Default::default()
555 };
556 let runner = MigrationRunner::new(pool, config);
557
558 assert!(runner.verify_migration_order(&migrations).is_ok());
559 }
560}