use core::mem::MaybeUninit;
use crate::{
account::AccountView,
address::Address,
error::ProgramError,
result::ProgramResult,
};
pub struct DynCpi<'a, const MAX_ACCTS: usize, const MAX_DATA: usize> {
program_id: &'a Address,
accounts: [MaybeUninit<&'a AccountView>; MAX_ACCTS],
writable: [bool; MAX_ACCTS],
signer: [bool; MAX_ACCTS],
account_count: usize,
data: [MaybeUninit<u8>; MAX_DATA],
data_len: usize,
}
impl<'a, const MAX_ACCTS: usize, const MAX_DATA: usize> DynCpi<'a, MAX_ACCTS, MAX_DATA> {
#[inline]
pub fn new(program_id: &'a Address) -> Self {
Self {
program_id,
accounts: [const { MaybeUninit::uninit() }; MAX_ACCTS],
writable: [false; MAX_ACCTS],
signer: [false; MAX_ACCTS],
account_count: 0,
data: [const { MaybeUninit::uninit() }; MAX_DATA],
data_len: 0,
}
}
#[inline]
pub fn push_account(
&mut self,
account: &'a AccountView,
writable: bool,
signer: bool,
) -> ProgramResult {
if self.account_count >= MAX_ACCTS {
return Err(ProgramError::InvalidArgument);
}
self.accounts[self.account_count] = MaybeUninit::new(account);
self.writable[self.account_count] = writable;
self.signer[self.account_count] = signer;
self.account_count = self.account_count.wrapping_add(1);
Ok(())
}
#[inline]
pub fn push_data(&mut self, bytes: &[u8]) -> ProgramResult {
if self.data_len.saturating_add(bytes.len()) > MAX_DATA {
return Err(ProgramError::InvalidArgument);
}
let dst = &mut self.data[self.data_len..self.data_len + bytes.len()];
for (i, b) in bytes.iter().enumerate() {
dst[i] = MaybeUninit::new(*b);
}
self.data_len = self.data_len.wrapping_add(bytes.len());
Ok(())
}
#[inline]
pub fn push_byte(&mut self, byte: u8) -> ProgramResult {
self.push_data(core::slice::from_ref(&byte))
}
#[inline]
pub fn push_u64_le(&mut self, value: u64) -> ProgramResult {
self.push_data(&value.to_le_bytes())
}
#[inline]
pub fn push_pubkey(&mut self, address: &Address) -> ProgramResult {
self.push_data(address.as_array())
}
#[inline(always)]
pub const fn account_count(&self) -> usize {
self.account_count
}
#[inline(always)]
pub const fn program_id(&self) -> &Address {
self.program_id
}
#[inline(always)]
pub const fn data_len(&self) -> usize {
self.data_len
}
#[inline]
pub fn data(&self) -> &[u8] {
unsafe {
core::slice::from_raw_parts(
self.data.as_ptr() as *const u8,
self.data_len,
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn byte_push_walks_the_buffer() {
let program = Address::from([0u8; 32]);
let mut cpi: DynCpi<4, 32> = DynCpi::new(&program);
cpi.push_byte(0xA1).unwrap();
cpi.push_u64_le(0xCAFEBABE_u64).unwrap();
assert_eq!(cpi.data_len(), 1 + 8);
assert_eq!(cpi.data()[0], 0xA1);
assert_eq!(
&cpi.data()[1..9],
&0xCAFEBABE_u64.to_le_bytes()
);
}
#[test]
fn data_overflow_rejects() {
let program = Address::from([0u8; 32]);
let mut cpi: DynCpi<0, 4> = DynCpi::new(&program);
cpi.push_u64_le(1).expect_err("u64 is 8 bytes, buffer is 4");
}
#[test]
fn push_pubkey_fills_32_bytes() {
let program = Address::from([0u8; 32]);
let mut cpi: DynCpi<0, 64> = DynCpi::new(&program);
let pk = Address::from([0x7Au8; 32]);
cpi.push_pubkey(&pk).unwrap();
assert_eq!(cpi.data_len(), 32);
assert!(cpi.data().iter().all(|b| *b == 0x7A));
}
}