dibs_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use unsynn::{LiteralString, Parse, ToTokens, TokenIter};
4
5/// Register a migration function.
6///
7/// # Example
8///
9/// ```ignore
10/// #[dibs::migration("2026-01-17-create-users")]
11/// async fn create_users(ctx: &mut MigrationContext) -> Result<()> {
12///     ctx.execute("CREATE TABLE users (...)").await?;
13///     Ok(())
14/// }
15/// ```
16#[proc_macro_attribute]
17pub fn migration(attr: TokenStream, item: TokenStream) -> TokenStream {
18    // Convert to proc_macro2 and create unsynn TokenIter
19    let attr2: proc_macro2::TokenStream = attr.into();
20    let mut tokens = TokenIter::new(attr2);
21
22    let version = match LiteralString::parse(&mut tokens) {
23        Ok(v) => v,
24        Err(e) => {
25            let msg = format!("expected string literal for migration version: {e}");
26            return quote! { compile_error!(#msg); }.into();
27        }
28    };
29
30    let version_lit = version.to_token_stream();
31
32    let item: proc_macro2::TokenStream = item.into();
33
34    // Extract function name from the item
35    let item_str = item.to_string();
36    let fn_name = match extract_fn_name(&item_str) {
37        Some(name) => name,
38        None => {
39            return quote! { compile_error!("expected function"); }.into();
40        }
41    };
42    let fn_ident = quote::format_ident!("{}", fn_name);
43    let registration_ident = quote::format_ident!(
44        "__DIBS_MIGRATION_{}",
45        fn_name.to_uppercase().replace('-', "_")
46    );
47
48    quote! {
49        #item
50
51        #[allow(non_upper_case_globals)]
52        #[::dibs::inventory::collect]
53        static #registration_ident: ::dibs::Migration = ::dibs::Migration {
54            version: #version_lit,
55            name: stringify!(#fn_ident),
56            run: |ctx| Box::pin(#fn_ident(ctx)),
57        };
58    }
59    .into()
60}
61
62fn extract_fn_name(s: &str) -> Option<&str> {
63    // Simple extraction: find "fn " and take the next identifier
64    let idx = s.find("fn ")?;
65    let rest = &s[idx + 3..].trim_start();
66    let end = rest.find(|c: char| !c.is_alphanumeric() && c != '_')?;
67    Some(&rest[..end])
68}