vespertide_macro/
lib.rs

1// MigrationOptions and MigrationError are now in vespertide-core
2
3mod loader;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::parse::{Parse, ParseStream};
8use syn::{Expr, Ident, Token, parse_macro_input};
9use vespertide_query::{DatabaseBackend, build_plan_queries};
10
11struct MacroInput {
12    pool: Expr,
13    version_table: Option<String>,
14}
15
16impl Parse for MacroInput {
17    fn parse(input: ParseStream) -> syn::Result<Self> {
18        let pool = input.parse()?;
19        let mut version_table = None;
20
21        while !input.is_empty() {
22            input.parse::<Token![,]>()?;
23            if input.is_empty() {
24                break;
25            }
26
27            let key: Ident = input.parse()?;
28            if key == "version_table" {
29                input.parse::<Token![=]>()?;
30                let value: syn::LitStr = input.parse()?;
31                version_table = Some(value.value());
32            } else {
33                return Err(syn::Error::new(
34                    key.span(),
35                    "unsupported option for vespertide_migration!",
36                ));
37            }
38        }
39
40        Ok(MacroInput {
41            pool,
42            version_table,
43        })
44    }
45}
46
47/// Zero-runtime migration entry point.
48#[proc_macro]
49pub fn vespertide_migration(input: TokenStream) -> TokenStream {
50    let input = parse_macro_input!(input as MacroInput);
51    let pool = &input.pool;
52    let version_table = input
53        .version_table
54        .unwrap_or_else(|| "vespertide_version".to_string());
55
56    // Load migration files and build SQL at compile time
57    let migrations = match loader::load_migrations_at_compile_time() {
58        Ok(migrations) => migrations,
59        Err(e) => {
60            return syn::Error::new(
61                proc_macro2::Span::call_site(),
62                format!("Failed to load migrations at compile time: {}", e),
63            )
64            .to_compile_error()
65            .into();
66        }
67    };
68
69    // Build SQL for each migration
70    let mut migration_blocks = Vec::new();
71    for migration in &migrations {
72        let version = migration.version;
73        let queries = match build_plan_queries(migration) {
74            Ok(queries) => queries,
75            Err(e) => {
76                return syn::Error::new(
77                    proc_macro2::Span::call_site(),
78                    format!(
79                        "Failed to build queries for migration version {}: {}",
80                        version, e
81                    ),
82                )
83                .to_compile_error()
84                .into();
85            }
86        };
87
88        // Pre-generate SQL for all backends at compile time
89        let sql_statements: Vec<_> = queries
90            .iter()
91            .map(|q| {
92                let pg_sql = q.build(DatabaseBackend::Postgres);
93                let mysql_sql = q.build(DatabaseBackend::MySql);
94                let sqlite_sql = q.build(DatabaseBackend::Sqlite);
95                quote! {
96                    match backend {
97                        sea_orm::DatabaseBackend::Postgres => #pg_sql,
98                        sea_orm::DatabaseBackend::MySql => #mysql_sql,
99                        sea_orm::DatabaseBackend::Sqlite => #sqlite_sql,
100                        _ => #pg_sql, // Fallback to PostgreSQL syntax for unknown backends
101                    }
102                }
103            })
104            .collect();
105
106        // Generate version guard and SQL execution block
107        let block = quote! {
108            if version < #version {
109                // Begin transaction
110                let txn = __pool.begin().await.map_err(|e| {
111                    ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
112                })?;
113
114                // Execute SQL statements
115                #(
116                    {
117                        let sql: &str = #sql_statements;
118                        let stmt = sea_orm::Statement::from_string(backend, sql);
119                        txn.execute_raw(stmt).await.map_err(|e| {
120                            ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
121                        })?;
122                    }
123                )*
124
125                // Insert version record for this migration
126                let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
127                let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
128                let stmt = sea_orm::Statement::from_string(backend, insert_sql);
129                txn.execute_raw(stmt).await.map_err(|e| {
130                    ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
131                })?;
132
133                // Commit transaction
134                txn.commit().await.map_err(|e| {
135                    ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
136                })?;
137            }
138        };
139
140        migration_blocks.push(block);
141    }
142
143    // Emit final generated async block
144    let generated = quote! {
145        async {
146            use sea_orm::{ConnectionTrait, TransactionTrait};
147            let __pool = #pool;
148            let version_table = #version_table;
149            let backend = __pool.get_database_backend();
150
151            // Create version table if it does not exist
152            // Table structure: version (INTEGER PRIMARY KEY), created_at (timestamp)
153            let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
154            let create_table_sql = format!(
155                "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
156                version_table
157            );
158            let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
159            __pool.execute_raw(stmt).await.map_err(|e| {
160                ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
161            })?;
162
163            // Read current maximum version (latest applied migration)
164            let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
165            let stmt = sea_orm::Statement::from_string(backend, select_sql);
166            let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
167                ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
168            })?;
169
170            let mut version = version_result
171                .and_then(|row| row.try_get::<i32>("", "version").ok())
172                .unwrap_or(0) as u32;
173
174            // Execute each migration block
175            #(#migration_blocks)*
176
177            Ok::<(), ::vespertide::MigrationError>(())
178        }
179    };
180
181    generated.into()
182}