dtype_variant_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Ident, Token, parse_macro_input, punctuated::Punctuated};
4
5mod derive;
6mod grouped_matcher;
7mod matcher_gen;
8
9pub(crate) fn dtype_variant_path() -> syn::Path {
10    let found_crate = proc_macro_crate::crate_name("dtype_variant")
11        .expect("dtype_variant is present in `Cargo.toml`");
12    match found_crate {
13        proc_macro_crate::FoundCrate::Itself => format_ident!("crate").into(),
14        proc_macro_crate::FoundCrate::Name(name) => {
15            // Parse crate name safely - fall back to simple identifier if parsing fails
16            syn::parse_str(name.as_str()).unwrap_or_else(|_| {
17                // Create a simple path from the crate name
18                syn::Path::from(format_ident!(
19                    "{}",
20                    name.as_str().replace('-', "_")
21                ))
22            })
23        }
24    }
25}
26
27#[proc_macro_derive(DType, attributes(dtype, dtype_grouped_matcher))]
28pub fn dtype_derive(input: TokenStream) -> TokenStream {
29    derive::dtype_derive_impl(input)
30}
31
32struct DTypeInput {
33    variants: Punctuated<Ident, Token![,]>,
34}
35
36impl syn::parse::Parse for DTypeInput {
37    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
38        let content;
39        syn::bracketed!(content in input);
40        Ok(DTypeInput {
41            variants: content.parse_terminated(Ident::parse, Token![,])?,
42        })
43    }
44}
45
46#[proc_macro]
47pub fn build_dtype_tokens(input: TokenStream) -> TokenStream {
48    let DTypeInput { variants } = parse_macro_input!(input as DTypeInput);
49
50    let expanded = variants.iter().map(|variant| {
51        let variant_name = format_ident!("{}Variant", variant);
52
53        quote! {
54            #[derive(Default, Debug)]
55            pub struct #variant_name;
56        }
57    });
58
59    quote! {
60        #(#expanded)*
61    }
62    .into()
63}