use anchor_lang::prelude::*;
use arcium_client::{
idl::arcium::{
accounts::Cluster,
cpi::{accounts::InitComputationDefinition, init_computation_definition},
types::{
AccountArgument,
ArgumentList,
ArgumentRef,
CallbackInstruction,
CircuitSource,
ComputationDefinitionMeta,
ComputationSignature,
Parameter,
SetUnset,
},
},
pda::{CLOCK_PDA, FEE_POOL_PDA},
};
pub use solana_address_lookup_table_interface;
use solana_alt_bn128_bls::Sha256Normalized;
use traits::{InitCompDefAccs, QueueCompAccs};
pub mod traits;
pub mod prelude {
pub use super::*;
pub use arcium_client::idl::arcium::{
accounts::{ClockAccount, Cluster, ComputationDefinitionAccount, FeePool, MXEAccount},
program::Arcium,
types::{AccountArgument, ArgumentList, ArgumentRef},
ID_CONST as ARCIUM_PROG_ID,
};
pub use arcium_macros::{
arcium_callback,
arcium_program,
callback_accounts,
check_args,
init_computation_definition_accounts,
queue_computation_accounts,
};
pub use traits::CallbackCompAccs;
pub use ArgBuilder;
pub use LUT_PROGRAM_ID;
}
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct SharedEncryptedStruct<const LEN: usize> {
pub encryption_key: [u8; 32],
pub nonce: u128,
pub ciphertexts: [[u8; 32]; LEN],
}
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct MXEEncryptedStruct<const LEN: usize> {
pub nonce: u128,
pub ciphertexts: [[u8; 32]; LEN],
}
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct EncDataStruct<const LEN: usize> {
pub ciphertexts: [[u8; 32]; LEN],
}
#[error_code]
pub enum ArciumError {
AbortedComputation,
BLSSignatureVerificationFailed,
InvalidClusterBLSPublicKey,
InvalidComputationAccount,
MarkerForIdlBuildUsageNotAllowed,
#[msg("Multi-transaction callbacks disabled; enable 'multi-tx-callbacks' feature")]
MultiTxCallbacksDisabled,
}
#[derive(Debug, AnchorSerialize, AnchorDeserialize)]
pub enum RawComputationOutputs<O: AnchorDeserialize + AnchorSerialize> {
Success(O),
Failure,
}
pub trait HasSize {
const SIZE: usize;
}
#[derive(Debug)]
pub enum SignedComputationOutputs<O: HasSize + AnchorDeserialize + AnchorSerialize> {
Success(Vec<u8>, [u8; 64]),
Failure,
MarkerForIdlBuildDoNotUseThis(O),
}
impl<O: HasSize + AnchorDeserialize + AnchorSerialize> AnchorDeserialize
for SignedComputationOutputs<O>
{
fn deserialize_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
let variant = u8::deserialize_reader(reader)?;
match variant {
0 => {
let mut bytes = vec![0u8; O::SIZE];
reader.read_exact(&mut bytes)?;
let mut sig = [0u8; 64];
reader.read_exact(&mut sig)?;
Ok(SignedComputationOutputs::Success(bytes, sig))
}
1 => Ok(SignedComputationOutputs::Failure),
2 => {
let o = O::deserialize_reader(reader)?;
Ok(SignedComputationOutputs::MarkerForIdlBuildDoNotUseThis(o))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid SignedComputationOutputs variant",
)),
}
}
}
impl<O: HasSize + AnchorDeserialize + AnchorSerialize> AnchorSerialize
for SignedComputationOutputs<O>
{
fn serialize<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
match self {
SignedComputationOutputs::Success(bytes, sig) => {
if bytes.len() != O::SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"SignedComputationOutputs payload must match O::SIZE",
));
}
0u8.serialize(writer)?;
writer.write_all(bytes)?;
writer.write_all(sig)?;
}
SignedComputationOutputs::Failure => {
1u8.serialize(writer)?;
}
SignedComputationOutputs::MarkerForIdlBuildDoNotUseThis(o) => {
2u8.serialize(writer)?;
o.serialize(writer)?;
}
}
Ok(())
}
}
impl<O: HasSize + AnchorDeserialize + AnchorSerialize> SignedComputationOutputs<O> {
pub fn verify_output_raw(
self,
arcium_cluster_acc: &Cluster,
computation_account: &UncheckedAccount,
) -> Result<Vec<u8>> {
let bls_pubkey = match arcium_cluster_acc.bls_public_key {
SetUnset::Set(bls_pubkey) => bls_pubkey,
SetUnset::Unset(..) => return Err(ArciumError::InvalidClusterBLSPublicKey.into()),
};
let (slot, slot_counter) = get_slot_and_slot_counter_bytes(computation_account)?;
match self {
SignedComputationOutputs::Success(o_bytes, bls_sig_bytes) => {
let message = &[
o_bytes.as_slice(),
slot.as_ref(),
slot_counter.as_ref(),
]
.concat();
let bls_pubkey = solana_alt_bn128_bls::G2CompressedPoint(bls_pubkey.0);
let bls_sig = solana_alt_bn128_bls::G1Point(bls_sig_bytes);
bls_pubkey
.verify_signature::<Sha256Normalized, &[u8], solana_alt_bn128_bls::G1Point>(
bls_sig, message,
)
.map_err(|_| ArciumError::BLSSignatureVerificationFailed)?;
Ok(o_bytes)
}
SignedComputationOutputs::Failure => Err(ArciumError::AbortedComputation.into()),
SignedComputationOutputs::MarkerForIdlBuildDoNotUseThis(_) => {
Err(ArciumError::MarkerForIdlBuildUsageNotAllowed.into())
}
}
}
pub fn verify_output(
self,
arcium_cluster_acc: &Cluster,
computation_account: &UncheckedAccount,
) -> Result<O> {
let raw = self.verify_output_raw(arcium_cluster_acc, computation_account)?;
Ok(O::try_from_slice(&raw)?)
}
}
const SIGNER_ACCOUNT_BUMP_OFFSET: usize = 8;
pub fn queue_computation<'info, T>(
accs: &T,
computation_offset: u64,
args: ArgumentList,
callback_instructions: Vec<CallbackInstruction>,
num_callback_txs: u8,
cu_price_micro: u64,
) -> Result<()>
where
T: QueueCompAccs<'info>,
{
#[cfg(not(feature = "multi-tx-callbacks"))]
if num_callback_txs != 1 {
return Err(error!(ArciumError::MultiTxCallbacksDisabled));
}
let bump = accs.signer_pda_bump();
let signer_seeds: &[&[&[u8]]] = &[&[SIGN_PDA_SEED, &[bump]]];
let queue_comp_accounts = accs.queue_comp_accs();
queue_comp_accounts.sign_seed.try_borrow_mut_data()?[SIGNER_ACCOUNT_BUMP_OFFSET] = bump;
let cpi_context =
CpiContext::new_with_signer(accs.arcium_program(), queue_comp_accounts, signer_seeds);
arcium_client::idl::arcium::cpi::queue_computation(
cpi_context,
computation_offset,
accs.comp_def_offset(),
args,
accs.mxe_program(),
callback_instructions,
num_callback_txs,
0,
cu_price_micro,
)
}
pub fn init_comp_def<'info, T>(
accs: &T,
circuit_source_override: Option<CircuitSource>,
finalize_authority: Option<Pubkey>,
) -> Result<()>
where
T: InitCompDefAccs<'info>,
{
let cpi_context = CpiContext::new(
accs.arcium_program(),
InitComputationDefinition {
signer: accs.signer(),
system_program: accs.system_program(),
mxe: accs.mxe_acc(),
comp_def_acc: accs.comp_def_acc(),
address_lookup_table: accs.address_lookup_table(),
lut_program: accs.lut_program(),
},
);
let signature = ComputationSignature {
parameters: accs.params(),
outputs: accs.outputs(),
};
let computation_definition = ComputationDefinitionMeta {
circuit_len: accs.compiled_circuit_len(),
signature,
};
init_computation_definition(
cpi_context,
accs.comp_def_offset(),
accs.mxe_program(),
computation_definition,
circuit_source_override,
accs.weight(),
finalize_authority,
)?;
Ok(())
}
pub struct ShortVec<T: AnchorSerialize + AnchorDeserialize> {
pub data: Vec<T>,
}
impl<T: AnchorSerialize + AnchorDeserialize> AnchorSerialize for ShortVec<T> {
fn serialize<W: std::io::Write>(
&self,
writer: &mut W,
) -> std::result::Result<(), std::io::Error> {
let len: u16 = self.data.len().try_into().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Length too large, must fit in u16",
)
})?;
len.serialize(writer)?;
for item in &self.data {
item.serialize(writer)?;
}
Ok(())
}
}
impl<T: AnchorSerialize + AnchorDeserialize> AnchorDeserialize for ShortVec<T> {
fn deserialize_reader<R: std::io::Read>(
reader: &mut R,
) -> std::result::Result<Self, std::io::Error> {
let len: u16 = u16::deserialize_reader(reader)?;
let mut data = Vec::with_capacity(len as usize);
for _ in 0..len {
data.push(T::deserialize_reader(reader)?);
}
Ok(Self { data })
}
}
fn get_slot_and_slot_counter_bytes(
computation_account: &UncheckedAccount,
) -> Result<([u8; 8], [u8; 2])> {
const SLOT_OFFSET: usize = 100;
const SLOT_COUNTER_OFFSET: usize = 108;
let data = computation_account.try_borrow_data()?;
let slot_bytes: [u8; 8] = data
.get(SLOT_OFFSET..SLOT_OFFSET + 8)
.ok_or(ArciumError::InvalidComputationAccount)?
.try_into()
.map_err(|_| ArciumError::InvalidComputationAccount)?;
let slot_counter_bytes: [u8; 2] = data
.get(SLOT_COUNTER_OFFSET..SLOT_COUNTER_OFFSET + 2)
.ok_or(ArciumError::InvalidComputationAccount)?
.try_into()
.map_err(|_| ArciumError::InvalidComputationAccount)?;
Ok((slot_bytes, slot_counter_bytes))
}
#[cfg(feature = "idl-build")]
impl<T: AnchorSerialize + AnchorDeserialize> anchor_lang::idl::build::IdlBuild for ShortVec<T> {
fn create_type() -> Option<anchor_lang::idl::types::IdlTypeDef> {
Some(anchor_lang::idl::types::IdlTypeDef {
name: Self::get_full_path(),
docs: vec![],
serialization: anchor_lang::idl::types::IdlSerialization::default(),
repr: None,
generics: <[_]>::into_vec(Box::new([
anchor_lang::idl::types::IdlTypeDefGeneric::Type { name: "T".into() }.into(),
])),
ty: anchor_lang::idl::types::IdlTypeDefTy::Struct {
fields: Some(anchor_lang::idl::types::IdlDefinedFields::Named(
<[_]>::into_vec(Box::new([anchor_lang::idl::types::IdlField {
name: "data".into(),
docs: vec![],
ty: anchor_lang::idl::types::IdlType::Vec(Box::new(
anchor_lang::idl::types::IdlType::Generic("T".into()),
)),
}])),
)),
},
})
}
fn insert_types(
types: &mut std::collections::BTreeMap<String, anchor_lang::idl::types::IdlTypeDef>,
) {
}
fn get_full_path() -> String {
std::fmt::format(format_args!("{0}", "ShortVec"))
}
}
#[cfg(feature = "idl-build")]
impl<O: HasSize + AnchorSerialize + AnchorDeserialize> anchor_lang::idl::build::IdlBuild
for SignedComputationOutputs<O>
{
fn create_type() -> Option<anchor_lang::idl::types::IdlTypeDef> {
Some(anchor_lang::idl::types::IdlTypeDef {
name: Self::get_full_path(),
docs: vec![],
serialization: anchor_lang::idl::types::IdlSerialization::default(),
repr: None,
generics: <[_]>::into_vec(Box::new([
anchor_lang::idl::types::IdlTypeDefGeneric::Type { name: "O".into() }.into(),
])),
ty: anchor_lang::idl::types::IdlTypeDefTy::Enum {
variants: vec![
anchor_lang::idl::types::IdlEnumVariant {
name: "Success".into(),
fields: Some(anchor_lang::idl::types::IdlDefinedFields::Tuple(vec![
anchor_lang::idl::types::IdlType::Generic("O".into()),
anchor_lang::idl::types::IdlType::Array(
Box::new(anchor_lang::idl::types::IdlType::U8),
anchor_lang::idl::types::IdlArrayLen::Value(64),
),
])),
},
anchor_lang::idl::types::IdlEnumVariant {
name: "Failure".into(),
fields: None,
},
anchor_lang::idl::types::IdlEnumVariant {
name: "MarkerForIdlBuildDoNotUseThis".into(),
fields: Some(anchor_lang::idl::types::IdlDefinedFields::Tuple(vec![
anchor_lang::idl::types::IdlType::Generic("O".into()),
])),
},
],
},
})
}
fn insert_types(
types: &mut std::collections::BTreeMap<String, anchor_lang::idl::types::IdlTypeDef>,
) {
}
fn get_full_path() -> String {
"SignedComputationOutputs".to_string()
}
}
#[macro_export]
macro_rules! derive_seed {
($name:ident) => {
stringify!($name).as_bytes()
};
}
pub const fn comp_def_offset(conf_ix_name: &str) -> u32 {
let hasher = ::sha2_const_stable::Sha256::new();
let result = hasher.update(conf_ix_name.as_bytes()).finalize();
u32::from_le_bytes([result[0], result[1], result[2], result[3]])
}
pub const MXE_PDA_SEED: &[u8] = derive_seed!(MXEAccount);
pub const MEMPOOL_PDA_SEED: &[u8] = b"Mempool";
pub const EXECPOOL_PDA_SEED: &[u8] = b"Execpool";
pub const COMP_PDA_SEED: &[u8] = derive_seed!(ComputationAccount);
pub const COMP_DEF_PDA_SEED: &[u8] = derive_seed!(ComputationDefinitionAccount);
pub const CLUSTER_PDA_SEED: &[u8] = derive_seed!(Cluster);
pub const POOL_PDA_SEED: &[u8] = derive_seed!(FeePool);
pub const CLOCK_PDA_SEED: &[u8] = derive_seed!(ClockAccount);
pub const SIGN_PDA_SEED: &[u8] = derive_seed!(ArciumSignerAccount);
pub const ARCIUM_CLOCK_ACCOUNT_ADDRESS: Pubkey = CLOCK_PDA.0;
pub const ARCIUM_FEE_POOL_ACCOUNT_ADDRESS: Pubkey = FEE_POOL_PDA.0;
pub const LUT_PROGRAM_ID: Pubkey = solana_address_lookup_table_interface::program::ID;
#[macro_export]
macro_rules! derive_mxe_pda {
() => {
Pubkey::find_program_address(&[MXE_PDA_SEED, ID.to_bytes().as_ref()], &ARCIUM_PROG_ID).0
};
}
#[macro_export]
macro_rules! derive_mempool_pda {
($mxe_account:expr, $error_path:expr) => {
Pubkey::find_program_address(
&[
MEMPOOL_PDA_SEED,
&$mxe_account.cluster.ok_or($error_path)?.to_le_bytes(),
],
&ARCIUM_PROG_ID,
)
.0
};
}
#[macro_export]
macro_rules! derive_execpool_pda {
($mxe_account:expr, $error_path:expr) => {
Pubkey::find_program_address(
&[
EXECPOOL_PDA_SEED,
&$mxe_account.cluster.ok_or($error_path)?.to_le_bytes(),
],
&ARCIUM_PROG_ID,
)
.0
};
}
#[macro_export]
macro_rules! derive_comp_pda {
($computation_offset:expr, $mxe_account:expr, $error_path:expr) => {
Pubkey::find_program_address(
&[
COMP_PDA_SEED,
&$mxe_account.cluster.ok_or($error_path)?.to_le_bytes(),
&$computation_offset.to_le_bytes(),
],
&ARCIUM_PROG_ID,
)
.0
};
}
#[macro_export]
macro_rules! derive_comp_def_pda {
($conf_ix_name:expr) => {
Pubkey::find_program_address(
&[
COMP_DEF_PDA_SEED,
&ID_CONST.to_bytes(),
&$conf_ix_name.to_le_bytes(),
],
&ARCIUM_PROG_ID,
)
.0
};
}
#[macro_export]
macro_rules! derive_cluster_pda {
($mxe_account:expr, $error_path:expr) => {
Pubkey::find_program_address(
&[
CLUSTER_PDA_SEED,
&$mxe_account.cluster.ok_or($error_path)?.to_le_bytes(),
],
&ARCIUM_PROG_ID,
)
.0
};
}
#[macro_export]
macro_rules! derive_sign_pda {
() => {
Pubkey::find_program_address(&[SIGN_PDA_SEED], &ID_CONST).0
};
}
#[macro_export]
macro_rules! derive_mxe_lut_pda {
($lut_offset:expr) => {{
let mxe_pda = derive_mxe_pda!();
::arcium_anchor::solana_address_lookup_table_interface::instruction::derive_lookup_table_address(&mxe_pda, $lut_offset).0
}};
}
include!("arg_builder.rs");
include!("arg_match_param.rs");
pub const fn const_match_computation(
arguments: &[ArgumentRef],
accounts: &[AccountArgument],
parameters: &[Parameter],
) {
if let Err(err) = args_match_params(arguments, accounts, parameters) {
err.const_panic();
}
}
#[cfg(test)]
mod tests {
use super::*;
use arcium_client::idl::arcium::{
accounts::ComputationAccount,
types::{ComputationStatus, ExecutionFee},
ID_CONST as ARCIUM_PROG_ID,
};
use std::{cell::RefCell, rc::Rc};
fn derive_arcium_pda(seeds: &[&[u8]]) -> Pubkey {
Pubkey::find_program_address(seeds, &ARCIUM_PROG_ID).0
}
#[test]
fn test_comp_def_offset() {
let conf_ix_name = "add_together";
let offset = comp_def_offset(conf_ix_name);
assert_eq!(offset, 4005749700);
}
#[test]
fn test_clock_account_address() {
let address = derive_arcium_pda(&[CLOCK_PDA_SEED]);
assert_eq!(address, ARCIUM_CLOCK_ACCOUNT_ADDRESS);
}
#[test]
fn test_fee_pool_account_address() {
let address = derive_arcium_pda(&[POOL_PDA_SEED]);
assert_eq!(address, ARCIUM_FEE_POOL_ACCOUNT_ADDRESS);
}
#[test]
fn test_get_slot_and_slot_counter_bytes() {
let computation_account = ComputationAccount {
payer: Pubkey::default(),
mxe_program_id: Pubkey::default(),
computation_definition_offset: 0,
execution_fee: ExecutionFee {
base_fee: 0,
priority_fee: 0,
output_delivery_fee: 0,
},
slot: 12345,
slot_counter: 5678,
status: ComputationStatus::Queued,
arguments: ArgumentList {
args: vec![],
byte_arrays: vec![],
plaintext_numbers: vec![],
values_128_bit: vec![],
accounts: vec![],
},
custom_callback_instructions: Vec::new(),
callback_transactions_required: 0,
callback_transactions_submitted_bm: 0,
bump: 0,
};
let mut key = Pubkey::default();
let mut lamports = 0;
let mut data = vec![];
let mut owner = Pubkey::default();
computation_account.try_serialize(&mut data).unwrap();
let account_info = AccountInfo {
key: &mut key,
lamports: Rc::new(RefCell::new(&mut lamports)),
data: Rc::new(RefCell::new(&mut data)),
owner: &mut owner,
rent_epoch: 0,
is_signer: false,
is_writable: false,
executable: false,
};
let computation_account = UncheckedAccount::try_from(&account_info);
let (slot_bytes, slot_counter_bytes) =
get_slot_and_slot_counter_bytes(&computation_account).unwrap();
let slot = u64::from_le_bytes(slot_bytes);
let slot_counter = u16::from_le_bytes(slot_counter_bytes);
assert_eq!(slot, 12345);
assert_eq!(slot_counter, 5678);
}
#[derive(Debug, PartialEq, AnchorSerialize, AnchorDeserialize)]
struct TestOutput {
value: u64,
}
impl HasSize for TestOutput {
const SIZE: usize = 8;
}
#[test]
fn test_signed_computation_outputs_roundtrip_success() {
let original = SignedComputationOutputs::<TestOutput>::Success(
vec![1, 2, 3, 4, 5, 6, 7, 8],
[42u8; 64],
);
let mut buf = Vec::new();
original.serialize(&mut buf).unwrap();
assert_eq!(buf.len(), 1 + 8 + 64);
assert_eq!(buf[0], 0);
let deserialized =
SignedComputationOutputs::<TestOutput>::deserialize(&mut &buf[..]).unwrap();
match deserialized {
SignedComputationOutputs::Success(bytes, sig) => {
assert_eq!(bytes, vec![1, 2, 3, 4, 5, 6, 7, 8]);
assert_eq!(sig, [42u8; 64]);
}
_ => panic!("Expected Success variant"),
}
}
#[test]
fn test_signed_computation_outputs_no_length_prefix() {
let output = SignedComputationOutputs::<TestOutput>::Success(vec![0xAA; 8], [0xBB; 64]);
let mut buf = Vec::new();
output.serialize(&mut buf).unwrap();
assert_eq!(&buf[1..9], &[0xAA; 8]);
assert_eq!(&buf[9..73], &[0xBB; 64]);
}
#[test]
fn test_signed_computation_outputs_roundtrip_failure() {
let original = SignedComputationOutputs::<TestOutput>::Failure;
let mut buf = Vec::new();
original.serialize(&mut buf).unwrap();
assert_eq!(buf.len(), 1);
assert_eq!(buf[0], 1);
let deserialized =
SignedComputationOutputs::<TestOutput>::deserialize(&mut &buf[..]).unwrap();
assert!(matches!(deserialized, SignedComputationOutputs::Failure));
}
#[test]
fn test_signed_computation_outputs_serialize_validates_size() {
let invalid = SignedComputationOutputs::<TestOutput>::Success(vec![1, 2, 3], [0u8; 64]);
let mut buf = Vec::new();
let result = invalid.serialize(&mut buf);
assert!(result.is_err());
}
}