use crate::{
utils::{check_encrypted_ix_path, ArciumCallbackArgs},
validation::{
always_valid_check,
is_valid_arcium_program_type,
is_valid_comp_def_acc_type,
is_valid_signer_type,
validate_struct_fields,
ValidateFunction,
},
};
use quote::quote;
use syn::{parse::Parse, DeriveInput, Ident, ItemFn, LitStr, PatType, Token};
pub struct CallbackAccArgs {
pub encrypted_ix: LitStr,
pub payer_field: Ident,
}
impl Parse for CallbackAccArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let encrypted_ix: LitStr = input.parse()?;
input.parse::<Token![,]>()?;
let payer_field: Ident = input.parse()?;
Ok(CallbackAccArgs {
encrypted_ix,
payer_field,
})
}
}
pub fn callback_accs_derive(input: &DeriveInput, args: CallbackAccArgs) -> proc_macro::TokenStream {
check_encrypted_ix_path(&args.encrypted_ix.value());
let payer_field = &args.payer_field.to_string();
let required_fields: Vec<(&str, ValidateFunction, bool)> = vec![
(payer_field, is_valid_signer_type, true),
("arcium_program", is_valid_arcium_program_type, false),
("comp_def_account", is_valid_comp_def_acc_type, false),
("instructions_sysvar", always_valid_check, false),
];
if let Err(error_msg) = validate_struct_fields(&input.data, &required_fields) {
return quote! {
compile_error!(#error_msg);
}
.into();
}
let expanded = quote! {
#input
};
expanded.into()
}
pub fn callback_ix_derive(input_fn: ItemFn, args: ArciumCallbackArgs) -> proc_macro::TokenStream {
let fn_name = &input_fn.sig.ident;
let fn_body = &input_fn.block;
let fn_params = &input_fn.sig.inputs;
check_encrypted_ix_path(&args.encrypted_ix);
if *fn_name.to_string() != format!("{}_callback", &args.encrypted_ix) {
return syn::Error::new_spanned(
fn_name,
"function name must be `<encrypted_ix_name>_callback`",
)
.to_compile_error()
.into();
}
if fn_params.len() != 2 {
return syn::Error::new_spanned(
input_fn.sig.inputs,
"expected exactly two parameters, `ctx` and `output`",
)
.to_compile_error()
.into();
}
let ctx_param = fn_params
.first()
.expect("First parameter must be a Context<T>");
if let syn::FnArg::Typed(PatType { ty, .. }) = ctx_param {
if let syn::Type::Path(type_path) = ty.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident != "Context" {
return syn::Error::new_spanned(ty, "parameter must be of type `Context<T>`")
.to_compile_error()
.into();
}
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if args.args.len() != 1 {
return syn::Error::new_spanned(
ty,
"`Context` must have exactly one type argument",
)
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(ty, "`Context` must have a type argument")
.to_compile_error()
.into();
}
}
} else {
return syn::Error::new_spanned(ty, "parameter must be of type `Context<T>`")
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(ctx_param, "parameter must be of type `Context<T>`")
.to_compile_error()
.into();
}
let output_param = fn_params.iter().nth(1).unwrap();
if let syn::FnArg::Typed(PatType { ty, .. }) = output_param {
if let syn::Type::Path(type_path) = ty.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident != "ComputationOutputs" {
return syn::Error::new_spanned(
ty,
"second parameter must be of type `ComputationOutputs`",
)
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(
ty,
"second parameter must be of type `ComputationOutputs`",
)
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(
ty,
"second parameter must be of type `ComputationOutputs`",
)
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(
output_param,
"second parameter must be of type `ComputationOutputs`",
)
.to_compile_error()
.into();
}
let return_type = &input_fn.sig.output;
if let syn::ReturnType::Type(_, ty) = return_type {
if let syn::Type::Path(type_path) = ty.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident != "Result" {
return syn::Error::new_spanned(ty, "function must return a `Result` type")
.to_compile_error()
.into();
}
}
} else {
return syn::Error::new_spanned(ty, "function must return a `Result` type")
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(return_type, "function must return a `Result` type")
.to_compile_error()
.into();
}
quote! {
pub fn #fn_name (#fn_params) -> ::anchor_lang::Result<()> {
validate_callback_ixs(&ctx.accounts.instructions_sysvar, &ctx.accounts.arcium_program.key())?;
#fn_body
}
}
.into()
}