asyn_retry_policy_macro/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemFn, Lit, Expr};
4
5#[proc_macro_attribute]
6pub fn retry(attr: TokenStream, item: TokenStream) -> TokenStream {
7 let mut attempts: Option<usize> = None;
13 let mut base_delay_ms: Option<u64> = None;
14 let mut max_delay_ms: Option<u64> = None;
15 let mut backoff_factor: Option<f64> = None;
16 let mut jitter_opt: Option<bool> = None;
17 let mut rng_seed: Option<u64> = None;
18 let mut predicate_expr: Option<syn::Expr> = None;
19
20 if !attr.is_empty() {
21 if let Ok(Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. })) = syn::parse::<Expr>(attr.clone()) {
23 attempts = Some(litint.base10_parse::<usize>().unwrap_or(3));
24 } else {
25 struct KeyVals(Vec<(syn::Ident, syn::Expr)>);
27
28 impl syn::parse::Parse for KeyVals {
29 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
30 let mut out = Vec::new();
31 while !input.is_empty() {
32 let key: syn::Ident = input.parse()?;
33 input.parse::<syn::Token![=]>()?;
34 let expr: syn::Expr = input.parse()?;
35 out.push((key, expr));
36 if input.peek(syn::Token![,]) {
37 let _ = input.parse::<syn::Token![,]>()?;
38 }
39 }
40 Ok(KeyVals(out))
41 }
42 }
43
44 let args = parse_macro_input!(attr as KeyVals);
45 for (ident, expr) in args.0 {
46 match ident.to_string().as_str() {
47 "attempts" => match expr {
48 Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. }) => attempts = Some(litint.base10_parse::<usize>().unwrap()),
49 _ => return syn::Error::new_spanned(expr, "expected integer literal").to_compile_error().into(),
50 },
51 "base_delay_ms" => match expr {
52 Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. }) => base_delay_ms = Some(litint.base10_parse::<u64>().unwrap()),
53 _ => return syn::Error::new_spanned(expr, "expected integer literal for base_delay_ms").to_compile_error().into(),
54 },
55 "max_delay_ms" => match expr {
56 Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. }) => max_delay_ms = Some(litint.base10_parse::<u64>().unwrap()),
57 _ => return syn::Error::new_spanned(expr, "expected integer literal for max_delay_ms").to_compile_error().into(),
58 },
59 "backoff_factor" => match expr {
60 Expr::Lit(syn::ExprLit { lit: Lit::Float(litf), .. }) => backoff_factor = Some(litf.base10_parse::<f64>().unwrap()),
61 Expr::Lit(syn::ExprLit { lit: Lit::Int(liti), .. }) => backoff_factor = Some(liti.base10_parse::<f64>().unwrap()),
62 _ => return syn::Error::new_spanned(expr, "expected numeric literal for backoff_factor").to_compile_error().into(),
63 },
64 "jitter" => match expr {
65 Expr::Lit(syn::ExprLit { lit: Lit::Bool(litb), .. }) => jitter_opt = Some(litb.value),
66 _ => return syn::Error::new_spanned(expr, "expected boolean literal for jitter").to_compile_error().into(),
67 },
68 "rng_seed" => match expr {
69 Expr::Lit(syn::ExprLit { lit: Lit::Int(litint), .. }) => rng_seed = Some(litint.base10_parse::<u64>().unwrap()),
70 _ => return syn::Error::new_spanned(expr, "expected integer literal for rng_seed").to_compile_error().into(),
71 },
72 "predicate" => {
73 match expr {
75 Expr::Path(_) => {
76 predicate_expr = Some(expr);
77 }
78 Expr::Closure(_) => {
79 predicate_expr = Some(expr);
81 }
82 Expr::Lit(syn::ExprLit { lit: Lit::Str(lits), .. }) => {
83 let s = lits.value();
85 match syn::parse_str::<syn::Path>(&s) {
86 Ok(p) => predicate_expr = Some(Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: p })),
87 Err(_) => return syn::Error::new_spanned(lits, "invalid path in string").to_compile_error().into(),
88 }
89 }
90 _ => return syn::Error::new_spanned(expr, "expected path, closure, or string literal for predicate").to_compile_error().into(),
91 }
92 }
93 other => return syn::Error::new_spanned(ident, format!("unknown option `{}`", other)).to_compile_error().into(),
94 }
95 }
96 }
97 }
98
99 let attempts = attempts.unwrap_or(3usize);
101
102 let input = parse_macro_input!(item as ItemFn);
103
104 if input.sig.asyncness.is_none() {
106 return syn::Error::new_spanned(input.sig.fn_token, "`#[retry]` can only be applied to `async fn`").to_compile_error().into();
107 }
108
109 let vis = &input.vis;
110 let mut sig = input.sig.clone();
111 let attrs = &input.attrs;
112 let block = &input.block;
113
114 let mut clones = Vec::new();
116 for input in sig.inputs.iter() {
117 if let syn::FnArg::Typed(pat_type) = input {
118 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
119 let ident = &pat_ident.ident;
120 clones.push(quote::quote! { let #ident = #ident.clone(); });
121 }
122 }
123 }
124
125 let mut fields = Vec::new();
130 fields.push(quote! { attempts: #attempts });
131 if let Some(ms) = base_delay_ms {
132 fields.push(quote! { base_delay: ::std::time::Duration::from_millis(#ms) });
133 }
134 if let Some(ms) = max_delay_ms {
135 fields.push(quote! { max_delay: ::std::time::Duration::from_millis(#ms) });
136 }
137 if let Some(f) = backoff_factor {
138 fields.push(quote! { backoff_factor: #f });
139 }
140 if let Some(b) = jitter_opt {
141 fields.push(quote! { jitter: #b });
142 }
143 if let Some(seed) = rng_seed {
144 fields.push(quote! { rng_seed: Some(#seed) });
145 }
146
147 let predicate_tokens = if let Some(pred) = predicate_expr {
149 quote! { #pred }
150 } else {
151 quote! { |_| true }
152 };
153
154 let expanded = quote! {
155 #(#attrs)*
156 #vis #sig {
157 let policy = ::asyn_retry_policy::RetryPolicy { #(#fields),*, ..Default::default() };
158 policy.retry(|| {
159 #(#clones)*
160 async move #block
161 }, #predicate_tokens).await
162 }
163 };
164
165 expanded.into()
166}