vespertide_macro/
lib.rs

1// MigrationOptions and MigrationError are now in vespertide-core
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Token};
7use vespertide_loader::{load_migrations_at_compile_time, load_models_at_compile_time};
8use vespertide_planner::apply_action;
9use vespertide_query::{DatabaseBackend, build_plan_queries};
10
11struct MacroInput {
12    pool: Expr,
13    version_table: Option<String>,
14}
15
16impl Parse for MacroInput {
17    fn parse(input: ParseStream) -> syn::Result<Self> {
18        let pool = input.parse()?;
19        let mut version_table = None;
20
21        while !input.is_empty() {
22            input.parse::<Token![,]>()?;
23            if input.is_empty() {
24                break;
25            }
26
27            let key: Ident = input.parse()?;
28            if key == "version_table" {
29                input.parse::<Token![=]>()?;
30                let value: syn::LitStr = input.parse()?;
31                version_table = Some(value.value());
32            } else {
33                return Err(syn::Error::new(
34                    key.span(),
35                    "unsupported option for vespertide_migration!",
36                ));
37            }
38        }
39
40        Ok(MacroInput {
41            pool,
42            version_table,
43        })
44    }
45}
46
47/// Build a migration block for a single migration version.
48/// Returns the generated code block and updates the baseline schema.
49pub(crate) fn build_migration_block(
50    migration: &vespertide_core::MigrationPlan,
51    baseline_schema: &mut Vec<vespertide_core::TableDef>,
52) -> Result<proc_macro2::TokenStream, String> {
53    let version = migration.version;
54
55    // Use the current baseline schema (from all previous migrations)
56    let queries = build_plan_queries(migration, baseline_schema).map_err(|e| {
57        format!(
58            "Failed to build queries for migration version {}: {}",
59            version, e
60        )
61    })?;
62
63    // Update baseline schema incrementally by applying each action
64    for action in &migration.actions {
65        let _ = apply_action(baseline_schema, action);
66    }
67
68    // Pre-generate SQL for all backends at compile time
69    // Each query may produce multiple SQL statements, so we flatten them
70    let mut pg_sqls = Vec::new();
71    let mut mysql_sqls = Vec::new();
72    let mut sqlite_sqls = Vec::new();
73
74    for q in &queries {
75        for stmt in &q.postgres {
76            pg_sqls.push(stmt.build(DatabaseBackend::Postgres));
77        }
78        for stmt in &q.mysql {
79            mysql_sqls.push(stmt.build(DatabaseBackend::MySql));
80        }
81        for stmt in &q.sqlite {
82            sqlite_sqls.push(stmt.build(DatabaseBackend::Sqlite));
83        }
84    }
85
86    // Generate version guard and SQL execution block
87    let block = quote! {
88        if version < #version {
89            // Begin transaction
90            let txn = __pool.begin().await.map_err(|e| {
91                ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
92            })?;
93
94            // Select SQL statements based on backend
95            let sqls: &[&str] = match backend {
96                sea_orm::DatabaseBackend::Postgres => &[#(#pg_sqls),*],
97                sea_orm::DatabaseBackend::MySql => &[#(#mysql_sqls),*],
98                sea_orm::DatabaseBackend::Sqlite => &[#(#sqlite_sqls),*],
99                _ => &[#(#pg_sqls),*], // Fallback to PostgreSQL syntax for unknown backends
100            };
101
102            // Execute SQL statements
103            for sql in sqls {
104                if !sql.is_empty() {
105                    let stmt = sea_orm::Statement::from_string(backend, *sql);
106                    txn.execute_raw(stmt).await.map_err(|e| {
107                        ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
108                    })?;
109                }
110            }
111
112            // Insert version record for this migration
113            let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
114            let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
115            let stmt = sea_orm::Statement::from_string(backend, insert_sql);
116            txn.execute_raw(stmt).await.map_err(|e| {
117                ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
118            })?;
119
120            // Commit transaction
121            txn.commit().await.map_err(|e| {
122                ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
123            })?;
124        }
125    };
126
127    Ok(block)
128}
129
130/// Generate the final async migration block with all migrations.
131pub(crate) fn generate_migration_code(
132    pool: &Expr,
133    version_table: &str,
134    migration_blocks: Vec<proc_macro2::TokenStream>,
135) -> proc_macro2::TokenStream {
136    quote! {
137        async {
138            use sea_orm::{ConnectionTrait, TransactionTrait};
139            let __pool = #pool;
140            let version_table = #version_table;
141            let backend = __pool.get_database_backend();
142
143            // Create version table if it does not exist
144            // Table structure: version (INTEGER PRIMARY KEY), created_at (timestamp)
145            let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
146            let create_table_sql = format!(
147                "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
148                version_table
149            );
150            let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
151            __pool.execute_raw(stmt).await.map_err(|e| {
152                ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
153            })?;
154
155            // Read current maximum version (latest applied migration)
156            let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
157            let stmt = sea_orm::Statement::from_string(backend, select_sql);
158            let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
159                ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
160            })?;
161
162            let mut version = version_result
163                .and_then(|row| row.try_get::<i32>("", "version").ok())
164                .unwrap_or(0) as u32;
165
166            // Execute each migration block
167            #(#migration_blocks)*
168
169            Ok::<(), ::vespertide::MigrationError>(())
170        }
171    }
172}
173
174/// Inner implementation that works with proc_macro2::TokenStream for testability.
175pub(crate) fn vespertide_migration_impl(
176    input: proc_macro2::TokenStream,
177) -> proc_macro2::TokenStream {
178    let input: MacroInput = match syn::parse2(input) {
179        Ok(input) => input,
180        Err(e) => return e.to_compile_error(),
181    };
182    let pool = &input.pool;
183    let version_table = input
184        .version_table
185        .unwrap_or_else(|| "vespertide_version".to_string());
186
187    // Load migration files and build SQL at compile time
188    let migrations = match load_migrations_at_compile_time() {
189        Ok(migrations) => migrations,
190        Err(e) => {
191            return syn::Error::new(
192                proc_macro2::Span::call_site(),
193                format!("Failed to load migrations at compile time: {}", e),
194            )
195            .to_compile_error();
196        }
197    };
198    let _models = match load_models_at_compile_time() {
199        Ok(models) => models,
200        Err(e) => {
201            return syn::Error::new(
202                proc_macro2::Span::call_site(),
203                format!("Failed to load models at compile time: {}", e),
204            )
205            .to_compile_error();
206        }
207    };
208
209    // Build SQL for each migration using incremental baseline schema
210    let mut baseline_schema = Vec::new();
211    let mut migration_blocks = Vec::new();
212
213    for migration in &migrations {
214        match build_migration_block(migration, &mut baseline_schema) {
215            Ok(block) => migration_blocks.push(block),
216            Err(e) => {
217                return syn::Error::new(proc_macro2::Span::call_site(), e).to_compile_error();
218            }
219        }
220    }
221
222    generate_migration_code(pool, &version_table, migration_blocks)
223}
224
225/// Zero-runtime migration entry point.
226#[proc_macro]
227pub fn vespertide_migration(input: TokenStream) -> TokenStream {
228    vespertide_migration_impl(input.into()).into()
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use std::fs::File;
235    use std::io::Write;
236    use tempfile::tempdir;
237    use vespertide_core::{
238        ColumnDef, ColumnType, MigrationAction, MigrationPlan, SimpleColumnType, StrOrBoolOrArray,
239    };
240
241    #[test]
242    fn test_macro_expansion_with_runtime_macros() {
243        // Create a temporary directory with test files
244        let dir = tempdir().unwrap();
245
246        // Create a test file that uses the macro
247        let test_file_path = dir.path().join("test_macro.rs");
248        let mut test_file = File::create(&test_file_path).unwrap();
249        writeln!(
250            test_file,
251            r#"vespertide_migration!(pool, version_table = "test_versions");"#
252        )
253        .unwrap();
254
255        // Use runtime-macros to emulate macro expansion
256        let file = File::open(&test_file_path).unwrap();
257        let result = runtime_macros::emulate_functionlike_macro_expansion(
258            file,
259            &[("vespertide_migration", vespertide_migration_impl)],
260        );
261
262        // The macro will fail because there's no vespertide config, but
263        // the important thing is that it runs and covers the macro code
264        // We expect an error due to missing config
265        assert!(result.is_ok() || result.is_err());
266    }
267
268    #[test]
269    fn test_macro_with_simple_pool() {
270        let dir = tempdir().unwrap();
271        let test_file_path = dir.path().join("test_simple.rs");
272        let mut test_file = File::create(&test_file_path).unwrap();
273        writeln!(test_file, r#"vespertide_migration!(db_pool);"#).unwrap();
274
275        let file = File::open(&test_file_path).unwrap();
276        let result = runtime_macros::emulate_functionlike_macro_expansion(
277            file,
278            &[("vespertide_migration", vespertide_migration_impl)],
279        );
280
281        assert!(result.is_ok() || result.is_err());
282    }
283
284    #[test]
285    fn test_macro_parsing_invalid_option() {
286        // Test that invalid options produce a compile error
287        let input: proc_macro2::TokenStream = "pool, invalid_option = \"value\"".parse().unwrap();
288        let output = vespertide_migration_impl(input);
289        let output_str = output.to_string();
290        // Should contain an error message about unsupported option
291        assert!(output_str.contains("unsupported option"));
292    }
293
294    #[test]
295    fn test_macro_parsing_valid_input() {
296        // Test that valid input is parsed correctly
297        // The macro will either succeed (if migrations dir exists and is empty)
298        // or fail with a migration loading error
299        let input: proc_macro2::TokenStream = "my_pool".parse().unwrap();
300        let output = vespertide_migration_impl(input);
301        let output_str = output.to_string();
302        // Should produce output (either success or migration loading error)
303        assert!(!output_str.is_empty());
304        // If error, it should mention "Failed to load"
305        // If success, it should contain "async"
306        assert!(
307            output_str.contains("async") || output_str.contains("Failed to load"),
308            "Unexpected output: {}",
309            output_str
310        );
311    }
312
313    #[test]
314    fn test_macro_parsing_with_version_table() {
315        let input: proc_macro2::TokenStream =
316            r#"pool, version_table = "custom_versions""#.parse().unwrap();
317        let output = vespertide_migration_impl(input);
318        let output_str = output.to_string();
319        assert!(!output_str.is_empty());
320    }
321
322    #[test]
323    fn test_macro_parsing_trailing_comma() {
324        let input: proc_macro2::TokenStream = "pool,".parse().unwrap();
325        let output = vespertide_migration_impl(input);
326        let output_str = output.to_string();
327        assert!(!output_str.is_empty());
328    }
329
330    fn test_column(name: &str) -> ColumnDef {
331        ColumnDef {
332            name: name.into(),
333            r#type: ColumnType::Simple(SimpleColumnType::Integer),
334            nullable: false,
335            default: None,
336            comment: None,
337            primary_key: None,
338            unique: None,
339            index: None,
340            foreign_key: None,
341        }
342    }
343
344    #[test]
345    fn test_build_migration_block_create_table() {
346        let migration = MigrationPlan {
347            version: 1,
348            comment: None,
349            created_at: None,
350            actions: vec![MigrationAction::CreateTable {
351                table: "users".into(),
352                columns: vec![test_column("id")],
353                constraints: vec![],
354            }],
355        };
356
357        let mut baseline = Vec::new();
358        let result = build_migration_block(&migration, &mut baseline);
359
360        assert!(result.is_ok());
361        let block = result.unwrap();
362        let block_str = block.to_string();
363
364        // Verify the generated block contains expected elements
365        assert!(block_str.contains("version < 1u32"));
366        assert!(block_str.contains("CREATE TABLE"));
367
368        // Verify baseline schema was updated
369        assert_eq!(baseline.len(), 1);
370        assert_eq!(baseline[0].name, "users");
371    }
372
373    #[test]
374    fn test_build_migration_block_add_column() {
375        // First create the table
376        let create_migration = MigrationPlan {
377            version: 1,
378            comment: None,
379            created_at: None,
380            actions: vec![MigrationAction::CreateTable {
381                table: "users".into(),
382                columns: vec![test_column("id")],
383                constraints: vec![],
384            }],
385        };
386
387        let mut baseline = Vec::new();
388        let _ = build_migration_block(&create_migration, &mut baseline);
389
390        // Now add a column
391        let add_column_migration = MigrationPlan {
392            version: 2,
393            comment: None,
394            created_at: None,
395            actions: vec![MigrationAction::AddColumn {
396                table: "users".into(),
397                column: Box::new(ColumnDef {
398                    name: "email".into(),
399                    r#type: ColumnType::Simple(SimpleColumnType::Text),
400                    nullable: true,
401                    default: None,
402                    comment: None,
403                    primary_key: None,
404                    unique: None,
405                    index: None,
406                    foreign_key: None,
407                }),
408                fill_with: None,
409            }],
410        };
411
412        let result = build_migration_block(&add_column_migration, &mut baseline);
413        assert!(result.is_ok());
414        let block = result.unwrap();
415        let block_str = block.to_string();
416
417        assert!(block_str.contains("version < 2u32"));
418        assert!(block_str.contains("ALTER TABLE"));
419        assert!(block_str.contains("ADD COLUMN"));
420    }
421
422    #[test]
423    fn test_build_migration_block_multiple_actions() {
424        let migration = MigrationPlan {
425            version: 1,
426            comment: None,
427            created_at: None,
428            actions: vec![
429                MigrationAction::CreateTable {
430                    table: "users".into(),
431                    columns: vec![test_column("id")],
432                    constraints: vec![],
433                },
434                MigrationAction::CreateTable {
435                    table: "posts".into(),
436                    columns: vec![test_column("id")],
437                    constraints: vec![],
438                },
439            ],
440        };
441
442        let mut baseline = Vec::new();
443        let result = build_migration_block(&migration, &mut baseline);
444
445        assert!(result.is_ok());
446        assert_eq!(baseline.len(), 2);
447    }
448
449    #[test]
450    fn test_generate_migration_code() {
451        let pool: Expr = syn::parse_str("db_pool").unwrap();
452        let version_table = "test_versions";
453
454        // Create a simple migration block
455        let migration = MigrationPlan {
456            version: 1,
457            comment: None,
458            created_at: None,
459            actions: vec![MigrationAction::CreateTable {
460                table: "users".into(),
461                columns: vec![test_column("id")],
462                constraints: vec![],
463            }],
464        };
465
466        let mut baseline = Vec::new();
467        let block = build_migration_block(&migration, &mut baseline).unwrap();
468
469        let generated = generate_migration_code(&pool, version_table, vec![block]);
470        let generated_str = generated.to_string();
471
472        // Verify the generated code structure
473        assert!(generated_str.contains("async"));
474        assert!(generated_str.contains("db_pool"));
475        assert!(generated_str.contains("test_versions"));
476        assert!(generated_str.contains("CREATE TABLE IF NOT EXISTS"));
477        assert!(generated_str.contains("SELECT MAX"));
478    }
479
480    #[test]
481    fn test_generate_migration_code_empty_migrations() {
482        let pool: Expr = syn::parse_str("pool").unwrap();
483        let version_table = "vespertide_version";
484
485        let generated = generate_migration_code(&pool, version_table, vec![]);
486        let generated_str = generated.to_string();
487
488        // Should still generate the wrapper code
489        assert!(generated_str.contains("async"));
490        assert!(generated_str.contains("vespertide_version"));
491    }
492
493    #[test]
494    fn test_generate_migration_code_multiple_blocks() {
495        let pool: Expr = syn::parse_str("connection").unwrap();
496
497        let mut baseline = Vec::new();
498
499        let migration1 = MigrationPlan {
500            version: 1,
501            comment: None,
502            created_at: None,
503            actions: vec![MigrationAction::CreateTable {
504                table: "users".into(),
505                columns: vec![test_column("id")],
506                constraints: vec![],
507            }],
508        };
509        let block1 = build_migration_block(&migration1, &mut baseline).unwrap();
510
511        let migration2 = MigrationPlan {
512            version: 2,
513            comment: None,
514            created_at: None,
515            actions: vec![MigrationAction::CreateTable {
516                table: "posts".into(),
517                columns: vec![test_column("id")],
518                constraints: vec![],
519            }],
520        };
521        let block2 = build_migration_block(&migration2, &mut baseline).unwrap();
522
523        let generated = generate_migration_code(&pool, "migrations", vec![block1, block2]);
524        let generated_str = generated.to_string();
525
526        // Both version checks should be present
527        assert!(generated_str.contains("version < 1u32"));
528        assert!(generated_str.contains("version < 2u32"));
529    }
530
531    #[test]
532    fn test_build_migration_block_generates_all_backends() {
533        let migration = MigrationPlan {
534            version: 1,
535            comment: None,
536            created_at: None,
537            actions: vec![MigrationAction::CreateTable {
538                table: "test_table".into(),
539                columns: vec![test_column("id")],
540                constraints: vec![],
541            }],
542        };
543
544        let mut baseline = Vec::new();
545        let result = build_migration_block(&migration, &mut baseline);
546        assert!(result.is_ok());
547
548        let block_str = result.unwrap().to_string();
549
550        // The generated block should have backend matching
551        assert!(block_str.contains("DatabaseBackend :: Postgres"));
552        assert!(block_str.contains("DatabaseBackend :: MySql"));
553        assert!(block_str.contains("DatabaseBackend :: Sqlite"));
554    }
555
556    #[test]
557    fn test_build_migration_block_with_delete_table() {
558        // First create the table
559        let create_migration = MigrationPlan {
560            version: 1,
561            comment: None,
562            created_at: None,
563            actions: vec![MigrationAction::CreateTable {
564                table: "temp_table".into(),
565                columns: vec![test_column("id")],
566                constraints: vec![],
567            }],
568        };
569
570        let mut baseline = Vec::new();
571        let _ = build_migration_block(&create_migration, &mut baseline);
572        assert_eq!(baseline.len(), 1);
573
574        // Now delete it
575        let delete_migration = MigrationPlan {
576            version: 2,
577            comment: None,
578            created_at: None,
579            actions: vec![MigrationAction::DeleteTable {
580                table: "temp_table".into(),
581            }],
582        };
583
584        let result = build_migration_block(&delete_migration, &mut baseline);
585        assert!(result.is_ok());
586        let block_str = result.unwrap().to_string();
587        assert!(block_str.contains("DROP TABLE"));
588
589        // Baseline should be empty after delete
590        assert_eq!(baseline.len(), 0);
591    }
592
593    #[test]
594    fn test_build_migration_block_with_index() {
595        let migration = MigrationPlan {
596            version: 1,
597            comment: None,
598            created_at: None,
599            actions: vec![MigrationAction::CreateTable {
600                table: "users".into(),
601                columns: vec![
602                    test_column("id"),
603                    ColumnDef {
604                        name: "email".into(),
605                        r#type: ColumnType::Simple(SimpleColumnType::Text),
606                        nullable: true,
607                        default: None,
608                        comment: None,
609                        primary_key: None,
610                        unique: None,
611                        index: Some(StrOrBoolOrArray::Bool(true)),
612                        foreign_key: None,
613                    },
614                ],
615                constraints: vec![],
616            }],
617        };
618
619        let mut baseline = Vec::new();
620        let result = build_migration_block(&migration, &mut baseline);
621        assert!(result.is_ok());
622
623        // Table should be normalized with index
624        let table = &baseline[0];
625        let normalized = table.clone().normalize();
626        assert!(normalized.is_ok());
627    }
628
629    #[test]
630    fn test_build_migration_block_error_nonexistent_table() {
631        // Try to add column to a table that doesn't exist - should fail
632        let migration = MigrationPlan {
633            version: 1,
634            comment: None,
635            created_at: None,
636            actions: vec![MigrationAction::AddColumn {
637                table: "nonexistent_table".into(),
638                column: Box::new(test_column("new_col")),
639                fill_with: None,
640            }],
641        };
642
643        let mut baseline = Vec::new();
644        let result = build_migration_block(&migration, &mut baseline);
645
646        assert!(result.is_err());
647        let err = result.unwrap_err();
648        assert!(err.contains("Failed to build queries for migration version 1"));
649    }
650
651    #[test]
652    fn test_vespertide_migration_impl_loading_error() {
653        // Save original CARGO_MANIFEST_DIR
654        let original = std::env::var("CARGO_MANIFEST_DIR").ok();
655
656        // Remove CARGO_MANIFEST_DIR to trigger loading error
657        unsafe {
658            std::env::remove_var("CARGO_MANIFEST_DIR");
659        }
660
661        let input: proc_macro2::TokenStream = "pool".parse().unwrap();
662        let output = vespertide_migration_impl(input);
663        let output_str = output.to_string();
664
665        // Should contain error about failed loading
666        assert!(
667            output_str.contains("Failed to load migrations at compile time"),
668            "Expected loading error, got: {}",
669            output_str
670        );
671
672        // Restore CARGO_MANIFEST_DIR
673        if let Some(val) = original {
674            unsafe {
675                std::env::set_var("CARGO_MANIFEST_DIR", val);
676            }
677        }
678    }
679}