1use std::env;
4use std::path::PathBuf;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::parse::{Parse, ParseStream};
9use syn::{Expr, Ident, Token};
10use vespertide_loader::{
11 load_config_or_default, load_migrations_at_compile_time, load_models_at_compile_time,
12};
13use vespertide_planner::apply_action;
14use vespertide_query::{DatabaseBackend, build_plan_queries};
15
16struct MacroInput {
17 pool: Expr,
18 version_table: Option<String>,
19}
20
21impl Parse for MacroInput {
22 fn parse(input: ParseStream) -> syn::Result<Self> {
23 let pool = input.parse()?;
24 let mut version_table = None;
25
26 while !input.is_empty() {
27 input.parse::<Token![,]>()?;
28 if input.is_empty() {
29 break;
30 }
31
32 let key: Ident = input.parse()?;
33 if key == "version_table" {
34 input.parse::<Token![=]>()?;
35 let value: syn::LitStr = input.parse()?;
36 version_table = Some(value.value());
37 } else {
38 return Err(syn::Error::new(
39 key.span(),
40 "unsupported option for vespertide_migration!",
41 ));
42 }
43 }
44
45 Ok(MacroInput {
46 pool,
47 version_table,
48 })
49 }
50}
51
52pub(crate) fn build_migration_block(
55 migration: &vespertide_core::MigrationPlan,
56 baseline_schema: &mut Vec<vespertide_core::TableDef>,
57) -> Result<proc_macro2::TokenStream, String> {
58 let version = migration.version;
59
60 let queries = build_plan_queries(migration, baseline_schema).map_err(|e| {
62 format!(
63 "Failed to build queries for migration version {}: {}",
64 version, e
65 )
66 })?;
67
68 for action in &migration.actions {
70 let _ = apply_action(baseline_schema, action);
71 }
72
73 let mut pg_sqls = Vec::new();
76 let mut mysql_sqls = Vec::new();
77 let mut sqlite_sqls = Vec::new();
78
79 for q in &queries {
80 for stmt in &q.postgres {
81 pg_sqls.push(stmt.build(DatabaseBackend::Postgres));
82 }
83 for stmt in &q.mysql {
84 mysql_sqls.push(stmt.build(DatabaseBackend::MySql));
85 }
86 for stmt in &q.sqlite {
87 sqlite_sqls.push(stmt.build(DatabaseBackend::Sqlite));
88 }
89 }
90
91 let block = quote! {
93 if version < #version {
94 let txn = __pool.begin().await.map_err(|e| {
96 ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
97 })?;
98
99 let sqls: &[&str] = match backend {
101 sea_orm::DatabaseBackend::Postgres => &[#(#pg_sqls),*],
102 sea_orm::DatabaseBackend::MySql => &[#(#mysql_sqls),*],
103 sea_orm::DatabaseBackend::Sqlite => &[#(#sqlite_sqls),*],
104 _ => &[#(#pg_sqls),*], };
106
107 for sql in sqls {
109 if !sql.is_empty() {
110 let stmt = sea_orm::Statement::from_string(backend, *sql);
111 txn.execute_raw(stmt).await.map_err(|e| {
112 ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
113 })?;
114 }
115 }
116
117 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
119 let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
120 let stmt = sea_orm::Statement::from_string(backend, insert_sql);
121 txn.execute_raw(stmt).await.map_err(|e| {
122 ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
123 })?;
124
125 txn.commit().await.map_err(|e| {
127 ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
128 })?;
129 }
130 };
131
132 Ok(block)
133}
134
135pub(crate) fn generate_migration_code(
137 pool: &Expr,
138 version_table: &str,
139 migration_blocks: Vec<proc_macro2::TokenStream>,
140) -> proc_macro2::TokenStream {
141 quote! {
142 async {
143 use sea_orm::{ConnectionTrait, TransactionTrait};
144 let __pool = #pool;
145 let version_table = #version_table;
146 let backend = __pool.get_database_backend();
147
148 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
151 let create_table_sql = format!(
152 "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
153 version_table
154 );
155 let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
156 __pool.execute_raw(stmt).await.map_err(|e| {
157 ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
158 })?;
159
160 let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
162 let stmt = sea_orm::Statement::from_string(backend, select_sql);
163 let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
164 ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
165 })?;
166
167 let mut version = version_result
168 .and_then(|row| row.try_get::<i32>("", "version").ok())
169 .unwrap_or(0) as u32;
170
171 #(#migration_blocks)*
173
174 Ok::<(), ::vespertide::MigrationError>(())
175 }
176 }
177}
178
179pub(crate) fn vespertide_migration_impl(
181 input: proc_macro2::TokenStream,
182) -> proc_macro2::TokenStream {
183 let input: MacroInput = match syn::parse2(input) {
184 Ok(input) => input,
185 Err(e) => return e.to_compile_error(),
186 };
187 let pool = &input.pool;
188
189 let project_root = match env::var("CARGO_MANIFEST_DIR") {
191 Ok(dir) => Some(PathBuf::from(dir)),
192 Err(_) => None,
193 };
194
195 let config = match load_config_or_default(project_root) {
197 Ok(config) => config,
198 #[cfg(not(tarpaulin_include))]
199 Err(e) => {
200 return syn::Error::new(
201 proc_macro2::Span::call_site(),
202 format!("Failed to load config at compile time: {}", e),
203 )
204 .to_compile_error();
205 }
206 };
207 let prefix = config.prefix();
208
209 let version_table = input
211 .version_table
212 .map(|vt| config.apply_prefix(&vt))
213 .unwrap_or_else(|| config.apply_prefix("vespertide_version"));
214
215 let migrations = match load_migrations_at_compile_time() {
217 Ok(migrations) => migrations,
218 Err(e) => {
219 return syn::Error::new(
220 proc_macro2::Span::call_site(),
221 format!("Failed to load migrations at compile time: {}", e),
222 )
223 .to_compile_error();
224 }
225 };
226 let _models = match load_models_at_compile_time() {
227 Ok(models) => models,
228 #[cfg(not(tarpaulin_include))]
229 Err(e) => {
230 return syn::Error::new(
231 proc_macro2::Span::call_site(),
232 format!("Failed to load models at compile time: {}", e),
233 )
234 .to_compile_error();
235 }
236 };
237
238 let mut baseline_schema = Vec::new();
240 let mut migration_blocks = Vec::new();
241
242 #[cfg(not(tarpaulin_include))]
243 for migration in &migrations {
244 let prefixed_migration = migration.clone().with_prefix(prefix);
246 match build_migration_block(&prefixed_migration, &mut baseline_schema) {
247 Ok(block) => migration_blocks.push(block),
248 Err(e) => {
249 return syn::Error::new(proc_macro2::Span::call_site(), e).to_compile_error();
250 }
251 }
252 }
253
254 generate_migration_code(pool, &version_table, migration_blocks)
255}
256
257#[cfg(not(tarpaulin_include))]
259#[proc_macro]
260pub fn vespertide_migration(input: TokenStream) -> TokenStream {
261 vespertide_migration_impl(input.into()).into()
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use std::fs::File;
268 use std::io::Write;
269 use tempfile::tempdir;
270 use vespertide_core::{
271 ColumnDef, ColumnType, MigrationAction, MigrationPlan, SimpleColumnType, StrOrBoolOrArray,
272 };
273
274 #[test]
275 fn test_macro_expansion_with_runtime_macros() {
276 let dir = tempdir().unwrap();
278
279 let test_file_path = dir.path().join("test_macro.rs");
281 let mut test_file = File::create(&test_file_path).unwrap();
282 writeln!(
283 test_file,
284 r#"vespertide_migration!(pool, version_table = "test_versions");"#
285 )
286 .unwrap();
287
288 let file = File::open(&test_file_path).unwrap();
290 let result = runtime_macros::emulate_functionlike_macro_expansion(
291 file,
292 &[("vespertide_migration", vespertide_migration_impl)],
293 );
294
295 assert!(result.is_ok() || result.is_err());
299 }
300
301 #[test]
302 fn test_macro_with_simple_pool() {
303 let dir = tempdir().unwrap();
304 let test_file_path = dir.path().join("test_simple.rs");
305 let mut test_file = File::create(&test_file_path).unwrap();
306 writeln!(test_file, r#"vespertide_migration!(db_pool);"#).unwrap();
307
308 let file = File::open(&test_file_path).unwrap();
309 let result = runtime_macros::emulate_functionlike_macro_expansion(
310 file,
311 &[("vespertide_migration", vespertide_migration_impl)],
312 );
313
314 assert!(result.is_ok() || result.is_err());
315 }
316
317 #[test]
318 fn test_macro_parsing_invalid_option() {
319 let input: proc_macro2::TokenStream = "pool, invalid_option = \"value\"".parse().unwrap();
321 let output = vespertide_migration_impl(input);
322 let output_str = output.to_string();
323 assert!(output_str.contains("unsupported option"));
325 }
326
327 #[test]
328 fn test_macro_parsing_valid_input() {
329 let input: proc_macro2::TokenStream = "my_pool".parse().unwrap();
333 let output = vespertide_migration_impl(input);
334 let output_str = output.to_string();
335 assert!(!output_str.is_empty());
337 assert!(
340 output_str.contains("async") || output_str.contains("Failed to load"),
341 "Unexpected output: {}",
342 output_str
343 );
344 }
345
346 #[test]
347 fn test_macro_parsing_with_version_table() {
348 let input: proc_macro2::TokenStream =
349 r#"pool, version_table = "custom_versions""#.parse().unwrap();
350 let output = vespertide_migration_impl(input);
351 let output_str = output.to_string();
352 assert!(!output_str.is_empty());
353 }
354
355 #[test]
356 fn test_macro_parsing_trailing_comma() {
357 let input: proc_macro2::TokenStream = "pool,".parse().unwrap();
358 let output = vespertide_migration_impl(input);
359 let output_str = output.to_string();
360 assert!(!output_str.is_empty());
361 }
362
363 fn test_column(name: &str) -> ColumnDef {
364 ColumnDef {
365 name: name.into(),
366 r#type: ColumnType::Simple(SimpleColumnType::Integer),
367 nullable: false,
368 default: None,
369 comment: None,
370 primary_key: None,
371 unique: None,
372 index: None,
373 foreign_key: None,
374 }
375 }
376
377 #[test]
378 fn test_build_migration_block_create_table() {
379 let migration = MigrationPlan {
380 version: 1,
381 comment: None,
382 created_at: None,
383 actions: vec![MigrationAction::CreateTable {
384 table: "users".into(),
385 columns: vec![test_column("id")],
386 constraints: vec![],
387 }],
388 };
389
390 let mut baseline = Vec::new();
391 let result = build_migration_block(&migration, &mut baseline);
392
393 assert!(result.is_ok());
394 let block = result.unwrap();
395 let block_str = block.to_string();
396
397 assert!(block_str.contains("version < 1u32"));
399 assert!(block_str.contains("CREATE TABLE"));
400
401 assert_eq!(baseline.len(), 1);
403 assert_eq!(baseline[0].name, "users");
404 }
405
406 #[test]
407 fn test_build_migration_block_add_column() {
408 let create_migration = MigrationPlan {
410 version: 1,
411 comment: None,
412 created_at: None,
413 actions: vec![MigrationAction::CreateTable {
414 table: "users".into(),
415 columns: vec![test_column("id")],
416 constraints: vec![],
417 }],
418 };
419
420 let mut baseline = Vec::new();
421 let _ = build_migration_block(&create_migration, &mut baseline);
422
423 let add_column_migration = MigrationPlan {
425 version: 2,
426 comment: None,
427 created_at: None,
428 actions: vec![MigrationAction::AddColumn {
429 table: "users".into(),
430 column: Box::new(ColumnDef {
431 name: "email".into(),
432 r#type: ColumnType::Simple(SimpleColumnType::Text),
433 nullable: true,
434 default: None,
435 comment: None,
436 primary_key: None,
437 unique: None,
438 index: None,
439 foreign_key: None,
440 }),
441 fill_with: None,
442 }],
443 };
444
445 let result = build_migration_block(&add_column_migration, &mut baseline);
446 assert!(result.is_ok());
447 let block = result.unwrap();
448 let block_str = block.to_string();
449
450 assert!(block_str.contains("version < 2u32"));
451 assert!(block_str.contains("ALTER TABLE"));
452 assert!(block_str.contains("ADD COLUMN"));
453 }
454
455 #[test]
456 fn test_build_migration_block_multiple_actions() {
457 let migration = MigrationPlan {
458 version: 1,
459 comment: None,
460 created_at: None,
461 actions: vec![
462 MigrationAction::CreateTable {
463 table: "users".into(),
464 columns: vec![test_column("id")],
465 constraints: vec![],
466 },
467 MigrationAction::CreateTable {
468 table: "posts".into(),
469 columns: vec![test_column("id")],
470 constraints: vec![],
471 },
472 ],
473 };
474
475 let mut baseline = Vec::new();
476 let result = build_migration_block(&migration, &mut baseline);
477
478 assert!(result.is_ok());
479 assert_eq!(baseline.len(), 2);
480 }
481
482 #[test]
483 fn test_generate_migration_code() {
484 let pool: Expr = syn::parse_str("db_pool").unwrap();
485 let version_table = "test_versions";
486
487 let migration = MigrationPlan {
489 version: 1,
490 comment: None,
491 created_at: None,
492 actions: vec![MigrationAction::CreateTable {
493 table: "users".into(),
494 columns: vec![test_column("id")],
495 constraints: vec![],
496 }],
497 };
498
499 let mut baseline = Vec::new();
500 let block = build_migration_block(&migration, &mut baseline).unwrap();
501
502 let generated = generate_migration_code(&pool, version_table, vec![block]);
503 let generated_str = generated.to_string();
504
505 assert!(generated_str.contains("async"));
507 assert!(generated_str.contains("db_pool"));
508 assert!(generated_str.contains("test_versions"));
509 assert!(generated_str.contains("CREATE TABLE IF NOT EXISTS"));
510 assert!(generated_str.contains("SELECT MAX"));
511 }
512
513 #[test]
514 fn test_generate_migration_code_empty_migrations() {
515 let pool: Expr = syn::parse_str("pool").unwrap();
516 let version_table = "vespertide_version";
517
518 let generated = generate_migration_code(&pool, version_table, vec![]);
519 let generated_str = generated.to_string();
520
521 assert!(generated_str.contains("async"));
523 assert!(generated_str.contains("vespertide_version"));
524 }
525
526 #[test]
527 fn test_generate_migration_code_multiple_blocks() {
528 let pool: Expr = syn::parse_str("connection").unwrap();
529
530 let mut baseline = Vec::new();
531
532 let migration1 = MigrationPlan {
533 version: 1,
534 comment: None,
535 created_at: None,
536 actions: vec![MigrationAction::CreateTable {
537 table: "users".into(),
538 columns: vec![test_column("id")],
539 constraints: vec![],
540 }],
541 };
542 let block1 = build_migration_block(&migration1, &mut baseline).unwrap();
543
544 let migration2 = MigrationPlan {
545 version: 2,
546 comment: None,
547 created_at: None,
548 actions: vec![MigrationAction::CreateTable {
549 table: "posts".into(),
550 columns: vec![test_column("id")],
551 constraints: vec![],
552 }],
553 };
554 let block2 = build_migration_block(&migration2, &mut baseline).unwrap();
555
556 let generated = generate_migration_code(&pool, "migrations", vec![block1, block2]);
557 let generated_str = generated.to_string();
558
559 assert!(generated_str.contains("version < 1u32"));
561 assert!(generated_str.contains("version < 2u32"));
562 }
563
564 #[test]
565 fn test_build_migration_block_generates_all_backends() {
566 let migration = MigrationPlan {
567 version: 1,
568 comment: None,
569 created_at: None,
570 actions: vec![MigrationAction::CreateTable {
571 table: "test_table".into(),
572 columns: vec![test_column("id")],
573 constraints: vec![],
574 }],
575 };
576
577 let mut baseline = Vec::new();
578 let result = build_migration_block(&migration, &mut baseline);
579 assert!(result.is_ok());
580
581 let block_str = result.unwrap().to_string();
582
583 assert!(block_str.contains("DatabaseBackend :: Postgres"));
585 assert!(block_str.contains("DatabaseBackend :: MySql"));
586 assert!(block_str.contains("DatabaseBackend :: Sqlite"));
587 }
588
589 #[test]
590 fn test_build_migration_block_with_delete_table() {
591 let create_migration = MigrationPlan {
593 version: 1,
594 comment: None,
595 created_at: None,
596 actions: vec![MigrationAction::CreateTable {
597 table: "temp_table".into(),
598 columns: vec![test_column("id")],
599 constraints: vec![],
600 }],
601 };
602
603 let mut baseline = Vec::new();
604 let _ = build_migration_block(&create_migration, &mut baseline);
605 assert_eq!(baseline.len(), 1);
606
607 let delete_migration = MigrationPlan {
609 version: 2,
610 comment: None,
611 created_at: None,
612 actions: vec![MigrationAction::DeleteTable {
613 table: "temp_table".into(),
614 }],
615 };
616
617 let result = build_migration_block(&delete_migration, &mut baseline);
618 assert!(result.is_ok());
619 let block_str = result.unwrap().to_string();
620 assert!(block_str.contains("DROP TABLE"));
621
622 assert_eq!(baseline.len(), 0);
624 }
625
626 #[test]
627 fn test_build_migration_block_with_index() {
628 let migration = MigrationPlan {
629 version: 1,
630 comment: None,
631 created_at: None,
632 actions: vec![MigrationAction::CreateTable {
633 table: "users".into(),
634 columns: vec![
635 test_column("id"),
636 ColumnDef {
637 name: "email".into(),
638 r#type: ColumnType::Simple(SimpleColumnType::Text),
639 nullable: true,
640 default: None,
641 comment: None,
642 primary_key: None,
643 unique: None,
644 index: Some(StrOrBoolOrArray::Bool(true)),
645 foreign_key: None,
646 },
647 ],
648 constraints: vec![],
649 }],
650 };
651
652 let mut baseline = Vec::new();
653 let result = build_migration_block(&migration, &mut baseline);
654 assert!(result.is_ok());
655
656 let table = &baseline[0];
658 let normalized = table.clone().normalize();
659 assert!(normalized.is_ok());
660 }
661
662 #[test]
663 fn test_build_migration_block_error_nonexistent_table() {
664 let migration = MigrationPlan {
666 version: 1,
667 comment: None,
668 created_at: None,
669 actions: vec![MigrationAction::AddColumn {
670 table: "nonexistent_table".into(),
671 column: Box::new(test_column("new_col")),
672 fill_with: None,
673 }],
674 };
675
676 let mut baseline = Vec::new();
677 let result = build_migration_block(&migration, &mut baseline);
678
679 assert!(result.is_err());
680 let err = result.unwrap_err();
681 assert!(err.contains("Failed to build queries for migration version 1"));
682 }
683
684 #[test]
685 fn test_vespertide_migration_impl_loading_error() {
686 let original = std::env::var("CARGO_MANIFEST_DIR").ok();
688
689 unsafe {
691 std::env::remove_var("CARGO_MANIFEST_DIR");
692 }
693
694 let input: proc_macro2::TokenStream = "pool".parse().unwrap();
695 let output = vespertide_migration_impl(input);
696 let output_str = output.to_string();
697
698 assert!(
700 output_str.contains("Failed to load migrations at compile time"),
701 "Expected loading error, got: {}",
702 output_str
703 );
704
705 if let Some(val) = original {
707 unsafe {
708 std::env::set_var("CARGO_MANIFEST_DIR", val);
709 }
710 }
711 }
712
713 #[test]
714 fn test_vespertide_migration_impl_with_valid_project() {
715 use std::fs;
716
717 let dir = tempdir().unwrap();
719 let project_dir = dir.path();
720
721 let config_content = r#"{
723 "modelsDir": "models",
724 "migrationsDir": "migrations",
725 "tableNamingCase": "snake",
726 "columnNamingCase": "snake",
727 "modelFormat": "json"
728 }"#;
729 fs::write(project_dir.join("vespertide.json"), config_content).unwrap();
730
731 fs::create_dir_all(project_dir.join("models")).unwrap();
733 fs::create_dir_all(project_dir.join("migrations")).unwrap();
734
735 let original = std::env::var("CARGO_MANIFEST_DIR").ok();
737 unsafe {
738 std::env::set_var("CARGO_MANIFEST_DIR", project_dir);
739 }
740
741 let input: proc_macro2::TokenStream = "pool".parse().unwrap();
742 let output = vespertide_migration_impl(input);
743 let output_str = output.to_string();
744
745 assert!(
747 output_str.contains("async"),
748 "Expected async block, got: {}",
749 output_str
750 );
751 assert!(
752 output_str.contains("CREATE TABLE IF NOT EXISTS"),
753 "Expected version table creation, got: {}",
754 output_str
755 );
756
757 if let Some(val) = original {
759 unsafe {
760 std::env::set_var("CARGO_MANIFEST_DIR", val);
761 }
762 } else {
763 unsafe {
764 std::env::remove_var("CARGO_MANIFEST_DIR");
765 }
766 }
767 }
768
769 #[test]
770 fn test_vespertide_migration_impl_with_migrations() {
771 use std::fs;
772
773 let dir = tempdir().unwrap();
775 let project_dir = dir.path();
776
777 let config_content = r#"{
779 "modelsDir": "models",
780 "migrationsDir": "migrations",
781 "tableNamingCase": "snake",
782 "columnNamingCase": "snake",
783 "modelFormat": "json"
784 }"#;
785 fs::write(project_dir.join("vespertide.json"), config_content).unwrap();
786
787 fs::create_dir_all(project_dir.join("models")).unwrap();
789 fs::create_dir_all(project_dir.join("migrations")).unwrap();
790
791 let migration_content = r#"{
793 "version": 1,
794 "actions": [
795 {
796 "type": "create_table",
797 "table": "users",
798 "columns": [
799 {"name": "id", "type": "integer", "nullable": false}
800 ],
801 "constraints": []
802 }
803 ]
804 }"#;
805 fs::write(
806 project_dir.join("migrations").join("0001_initial.json"),
807 migration_content,
808 )
809 .unwrap();
810
811 let original = std::env::var("CARGO_MANIFEST_DIR").ok();
813 unsafe {
814 std::env::set_var("CARGO_MANIFEST_DIR", project_dir);
815 }
816
817 let input: proc_macro2::TokenStream = "pool".parse().unwrap();
818 let output = vespertide_migration_impl(input);
819 let output_str = output.to_string();
820
821 assert!(
823 output_str.contains("async"),
824 "Expected async block, got: {}",
825 output_str
826 );
827
828 if let Some(val) = original {
830 unsafe {
831 std::env::set_var("CARGO_MANIFEST_DIR", val);
832 }
833 } else {
834 unsafe {
835 std::env::remove_var("CARGO_MANIFEST_DIR");
836 }
837 }
838 }
839}