backon_macros/
lib.rs

1//! Attribute macros that integrate with the `backon` retry library.
2//!
3//! # Overview
4//!
5//! This crate provides the `#[backon]` attribute for free functions and inherent
6//! methods. Annotated items are rewritten so their bodies execute inside the
7//! `backon` retry pipeline, matching the fluent builder style from the runtime
8//! crate without hand-written closures.
9//!
10//! The macro inspects the target signature to decide whether to call
11//! [`Retryable`](backon::Retryable) or [`BlockingRetryable`](backon::BlockingRetryable).
12//! When `context = true` is supplied, it switches to the corresponding `*_WithContext`
13//! traits so the arguments are preserved across retries.
14//!
15//! # Usage
16//!
17//! ```
18//! use std::time::Duration;
19//!
20//! use backon_macros::backon;
21//!
22//! #[derive(Debug)]
23//! enum ExampleError {
24//!     Temporary,
25//!     Fatal,
26//! }
27//!
28//! fn should_retry(err: &ExampleError) -> bool {
29//!     matches!(err, ExampleError::Temporary)
30//! }
31//!
32//! fn log_retry(err: &ExampleError, dur: Duration) {
33//!     println!("retrying after {dur:?}: {err:?}");
34//! }
35//!
36//! #[backon(
37//!     backoff = backon::ExponentialBuilder::default,
38//!     sleep = tokio::time::sleep,
39//!     when = should_retry,
40//!     notify = log_retry
41//! )]
42//! async fn fetch() -> Result<String, ExampleError> {
43//!     Ok("value".to_string())
44//! }
45//!
46//! #[tokio::main(flavor = "current_thread")]
47//! async fn main() -> Result<(), ExampleError> {
48//!     let value = fetch().await?;
49//!     println!("{value}");
50//!     Ok(())
51//! }
52//! ```
53//!
54//! # Parameters
55//!
56//! * `backoff = path` – Builder that creates a backoff strategy. Defaults to
57//!   `backon::ExponentialBuilder::default`.
58//! * `sleep = path` – Sleeper function used for async or blocking retries.
59//! * `when = path` – Predicate that filters retryable errors.
60//! * `notify = path` – Callback invoked before each sleep.
61//! * `adjust = path` – Async-only hook that can override the delay.
62//! * `context = true` – Capture inputs into a context tuple and use the
63//!   `RetryableWithContext` traits.
64//!
65//! # Limitations
66//!
67//! * Methods that take `&mut self` or own `self` are not generated; fallback to
68//!   manual `RetryableWithContext` until support lands.
69//! * Parameters must bind to identifiers; destructuring patterns are rejected.
70//! * `context = true` is unavailable for `&self` methods.
71#![forbid(unsafe_code)]
72
73use proc_macro::TokenStream;
74use quote::{format_ident, quote};
75use syn::parse::{Parse, ParseStream};
76use syn::spanned::Spanned;
77use syn::{Error, FnArg, Ident, ImplItemFn, ItemFn, LitBool, Pat, Path, Signature, Token};
78
79/// Attribute for turning a function into a retried one using backon retry APIs.
80#[proc_macro_attribute]
81pub fn backon(args: TokenStream, input: TokenStream) -> TokenStream {
82    match expand_backon(args, input) {
83        Ok(tokens) => tokens,
84        Err(err) => err.to_compile_error().into(),
85    }
86}
87
88fn expand_backon(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
89    let args = syn::parse2::<BackonArgs>(proc_macro2::TokenStream::from(args))?;
90    let input_tokens = proc_macro2::TokenStream::from(input);
91
92    if let Ok(mut item_fn) = syn::parse2::<ItemFn>(input_tokens.clone()) {
93        if item_fn.sig.receiver().is_some() {
94            let method = syn::parse2::<ImplItemFn>(input_tokens)?;
95            return expand_method(&args, method);
96        }
97        let original_block = (*item_fn.block).clone();
98        let body_tokens = quote!(#original_block);
99        let block = build_function_body(&args, &item_fn.sig, body_tokens, None, false, false)?;
100        item_fn.block = Box::new(block);
101        return Ok(TokenStream::from(quote!(#item_fn)));
102    }
103
104    if let Ok(method) = syn::parse2::<ImplItemFn>(input_tokens) {
105        return expand_method(&args, method);
106    }
107
108    Err(Error::new(
109        proc_macro2::Span::call_site(),
110        "#[backon] may only be applied to free functions or inherent methods",
111    ))
112}
113
114fn expand_method(args: &BackonArgs, method: ImplItemFn) -> syn::Result<TokenStream> {
115    let has_receiver = matches!(method.sig.inputs.first(), Some(FnArg::Receiver(_)));
116
117    if !has_receiver {
118        let mut wrapper = method;
119        wrapper.attrs.retain(|attr| !attr.path().is_ident("backon"));
120        let original_block = wrapper.block.clone();
121        let body_tokens = quote!(#original_block);
122        let block = build_function_body(args, &wrapper.sig, body_tokens, None, false, false)?;
123        wrapper.block = block;
124        return Ok(TokenStream::from(quote!(#wrapper)));
125    }
126
127    let mut helper = method.clone();
128    helper.attrs.retain(|attr| !attr.path().is_ident("backon"));
129    let helper_ident = format_ident!("__backon_{}_inner", helper.sig.ident);
130    helper.sig.ident = helper_ident.clone();
131
132    let mut wrapper = method;
133    wrapper.attrs.retain(|attr| !attr.path().is_ident("backon"));
134
135    let receiver = match wrapper.sig.inputs.first() {
136        Some(FnArg::Receiver(receiver)) => receiver,
137        _ => {
138            return Err(Error::new(
139                wrapper.sig.span(),
140                "failed to determine method receiver",
141            ));
142        }
143    };
144
145    if let Some(mutability) = receiver.mutability.as_ref() {
146        return Err(Error::new(
147            mutability.span(),
148            "`#[backon]` does not yet support methods taking `&mut self`; please fall back to manual `RetryableWithContext` usage",
149        ));
150    }
151
152    if receiver.reference.is_none() {
153        return Err(Error::new(
154            receiver.self_token.span,
155            "`#[backon]` does not support methods that take ownership of `self`; please fall back to manual `RetryableWithContext` usage",
156        ));
157    }
158
159    if args.context {
160        let span = args.context_span.unwrap_or_else(|| receiver.span());
161        return Err(Error::new(
162            span,
163            "`context = true` is not supported for methods taking `&self`",
164        ));
165    }
166
167    let arg_idents = collect_arg_idents(&wrapper.sig)?;
168
169    let receiver_tokens = quote!(self);
170    let helper_args = if arg_idents.is_empty() {
171        quote!(#receiver_tokens)
172    } else {
173        quote!(#receiver_tokens, #(#arg_idents),*)
174    };
175
176    let helper_call = if wrapper.sig.asyncness.is_some() {
177        quote!(Self::#helper_ident(#helper_args).await)
178    } else {
179        quote!(Self::#helper_ident(#helper_args))
180    };
181
182    let body_tokens = quote!({ #helper_call });
183    let block = build_function_body(args, &wrapper.sig, body_tokens, None, false, false)?;
184    wrapper.block = block;
185
186    Ok(TokenStream::from(quote!(#helper #wrapper)))
187}
188
189#[derive(Clone, Default)]
190struct BackonArgs {
191    backoff: Option<Path>,
192    sleep: Option<Path>,
193    when: Option<Path>,
194    notify: Option<Path>,
195    adjust: Option<Path>,
196    context: bool,
197    context_span: Option<proc_macro2::Span>,
198}
199
200impl Parse for BackonArgs {
201    fn parse(input: ParseStream) -> syn::Result<Self> {
202        if input.is_empty() {
203            return Ok(Self::default());
204        }
205
206        let mut args = BackonArgs::default();
207
208        while !input.is_empty() {
209            let ident: Ident = input.parse()?;
210            let key = ident.to_string();
211            input.parse::<Token![=]>()?;
212
213            match key.as_str() {
214                "backoff" => {
215                    ensure_path_unset(args.backoff.is_some(), ident.span())?;
216                    args.backoff = Some(input.parse()?);
217                }
218                "sleep" => {
219                    ensure_path_unset(args.sleep.is_some(), ident.span())?;
220                    args.sleep = Some(input.parse()?);
221                }
222                "when" => {
223                    ensure_path_unset(args.when.is_some(), ident.span())?;
224                    args.when = Some(input.parse()?);
225                }
226                "notify" => {
227                    ensure_path_unset(args.notify.is_some(), ident.span())?;
228                    args.notify = Some(input.parse()?);
229                }
230                "adjust" => {
231                    ensure_path_unset(args.adjust.is_some(), ident.span())?;
232                    args.adjust = Some(input.parse()?);
233                }
234                "context" => {
235                    if args.context {
236                        return Err(Error::new(
237                            ident.span(),
238                            "`context` cannot be specified more than once",
239                        ));
240                    }
241                    let value: LitBool = input.parse()?;
242                    args.context = value.value;
243                    args.context_span = Some(value.span());
244                }
245                other => {
246                    return Err(Error::new(
247                        ident.span(),
248                        format!("unknown parameter `{other}`"),
249                    ));
250                }
251            }
252
253            if input.peek(Token![,]) {
254                input.parse::<Token![,]>()?;
255            }
256        }
257
258        Ok(args)
259    }
260}
261
262fn ensure_path_unset(already: bool, span: proc_macro2::Span) -> syn::Result<()> {
263    if already {
264        Err(Error::new(span, "parameter already specified"))
265    } else {
266        Ok(())
267    }
268}
269
270fn collect_arg_idents(sig: &Signature) -> syn::Result<Vec<Ident>> {
271    let mut out = Vec::new();
272    for input in sig.inputs.iter() {
273        if let FnArg::Typed(pat_type) = input {
274            match &*pat_type.pat {
275                Pat::Ident(pat_ident) => out.push(pat_ident.ident.clone()),
276                _ => {
277                    return Err(Error::new(
278                        pat_type.span(),
279                        "parameters must bind to identifiers",
280                    ));
281                }
282            }
283        }
284    }
285    Ok(out)
286}
287
288fn build_function_body(
289    args: &BackonArgs,
290    sig: &Signature,
291    body: proc_macro2::TokenStream,
292    precomputed_context: Option<ContextInfo>,
293    force_context: bool,
294    include_receiver: bool,
295) -> syn::Result<syn::Block> {
296    let is_async = sig.asyncness.is_some();
297
298    let chain_config = ChainConfig {
299        is_async,
300        backoff: args
301            .backoff
302            .clone()
303            .unwrap_or_else(|| syn::parse_str("::backon::ExponentialBuilder::default").unwrap()),
304        sleep: args.sleep.clone(),
305        when: args.when.clone(),
306        notify: args.notify.clone(),
307        adjust: args.adjust.clone(),
308    };
309
310    if chain_config.adjust.is_some() && !is_async {
311        return Err(Error::new(
312            sig.ident.span(),
313            "`adjust` is only available for async functions",
314        ));
315    }
316
317    let context_data = if let Some(context) = precomputed_context {
318        Some(context)
319    } else if force_context || args.context {
320        Some(prepare_context(sig, include_receiver)?)
321    } else {
322        None
323    };
324
325    let chain_tokens = if let Some(context) = context_data {
326        build_with_context_chain(&chain_config, body.clone(), context)
327    } else {
328        build_simple_chain(&chain_config, body)
329    }?;
330
331    syn::parse2(chain_tokens)
332}
333
334struct ChainConfig {
335    is_async: bool,
336    backoff: Path,
337    sleep: Option<Path>,
338    when: Option<Path>,
339    notify: Option<Path>,
340    adjust: Option<Path>,
341}
342
343#[derive(Clone)]
344struct ContextInfo {
345    pattern: proc_macro2::TokenStream,
346    initial_expr: proc_macro2::TokenStream,
347    return_expr: proc_macro2::TokenStream,
348    ty: proc_macro2::TokenStream,
349}
350
351fn prepare_context(sig: &Signature, include_receiver: bool) -> syn::Result<ContextInfo> {
352    let mut patterns = Vec::new();
353    let mut exprs = Vec::new();
354    let mut return_exprs = Vec::new();
355    let mut types = Vec::new();
356    for input in sig.inputs.iter() {
357        match input {
358            FnArg::Receiver(receiver) => {
359                if !include_receiver {
360                    continue;
361                }
362
363                if receiver.reference.is_none() {
364                    return Err(Error::new(
365                        receiver.self_token.span,
366                        "`context = true` does not support methods that take ownership of `self`",
367                    ));
368                }
369
370                if receiver.colon_token.is_some() {
371                    return Err(Error::new(
372                        receiver.span(),
373                        "`#[backon]` currently supports only `&self` and `&mut self` receivers",
374                    ));
375                }
376
377                let binding = format_ident!("__backon_self");
378                let lifetime = receiver
379                    .reference
380                    .as_ref()
381                    .and_then(|(_, lifetime)| lifetime.as_ref());
382                let ty_tokens = if receiver.mutability.is_some() {
383                    if let Some(lifetime) = lifetime {
384                        quote!(& #lifetime mut Self)
385                    } else {
386                        quote!(&mut Self)
387                    }
388                } else if let Some(lifetime) = lifetime {
389                    quote!(& #lifetime Self)
390                } else {
391                    quote!(&Self)
392                };
393
394                patterns.push(quote!(#binding));
395                exprs.push(quote!(self));
396                return_exprs.push(quote!(#binding));
397                types.push(ty_tokens);
398            }
399            FnArg::Typed(pat_type) => match &*pat_type.pat {
400                Pat::Ident(pat_ident) => {
401                    let ident = &pat_ident.ident;
402                    patterns.push(quote!(#pat_ident));
403                    exprs.push(quote!(#ident));
404                    return_exprs.push(quote!(#ident));
405                    let ty = &pat_type.ty;
406                    types.push(quote!(#ty));
407                }
408                _ => {
409                    return Err(Error::new(
410                        pat_type.pat.span(),
411                        "`context = true` requires arguments to bind to identifiers",
412                    ));
413                }
414            },
415        }
416    }
417
418    let pattern = if patterns.is_empty() {
419        quote!(())
420    } else {
421        quote!((#(#patterns),*))
422    };
423
424    let initial_expr = if exprs.is_empty() {
425        quote!(())
426    } else {
427        quote!((#(#exprs),*))
428    };
429
430    let return_expr = if return_exprs.is_empty() {
431        quote!(())
432    } else {
433        quote!((#(#return_exprs),*))
434    };
435
436    let ty = if types.is_empty() {
437        quote!(())
438    } else {
439        quote!((#(#types),*))
440    };
441
442    Ok(ContextInfo {
443        pattern,
444        initial_expr,
445        return_expr,
446        ty,
447    })
448}
449
450fn build_simple_chain(
451    config: &ChainConfig,
452    body: proc_macro2::TokenStream,
453) -> syn::Result<proc_macro2::TokenStream> {
454    let backoff_path = &config.backoff;
455
456    let mut chain = if config.is_async {
457        quote! {
458            (|| async move #body)
459                .retry(__backon_builder)
460        }
461    } else {
462        quote! {
463            (|| #body)
464                .retry(__backon_builder)
465        }
466    };
467
468    if let Some(path) = config.sleep.clone() {
469        chain = quote!(#chain.sleep(#path));
470    }
471
472    if let Some(path) = config.when.clone() {
473        chain = quote!(#chain.when(#path));
474    }
475
476    if let Some(path) = config.notify.clone() {
477        chain = quote!(#chain.notify(#path));
478    }
479
480    if let Some(path) = config.adjust.clone() {
481        chain = quote!(#chain.adjust(#path));
482    }
483
484    let executed = if config.is_async {
485        quote!(#chain.await)
486    } else {
487        quote!(#chain.call())
488    };
489
490    let trait_use = if config.is_async {
491        quote!(
492            use ::backon::Retryable as _;
493        )
494    } else {
495        quote!(
496            use ::backon::BlockingRetryable as _;
497        )
498    };
499
500    Ok(quote!({
501        #trait_use
502        let __backon_builder = (#backoff_path)();
503        #executed
504    }))
505}
506
507fn build_with_context_chain(
508    config: &ChainConfig,
509    body: proc_macro2::TokenStream,
510    context: ContextInfo,
511) -> syn::Result<proc_macro2::TokenStream> {
512    let backoff_path = &config.backoff;
513    let initial_context = &context.initial_expr;
514    let return_context = &context.return_expr;
515    let context_ty = &context.ty;
516    let pattern = &context.pattern;
517
518    let mut chain = if config.is_async {
519        quote! {
520            (|__backon_ctx: #context_ty| async move {
521                let #pattern = __backon_ctx;
522                let __backon_result = #body;
523                (#return_context, __backon_result)
524            })
525            .retry(__backon_builder)
526        }
527    } else {
528        quote! {
529            (|__backon_ctx: #context_ty| {
530                let #pattern = __backon_ctx;
531                let __backon_result = #body;
532                (#return_context, __backon_result)
533            })
534            .retry(__backon_builder)
535        }
536    };
537
538    if let Some(path) = config.sleep.clone() {
539        chain = quote!(#chain.sleep(#path));
540    }
541
542    if let Some(path) = config.when.clone() {
543        chain = quote!(#chain.when(#path));
544    }
545
546    if let Some(path) = config.notify.clone() {
547        chain = quote!(#chain.notify(#path));
548    }
549
550    if let Some(path) = config.adjust.clone() {
551        chain = quote!(#chain.adjust(#path));
552    }
553
554    let trait_use = if config.is_async {
555        quote!(
556            use ::backon::RetryableWithContext as _;
557        )
558    } else {
559        quote!(
560            use ::backon::BlockingRetryableWithContext as _;
561        )
562    };
563
564    let tail = if config.is_async {
565        quote!({
566            let (__backon_context, __backon_result) = #chain
567                .context(__backon_initial_context)
568                .await;
569            let _ = __backon_context;
570            __backon_result
571        })
572    } else {
573        quote!({
574            let (__backon_context, __backon_result) = #chain
575                .context(__backon_initial_context)
576                .call();
577            let _ = __backon_context;
578            __backon_result
579        })
580    };
581
582    Ok(quote!({
583        #trait_use
584        let __backon_builder = (#backoff_path)();
585        let __backon_initial_context: #context_ty = #initial_context;
586        #tail
587    }))
588}