a2x_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields};
4
5/// Derive macro for the Spanned trait.
6///
7/// This macro automatically implements the `Spanned` trait by:
8/// - For structs: looking for a field of type `Option<SrcLoc>`
9/// - For enums: delegating to each variant's `span()` method (assumes each variant implements Spanned)
10#[proc_macro_derive(Spanned)]
11pub fn derive_spanned(input: TokenStream) -> TokenStream {
12    let input = parse_macro_input!(input as DeriveInput);
13    let name = &input.ident;
14
15    let expanded = match &input.data {
16        Data::Struct(data) => {
17            let field_name = match &data.fields {
18                Fields::Named(fields) => {
19                    fields.named.iter()
20                        .find(|f| is_option_srcloc(&f.ty))
21                        .and_then(|f| f.ident.as_ref())
22                        .expect("Spanned requires a field of type Option<SrcLoc>")
23                    }
24                _ => panic!("Spanned can only be derived for structs with named fields"),
25            };
26
27            quote! {
28                impl Spanned for #name {
29                    fn span(&self) -> Option<&SrcLoc> {
30                        self.#field_name.as_ref()
31                    }
32                }
33            }
34        }
35        Data::Enum(data) => {
36            let match_arms = data.variants.iter().map(|variant| {
37                let variant_name = &variant.ident;
38                match &variant.fields {
39                    Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
40                        // Single tuple field - delegate to it
41                        quote! {
42                            #name::#variant_name(inner) => inner.span()
43                        }
44                    }
45                    Fields::Unnamed(fields) if fields.unnamed.len() > 1 => {
46                        // Multiple tuple fields - delegate to the first one
47                        quote! {
48                            #name::#variant_name(inner, ..) => inner.span()
49                        }
50                    }
51                    Fields::Named(fields) => {
52                         let field_names: Vec<_> = fields.named.iter()
53                            .filter_map(|f| f.ident.as_ref())
54                            .collect();
55
56                        if let Some(src_loc_field) = fields.named.iter()
57                            .find(|f| is_option_srcloc(&f.ty))
58                            .and_then(|f| f.ident.as_ref())
59                            {
60                            // Has a SrcLoc field
61                            quote! {
62                                #name::#variant_name { ref #src_loc_field, .. } => #src_loc_field.as_ref()
63                            }
64                        } else if let Some(first_field) = field_names.first() {
65                            // Delegate to first field
66                            quote! {
67                                #name::#variant_name { #first_field, .. } => #first_field.span()
68                            }
69                        } else {
70                            panic!("Enum variant {} has no fields", variant_name);
71                        }
72                    }
73                    Fields::Unit => {
74                        panic!("Unit enum variants cannot implement Spanned")
75                    }
76                    &Fields::Unnamed(_) => todo!()
77                }
78            });
79
80            quote! {
81                impl Spanned for #name {
82                    fn span(&self) -> Option<&SrcLoc> {
83                        match self {
84                            #(#match_arms,)*
85                        }
86                    }
87                }
88            }
89        }
90        Data::Union(_) => panic!("Spanned cannot be derived for unions"),
91    };
92
93    TokenStream::from(expanded)
94}
95
96fn is_option_srcloc(ty: &syn::Type) -> bool {
97    if let syn::Type::Path(type_path) = ty {
98        if let Some(segment) = type_path.path.segments.last() {
99            if segment.ident == "Option" {
100                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
101                    if let Some(syn::GenericArgument::Type(syn::Type::Path(inner_path))) = args.args.first() {
102                        return inner_path.path.segments.last()
103                            .map(|seg| seg.ident == "SrcLoc")
104                            .unwrap_or(false);
105                    }
106                }
107            }
108        }
109    }
110    false
111}