use proc_macro2::TokenStream;
use quote::format_ident;
use quote::quote;
use quote::quote_spanned;
use quote::ToTokens as _;
use syn::braced;
use syn::parse::Parse;
use syn::parse::ParseStream;
use syn::parse_macro_input;
use syn::parse_quote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned as _;
use syn::token;
use syn::token::Comma;
use syn::token::Enum;
use syn::Attribute;
use syn::Block;
use syn::Field;
use syn::Fields;
use syn::FieldsUnnamed;
use syn::FnArg;
use syn::Ident;
use syn::ItemEnum;
use syn::Signature;
use syn::Token;
use syn::Variant;
use syn::Visibility;
#[proc_macro]
pub fn nullable_wrapper(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let wrapper = parse_macro_input!(input as NullableWrapper);
let expanded = expand(wrapper);
proc_macro::TokenStream::from(expanded)
}
fn expand(wrapper: NullableWrapper) -> TokenStream {
let NullableWrapper {
attrs,
vis,
enum_token,
ident,
variants,
fns,
} = wrapper;
let (enum_ident, struct_impl) = expand_struct_wrapper(&attrs, &vis, ident, &fns);
let fns = fns.into_iter().map(
|WrapperFn {
attrs,
sig,
default,
..
}| {
let method = &sig.ident;
let args: Punctuated<_, Comma> = sig
.inputs
.iter()
.filter_map(|arg| match arg {
FnArg::Receiver(_) => None,
FnArg::Typed(pat) => Some(&pat.pat),
})
.collect();
let body = default.map_or_else(
|| {
let matchers = variants.iter().map(
|Variant { ident, .. }| quote!(Self::#ident(inner) => inner.#method(#args)),
);
quote!({
match self {
#(#matchers),*
}
})
},
Block::into_token_stream,
);
quote! {
#(#attrs)*
#sig #body
}
},
);
let from_impls = variants.iter().map(|var @ Variant { ident, fields, .. }| {
let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = fields else {
panic!()
};
let Field { ty, .. } = &unnamed[0];
quote_spanned! { var.span() =>
impl From<#ty> for #enum_ident {
fn from(value: #ty) -> Self {
Self::#ident(value)
}
}
}
});
let try_into_impls = variants.iter().map(|var @ Variant { ident, fields, .. }| {
let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = fields else {
panic!()
};
let Field { ty, .. } = &unnamed[0];
quote_spanned! { var.span() =>
impl TryFrom<#enum_ident> for #ty {
type Error = ();
fn try_from(value: #enum_ident) -> Result<Self, Self::Error> {
match value {
#enum_ident::#ident(inner) => Ok(inner),
_ => Err(())
}
}
}
}
});
let expanded = quote! {
#struct_impl
#(#attrs)*
#enum_token #enum_ident {
#variants
}
impl #enum_ident {
#(#fns)*
}
#(#from_impls)*
#(#try_into_impls)*
};
expanded
}
fn expand_struct_wrapper(
attrs: &[Attribute],
vis: &Visibility,
ident: Ident,
fns: &[WrapperFn],
) -> (Ident, TokenStream) {
let Visibility::Public(pub_token) = vis else {
return (ident, TokenStream::new());
};
let enum_ident = format_ident!("{}Inner", ident);
let fns = fns.iter().map(
|WrapperFn {
attrs, vis, sig, ..
}| {
let method = &sig.ident;
let args: Punctuated<_, Comma> = sig
.inputs
.iter()
.filter_map(|arg| match arg {
FnArg::Receiver(_) => None,
FnArg::Typed(pat) => Some(&pat.pat),
})
.collect();
let body = quote!({
self.0.#method(#args)
});
quote! {
#(#attrs)*
#vis #sig #body
}
},
);
let token_stream = quote! {
#(#attrs)*
#[repr(transparent)]
#pub_token struct #ident(#enum_ident);
impl #ident {
#(#fns)*
}
impl<T> From<T> for #ident where #enum_ident: From<T> {
fn from(value: T) -> Self {
Self(#enum_ident::from(value))
}
}
};
(enum_ident, token_stream)
}
struct NullableWrapper {
attrs: Vec<Attribute>,
vis: Visibility,
enum_token: Enum,
ident: Ident,
variants: Punctuated<Variant, Comma>,
fns: Vec<WrapperFn>,
}
impl Parse for NullableWrapper {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ItemEnum {
attrs,
vis,
enum_token,
ident,
mut variants,
..
} = input.parse()?;
for variant in &mut variants {
match variant.fields {
Fields::Unit => {
let name = &variant.ident;
variant.fields = Fields::Unnamed(parse_quote!((#name)));
}
Fields::Unnamed(FieldsUnnamed {
ref mut unnamed, ..
}) if unnamed.len() == 1 => {}
_ => {
return Err(syn::Error::new_spanned(
&variant,
"only unit and new-type variants are supported",
))
}
}
}
let mut fns = Vec::new();
if !input.is_empty() {
let content;
braced!(content in input);
while !content.is_empty() {
fns.push(content.parse()?);
}
}
Ok(NullableWrapper {
attrs,
vis,
enum_token,
ident,
variants,
fns,
})
}
}
struct WrapperFn {
pub attrs: Vec<Attribute>,
pub vis: Visibility,
pub sig: Signature,
pub default: Option<Block>,
pub semi_token: Option<Token![;]>,
}
impl Parse for WrapperFn {
fn parse(input: ParseStream) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let vis: Visibility = input.parse()?;
let sig: Signature = input.parse()?;
let lookahead = input.lookahead1();
let (default, semi_token) = if lookahead.peek(token::Brace) {
let block = input.parse()?;
(Some(block), None)
} else if lookahead.peek(Token![;]) {
let semi_token: Token![;] = input.parse()?;
(None, Some(semi_token))
} else {
return Err(lookahead.error());
};
Ok(Self {
attrs,
vis,
sig,
default,
semi_token,
})
}
}