nullable_utils_macros/
lib.rs

1// SPDX-FileCopyrightText: 2024 Markus Haug (Korrat)
2//
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-License-Identifier: MIT
5
6//! This serves as a companion crate to [`nullable-utils`] and provides proc-macros to support working with Nullables.
7//!
8//! [`nullable-utils`]: https://crates.io/crates/nullable-utils
9
10use proc_macro2::TokenStream;
11use quote::format_ident;
12use quote::quote;
13use quote::quote_spanned;
14use quote::ToTokens as _;
15use syn::braced;
16use syn::parse::Parse;
17use syn::parse::ParseStream;
18use syn::parse_macro_input;
19use syn::parse_quote;
20use syn::punctuated::Punctuated;
21use syn::spanned::Spanned as _;
22use syn::token;
23use syn::token::Comma;
24use syn::token::Enum;
25use syn::Attribute;
26use syn::Block;
27use syn::Field;
28use syn::Fields;
29use syn::FieldsUnnamed;
30use syn::FnArg;
31use syn::Ident;
32use syn::ItemEnum;
33use syn::Signature;
34use syn::Token;
35use syn::Variant;
36use syn::Visibility;
37
38/// Create a wrapper enum for switching (internal) implementations efficiently.
39///
40/// When integrating with third-party infrastructure components (HTTP clients, database clients, …), Nullables make use
41/// of embedded stubs. This macro generates a wrapper enum for seamlessly switching between the real implementation and
42/// the embedded stub.
43///
44/// The macro expects input in the form of an enum declaration, optionally followed by a block of function declarations
45/// (like in traits). Each enum variant can either be a newtype variant or a unit variant, which will be transformed a
46/// into a newtype variants. The macro creates [`From<T>`] and [`TryInto<T>`] implementations for each variant.
47///
48/// For each method declaration, the wrapper enum will have a definition that automatically forwards the call to the its
49/// variants. If the method does have a default body, this will be used instead of generating the body automatically.
50#[proc_macro]
51pub fn nullable_wrapper(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
52    let wrapper = parse_macro_input!(input as NullableWrapper);
53    let expanded = expand(wrapper);
54    proc_macro::TokenStream::from(expanded)
55}
56
57fn expand(wrapper: NullableWrapper) -> TokenStream {
58    let NullableWrapper {
59        attrs,
60        vis,
61        enum_token,
62        ident,
63        variants,
64        fns,
65    } = wrapper;
66
67    let (enum_ident, struct_impl) = expand_struct_wrapper(&attrs, &vis, ident, &fns);
68
69    let fns = fns.into_iter().map(
70        |WrapperFn {
71             attrs,
72             sig,
73             default,
74             ..
75         }| {
76            let method = &sig.ident;
77            let args: Punctuated<_, Comma> = sig
78                .inputs
79                .iter()
80                .filter_map(|arg| match arg {
81                    FnArg::Receiver(_) => None,
82                    FnArg::Typed(pat) => Some(&pat.pat),
83                })
84                .collect();
85
86            let body = default.map_or_else(
87                || {
88                    let matchers = variants.iter().map(
89                        |Variant { ident, .. }| quote!(Self::#ident(inner) => inner.#method(#args)),
90                    );
91
92                    quote!({
93                        match self {
94                            #(#matchers),*
95                        }
96                    })
97                },
98                Block::into_token_stream,
99            );
100
101            quote! {
102                #(#attrs)*
103                #sig #body
104            }
105        },
106    );
107
108    let from_impls = variants.iter().map(|var @ Variant { ident, fields, .. }| {
109        // TODO handle variant attrs
110        let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = fields else {
111            panic!()
112        };
113        let Field { ty, .. } = &unnamed[0];
114
115        quote_spanned! { var.span() =>
116            impl From<#ty> for #enum_ident {
117                fn from(value: #ty) -> Self {
118                    Self::#ident(value)
119                }
120            }
121        }
122    });
123
124    let try_into_impls = variants.iter().map(|var @ Variant { ident, fields, .. }| {
125        // TODO handle variant attrs
126        let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = fields else {
127            panic!()
128        };
129        let Field { ty, .. } = &unnamed[0];
130
131        quote_spanned! { var.span() =>
132            impl TryFrom<#enum_ident> for #ty {
133                type Error = ();
134
135                fn try_from(value: #enum_ident) -> Result<Self, Self::Error> {
136                    match value {
137                        #enum_ident::#ident(inner) => Ok(inner),
138                        _ => Err(())
139                    }
140                }
141            }
142        }
143    });
144
145    let expanded = quote! {
146        #struct_impl
147
148        #(#attrs)*
149        #enum_token #enum_ident {
150            #variants
151        }
152
153        impl #enum_ident {
154            #(#fns)*
155        }
156
157        #(#from_impls)*
158
159        #(#try_into_impls)*
160    };
161    expanded
162}
163
164fn expand_struct_wrapper(
165    attrs: &[Attribute],
166    vis: &Visibility,
167    ident: Ident,
168    fns: &[WrapperFn],
169) -> (Ident, TokenStream) {
170    let Visibility::Public(pub_token) = vis else {
171        return (ident, TokenStream::new());
172    };
173
174    let enum_ident = format_ident!("{}Inner", ident);
175
176    let fns = fns.iter().map(
177        |WrapperFn {
178             attrs, vis, sig, ..
179         }| {
180            let method = &sig.ident;
181            let args: Punctuated<_, Comma> = sig
182                .inputs
183                .iter()
184                .filter_map(|arg| match arg {
185                    FnArg::Receiver(_) => None,
186                    FnArg::Typed(pat) => Some(&pat.pat),
187                })
188                .collect();
189
190            let body = quote!({
191                self.0.#method(#args)
192            });
193
194            quote! {
195                #(#attrs)*
196                #vis #sig #body
197            }
198        },
199    );
200
201    let token_stream = quote! {
202        #(#attrs)*
203        #[repr(transparent)]
204        #pub_token struct #ident(#enum_ident);
205
206        impl #ident {
207            #(#fns)*
208        }
209
210        impl<T> From<T> for #ident where #enum_ident: From<T> {
211            fn from(value: T) -> Self {
212                Self(#enum_ident::from(value))
213            }
214        }
215    };
216
217    (enum_ident, token_stream)
218}
219
220struct NullableWrapper {
221    attrs: Vec<Attribute>,
222    vis: Visibility,
223    enum_token: Enum,
224    ident: Ident,
225    variants: Punctuated<Variant, Comma>,
226    fns: Vec<WrapperFn>,
227}
228
229// TODO parse syntax ourselves instead of reusing syn types?
230impl Parse for NullableWrapper {
231    fn parse(input: ParseStream) -> syn::Result<Self> {
232        let ItemEnum {
233            attrs,
234            vis,
235            enum_token,
236            ident,
237            mut variants,
238            ..
239        } = input.parse()?;
240
241        for variant in &mut variants {
242            match variant.fields {
243                Fields::Unit => {
244                    let name = &variant.ident;
245                    variant.fields = Fields::Unnamed(parse_quote!((#name)));
246                }
247                Fields::Unnamed(FieldsUnnamed {
248                    ref mut unnamed, ..
249                }) if unnamed.len() == 1 => {}
250                _ => {
251                    return Err(syn::Error::new_spanned(
252                        &variant,
253                        "only unit and new-type variants are supported",
254                    ))
255                }
256            }
257        }
258        // TODO handle generics, brace token & method definitions
259
260        let mut fns = Vec::new();
261        if !input.is_empty() {
262            let content;
263            braced!(content in input);
264
265            while !content.is_empty() {
266                fns.push(content.parse()?);
267            }
268        }
269
270        Ok(NullableWrapper {
271            attrs,
272            vis,
273            enum_token,
274            ident,
275            variants,
276            fns,
277        })
278    }
279}
280
281struct WrapperFn {
282    pub attrs: Vec<Attribute>,
283    pub vis: Visibility,
284    pub sig: Signature,
285    pub default: Option<Block>,
286    pub semi_token: Option<Token![;]>,
287}
288
289impl Parse for WrapperFn {
290    fn parse(input: ParseStream) -> syn::Result<Self> {
291        let attrs = input.call(Attribute::parse_outer)?;
292        let vis: Visibility = input.parse()?;
293        let sig: Signature = input.parse()?;
294
295        let lookahead = input.lookahead1();
296        let (default, semi_token) = if lookahead.peek(token::Brace) {
297            let block = input.parse()?;
298            (Some(block), None)
299        } else if lookahead.peek(Token![;]) {
300            let semi_token: Token![;] = input.parse()?;
301            (None, Some(semi_token))
302        } else {
303            return Err(lookahead.error());
304        };
305
306        Ok(Self {
307            attrs,
308            vis,
309            sig,
310            default,
311            semi_token,
312        })
313    }
314}