locate_error_derive/
lib.rs1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::{quote, quote_spanned};
4use syn::{
5 Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Generics, Ident, Type,
6 parse_macro_input, parse_quote, spanned::Spanned,
7};
8
9#[proc_macro_derive(Locate, attributes(locate_from))]
12pub fn locate(input: TokenStream) -> TokenStream {
13 let input = parse_macro_input!(input as DeriveInput);
14 let ident = &input.ident;
15 let generics = &input.generics;
16
17 let from_attributes: Vec<Attribute> = parse_quote!(
18 #[allow(
19 deprecated,
20 unused_qualifications,
21 clippy::elidable_lifetime_names,
22 clippy::needless_lifetimes,
23 )]
24 #[automatically_derived]
25 );
26
27 match &input.data {
28 Data::Enum(data) => process_enum(data, &from_attributes, generics, ident),
29 Data::Struct(data) => process_struct(data, &from_attributes, generics, ident),
30 _ => TokenStream::from(quote! {
31 compile_error!("Locate can only be derived for enums or structs");
32 }),
33 }
34}
35
36fn process_enum(
37 data: &DataEnum,
38 from_attributes: &[Attribute],
39 generics: &Generics,
40 ident: &Ident,
41) -> TokenStream {
42 let mut from_impls = vec![];
43 let mut n_has_locate_from = 0;
44 for variant in &data.variants {
45 let variant_name = &variant.ident;
46 let fields = &variant.fields;
47
48 match &fields {
49 Fields::Unnamed(fields) => {
50 for field in fields.unnamed.iter() {
51 if let Some(index) = locate_from_attr_index(&field.attrs) {
52 if fields.unnamed.len() != 2 {
53 return TokenStream::from(quote_spanned! {
54 variant.ident.span() => compile_error!("Locate requires enums variants with the #[locate_from] attribute to have exactly two fields, one for the source and one for the location");
55 });
56 }
57 if let Some(other_field) = fields.unnamed.iter().nth((index + 1) % 2) {
58 if !is_location_type(&other_field.ty) {
59 return TokenStream::from(quote_spanned! {
60 other_field.ident.span() => compile_error!("Variants with #[locate_from] must have a field of type `locate_from::Location`");
61 });
62 }
63 }
64 n_has_locate_from += 1;
65 if let Type::Path(path) = &field.ty {
66 let field_type = &path.path;
67 from_impls.push(quote! {
68 #(#from_attributes)*
69 impl #generics ::core::convert::From<#field_type> for #ident #generics {
70 #[track_caller]
71 fn from(value: #field_type) -> Self {
72 let location = ::std::panic::Location::caller();
73 #ident::#variant_name {
74 0: value,
75 1: ::locate_error::Location {
76 file: location.file().to_string(),
77 line: location.line(),
78 column: location.column(),
79 }
80 }
81 }
82 }
83 });
84 }
85 }
86 }
87 }
88 Fields::Named(fields) => {
89 for field in fields.named.iter() {
90 let field_name = field.ident.as_ref().unwrap();
92 if locate_from_attr_index(&field.attrs).is_some() {
93 let has_location_field = fields.named.iter().any(|f| {
94 f.ident.as_ref().is_some_and(|name| name == "location")
95 && is_location_type(&f.ty)
96 });
97
98 if !has_location_field {
99 return TokenStream::from(quote_spanned! {
100 variant.ident.span() => compile_error!("Variants with #[locate_from] must have a field named 'location' of type `locate_from::Location`");
101 });
102 }
103
104 if fields.named.len() != 2 {
105 return TokenStream::from(quote_spanned! {
106 variant.ident.span() => compile_error!("Locate requires enums variants with the #[locate_from] attribute to have exactly two fields, one for the source and one for the location");
107 });
108 }
109
110 n_has_locate_from += 1;
111 if let Type::Path(path) = &field.ty {
112 let field_type = &path.path;
113 from_impls.push(quote! {
114 #(#from_attributes)*
115 impl #generics ::core::convert::From<#field_type> for #ident #generics {
116 #[track_caller]
117 fn from(value: #field_type) -> Self {
118 let location = ::std::panic::Location::caller();
119 #ident::#variant_name {
120 #field_name:value,
121 location: ::locate_error::Location {
122 file: location.file().to_string(),
123 line: location.line(),
124 column: location.column(),
125 }
126 }
127 }
128 }
129 });
130 }
131 }
132 }
133 }
134 Fields::Unit => {}
135 }
136 }
137
138 if n_has_locate_from == 0 {
139 return TokenStream::from(quote! {
140 compile_error!("Locate requires at least one variant with the #[locate_from] attribute (otherwise this macro is effectively a no-op)");
141 });
142 }
143
144 let expanded = quote! {
145 #(#from_impls)*
146 };
147
148 TokenStream::from(expanded)
149}
150
151fn process_struct(
152 data: &DataStruct,
153 from_attributes: &[Attribute],
154 generics: &Generics,
155 ident: &Ident,
156) -> TokenStream {
157 let mut from_impl: proc_macro2::TokenStream = quote! {};
158 let locate_from_fields: Vec<_> = data
160 .fields
161 .iter()
162 .filter(|field| locate_from_attr_index(&field.attrs).is_some())
163 .collect();
164
165 if locate_from_fields.is_empty() || locate_from_fields.len() > 1 {
167 let error_message = format!(
168 "Locate requires exactly one field marked with #[locate_from], found {:?}",
169 locate_from_fields.len()
170 );
171 return TokenStream::from(quote! {
172 compile_error!(#error_message);
173 });
174 }
175
176 if data.fields.len() > 2 {
178 return TokenStream::from(quote! {
179 compile_error!("Locate requires structs to have only a 'source' field (with the #[locate_from] attribute) and a 'location' field");
180 });
181 }
182
183 let has_location_field = data
185 .fields
186 .iter()
187 .any(|field| field.ident.as_ref().is_some_and(|name| name == "location"));
188
189 if !has_location_field {
190 return TokenStream::from(quote! {
191 compile_error!("Locate requires structs to have a field named 'location' of type `locate_from::Location`");
192 });
193 }
194
195 let field = locate_from_fields.first().unwrap();
196 let field_name = field.ident.as_ref().unwrap();
197
198 if let Type::Path(path) = &field.ty {
199 let field_type = &path.path;
200 from_impl = quote! {
201 #(#from_attributes)*
202 impl #generics ::core::convert::From<#field_type> for #ident #generics {
203 #[track_caller]
204 fn from(value: #field_type) -> Self {
205 let location = ::std::panic::Location::caller();
206 #ident {
207 #field_name: value,
208 location: ::locate_error::Location {
209 file: location.file().to_string(),
210 line: location.line(),
211 column: location.column(),
212 }
213 }
214 }
215 }
216 };
217 }
218
219 TokenStream::from(from_impl)
220}
221
222fn locate_from_attr_index(attributes: &[Attribute]) -> Option<usize> {
223 attributes.iter().position(|attr| {
224 if !attr.path().is_ident("locate_from") {
225 return false;
226 }
227 matches!(attr.meta, syn::Meta::Path(_))
229 })
230}
231
232fn is_location_type(ty: &Type) -> bool {
234 if let Type::Path(type_path) = ty {
235 if let Some(last_segment) = type_path.path.segments.last() {
236 if last_segment.ident == "Location" {
238 return true;
240 }
241 }
242 }
243 false
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use syn::parse_quote;
250
251 #[test]
252 fn test_locate_from_attr_index() {
253 let attributes = vec![];
254 assert!(locate_from_attr_index(&attributes).is_none());
255
256 let locate_from_attr: Attribute = parse_quote!(#[locate_from]);
258 let attributes = vec![locate_from_attr];
259 assert!(locate_from_attr_index(&attributes) == Some(0));
260
261 let locate_from_attr: Attribute = parse_quote!(#[locate_from = "some_path"]);
263 let attributes = vec![locate_from_attr];
264 assert!(locate_from_attr_index(&attributes).is_none());
265
266 let other_attr: Attribute = parse_quote!(#[derive(Debug)]);
268 let locate_from_attr: Attribute = parse_quote!(#[locate_from]);
269 let attributes = vec![other_attr, locate_from_attr];
270 assert!(locate_from_attr_index(&attributes).is_some());
271
272 let attr1: Attribute = parse_quote!(#[derive(Debug)]);
274 let attr2: Attribute = parse_quote!(#[derive(Clone)]);
275 let attributes = vec![attr1, attr2];
276 assert!(locate_from_attr_index(&attributes).is_none());
277 }
278}