locate_error_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::{quote, quote_spanned};
4use syn::{
5    Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Generics, Ident, Type,
6    parse_macro_input, parse_quote, spanned::Spanned,
7};
8
9/// This macro is used to implement `From` on an enum or struct and locating
10/// where the `From` impl is called. Typically used for tracking sources of bubbling errors with `thiserror`.
11#[proc_macro_derive(Locate, attributes(locate_from))]
12pub fn locate(input: TokenStream) -> TokenStream {
13    let input = parse_macro_input!(input as DeriveInput);
14    let ident = &input.ident;
15    let generics = &input.generics;
16
17    let from_attributes: Vec<Attribute> = parse_quote!(
18        #[allow(
19            deprecated,
20            unused_qualifications,
21            clippy::elidable_lifetime_names,
22            clippy::needless_lifetimes,
23        )]
24        #[automatically_derived]
25    );
26
27    match &input.data {
28        Data::Enum(data) => process_enum(data, &from_attributes, generics, ident),
29        Data::Struct(data) => process_struct(data, &from_attributes, generics, ident),
30        _ => TokenStream::from(quote! {
31            compile_error!("Locate can only be derived for enums or structs");
32        }),
33    }
34}
35
36fn process_enum(
37    data: &DataEnum,
38    from_attributes: &[Attribute],
39    generics: &Generics,
40    ident: &Ident,
41) -> TokenStream {
42    let mut from_impls = vec![];
43    let mut n_has_locate_from = 0;
44    for variant in &data.variants {
45        let variant_name = &variant.ident;
46        let fields = &variant.fields;
47
48        match &fields {
49            Fields::Unnamed(fields) => {
50                for field in fields.unnamed.iter() {
51                    if let Some(index) = locate_from_attr_index(&field.attrs) {
52                        if fields.unnamed.len() != 2 {
53                            return TokenStream::from(quote_spanned! {
54                                variant.ident.span() => compile_error!("Locate requires enums variants with the #[locate_from] attribute to have exactly two fields, one for the source and one for the location");
55                            });
56                        }
57                        if let Some(other_field) = fields.unnamed.iter().nth((index + 1) % 2) {
58                            if !is_location_type(&other_field.ty) {
59                                return TokenStream::from(quote_spanned! {
60                                    other_field.ident.span() => compile_error!("Variants with #[locate_from] must have a field of type `locate_from::Location`");
61                                });
62                            }
63                        }
64                        n_has_locate_from += 1;
65                        if let Type::Path(path) = &field.ty {
66                            let field_type = &path.path;
67                            from_impls.push(quote! {
68                                #(#from_attributes)*
69                                impl #generics ::core::convert::From<#field_type> for #ident #generics {
70                                    #[track_caller]
71                                    fn from(value: #field_type) -> Self {
72                                        let location = ::std::panic::Location::caller();
73                                        #ident::#variant_name {
74                                            0: value,
75                                            1: ::locate_error::Location {
76                                                file: location.file().to_string(),
77                                                line: location.line(),
78                                                column: location.column(),
79                                            }
80                                        }
81                                    }
82                                }
83                            });
84                        }
85                    }
86                }
87            }
88            Fields::Named(fields) => {
89                for field in fields.named.iter() {
90                    // Field name will be present for named fields
91                    let field_name = field.ident.as_ref().unwrap();
92                    if locate_from_attr_index(&field.attrs).is_some() {
93                        let has_location_field = fields.named.iter().any(|f| {
94                            f.ident.as_ref().is_some_and(|name| name == "location")
95                                && is_location_type(&f.ty)
96                        });
97
98                        if !has_location_field {
99                            return TokenStream::from(quote_spanned! {
100                                variant.ident.span() => compile_error!("Variants with #[locate_from] must have a field named 'location' of type `locate_from::Location`");
101                            });
102                        }
103
104                        if fields.named.len() != 2 {
105                            return TokenStream::from(quote_spanned! {
106                                variant.ident.span() => compile_error!("Locate requires enums variants with the #[locate_from] attribute to have exactly two fields, one for the source and one for the location");
107                            });
108                        }
109
110                        n_has_locate_from += 1;
111                        if let Type::Path(path) = &field.ty {
112                            let field_type = &path.path;
113                            from_impls.push(quote! {
114                                #(#from_attributes)*
115                                impl #generics ::core::convert::From<#field_type> for #ident #generics {
116                                    #[track_caller]
117                                    fn from(value: #field_type) -> Self {
118                                        let location = ::std::panic::Location::caller();
119                                        #ident::#variant_name {
120                                            #field_name:value,
121                                            location: ::locate_error::Location {
122                                                file: location.file().to_string(),
123                                                line: location.line(),
124                                                column: location.column(),
125                                            }
126                                        }
127                                    }
128                                }
129                            });
130                        }
131                    }
132                }
133            }
134            Fields::Unit => {}
135        }
136    }
137
138    if n_has_locate_from == 0 {
139        return TokenStream::from(quote! {
140            compile_error!("Locate requires at least one variant with the #[locate_from] attribute (otherwise this macro is effectively a no-op)");
141        });
142    }
143
144    let expanded = quote! {
145        #(#from_impls)*
146    };
147
148    TokenStream::from(expanded)
149}
150
151fn process_struct(
152    data: &DataStruct,
153    from_attributes: &[Attribute],
154    generics: &Generics,
155    ident: &Ident,
156) -> TokenStream {
157    let mut from_impl: proc_macro2::TokenStream = quote! {};
158    // Find fields with locate_from attribute
159    let locate_from_fields: Vec<_> = data
160        .fields
161        .iter()
162        .filter(|field| locate_from_attr_index(&field.attrs).is_some())
163        .collect();
164
165    // Check if there's exactly one field with locate_from
166    if locate_from_fields.is_empty() || locate_from_fields.len() > 1 {
167        let error_message = format!(
168            "Locate requires exactly one field marked with #[locate_from], found {:?}",
169            locate_from_fields.len()
170        );
171        return TokenStream::from(quote! {
172            compile_error!(#error_message);
173        });
174    }
175
176    // There can be at most 2 fields (one with the locate_from attribute and one with the location field)
177    if data.fields.len() > 2 {
178        return TokenStream::from(quote! {
179            compile_error!("Locate requires structs to have only a 'source' field (with the #[locate_from] attribute) and a 'location' field");
180        });
181    }
182
183    // Check if there's a field named "location"
184    let has_location_field = data
185        .fields
186        .iter()
187        .any(|field| field.ident.as_ref().is_some_and(|name| name == "location"));
188
189    if !has_location_field {
190        return TokenStream::from(quote! {
191            compile_error!("Locate requires structs to have a field named 'location' of type `locate_from::Location`");
192        });
193    }
194
195    let field = locate_from_fields.first().unwrap();
196    let field_name = field.ident.as_ref().unwrap();
197
198    if let Type::Path(path) = &field.ty {
199        let field_type = &path.path;
200        from_impl = quote! {
201            #(#from_attributes)*
202            impl #generics ::core::convert::From<#field_type> for #ident #generics {
203                #[track_caller]
204                fn from(value: #field_type) -> Self {
205                    let location = ::std::panic::Location::caller();
206                    #ident {
207                        #field_name: value,
208                        location: ::locate_error::Location {
209                            file: location.file().to_string(),
210                            line: location.line(),
211                            column: location.column(),
212                        }
213                    }
214                }
215            }
216        };
217    }
218
219    TokenStream::from(from_impl)
220}
221
222fn locate_from_attr_index(attributes: &[Attribute]) -> Option<usize> {
223    attributes.iter().position(|attr| {
224        if !attr.path().is_ident("locate_from") {
225            return false;
226        }
227        // Only allow #[locate_from], not #[locate_from = "some_path"]
228        matches!(attr.meta, syn::Meta::Path(_))
229    })
230}
231
232// Helper function to check if a type is Location (may not identify full path correctly, but works in most cases)
233fn is_location_type(ty: &Type) -> bool {
234    if let Type::Path(type_path) = ty {
235        if let Some(last_segment) = type_path.path.segments.last() {
236            // Check if the last segment is "Location"
237            if last_segment.ident == "Location" {
238                // Simplistic check, verify the last segment
239                return true;
240            }
241        }
242    }
243    false
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use syn::parse_quote;
250
251    #[test]
252    fn test_locate_from_attr_index() {
253        let attributes = vec![];
254        assert!(locate_from_attr_index(&attributes).is_none());
255
256        // Test with the locate_from attribute
257        let locate_from_attr: Attribute = parse_quote!(#[locate_from]);
258        let attributes = vec![locate_from_attr];
259        assert!(locate_from_attr_index(&attributes) == Some(0));
260
261        // Test does not have the locate_from attribute
262        let locate_from_attr: Attribute = parse_quote!(#[locate_from = "some_path"]);
263        let attributes = vec![locate_from_attr];
264        assert!(locate_from_attr_index(&attributes).is_none());
265
266        // Test with multiple attributes including locate_from
267        let other_attr: Attribute = parse_quote!(#[derive(Debug)]);
268        let locate_from_attr: Attribute = parse_quote!(#[locate_from]);
269        let attributes = vec![other_attr, locate_from_attr];
270        assert!(locate_from_attr_index(&attributes).is_some());
271
272        // Test with multiple attributes but without locate_from
273        let attr1: Attribute = parse_quote!(#[derive(Debug)]);
274        let attr2: Attribute = parse_quote!(#[derive(Clone)]);
275        let attributes = vec![attr1, attr2];
276        assert!(locate_from_attr_index(&attributes).is_none());
277    }
278}