1use crate::ddl::DdlGenerator;
10use crate::diff::SchemaOperation;
11use asupersync::{Cx, Outcome};
12use sqlmodel_core::{Connection, Error, Value};
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16#[derive(Debug, Clone)]
18pub struct Migration {
19 pub id: String,
21 pub description: String,
23 pub up: String,
25 pub down: String,
27}
28
29impl Migration {
30 pub fn new(
32 id: impl Into<String>,
33 description: impl Into<String>,
34 up: impl Into<String>,
35 down: impl Into<String>,
36 ) -> Self {
37 Self {
38 id: id.into(),
39 description: description.into(),
40 up: up.into(),
41 down: down.into(),
42 }
43 }
44
45 #[must_use]
49 pub fn new_version() -> String {
50 use std::time::{SystemTime, UNIX_EPOCH};
51 let now = SystemTime::now()
52 .duration_since(UNIX_EPOCH)
53 .map_or(0, |d| d.as_secs());
54
55 let days = now / 86400;
57 let secs = now % 86400;
58 let hours = secs / 3600;
59 let mins = (secs % 3600) / 60;
60 let secs = secs % 60;
61
62 let mut year = 1970;
64 let mut remaining_days = days as i64;
65
66 loop {
67 let days_in_year = if is_leap_year(year) { 366 } else { 365 };
68 if remaining_days < days_in_year {
69 break;
70 }
71 remaining_days -= days_in_year;
72 year += 1;
73 }
74
75 let months_days: [i64; 12] = if is_leap_year(year) {
76 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
77 } else {
78 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
79 };
80
81 let mut month = 1;
82 for days_in_month in months_days {
83 if remaining_days < days_in_month {
84 break;
85 }
86 remaining_days -= days_in_month;
87 month += 1;
88 }
89
90 let day = remaining_days + 1;
91
92 format!(
93 "{:04}{:02}{:02}{:02}{:02}{:02}",
94 year, month, day, hours, mins, secs
95 )
96 }
97
98 #[tracing::instrument(level = "info", skip(ops, ddl, description))]
102 pub fn from_operations(
103 ops: &[SchemaOperation],
104 ddl: &dyn DdlGenerator,
105 description: impl Into<String>,
106 ) -> Self {
107 let description = description.into();
108 let version = Self::new_version();
109
110 tracing::info!(
111 version = %version,
112 description = %description,
113 ops_count = ops.len(),
114 dialect = ddl.dialect(),
115 "Creating migration from schema operations"
116 );
117
118 let up_stmts = ddl.generate_all(ops);
119 let down_stmts = ddl.generate_rollback(ops);
120
121 let up = up_stmts.join(";\n\n") + if up_stmts.is_empty() { "" } else { ";" };
123 let down = down_stmts.join(";\n\n") + if down_stmts.is_empty() { "" } else { ";" };
124
125 tracing::debug!(
126 up_statements = up_stmts.len(),
127 down_statements = down_stmts.len(),
128 "Generated migration SQL"
129 );
130
131 Self {
132 id: version,
133 description,
134 up,
135 down,
136 }
137 }
138}
139
140fn is_leap_year(year: i64) -> bool {
142 (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
143}
144
145#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
151pub enum MigrationFormat {
152 #[default]
154 Sql,
155 Rust,
157}
158
159pub struct MigrationWriter {
161 migrations_dir: PathBuf,
163 format: MigrationFormat,
165}
166
167impl MigrationWriter {
168 pub fn new(dir: impl Into<PathBuf>) -> Self {
170 Self {
171 migrations_dir: dir.into(),
172 format: MigrationFormat::default(),
173 }
174 }
175
176 #[must_use]
178 pub fn with_format(mut self, format: MigrationFormat) -> Self {
179 self.format = format;
180 self
181 }
182
183 pub fn migrations_dir(&self) -> &Path {
185 &self.migrations_dir
186 }
187
188 pub fn format(&self) -> MigrationFormat {
190 self.format
191 }
192
193 #[tracing::instrument(level = "info", skip(self, migration))]
198 pub fn write(&self, migration: &Migration) -> std::io::Result<PathBuf> {
199 tracing::info!(
200 version = %migration.id,
201 description = %migration.description,
202 format = ?self.format,
203 dir = %self.migrations_dir.display(),
204 "Writing migration file"
205 );
206
207 std::fs::create_dir_all(&self.migrations_dir)?;
208
209 let filename = self.filename(migration);
210 let path = self.migrations_dir.join(&filename);
211 let content = self.format_migration(migration);
212
213 std::fs::write(&path, &content)?;
214
215 tracing::info!(
216 path = %path.display(),
217 bytes = content.len(),
218 "Migration file written"
219 );
220
221 Ok(path)
222 }
223
224 fn filename(&self, m: &Migration) -> String {
226 let sanitized_desc: String = m
229 .description
230 .to_lowercase()
231 .chars()
232 .map(|c| if c.is_alphanumeric() { c } else { '_' })
233 .collect::<String>()
234 .split('_')
235 .filter(|s| !s.is_empty())
236 .collect::<Vec<_>>()
237 .join("_");
238
239 let desc = if sanitized_desc.len() > 50 {
241 &sanitized_desc[..50]
242 } else {
243 &sanitized_desc
244 };
245
246 match self.format {
247 MigrationFormat::Sql => format!("{}_{}.sql", m.id, desc),
248 MigrationFormat::Rust => format!("{}_{}.rs", m.id, desc),
249 }
250 }
251
252 fn format_migration(&self, m: &Migration) -> String {
254 match self.format {
255 MigrationFormat::Sql => self.format_sql(m),
256 MigrationFormat::Rust => self.format_rust(m),
257 }
258 }
259
260 fn format_sql(&self, m: &Migration) -> String {
262 let mut content = String::new();
263
264 content.push_str(&format!("-- Migration: {}\n", m.description));
266 content.push_str(&format!("-- Version: {}\n", m.id));
267 content.push_str(&format!(
268 "-- Generated: {}\n\n",
269 std::time::SystemTime::now()
270 .duration_since(std::time::UNIX_EPOCH)
271 .map_or(0, |d| d.as_secs())
272 ));
273
274 content.push_str("-- ========== UP ==========\n\n");
276 content.push_str(&m.up);
277 content.push_str("\n\n");
278
279 content.push_str("-- ========== DOWN ==========\n");
281 content.push_str("-- Uncomment to enable rollback:\n\n");
282 for line in m.down.lines() {
283 content.push_str("-- ");
284 content.push_str(line);
285 content.push('\n');
286 }
287
288 content
289 }
290
291 fn format_rust(&self, m: &Migration) -> String {
293 let mut content = String::new();
294
295 content.push_str("//! Auto-generated migration.\n");
297 content.push_str(&format!("//! Description: {}\n", m.description));
298 content.push_str(&format!("//! Version: {}\n\n", m.id));
299
300 content.push_str("use sqlmodel_schema::Migration;\n\n");
301
302 content.push_str("/// Returns this migration.\n");
304 content.push_str("pub fn migration() -> Migration {\n");
305 content.push_str(" Migration::new(\n");
306 content.push_str(&format!(" {:?},\n", m.id));
307 content.push_str(&format!(" {:?},\n", m.description));
308
309 content.push_str(" r#\"\n");
311 content.push_str(&m.up);
312 content.push_str("\n\"#,\n");
313
314 content.push_str(" r#\"\n");
316 content.push_str(&m.down);
317 content.push_str("\n\"#,\n");
318
319 content.push_str(" )\n");
320 content.push_str("}\n");
321
322 content
323 }
324}
325
326#[derive(Debug, Clone, PartialEq, Eq)]
328pub enum MigrationStatus {
329 Pending,
331 Applied { at: i64 },
333 Failed { error: String },
335}
336
337pub struct MigrationRunner {
339 migrations: Vec<Migration>,
341 table_name: String,
343}
344
345fn sanitize_table_name(name: &str) -> String {
349 name.chars()
350 .filter(|c| c.is_alphanumeric() || *c == '_')
351 .collect()
352}
353
354impl MigrationRunner {
355 pub fn new(migrations: Vec<Migration>) -> Self {
357 Self {
358 migrations,
359 table_name: "_sqlmodel_migrations".to_string(),
360 }
361 }
362
363 pub fn table_name(mut self, name: impl Into<String>) -> Self {
368 self.table_name = sanitize_table_name(&name.into());
369 self
370 }
371
372 pub async fn init<C: Connection>(&self, cx: &Cx, conn: &C) -> Outcome<(), Error> {
374 let sql = format!(
375 "CREATE TABLE IF NOT EXISTS {} (
376 id TEXT PRIMARY KEY,
377 description TEXT NOT NULL,
378 applied_at INTEGER NOT NULL
379 )",
380 self.table_name
381 );
382
383 conn.execute(cx, &sql, &[]).await.map(|_| ())
384 }
385
386 pub async fn status<C: Connection>(
388 &self,
389 cx: &Cx,
390 conn: &C,
391 ) -> Outcome<Vec<(String, MigrationStatus)>, Error> {
392 match self.init(cx, conn).await {
394 Outcome::Ok(()) => {}
395 Outcome::Err(e) => return Outcome::Err(e),
396 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
397 Outcome::Panicked(p) => return Outcome::Panicked(p),
398 }
399
400 let sql = format!("SELECT id, applied_at FROM {}", self.table_name);
402 let rows = match conn.query(cx, &sql, &[]).await {
403 Outcome::Ok(rows) => rows,
404 Outcome::Err(e) => return Outcome::Err(e),
405 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
406 Outcome::Panicked(p) => return Outcome::Panicked(p),
407 };
408
409 let mut applied: HashMap<String, i64> = HashMap::new();
410 for row in rows {
411 if let (Ok(id), Ok(at)) = (
412 row.get_named::<String>("id"),
413 row.get_named::<i64>("applied_at"),
414 ) {
415 applied.insert(id, at);
416 }
417 }
418
419 let status: Vec<_> = self
420 .migrations
421 .iter()
422 .map(|m| {
423 let status = if let Some(&at) = applied.get(&m.id) {
424 MigrationStatus::Applied { at }
425 } else {
426 MigrationStatus::Pending
427 };
428 (m.id.clone(), status)
429 })
430 .collect();
431
432 Outcome::Ok(status)
433 }
434
435 pub async fn migrate<C: Connection>(&self, cx: &Cx, conn: &C) -> Outcome<Vec<String>, Error> {
437 let status = match self.status(cx, conn).await {
438 Outcome::Ok(s) => s,
439 Outcome::Err(e) => return Outcome::Err(e),
440 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
441 Outcome::Panicked(p) => return Outcome::Panicked(p),
442 };
443
444 let mut applied = Vec::new();
445
446 for (id, s) in status {
447 if s == MigrationStatus::Pending {
448 let Some(migration) = self.migrations.iter().find(|m| m.id == id) else {
449 continue;
451 };
452
453 match conn.execute(cx, &migration.up, &[]).await {
455 Outcome::Ok(_) => {}
456 Outcome::Err(e) => return Outcome::Err(e),
457 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
458 Outcome::Panicked(p) => return Outcome::Panicked(p),
459 }
460
461 let record_sql = format!(
463 "INSERT INTO {} (id, description, applied_at) VALUES ($1, $2, $3)",
464 self.table_name
465 );
466 let now = std::time::SystemTime::now()
467 .duration_since(std::time::UNIX_EPOCH)
468 .map_or(0, |d| d.as_secs() as i64);
469
470 match conn
471 .execute(
472 cx,
473 &record_sql,
474 &[
475 Value::Text(migration.id.clone()),
476 Value::Text(migration.description.clone()),
477 Value::BigInt(now),
478 ],
479 )
480 .await
481 {
482 Outcome::Ok(_) => {}
483 Outcome::Err(e) => return Outcome::Err(e),
484 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
485 Outcome::Panicked(p) => return Outcome::Panicked(p),
486 }
487
488 applied.push(id);
489 }
490 }
491
492 Outcome::Ok(applied)
493 }
494
495 pub async fn rollback<C: Connection>(
497 &self,
498 cx: &Cx,
499 conn: &C,
500 ) -> Outcome<Option<String>, Error> {
501 let status = match self.status(cx, conn).await {
502 Outcome::Ok(s) => s,
503 Outcome::Err(e) => return Outcome::Err(e),
504 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
505 Outcome::Panicked(p) => return Outcome::Panicked(p),
506 };
507
508 let last_applied = status
510 .iter()
511 .filter_map(|(id, s)| {
512 if let MigrationStatus::Applied { at } = s {
513 Some((id.clone(), *at))
514 } else {
515 None
516 }
517 })
518 .max_by_key(|(_, at)| *at);
519
520 let Some((id, _)) = last_applied else {
521 return Outcome::Ok(None);
522 };
523
524 let Some(migration) = self.migrations.iter().find(|m| m.id == id) else {
525 return Outcome::Err(Error::Custom(format!(
527 "Migration '{}' not found in migrations list",
528 id
529 )));
530 };
531
532 match conn.execute(cx, &migration.down, &[]).await {
534 Outcome::Ok(_) => {}
535 Outcome::Err(e) => return Outcome::Err(e),
536 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
537 Outcome::Panicked(p) => return Outcome::Panicked(p),
538 }
539
540 let delete_sql = format!("DELETE FROM {} WHERE id = $1", self.table_name);
542 match conn
543 .execute(cx, &delete_sql, &[Value::Text(id.clone())])
544 .await
545 {
546 Outcome::Ok(_) => {}
547 Outcome::Err(e) => return Outcome::Err(e),
548 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
549 Outcome::Panicked(p) => return Outcome::Panicked(p),
550 }
551
552 Outcome::Ok(Some(id))
553 }
554}
555
556#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn test_new_version_format() {
566 let version = Migration::new_version();
567 assert_eq!(version.len(), 14);
569 assert!(version.chars().all(|c| c.is_ascii_digit()));
571 let year: i32 = version[0..4].parse().unwrap();
573 assert!((2020..=2100).contains(&year));
574 }
575
576 #[test]
577 fn test_version_ordering() {
578 let v1 = "20250101_000000";
581 let v2 = "20250101_000001";
582 let v3 = "20250102_000000";
583
584 assert!(v2 > v1);
586 assert!(v3 > v2);
588 assert!(v3 > v1);
590 }
591
592 #[test]
593 fn test_migration_new() {
594 let m = Migration::new(
595 "001",
596 "Create users table",
597 "CREATE TABLE users",
598 "DROP TABLE users",
599 );
600 assert_eq!(m.id, "001");
601 assert_eq!(m.description, "Create users table");
602 assert_eq!(m.up, "CREATE TABLE users");
603 assert_eq!(m.down, "DROP TABLE users");
604 }
605
606 #[test]
607 fn test_migration_from_operations() {
608 use crate::ddl::SqliteDdlGenerator;
609 use crate::introspect::{ColumnInfo, ParsedSqlType, TableInfo};
610
611 let table = TableInfo {
612 name: "heroes".to_string(),
613 columns: vec![
614 ColumnInfo {
615 name: "id".to_string(),
616 sql_type: "INTEGER".to_string(),
617 parsed_type: ParsedSqlType::parse("INTEGER"),
618 nullable: false,
619 default: None,
620 primary_key: true,
621 auto_increment: true,
622 comment: None,
623 },
624 ColumnInfo {
625 name: "name".to_string(),
626 sql_type: "TEXT".to_string(),
627 parsed_type: ParsedSqlType::parse("TEXT"),
628 nullable: false,
629 default: None,
630 primary_key: false,
631 auto_increment: false,
632 comment: None,
633 },
634 ],
635 primary_key: vec!["id".to_string()],
636 foreign_keys: Vec::new(),
637 unique_constraints: Vec::new(),
638 check_constraints: Vec::new(),
639 indexes: Vec::new(),
640 comment: None,
641 };
642
643 let ops = vec![crate::diff::SchemaOperation::CreateTable(table)];
644 let ddl = SqliteDdlGenerator;
645 let m = Migration::from_operations(&ops, &ddl, "Create heroes table");
646
647 assert!(!m.id.is_empty());
648 assert_eq!(m.description, "Create heroes table");
649 assert!(m.up.contains("CREATE TABLE"));
650 assert!(m.up.contains("heroes"));
651 assert!(m.down.contains("DROP TABLE"));
652 }
653
654 #[test]
655 fn test_is_leap_year() {
656 assert!(!is_leap_year(2023)); assert!(is_leap_year(2024)); assert!(!is_leap_year(2100)); assert!(is_leap_year(2000)); }
661
662 #[test]
663 fn test_migration_format_default() {
664 assert_eq!(MigrationFormat::default(), MigrationFormat::Sql);
665 }
666
667 #[test]
668 fn test_migration_writer_new() {
669 let writer = MigrationWriter::new("/tmp/migrations");
670 assert_eq!(writer.migrations_dir(), Path::new("/tmp/migrations"));
671 assert_eq!(writer.format(), MigrationFormat::Sql);
672 }
673
674 #[test]
675 fn test_migration_writer_with_format() {
676 let writer = MigrationWriter::new("/tmp/migrations").with_format(MigrationFormat::Rust);
677 assert_eq!(writer.format(), MigrationFormat::Rust);
678 }
679
680 #[test]
681 fn test_filename_sanitization() {
682 let writer = MigrationWriter::new("/tmp");
683 let m = Migration::new("20260127120000", "Create Users Table!!!", "", "");
684 let filename = writer.filename(&m);
685 assert!(filename.starts_with("20260127120000_"));
686 assert!(
687 Path::new(&filename)
688 .extension()
689 .is_some_and(|ext| ext.eq_ignore_ascii_case("sql"))
690 );
691 assert!(!filename.contains('!'));
692 assert!(!filename.contains(' '));
693 }
694
695 #[test]
696 fn test_filename_rust_format() {
697 let writer = MigrationWriter::new("/tmp").with_format(MigrationFormat::Rust);
698 let m = Migration::new("20260127120000", "Test migration", "", "");
699 let filename = writer.filename(&m);
700 assert!(
701 Path::new(&filename)
702 .extension()
703 .is_some_and(|ext| ext.eq_ignore_ascii_case("rs"))
704 );
705 }
706
707 #[test]
708 fn test_format_sql_structure() {
709 let writer = MigrationWriter::new("/tmp");
710 let m = Migration::new(
711 "20260127120000",
712 "Test migration",
713 "CREATE TABLE test (id INT)",
714 "DROP TABLE test",
715 );
716 let content = writer.format_sql(&m);
717
718 assert!(content.contains("-- Migration: Test migration"));
720 assert!(content.contains("-- Version: 20260127120000"));
721
722 assert!(content.contains("-- ========== UP =========="));
724 assert!(content.contains("CREATE TABLE test"));
725
726 assert!(content.contains("-- ========== DOWN =========="));
728 assert!(content.contains("DROP TABLE test"));
729 }
730
731 #[test]
732 fn test_format_rust_structure() {
733 let writer = MigrationWriter::new("/tmp").with_format(MigrationFormat::Rust);
734 let m = Migration::new(
735 "20260127120000",
736 "Test migration",
737 "CREATE TABLE test",
738 "DROP TABLE test",
739 );
740 let content = writer.format_rust(&m);
741
742 assert!(content.contains("//! Auto-generated migration"));
744 assert!(content.contains("//! Description: Test migration"));
745
746 assert!(content.contains("use sqlmodel_schema::Migration"));
748
749 assert!(content.contains("pub fn migration() -> Migration"));
751 assert!(content.contains("Migration::new("));
752
753 assert!(content.contains("CREATE TABLE test"));
755 assert!(content.contains("DROP TABLE test"));
756 }
757
758 #[test]
759 fn test_filename_truncation() {
760 let writer = MigrationWriter::new("/tmp");
761 let long_desc = "a".repeat(100); let m = Migration::new("20260127120000", &long_desc, "", "");
763 let filename = writer.filename(&m);
764 assert!(filename.len() < 100);
766 }
767
768 #[test]
769 fn test_migration_status_enum() {
770 let pending = MigrationStatus::Pending;
771 let applied = MigrationStatus::Applied { at: 1_234_567_890 };
772 let failed = MigrationStatus::Failed {
773 error: "Test error".to_string(),
774 };
775
776 assert_eq!(pending, MigrationStatus::Pending);
777 assert_ne!(pending, applied);
778
779 assert!(matches!(
780 applied,
781 MigrationStatus::Applied { at } if at == 1_234_567_890
782 ));
783 assert!(matches!(
784 failed,
785 MigrationStatus::Failed { ref error } if error == "Test error"
786 ));
787 }
788
789 #[test]
790 fn test_migration_runner_new() {
791 let migrations = vec![
792 Migration::new("001", "First", "UP", "DOWN"),
793 Migration::new("002", "Second", "UP", "DOWN"),
794 ];
795 let runner = MigrationRunner::new(migrations);
796 assert_eq!(runner.table_name, "_sqlmodel_migrations");
797 }
798
799 #[test]
800 fn test_migration_runner_custom_table() {
801 let runner = MigrationRunner::new(vec![]).table_name("custom_migrations");
802 assert_eq!(runner.table_name, "custom_migrations");
803 }
804}