use proc_macro::TokenStream;
use quote::{quote, format_ident};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
Data, DataEnum, DeriveInput, Fields,
Ident, ItemTrait, Path, Result, Token, TraitItem, TraitItemFn,
Type,
};
use heck::ToSnakeCase;
use proc_macro2::TokenStream as TokenStream2;
fn generate_allocator_arms(field_name: &Ident, ty: &Type, arena_type_name: &Ident) -> TokenStream2 {
#[cfg(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo"))]
let mut arms = vec![];
#[cfg(not(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo")))]
let arms: Vec<TokenStream2> = vec![];
#[cfg(feature = "allocator-typed-arena")]
arms.push(quote! {
#arena_type_name::Typed { #field_name, .. } => {
#field_name.alloc(value) as *mut #ty as *mut ()
}
});
#[cfg(feature = "allocator-bumpalo")]
arms.push(quote! {
#arena_type_name::Bumpalo { arena, .. } => {
unsafe {
let arena_ref = &**arena;
arena_ref.alloc(value) as *mut #ty as *mut ()
}
}
});
if arms.is_empty() {
let _ = (field_name, ty, arena_type_name); quote! {
_ => compile_error!("At least one allocator feature must be enabled (allocator-typed-arena or allocator-bumpalo)")
}
} else {
quote! { #(#arms)* }
}
}
fn generate_arena_enum(arena_type_name: &Ident, lifetime: &TokenStream2, typed_arena_fields: &[TokenStream2]) -> TokenStream2 {
#[cfg(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo"))]
let mut variants = vec![];
#[cfg(not(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo")))]
let variants: Vec<TokenStream2> = vec![];
#[cfg(feature = "allocator-typed-arena")]
variants.push(quote! {
Typed {
#(#typed_arena_fields,)*
}
});
#[cfg(feature = "allocator-bumpalo")]
variants.push(quote! {
Bumpalo {
arena: *mut ::tagged_dispatch::bumpalo::Bump,
owned: bool,
_phantom: ::core::marker::PhantomData<&#lifetime ()>,
}
});
if variants.is_empty() {
let _ = typed_arena_fields;
quote! {
compile_error!("At least one allocator feature must be enabled");
}
} else {
quote! {
#[doc(hidden)]
enum #arena_type_name<#lifetime> {
#(#variants,)*
}
}
}
}
fn generate_builder_new() -> TokenStream2 {
#[cfg(feature = "allocator-bumpalo")]
return quote! {
Self::with_bumpalo()
};
#[cfg(all(feature = "allocator-typed-arena", not(feature = "allocator-bumpalo")))]
return quote! {
Self::with_typed_arena()
};
#[cfg(not(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo")))]
quote! {
compile_error!("At least one allocator feature must be enabled (allocator-typed-arena or allocator-bumpalo)")
}
}
fn generate_builder_methods(
builder_name: &Ident,
arena_type_name: &Ident,
typed_arena_inits: &[TokenStream2],
lifetime: &TokenStream2
) -> TokenStream2 {
#[cfg(not(feature = "allocator-bumpalo"))]
let _ = (builder_name, lifetime);
#[cfg(not(feature = "allocator-typed-arena"))]
let _ = typed_arena_inits;
#[cfg(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo"))]
let mut methods = vec![];
#[cfg(not(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo")))]
let methods: Vec<TokenStream2> = {
let _ = (builder_name, arena_type_name, typed_arena_inits, lifetime);
vec![]
};
#[cfg(feature = "allocator-bumpalo")]
methods.push(quote! {
pub fn with_bumpalo() -> #builder_name<'static> {
let arena = Box::leak(Box::new(::tagged_dispatch::bumpalo::Bump::new()));
#builder_name {
allocator: #arena_type_name::Bumpalo {
arena: arena as *mut _,
owned: true,
_phantom: ::core::marker::PhantomData,
},
_phantom: ::core::marker::PhantomData,
}
}
pub fn with_external_bumpalo(arena: &#lifetime ::tagged_dispatch::bumpalo::Bump) -> Self {
Self {
allocator: #arena_type_name::Bumpalo {
arena: arena as *const _ as *mut _,
owned: false,
_phantom: ::core::marker::PhantomData,
},
_phantom: ::core::marker::PhantomData,
}
}
});
#[cfg(feature = "allocator-typed-arena")]
methods.push(quote! {
pub fn with_typed_arena() -> Self {
Self {
allocator: #arena_type_name::Typed {
#(#typed_arena_inits,)*
},
_phantom: ::core::marker::PhantomData,
}
}
});
quote! { #(#methods)* }
}
fn generate_reset_impl(
arena_type_name: &Ident,
typed_arena_inits2: &[TokenStream2]
) -> TokenStream2 {
#[cfg(not(feature = "allocator-typed-arena"))]
let _ = typed_arena_inits2;
#[cfg(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo"))]
let mut arms = vec![];
#[cfg(not(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo")))]
let arms: Vec<TokenStream2> = {
let _ = (arena_type_name, typed_arena_inits2);
vec![]
};
#[cfg(feature = "allocator-typed-arena")]
arms.push(quote! {
#arena_type_name::Typed { .. } => {
self.allocator = #arena_type_name::Typed {
#(#typed_arena_inits2,)*
};
}
});
#[cfg(feature = "allocator-bumpalo")]
arms.push(quote! {
#arena_type_name::Bumpalo { arena, owned: true, .. } => {
unsafe {
(&mut **arena).reset();
}
}
#arena_type_name::Bumpalo { owned: false, .. } => {
panic!("Cannot reset builder using external arena");
}
});
quote! {
match &mut self.allocator {
#(#arms)*
}
}
}
fn generate_stats_impl(arena_type_name: &Ident) -> TokenStream2 {
#[cfg(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo"))]
let mut arms = vec![];
#[cfg(not(any(feature = "allocator-typed-arena", feature = "allocator-bumpalo")))]
let arms: Vec<TokenStream2> = {
let _ = arena_type_name;
vec![]
};
#[cfg(feature = "allocator-typed-arena")]
arms.push(quote! {
#arena_type_name::Typed { .. } => {
Default::default()
}
});
#[cfg(feature = "allocator-bumpalo")]
arms.push(quote! {
#arena_type_name::Bumpalo { arena, .. } => {
unsafe {
let arena_ref = &**arena;
::tagged_dispatch::ArenaStats {
allocated_bytes: arena_ref.allocated_bytes(),
chunk_capacity: arena_ref.chunk_capacity(),
}
}
}
});
quote! {
match &self.allocator {
#(#arms)*
}
}
}
#[proc_macro_attribute]
pub fn tagged_dispatch(args: TokenStream, input: TokenStream) -> TokenStream {
if let Ok(trait_def) = syn::parse::<ItemTrait>(input.clone()) {
process_trait(trait_def)
} else if let Ok(enum_def) = syn::parse::<DeriveInput>(input) {
process_enum(args, enum_def)
} else {
syn::Error::new(
proc_macro2::Span::call_site(),
"tagged_dispatch can only be applied to traits or enums"
)
.to_compile_error()
.into()
}
}
fn process_trait(mut trait_def: ItemTrait) -> TokenStream {
let trait_name = &trait_def.ident;
let dispatch_methods: Vec<_> = trait_def.items.iter().filter_map(|item| {
if let TraitItem::Fn(method) = item {
let has_no_dispatch = method.attrs.iter().any(|attr|
attr.path().is_ident("no_dispatch")
);
if !has_no_dispatch {
Some(method.clone())
} else {
None
}
} else {
None
}
}).collect();
for item in &mut trait_def.items {
if let TraitItem::Fn(method) = item {
method.attrs.retain(|attr| !attr.path().is_ident("no_dispatch"));
}
}
let macro_name = format_ident!("__impl_{}_dispatch", trait_name.to_string().to_snake_case());
let dispatch_impls: Vec<_> = dispatch_methods.iter().map(|method| {
generate_dispatch_method(method)
}).collect();
let output = quote! {
#trait_def
#[doc(hidden)]
macro_rules! #macro_name {
(
$enum_name:ident,
$enum_type_name:ident,
owned,
[$(($variant:ident, $type:ty)),* $(,)?]
) => {
impl $enum_name {
#(#dispatch_impls)*
}
};
(
$enum_name:ident,
$enum_type_name:ident,
$lifetime:lifetime,
[$(($variant:ident, $type:ty)),* $(,)?]
) => {
impl<$lifetime> $enum_name<$lifetime> {
#(#dispatch_impls)*
}
};
}
};
TokenStream::from(output)
}
fn process_enum(args: TokenStream, mut enum_def: DeriveInput) -> TokenStream {
let parsed = parse_macro_input!(args as TraitListWithFlags);
let enum_name = &enum_def.ident;
let vis = &enum_def.vis;
let generics = &enum_def.generics;
let has_lifetime = !generics.lifetimes().collect::<Vec<_>>().is_empty();
let lifetime = generics.lifetimes().next().map(|lt| <.lifetime);
let variants = if let Data::Enum(ref mut data_enum) = enum_def.data {
process_enum_variants(data_enum)
} else {
return syn::Error::new_spanned(
enum_def,
"tagged_dispatch can only be applied to enums"
)
.to_compile_error()
.into();
};
if has_lifetime {
generate_arena_impl(enum_name, vis, lifetime.unwrap(), &variants, &parsed.traits, &parsed.flags)
} else {
generate_owned_impl(enum_name, vis, &variants, &parsed.traits, &parsed.flags)
}
}
fn process_enum_variants(data_enum: &mut DataEnum) -> Vec<(Ident, Type)> {
data_enum.variants.iter_mut().map(|variant| {
match &mut variant.fields {
Fields::Unit => {
let type_name = &variant.ident;
let type_path: Type = syn::parse_quote!(#type_name);
variant.fields = Fields::Unnamed(syn::parse_quote!((#type_path)));
(variant.ident.clone(), type_path)
}
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
let inner_type = fields.unnamed.first().unwrap().ty.clone();
(variant.ident.clone(), inner_type)
}
_ => {
panic!("Each variant must either be a unit variant or have exactly one unnamed field");
}
}
}).collect()
}
fn generate_owned_impl(
enum_name: &Ident,
vis: &syn::Visibility,
variants: &[(Ident, Type)],
traits: &[Path],
flags: &TraitGenerationFlags,
) -> TokenStream {
let enum_type_name = format_ident!("{}Type", enum_name);
let constructors = variants.iter().enumerate().map(|(i, (variant, ty))| {
let tag = i as u8;
let method_name = format_ident!("{}", variant.to_string().to_snake_case());
quote! {
#[doc = concat!("Create a `", stringify!(#variant), "` variant")]
#[inline]
pub fn #method_name(value: #ty) -> Self {
let boxed = Box::new(value);
let ptr = Box::into_raw(boxed) as *mut ();
Self(::tagged_dispatch::TaggedPtr::new(ptr, #tag))
}
}
});
let from_impls = variants.iter().enumerate().map(|(i, (_variant, ty))| {
let tag = i as u8;
quote! {
impl From<#ty> for #enum_name {
fn from(value: #ty) -> Self {
let boxed = Box::new(value);
let ptr = Box::into_raw(boxed) as *mut ();
Self(::tagged_dispatch::TaggedPtr::new(ptr, #tag))
}
}
}
});
let drop_arms = variants.iter().enumerate().map(|(i, (_variant, ty))| {
let tag = i as u8;
quote! {
#tag => {
let ptr = self.0.untagged_ptr() as *mut #ty;
drop(Box::from_raw(ptr));
}
}
});
let clone_arms = variants.iter().enumerate().map(|(i, (variant, ty))| {
let method_name = format_ident!("{}", variant.to_string().to_snake_case());
let tag = i as u8;
quote! {
#tag => {
let ptr = self.0.ptr() as *const #ty;
let cloned = (*ptr).clone();
Self::#method_name(cloned)
}
}
});
let enum_variants = variants.iter().map(|(variant, _)| {
quote! { #variant }
});
let variant_list: Vec<_> = variants.iter().map(|(variant, ty)| {
quote! { (#variant, #ty) }
}).collect();
let dispatch_invocations = traits.iter().map(|trait_path| {
let trait_name = &trait_path.segments.last().unwrap().ident;
let macro_name = format_ident!("__impl_{}_dispatch", trait_name.to_string().to_snake_case());
let variant_list = variant_list.clone();
quote! {
#macro_name!(#enum_name, #enum_type_name, owned, [#(#variant_list),*]);
}
});
let trait_checks = traits.iter().flat_map(|trait_path| {
variants.iter().map(move |(_, ty)| {
quote! {
const _: fn() = || {
fn assert_impl<T: #trait_path>() {}
assert_impl::<#ty>();
};
}
})
});
let debug_impl = if flags.should_generate_debug() {
quote! {
impl ::core::fmt::Debug for #enum_name {
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
write!(f, "{}::{:?}", stringify!(#enum_name), self.tag_type())
}
}
}
} else {
quote! {}
};
let eq_impl = if flags.should_generate_eq() {
quote! {
impl ::core::cmp::PartialEq for #enum_name {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
impl ::core::cmp::Eq for #enum_name {}
}
} else {
quote! {}
};
let ord_impl = if flags.should_generate_ord() {
quote! {
impl ::core::cmp::PartialOrd for #enum_name {
fn partial_cmp(&self, other: &Self) -> Option<::core::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
impl ::core::cmp::Ord for #enum_name {
fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
self.0.cmp(&other.0)
}
}
}
} else {
quote! {}
};
let output = quote! {
#[repr(transparent)]
#vis struct #enum_name(::tagged_dispatch::TaggedPtr<()>);
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#vis enum #enum_type_name {
#(#enum_variants,)*
}
impl #enum_name {
#(#constructors)*
#[inline(always)]
pub fn tag_type(&self) -> #enum_type_name {
unsafe { ::core::mem::transmute(self.0.tag()) }
}
}
#debug_impl
#eq_impl
#ord_impl
#(#from_impls)*
impl Drop for #enum_name {
fn drop(&mut self) {
if self.0.is_null() {
return;
}
unsafe {
match self.0.tag() {
#(#drop_arms)*
_ => unreachable!("Invalid tag"),
}
}
}
}
impl Clone for #enum_name {
fn clone(&self) -> Self {
unsafe {
match self.0.tag() {
#(#clone_arms)*
_ => unreachable!("Invalid tag"),
}
}
}
}
#(#dispatch_invocations)*
#(#trait_checks)*
const _: () = assert!(::core::mem::size_of::<#enum_name>() == 8);
};
TokenStream::from(output)
}
fn generate_arena_impl(
enum_name: &Ident,
vis: &syn::Visibility,
lifetime: &syn::Lifetime,
variants: &[(Ident, Type)],
traits: &[Path],
flags: &TraitGenerationFlags,
) -> TokenStream {
let enum_type_name = format_ident!("{}Type", enum_name);
let builder_name = format_ident!("{}ArenaBuilder", enum_name);
let arena_type_name = format_ident!("{}ArenaType", enum_name);
let typed_arena_fields: Vec<_> = variants.iter().map(|(variant, ty)| {
let field_name = format_ident!("{}_arena", variant.to_string().to_snake_case());
quote! { #field_name: ::typed_arena::Arena<#ty> }
}).collect();
let typed_arena_inits: Vec<_> = variants.iter().map(|(variant, _ty)| {
let field_name = format_ident!("{}_arena", variant.to_string().to_snake_case());
quote! { #field_name: ::typed_arena::Arena::new() }
}).collect();
let typed_arena_inits2 = typed_arena_inits.clone();
let builder_methods = variants.iter().enumerate().map(|(i, (variant, ty))| {
let tag = i as u8;
let method_name = format_ident!("{}", variant.to_string().to_snake_case());
let field_name = format_ident!("{}_arena", variant.to_string().to_snake_case());
let allocator_arms = generate_allocator_arms(&field_name, ty, &arena_type_name);
quote! {
#[doc = concat!("Create a `", stringify!(#variant), "` variant in the arena")]
#[inline]
pub fn #method_name(&#lifetime self, value: #ty) -> #enum_name<#lifetime> {
let ptr = match &self.allocator {
#allocator_arms
};
#enum_name(::tagged_dispatch::TaggedPtr::new(ptr, #tag), ::core::marker::PhantomData)
}
}
});
let enum_variants = variants.iter().map(|(variant, _)| {
quote! { #variant }
});
let variant_list: Vec<_> = variants.iter().map(|(variant, ty)| {
quote! { (#variant, #ty) }
}).collect();
let dispatch_invocations = traits.iter().map(|trait_path| {
let trait_name = &trait_path.segments.last().unwrap().ident;
let macro_name = format_ident!("__impl_{}_dispatch", trait_name.to_string().to_snake_case());
let variant_list = variant_list.clone();
quote! {
#macro_name!(#enum_name, #enum_type_name, #lifetime, [#(#variant_list),*]);
}
});
let trait_checks = traits.iter().flat_map(|trait_path| {
variants.iter().map(move |(_, ty)| {
quote! {
const _: fn() = || {
fn assert_impl<T: #trait_path>() {}
assert_impl::<#ty>();
};
}
})
});
let lifetime_tokens = quote! { #lifetime };
let arena_enum_definition = generate_arena_enum(&arena_type_name, &lifetime_tokens, &typed_arena_fields);
let builder_new_impl = generate_builder_new();
let builder_specific_methods = generate_builder_methods(&builder_name, &arena_type_name, &typed_arena_inits, &lifetime_tokens);
let reset_impl = generate_reset_impl(&arena_type_name, &typed_arena_inits2);
let stats_impl = generate_stats_impl(&arena_type_name);
let debug_impl = if flags.should_generate_debug() {
quote! {
impl<#lifetime> ::core::fmt::Debug for #enum_name<#lifetime> {
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
write!(f, "{}::{:?}", stringify!(#enum_name), self.tag_type())
}
}
}
} else {
quote! {}
};
let eq_impl = if flags.should_generate_eq() {
quote! {
impl<#lifetime> ::core::cmp::PartialEq for #enum_name<#lifetime> {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
impl<#lifetime> ::core::cmp::Eq for #enum_name<#lifetime> {}
}
} else {
quote! {}
};
let ord_impl = if flags.should_generate_ord() {
quote! {
impl<#lifetime> ::core::cmp::PartialOrd for #enum_name<#lifetime> {
fn partial_cmp(&self, other: &Self) -> Option<::core::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
impl<#lifetime> ::core::cmp::Ord for #enum_name<#lifetime> {
fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
self.0.cmp(&other.0)
}
}
}
} else {
quote! {}
};
let output = quote! {
#[repr(transparent)]
#vis struct #enum_name<#lifetime>(
::tagged_dispatch::TaggedPtr<()>,
::core::marker::PhantomData<&#lifetime ()>
);
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#vis enum #enum_type_name {
#(#enum_variants,)*
}
#arena_enum_definition
#vis struct #builder_name<#lifetime> {
allocator: #arena_type_name<#lifetime>,
_phantom: ::core::marker::PhantomData<&#lifetime ()>,
}
impl<#lifetime> #builder_name<#lifetime> {
pub fn new() -> Self {
#builder_new_impl
}
#builder_specific_methods
pub fn reset(&mut self) {
#reset_impl
}
pub fn clear(&mut self) {
self.reset(); }
pub fn stats(&self) -> ::tagged_dispatch::ArenaStats {
#stats_impl
}
#(#builder_methods)*
}
impl<#lifetime> #enum_name<#lifetime> {
pub fn arena_builder() -> #builder_name<#lifetime> {
#builder_name::new()
}
#[inline(always)]
pub fn tag_type(&self) -> #enum_type_name {
unsafe { ::core::mem::transmute(self.0.tag()) }
}
}
impl<#lifetime> Copy for #enum_name<#lifetime> {}
impl<#lifetime> Clone for #enum_name<#lifetime> {
#[inline(always)]
fn clone(&self) -> Self {
*self
}
}
#debug_impl
#eq_impl
#ord_impl
#(#dispatch_invocations)*
#(#trait_checks)*
const _: () = assert!(::core::mem::size_of::<#enum_name<'static>>() == 8);
};
TokenStream::from(output)
}
fn generate_dispatch_method(method: &TraitItemFn) -> proc_macro2::TokenStream {
let method_name = &method.sig.ident;
let inputs = &method.sig.inputs;
let output = &method.sig.output;
let args: Vec<_> = inputs.iter().skip(1).collect();
let arg_names: Vec<_> = args.iter().filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
Some(&pat_ident.ident)
} else {
None
}
} else {
None
}
}).collect();
quote! {
#[inline]
pub fn #method_name(&self #(, #args)*) #output {
unsafe {
match self.tag_type() {
$(
$enum_type_name::$variant => {
let ptr = &*(self.0.ptr() as *const $type);
ptr.#method_name(#(#arg_names),*)
}
)*
}
}
}
}
}
#[derive(Debug, Clone, Default)]
struct TraitGenerationFlags {
no_debug: bool,
no_eq: bool,
no_ord: bool,
no_traits: bool,
}
impl TraitGenerationFlags {
fn should_generate_debug(&self) -> bool {
!self.no_traits && !self.no_debug
}
fn should_generate_eq(&self) -> bool {
!self.no_traits && !self.no_eq
}
fn should_generate_ord(&self) -> bool {
!self.no_traits && !self.no_ord && !self.no_eq }
}
struct TraitListWithFlags {
traits: Vec<Path>,
flags: TraitGenerationFlags,
}
impl Parse for TraitListWithFlags {
fn parse(input: ParseStream) -> Result<Self> {
let mut traits = Vec::new();
let mut flags = TraitGenerationFlags::default();
if input.is_empty() {
return Ok(TraitListWithFlags { traits, flags });
}
let items = Punctuated::<syn::Expr, Token![,]>::parse_terminated(input)?;
for item in items {
if let syn::Expr::Path(expr_path) = item {
if expr_path.path.is_ident("no_debug") {
flags.no_debug = true;
} else if expr_path.path.is_ident("no_eq") {
flags.no_eq = true;
} else if expr_path.path.is_ident("no_ord") {
flags.no_ord = true;
} else if expr_path.path.is_ident("no_cmp") {
flags.no_eq = true;
flags.no_ord = true;
} else if expr_path.path.is_ident("no_traits") {
flags.no_traits = true;
} else {
traits.push(expr_path.path);
}
} else {
return Err(syn::Error::new_spanned(
item,
"Expected trait name or flag (no_debug, no_eq, no_ord, no_cmp, no_traits)"
));
}
}
Ok(TraitListWithFlags { traits, flags })
}
}