use bytemuck::Pod;
use solana_program::program_error::ProgramError;
pub trait Discriminator {
fn discriminator() -> u8;
}
pub trait AccountDeserialize {
fn try_from_bytes(data: &[u8]) -> Result<&Self, ProgramError>;
fn try_from_bytes_mut(data: &mut [u8]) -> Result<&mut Self, ProgramError>;
}
impl<T> AccountDeserialize for T
where
T: Discriminator + Pod,
{
fn try_from_bytes(data: &[u8]) -> Result<&Self, ProgramError> {
if Self::discriminator().ne(&data[0]) {
return Err(solana_program::program_error::ProgramError::InvalidAccountData);
}
bytemuck::try_from_bytes::<Self>(&data[8..]).or(Err(
solana_program::program_error::ProgramError::InvalidAccountData,
))
}
fn try_from_bytes_mut(data: &mut [u8]) -> Result<&mut Self, ProgramError> {
if Self::discriminator().ne(&data[0]) {
return Err(solana_program::program_error::ProgramError::InvalidAccountData);
}
bytemuck::try_from_bytes_mut::<Self>(&mut data[8..]).or(Err(
solana_program::program_error::ProgramError::InvalidAccountData,
))
}
}
pub trait AccountHeaderDeserialize {
fn try_header_from_bytes(data: &[u8]) -> Result<(&Self, &[u8]), ProgramError>;
fn try_header_from_bytes_mut(data: &mut [u8]) -> Result<(&mut Self, &mut [u8]), ProgramError>;
}
impl<T> AccountHeaderDeserialize for T
where
T: Discriminator + Pod,
{
fn try_header_from_bytes(data: &[u8]) -> Result<(&Self, &[u8]), ProgramError> {
if Self::discriminator().ne(&data[0]) {
return Err(solana_program::program_error::ProgramError::InvalidAccountData);
}
let (prefix, remainder) = data[8..].split_at(std::mem::size_of::<T>());
Ok((
bytemuck::try_from_bytes::<Self>(prefix).or(Err(
solana_program::program_error::ProgramError::InvalidAccountData,
))?,
remainder,
))
}
fn try_header_from_bytes_mut(data: &mut [u8]) -> Result<(&mut Self, &mut [u8]), ProgramError> {
let (prefix, remainder) = data[8..].split_at_mut(std::mem::size_of::<T>());
Ok((
bytemuck::try_from_bytes_mut::<Self>(prefix).or(Err(
solana_program::program_error::ProgramError::InvalidAccountData,
))?,
remainder,
))
}
}
#[cfg(test)]
mod tests {
use crate::AccountDeserialize;
use super::*;
use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[derive(Copy, Clone)]
struct GenericallySizedType<const N: usize> {
field: [u32; N],
}
unsafe impl<const N: usize> Zeroable for GenericallySizedType<N> {}
unsafe impl<const N: usize> Pod for GenericallySizedType<N> {}
#[repr(C)]
#[derive(Copy, Clone, Zeroable, Pod)]
struct GenericallySizedTypeHeader {
field_len: u64,
}
impl Discriminator for GenericallySizedTypeHeader {
fn discriminator() -> u8 {
0
}
}
#[test]
fn account_headers() {
let mut data = [0u8; 32];
data[8] = 4;
data[16] = 5;
let (_foo_header, foo) = GenericallySizedTypeHeader::try_header_from_bytes(&data)
.map(|(header, remainder)| {
let foo = match header.field_len {
4 => bytemuck::try_from_bytes::<GenericallySizedType<4>>(remainder).unwrap(),
x => panic!("{}", format!("unknown field len, {x}")),
};
(header, foo)
})
.unwrap();
assert_eq!(5, foo.field[0]);
}
#[repr(C)]
#[derive(Copy, Clone, Zeroable, Pod)]
struct TestType {
field0: u64,
field1: u64,
}
impl Discriminator for TestType {
fn discriminator() -> u8 {
7
}
}
#[test]
fn account_deserialize() {
let mut data = [0u8; 24];
data[0] = 7;
data[8] = 42;
data[16] = 43;
let foo = TestType::try_from_bytes(&data).unwrap();
assert_eq!(42, foo.field0);
assert_eq!(43, foo.field1);
}
}