1use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Token};
7use vespertide_loader::{load_migrations_at_compile_time, load_models_at_compile_time};
8use vespertide_planner::apply_action;
9use vespertide_query::{DatabaseBackend, build_plan_queries};
10
11struct MacroInput {
12 pool: Expr,
13 version_table: Option<String>,
14}
15
16impl Parse for MacroInput {
17 fn parse(input: ParseStream) -> syn::Result<Self> {
18 let pool = input.parse()?;
19 let mut version_table = None;
20
21 while !input.is_empty() {
22 input.parse::<Token![,]>()?;
23 if input.is_empty() {
24 break;
25 }
26
27 let key: Ident = input.parse()?;
28 if key == "version_table" {
29 input.parse::<Token![=]>()?;
30 let value: syn::LitStr = input.parse()?;
31 version_table = Some(value.value());
32 } else {
33 return Err(syn::Error::new(
34 key.span(),
35 "unsupported option for vespertide_migration!",
36 ));
37 }
38 }
39
40 Ok(MacroInput {
41 pool,
42 version_table,
43 })
44 }
45}
46
47pub(crate) fn build_migration_block(
50 migration: &vespertide_core::MigrationPlan,
51 baseline_schema: &mut Vec<vespertide_core::TableDef>,
52) -> Result<proc_macro2::TokenStream, String> {
53 let version = migration.version;
54
55 let queries = build_plan_queries(migration, baseline_schema).map_err(|e| {
57 format!(
58 "Failed to build queries for migration version {}: {}",
59 version, e
60 )
61 })?;
62
63 for action in &migration.actions {
65 let _ = apply_action(baseline_schema, action);
66 }
67
68 let mut pg_sqls = Vec::new();
71 let mut mysql_sqls = Vec::new();
72 let mut sqlite_sqls = Vec::new();
73
74 for q in &queries {
75 for stmt in &q.postgres {
76 pg_sqls.push(stmt.build(DatabaseBackend::Postgres));
77 }
78 for stmt in &q.mysql {
79 mysql_sqls.push(stmt.build(DatabaseBackend::MySql));
80 }
81 for stmt in &q.sqlite {
82 sqlite_sqls.push(stmt.build(DatabaseBackend::Sqlite));
83 }
84 }
85
86 let block = quote! {
88 if version < #version {
89 let txn = __pool.begin().await.map_err(|e| {
91 ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
92 })?;
93
94 let sqls: &[&str] = match backend {
96 sea_orm::DatabaseBackend::Postgres => &[#(#pg_sqls),*],
97 sea_orm::DatabaseBackend::MySql => &[#(#mysql_sqls),*],
98 sea_orm::DatabaseBackend::Sqlite => &[#(#sqlite_sqls),*],
99 _ => &[#(#pg_sqls),*], };
101
102 for sql in sqls {
104 if !sql.is_empty() {
105 let stmt = sea_orm::Statement::from_string(backend, *sql);
106 txn.execute_raw(stmt).await.map_err(|e| {
107 ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
108 })?;
109 }
110 }
111
112 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
114 let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
115 let stmt = sea_orm::Statement::from_string(backend, insert_sql);
116 txn.execute_raw(stmt).await.map_err(|e| {
117 ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
118 })?;
119
120 txn.commit().await.map_err(|e| {
122 ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
123 })?;
124 }
125 };
126
127 Ok(block)
128}
129
130pub(crate) fn generate_migration_code(
132 pool: &Expr,
133 version_table: &str,
134 migration_blocks: Vec<proc_macro2::TokenStream>,
135) -> proc_macro2::TokenStream {
136 quote! {
137 async {
138 use sea_orm::{ConnectionTrait, TransactionTrait};
139 let __pool = #pool;
140 let version_table = #version_table;
141 let backend = __pool.get_database_backend();
142
143 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
146 let create_table_sql = format!(
147 "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
148 version_table
149 );
150 let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
151 __pool.execute_raw(stmt).await.map_err(|e| {
152 ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
153 })?;
154
155 let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
157 let stmt = sea_orm::Statement::from_string(backend, select_sql);
158 let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
159 ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
160 })?;
161
162 let mut version = version_result
163 .and_then(|row| row.try_get::<i32>("", "version").ok())
164 .unwrap_or(0) as u32;
165
166 #(#migration_blocks)*
168
169 Ok::<(), ::vespertide::MigrationError>(())
170 }
171 }
172}
173
174pub(crate) fn vespertide_migration_impl(
176 input: proc_macro2::TokenStream,
177) -> proc_macro2::TokenStream {
178 let input: MacroInput = match syn::parse2(input) {
179 Ok(input) => input,
180 Err(e) => return e.to_compile_error(),
181 };
182 let pool = &input.pool;
183 let version_table = input
184 .version_table
185 .unwrap_or_else(|| "vespertide_version".to_string());
186
187 let migrations = match load_migrations_at_compile_time() {
189 Ok(migrations) => migrations,
190 Err(e) => {
191 return syn::Error::new(
192 proc_macro2::Span::call_site(),
193 format!("Failed to load migrations at compile time: {}", e),
194 )
195 .to_compile_error();
196 }
197 };
198 let _models = match load_models_at_compile_time() {
199 Ok(models) => models,
200 #[cfg(not(tarpaulin_include))]
201 Err(e) => {
202 return syn::Error::new(
203 proc_macro2::Span::call_site(),
204 format!("Failed to load models at compile time: {}", e),
205 )
206 .to_compile_error();
207 }
208 };
209
210 let mut baseline_schema = Vec::new();
212 let mut migration_blocks = Vec::new();
213
214 #[cfg(not(tarpaulin_include))]
215 for migration in &migrations {
216 match build_migration_block(migration, &mut baseline_schema) {
217 Ok(block) => migration_blocks.push(block),
218 Err(e) => {
219 return syn::Error::new(proc_macro2::Span::call_site(), e).to_compile_error();
220 }
221 }
222 }
223
224 generate_migration_code(pool, &version_table, migration_blocks)
225}
226
227#[cfg(not(tarpaulin_include))]
229#[proc_macro]
230pub fn vespertide_migration(input: TokenStream) -> TokenStream {
231 vespertide_migration_impl(input.into()).into()
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use std::fs::File;
238 use std::io::Write;
239 use tempfile::tempdir;
240 use vespertide_core::{
241 ColumnDef, ColumnType, MigrationAction, MigrationPlan, SimpleColumnType, StrOrBoolOrArray,
242 };
243
244 #[test]
245 fn test_macro_expansion_with_runtime_macros() {
246 let dir = tempdir().unwrap();
248
249 let test_file_path = dir.path().join("test_macro.rs");
251 let mut test_file = File::create(&test_file_path).unwrap();
252 writeln!(
253 test_file,
254 r#"vespertide_migration!(pool, version_table = "test_versions");"#
255 )
256 .unwrap();
257
258 let file = File::open(&test_file_path).unwrap();
260 let result = runtime_macros::emulate_functionlike_macro_expansion(
261 file,
262 &[("vespertide_migration", vespertide_migration_impl)],
263 );
264
265 assert!(result.is_ok() || result.is_err());
269 }
270
271 #[test]
272 fn test_macro_with_simple_pool() {
273 let dir = tempdir().unwrap();
274 let test_file_path = dir.path().join("test_simple.rs");
275 let mut test_file = File::create(&test_file_path).unwrap();
276 writeln!(test_file, r#"vespertide_migration!(db_pool);"#).unwrap();
277
278 let file = File::open(&test_file_path).unwrap();
279 let result = runtime_macros::emulate_functionlike_macro_expansion(
280 file,
281 &[("vespertide_migration", vespertide_migration_impl)],
282 );
283
284 assert!(result.is_ok() || result.is_err());
285 }
286
287 #[test]
288 fn test_macro_parsing_invalid_option() {
289 let input: proc_macro2::TokenStream = "pool, invalid_option = \"value\"".parse().unwrap();
291 let output = vespertide_migration_impl(input);
292 let output_str = output.to_string();
293 assert!(output_str.contains("unsupported option"));
295 }
296
297 #[test]
298 fn test_macro_parsing_valid_input() {
299 let input: proc_macro2::TokenStream = "my_pool".parse().unwrap();
303 let output = vespertide_migration_impl(input);
304 let output_str = output.to_string();
305 assert!(!output_str.is_empty());
307 assert!(
310 output_str.contains("async") || output_str.contains("Failed to load"),
311 "Unexpected output: {}",
312 output_str
313 );
314 }
315
316 #[test]
317 fn test_macro_parsing_with_version_table() {
318 let input: proc_macro2::TokenStream =
319 r#"pool, version_table = "custom_versions""#.parse().unwrap();
320 let output = vespertide_migration_impl(input);
321 let output_str = output.to_string();
322 assert!(!output_str.is_empty());
323 }
324
325 #[test]
326 fn test_macro_parsing_trailing_comma() {
327 let input: proc_macro2::TokenStream = "pool,".parse().unwrap();
328 let output = vespertide_migration_impl(input);
329 let output_str = output.to_string();
330 assert!(!output_str.is_empty());
331 }
332
333 fn test_column(name: &str) -> ColumnDef {
334 ColumnDef {
335 name: name.into(),
336 r#type: ColumnType::Simple(SimpleColumnType::Integer),
337 nullable: false,
338 default: None,
339 comment: None,
340 primary_key: None,
341 unique: None,
342 index: None,
343 foreign_key: None,
344 }
345 }
346
347 #[test]
348 fn test_build_migration_block_create_table() {
349 let migration = MigrationPlan {
350 version: 1,
351 comment: None,
352 created_at: None,
353 actions: vec![MigrationAction::CreateTable {
354 table: "users".into(),
355 columns: vec![test_column("id")],
356 constraints: vec![],
357 }],
358 };
359
360 let mut baseline = Vec::new();
361 let result = build_migration_block(&migration, &mut baseline);
362
363 assert!(result.is_ok());
364 let block = result.unwrap();
365 let block_str = block.to_string();
366
367 assert!(block_str.contains("version < 1u32"));
369 assert!(block_str.contains("CREATE TABLE"));
370
371 assert_eq!(baseline.len(), 1);
373 assert_eq!(baseline[0].name, "users");
374 }
375
376 #[test]
377 fn test_build_migration_block_add_column() {
378 let create_migration = MigrationPlan {
380 version: 1,
381 comment: None,
382 created_at: None,
383 actions: vec![MigrationAction::CreateTable {
384 table: "users".into(),
385 columns: vec![test_column("id")],
386 constraints: vec![],
387 }],
388 };
389
390 let mut baseline = Vec::new();
391 let _ = build_migration_block(&create_migration, &mut baseline);
392
393 let add_column_migration = MigrationPlan {
395 version: 2,
396 comment: None,
397 created_at: None,
398 actions: vec![MigrationAction::AddColumn {
399 table: "users".into(),
400 column: Box::new(ColumnDef {
401 name: "email".into(),
402 r#type: ColumnType::Simple(SimpleColumnType::Text),
403 nullable: true,
404 default: None,
405 comment: None,
406 primary_key: None,
407 unique: None,
408 index: None,
409 foreign_key: None,
410 }),
411 fill_with: None,
412 }],
413 };
414
415 let result = build_migration_block(&add_column_migration, &mut baseline);
416 assert!(result.is_ok());
417 let block = result.unwrap();
418 let block_str = block.to_string();
419
420 assert!(block_str.contains("version < 2u32"));
421 assert!(block_str.contains("ALTER TABLE"));
422 assert!(block_str.contains("ADD COLUMN"));
423 }
424
425 #[test]
426 fn test_build_migration_block_multiple_actions() {
427 let migration = MigrationPlan {
428 version: 1,
429 comment: None,
430 created_at: None,
431 actions: vec![
432 MigrationAction::CreateTable {
433 table: "users".into(),
434 columns: vec![test_column("id")],
435 constraints: vec![],
436 },
437 MigrationAction::CreateTable {
438 table: "posts".into(),
439 columns: vec![test_column("id")],
440 constraints: vec![],
441 },
442 ],
443 };
444
445 let mut baseline = Vec::new();
446 let result = build_migration_block(&migration, &mut baseline);
447
448 assert!(result.is_ok());
449 assert_eq!(baseline.len(), 2);
450 }
451
452 #[test]
453 fn test_generate_migration_code() {
454 let pool: Expr = syn::parse_str("db_pool").unwrap();
455 let version_table = "test_versions";
456
457 let migration = MigrationPlan {
459 version: 1,
460 comment: None,
461 created_at: None,
462 actions: vec![MigrationAction::CreateTable {
463 table: "users".into(),
464 columns: vec![test_column("id")],
465 constraints: vec![],
466 }],
467 };
468
469 let mut baseline = Vec::new();
470 let block = build_migration_block(&migration, &mut baseline).unwrap();
471
472 let generated = generate_migration_code(&pool, version_table, vec![block]);
473 let generated_str = generated.to_string();
474
475 assert!(generated_str.contains("async"));
477 assert!(generated_str.contains("db_pool"));
478 assert!(generated_str.contains("test_versions"));
479 assert!(generated_str.contains("CREATE TABLE IF NOT EXISTS"));
480 assert!(generated_str.contains("SELECT MAX"));
481 }
482
483 #[test]
484 fn test_generate_migration_code_empty_migrations() {
485 let pool: Expr = syn::parse_str("pool").unwrap();
486 let version_table = "vespertide_version";
487
488 let generated = generate_migration_code(&pool, version_table, vec![]);
489 let generated_str = generated.to_string();
490
491 assert!(generated_str.contains("async"));
493 assert!(generated_str.contains("vespertide_version"));
494 }
495
496 #[test]
497 fn test_generate_migration_code_multiple_blocks() {
498 let pool: Expr = syn::parse_str("connection").unwrap();
499
500 let mut baseline = Vec::new();
501
502 let migration1 = MigrationPlan {
503 version: 1,
504 comment: None,
505 created_at: None,
506 actions: vec![MigrationAction::CreateTable {
507 table: "users".into(),
508 columns: vec![test_column("id")],
509 constraints: vec![],
510 }],
511 };
512 let block1 = build_migration_block(&migration1, &mut baseline).unwrap();
513
514 let migration2 = MigrationPlan {
515 version: 2,
516 comment: None,
517 created_at: None,
518 actions: vec![MigrationAction::CreateTable {
519 table: "posts".into(),
520 columns: vec![test_column("id")],
521 constraints: vec![],
522 }],
523 };
524 let block2 = build_migration_block(&migration2, &mut baseline).unwrap();
525
526 let generated = generate_migration_code(&pool, "migrations", vec![block1, block2]);
527 let generated_str = generated.to_string();
528
529 assert!(generated_str.contains("version < 1u32"));
531 assert!(generated_str.contains("version < 2u32"));
532 }
533
534 #[test]
535 fn test_build_migration_block_generates_all_backends() {
536 let migration = MigrationPlan {
537 version: 1,
538 comment: None,
539 created_at: None,
540 actions: vec![MigrationAction::CreateTable {
541 table: "test_table".into(),
542 columns: vec![test_column("id")],
543 constraints: vec![],
544 }],
545 };
546
547 let mut baseline = Vec::new();
548 let result = build_migration_block(&migration, &mut baseline);
549 assert!(result.is_ok());
550
551 let block_str = result.unwrap().to_string();
552
553 assert!(block_str.contains("DatabaseBackend :: Postgres"));
555 assert!(block_str.contains("DatabaseBackend :: MySql"));
556 assert!(block_str.contains("DatabaseBackend :: Sqlite"));
557 }
558
559 #[test]
560 fn test_build_migration_block_with_delete_table() {
561 let create_migration = MigrationPlan {
563 version: 1,
564 comment: None,
565 created_at: None,
566 actions: vec![MigrationAction::CreateTable {
567 table: "temp_table".into(),
568 columns: vec![test_column("id")],
569 constraints: vec![],
570 }],
571 };
572
573 let mut baseline = Vec::new();
574 let _ = build_migration_block(&create_migration, &mut baseline);
575 assert_eq!(baseline.len(), 1);
576
577 let delete_migration = MigrationPlan {
579 version: 2,
580 comment: None,
581 created_at: None,
582 actions: vec![MigrationAction::DeleteTable {
583 table: "temp_table".into(),
584 }],
585 };
586
587 let result = build_migration_block(&delete_migration, &mut baseline);
588 assert!(result.is_ok());
589 let block_str = result.unwrap().to_string();
590 assert!(block_str.contains("DROP TABLE"));
591
592 assert_eq!(baseline.len(), 0);
594 }
595
596 #[test]
597 fn test_build_migration_block_with_index() {
598 let migration = MigrationPlan {
599 version: 1,
600 comment: None,
601 created_at: None,
602 actions: vec![MigrationAction::CreateTable {
603 table: "users".into(),
604 columns: vec![
605 test_column("id"),
606 ColumnDef {
607 name: "email".into(),
608 r#type: ColumnType::Simple(SimpleColumnType::Text),
609 nullable: true,
610 default: None,
611 comment: None,
612 primary_key: None,
613 unique: None,
614 index: Some(StrOrBoolOrArray::Bool(true)),
615 foreign_key: None,
616 },
617 ],
618 constraints: vec![],
619 }],
620 };
621
622 let mut baseline = Vec::new();
623 let result = build_migration_block(&migration, &mut baseline);
624 assert!(result.is_ok());
625
626 let table = &baseline[0];
628 let normalized = table.clone().normalize();
629 assert!(normalized.is_ok());
630 }
631
632 #[test]
633 fn test_build_migration_block_error_nonexistent_table() {
634 let migration = MigrationPlan {
636 version: 1,
637 comment: None,
638 created_at: None,
639 actions: vec![MigrationAction::AddColumn {
640 table: "nonexistent_table".into(),
641 column: Box::new(test_column("new_col")),
642 fill_with: None,
643 }],
644 };
645
646 let mut baseline = Vec::new();
647 let result = build_migration_block(&migration, &mut baseline);
648
649 assert!(result.is_err());
650 let err = result.unwrap_err();
651 assert!(err.contains("Failed to build queries for migration version 1"));
652 }
653
654 #[test]
655 fn test_vespertide_migration_impl_loading_error() {
656 let original = std::env::var("CARGO_MANIFEST_DIR").ok();
658
659 unsafe {
661 std::env::remove_var("CARGO_MANIFEST_DIR");
662 }
663
664 let input: proc_macro2::TokenStream = "pool".parse().unwrap();
665 let output = vespertide_migration_impl(input);
666 let output_str = output.to_string();
667
668 assert!(
670 output_str.contains("Failed to load migrations at compile time"),
671 "Expected loading error, got: {}",
672 output_str
673 );
674
675 if let Some(val) = original {
677 unsafe {
678 std::env::set_var("CARGO_MANIFEST_DIR", val);
679 }
680 }
681 }
682
683 #[test]
684 fn test_vespertide_migration_impl_with_valid_project() {
685 use std::fs;
686
687 let dir = tempdir().unwrap();
689 let project_dir = dir.path();
690
691 let config_content = r#"{
693 "modelsDir": "models",
694 "migrationsDir": "migrations",
695 "tableNamingCase": "snake",
696 "columnNamingCase": "snake",
697 "modelFormat": "json"
698 }"#;
699 fs::write(project_dir.join("vespertide.json"), config_content).unwrap();
700
701 fs::create_dir_all(project_dir.join("models")).unwrap();
703 fs::create_dir_all(project_dir.join("migrations")).unwrap();
704
705 let original = std::env::var("CARGO_MANIFEST_DIR").ok();
707 unsafe {
708 std::env::set_var("CARGO_MANIFEST_DIR", project_dir);
709 }
710
711 let input: proc_macro2::TokenStream = "pool".parse().unwrap();
712 let output = vespertide_migration_impl(input);
713 let output_str = output.to_string();
714
715 assert!(
717 output_str.contains("async"),
718 "Expected async block, got: {}",
719 output_str
720 );
721 assert!(
722 output_str.contains("CREATE TABLE IF NOT EXISTS"),
723 "Expected version table creation, got: {}",
724 output_str
725 );
726
727 if let Some(val) = original {
729 unsafe {
730 std::env::set_var("CARGO_MANIFEST_DIR", val);
731 }
732 } else {
733 unsafe {
734 std::env::remove_var("CARGO_MANIFEST_DIR");
735 }
736 }
737 }
738
739 #[test]
740 fn test_vespertide_migration_impl_with_migrations() {
741 use std::fs;
742
743 let dir = tempdir().unwrap();
745 let project_dir = dir.path();
746
747 let config_content = r#"{
749 "modelsDir": "models",
750 "migrationsDir": "migrations",
751 "tableNamingCase": "snake",
752 "columnNamingCase": "snake",
753 "modelFormat": "json"
754 }"#;
755 fs::write(project_dir.join("vespertide.json"), config_content).unwrap();
756
757 fs::create_dir_all(project_dir.join("models")).unwrap();
759 fs::create_dir_all(project_dir.join("migrations")).unwrap();
760
761 let migration_content = r#"{
763 "version": 1,
764 "actions": [
765 {
766 "type": "create_table",
767 "table": "users",
768 "columns": [
769 {"name": "id", "type": "integer", "nullable": false}
770 ],
771 "constraints": []
772 }
773 ]
774 }"#;
775 fs::write(
776 project_dir.join("migrations").join("0001_initial.json"),
777 migration_content,
778 )
779 .unwrap();
780
781 let original = std::env::var("CARGO_MANIFEST_DIR").ok();
783 unsafe {
784 std::env::set_var("CARGO_MANIFEST_DIR", project_dir);
785 }
786
787 let input: proc_macro2::TokenStream = "pool".parse().unwrap();
788 let output = vespertide_migration_impl(input);
789 let output_str = output.to_string();
790
791 assert!(
793 output_str.contains("async"),
794 "Expected async block, got: {}",
795 output_str
796 );
797
798 if let Some(val) = original {
800 unsafe {
801 std::env::set_var("CARGO_MANIFEST_DIR", val);
802 }
803 } else {
804 unsafe {
805 std::env::remove_var("CARGO_MANIFEST_DIR");
806 }
807 }
808 }
809}