arcium-anchor 0.4.0

A helper crate for integrating Arcium into Solana programs.
Documentation
use anchor_lang::prelude::*;
use arcium_client::{
    idl::arcium::{
        cpi::{accounts::InitComputationDefinition, init_computation_definition},
        types::{
            Argument,
            CallbackInstruction,
            CircuitSource,
            ComputationDefinitionMeta,
            ComputationSignature,
            Parameter,
        },
    },
    pda::{CLOCK_PDA, FEE_POOL_PDA},
};
use traits::{InitCompDefAccs, QueueCompAccs};

pub mod traits;

pub mod prelude {
    pub use super::*;
    pub use arcium_client::idl::arcium::{
        accounts::{ClockAccount, Cluster, ComputationDefinitionAccount, FeePool, MXEAccount},
        program::Arcium,
        types::Argument,
        ID_CONST as ARCIUM_PROG_ID,
    };
    pub use arcium_macros::{
        arcium_callback,
        arcium_program,
        callback_accounts,
        check_args,
        init_computation_definition_accounts,
        queue_computation_accounts,
    };
    pub use traits::CallbackCompAccs;
}

#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct SharedEncryptedStruct<const LEN: usize> {
    pub encryption_key: [u8; 32],
    pub nonce: u128,
    pub ciphertexts: [[u8; 32]; LEN],
}

#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct MXEEncryptedStruct<const LEN: usize> {
    pub nonce: u128,
    pub ciphertexts: [[u8; 32]; LEN],
}

#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct EncDataStruct<const LEN: usize> {
    pub ciphertexts: [[u8; 32]; LEN],
}

// a struct matching this is present in arx/src/utils/arx_computation_output.rs
// when making changes here, make sure to update the arx version (e.g. if adding a new variant)
#[derive(Debug, AnchorSerialize, AnchorDeserialize)]
pub enum ComputationOutputs<O> {
    Success(O),
    Failure,
}

pub fn queue_computation<'info, T>(
    accs: &T,
    computation_offset: u64,
    args: Vec<Argument>,
    callback_url: Option<String>,
    callback_instructions: Vec<CallbackInstruction>,
    num_callback_txs: u8,
) -> Result<()>
where
    T: QueueCompAccs<'info>,
{
    let signer_seeds: &[&[&[u8]]] = &[&[SIGN_PDA_SEED, &[accs.signer_pda_bump()]]];
    let cpi_context =
        CpiContext::new_with_signer(accs.arcium_program(), accs.queue_comp_accs(), signer_seeds);
    arcium_client::idl::arcium::cpi::queue_computation(
        cpi_context,
        computation_offset,
        accs.comp_def_offset(),
        None,
        args,
        accs.mxe_program(),
        callback_url,
        callback_instructions,
        num_callback_txs,
        0,
        0,
    )
}

pub fn init_comp_def<'info, T>(
    accs: &T,
    cu_amount: u64,
    circuit_source_override: Option<CircuitSource>,
    finalize_authority: Option<Pubkey>,
) -> Result<()>
where
    T: InitCompDefAccs<'info>,
{
    let cpi_context = CpiContext::new(
        accs.arcium_program(),
        InitComputationDefinition {
            signer: accs.signer(),
            system_program: accs.system_program(),
            mxe: accs.mxe_acc(),
            comp_def_acc: accs.comp_def_acc(),
        },
    );

    let signature = ComputationSignature {
        parameters: accs.params(),
        outputs: accs.outputs(),
    };
    let computation_definition = ComputationDefinitionMeta {
        circuit_len: accs.compiled_circuit_len(),
        signature,
    };
    init_computation_definition(
        cpi_context,
        accs.comp_def_offset(),
        accs.mxe_program(),
        computation_definition,
        circuit_source_override,
        cu_amount,
        finalize_authority,
    )?;

    Ok(())
}

pub struct ShortVec<T: AnchorSerialize + AnchorDeserialize> {
    pub data: Vec<T>,
}

impl<T: AnchorSerialize + AnchorDeserialize> AnchorSerialize for ShortVec<T> {
    fn serialize<W: std::io::Write>(
        &self,
        writer: &mut W,
    ) -> std::result::Result<(), std::io::Error> {
        let len: u16 = self.data.len().try_into().map_err(|_| {
            std::io::Error::new(
                std::io::ErrorKind::InvalidInput,
                "Length too large, must fit in u16",
            )
        })?;
        len.serialize(writer)?;
        for item in &self.data {
            item.serialize(writer)?;
        }
        Ok(())
    }
}

impl<T: AnchorSerialize + AnchorDeserialize> AnchorDeserialize for ShortVec<T> {
    fn deserialize_reader<R: std::io::Read>(
        reader: &mut R,
    ) -> std::result::Result<Self, std::io::Error> {
        let len: u16 = u16::deserialize_reader(reader)?;
        let mut data = Vec::with_capacity(len as usize);
        for _ in 0..len {
            data.push(T::deserialize_reader(reader)?);
        }
        Ok(Self { data })
    }
}

