1#![forbid(unsafe_code)]
72
73use proc_macro::TokenStream;
74use quote::{format_ident, quote};
75use syn::parse::{Parse, ParseStream};
76use syn::spanned::Spanned;
77use syn::{Error, FnArg, Ident, ImplItemFn, ItemFn, LitBool, Pat, Path, Signature, Token};
78
79#[proc_macro_attribute]
81pub fn backon(args: TokenStream, input: TokenStream) -> TokenStream {
82 match expand_backon(args, input) {
83 Ok(tokens) => tokens,
84 Err(err) => err.to_compile_error().into(),
85 }
86}
87
88fn expand_backon(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
89 let args = syn::parse2::<BackonArgs>(proc_macro2::TokenStream::from(args))?;
90 let input_tokens = proc_macro2::TokenStream::from(input);
91
92 if let Ok(mut item_fn) = syn::parse2::<ItemFn>(input_tokens.clone()) {
93 if item_fn.sig.receiver().is_some() {
94 let method = syn::parse2::<ImplItemFn>(input_tokens)?;
95 return expand_method(&args, method);
96 }
97 let original_block = (*item_fn.block).clone();
98 let body_tokens = quote!(#original_block);
99 let block = build_function_body(&args, &item_fn.sig, body_tokens, None, false, false)?;
100 item_fn.block = Box::new(block);
101 return Ok(TokenStream::from(quote!(#item_fn)));
102 }
103
104 if let Ok(method) = syn::parse2::<ImplItemFn>(input_tokens) {
105 return expand_method(&args, method);
106 }
107
108 Err(Error::new(
109 proc_macro2::Span::call_site(),
110 "#[backon] may only be applied to free functions or inherent methods",
111 ))
112}
113
114fn expand_method(args: &BackonArgs, method: ImplItemFn) -> syn::Result<TokenStream> {
115 let has_receiver = matches!(method.sig.inputs.first(), Some(FnArg::Receiver(_)));
116
117 if !has_receiver {
118 let mut wrapper = method;
119 wrapper.attrs.retain(|attr| !attr.path().is_ident("backon"));
120 let original_block = wrapper.block.clone();
121 let body_tokens = quote!(#original_block);
122 let block = build_function_body(args, &wrapper.sig, body_tokens, None, false, false)?;
123 wrapper.block = block;
124 return Ok(TokenStream::from(quote!(#wrapper)));
125 }
126
127 let mut helper = method.clone();
128 helper.attrs.retain(|attr| !attr.path().is_ident("backon"));
129 let helper_ident = format_ident!("__backon_{}_inner", helper.sig.ident);
130 helper.sig.ident = helper_ident.clone();
131
132 let mut wrapper = method;
133 wrapper.attrs.retain(|attr| !attr.path().is_ident("backon"));
134
135 let receiver = match wrapper.sig.inputs.first() {
136 Some(FnArg::Receiver(receiver)) => receiver,
137 _ => {
138 return Err(Error::new(
139 wrapper.sig.span(),
140 "failed to determine method receiver",
141 ));
142 }
143 };
144
145 if let Some(mutability) = receiver.mutability.as_ref() {
146 return Err(Error::new(
147 mutability.span(),
148 "`#[backon]` does not yet support methods taking `&mut self`; please fall back to manual `RetryableWithContext` usage",
149 ));
150 }
151
152 if receiver.reference.is_none() {
153 return Err(Error::new(
154 receiver.self_token.span,
155 "`#[backon]` does not support methods that take ownership of `self`; please fall back to manual `RetryableWithContext` usage",
156 ));
157 }
158
159 if args.context {
160 let span = args.context_span.unwrap_or_else(|| receiver.span());
161 return Err(Error::new(
162 span,
163 "`context = true` is not supported for methods taking `&self`",
164 ));
165 }
166
167 let arg_idents = collect_arg_idents(&wrapper.sig)?;
168
169 let receiver_tokens = quote!(self);
170 let helper_args = if arg_idents.is_empty() {
171 quote!(#receiver_tokens)
172 } else {
173 quote!(#receiver_tokens, #(#arg_idents),*)
174 };
175
176 let helper_call = if wrapper.sig.asyncness.is_some() {
177 quote!(Self::#helper_ident(#helper_args).await)
178 } else {
179 quote!(Self::#helper_ident(#helper_args))
180 };
181
182 let body_tokens = quote!({ #helper_call });
183 let block = build_function_body(args, &wrapper.sig, body_tokens, None, false, false)?;
184 wrapper.block = block;
185
186 Ok(TokenStream::from(quote!(#helper #wrapper)))
187}
188
189#[derive(Clone, Default)]
190struct BackonArgs {
191 backoff: Option<Path>,
192 sleep: Option<Path>,
193 when: Option<Path>,
194 notify: Option<Path>,
195 adjust: Option<Path>,
196 context: bool,
197 context_span: Option<proc_macro2::Span>,
198}
199
200impl Parse for BackonArgs {
201 fn parse(input: ParseStream) -> syn::Result<Self> {
202 if input.is_empty() {
203 return Ok(Self::default());
204 }
205
206 let mut args = BackonArgs::default();
207
208 while !input.is_empty() {
209 let ident: Ident = input.parse()?;
210 let key = ident.to_string();
211 input.parse::<Token![=]>()?;
212
213 match key.as_str() {
214 "backoff" => {
215 ensure_path_unset(args.backoff.is_some(), ident.span())?;
216 args.backoff = Some(input.parse()?);
217 }
218 "sleep" => {
219 ensure_path_unset(args.sleep.is_some(), ident.span())?;
220 args.sleep = Some(input.parse()?);
221 }
222 "when" => {
223 ensure_path_unset(args.when.is_some(), ident.span())?;
224 args.when = Some(input.parse()?);
225 }
226 "notify" => {
227 ensure_path_unset(args.notify.is_some(), ident.span())?;
228 args.notify = Some(input.parse()?);
229 }
230 "adjust" => {
231 ensure_path_unset(args.adjust.is_some(), ident.span())?;
232 args.adjust = Some(input.parse()?);
233 }
234 "context" => {
235 if args.context {
236 return Err(Error::new(
237 ident.span(),
238 "`context` cannot be specified more than once",
239 ));
240 }
241 let value: LitBool = input.parse()?;
242 args.context = value.value;
243 args.context_span = Some(value.span());
244 }
245 other => {
246 return Err(Error::new(
247 ident.span(),
248 format!("unknown parameter `{other}`"),
249 ));
250 }
251 }
252
253 if input.peek(Token![,]) {
254 input.parse::<Token![,]>()?;
255 }
256 }
257
258 Ok(args)
259 }
260}
261
262fn ensure_path_unset(already: bool, span: proc_macro2::Span) -> syn::Result<()> {
263 if already {
264 Err(Error::new(span, "parameter already specified"))
265 } else {
266 Ok(())
267 }
268}
269
270fn collect_arg_idents(sig: &Signature) -> syn::Result<Vec<Ident>> {
271 let mut out = Vec::new();
272 for input in sig.inputs.iter() {
273 if let FnArg::Typed(pat_type) = input {
274 match &*pat_type.pat {
275 Pat::Ident(pat_ident) => out.push(pat_ident.ident.clone()),
276 _ => {
277 return Err(Error::new(
278 pat_type.span(),
279 "parameters must bind to identifiers",
280 ));
281 }
282 }
283 }
284 }
285 Ok(out)
286}
287
288fn build_function_body(
289 args: &BackonArgs,
290 sig: &Signature,
291 body: proc_macro2::TokenStream,
292 precomputed_context: Option<ContextInfo>,
293 force_context: bool,
294 include_receiver: bool,
295) -> syn::Result<syn::Block> {
296 let is_async = sig.asyncness.is_some();
297
298 let chain_config = ChainConfig {
299 is_async,
300 backoff: args
301 .backoff
302 .clone()
303 .unwrap_or_else(|| syn::parse_str("::backon::ExponentialBuilder::default").unwrap()),
304 sleep: args.sleep.clone(),
305 when: args.when.clone(),
306 notify: args.notify.clone(),
307 adjust: args.adjust.clone(),
308 };
309
310 if chain_config.adjust.is_some() && !is_async {
311 return Err(Error::new(
312 sig.ident.span(),
313 "`adjust` is only available for async functions",
314 ));
315 }
316
317 let context_data = if let Some(context) = precomputed_context {
318 Some(context)
319 } else if force_context || args.context {
320 Some(prepare_context(sig, include_receiver)?)
321 } else {
322 None
323 };
324
325 let chain_tokens = if let Some(context) = context_data {
326 build_with_context_chain(&chain_config, body.clone(), context)
327 } else {
328 build_simple_chain(&chain_config, body)
329 }?;
330
331 syn::parse2(chain_tokens)
332}
333
334struct ChainConfig {
335 is_async: bool,
336 backoff: Path,
337 sleep: Option<Path>,
338 when: Option<Path>,
339 notify: Option<Path>,
340 adjust: Option<Path>,
341}
342
343#[derive(Clone)]
344struct ContextInfo {
345 pattern: proc_macro2::TokenStream,
346 initial_expr: proc_macro2::TokenStream,
347 return_expr: proc_macro2::TokenStream,
348 ty: proc_macro2::TokenStream,
349}
350
351fn prepare_context(sig: &Signature, include_receiver: bool) -> syn::Result<ContextInfo> {
352 let mut patterns = Vec::new();
353 let mut exprs = Vec::new();
354 let mut return_exprs = Vec::new();
355 let mut types = Vec::new();
356 for input in sig.inputs.iter() {
357 match input {
358 FnArg::Receiver(receiver) => {
359 if !include_receiver {
360 continue;
361 }
362
363 if receiver.reference.is_none() {
364 return Err(Error::new(
365 receiver.self_token.span,
366 "`context = true` does not support methods that take ownership of `self`",
367 ));
368 }
369
370 if receiver.colon_token.is_some() {
371 return Err(Error::new(
372 receiver.span(),
373 "`#[backon]` currently supports only `&self` and `&mut self` receivers",
374 ));
375 }
376
377 let binding = format_ident!("__backon_self");
378 let lifetime = receiver
379 .reference
380 .as_ref()
381 .and_then(|(_, lifetime)| lifetime.as_ref());
382 let ty_tokens = if receiver.mutability.is_some() {
383 if let Some(lifetime) = lifetime {
384 quote!(& #lifetime mut Self)
385 } else {
386 quote!(&mut Self)
387 }
388 } else if let Some(lifetime) = lifetime {
389 quote!(& #lifetime Self)
390 } else {
391 quote!(&Self)
392 };
393
394 patterns.push(quote!(#binding));
395 exprs.push(quote!(self));
396 return_exprs.push(quote!(#binding));
397 types.push(ty_tokens);
398 }
399 FnArg::Typed(pat_type) => match &*pat_type.pat {
400 Pat::Ident(pat_ident) => {
401 let ident = &pat_ident.ident;
402 patterns.push(quote!(#pat_ident));
403 exprs.push(quote!(#ident));
404 return_exprs.push(quote!(#ident));
405 let ty = &pat_type.ty;
406 types.push(quote!(#ty));
407 }
408 _ => {
409 return Err(Error::new(
410 pat_type.pat.span(),
411 "`context = true` requires arguments to bind to identifiers",
412 ));
413 }
414 },
415 }
416 }
417
418 let pattern = if patterns.is_empty() {
419 quote!(())
420 } else {
421 quote!((#(#patterns),*))
422 };
423
424 let initial_expr = if exprs.is_empty() {
425 quote!(())
426 } else {
427 quote!((#(#exprs),*))
428 };
429
430 let return_expr = if return_exprs.is_empty() {
431 quote!(())
432 } else {
433 quote!((#(#return_exprs),*))
434 };
435
436 let ty = if types.is_empty() {
437 quote!(())
438 } else {
439 quote!((#(#types),*))
440 };
441
442 Ok(ContextInfo {
443 pattern,
444 initial_expr,
445 return_expr,
446 ty,
447 })
448}
449
450fn build_simple_chain(
451 config: &ChainConfig,
452 body: proc_macro2::TokenStream,
453) -> syn::Result<proc_macro2::TokenStream> {
454 let backoff_path = &config.backoff;
455
456 let mut chain = if config.is_async {
457 quote! {
458 (|| async move #body)
459 .retry(__backon_builder)
460 }
461 } else {
462 quote! {
463 (|| #body)
464 .retry(__backon_builder)
465 }
466 };
467
468 if let Some(path) = config.sleep.clone() {
469 chain = quote!(#chain.sleep(#path));
470 }
471
472 if let Some(path) = config.when.clone() {
473 chain = quote!(#chain.when(#path));
474 }
475
476 if let Some(path) = config.notify.clone() {
477 chain = quote!(#chain.notify(#path));
478 }
479
480 if let Some(path) = config.adjust.clone() {
481 chain = quote!(#chain.adjust(#path));
482 }
483
484 let executed = if config.is_async {
485 quote!(#chain.await)
486 } else {
487 quote!(#chain.call())
488 };
489
490 let trait_use = if config.is_async {
491 quote!(
492 use ::backon::Retryable as _;
493 )
494 } else {
495 quote!(
496 use ::backon::BlockingRetryable as _;
497 )
498 };
499
500 Ok(quote!({
501 #trait_use
502 let __backon_builder = (#backoff_path)();
503 #executed
504 }))
505}
506
507fn build_with_context_chain(
508 config: &ChainConfig,
509 body: proc_macro2::TokenStream,
510 context: ContextInfo,
511) -> syn::Result<proc_macro2::TokenStream> {
512 let backoff_path = &config.backoff;
513 let initial_context = &context.initial_expr;
514 let return_context = &context.return_expr;
515 let context_ty = &context.ty;
516 let pattern = &context.pattern;
517
518 let mut chain = if config.is_async {
519 quote! {
520 (|__backon_ctx: #context_ty| async move {
521 let #pattern = __backon_ctx;
522 let __backon_result = #body;
523 (#return_context, __backon_result)
524 })
525 .retry(__backon_builder)
526 }
527 } else {
528 quote! {
529 (|__backon_ctx: #context_ty| {
530 let #pattern = __backon_ctx;
531 let __backon_result = #body;
532 (#return_context, __backon_result)
533 })
534 .retry(__backon_builder)
535 }
536 };
537
538 if let Some(path) = config.sleep.clone() {
539 chain = quote!(#chain.sleep(#path));
540 }
541
542 if let Some(path) = config.when.clone() {
543 chain = quote!(#chain.when(#path));
544 }
545
546 if let Some(path) = config.notify.clone() {
547 chain = quote!(#chain.notify(#path));
548 }
549
550 if let Some(path) = config.adjust.clone() {
551 chain = quote!(#chain.adjust(#path));
552 }
553
554 let trait_use = if config.is_async {
555 quote!(
556 use ::backon::RetryableWithContext as _;
557 )
558 } else {
559 quote!(
560 use ::backon::BlockingRetryableWithContext as _;
561 )
562 };
563
564 let tail = if config.is_async {
565 quote!({
566 let (__backon_context, __backon_result) = #chain
567 .context(__backon_initial_context)
568 .await;
569 let _ = __backon_context;
570 __backon_result
571 })
572 } else {
573 quote!({
574 let (__backon_context, __backon_result) = #chain
575 .context(__backon_initial_context)
576 .call();
577 let _ = __backon_context;
578 __backon_result
579 })
580 };
581
582 Ok(quote!({
583 #trait_use
584 let __backon_builder = (#backoff_path)();
585 let __backon_initial_context: #context_ty = #initial_context;
586 #tail
587 }))
588}