mpc-macros 0.2.12

Arcium MPC Macros
Documentation
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{parse_macro_input, visit::Visit, DeriveInput};

/// Finds `Label`.
#[derive(Default)]
struct LabelFinder {
    /// Will be set to `true` when `Label` is found.
    found: bool,
}

impl syn::visit::Visit<'_> for LabelFinder {
    fn visit_ident(&mut self, id: &syn::Ident) {
        if id == "Label" {
            self.found = true;
        }
    }
}

impl LabelFinder {
    /// Finds whether `input` uses `Label` type.
    fn find(input: &syn::Type) -> bool {
        let mut finder = LabelFinder::default();
        finder.visit_type(input);
        finder.found
    }
}

/// Returns the `match` case in `map_labels` corresponding to `variant`.
fn variant_to_map_labels_case(
    special_types: &LabelTypes,
    arg_fn_name: &syn::Ident,
    enum_name: &syn::Ident,
    variant: &syn::Variant,
) -> proc_macro2::TokenStream {
    let syn::Fields::Named(fields) = &variant.fields else {
        return syn::Error::new_spanned(
            &variant.fields,
            "The `Gate` enum only supports named fields",
        )
        .to_compile_error();
    };
    let variant_name = &variant.ident;
    let fields_names = fields.named.iter().map(|field| &field.ident);
    let fields_values = fields.named.iter().map(|field| {
        special_types.map_labels(
            field.ident.as_ref().expect("Named field with no name."),
            &field.ty,
            arg_fn_name,
        )
    });
    quote! {
         #enum_name::#variant_name { #(#fields_names),*} => #enum_name::#variant_name { #(#fields_values),* },
    }
}

/// Returns the `match` case in `for_each_label` corresponding to `variant`.
fn variant_to_for_each_label_case(
    special_types: &LabelTypes,
    arg_fn_name: &syn::Ident,
    enum_name: &syn::Ident,
    variant: &syn::Variant,
) -> proc_macro2::TokenStream {
    let syn::Fields::Named(fields) = &variant.fields else {
        return syn::Error::new_spanned(
            &variant.fields,
            "The `Gate` enum only supports named fields",
        )
        .to_compile_error();
    };
    let variant_name = &variant.ident;
    let fields_names = fields.named.iter().map(|field| &field.ident);
    let fields_ops = fields.named.iter().map(|field| {
        special_types.for_each_label(
            field.ident.as_ref().expect("Named field with no name."),
            &field.ty,
            arg_fn_name,
        )
    });
    quote! {
         #enum_name::#variant_name { #(#fields_names),*} => { #(#fields_ops)* },
    }
}

/// A struct containing all supported types with `Label`.
struct LabelTypes {
    /// `Label`
    label_type: syn::Type,
    /// `Vec<Label>`
    vec_label_type: syn::Type,
}

impl Default for LabelTypes {
    fn default() -> Self {
        let label_type = syn::parse_str("Label").expect("Unable to parse label type");
        let vec_label_type = syn::parse_str("Vec<Label>").expect("Unable to parse vec label type");
        Self {
            label_type,
            vec_label_type,
        }
    }
}

impl LabelTypes {
    /// Used in `map_labels` on all variant fields.
    /// Returns code that sets the value of field `field_name` of type `field_type` after function
    /// `arg_fn_name` has been run on all `Label`.
    fn map_labels(
        &self,
        field_name: &syn::Ident,
        field_type: &syn::Type,
        arg_fn_name: &syn::Ident,
    ) -> proc_macro2::TokenStream {
        if *field_type == self.vec_label_type {
            quote! {#field_name: #field_name.iter().cloned().map(#arg_fn_name).collect()}
        } else if *field_type == self.label_type {
            quote! {#field_name: #arg_fn_name(#field_name.clone())}
        } else if LabelFinder::find(field_type) {
            let err = syn::Error::new_spanned(
                field_type,
                "Unknown `Label` type for `GateMethods` macro.",
            )
            .to_compile_error();
            quote! {#field_name: #err}
        } else {
            quote! {#field_name: #field_name.clone()}
        }
    }
    /// Used in `for_each_label` on all variant fields.
    /// Returns code that applies `arg_fn_name` on all `Label` in field `field_name` of type
    /// `field_type`.
    fn for_each_label(
        &self,
        field_name: &syn::Ident,
        field_type: &syn::Type,
        arg_fn_name: &syn::Ident,
    ) -> proc_macro2::TokenStream {
        if *field_type == self.vec_label_type {
            quote! {#field_name.iter().cloned().for_each(#arg_fn_name);}
        } else if *field_type == self.label_type {
            quote! {#arg_fn_name(#field_name.clone());}
        } else if LabelFinder::find(field_type) {
            syn::Error::new_spanned(field_type, "Unknown `Label` type for `GateMethods` macro.")
                .to_compile_error()
        } else {
            quote! {}
        }
    }
}

pub fn derive_gate_methods_inner(input: DeriveInput) -> proc_macro2::TokenStream {
    let special_types = LabelTypes::default();
    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
    let enum_name = input.ident; // `enum_name` is equal to `Gate`.
    let syn::Data::Enum(syn::DataEnum { variants, .. }) = input.data else {
        return syn::Error::new(Span::call_site(), "Only works on enums").to_compile_error();
    };
    let arg_fn_name = format_ident!("f");
    let map_iter = variants.iter().map(|variant| {
        variant_to_map_labels_case(&special_types, &arg_fn_name, &enum_name, variant)
    });
    let for_each_iter = variants.iter().map(|variant| {
        variant_to_for_each_label_case(&special_types, &arg_fn_name, &enum_name, variant)
    });
    let map_labels_doc =
        format!("Creates a new `{enum_name}` by applying `{arg_fn_name}` to every `Label`.");
    let for_each_label_doc =
        format!("Applies `{arg_fn_name}` to every `Label` in the `{enum_name}`.");
    quote! {
        impl #impl_generics #enum_name #ty_generics #where_clause {
            #[doc = #map_labels_doc]
            pub fn map_labels(&self, mut #arg_fn_name: impl FnMut(Label) -> Label) -> Self {
                match self {
                    #(#map_iter)*
                }
            }

            #[doc = #for_each_label_doc]
            pub fn for_each_label(&self, mut #arg_fn_name: impl FnMut(Label)) {
                #![allow(unused_variables)]
                match self {
                    #(#for_each_iter)*
                }

            }
        }
    }
}

pub fn derive_gate_methods(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    // Parse the input token stream into a DeriveInput struct for analysis
    let input = parse_macro_input!(input as DeriveInput);
    derive_gate_methods_inner(input).into()
}