Skip to main content

dibs_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use unsynn::{LiteralString, Parse, ToTokens, TokenIter};
4
5/// Register a migration function.
6///
7/// The version is automatically derived from the filename. For example,
8/// a file named `m_2026_01_18_173711_create_users.rs` will have version
9/// `2026_01_18_173711-create_users`.
10///
11/// # Example
12///
13/// ```ignore
14/// // In file: src/migrations/m_2026_01_18_create_users.rs
15/// #[dibs::migration]
16/// async fn migrate(ctx: &mut MigrationContext) -> MigrationResult<()> {
17///     ctx.execute("CREATE TABLE users (...)").await?;
18///     Ok(())
19/// }
20/// ```
21///
22/// Use `MigrationResult` instead of `Result` to enable `#[track_caller]` -
23/// when an error occurs, the exact source location (file:line:column) is captured.
24#[proc_macro_attribute]
25pub fn migration(attr: TokenStream, item: TokenStream) -> TokenStream {
26    // Convert to proc_macro2 and create unsynn TokenIter
27    let attr2: proc_macro2::TokenStream = attr.into();
28    let mut tokens = TokenIter::new(attr2);
29
30    // Version is optional - if not provided, it will be derived from filename
31    let explicit_version = LiteralString::parse(&mut tokens).ok();
32
33    let item: proc_macro2::TokenStream = item.into();
34
35    // Extract function name from the item
36    let item_str = item.to_string();
37    let fn_name = match extract_fn_name(&item_str) {
38        Some(name) => name,
39        None => {
40            return quote! { compile_error!("expected function"); }.into();
41        }
42    };
43    let fn_ident = quote::format_ident!("{}", fn_name);
44
45    let version_expr = if let Some(version) = explicit_version {
46        let version_lit = version.to_token_stream();
47        quote! { #version_lit }
48    } else {
49        // Derive version from filename at compile time
50        // file!() returns something like "src/migrations/m_2026_01_18_173711_create_users.rs"
51        // We extract the filename, strip the .rs and leading m_, then format as version
52        quote! {
53            {
54                const FILE: &str = file!();
55                // Extract just the filename
56                const fn find_last_slash(s: &[u8]) -> usize {
57                    let mut i = s.len();
58                    while i > 0 {
59                        i -= 1;
60                        if s[i] == b'/' || s[i] == b'\\' {
61                            return i + 1;
62                        }
63                    }
64                    0
65                }
66                const SLASH_POS: usize = find_last_slash(FILE.as_bytes());
67                const FILENAME: &str = unsafe {
68                    // SAFETY: SLASH_POS is always a valid index
69                    std::str::from_utf8_unchecked(FILE.as_bytes().split_at(SLASH_POS).1)
70                };
71                // Strip .rs extension and leading m_
72                ::dibs::__derive_migration_version(FILENAME)
73            }
74        }
75    };
76
77    quote! {
78        #item
79
80        ::dibs::inventory::submit! {
81            ::dibs::Migration {
82                version: #version_expr,
83                name: stringify!(#fn_ident),
84                run: |ctx| Box::pin(#fn_ident(ctx)),
85                source_file: (env!("CARGO_MANIFEST_DIR"), file!()),
86            }
87        }
88    }
89    .into()
90}
91
92fn extract_fn_name(s: &str) -> Option<&str> {
93    // Simple extraction: find "fn " and take the next identifier
94    let idx = s.find("fn ")?;
95    let rest = &s[idx + 3..].trim_start();
96    let end = rest.find(|c: char| !c.is_alphanumeric() && c != '_')?;
97    Some(&rest[..end])
98}