vespertide_macro/
lib.rs

1// MigrationOptions and MigrationError are now in vespertide-core
2
3use proc_macro::TokenStream;
4use quote::quote;
5use std::env;
6use std::fs;
7use std::path::PathBuf;
8use syn::parse::{Parse, ParseStream};
9use syn::{Expr, Ident, Token, parse_macro_input};
10use vespertide_config::VespertideConfig;
11use vespertide_core::MigrationPlan;
12use vespertide_query::build_plan_queries;
13
14struct MacroInput {
15    pool: Expr,
16    version_table: Option<String>,
17}
18
19impl Parse for MacroInput {
20    fn parse(input: ParseStream) -> syn::Result<Self> {
21        let pool = input.parse()?;
22        let mut version_table = None;
23
24        while !input.is_empty() {
25            input.parse::<Token![,]>()?;
26            if input.is_empty() {
27                break;
28            }
29
30            let key: Ident = input.parse()?;
31            if key == "version_table" {
32                input.parse::<Token![=]>()?;
33                let value: syn::LitStr = input.parse()?;
34                version_table = Some(value.value());
35            } else {
36                return Err(syn::Error::new(
37                    key.span(),
38                    "unsupported option for vespertide_migration!",
39                ));
40            }
41        }
42
43        Ok(MacroInput {
44            pool,
45            version_table,
46        })
47    }
48}
49
50/// Zero-runtime migration entry point.
51#[proc_macro]
52pub fn vespertide_migration(input: TokenStream) -> TokenStream {
53    let input = parse_macro_input!(input as MacroInput);
54    let pool = &input.pool;
55    let version_table = input
56        .version_table
57        .unwrap_or_else(|| "vespertide_version".to_string());
58
59    // Load migration files and build SQL at compile time
60    let migrations = match load_migrations_at_compile_time() {
61        Ok(migrations) => migrations,
62        Err(e) => {
63            return syn::Error::new(
64                proc_macro2::Span::call_site(),
65                format!("Failed to load migrations at compile time: {}", e),
66            )
67            .to_compile_error()
68            .into();
69        }
70    };
71
72    // Build SQL for each migration
73    let mut migration_blocks = Vec::new();
74    for migration in &migrations {
75        let version = migration.version;
76        let queries = match build_plan_queries(migration) {
77            Ok(queries) => queries,
78            Err(e) => {
79                return syn::Error::new(
80                    proc_macro2::Span::call_site(),
81                    format!(
82                        "Failed to build queries for migration version {}: {}",
83                        version, e
84                    ),
85                )
86                .to_compile_error()
87                .into();
88            }
89        };
90
91        // Statically embed SQL text and bind parameters (as values)
92        let sql_statements: Vec<_> = queries
93            .iter()
94            .map(|q| {
95                let sql = &q.sql;
96                let binds = &q.binds;
97                let value_tokens = binds.iter().map(|b| {
98                    quote! { sea_orm::Value::String(Some(#b.to_string())) }
99                });
100                quote! { (#sql, vec![#(#value_tokens),*]) }
101            })
102            .collect();
103
104        // Generate version guard and SQL execution block
105        let block = quote! {
106            if version < #version {
107                // Begin transaction
108                let txn = __pool.begin().await.map_err(|e| {
109                    ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
110                })?;
111
112                // Execute SQL statements
113                #(
114                    {
115                        let (sql, values) = #sql_statements;
116                        let stmt = sea_orm::Statement::from_sql_and_values(backend, sql, values);
117                        txn.execute_raw(stmt).await.map_err(|e| {
118                            ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL: {}", e))
119                        })?;
120                    }
121                )*
122
123                // Insert version record for this migration
124                let stmt = sea_orm::Statement::from_sql_and_values(
125                    backend,
126                    &format!("INSERT INTO {} (version) VALUES (?)", version_table),
127                    vec![sea_orm::Value::Int(Some(#version as i32))],
128                );
129                txn.execute_raw(stmt).await.map_err(|e| {
130                    ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
131                })?;
132
133                // Commit transaction
134                txn.commit().await.map_err(|e| {
135                    ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
136                })?;
137            }
138        };
139
140        migration_blocks.push(block);
141    }
142
143    // Emit final generated async block
144    let generated = quote! {
145        async {
146            use sea_orm::{ConnectionTrait, TransactionTrait};
147            let __pool = #pool;
148            let version_table = #version_table;
149            let backend = __pool.get_database_backend();
150
151            // Create version table if it does not exist
152            // Table structure: version (INTEGER PRIMARY KEY), created_at (timestamp)
153            let create_table_sql = format!(
154                "CREATE TABLE IF NOT EXISTS {} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
155                version_table
156            );
157            let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
158            __pool.execute_raw(stmt).await.map_err(|e| {
159                ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
160            })?;
161
162            // Read current maximum version (latest applied migration)
163            let stmt = sea_orm::Statement::from_string(
164                backend,
165                format!("SELECT MAX(version) as version FROM {}", version_table),
166            );
167            let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
168                ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
169            })?;
170
171            let mut version = version_result
172                .and_then(|row| row.try_get::<i32>("", "version").ok())
173                .unwrap_or(0) as u32;
174
175            // Execute each migration block
176            #(#migration_blocks)*
177
178            Ok::<(), ::vespertide::MigrationError>(())
179        }
180    };
181
182    generated.into()
183}
184
185fn load_migrations_at_compile_time() -> Result<Vec<MigrationPlan>, Box<dyn std::error::Error>> {
186    // Locate project root from CARGO_MANIFEST_DIR
187    let manifest_dir = env::var("CARGO_MANIFEST_DIR")
188        .map_err(|_| "CARGO_MANIFEST_DIR environment variable not set")?;
189    let project_root = PathBuf::from(manifest_dir);
190
191    // Read vespertide.json
192    let config_path = project_root.join("vespertide.json");
193    let config: VespertideConfig = if config_path.exists() {
194        let content = fs::read_to_string(&config_path)?;
195        serde_json::from_str(&content)?
196    } else {
197        // Fall back to defaults if config is missing
198        VespertideConfig::default()
199    };
200
201    // Read migrations directory
202    let migrations_dir = project_root.join(config.migrations_dir());
203    if !migrations_dir.exists() {
204        return Ok(Vec::new());
205    }
206
207    let mut plans = Vec::new();
208    let entries = fs::read_dir(&migrations_dir)?;
209
210    for entry in entries {
211        let entry = entry?;
212        let path = entry.path();
213        if path.is_file() {
214            let ext = path.extension().and_then(|s| s.to_str());
215            if ext == Some("json") || ext == Some("yaml") || ext == Some("yml") {
216                let content = fs::read_to_string(&path)?;
217
218                let plan: MigrationPlan = if ext == Some("json") {
219                    serde_json::from_str(&content)?
220                } else {
221                    serde_yaml::from_str(&content)?
222                };
223
224                plans.push(plan);
225            }
226        }
227    }
228
229    // Sort by version
230    plans.sort_by_key(|p| p.version);
231    Ok(plans)
232}