retryable_proc_macros/
mod.rs

1use proc_macro::TokenStream;
2
3extern crate proc_macro;
4use darling::FromMeta;
5use quote::quote;
6use syn::{parse_macro_input, AttributeArgs, ItemFn, Signature};
7
8#[derive(Debug, FromMeta)]
9struct MacroArgs {
10    max_attempts: u16,
11    sleep_seconds: u64,
12}
13
14#[proc_macro_attribute]
15pub fn retryable(args: TokenStream, item: TokenStream) -> TokenStream {
16    let attr_args = parse_macro_input!(args as AttributeArgs);
17    let args = match MacroArgs::from_list(&attr_args) {
18        Ok(v) => v,
19        Err(e) => {
20            return TokenStream::from(e.write_errors());
21        }
22    };
23
24    let max_attempts = args.max_attempts;
25    let sleep_seconds = args.sleep_seconds;
26
27    let function = parse_macro_input!(item as ItemFn);
28    let function_signature = function.sig.clone();
29
30    let ItemFn {
31        attrs,
32        vis,
33        block,
34        sig,
35        ..
36    } = function;
37
38    let Signature {
39        output: return_type,
40        inputs: params,
41        unsafety,
42        asyncness,
43        constness,
44        abi,
45        ident,
46        generics:
47            syn::Generics {
48                params: gen_params,
49                where_clause,
50                ..
51            },
52        ..
53    } = function_signature;
54
55    if sig.asyncness.is_some() {
56        quote!(
57            #(#attrs) *
58            #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type
59            #where_clause
60            {
61                let mut counter: u16 = 0;
62                loop {
63                    match #block {
64                        Ok(result) => return Ok(result),
65                        Err(err) if counter < #max_attempts => {
66                            counter += 1;
67                            ::retryable::async_std::task::sleep(std::time::Duration::from_secs(#sleep_seconds)).await
68                        },
69                        Err(err) if counter == #max_attempts => {
70                            break Err(err)
71                        },
72                        Err(err) => break Err(err)
73                    }
74                }
75            }).into()
76    } else {
77        quote!(
78        #(#attrs) *
79        #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type
80        #where_clause
81        {
82            let mut counter: u16 = 0;
83            loop {
84                match #block {
85                    Ok(result) => return Ok(result),
86                    Err(err) if counter < #max_attempts => {
87                        counter += 1;
88                        std::thread::sleep(std::time::Duration::from_secs(#sleep_seconds));
89                    },
90                    Err(err) if counter == #max_attempts => {
91                        break Err(err)
92                    },
93                    Err(err) => break Err(err)
94                }
95            }
96
97        })
98        .into()
99    }
100}