nullable-utils-macros 0.2.0

Helpers for working with James Shore's Nullables (proc-macro crate)
Documentation
// SPDX-FileCopyrightText: 2024 Markus Haug (Korrat)
//
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: MIT

//! This serves as a companion crate to [`nullable-utils`] and provides proc-macros to support working with Nullables.
//!
//! [`nullable-utils`]: https://crates.io/crates/nullable-utils

use proc_macro2::TokenStream;
use quote::format_ident;
use quote::quote;
use quote::quote_spanned;
use quote::ToTokens as _;
use syn::braced;
use syn::parse::Parse;
use syn::parse::ParseStream;
use syn::parse_macro_input;
use syn::parse_quote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned as _;
use syn::token;
use syn::token::Comma;
use syn::token::Enum;
use syn::Attribute;
use syn::Block;
use syn::Field;
use syn::Fields;
use syn::FieldsUnnamed;
use syn::FnArg;
use syn::Ident;
use syn::ItemEnum;
use syn::Signature;
use syn::Token;
use syn::Variant;
use syn::Visibility;

/// Create a wrapper enum for switching (internal) implementations efficiently.
///
/// When integrating with third-party infrastructure components (HTTP clients, database clients, …), Nullables make use
/// of embedded stubs. This macro generates a wrapper enum for seamlessly switching between the real implementation and
/// the embedded stub.
///
/// The macro expects input in the form of an enum declaration, optionally followed by a block of function declarations
/// (like in traits). Each enum variant can either be a newtype variant or a unit variant, which will be transformed a
/// into a newtype variants. The macro creates [`From<T>`] and [`TryInto<T>`] implementations for each variant.
///
/// For each method declaration, the wrapper enum will have a definition that automatically forwards the call to the its
/// variants. If the method does have a default body, this will be used instead of generating the body automatically.
#[proc_macro]
pub fn nullable_wrapper(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let wrapper = parse_macro_input!(input as NullableWrapper);
    let expanded = expand(wrapper);
    proc_macro::TokenStream::from(expanded)
}

fn expand(wrapper: NullableWrapper) -> TokenStream {
    let NullableWrapper {
        attrs,
        vis,
        enum_token,
        ident,
        variants,
        fns,
    } = wrapper;

    let (enum_ident, struct_impl) = expand_struct_wrapper(&attrs, &vis, ident, &fns);

    let fns = fns.into_iter().map(
        |WrapperFn {
             attrs,
             sig,
             default,
             ..
         }| {
            let method = &sig.ident;
            let args: Punctuated<_, Comma> = sig
                .inputs
                .iter()
                .filter_map(|arg| match arg {
                    FnArg::Receiver(_) => None,
                    FnArg::Typed(pat) => Some(&pat.pat),
                })
                .collect();

            let body = default.map_or_else(
                || {
                    let matchers = variants.iter().map(
                        |Variant { ident, .. }| quote!(Self::#ident(inner) => inner.#method(#args)),
                    );

                    quote!({
                        match self {
                            #(#matchers),*
                        }
                    })
                },
                Block::into_token_stream,
            );

            quote! {
                #(#attrs)*
                #sig #body
            }
        },
    );

    let from_impls = variants.iter().map(|var @ Variant { ident, fields, .. }| {
        // TODO handle variant attrs
        let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = fields else {
            panic!()
        };
        let Field { ty, .. } = &unnamed[0];

        quote_spanned! { var.span() =>
            impl From<#ty> for #enum_ident {
                fn from(value: #ty) -> Self {
                    Self::#ident(value)
                }
            }
        }
    });

    let try_into_impls = variants.iter().map(|var @ Variant { ident, fields, .. }| {
        // TODO handle variant attrs
        let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = fields else {
            panic!()
        };
        let Field { ty, .. } = &unnamed[0];

        quote_spanned! { var.span() =>
            impl TryFrom<#enum_ident> for #ty {
                type Error = ();

                fn try_from(value: #enum_ident) -> Result<Self, Self::Error> {
                    match value {
                        #enum_ident::#ident(inner) => Ok(inner),
                        _ => Err(())
                    }
                }
            }
        }
    });

    let expanded = quote! {
        #struct_impl

        #(#attrs)*
        #enum_token #enum_ident {
            #variants
        }

        impl #enum_ident {
            #(#fns)*
        }

        #(#from_impls)*

        #(#try_into_impls)*
    };
    expanded
}

