vespertide_macro/
lib.rs

1// MigrationOptions and MigrationError are now in vespertide-core
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Token};
7use vespertide_loader::{load_migrations_at_compile_time, load_models_at_compile_time};
8use vespertide_planner::apply_action;
9use vespertide_query::{DatabaseBackend, build_plan_queries};
10
11struct MacroInput {
12    pool: Expr,
13    version_table: Option<String>,
14}
15
16impl Parse for MacroInput {
17    fn parse(input: ParseStream) -> syn::Result<Self> {
18        let pool = input.parse()?;
19        let mut version_table = None;
20
21        while !input.is_empty() {
22            input.parse::<Token![,]>()?;
23            if input.is_empty() {
24                break;
25            }
26
27            let key: Ident = input.parse()?;
28            if key == "version_table" {
29                input.parse::<Token![=]>()?;
30                let value: syn::LitStr = input.parse()?;
31                version_table = Some(value.value());
32            } else {
33                return Err(syn::Error::new(
34                    key.span(),
35                    "unsupported option for vespertide_migration!",
36                ));
37            }
38        }
39
40        Ok(MacroInput {
41            pool,
42            version_table,
43        })
44    }
45}
46
47/// Build a migration block for a single migration version.
48/// Returns the generated code block and updates the baseline schema.
49pub(crate) fn build_migration_block(
50    migration: &vespertide_core::MigrationPlan,
51    baseline_schema: &mut Vec<vespertide_core::TableDef>,
52) -> Result<proc_macro2::TokenStream, String> {
53    let version = migration.version;
54
55    // Use the current baseline schema (from all previous migrations)
56    let queries = build_plan_queries(migration, baseline_schema).map_err(|e| {
57        format!(
58            "Failed to build queries for migration version {}: {}",
59            version, e
60        )
61    })?;
62
63    // Update baseline schema incrementally by applying each action
64    for action in &migration.actions {
65        let _ = apply_action(baseline_schema, action);
66    }
67
68    // Pre-generate SQL for all backends at compile time
69    // Each query may produce multiple SQL statements, so we flatten them
70    let mut pg_sqls = Vec::new();
71    let mut mysql_sqls = Vec::new();
72    let mut sqlite_sqls = Vec::new();
73
74    for q in &queries {
75        for stmt in &q.postgres {
76            pg_sqls.push(stmt.build(DatabaseBackend::Postgres));
77        }
78        for stmt in &q.mysql {
79            mysql_sqls.push(stmt.build(DatabaseBackend::MySql));
80        }
81        for stmt in &q.sqlite {
82            sqlite_sqls.push(stmt.build(DatabaseBackend::Sqlite));
83        }
84    }
85
86    // Generate version guard and SQL execution block
87    let block = quote! {
88        if version < #version {
89            // Begin transaction
90            let txn = __pool.begin().await.map_err(|e| {
91                ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
92            })?;
93
94            // Select SQL statements based on backend
95            let sqls: &[&str] = match backend {
96                sea_orm::DatabaseBackend::Postgres => &[#(#pg_sqls),*],
97                sea_orm::DatabaseBackend::MySql => &[#(#mysql_sqls),*],
98                sea_orm::DatabaseBackend::Sqlite => &[#(#sqlite_sqls),*],
99                _ => &[#(#pg_sqls),*], // Fallback to PostgreSQL syntax for unknown backends
100            };
101
102            // Execute SQL statements
103            for sql in sqls {
104                if !sql.is_empty() {
105                    let stmt = sea_orm::Statement::from_string(backend, *sql);
106                    txn.execute_raw(stmt).await.map_err(|e| {
107                        ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
108                    })?;
109                }
110            }
111
112            // Insert version record for this migration
113            let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
114            let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
115            let stmt = sea_orm::Statement::from_string(backend, insert_sql);
116            txn.execute_raw(stmt).await.map_err(|e| {
117                ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
118            })?;
119
120            // Commit transaction
121            txn.commit().await.map_err(|e| {
122                ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
123            })?;
124        }
125    };
126
127    Ok(block)
128}
129
130/// Generate the final async migration block with all migrations.
131pub(crate) fn generate_migration_code(
132    pool: &Expr,
133    version_table: &str,
134    migration_blocks: Vec<proc_macro2::TokenStream>,
135) -> proc_macro2::TokenStream {
136    quote! {
137        async {
138            use sea_orm::{ConnectionTrait, TransactionTrait};
139            let __pool = #pool;
140            let version_table = #version_table;
141            let backend = __pool.get_database_backend();
142
143            // Create version table if it does not exist
144            // Table structure: version (INTEGER PRIMARY KEY), created_at (timestamp)
145            let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
146            let create_table_sql = format!(
147                "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
148                version_table
149            );
150            let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
151            __pool.execute_raw(stmt).await.map_err(|e| {
152                ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
153            })?;
154
155            // Read current maximum version (latest applied migration)
156            let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
157            let stmt = sea_orm::Statement::from_string(backend, select_sql);
158            let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
159                ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
160            })?;
161
162            let mut version = version_result
163                .and_then(|row| row.try_get::<i32>("", "version").ok())
164                .unwrap_or(0) as u32;
165
166            // Execute each migration block
167            #(#migration_blocks)*
168
169            Ok::<(), ::vespertide::MigrationError>(())
170        }
171    }
172}
173
174/// Inner implementation that works with proc_macro2::TokenStream for testability.
175pub(crate) fn vespertide_migration_impl(
176    input: proc_macro2::TokenStream,
177) -> proc_macro2::TokenStream {
178    let input: MacroInput = match syn::parse2(input) {
179        Ok(input) => input,
180        Err(e) => return e.to_compile_error(),
181    };
182    let pool = &input.pool;
183    let version_table = input
184        .version_table
185        .unwrap_or_else(|| "vespertide_version".to_string());
186
187    // Load migration files and build SQL at compile time
188    let migrations = match load_migrations_at_compile_time() {
189        Ok(migrations) => migrations,
190        Err(e) => {
191            return syn::Error::new(
192                proc_macro2::Span::call_site(),
193                format!("Failed to load migrations at compile time: {}", e),
194            )
195            .to_compile_error();
196        }
197    };
198    let _models = match load_models_at_compile_time() {
199        Ok(models) => models,
200        #[cfg(not(tarpaulin_include))]
201        Err(e) => {
202            return syn::Error::new(
203                proc_macro2::Span::call_site(),
204                format!("Failed to load models at compile time: {}", e),
205            )
206            .to_compile_error();
207        }
208    };
209
210    // Build SQL for each migration using incremental baseline schema
211    let mut baseline_schema = Vec::new();
212    let mut migration_blocks = Vec::new();
213
214    #[cfg(not(tarpaulin_include))]
215    for migration in &migrations {
216        match build_migration_block(migration, &mut baseline_schema) {
217            Ok(block) => migration_blocks.push(block),
218            Err(e) => {
219                return syn::Error::new(proc_macro2::Span::call_site(), e).to_compile_error();
220            }
221        }
222    }
223
224    generate_migration_code(pool, &version_table, migration_blocks)
225}
226
227/// Zero-runtime migration entry point.
228#[cfg(not(tarpaulin_include))]
229#[proc_macro]
230pub fn vespertide_migration(input: TokenStream) -> TokenStream {
231    vespertide_migration_impl(input.into()).into()
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use std::fs::File;
238    use std::io::Write;
239    use tempfile::tempdir;
240    use vespertide_core::{
241        ColumnDef, ColumnType, MigrationAction, MigrationPlan, SimpleColumnType, StrOrBoolOrArray,
242    };
243
244    #[test]
245    fn test_macro_expansion_with_runtime_macros() {
246        // Create a temporary directory with test files
247        let dir = tempdir().unwrap();
248
249        // Create a test file that uses the macro
250        let test_file_path = dir.path().join("test_macro.rs");
251        let mut test_file = File::create(&test_file_path).unwrap();
252        writeln!(
253            test_file,
254            r#"vespertide_migration!(pool, version_table = "test_versions");"#
255        )
256        .unwrap();
257
258        // Use runtime-macros to emulate macro expansion
259        let file = File::open(&test_file_path).unwrap();
260        let result = runtime_macros::emulate_functionlike_macro_expansion(
261            file,
262            &[("vespertide_migration", vespertide_migration_impl)],
263        );
264
265        // The macro will fail because there's no vespertide config, but
266        // the important thing is that it runs and covers the macro code
267        // We expect an error due to missing config
268        assert!(result.is_ok() || result.is_err());
269    }
270
271    #[test]
272    fn test_macro_with_simple_pool() {
273        let dir = tempdir().unwrap();
274        let test_file_path = dir.path().join("test_simple.rs");
275        let mut test_file = File::create(&test_file_path).unwrap();
276        writeln!(test_file, r#"vespertide_migration!(db_pool);"#).unwrap();
277
278        let file = File::open(&test_file_path).unwrap();
279        let result = runtime_macros::emulate_functionlike_macro_expansion(
280            file,
281            &[("vespertide_migration", vespertide_migration_impl)],
282        );
283
284        assert!(result.is_ok() || result.is_err());
285    }
286
287    #[test]
288    fn test_macro_parsing_invalid_option() {
289        // Test that invalid options produce a compile error
290        let input: proc_macro2::TokenStream = "pool, invalid_option = \"value\"".parse().unwrap();
291        let output = vespertide_migration_impl(input);
292        let output_str = output.to_string();
293        // Should contain an error message about unsupported option
294        assert!(output_str.contains("unsupported option"));
295    }
296
297    #[test]
298    fn test_macro_parsing_valid_input() {
299        // Test that valid input is parsed correctly
300        // The macro will either succeed (if migrations dir exists and is empty)
301        // or fail with a migration loading error
302        let input: proc_macro2::TokenStream = "my_pool".parse().unwrap();
303        let output = vespertide_migration_impl(input);
304        let output_str = output.to_string();
305        // Should produce output (either success or migration loading error)
306        assert!(!output_str.is_empty());
307        // If error, it should mention "Failed to load"
308        // If success, it should contain "async"
309        assert!(
310            output_str.contains("async") || output_str.contains("Failed to load"),
311            "Unexpected output: {}",
312            output_str
313        );
314    }
315
316    #[test]
317    fn test_macro_parsing_with_version_table() {
318        let input: proc_macro2::TokenStream =
319            r#"pool, version_table = "custom_versions""#.parse().unwrap();
320        let output = vespertide_migration_impl(input);
321        let output_str = output.to_string();
322        assert!(!output_str.is_empty());
323    }
324
325    #[test]
326    fn test_macro_parsing_trailing_comma() {
327        let input: proc_macro2::TokenStream = "pool,".parse().unwrap();
328        let output = vespertide_migration_impl(input);
329        let output_str = output.to_string();
330        assert!(!output_str.is_empty());
331    }
332
333    fn test_column(name: &str) -> ColumnDef {
334        ColumnDef {
335            name: name.into(),
336            r#type: ColumnType::Simple(SimpleColumnType::Integer),
337            nullable: false,
338            default: None,
339            comment: None,
340            primary_key: None,
341            unique: None,
342            index: None,
343            foreign_key: None,
344        }
345    }
346
347    #[test]
348    fn test_build_migration_block_create_table() {
349        let migration = MigrationPlan {
350            version: 1,
351            comment: None,
352            created_at: None,
353            actions: vec![MigrationAction::CreateTable {
354                table: "users".into(),
355                columns: vec![test_column("id")],
356                constraints: vec![],
357            }],
358        };
359
360        let mut baseline = Vec::new();
361        let result = build_migration_block(&migration, &mut baseline);
362
363        assert!(result.is_ok());
364        let block = result.unwrap();
365        let block_str = block.to_string();
366
367        // Verify the generated block contains expected elements
368        assert!(block_str.contains("version < 1u32"));
369        assert!(block_str.contains("CREATE TABLE"));
370
371        // Verify baseline schema was updated
372        assert_eq!(baseline.len(), 1);
373        assert_eq!(baseline[0].name, "users");
374    }
375
376    #[test]
377    fn test_build_migration_block_add_column() {
378        // First create the table
379        let create_migration = MigrationPlan {
380            version: 1,
381            comment: None,
382            created_at: None,
383            actions: vec![MigrationAction::CreateTable {
384                table: "users".into(),
385                columns: vec![test_column("id")],
386                constraints: vec![],
387            }],
388        };
389
390        let mut baseline = Vec::new();
391        let _ = build_migration_block(&create_migration, &mut baseline);
392
393        // Now add a column
394        let add_column_migration = MigrationPlan {
395            version: 2,
396            comment: None,
397            created_at: None,
398            actions: vec![MigrationAction::AddColumn {
399                table: "users".into(),
400                column: Box::new(ColumnDef {
401                    name: "email".into(),
402                    r#type: ColumnType::Simple(SimpleColumnType::Text),
403                    nullable: true,
404                    default: None,
405                    comment: None,
406                    primary_key: None,
407                    unique: None,
408                    index: None,
409                    foreign_key: None,
410                }),
411                fill_with: None,
412            }],
413        };
414
415        let result = build_migration_block(&add_column_migration, &mut baseline);
416        assert!(result.is_ok());
417        let block = result.unwrap();
418        let block_str = block.to_string();
419
420        assert!(block_str.contains("version < 2u32"));
421        assert!(block_str.contains("ALTER TABLE"));
422        assert!(block_str.contains("ADD COLUMN"));
423    }
424
425    #[test]
426    fn test_build_migration_block_multiple_actions() {
427        let migration = MigrationPlan {
428            version: 1,
429            comment: None,
430            created_at: None,
431            actions: vec![
432                MigrationAction::CreateTable {
433                    table: "users".into(),
434                    columns: vec![test_column("id")],
435                    constraints: vec![],
436                },
437                MigrationAction::CreateTable {
438                    table: "posts".into(),
439                    columns: vec![test_column("id")],
440                    constraints: vec![],
441                },
442            ],
443        };
444
445        let mut baseline = Vec::new();
446        let result = build_migration_block(&migration, &mut baseline);
447
448        assert!(result.is_ok());
449        assert_eq!(baseline.len(), 2);
450    }
451
452    #[test]
453    fn test_generate_migration_code() {
454        let pool: Expr = syn::parse_str("db_pool").unwrap();
455        let version_table = "test_versions";
456
457        // Create a simple migration block
458        let migration = MigrationPlan {
459            version: 1,
460            comment: None,
461            created_at: None,
462            actions: vec![MigrationAction::CreateTable {
463                table: "users".into(),
464                columns: vec![test_column("id")],
465                constraints: vec![],
466            }],
467        };
468
469        let mut baseline = Vec::new();
470        let block = build_migration_block(&migration, &mut baseline).unwrap();
471
472        let generated = generate_migration_code(&pool, version_table, vec![block]);
473        let generated_str = generated.to_string();
474
475        // Verify the generated code structure
476        assert!(generated_str.contains("async"));
477        assert!(generated_str.contains("db_pool"));
478        assert!(generated_str.contains("test_versions"));
479        assert!(generated_str.contains("CREATE TABLE IF NOT EXISTS"));
480        assert!(generated_str.contains("SELECT MAX"));
481    }
482
483    #[test]
484    fn test_generate_migration_code_empty_migrations() {
485        let pool: Expr = syn::parse_str("pool").unwrap();
486        let version_table = "vespertide_version";
487
488        let generated = generate_migration_code(&pool, version_table, vec![]);
489        let generated_str = generated.to_string();
490
491        // Should still generate the wrapper code
492        assert!(generated_str.contains("async"));
493        assert!(generated_str.contains("vespertide_version"));
494    }
495
496    #[test]
497    fn test_generate_migration_code_multiple_blocks() {
498        let pool: Expr = syn::parse_str("connection").unwrap();
499
500        let mut baseline = Vec::new();
501
502        let migration1 = MigrationPlan {
503            version: 1,
504            comment: None,
505            created_at: None,
506            actions: vec![MigrationAction::CreateTable {
507                table: "users".into(),
508                columns: vec![test_column("id")],
509                constraints: vec![],
510            }],
511        };
512        let block1 = build_migration_block(&migration1, &mut baseline).unwrap();
513
514        let migration2 = MigrationPlan {
515            version: 2,
516            comment: None,
517            created_at: None,
518            actions: vec![MigrationAction::CreateTable {
519                table: "posts".into(),
520                columns: vec![test_column("id")],
521                constraints: vec![],
522            }],
523        };
524        let block2 = build_migration_block(&migration2, &mut baseline).unwrap();
525
526        let generated = generate_migration_code(&pool, "migrations", vec![block1, block2]);
527        let generated_str = generated.to_string();
528
529        // Both version checks should be present
530        assert!(generated_str.contains("version < 1u32"));
531        assert!(generated_str.contains("version < 2u32"));
532    }
533
534    #[test]
535    fn test_build_migration_block_generates_all_backends() {
536        let migration = MigrationPlan {
537            version: 1,
538            comment: None,
539            created_at: None,
540            actions: vec![MigrationAction::CreateTable {
541                table: "test_table".into(),
542                columns: vec![test_column("id")],
543                constraints: vec![],
544            }],
545        };
546
547        let mut baseline = Vec::new();
548        let result = build_migration_block(&migration, &mut baseline);
549        assert!(result.is_ok());
550
551        let block_str = result.unwrap().to_string();
552
553        // The generated block should have backend matching
554        assert!(block_str.contains("DatabaseBackend :: Postgres"));
555        assert!(block_str.contains("DatabaseBackend :: MySql"));
556        assert!(block_str.contains("DatabaseBackend :: Sqlite"));
557    }
558
559    #[test]
560    fn test_build_migration_block_with_delete_table() {
561        // First create the table
562        let create_migration = MigrationPlan {
563            version: 1,
564            comment: None,
565            created_at: None,
566            actions: vec![MigrationAction::CreateTable {
567                table: "temp_table".into(),
568                columns: vec![test_column("id")],
569                constraints: vec![],
570            }],
571        };
572
573        let mut baseline = Vec::new();
574        let _ = build_migration_block(&create_migration, &mut baseline);
575        assert_eq!(baseline.len(), 1);
576
577        // Now delete it
578        let delete_migration = MigrationPlan {
579            version: 2,
580            comment: None,
581            created_at: None,
582            actions: vec![MigrationAction::DeleteTable {
583                table: "temp_table".into(),
584            }],
585        };
586
587        let result = build_migration_block(&delete_migration, &mut baseline);
588        assert!(result.is_ok());
589        let block_str = result.unwrap().to_string();
590        assert!(block_str.contains("DROP TABLE"));
591
592        // Baseline should be empty after delete
593        assert_eq!(baseline.len(), 0);
594    }
595
596    #[test]
597    fn test_build_migration_block_with_index() {
598        let migration = MigrationPlan {
599            version: 1,
600            comment: None,
601            created_at: None,
602            actions: vec![MigrationAction::CreateTable {
603                table: "users".into(),
604                columns: vec![
605                    test_column("id"),
606                    ColumnDef {
607                        name: "email".into(),
608                        r#type: ColumnType::Simple(SimpleColumnType::Text),
609                        nullable: true,
610                        default: None,
611                        comment: None,
612                        primary_key: None,
613                        unique: None,
614                        index: Some(StrOrBoolOrArray::Bool(true)),
615                        foreign_key: None,
616                    },
617                ],
618                constraints: vec![],
619            }],
620        };
621
622        let mut baseline = Vec::new();
623        let result = build_migration_block(&migration, &mut baseline);
624        assert!(result.is_ok());
625
626        // Table should be normalized with index
627        let table = &baseline[0];
628        let normalized = table.clone().normalize();
629        assert!(normalized.is_ok());
630    }
631
632    #[test]
633    fn test_build_migration_block_error_nonexistent_table() {
634        // Try to add column to a table that doesn't exist - should fail
635        let migration = MigrationPlan {
636            version: 1,
637            comment: None,
638            created_at: None,
639            actions: vec![MigrationAction::AddColumn {
640                table: "nonexistent_table".into(),
641                column: Box::new(test_column("new_col")),
642                fill_with: None,
643            }],
644        };
645
646        let mut baseline = Vec::new();
647        let result = build_migration_block(&migration, &mut baseline);
648
649        assert!(result.is_err());
650        let err = result.unwrap_err();
651        assert!(err.contains("Failed to build queries for migration version 1"));
652    }
653
654    #[test]
655    fn test_vespertide_migration_impl_loading_error() {
656        // Save original CARGO_MANIFEST_DIR
657        let original = std::env::var("CARGO_MANIFEST_DIR").ok();
658
659        // Remove CARGO_MANIFEST_DIR to trigger loading error
660        unsafe {
661            std::env::remove_var("CARGO_MANIFEST_DIR");
662        }
663
664        let input: proc_macro2::TokenStream = "pool".parse().unwrap();
665        let output = vespertide_migration_impl(input);
666        let output_str = output.to_string();
667
668        // Should contain error about failed loading
669        assert!(
670            output_str.contains("Failed to load migrations at compile time"),
671            "Expected loading error, got: {}",
672            output_str
673        );
674
675        // Restore CARGO_MANIFEST_DIR
676        if let Some(val) = original {
677            unsafe {
678                std::env::set_var("CARGO_MANIFEST_DIR", val);
679            }
680        }
681    }
682
683    #[test]
684    fn test_vespertide_migration_impl_with_valid_project() {
685        use std::fs;
686
687        // Create a temporary directory with a valid vespertide project
688        let dir = tempdir().unwrap();
689        let project_dir = dir.path();
690
691        // Create vespertide.json config
692        let config_content = r#"{
693            "modelsDir": "models",
694            "migrationsDir": "migrations",
695            "tableNamingCase": "snake",
696            "columnNamingCase": "snake",
697            "modelFormat": "json"
698        }"#;
699        fs::write(project_dir.join("vespertide.json"), config_content).unwrap();
700
701        // Create empty models and migrations directories
702        fs::create_dir_all(project_dir.join("models")).unwrap();
703        fs::create_dir_all(project_dir.join("migrations")).unwrap();
704
705        // Save original CARGO_MANIFEST_DIR and set to temp dir
706        let original = std::env::var("CARGO_MANIFEST_DIR").ok();
707        unsafe {
708            std::env::set_var("CARGO_MANIFEST_DIR", project_dir);
709        }
710
711        let input: proc_macro2::TokenStream = "pool".parse().unwrap();
712        let output = vespertide_migration_impl(input);
713        let output_str = output.to_string();
714
715        // Should produce valid async code since there are no migrations
716        assert!(
717            output_str.contains("async"),
718            "Expected async block, got: {}",
719            output_str
720        );
721        assert!(
722            output_str.contains("CREATE TABLE IF NOT EXISTS"),
723            "Expected version table creation, got: {}",
724            output_str
725        );
726
727        // Restore CARGO_MANIFEST_DIR
728        if let Some(val) = original {
729            unsafe {
730                std::env::set_var("CARGO_MANIFEST_DIR", val);
731            }
732        } else {
733            unsafe {
734                std::env::remove_var("CARGO_MANIFEST_DIR");
735            }
736        }
737    }
738
739    #[test]
740    fn test_vespertide_migration_impl_with_migrations() {
741        use std::fs;
742
743        // Create a temporary directory with a valid vespertide project and migrations
744        let dir = tempdir().unwrap();
745        let project_dir = dir.path();
746
747        // Create vespertide.json config
748        let config_content = r#"{
749            "modelsDir": "models",
750            "migrationsDir": "migrations",
751            "tableNamingCase": "snake",
752            "columnNamingCase": "snake",
753            "modelFormat": "json"
754        }"#;
755        fs::write(project_dir.join("vespertide.json"), config_content).unwrap();
756
757        // Create models and migrations directories
758        fs::create_dir_all(project_dir.join("models")).unwrap();
759        fs::create_dir_all(project_dir.join("migrations")).unwrap();
760
761        // Create a migration file
762        let migration_content = r#"{
763            "version": 1,
764            "actions": [
765                {
766                    "type": "create_table",
767                    "table": "users",
768                    "columns": [
769                        {"name": "id", "type": "integer", "nullable": false}
770                    ],
771                    "constraints": []
772                }
773            ]
774        }"#;
775        fs::write(
776            project_dir.join("migrations").join("0001_initial.json"),
777            migration_content,
778        )
779        .unwrap();
780
781        // Save original CARGO_MANIFEST_DIR and set to temp dir
782        let original = std::env::var("CARGO_MANIFEST_DIR").ok();
783        unsafe {
784            std::env::set_var("CARGO_MANIFEST_DIR", project_dir);
785        }
786
787        let input: proc_macro2::TokenStream = "pool".parse().unwrap();
788        let output = vespertide_migration_impl(input);
789        let output_str = output.to_string();
790
791        // Should produce valid async code with migration
792        assert!(
793            output_str.contains("async"),
794            "Expected async block, got: {}",
795            output_str
796        );
797
798        // Restore CARGO_MANIFEST_DIR
799        if let Some(val) = original {
800            unsafe {
801                std::env::set_var("CARGO_MANIFEST_DIR", val);
802            }
803        } else {
804            unsafe {
805                std::env::remove_var("CARGO_MANIFEST_DIR");
806            }
807        }
808    }
809}