arcium-macros 0.9.3

Helper macros for developing Solana programs that integrate with the Arcium network.
Documentation
//! Compile-time validation of computation arguments using const evaluation.
//!
//! This module uses syn's visitor pattern to find `#[args("circuit_name")]` attributes in the code,
//! extracts the argument list, and validates it against the circuit interface using const
//! evaluation. Arguments are replaced with placeholder constants to enable this validation to run
//! at compile-time.

use crate::utils::{get_param_tokens_from_interface, read_conf_ix_interface};
use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use std::collections::HashMap;
use syn::{
    parse_quote,
    punctuated::Punctuated,
    visit_mut::VisitMut,
    Attribute,
    Expr,
    ExprCall,
    ExprMethodCall,
    ItemFn,
    Local,
    Stmt,
};

const ARGS_ATTRIBUTE_NAME: &str = "args";

/// Extracts method calls and accounts from ArgBuilder chain and converts them to ArgumentRef::...
/// expressions
fn extract_builder_args(
    expr: &Expr,
) -> syn::Result<(
    Punctuated<Expr, syn::token::Comma>,
    Punctuated<Expr, syn::token::Comma>,
)> {
    let mut method_calls = Vec::new();
    let mut current = expr;

    // Traverse the method call chain backwards (from build() to new())
    loop {
        match current {
            Expr::MethodCall(ExprMethodCall {
                method,
                receiver,
                args,
                ..
            }) => {
                if method == "build" {
                    // Start from the receiver of build()
                    current = receiver.as_ref();
                    continue;
                }
                // Collect method calls in reverse order
                method_calls.push((method.to_string(), args.clone()));
                current = receiver.as_ref();
            }
            Expr::Call(ExprCall { func, .. }) => {
                // Check if this is ArgBuilder::new()
                if let Expr::Path(path_expr) = func.as_ref() {
                    if let Some(seg) = path_expr.path.segments.last() {
                        if seg.ident == "new" {
                            // Found the start, break
                            break;
                        }
                    }
                }
                return Err(syn::Error::new_spanned(
                    expr,
                    "Expected ArgBuilder::new()...build() pattern",
                ));
            }
            _ => {
                return Err(syn::Error::new_spanned(
                    expr,
                    "Expected ArgBuilder::new()...build() pattern",
                ));
            }
        }
    }

    // Reverse to get correct order and convert to ArgumentRef::... expressions
    method_calls.reverse();
    let mut arguments = Punctuated::new();
    let mut account_count = 0u8;
    let mut accounts = Punctuated::new();

    for (method_name, method_args) in method_calls {
        let arg_expr = match method_name.as_str() {
            "x25519_pubkey" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::X25519Pubkey(0) }
            }
            "plaintext_u128" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextU128(0) }
            }
            "plaintext_u64" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextU64(0) }
            }
            "plaintext_u32" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextU32(0) }
            }
            "plaintext_u16" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextU16(0) }
            }
            "plaintext_u8" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextU8(0) }
            }
            "plaintext_i128" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextI128(0) }
            }
            "plaintext_i64" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextI64(0) }
            }
            "plaintext_i32" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextI32(0) }
            }
            "plaintext_i16" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextI16(0) }
            }
            "plaintext_i8" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextI8(0) }
            }
            "plaintext_bool" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextBool(0) }
            }
            "plaintext_float" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::PlaintextFloat(0) }
            }
            "encrypted_u128" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedU128(0) }
            }
            "encrypted_u64" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedU64(0) }
            }
            "encrypted_u32" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedU32(0) }
            }
            "encrypted_u16" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedU16(0) }
            }
            "encrypted_u8" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedU8(0) }
            }
            "encrypted_i128" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedI128(0) }
            }
            "encrypted_i64" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedI64(0) }
            }
            "encrypted_i32" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedI32(0) }
            }
            "encrypted_i16" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedI16(0) }
            }
            "encrypted_i8" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedI8(0) }
            }
            "encrypted_bool" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedBool(0) }
            }
            "encrypted_float" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::EncryptedFloat(0) }
            }
            "arcis_ed25519_signature" => {
                // We don't care about the index since we're just checking if it matches the param
                parse_quote! { ArgumentRef::ArcisEd25519Signature(0) }
            }
            "account" => {
                if method_args.len() != 3 {
                    return Err(syn::Error::new_spanned(
                        &method_args,
                        "account expects 3 arguments",
                    ));
                }
                let offset = method_args.iter().nth(1).unwrap();
                let length = method_args.iter().nth(2).unwrap();
                let account_index = account_count;
                let res = parse_quote! { ArgumentRef::Account(#account_index) };
                account_count += 1;
                // Set pubkey to zero address so it can run in const context
                accounts.push(parse_quote! { AccountArgument { pubkey: anchor_lang::solana_program::pubkey::Pubkey::new_from_array([0;32]), offset: #offset, length: #length } });
                res
            }
            _ => {
                return Err(syn::Error::new_spanned(
                    &method_args,
                    format!("Unknown builder method: {}", method_name),
                ));
            }
        };
        arguments.push(arg_expr);
    }

    Ok((arguments, accounts))
}

#[derive(Default)]
struct ArgInfo {
    args: Punctuated<Expr, syn::token::Comma>,
    accounts: Punctuated<Expr, syn::token::Comma>,
}

#[derive(Default)]
struct IxArgsFinder {
    current_ix: Option<String>,
    // Mapping to args and accounts for each circuit
    found: HashMap<String, ArgInfo>,
    errors: Vec<syn::Error>,
}

impl VisitMut for IxArgsFinder {
    fn visit_attributes_mut(&mut self, i: &mut Vec<Attribute>) {
        let attr = i
            .iter()
            .enumerate()
            .find(|attr| attr.1.meta.path().is_ident(ARGS_ATTRIBUTE_NAME));
        if let Some((idx, attr)) = attr {
            match attr.meta.require_list() {
                Ok(nv) => {
                    let s: syn::LitStr = match syn::parse2(nv.tokens.clone()) {
                        Ok(s) => s,
                        Err(e) => {
                            self.errors.push(e);
                            return;
                        }
                    };
                    self.current_ix = Some(s.value());
                }
                Err(e) => self.errors.push(e),
            }
            i.remove(idx);
        }
    }
    fn visit_local_mut(&mut self, i: &mut Local) {
        self.visit_attributes_mut(&mut i.attrs);
        syn::visit_mut::visit_local_mut(self, i);
    }
    fn visit_expr_mut(&mut self, i: &mut Expr) {
        let Some(current_ix) = self.current_ix.as_ref() else {
            syn::visit_mut::visit_expr_mut(self, i);
            return;
        };
        if self.found.contains_key(current_ix) {
            syn::visit_mut::visit_expr_mut(self, i);
            return;
        }
        // Check if this is an inline ArgBuilder::new()...build() expression
        match extract_builder_args(i) {
            Ok((arguments, accounts)) => {
                self.found.insert(
                    current_ix.clone(),
                    ArgInfo {
                        args: arguments,
                        accounts,
                    },
                );
                // Clear current_ix after finding args
                self.current_ix = None;
            }
            Err(_) => {
                // Not an ArgBuilder pattern, continue visiting
                syn::visit_mut::visit_expr_mut(self, i);
            }
        }
    }
    fn visit_stmt_mut(&mut self, i: &mut Stmt) {
        syn::visit_mut::visit_stmt_mut(self, i);
    }
}

pub fn check_args_fn(mut item_fn: ItemFn) -> TokenStream {
    let mut ix_args_finder = IxArgsFinder::default();
    ix_args_finder.visit_item_fn_mut(&mut item_fn);
    if ix_args_finder.found.is_empty() {
        ix_args_finder.errors.push(syn::Error::new(
            Span::call_site(),
            "No `#[args(\"your_instruction\")]` found.",
        ));
    }

    let extra_stmts = ix_args_finder
        .found
        .into_iter()
        .map(|(ix, ArgInfo { args, accounts })| {
            let conf_ix_interface = read_conf_ix_interface(&ix);
            let param_tokens = get_param_tokens_from_interface(&conf_ix_interface);
            let quote_args = args.iter();
            let quote_accounts = accounts.iter();

            let res = parse_quote! {
                const {
                    let accounts = [#(#quote_accounts),*];
                    let params = [#(#param_tokens),*];
                    let args = [#(#quote_args),*];
                    const_match_computation(&args, &accounts, &params);
                };
            };

            res
        });

    item_fn.block.stmts.splice(0..0, extra_stmts);
    let mut res = item_fn.to_token_stream();
    for err in ix_args_finder.errors {
        res.extend(err.to_compile_error());
    }
    res
}

#[cfg(test)]
mod tests {
    use super::*;
    #[ignore = "Used for debugging, not for testing."]
    #[test]
    fn debug_this_macro() {
        let input = parse_quote! {
            pub fn find_next_match(ctx: Context<NextMatch>, computation_offset: u64) -> Result<()> {
                ctx.accounts.sign_pda_account.bump = ctx.bumps.sign_pda_account;

                #[args("find_next_match")]
                let args = ArgBuilder::new()
                    .x25519_pubkey(ctx.accounts.orderbook.encryption_pubkey)
                    .plaintext_u128(ctx.accounts.orderbook.nonce)
                    .account(
                        ctx.accounts.orderbook.key(),
                        // Offset of 8 (discriminator) + 1 (bump) + 16 (nonce) + 32 (encryption pubkey)
                        8 + 1 + 16 + 32,
                        32 * 3 * ORDERBOOK_SIZE as u32,
                    )
                    .build();

                // Call the Arcium program to queue the computations
                queue_computation(
                    ctx.accounts,
                    computation_offset,
                    args,
                    vec![CallbackInstruction{
                        program_id: ID_CONST,
                        discriminator: instruction::FindNextMatchCallback::DISCRIMINATOR.to_vec(),
                        accounts: vec![
                            CallbackAccount{
                                pubkey: ARCIUM_PROGRAM_ID,
                                is_writable: false,
                            },
                            CallbackAccount{
                                pubkey: derive_comp_def_pda!(COMP_DEF_OFFSET_FIND_MATCH),
                                is_writable: false,
                            },
                            CallbackAccount{
                                pubkey: INSTRUCTIONS_SYSVAR_ID,
                                is_writable: false,
                            },
                        ],
                    }],
                    1,
                    0,
                )?;
                Ok(())
            }
        };
        let res = check_args_fn(input);
        println!("{}", res);
    }
}