Skip to main content

recovery_derive/
lib.rs

1//! A derive macro for the `recovery` trait.
2//!
3//! Please see the docs for the `recovery` crate for more information on how to use it.
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{Attribute, DeriveInput, parse_macro_input};
7
8#[proc_macro_derive(Recovery, attributes(recovery))]
9pub fn recovery_derive(tokens: TokenStream) -> TokenStream {
10    let DeriveInput {
11        attrs,
12        ident,
13        data,
14        generics,
15        ..
16    } = parse_macro_input!(tokens);
17    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
18
19    let syn::Data::Enum(data) = data else {
20        panic!("#[derive(Recovery)] only supports enums");
21    };
22
23    let default_recovery = find_recovery_attr(&attrs);
24
25    let mut variants = vec![];
26    for variant in data.variants {
27        let syn::Variant {
28            attrs,
29            ident,
30            fields,
31            ..
32        } = variant;
33
34        let recovery_strategy = find_recovery_attr(&attrs)
35            .or(default_recovery)
36            .unwrap_or_else(|| {
37                panic!("Please add #[recovery(...)] for {ident} or an enum-level default")
38            });
39
40        variants.push(match (fields, recovery_strategy) {
41            (syn::Fields::Named(_), DeriveStrategy::Transparent) => {
42                panic!("#[recovery(transparent)] is not supported for variants with named fields");
43            }
44            (syn::Fields::Named(_), DeriveStrategy::Auto) => {
45                quote! { Self::#ident{..} => recovery::RecoveryStrategy::Auto }
46            }
47            (syn::Fields::Named(_), DeriveStrategy::Manual) => {
48                quote! { Self::#ident{..} => recovery::RecoveryStrategy::Manual }
49            }
50            (syn::Fields::Named(_), DeriveStrategy::Never) => {
51                quote! { Self::#ident{..} => recovery::RecoveryStrategy::Never }
52            }
53            (syn::Fields::Unnamed(_), DeriveStrategy::Transparent) => {
54                quote! { Self::#ident(field, ..) => field.recovery() }
55            }
56            (syn::Fields::Unnamed(_), DeriveStrategy::Auto) => {
57                quote! { Self::#ident(..) => recovery::RecoveryStrategy::Auto }
58            }
59            (syn::Fields::Unnamed(_), DeriveStrategy::Manual) => {
60                quote! { Self::#ident(..) => recovery::RecoveryStrategy::Manual }
61            }
62            (syn::Fields::Unnamed(_), DeriveStrategy::Never) => {
63                quote! { Self::#ident(..) => recovery::RecoveryStrategy::Never }
64            }
65            (syn::Fields::Unit, DeriveStrategy::Transparent) => {
66                panic!("#[recovery(transparent)] is not supported for unit variants");
67            }
68            (syn::Fields::Unit, DeriveStrategy::Auto) => {
69                quote! { Self::#ident => recovery::RecoveryStrategy::Auto }
70            }
71            (syn::Fields::Unit, DeriveStrategy::Manual) => {
72                quote! { Self::#ident => recovery::RecoveryStrategy::Manual }
73            }
74            (syn::Fields::Unit, DeriveStrategy::Never) => {
75                quote! { Self::#ident => recovery::RecoveryStrategy::Never }
76            }
77        });
78    }
79
80    quote! {
81        impl #impl_generics recovery::Recovery for #ident #type_generics #where_clause {
82            fn recovery(&self) -> recovery::RecoveryStrategy {
83                match self {
84                    #( #variants, )*
85                }
86            }
87        }
88    }
89    .into()
90}
91
92fn find_recovery_attr(attrs: &[Attribute]) -> Option<DeriveStrategy> {
93    let attr = attrs.iter().find(|attr| attr.path().is_ident("recovery"))?;
94
95    let mut result = None;
96    attr.parse_nested_meta(|meta| match meta.path.get_ident() {
97        Some(ident) if ident == "transparent" => {
98            result = Some(DeriveStrategy::Transparent);
99            Ok(())
100        }
101        Some(ident) if ident == "auto" => {
102            result = Some(DeriveStrategy::Auto);
103            Ok(())
104        }
105        Some(ident) if ident == "manual" => {
106            result = Some(DeriveStrategy::Manual);
107            Ok(())
108        }
109        Some(ident) if ident == "never" => {
110            result = Some(DeriveStrategy::Never);
111            Ok(())
112        }
113        Some(_) | None => Err(meta
114            .error("One of these is required: \"auto\", \"manual\", \"never\", \"transparent\"")),
115    })
116    .unwrap();
117    result
118}
119
120#[derive(Clone, Copy)]
121enum DeriveStrategy {
122    Transparent,
123    Auto,
124    Manual,
125    Never,
126}