arcium-macros 0.1.46

Arcium Macros
Documentation
use crate::{
    utils::{check_encrypted_ix_path, ArciumCallbackArgs},
    validation::{
        always_valid_check,
        is_valid_arcium_program_type,
        is_valid_comp_def_acc_type,
        is_valid_signer_type,
        validate_struct_fields,
        ValidateFunction,
    },
};
use quote::quote;
use syn::{parse::Parse, DeriveInput, Ident, ItemFn, LitStr, PatType, Token};

pub struct CallbackAccArgs {
    pub encrypted_ix: LitStr,
    pub payer_field: Ident,
}

impl Parse for CallbackAccArgs {
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
        let encrypted_ix: LitStr = input.parse()?;
        input.parse::<Token![,]>()?;
        let payer_field: Ident = input.parse()?;

        Ok(CallbackAccArgs {
            encrypted_ix,
            payer_field,
        })
    }
}

pub fn callback_accs_derive(input: &DeriveInput, args: CallbackAccArgs) -> proc_macro::TokenStream {
    // Check if the /build directory already contains the confidential instruction
    check_encrypted_ix_path(&args.encrypted_ix.value());

    let payer_field = &args.payer_field.to_string();
    let required_fields: Vec<(&str, ValidateFunction, bool)> = vec![
        (payer_field, is_valid_signer_type, true),
        ("arcium_program", is_valid_arcium_program_type, false),
        ("comp_def_account", is_valid_comp_def_acc_type, false),
        ("instructions_sysvar", always_valid_check, false),
    ];

    if let Err(error_msg) = validate_struct_fields(&input.data, &required_fields) {
        return quote! {
            compile_error!(#error_msg);
        }
        .into();
    }

    // Generate the final TokenStream
    let expanded = quote! {
        #input
    };

    expanded.into()
}

pub fn callback_ix_derive(input_fn: ItemFn, args: ArciumCallbackArgs) -> proc_macro::TokenStream {
    let fn_name = &input_fn.sig.ident;
    let fn_body = &input_fn.block;
    let fn_params = &input_fn.sig.inputs;

    // Check if the /build directory already contains the confidential instruction
    check_encrypted_ix_path(&args.encrypted_ix);

    // Function name should be "<encrypted_ix>_callback"
    if *fn_name.to_string() != format!("{}_callback", &args.encrypted_ix) {
        return syn::Error::new_spanned(
            fn_name,
            "function name must be `<encrypted_ix_name>_callback`",
        )
        .to_compile_error()
        .into();
    }

    // The function must have exactly two parameters
    if fn_params.len() != 2 {
        return syn::Error::new_spanned(
            input_fn.sig.inputs,
            "expected exactly two parameters, `ctx` and `output`",
        )
        .to_compile_error()
        .into();
    }

    // The first parameter must be a Context<T> type where T is any struct
    let ctx_param = fn_params
        .first()
        .expect("First parameter must be a Context<T>");
    if let syn::FnArg::Typed(PatType { ty, .. }) = ctx_param {
        if let syn::Type::Path(type_path) = ty.as_ref() {
            if let Some(segment) = type_path.path.segments.last() {
                if segment.ident != "Context" {
                    return syn::Error::new_spanned(ty, "parameter must be of type `Context<T>`")
                        .to_compile_error()
                        .into();
                }
                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
                    if args.args.len() != 1 {
                        return syn::Error::new_spanned(
                            ty,
                            "`Context` must have exactly one type argument",
                        )
                        .to_compile_error()
                        .into();
                    }
                } else {
                    return syn::Error::new_spanned(ty, "`Context` must have a type argument")
                        .to_compile_error()
                        .into();
                }
            }
        } else {
            return syn::Error::new_spanned(ty, "parameter must be of type `Context<T>`")
                .to_compile_error()
                .into();
        }
    } else {
        return syn::Error::new_spanned(ctx_param, "parameter must be of type `Context<T>`")
            .to_compile_error()
            .into();
    }

    // The second parameter must be a ComputationOutputs
    let output_param = fn_params.iter().nth(1).unwrap();
    if let syn::FnArg::Typed(PatType { ty, .. }) = output_param {
        if let syn::Type::Path(type_path) = ty.as_ref() {
            if let Some(segment) = type_path.path.segments.last() {
                if segment.ident != "ComputationOutputs" {
                    return syn::Error::new_spanned(
                        ty,
                        "second parameter must be of type `ComputationOutputs`",
                    )
                    .to_compile_error()
                    .into();
                }
            } else {
                return syn::Error::new_spanned(
                    ty,
                    "second parameter must be of type `ComputationOutputs`",
                )
                .to_compile_error()
                .into();
            }
        } else {
            return syn::Error::new_spanned(
                ty,
                "second parameter must be of type `ComputationOutputs`",
            )
            .to_compile_error()
            .into();
        }
    } else {
        return syn::Error::new_spanned(
            output_param,
            "second parameter must be of type `ComputationOutputs`",
        )
        .to_compile_error()
        .into();
    }

    // Check if the function returns a Result type
    let return_type = &input_fn.sig.output;

    if let syn::ReturnType::Type(_, ty) = return_type {
        if let syn::Type::Path(type_path) = ty.as_ref() {
            if let Some(segment) = type_path.path.segments.last() {
                if segment.ident != "Result" {
                    return syn::Error::new_spanned(ty, "function must return a `Result` type")
                        .to_compile_error()
                        .into();
                }
            }
        } else {
            return syn::Error::new_spanned(ty, "function must return a `Result` type")
                .to_compile_error()
                .into();
        }
    } else {
        return syn::Error::new_spanned(return_type, "function must return a `Result` type")
            .to_compile_error()
            .into();
    }

    quote! {
        pub fn #fn_name (#fn_params) -> ::anchor_lang::Result<()> {
            validate_callback_ixs(&ctx.accounts.instructions_sysvar, &ctx.accounts.arcium_program.key())?;

            #fn_body
        }
    }
    .into()
}