1mod loader;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::parse::{Parse, ParseStream};
8use syn::{Expr, Ident, Token, parse_macro_input};
9use vespertide_query::{DatabaseBackend, build_plan_queries};
10
11struct MacroInput {
12 pool: Expr,
13 version_table: Option<String>,
14}
15
16impl Parse for MacroInput {
17 fn parse(input: ParseStream) -> syn::Result<Self> {
18 let pool = input.parse()?;
19 let mut version_table = None;
20
21 while !input.is_empty() {
22 input.parse::<Token![,]>()?;
23 if input.is_empty() {
24 break;
25 }
26
27 let key: Ident = input.parse()?;
28 if key == "version_table" {
29 input.parse::<Token![=]>()?;
30 let value: syn::LitStr = input.parse()?;
31 version_table = Some(value.value());
32 } else {
33 return Err(syn::Error::new(
34 key.span(),
35 "unsupported option for vespertide_migration!",
36 ));
37 }
38 }
39
40 Ok(MacroInput {
41 pool,
42 version_table,
43 })
44 }
45}
46
47#[proc_macro]
49pub fn vespertide_migration(input: TokenStream) -> TokenStream {
50 let input = parse_macro_input!(input as MacroInput);
51 let pool = &input.pool;
52 let version_table = input
53 .version_table
54 .unwrap_or_else(|| "vespertide_version".to_string());
55
56 let migrations = match loader::load_migrations_at_compile_time() {
58 Ok(migrations) => migrations,
59 Err(e) => {
60 return syn::Error::new(
61 proc_macro2::Span::call_site(),
62 format!("Failed to load migrations at compile time: {}", e),
63 )
64 .to_compile_error()
65 .into();
66 }
67 };
68
69 let mut migration_blocks = Vec::new();
71 for migration in &migrations {
72 let version = migration.version;
73 let queries = match build_plan_queries(migration) {
74 Ok(queries) => queries,
75 Err(e) => {
76 return syn::Error::new(
77 proc_macro2::Span::call_site(),
78 format!(
79 "Failed to build queries for migration version {}: {}",
80 version, e
81 ),
82 )
83 .to_compile_error()
84 .into();
85 }
86 };
87
88 let sql_statements: Vec<_> = queries
90 .iter()
91 .map(|q| {
92 let pg_sql = q.build(DatabaseBackend::Postgres);
93 let mysql_sql = q.build(DatabaseBackend::MySql);
94 let sqlite_sql = q.build(DatabaseBackend::Sqlite);
95 quote! {
96 match backend {
97 sea_orm::DatabaseBackend::Postgres => #pg_sql,
98 sea_orm::DatabaseBackend::MySql => #mysql_sql,
99 sea_orm::DatabaseBackend::Sqlite => #sqlite_sql,
100 _ => #pg_sql, }
102 }
103 })
104 .collect();
105
106 let block = quote! {
108 if version < #version {
109 let txn = __pool.begin().await.map_err(|e| {
111 ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
112 })?;
113
114 #(
116 {
117 let sql: &str = #sql_statements;
118 let stmt = sea_orm::Statement::from_string(backend, sql);
119 txn.execute_raw(stmt).await.map_err(|e| {
120 ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
121 })?;
122 }
123 )*
124
125 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
127 let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
128 let stmt = sea_orm::Statement::from_string(backend, insert_sql);
129 txn.execute_raw(stmt).await.map_err(|e| {
130 ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
131 })?;
132
133 txn.commit().await.map_err(|e| {
135 ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
136 })?;
137 }
138 };
139
140 migration_blocks.push(block);
141 }
142
143 let generated = quote! {
145 async {
146 use sea_orm::{ConnectionTrait, TransactionTrait};
147 let __pool = #pool;
148 let version_table = #version_table;
149 let backend = __pool.get_database_backend();
150
151 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
154 let create_table_sql = format!(
155 "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
156 version_table
157 );
158 let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
159 __pool.execute_raw(stmt).await.map_err(|e| {
160 ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
161 })?;
162
163 let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
165 let stmt = sea_orm::Statement::from_string(backend, select_sql);
166 let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
167 ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
168 })?;
169
170 let mut version = version_result
171 .and_then(|row| row.try_get::<i32>("", "version").ok())
172 .unwrap_or(0) as u32;
173
174 #(#migration_blocks)*
176
177 Ok::<(), ::vespertide::MigrationError>(())
178 }
179 };
180
181 generated.into()
182}