use crate::{
gen_callback_types::gen_callback_output_struct,
utils::{check_encrypted_ix_path, ArciumCallbackArgs},
validation::{
always_valid_check,
is_valid_arcium_program_type,
is_valid_cluster_acc_type,
is_valid_comp_acc_type,
is_valid_comp_def_acc_type,
is_valid_mxe_acc_type,
validate_struct_fields,
ValidateFunction,
},
};
use convert_case::{Case, Casing};
use quote::quote;
use syn::{parse::Parse, DeriveInput, ItemFn, LitStr, PatType};
pub struct CallbackAccArgs {
pub encrypted_ix: LitStr,
}
impl Parse for CallbackAccArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let encrypted_ix: LitStr = input.parse()?;
Ok(CallbackAccArgs { encrypted_ix })
}
}
fn validate_callback_struct_name(struct_name: &str, encrypted_ix: &str) -> Result<(), String> {
let expected_name = format!("{}Callback", encrypted_ix.to_case(Case::Pascal));
if struct_name != expected_name {
Err(format!(
"struct `{}` must be named `{}` for encrypted instruction '{}'",
struct_name, expected_name, encrypted_ix
))
} else {
Ok(())
}
}
pub fn callback_accs_derive(input: &DeriveInput, args: CallbackAccArgs) -> proc_macro::TokenStream {
let struct_name = &input.ident;
let encrypted_ix_value = &args.encrypted_ix.value();
if let Err(error_msg) =
validate_callback_struct_name(&struct_name.to_string(), encrypted_ix_value)
{
return syn::Error::new_spanned(struct_name, error_msg)
.to_compile_error()
.into();
}
check_encrypted_ix_path(&args.encrypted_ix.value());
let required_fields: Vec<(&str, ValidateFunction, bool)> = vec![
("arcium_program", is_valid_arcium_program_type, false),
("comp_def_account", is_valid_comp_def_acc_type, false),
("mxe_account", is_valid_mxe_acc_type, false),
("computation_account", is_valid_comp_acc_type, false),
("cluster_account", is_valid_cluster_acc_type, false),
("instructions_sysvar", always_valid_check, false),
];
if let Err(error_msg) = validate_struct_fields(&input.data, &required_fields) {
return quote! {
compile_error!(#error_msg);
}
.into();
}
let encrypted_ix_value = &args.encrypted_ix.value();
let callback_output_struct = gen_callback_output_struct(encrypted_ix_value);
let callback_trait_impl = quote::quote! {
impl ::arcium_anchor::traits::CallbackCompAccs for #struct_name<'_>{
fn callback_ix(computation_offset: u64, mxe_account: &::arcium_client::idl::arcium::accounts::MXEAccount, extra_accs: &[::arcium_client::idl::arcium::types::CallbackAccount]) -> ::anchor_lang::prelude::Result<::arcium_client::idl::arcium::types::CallbackInstruction> {
let mut accounts = Vec::with_capacity(extra_accs.len() + 3);
accounts.push(::arcium_client::idl::arcium::types::CallbackAccount{
pubkey: ::arcium_client::ARCIUM_PROGRAM_ID,
is_writable: false,
});
accounts.push(::arcium_client::idl::arcium::types::CallbackAccount{
pubkey: ::arcium_anchor::derive_comp_def_pda!(::arcium_anchor::comp_def_offset(#encrypted_ix_value)),
is_writable: false,
});
accounts.push(::arcium_client::idl::arcium::types::CallbackAccount{
pubkey: ::arcium_anchor::derive_mxe_pda!(),
is_writable: false,
});
accounts.push(::arcium_client::idl::arcium::types::CallbackAccount{
pubkey: ::arcium_anchor::derive_comp_pda!(computation_offset, mxe_account, ErrorCode::ClusterNotSet),
is_writable: false,
});
accounts.push(::arcium_client::idl::arcium::types::CallbackAccount{
pubkey: ::arcium_anchor::derive_cluster_pda!(mxe_account, ErrorCode::ClusterNotSet),
is_writable: false,
});
accounts.push(::arcium_client::idl::arcium::types::CallbackAccount{
pubkey: ::anchor_lang::solana_program::sysvar::instructions::ID,
is_writable: false,
});
accounts.extend_from_slice(extra_accs);
Ok(::arcium_client::idl::arcium::types::CallbackInstruction{
program_id: crate::ID_CONST,
discriminator: crate::instruction::#struct_name::DISCRIMINATOR.to_vec(),
accounts,
})
}
}
};
let expanded = quote! {
#callback_output_struct
#input
#callback_trait_impl
};
expanded.into()
}
pub fn callback_ix_derive(input_fn: ItemFn, args: ArciumCallbackArgs) -> proc_macro::TokenStream {
let fn_name = &input_fn.sig.ident;
let fn_body = &input_fn.block;
let fn_params = &input_fn.sig.inputs;
check_encrypted_ix_path(&args.encrypted_ix);
if *fn_name.to_string() != format!("{}_callback", &args.encrypted_ix) {
return syn::Error::new_spanned(
fn_name,
"function name must be `<encrypted_ix_name>_callback`",
)
.to_compile_error()
.into();
}
if fn_params.len() != 2 {
return syn::Error::new_spanned(
input_fn.sig.inputs,
"expected exactly two parameters, `ctx`, and `output`",
)
.to_compile_error()
.into();
}
let ctx_param = fn_params
.first()
.expect("First parameter must be a Context<T>");
if let syn::FnArg::Typed(PatType { ty, .. }) = ctx_param {
if let syn::Type::Path(type_path) = ty.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident != "Context" {
return syn::Error::new_spanned(ty, "parameter must be of type `Context<T>`")
.to_compile_error()
.into();
}
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if args.args.len() != 1 {
return syn::Error::new_spanned(
ty,
"`Context` must have exactly one type argument",
)
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(ty, "`Context` must have a type argument")
.to_compile_error()
.into();
}
}
} else {
return syn::Error::new_spanned(ty, "parameter must be of type `Context<T>`")
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(ctx_param, "parameter must be of type `Context<T>`")
.to_compile_error()
.into();
}
let output_param = fn_params.iter().nth(1).unwrap();
if let syn::FnArg::Typed(PatType { ty, .. }) = output_param {
if let syn::Type::Path(type_path) = ty.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
let (struct_name, generic_count) = if args.auto_serialize {
if segment.ident == "SignedComputationOutputs" {
("SignedComputationOutputs", 1)
} else {
return syn::Error::new_spanned(
ty,
"second parameter must be `SignedComputationOutputs<T>`",
)
.to_compile_error()
.into();
}
} else if segment.ident == "SignedComputationOutputs" {
("SignedComputationOutputs", 1)
} else if segment.ident == "RawComputationOutputs" {
("RawComputationOutputs", 1)
} else {
return syn::Error::new_spanned(
ty,
"second parameter must be `SignedComputationOutputs<T>` or `RawComputationOutputs<T>` when auto_serialize = false",
)
.to_compile_error()
.into();
};
let type_arg = match &segment.arguments {
syn::PathArguments::AngleBracketed(args_bracket)
if args_bracket.args.len() == generic_count =>
{
args_bracket.args.first()
}
syn::PathArguments::AngleBracketed(_) => {
return syn::Error::new_spanned(
ty,
format!(
"`{}` must have exactly {} argument(s)",
struct_name, generic_count
),
)
.to_compile_error()
.into();
}
_ => {
return syn::Error::new_spanned(
ty,
format!("`{}` must have {} argument(s)", struct_name, generic_count),
)
.to_compile_error()
.into();
}
};
if args.auto_serialize {
if let Some(syn::GenericArgument::Type(syn::Type::Path(inner_path))) = type_arg
{
if let Some(inner_segment) = inner_path.path.segments.last() {
let expected_type_name =
format!("{}Output", args.encrypted_ix.to_case(Case::Pascal));
if inner_segment.ident != expected_type_name {
return syn::Error::new_spanned(
ty,
format!(
"when auto_serialize is true (default), expected type `{}<{}>` but found `{}<{}>`. \
Consider using `auto_serialize = false` if you want to use a custom type.",
struct_name, expected_type_name, struct_name, inner_segment.ident
),
)
.to_compile_error()
.into();
}
}
}
}
} else {
return syn::Error::new_spanned(
ty,
"second parameter must reference a concrete callback output type",
)
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(
ty,
"second parameter must reference a concrete callback output type",
)
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(
output_param,
"second parameter must reference a concrete callback output type",
)
.to_compile_error()
.into();
}
let return_type = &input_fn.sig.output;
if let syn::ReturnType::Type(_, ty) = return_type {
if let syn::Type::Path(type_path) = ty.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident != "Result" {
return syn::Error::new_spanned(ty, "function must return a `Result` type")
.to_compile_error()
.into();
}
}
} else {
return syn::Error::new_spanned(ty, "function must return a `Result` type")
.to_compile_error()
.into();
}
} else {
return syn::Error::new_spanned(return_type, "function must return a `Result` type")
.to_compile_error()
.into();
}
quote! {
pub fn #fn_name (#fn_params) -> ::anchor_lang::Result<()> {
validate_callback_ixs(&ctx.accounts.instructions_sysvar, &ctx.accounts.arcium_program.key())?;
#fn_body
}
}
.into()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_callback_struct_naming_valid_snake_case() {
assert!(validate_callback_struct_name("AddOrderCallback", "add_order").is_ok());
assert!(validate_callback_struct_name("FindNextMatchCallback", "find_next_match").is_ok());
assert!(validate_callback_struct_name("AddTogetherCallback", "add_together").is_ok());
assert!(
validate_callback_struct_name("VerifySignatureCallback", "verify_signature").is_ok()
);
}
#[test]
fn test_callback_struct_naming_valid_edge_cases() {
assert!(validate_callback_struct_name("TestCallback", "test").is_ok());
assert!(
validate_callback_struct_name("MyLongCircuitNameCallback", "my_long_circuit_name")
.is_ok()
);
assert!(validate_callback_struct_name("Ed25519Callback", "ed_25519").is_ok());
}
#[test]
fn test_callback_struct_naming_invalid_wrong_suffix() {
let result = validate_callback_struct_name("AddOrder", "add_order");
assert!(result.is_err());
assert!(result.unwrap_err().contains("AddOrderCallback"));
}
#[test]
fn test_callback_struct_naming_invalid_wrong_case() {
let result = validate_callback_struct_name("addOrderCallback", "add_order");
assert!(result.is_err());
assert!(result.unwrap_err().contains("AddOrderCallback"));
}
#[test]
fn test_callback_struct_naming_invalid_completely_wrong() {
let result = validate_callback_struct_name("WrongName", "add_order");
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(error_msg.contains("WrongName"));
assert!(error_msg.contains("AddOrderCallback"));
assert!(error_msg.contains("add_order"));
}
#[test]
fn test_callback_struct_naming_error_message_format() {
let result = validate_callback_struct_name("WrongCallback", "my_circuit");
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(error_msg.contains("WrongCallback"));
assert!(error_msg.contains("MyCircuitCallback"));
assert!(error_msg.contains("my_circuit"));
assert!(error_msg.contains("must be named"));
}
#[test]
fn test_pascal_case_conversion() {
assert_eq!("add_order".to_case(Case::Pascal), "AddOrder");
assert_eq!("find_next_match".to_case(Case::Pascal), "FindNextMatch");
assert_eq!("test".to_case(Case::Pascal), "Test");
assert_eq!(
"my_long_circuit_name".to_case(Case::Pascal),
"MyLongCircuitName"
);
}
}