fn expand_struct_wrapper(
    attrs: &[Attribute],
    vis: &Visibility,
    ident: Ident,
    fns: &[WrapperFn],
) -> (Ident, TokenStream) {
    let Visibility::Public(pub_token) = vis else {
        return (ident, TokenStream::new());
    };

    let enum_ident = format_ident!("{}Inner", ident);

    let fns = fns.iter().map(
        |WrapperFn {
             attrs, vis, sig, ..
         }| {
            let method = &sig.ident;
            let args: Punctuated<_, Comma> = sig
                .inputs
                .iter()
                .filter_map(|arg| match arg {
                    FnArg::Receiver(_) => None,
                    FnArg::Typed(pat) => Some(&pat.pat),
                })
                .collect();

            let body = quote!({
                self.0.#method(#args)
            });

            quote! {
                #(#attrs)*
                #vis #sig #body
            }
        },
    );

    let token_stream = quote! {
        #(#attrs)*
        #[repr(transparent)]
        #pub_token struct #ident(#enum_ident);

        impl #ident {
            #(#fns)*
        }

        impl<T> From<T> for #ident where #enum_ident: From<T> {
            fn from(value: T) -> Self {
                Self(#enum_ident::from(value))
            }
        }
    };

    (enum_ident, token_stream)
}

struct NullableWrapper {
    attrs: Vec<Attribute>,
    vis: Visibility,
    enum_token: Enum,
    ident: Ident,
    variants: Punctuated<Variant, Comma>,
    fns: Vec<WrapperFn>,
}

// TODO parse syntax ourselves instead of reusing syn types?
impl Parse for NullableWrapper {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let ItemEnum {
            attrs,
            vis,
            enum_token,
            ident,
            mut variants,
            ..
        } = input.parse()?;

        for variant in &mut variants {
            match variant.fields {
                Fields::Unit => {
                    let name = &variant.ident;
                    variant.fields = Fields::Unnamed(parse_quote!((#name)));
                }
                Fields::Unnamed(FieldsUnnamed {
                    ref mut unnamed, ..
                }) if unnamed.len() == 1 => {}
                _ => {
                    return Err(syn::Error::new_spanned(
                        &variant,
                        "only unit and new-type variants are supported",
                    ))
                }
            }
        }
        // TODO handle generics, brace token & method definitions

        let mut fns = Vec::new();
        if !input.is_empty() {
            let content;
            braced!(content in input);

            while !content.is_empty() {
                fns.push(content.parse()?);
            }
        }

        Ok(NullableWrapper {
            attrs,
            vis,
            enum_token,
            ident,
            variants,
            fns,
        })
    }
}

struct WrapperFn {
    pub attrs: Vec<Attribute>,
    pub vis: Visibility,
    pub sig: Signature,
    pub default: Option<Block>,
    pub semi_token: Option<Token![;]>,
}

impl Parse for WrapperFn {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let attrs = input.call(Attribute::parse_outer)?;
        let vis: Visibility = input.parse()?;
        let sig: Signature = input.parse()?;

        let lookahead = input.lookahead1();
        let (default, semi_token) = if lookahead.peek(token::Brace) {
            let block = input.parse()?;
            (Some(block), None)
        } else if lookahead.peek(Token![;]) {
            let semi_token: Token![;] = input.parse()?;
            (None, Some(semi_token))
        } else {
            return Err(lookahead.error());
        };

        Ok(Self {
            attrs,
            vis,
            sig,
            default,
            semi_token,
        })
    }
}