1#![allow(warnings)]
2use proc_macro::TokenStream;
3use proc_macro2::Ident;
4use quote::quote;
5use syn::{parse_macro_input, Attribute, Data, DeriveInput, Expr, ExprLit, Fields, FieldsNamed, ItemEnum, Lit, Meta, Token};
6use syn::punctuated::Punctuated;
7
8#[proc_macro_attribute]
9pub fn anywrap(_attr: TokenStream, item: TokenStream) -> TokenStream {
10 let mut input_enum = parse_macro_input!(item as ItemEnum);
11
12 if let Err(err) = add_attr_impl(&mut input_enum) {
13 return err.to_compile_error().into();
14 }
15
16 let std_error_ts = from_std_error_impl(&input_enum);
17 let chain_ts = chain_impl(&input_enum);
18 let context_ts = context_impl(&input_enum);
19 let wrap_ts = wrap_impl(&input_enum);
20
21 let output = quote! {
22 #input_enum
23 #std_error_ts
24 #chain_ts
25 #context_ts
26 #wrap_ts
27 };
28
29 output.into()
30}
31
32fn add_attr_impl(input_enum: &mut ItemEnum) -> Result<(), syn::Error> {
34 let enum_ident = &input_enum.ident;
35
36 let extra_fields = quote! {
37 location: anywrap::location::Location,
38 chain: Option<Box<#enum_ident>>
39 };
40
41 for variant in &mut input_enum.variants {
42 if let Fields::Named(fields_named) = &mut variant.fields {
43 let parsed: FieldsNamed = syn::parse2(quote!({ #extra_fields }))?;
44 fields_named.named.extend(parsed.named);
45 } else {
46 return Err(syn::Error::new_spanned(
47 variant,
48 "Only struct-like enum variants are supported",
49 ));
50 }
51 }
52
53 let extra_variants: ItemEnum = syn::parse2(quote! {
54 enum Dummy {
55 #[anywrap_attr(display = "{msg}")]
56 Context {
57 msg: String,
58 #extra_fields
59 },
60 #[anywrap_attr(display = "{source}")]
61 Any {
62 source: Box<dyn std::error::Error + Send + Sync + 'static>,
63 #extra_fields
64 }
65 }
66 })?;
67
68 input_enum.variants.extend(extra_variants.variants);
69
70 Ok(())
71}
72
73fn chain_impl(input_enum: &ItemEnum) -> proc_macro2::TokenStream {
75 let enum_ident = &input_enum.ident;
76 let mut match_arms = Vec::new();
77
78 for variant in &input_enum.variants {
79 let ident = &variant.ident;
80
81 if let Fields::Named(ref fields_named) = variant.fields {
83 let has_chain = fields_named.named.iter().any(|f| {
84 f.ident.as_ref().map(|i| i == "chain").unwrap_or(false)
85 });
86
87 if has_chain {
88 match_arms.push(quote! {
89 #enum_ident::#ident { chain, .. } => {
90 if let Some(chained) = chain {
91 current = chained;
92 } else {
93 *chain = Some(Box::new(next));
94 break;
95 }
96 }
97 });
98 }
99 }
100 }
101
102 quote! {
103 impl #enum_ident {
104 pub fn push_chain(mut self, next: Self) -> Self {
105 let mut current = &mut self;
106 loop {
107 match current {
108 #(#match_arms),*
109 _ => break,
110 }
111 }
112 self
113 }
114 }
115 }
116}
117
118fn from_std_error_impl(input_enum: &ItemEnum) -> proc_macro2::TokenStream {
120 let enum_ident = &input_enum.ident;
121 quote! {
122 impl<E> From<E> for #enum_ident
123 where
124 E: core::error::Error + Send + Sync + 'static,
125 {
126 #[track_caller]
127 fn from(e: E) -> Self {
128 #enum_ident::Any {
129 source: Box::new(e),
130 location: anywrap::location::Location::default(),
131 chain: None,
132 }
133 }
134 }
135 }
136}
137
138fn context_impl(input_enum: &ItemEnum) -> proc_macro2::TokenStream {
140 let enum_ident = &input_enum.ident;
141 quote! {
142 pub trait Context<T, E> {
143 fn context<M>(self, msg: M) -> std::result::Result<T, #enum_ident>
144 where
145 M: std::fmt::Display + Send + Sync + 'static;
146 }
147
148 impl<T, E> Context<T, E> for std::result::Result<T, E>
149 where
150 E: core::error::Error + Send + Sync + 'static,
151 {
152 #[track_caller]
153 fn context<M>(self, msg: M) -> std::result::Result<T, #enum_ident>
154 where
155 M: std::fmt::Display + Send + Sync + 'static,
156 {
157 self.map_err(|e| {
158 let a = #enum_ident::Any {
159 source: Box::new(e),
160 location: anywrap::location::Location::default(),
161 chain: None,
162 };
163 let m = #enum_ident::Context {
164 msg: msg.to_string(),
165 location: anywrap::location::Location::default(),
166 chain: None,
167 };
168 a.push_chain(m)
169 })
170 }
171 }
172
173 impl<T> Context<T, #enum_ident> for std::result::Result<T, #enum_ident> {
174 #[track_caller]
175 fn context<M>(self, msg: M) -> std::result::Result<T, #enum_ident>
176 where
177 M: std::fmt::Display + Send + Sync + 'static,
178 {
179 let location = anywrap::location::Location::default();
180 self.map_err(|e| {
181 let m = #enum_ident::Context {
182 msg: msg.to_string(),
183 location: location,
184 chain: None,
185 };
186 e.push_chain(m)
187 })
188 }
189 }
190 }
191}
192
193fn wrap_impl(input_enum: &ItemEnum) -> proc_macro2::TokenStream {
195 let enum_name = &input_enum.ident;
196 let mut impls = vec![];
197
198 for variant in &input_enum.variants {
199 let variant_name = &variant.ident;
200
201 if let Fields::Named(fields_named) = &variant.fields {
202 let mut error_field_type = None;
203 let mut field_assignments = vec![];
204
205 for field in &fields_named.named {
206 let ident = field.ident.as_ref().unwrap();
207 if ident == "source" {
208 error_field_type = Some(&field.ty);
209 field_assignments.push(quote! { #ident: e });
210 } else if ident == "location" {
211 field_assignments.push(quote! { #ident: location });
212 } else {
213 field_assignments.push(quote! { #ident: Default::default() });
214 }
215 }
216
217 if let Some(error_ty) = error_field_type {
218 impls.push(quote! {
219 impl<T> Wrap<T> for std::result::Result<T, #error_ty> {
220 #[track_caller]
221 fn wrap(self) -> std::result::Result<T, #enum_name> {
222 let location = anywrap::location::Location::default();
223 self.map_err(|e| {
224 #enum_name::#variant_name {
225 #(#field_assignments),*
226 }
227 })
228 }
229 }
230 });
231 }
232 }
233 }
234
235 quote! {
236 pub trait Wrap<T> {
237 fn wrap(self) -> std::result::Result<T, #enum_name>;
238 }
239
240 #(#impls)*
241 }
242}
243
244#[proc_macro_derive(AnyWrap, attributes(anywrap_attr))]
245pub fn derive_anywrap(item: TokenStream) -> TokenStream {
246 let input = parse_macro_input!(item as DeriveInput);
247
248 let enum_ident = &input.ident;
249 let Data::Enum(data_enum) = &input.data else {
250 return syn::Error::new_spanned(&input, "only enums are supported")
251 .to_compile_error()
252 .into();
253 };
254
255 let mut match_lines = Vec::new();
256 let mut chain_lines = Vec::new();
257 let mut chain_arms = Vec::new();
258 let mut from_impls = Vec::new();
259
260 for variant in &data_enum.variants {
261 let variant_ident = &variant.ident;
262
263 let Fields::Named(fields_named) = &variant.fields else {
264 return syn::Error::new_spanned(variant, "only named fields are supported")
265 .to_compile_error()
266 .into();
267 };
268
269 let field_idents: Vec<&Ident> = fields_named
270 .named
271 .iter()
272 .filter_map(|f| f.ident.as_ref())
273 .collect();
274
275 let mut display_format = None;
277 let mut from_field = None;
278
279 for attr in &variant.attrs {
280 if let Some(expr) = get_attr_value(attr, "anywrap_attr", "display") {
281 if let Lit::Str(lit) = &expr.lit {
282 display_format = Some(lit.value());
283 }
284 }
285 if let Some(expr) = get_attr_value(attr, "anywrap_attr", "from") {
286 if let Lit::Str(lit) = &expr.lit {
287 from_field = Some(lit.value());
288 }
289 }
290 }
291 let display_fmt = display_format.unwrap_or_else(|| {
292 panic!(
293 "Missing #[anywrap_attr(display = \"...\")] for variant `{}`",
294 variant_ident
295 )
296 });
297
298 if from_field.is_some() {
299 let fields = match &variant.fields {
301 Fields::Named(fields) => &fields.named,
302 _ => panic!("枚举变体必须使用命名字段"),
303 };
304
305 if fields.len() != 1 {
306 panic!("只有单个字段的变体才能实现 From trait. 变体名: {}", variant_ident);
307 }
308
309 let field = fields.iter().find(|f| {
311 f.ident.as_ref().map(|i| i.to_string()) == from_field
312 }).expect(&format!("找不到指定的字段: {:?}", from_field));
313
314 let field_name = field.ident.as_ref().unwrap();
316 let field_type = &field.ty;
318
319 from_impls.push(quote! {
321 impl From<#field_type> for #enum_ident {
322 fn from(source: #field_type) -> Self {
323 #enum_ident::#variant_ident {
324 #field_name: source,
325 location: Default::default(),
326 chain: None,
327 }
328 }
329 }
330 });
331 }
332
333 let match_arm = quote! {
334 #enum_ident::#variant_ident { #( #field_idents, )* .. } => format!(#display_fmt),
335 };
336 match_lines.push(match_arm);
337
338 let chain_arm = quote! {
339 #enum_ident::#variant_ident { #( #field_idents, )* location, .. } => {
340 format!("{idx}: {}, at {location}", format!(#display_fmt))
341 }
342 };
343 chain_lines.push(chain_arm);
344
345 let chain_extractor = quote! {
346 #enum_ident::#variant_ident { chain, .. } => chain.as_deref(),
347 };
348 chain_arms.push(chain_extractor);
349 }
350
351 let output = quote! {
352 impl std::fmt::Display for #enum_ident {
353 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354 writeln!(f, "{}", match self {
355 #( #match_lines )*
356 Error::Context { msg, .. } => format!("{msg}"),
357 Error::Any { source, .. } => format!("{source}"),
358 })?;
359 Ok(())
360 }
361 }
362 impl std::fmt::Debug for #enum_ident {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 writeln!(f, "{}", match self {
365 #( #match_lines )*
366 Error::Context { msg, .. } => format!("{msg}"),
367 Error::Any { source, .. } => format!("{source}"),
368 })?;
369
370 fn write_chain(err: &#enum_ident, f: &mut std::fmt::Formatter<'_>, idx: usize) -> std::fmt::Result {
371 let line = match err {
372 #( #chain_lines )*
373 Error::Context { msg, location, .. } => format!("{idx}: {msg}, at {location}"),
374 Error::Any { source, location, .. } => format!("{idx}: {source}, at {location}"),
375 };
376 writeln!(f, "{}", line)?;
377
378 if let Some(inner) = match err {
379 #( #chain_arms )*
380 Error::Context { chain, .. } => chain.as_deref(),
381 Error::Any { chain, .. } => chain.as_deref(),
382 } {
383 write_chain(inner, f, idx + 1)?;
384 }
385
386 Ok(())
387 }
388
389 write_chain(self, f, 0)
390 }
391 }
392
393 #(#from_impls)*
394 };
395
396 output.into()
397}
398
399fn get_attr_value(attr: &Attribute, attr_name: &str, key: &str) -> Option<ExprLit> {
400 if attr.path().is_ident(attr_name) {
401 if let Meta::List(meta) = &attr.meta {
402 for nested in meta.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated).unwrap() {
403 match nested {
404 Meta::NameValue(name_value) => {
405 if name_value.path.is_ident(key) {
406 if let Expr::Lit(expr_lit) = &name_value.value {
407 return Some(expr_lit.clone());
408 }
409 }
410 }
411 _ => {}
412 }
413 }
414 }
415 }
416 None
417}