use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{parse_macro_input, visit::Visit, DeriveInput};
#[derive(Default)]
struct LabelFinder {
found: bool,
}
impl syn::visit::Visit<'_> for LabelFinder {
fn visit_ident(&mut self, id: &syn::Ident) {
if id == "Label" {
self.found = true;
}
}
}
impl LabelFinder {
fn find(input: &syn::Type) -> bool {
let mut finder = LabelFinder::default();
finder.visit_type(input);
finder.found
}
}
fn variant_to_map_labels_case(
special_types: &LabelTypes,
arg_fn_name: &syn::Ident,
enum_name: &syn::Ident,
variant: &syn::Variant,
) -> proc_macro2::TokenStream {
let syn::Fields::Named(fields) = &variant.fields else {
return syn::Error::new_spanned(
&variant.fields,
"The `Gate` enum only supports named fields",
)
.to_compile_error();
};
let variant_name = &variant.ident;
let fields_names = fields.named.iter().map(|field| &field.ident);
let fields_values = fields.named.iter().map(|field| {
special_types.map_labels(
field.ident.as_ref().expect("Named field with no name."),
&field.ty,
arg_fn_name,
)
});
quote! {
#enum_name::#variant_name { #(#fields_names),*} => #enum_name::#variant_name { #(#fields_values),* },
}
}
fn variant_to_for_each_label_case(
special_types: &LabelTypes,
arg_fn_name: &syn::Ident,
enum_name: &syn::Ident,
variant: &syn::Variant,
) -> proc_macro2::TokenStream {
let syn::Fields::Named(fields) = &variant.fields else {
return syn::Error::new_spanned(
&variant.fields,
"The `Gate` enum only supports named fields",
)
.to_compile_error();
};
let variant_name = &variant.ident;
let fields_names = fields.named.iter().map(|field| &field.ident);
let fields_ops = fields.named.iter().map(|field| {
special_types.for_each_label(
field.ident.as_ref().expect("Named field with no name."),
&field.ty,
arg_fn_name,
)
});
quote! {
#enum_name::#variant_name { #(#fields_names),*} => { #(#fields_ops)* },
}
}
struct LabelTypes {
label_type: syn::Type,
vec_label_type: syn::Type,
}
impl Default for LabelTypes {
fn default() -> Self {
let label_type = syn::parse_str("Label").expect("Unable to parse label type");
let vec_label_type = syn::parse_str("Vec<Label>").expect("Unable to parse vec label type");
Self {
label_type,
vec_label_type,
}
}
}
impl LabelTypes {
fn map_labels(
&self,
field_name: &syn::Ident,
field_type: &syn::Type,
arg_fn_name: &syn::Ident,
) -> proc_macro2::TokenStream {
if *field_type == self.vec_label_type {
quote! {#field_name: #field_name.iter().cloned().map(#arg_fn_name).collect()}
} else if *field_type == self.label_type {
quote! {#field_name: #arg_fn_name(#field_name.clone())}
} else if LabelFinder::find(field_type) {
let err = syn::Error::new_spanned(
field_type,
"Unknown `Label` type for `GateMethods` macro.",
)
.to_compile_error();
quote! {#field_name: #err}
} else {
quote! {#field_name: #field_name.clone()}
}
}
fn for_each_label(
&self,
field_name: &syn::Ident,
field_type: &syn::Type,
arg_fn_name: &syn::Ident,
) -> proc_macro2::TokenStream {
if *field_type == self.vec_label_type {
quote! {#field_name.iter().cloned().for_each(#arg_fn_name);}
} else if *field_type == self.label_type {
quote! {#arg_fn_name(#field_name.clone());}
} else if LabelFinder::find(field_type) {
syn::Error::new_spanned(field_type, "Unknown `Label` type for `GateMethods` macro.")
.to_compile_error()
} else {
quote! {}
}
}
}
pub fn derive_gate_methods_inner(input: DeriveInput) -> proc_macro2::TokenStream {
let special_types = LabelTypes::default();
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let enum_name = input.ident; let syn::Data::Enum(syn::DataEnum { variants, .. }) = input.data else {
return syn::Error::new(Span::call_site(), "Only works on enums").to_compile_error();
};
let arg_fn_name = format_ident!("f");
let map_iter = variants.iter().map(|variant| {
variant_to_map_labels_case(&special_types, &arg_fn_name, &enum_name, variant)
});
let for_each_iter = variants.iter().map(|variant| {
variant_to_for_each_label_case(&special_types, &arg_fn_name, &enum_name, variant)
});
let map_labels_doc =
format!("Creates a new `{enum_name}` by applying `{arg_fn_name}` to every `Label`.");
let for_each_label_doc =
format!("Applies `{arg_fn_name}` to every `Label` in the `{enum_name}`.");
quote! {
impl #impl_generics #enum_name #ty_generics #where_clause {
#[doc = #map_labels_doc]
pub fn map_labels(&self, mut #arg_fn_name: impl FnMut(Label) -> Label) -> Self {
match self {
#(#map_iter)*
}
}
#[doc = #for_each_label_doc]
pub fn for_each_label(&self, mut #arg_fn_name: impl FnMut(Label)) {
#![allow(unused_variables)]
match self {
#(#for_each_iter)*
}
}
}
}
}
pub fn derive_gate_methods(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
derive_gate_methods_inner(input).into()
}