#[cfg(feature = "idl-build")]
impl<T: AnchorSerialize + AnchorDeserialize> anchor_lang::idl::build::IdlBuild for ShortVec<T> {
    fn create_type() -> Option<anchor_lang::idl::types::IdlTypeDef> {
        Some(anchor_lang::idl::types::IdlTypeDef {
            name: Self::get_full_path(),
            docs: vec![],
            serialization: anchor_lang::idl::types::IdlSerialization::default(),
            repr: None,
            generics: <[_]>::into_vec(Box::new([
                anchor_lang::idl::types::IdlTypeDefGeneric::Type { name: "T".into() }.into(),
            ])),
            ty: anchor_lang::idl::types::IdlTypeDefTy::Struct {
                fields: Some(anchor_lang::idl::types::IdlDefinedFields::Named(
                    <[_]>::into_vec(Box::new([anchor_lang::idl::types::IdlField {
                        name: "data".into(),
                        docs: vec![],
                        ty: anchor_lang::idl::types::IdlType::Vec(Box::new(
                            anchor_lang::idl::types::IdlType::Generic("T".into()),
                        )),
                    }])),
                )),
            },
        })
    }

    fn insert_types(
        types: &mut std::collections::BTreeMap<String, anchor_lang::idl::types::IdlTypeDef>,
    ) {
    }

    fn get_full_path() -> String {
        std::fmt::format(format_args!("{0}", "ShortVec"))
    }
}

#[macro_export]
macro_rules! derive_seed {
    ($name:ident) => {
        stringify!($name).as_bytes()
    };
}

pub const fn comp_def_offset(conf_ix_name: &str) -> u32 {
    let hasher = ::sha2_const_stable::Sha256::new();
    let result = hasher.update(conf_ix_name.as_bytes()).finalize();
    u32::from_le_bytes([result[0], result[1], result[2], result[3]])
}

pub const MXE_PDA_SEED: &[u8] = derive_seed!(MXEAccount);
pub const MEMPOOL_PDA_SEED: &[u8] = b"Mempool";
pub const EXECPOOL_PDA_SEED: &[u8] = b"Execpool";
pub const COMP_PDA_SEED: &[u8] = derive_seed!(ComputationAccount);
pub const COMP_DEF_PDA_SEED: &[u8] = derive_seed!(ComputationDefinitionAccount);
pub const CLUSTER_PDA_SEED: &[u8] = derive_seed!(Cluster);
pub const POOL_PDA_SEED: &[u8] = derive_seed!(FeePool);
pub const CLOCK_PDA_SEED: &[u8] = derive_seed!(ClockAccount);
pub const SIGN_PDA_SEED: &[u8] = derive_seed!(SignerAccount);

pub const ARCIUM_CLOCK_ACCOUNT_ADDRESS: Pubkey = CLOCK_PDA.0;
pub const ARCIUM_FEE_POOL_ACCOUNT_ADDRESS: Pubkey = FEE_POOL_PDA.0;

#[macro_export]
macro_rules! derive_mxe_pda {
    () => {
        Pubkey::find_program_address(&[MXE_PDA_SEED, ID.to_bytes().as_ref()], &ARCIUM_PROG_ID).0
    };
}

#[macro_export]
macro_rules! derive_mempool_pda {
    () => {
        Pubkey::find_program_address(&[MEMPOOL_PDA_SEED, ID.to_bytes().as_ref()], &ARCIUM_PROG_ID).0
    };
}

#[macro_export]
macro_rules! derive_execpool_pda {
    () => {
        Pubkey::find_program_address(
            &[EXECPOOL_PDA_SEED, ID.to_bytes().as_ref()],
            &ARCIUM_PROG_ID,
        )
        .0
    };
}

#[macro_export]
macro_rules! derive_comp_pda {
    ($computation_offset:expr) => {
        Pubkey::find_program_address(
            &[
                COMP_PDA_SEED,
                ID.to_bytes().as_ref(),
                &$computation_offset.to_le_bytes(),
            ],
            &ARCIUM_PROG_ID,
        )
        .0
    };
}

#[macro_export]
macro_rules! derive_comp_def_pda {
    ($conf_ix_name:expr) => {
        Pubkey::find_program_address(
            &[
                COMP_DEF_PDA_SEED,
                &ID_CONST.to_bytes(),
                &$conf_ix_name.to_le_bytes(),
            ],
            &ARCIUM_PROG_ID,
        )
        .0
    };
}

#[macro_export]
macro_rules! derive_cluster_pda {
    ($mxe_account:expr, $erorr_path:expr) => {
        Pubkey::find_program_address(
            &[
                CLUSTER_PDA_SEED,
                &$mxe_account.cluster.ok_or($erorr_path)?.to_le_bytes(),
            ],
            &ARCIUM_PROG_ID,
        )
        .0
    };
}

#[macro_export]
macro_rules! derive_sign_pda {
    () => {
        Pubkey::find_program_address(&[SIGN_PDA_SEED], &ID_CONST).0
    };
}
include!("arg_match_param.rs");
pub const fn const_match_computation(arguments: &[Argument], parameters: &[Parameter]) {
    if let Err(err) = args_match_params(arguments, parameters) {
        err.const_panic();
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use arcium_client::idl::arcium::ID_CONST as ARCIUM_PROG_ID;

    fn derive_arcium_pda(seeds: &[&[u8]]) -> Pubkey {
        Pubkey::find_program_address(seeds, &ARCIUM_PROG_ID).0
    }

    #[test]
    fn test_comp_def_offset() {
        let conf_ix_name = "add_together";
        let offset = comp_def_offset(conf_ix_name);
        assert_eq!(offset, 4005749700);
    }

    #[test]
    fn test_clock_account_address() {
        let address = derive_arcium_pda(&[CLOCK_PDA_SEED]);
        assert_eq!(address, ARCIUM_CLOCK_ACCOUNT_ADDRESS);
    }

    #[test]
    fn test_fee_pool_account_address() {
        let address = derive_arcium_pda(&[POOL_PDA_SEED]);
        assert_eq!(address, ARCIUM_FEE_POOL_ACCOUNT_ADDRESS);
    }
}