mpc-macros 0.3.0

Arcium MPC Macros
Documentation
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{parse_macro_input, visit::Visit, DeriveInput};

/// Finds `GateIndex`.
#[derive(Default)]
struct GateIndexFinder {
    /// Will be set to `true` when `GateIndex` is found.
    found: bool,
}

impl syn::visit::Visit<'_> for GateIndexFinder {
    fn visit_ident(&mut self, id: &syn::Ident) {
        if id == "GateIndex" {
            self.found = true;
        }
    }
}

impl GateIndexFinder {
    /// Finds whether `input` uses `GateIndex` type.
    fn find(input: &syn::Type) -> bool {
        let mut finder = GateIndexFinder::default();
        finder.visit_type(input);
        finder.found
    }
}

/// Returns the `match`x case in `map_gate_indices` corresponding to `variant`.
fn variant_to_map_gate_indices_case(
    special_types: &GateIndexTypes,
    arg_fn_name: &syn::Ident,
    enum_name: &syn::Ident,
    variant: &syn::Variant,
) -> proc_macro2::TokenStream {
    let syn::Fields::Named(fields) = &variant.fields else {
        return syn::Error::new_spanned(
            &variant.fields,
            "The `Gate` enum only supports named fields",
        )
        .to_compile_error();
    };
    let variant_name = &variant.ident;
    let fields_names = fields.named.iter().map(|field| &field.ident);
    let fields_values = fields.named.iter().map(|field| {
        special_types.map_gate_indices(
            field.ident.as_ref().expect("Named field with no name."),
            &field.ty,
            arg_fn_name,
        )
    });
    quote! {
         #enum_name::#variant_name { #(#fields_names),*} => #enum_name::#variant_name { #(#fields_values),* },
    }
}

/// Returns the `match` case in `for_each_gate_index` corresponding to `variant`.
fn variant_to_for_each_gate_index_case(
    special_types: &GateIndexTypes,
    arg_fn_name: &syn::Ident,
    enum_name: &syn::Ident,
    variant: &syn::Variant,
) -> proc_macro2::TokenStream {
    let syn::Fields::Named(fields) = &variant.fields else {
        return syn::Error::new_spanned(
            &variant.fields,
            "The `Gate` enum only supports named fields",
        )
        .to_compile_error();
    };
    let variant_name = &variant.ident;
    let fields_names = fields.named.iter().map(|field| &field.ident);
    let fields_ops = fields.named.iter().map(|field| {
        special_types.for_each_gate_index(
            field.ident.as_ref().expect("Named field with no name."),
            &field.ty,
            arg_fn_name,
        )
    });
    quote! {
         #enum_name::#variant_name { #(#fields_names),*} => { #(#fields_ops)* },
    }
}

/// A struct containing all supported types with `GateIndex`.
struct GateIndexTypes {
    /// `GateIndex`
    gate_index_type: syn::Type,
    /// `Vec<GateIndex>`
    vec_gate_index_type: syn::Type,
}

impl Default for GateIndexTypes {
    fn default() -> Self {
        let gate_index_type = syn::parse_str("GateIndex").expect("Unable to parse gate_index type");
        let vec_gate_index_type =
            syn::parse_str("Vec<GateIndex>").expect("Unable to parse vec gate_index type");
        Self {
            gate_index_type,
            vec_gate_index_type,
        }
    }
}

impl GateIndexTypes {
    /// Used in `map_gate_indices` on all variant fields.
    /// Returns code that sets the value of field `field_name` of type `field_type` after function
    /// `arg_fn_name` has been run on all `GateIndex`.
    fn map_gate_indices(
        &self,
        field_name: &syn::Ident,
        field_type: &syn::Type,
        arg_fn_name: &syn::Ident,
    ) -> proc_macro2::TokenStream {
        if *field_type == self.vec_gate_index_type {
            quote! {#field_name: #field_name.iter().cloned().map(#arg_fn_name).collect()}
        } else if *field_type == self.gate_index_type {
            quote! {#field_name: #arg_fn_name(#field_name.clone())}
        } else if GateIndexFinder::find(field_type) {
            let err = syn::Error::new_spanned(
                field_type,
                "Unknown `GateIndex` type for `GateMethods` macro.",
            )
            .to_compile_error();
            quote! {#field_name: #err}
        } else {
            quote! {#field_name: #field_name.clone()}
        }
    }
    /// Used in `for_each_gate_index` on all variant fields.
    /// Returns code that applies `arg_fn_name` on all `GateIndex` in field `field_name` of type
    /// `field_type`.
    fn for_each_gate_index(
        &self,
        field_name: &syn::Ident,
        field_type: &syn::Type,
        arg_fn_name: &syn::Ident,
    ) -> proc_macro2::TokenStream {
        if *field_type == self.vec_gate_index_type {
            quote! {#field_name.iter().cloned().for_each(#arg_fn_name);}
        } else if *field_type == self.gate_index_type {
            quote! {#arg_fn_name(#field_name.clone());}
        } else if GateIndexFinder::find(field_type) {
            syn::Error::new_spanned(
                field_type,
                "Unknown `GateIndex` type for `GateMethods` macro.",
            )
            .to_compile_error()
        } else {
            quote! {}
        }
    }
}

pub fn derive_gate_methods_inner(input: DeriveInput) -> proc_macro2::TokenStream {
    let special_types = GateIndexTypes::default();
    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
    let enum_name = input.ident; // `enum_name` is equal to `Gate`.
    let syn::Data::Enum(syn::DataEnum { variants, .. }) = input.data else {
        return syn::Error::new(Span::call_site(), "Only works on enums").to_compile_error();
    };
    let arg_fn_name = format_ident!("f");
    let map_iter = variants.iter().map(|variant| {
        variant_to_map_gate_indices_case(&special_types, &arg_fn_name, &enum_name, variant)
    });
    let for_each_iter = variants.iter().map(|variant| {
        variant_to_for_each_gate_index_case(&special_types, &arg_fn_name, &enum_name, variant)
    });
    let map_gate_indices_doc =
        format!("Creates a new `{enum_name}` by applying `{arg_fn_name}` to every `GateIndex`.");
    let for_each_gate_index_doc =
        format!("Applies `{arg_fn_name}` to every `GateIndex` in the `{enum_name}`.");
    quote! {
        impl #impl_generics #enum_name #ty_generics #where_clause {
            #[doc = #map_gate_indices_doc]
            pub fn map_gate_indices(&self, mut #arg_fn_name: impl FnMut(GateIndex) -> GateIndex) -> Self {
                match self {
                    #(#map_iter)*
                }
            }

            #[doc = #for_each_gate_index_doc]
            pub fn for_each_gate_index(&self, mut #arg_fn_name: impl FnMut(GateIndex)) {
                #![allow(unused_variables)]
                match self {
                    #(#for_each_iter)*
                }

            }
        }
    }
}

pub fn derive_gate_methods(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    // Parse the input token stream into a DeriveInput struct for analysis
    let input = parse_macro_input!(input as DeriveInput);
    derive_gate_methods_inner(input).into()
}