anywrap_macro/
lib.rs

1#![allow(warnings)]
2use proc_macro::TokenStream;
3use proc_macro2::Ident;
4use quote::quote;
5use syn::{parse_macro_input, Attribute, Data, DeriveInput, Expr, ExprLit, Fields, FieldsNamed, ItemEnum, Lit, Meta, Token};
6use syn::punctuated::Punctuated;
7
8#[proc_macro_attribute]
9pub fn anywrap(_attr: TokenStream, item: TokenStream) -> TokenStream {
10    let mut input_enum = parse_macro_input!(item as ItemEnum);
11
12    if let Err(err) = add_attr_impl(&mut input_enum) {
13        return err.to_compile_error().into();
14    }
15
16    let std_error_ts = from_std_error_impl(&input_enum);
17    let chain_ts = chain_impl(&input_enum);
18    let context_ts = context_impl(&input_enum);
19    let wrap_ts = wrap_impl(&input_enum);
20
21    let output = quote! {
22        #input_enum
23        #std_error_ts
24        #chain_ts
25        #context_ts
26        #wrap_ts
27    };
28
29    output.into()
30}
31
32// --- 增加属性 逻辑 ---
33fn add_attr_impl(input_enum: &mut ItemEnum) -> Result<(), syn::Error> {
34    let enum_ident = &input_enum.ident;
35
36    let extra_fields = quote! {
37        location: anywrap::location::Location,
38        chain: Option<Box<#enum_ident>>
39    };
40
41    for variant in &mut input_enum.variants {
42        if let Fields::Named(fields_named) = &mut variant.fields {
43            let parsed: FieldsNamed = syn::parse2(quote!({ #extra_fields }))?;
44            fields_named.named.extend(parsed.named);
45        } else {
46            return Err(syn::Error::new_spanned(
47                variant,
48                "Only struct-like enum variants are supported",
49            ));
50        }
51    }
52
53    let extra_variants: ItemEnum = syn::parse2(quote! {
54        enum Dummy {
55            #[anywrap_attr(display = "{msg}")]
56            Context {
57                msg: String,
58                #extra_fields
59            },
60            #[anywrap_attr(display = "{source}")]
61            Any {
62                source: Box<dyn std::error::Error + Send + Sync + 'static>,
63                #extra_fields
64            }
65        }
66    })?;
67
68    input_enum.variants.extend(extra_variants.variants);
69
70    Ok(())
71}
72
73// --- enrich_with_chain 逻辑 ---
74fn chain_impl(input_enum: &ItemEnum) -> proc_macro2::TokenStream {
75    let enum_ident = &input_enum.ident;
76    let mut match_arms = Vec::new();
77
78    for variant in &input_enum.variants {
79        let ident = &variant.ident;
80
81        // 只匹配具名字段
82        if let Fields::Named(ref fields_named) = variant.fields {
83            let has_chain = fields_named.named.iter().any(|f| {
84                f.ident.as_ref().map(|i| i == "chain").unwrap_or(false)
85            });
86
87            if has_chain {
88                match_arms.push(quote! {
89                    #enum_ident::#ident { chain, .. } => {
90                        if let Some(chained) = chain {
91                            current = chained;
92                        } else {
93                            *chain = Some(Box::new(next));
94                            break;
95                        }
96                    }
97                });
98            }
99        }
100    }
101
102    quote! {
103        impl #enum_ident {
104            pub fn push_chain(mut self, next: Self) -> Self {
105                let mut current = &mut self;
106                loop {
107                    match current {
108                        #(#match_arms),*
109                        _ => break,
110                    }
111                }
112                self
113            }
114        }
115    }
116}
117
118// --- 标准Error的From实现 逻辑 ---
119fn from_std_error_impl(input_enum: &ItemEnum) -> proc_macro2::TokenStream {
120    let enum_ident = &input_enum.ident;
121    quote! {
122        impl<E> From<E> for #enum_ident
123        where
124            E: core::error::Error + Send + Sync + 'static,
125        {
126            #[track_caller]
127            fn from(e: E) -> Self {
128                #enum_ident::Any {
129                    source: Box::new(e),
130                    location: anywrap::location::Location::default(),
131                    chain: None,
132                }
133            }
134        }
135    }
136}
137
138// --- Context 逻辑 ---
139fn context_impl(input_enum: &ItemEnum) -> proc_macro2::TokenStream {
140    let enum_ident = &input_enum.ident;
141    quote! {
142        pub trait Context<T, E> {
143            fn context<M>(self, msg: M) -> std::result::Result<T, #enum_ident>
144            where
145                M: std::fmt::Display + Send + Sync + 'static;
146        }
147
148        impl<T, E> Context<T, E> for std::result::Result<T, E>
149        where
150            E: core::error::Error + Send + Sync + 'static,
151        {
152            #[track_caller]
153            fn context<M>(self, msg: M) -> std::result::Result<T, #enum_ident>
154            where
155                M: std::fmt::Display + Send + Sync + 'static,
156            {
157                self.map_err(|e| {
158                    let a = #enum_ident::Any {
159                        source: Box::new(e),
160                        location: anywrap::location::Location::default(),
161                        chain: None,
162                    };
163                    let m = #enum_ident::Context {
164                        msg: msg.to_string(),
165                        location: anywrap::location::Location::default(),
166                        chain: None,
167                    };
168                    a.push_chain(m)
169                })
170            }
171        }
172
173        impl<T> Context<T, #enum_ident> for std::result::Result<T, #enum_ident> {
174            #[track_caller]
175            fn context<M>(self, msg: M) -> std::result::Result<T, #enum_ident>
176            where
177                M: std::fmt::Display + Send + Sync + 'static,
178            {
179                let location = anywrap::location::Location::default();
180                self.map_err(|e| {
181                    let m = #enum_ident::Context {
182                        msg: msg.to_string(),
183                        location: location,
184                        chain: None,
185                    };
186                    e.push_chain(m)
187                })
188            }
189        }
190    }
191}
192
193// --- wrap 逻辑 ---
194fn wrap_impl(input_enum: &ItemEnum) -> proc_macro2::TokenStream {
195    let enum_name = &input_enum.ident;
196    let mut impls = vec![];
197
198    for variant in &input_enum.variants {
199        let variant_name = &variant.ident;
200
201        if let Fields::Named(fields_named) = &variant.fields {
202            let mut error_field_type = None;
203            let mut field_assignments = vec![];
204
205            for field in &fields_named.named {
206                let ident = field.ident.as_ref().unwrap();
207                if ident == "source" {
208                    error_field_type = Some(&field.ty);
209                    field_assignments.push(quote! { #ident: e });
210                } else if ident == "location" {
211                    field_assignments.push(quote! { #ident: location });
212                } else {
213                    field_assignments.push(quote! { #ident: Default::default() });
214                }
215            }
216
217            if let Some(error_ty) = error_field_type {
218                impls.push(quote! {
219                    impl<T> Wrap<T> for std::result::Result<T, #error_ty> {
220                        #[track_caller]
221                        fn wrap(self) -> std::result::Result<T, #enum_name> {
222                            let location = anywrap::location::Location::default();
223                            self.map_err(|e| {
224                                #enum_name::#variant_name {
225                                    #(#field_assignments),*
226                                }
227                            })
228                        }
229                    }
230                });
231            }
232        }
233    }
234
235    quote! {
236        pub trait Wrap<T> {
237            fn wrap(self) -> std::result::Result<T, #enum_name>;
238        }
239
240        #(#impls)*
241    }
242}
243
244#[proc_macro_derive(AnyWrap, attributes(anywrap_attr))]
245pub fn derive_anywrap(item: TokenStream) -> TokenStream {
246    let input = parse_macro_input!(item as DeriveInput);
247
248    let enum_ident = &input.ident;
249    let Data::Enum(data_enum) = &input.data else {
250        return syn::Error::new_spanned(&input, "only enums are supported")
251            .to_compile_error()
252            .into();
253    };
254
255    let mut match_lines = Vec::new();
256    let mut chain_lines = Vec::new();
257    let mut chain_arms = Vec::new();
258    let mut from_impls = Vec::new();
259
260    for variant in &data_enum.variants {
261        let variant_ident = &variant.ident;
262
263        let Fields::Named(fields_named) = &variant.fields else {
264            return syn::Error::new_spanned(variant, "only named fields are supported")
265                .to_compile_error()
266                .into();
267        };
268
269        let field_idents: Vec<&Ident> = fields_named
270            .named
271            .iter()
272            .filter_map(|f| f.ident.as_ref())
273            .collect();
274
275        // 解析属性
276        let mut display_format = None;
277        let mut from_field = None;
278
279        for attr in &variant.attrs {
280            if let Some(expr) = get_attr_value(attr, "anywrap_attr", "display") {
281                if let Lit::Str(lit) = &expr.lit {
282                    display_format = Some(lit.value());
283                }
284            }
285            if let Some(expr) = get_attr_value(attr, "anywrap_attr", "from") {
286                if let Lit::Str(lit) = &expr.lit {
287                    from_field = Some(lit.value());
288                }
289            }
290        }
291        let display_fmt = display_format.unwrap_or_else(|| {
292            panic!(
293                "Missing #[anywrap_attr(display = \"...\")] for variant `{}`",
294                variant_ident
295            )
296        });
297
298        if from_field.is_some() {
299            // 获取变体字段信息
300            let fields = match &variant.fields {
301                Fields::Named(fields) => &fields.named,
302                _ => panic!("枚举变体必须使用命名字段"),
303            };
304
305            if fields.len() != 1 {
306                panic!("只有单个字段的变体才能实现 From trait. 变体名: {}", variant_ident);
307            }
308
309            // 查找指定的字段
310            let field = fields.iter().find(|f| {
311                f.ident.as_ref().map(|i| i.to_string()) == from_field
312            }).expect(&format!("找不到指定的字段: {:?}", from_field));
313
314            // 获取字段名
315            let field_name = field.ident.as_ref().unwrap();
316            // 获取字段类型
317            let field_type = &field.ty;
318
319            // 生成 From 实现
320            from_impls.push(quote! {
321                impl From<#field_type> for #enum_ident {
322                    fn from(source: #field_type) -> Self {
323                        #enum_ident::#variant_ident {
324                            #field_name: source,
325                            location: Default::default(),
326                            chain: None,
327                        }
328                    }
329                }
330            });
331        }
332
333        let match_arm = quote! {
334            #enum_ident::#variant_ident { #( #field_idents, )* .. } => format!(#display_fmt),
335        };
336        match_lines.push(match_arm);
337
338        let chain_arm = quote! {
339            #enum_ident::#variant_ident { #( #field_idents, )* location, .. } => {
340                format!("{idx}: {}, at {location}", format!(#display_fmt))
341            }
342        };
343        chain_lines.push(chain_arm);
344
345        let chain_extractor = quote! {
346            #enum_ident::#variant_ident { chain, .. } => chain.as_deref(),
347        };
348        chain_arms.push(chain_extractor);
349    }
350
351    let output = quote! {
352        impl std::fmt::Display for #enum_ident {
353            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354                writeln!(f, "{}", match self {
355                    #( #match_lines )*
356                    Error::Context { msg, .. } => format!("{msg}"),
357                    Error::Any { source, .. } => format!("{source}"),
358                })?;
359                Ok(())
360            }
361        }
362        impl std::fmt::Debug for #enum_ident {
363            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364                writeln!(f, "{}", match self {
365                    #( #match_lines )*
366                    Error::Context { msg, .. } => format!("{msg}"),
367                    Error::Any { source, .. } => format!("{source}"),
368                })?;
369
370                fn write_chain(err: &#enum_ident, f: &mut std::fmt::Formatter<'_>, idx: usize) -> std::fmt::Result {
371                    let line = match err {
372                        #( #chain_lines )*
373                        Error::Context { msg, location, .. } => format!("{idx}: {msg}, at {location}"),
374                        Error::Any { source, location, .. } => format!("{idx}: {source}, at {location}"),
375                    };
376                    writeln!(f, "{}", line)?;
377
378                    if let Some(inner) = match err {
379                        #( #chain_arms )*
380                        Error::Context { chain, .. } => chain.as_deref(),
381                        Error::Any { chain, .. } => chain.as_deref(),
382                    } {
383                        write_chain(inner, f, idx + 1)?;
384                    }
385
386                    Ok(())
387                }
388
389                write_chain(self, f, 0)
390            }
391        }
392
393        #(#from_impls)*
394    };
395
396    output.into()
397}
398
399fn get_attr_value(attr: &Attribute, attr_name: &str, key: &str) -> Option<ExprLit> {
400    if attr.path().is_ident(attr_name) {
401        if let Meta::List(meta) = &attr.meta {
402            for nested in meta.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated).unwrap() {
403                match nested {
404                    Meta::NameValue(name_value) => {
405                        if name_value.path.is_ident(key) {
406                            if let Expr::Lit(expr_lit) = &name_value.value {
407                                return Some(expr_lit.clone());
408                            }
409                        }
410                    }
411                    _ => {}
412                }
413            }
414        }
415    }
416    None
417}