1extern crate proc_macro;
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{
11 parse_macro_input, ItemFn, Type, ReturnType, GenericArgument, PathArguments,
12 parse::Parse, parse::ParseStream, Error, Result as SynResult
13};
14
15struct AttributeArgs {
17 strategy_type: Type,
18}
19
20impl Parse for AttributeArgs {
21 fn parse(input: ParseStream) -> SynResult<Self> {
22 let strategy_type: Type = input.parse()?;
23 Ok(AttributeArgs { strategy_type })
24 }
25}
26
27fn extract_result_types(return_type: &Type) -> SynResult<(Type, Type)> {
29 if let Type::Path(type_path) = return_type {
30 if let Some(segment) = type_path.path.segments.last() {
31 if segment.ident == "Result" {
32 if let PathArguments::AngleBracketed(args) = &segment.arguments {
33 if args.args.len() == 2 {
34 if let (
35 GenericArgument::Type(ok_type),
36 GenericArgument::Type(err_type)
37 ) = (&args.args[0], &args.args[1]) {
38 return Ok((ok_type.clone(), err_type.clone()));
39 }
40 }
41 }
42 }
43 }
44 }
45
46 Err(Error::new_spanned(
47 return_type,
48 "Expected function to return Result<T, E>"
49 ))
50}
51
52#[proc_macro_attribute]
81pub fn error_strategy(args: TokenStream, item: TokenStream) -> TokenStream {
82 let input_fn = parse_macro_input!(item as ItemFn);
83 let args = parse_macro_input!(args as AttributeArgs);
84
85 let strategy_type = args.strategy_type;
86 let fn_name = &input_fn.sig.ident;
87 let fn_vis = &input_fn.vis;
88 let fn_inputs = &input_fn.sig.inputs;
89 let fn_body = &input_fn.block;
90 let fn_asyncness = &input_fn.sig.asyncness;
91 let fn_generics = &input_fn.sig.generics;
92 let where_clause = &input_fn.sig.generics.where_clause;
93
94 if fn_asyncness.is_none() {
96 return Error::new_spanned(
97 &input_fn.sig,
98 "error_strategy can only be applied to async functions"
99 ).to_compile_error().into();
100 }
101
102 let (ok_type, err_type) = match &input_fn.sig.output {
104 ReturnType::Type(_, ty) => {
105 match extract_result_types(ty) {
106 Ok(types) => types,
107 Err(e) => return e.to_compile_error().into(),
108 }
109 }
110 ReturnType::Default => {
111 return Error::new_spanned(
112 &input_fn.sig,
113 "Function must return Result<T, E>"
114 ).to_compile_error().into();
115 }
116 };
117
118 let original_impl_name = syn::Ident::new(
120 &format!("{}_original_impl", fn_name),
121 fn_name.span()
122 );
123
124 let strategy_name = quote!(#strategy_type).to_string();
126
127 let param_names: Vec<_> = input_fn.sig.inputs.iter().filter_map(|arg| {
129 if let syn::FnArg::Typed(pat_type) = arg {
130 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
131 Some(&pat_ident.ident)
132 } else {
133 None
134 }
135 } else {
136 None
137 }
138 }).collect();
139
140 let expanded = quote! {
141 #[doc(hidden)]
142 #fn_asyncness fn #original_impl_name #fn_generics (#fn_inputs) -> Result<#ok_type, #err_type> #where_clause
143 #fn_body
144
145 #fn_vis #fn_asyncness fn #fn_name #fn_generics (#fn_inputs) -> crate::PipexResult<#ok_type, #err_type> #where_clause {
146 let result = #original_impl_name(#(#param_names),*).await;
147 crate::PipexResult::new(result, #strategy_name)
148 }
149 };
150
151 TokenStream::from(expanded)
152}
153
154