mpc-macros 0.2.16

Arcium MPC Macros
Documentation
use proc_macro::TokenStream;
use quote::quote;
use syn::{
    parse::Parser,
    punctuated::Punctuated,
    spanned::Spanned,
    Error,
    GenericArgument,
    Ident,
    ImplItem,
    ItemImpl,
    PathArguments,
    ReturnType,
    Token,
    Type,
};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum OpKind {
    Binary,       // Add, Mul, etc. - has RHS and return type
    BinaryAssign, // AddAssign, MulAssign, etc. - has RHS, no return type
    Unary,        // Neg, Not, etc. - no RHS, has return type
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Variant {
    Owned,    // Self op Self | Self op= Self
    Borrowed, // &Self op &Self | op &Self
    Flipped,  // &Self op Self
    FlippedCommutative,
}

impl Variant {
    fn from_str(s: &str, op_kind: OpKind) -> Result<Self, String> {
        match (s, op_kind) {
            ("owned", OpKind::Binary | OpKind::BinaryAssign) => Ok(Self::Owned),
            ("borrowed", OpKind::Binary | OpKind::Unary) => Ok(Self::Borrowed),
            ("flipped", OpKind::Binary) => Ok(Self::Flipped),
            ("flipped_commutative", OpKind::Binary) => Ok(Self::FlippedCommutative),
            _ => Err(format!("Invalid variant '{s}' for {op_kind:?}")),
        }
    }
}

pub fn op_variants(attr: TokenStream, item: TokenStream) -> TokenStream {
    _op_variants(attr, item).unwrap_or_else(|e| e.to_compile_error().into())
}

fn _op_variants(attr: TokenStream, item: TokenStream) -> Result<TokenStream, Error> {
    // 1. PARSE AND VALIDATE
    let impl_block: ItemImpl = syn::parse(item)?;
    let attrs = Punctuated::<Ident, Token![,]>::parse_terminated.parse(attr)?;

    // 1.1 Extract trait info
    let (_, trait_path, _) = (impl_block.trait_.as_ref())
        .ok_or_else(|| Error::new_spanned(&impl_block, "Expected trait impl"))?;
    let trait_name = trait_path.segments.last().unwrap().ident.clone();

    // 1.2 Extract RHS type
    let rhs_ty = trait_path.segments.last().and_then(|seg| {
        let PathArguments::AngleBracketed(args) = &seg.arguments else {
            return None;
        };
        args.args.first().and_then(|arg| match arg {
            GenericArgument::Type(ty) => Some(ty.clone()),
            _ => None,
        })
    });

    // 1.3 Extract method info
    let methods: Vec<_> = (impl_block.items.iter())
        .filter_map(|item| match item {
            ImplItem::Fn(m) => Some(m),
            _ => None,
        })
        .collect();
    if methods.is_empty() {
        return Err(Error::new_spanned(&impl_block, "No method found"));
    }
    if methods.len() > 1 {
        return Err(Error::new_spanned(&impl_block, "Multiple methods found"));
    }
    let op_fn = methods.first().unwrap();
    let op = op_fn.sig.ident.clone();
    let has_return = !matches!(op_fn.sig.output, ReturnType::Default);

    // 1.4 Determine op kind
    let op_kind = match (rhs_ty.is_some(), has_return) {
        (true, true) => OpKind::Binary,
        (true, false) => OpKind::BinaryAssign,
        (false, true) => OpKind::Unary,
        _ => return Err(Error::new_spanned(op_fn, "Invalid trait signature")),
    };

    // 1.5 Extract output type
    let out_ty = if has_return {
        (impl_block.items.iter())
            .find_map(|item| match (item, &op_fn.sig.output) {
                (ImplItem::Type(t), _) if t.ident == "Output" => Some(t.ty.clone()),
                (_, ReturnType::Type(_, ty)) => Some((**ty).clone()),
                _ => None,
            })
            .ok_or_else(|| Error::new_spanned(op_fn, "Cannot determine output type"))?
    } else {
        syn::parse_quote!(())
    };

    // 1.6 Parse variant attributes
    let variants: Result<Vec<_>, _> = (attrs.iter())
        .map(|ident| Variant::from_str(&ident.to_string(), op_kind))
        .collect();
    let variants = variants.map_err(|e| Error::new(attrs.span(), e))?;

    // 1.7 Check Self is owned
    let ty = (*impl_block.self_ty).clone();
    if matches!(ty, Type::Reference(_)) {
        return Err(Error::new_spanned(
            &ty,
            "impl must be for an owned type, not a reference",
        ));
    }

    let gen = impl_block.generics.clone();
    let wc = gen.where_clause.clone();
    let mut output = quote! { #impl_block };
    let ref_ty: Type = syn::parse_quote!(&#ty);

    // 2. IMPLEMENT
    for variant in variants {
        // Prepare Self & rhs types and op bodies
        let (rhs_own, rhs_ref) = owned_and_ref(rhs_ty.as_ref().unwrap_or(&ty));
        let (impl_ty, rhs_ty, op_body) = match (variant, op_kind) {
            (Variant::Owned, _) => (&ty, Some(rhs_own), quote! { self.#op(&rhs) }),
            (Variant::Borrowed, OpKind::Binary) => {
                (&ref_ty, Some(rhs_ref), quote! { self.to_owned().#op(rhs) })
            }
            (Variant::Borrowed, OpKind::Unary) => (&ref_ty, None, quote! { self.to_owned().#op() }),
            (Variant::Flipped, OpKind::Binary) => {
                (&ref_ty, Some(rhs_own), quote! { self.to_owned().#op(&rhs) })
            }
            (Variant::FlippedCommutative, OpKind::Binary) => {
                (&ref_ty, Some(rhs_own), quote! { rhs.#op(self) })
            }
            _ => unreachable!(),
        };

        // Prepare where clause
        let needs_clone = matches!(variant, Variant::Borrowed | Variant::Flipped);
        let wc = if needs_clone {
            let mut wc = wc.clone().unwrap_or_else(|| syn::parse_quote!(where));
            wc.predicates.push(syn::parse_quote!(#ty: Clone));
            Some(wc)
        } else {
            wc.clone()
        };

        // Generate impl block
        let impl_block = match (rhs_ty, op_kind) {
            (Some(rhs), OpKind::Binary) => quote! {
                impl #gen #trait_name<#rhs> for #impl_ty #wc {
                    type Output = #out_ty;
                    #[inline]
                    fn #op(self, rhs: #rhs) -> Self::Output { #op_body }
                }
            },
            (Some(rhs), OpKind::BinaryAssign) => quote! {
                impl #gen #trait_name<#rhs> for #impl_ty #wc {
                    #[inline]
                    fn #op(&mut self, rhs: #rhs) { #op_body }
                }
            },
            (None, OpKind::Unary) => quote! {
                impl #gen #trait_name for #impl_ty #wc {
                    type Output = #out_ty;
                    #[inline]
                    fn #op(self) -> Self::Output { #op_body }
                }
            },
            _ => unreachable!(),
        };
        output.extend(impl_block);
    }

    Ok(output.into())
}

fn owned_and_ref(ty: &Type) -> (Type, Type) {
    match ty {
        Type::Reference(r) => ((*r.elem).clone(), ty.clone()),
        _ => (ty.clone(), syn::parse_quote!(&#ty)),
    }
}