asyn_retry_policy_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemFn, Lit, Expr};
4
5#[proc_macro_attribute]
6pub fn retry(attr: TokenStream, item: TokenStream) -> TokenStream {
7    // Supported attribute forms:
8    // - empty: `#[retry]`
9    // - single integer: `#[retry(3)]`
10    // - named args: `#[retry(attempts = 3, base_delay_ms = 100, max_delay_ms = 5000, backoff_factor = 2.0, jitter = true, rng_seed = 42)]`
11
12    let mut attempts: Option<usize> = None;
13    let mut base_delay_ms: Option<u64> = None;
14    let mut max_delay_ms: Option<u64> = None;
15    let mut backoff_factor: Option<f64> = None;
16    let mut jitter_opt: Option<bool> = None;
17    let mut rng_seed: Option<u64> = None;
18    let mut predicate_expr: Option<syn::Expr> = None;
19
20    if !attr.is_empty() {
21        // try simple integer form first
22        if let Ok(Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. })) = syn::parse::<Expr>(attr.clone()) {
23            attempts = Some(litint.base10_parse::<usize>().unwrap_or(3));
24        } else {
25            // parse named args using a simple key = expr parser
26            struct KeyVals(Vec<(syn::Ident, syn::Expr)>);
27
28            impl syn::parse::Parse for KeyVals {
29                fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
30                    let mut out = Vec::new();
31                    while !input.is_empty() {
32                        let key: syn::Ident = input.parse()?;
33                        input.parse::<syn::Token![=]>()?;
34                        let expr: syn::Expr = input.parse()?;
35                        out.push((key, expr));
36                        if input.peek(syn::Token![,]) {
37                            let _ = input.parse::<syn::Token![,]>()?;
38                        }
39                    }
40                    Ok(KeyVals(out))
41                }
42            }
43
44            let args = parse_macro_input!(attr as KeyVals);
45            for (ident, expr) in args.0 {
46                match ident.to_string().as_str() {
47                    "attempts" => match expr {
48                        Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. }) => attempts = Some(litint.base10_parse::<usize>().unwrap()),
49                        _ => return syn::Error::new_spanned(expr, "expected integer literal").to_compile_error().into(),
50                    },
51                    "base_delay_ms" => match expr {
52                        Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. }) => base_delay_ms = Some(litint.base10_parse::<u64>().unwrap()),
53                        _ => return syn::Error::new_spanned(expr, "expected integer literal for base_delay_ms").to_compile_error().into(),
54                    },
55                    "max_delay_ms" => match expr {
56                        Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. }) => max_delay_ms = Some(litint.base10_parse::<u64>().unwrap()),
57                        _ => return syn::Error::new_spanned(expr, "expected integer literal for max_delay_ms").to_compile_error().into(),
58                    },
59                    "backoff_factor" => match expr {
60                        Expr::Lit(syn::ExprLit { lit: Lit::Float(litf), .. }) => backoff_factor = Some(litf.base10_parse::<f64>().unwrap()),
61                        Expr::Lit(syn::ExprLit { lit: Lit::Int(liti), .. }) => backoff_factor = Some(liti.base10_parse::<f64>().unwrap()),
62                        _ => return syn::Error::new_spanned(expr, "expected numeric literal for backoff_factor").to_compile_error().into(),
63                    },
64                    "jitter" => match expr {
65                        Expr::Lit(syn::ExprLit { lit: Lit::Bool(litb), .. }) => jitter_opt = Some(litb.value),
66                        _ => return syn::Error::new_spanned(expr, "expected boolean literal for jitter").to_compile_error().into(),
67                    },
68                    "rng_seed" => match expr {
69                        Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. }) => rng_seed = Some(litint.base10_parse::<u64>().unwrap()),
70                        _ => return syn::Error::new_spanned(expr, "expected integer literal for rng_seed").to_compile_error().into(),
71                    },
72                    "predicate" => {
73                        // Accept a path, a closure, or a string literal with the path
74                        match expr {
75                            Expr::Path(_) => {
76                                predicate_expr = Some(expr);
77                            }
78                            Expr::Closure(_) => {
79                                // inline closure expression is accepted
80                                predicate_expr = Some(expr);
81                            }
82                            Expr::Lit(syn::ExprLit { lit: Lit::Str(lits), .. }) => {
83                                // Parse string into a path
84                                let s = lits.value();
85                                match syn::parse_str::<syn::Path>(&s) {
86                                    Ok(p) => predicate_expr = Some(Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: p })),
87                                    Err(_) => return syn::Error::new_spanned(lits, "invalid path in string").to_compile_error().into(),
88                                }
89                            }
90                            _ => return syn::Error::new_spanned(expr, "expected path, closure, or string literal for predicate").to_compile_error().into(),
91                        }
92                    }
93                    other => return syn::Error::new_spanned(ident, format!("unknown option `{}`", other)).to_compile_error().into(),
94                }
95            }
96        }
97    }
98
99    // Default attempts if not provided
100    let attempts = attempts.unwrap_or(3usize);
101
102    let input = parse_macro_input!(item as ItemFn);
103
104    // Ensure function is async
105    if input.sig.asyncness.is_none() {
106        return syn::Error::new_spanned(input.sig.fn_token, "`#[retry]` can only be applied to `async fn`").to_compile_error().into();
107    }
108
109    let vis = &input.vis;
110    let mut sig = input.sig.clone();
111    let attrs = &input.attrs;
112    let block = &input.block;
113
114    // Collect simple parameter identifiers to clone inside the closure per-attempt
115    let mut clones = Vec::new();
116    for input in sig.inputs.iter() {
117        if let syn::FnArg::Typed(pat_type) = input {
118            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
119                let ident = &pat_ident.ident;
120                clones.push(quote::quote! { let #ident = #ident.clone(); });
121            }
122        }
123    }
124
125    // Build the new function body that wraps the original body inside a RetryPolicy::retry call
126    // We'll reference the runtime crate as `::asyn_retry_policy::RetryPolicy`
127
128    // Build policy initializer fields
129    let mut fields = Vec::new();
130    fields.push(quote! { attempts: #attempts });
131    if let Some(ms) = base_delay_ms {
132        fields.push(quote! { base_delay: ::std::time::Duration::from_millis(#ms) });
133    }
134    if let Some(ms) = max_delay_ms {
135        fields.push(quote! { max_delay: ::std::time::Duration::from_millis(#ms) });
136    }
137    if let Some(f) = backoff_factor {
138        fields.push(quote! { backoff_factor: #f });
139    }
140    if let Some(b) = jitter_opt {
141        fields.push(quote! { jitter: #b });
142    }
143    if let Some(seed) = rng_seed {
144        fields.push(quote! { rng_seed: Some(#seed) });
145    }
146
147    // predicate expression to use as the retry predicate; defaults to `|_| true`
148    let predicate_tokens = if let Some(pred) = predicate_expr {
149        quote! { #pred }
150    } else {
151        quote! { |_| true }
152    };
153
154    let expanded = quote! {
155        #(#attrs)*
156        #vis #sig {
157            let policy = ::asyn_retry_policy::RetryPolicy { #(#fields),*, ..Default::default() };
158            policy.retry(|| {
159                #(#clones)*
160                async move #block
161            }, #predicate_tokens).await
162        }
163    };
164
165    expanded.into()
166}