use std::marker::PhantomData;
use crate::Visitor;
pub struct ZeroCopyVisitor<T>(PhantomData<T>);
impl<T> Default for ZeroCopyVisitor<T> {
fn default() -> Self {
Self(Default::default())
}
}
impl<T> Visitor for ZeroCopyVisitor<T>
where
T: anchor_lang::ZeroCopy,
{
type Value = T;
fn visit_bytes(self, data: &[u8]) -> Result<Self::Value, crate::DecodeError> {
use anchor_lang::prelude::{Error, ErrorCode};
use bytemuck::PodCastError;
let disc = T::DISCRIMINATOR;
if data.len() < disc.len() {
return Err(Error::from(ErrorCode::AccountDiscriminatorNotFound).into());
}
let given_disc = &data[..8];
if disc != given_disc {
return Err(Error::from(ErrorCode::AccountDiscriminatorMismatch).into());
}
let end = std::mem::size_of::<T>() + 8;
if data.len() < end {
return Err(Error::from(ErrorCode::AccountDidNotDeserialize).into());
}
let data_without_discriminator = &data[8..end];
match bytemuck::try_from_bytes(data_without_discriminator) {
Ok(data) => Ok(*data),
Err(PodCastError::TargetAlignmentGreaterAndInputNotAligned) => {
bytemuck::try_pod_read_unaligned(data_without_discriminator)
.map_err(|_| Error::from(ErrorCode::AccountDidNotDeserialize).into())
}
Err(error) => Err(crate::DecodeError::custom(format!("bytemuck: {error}"))),
}
}
}
#[macro_export]
macro_rules! impl_decode_for_zero_copy {
($decoded:ty) => {
impl $crate::Decode for $decoded {
fn decode<D: $crate::Decoder>(decoder: D) -> Result<Self, $crate::DecodeError> {
decoder.decode_bytes($crate::value::ZeroCopyVisitor::<$decoded>::default())
}
}
};
}
pub struct AccountDeserializeVisitor<T>(PhantomData<T>);
impl<T> Default for AccountDeserializeVisitor<T> {
fn default() -> Self {
Self(Default::default())
}
}
impl<T> Visitor for AccountDeserializeVisitor<T>
where
T: anchor_lang::AccountDeserialize,
{
type Value = T;
fn visit_bytes(self, mut data: &[u8]) -> Result<Self::Value, crate::DecodeError> {
Ok(T::try_deserialize(&mut data)?)
}
}
#[macro_export]
macro_rules! impl_decode_for_account_deserialize {
($decoded:ty) => {
impl $crate::Decode for $decoded {
fn decode<D: $crate::Decoder>(decoder: D) -> Result<Self, $crate::DecodeError> {
decoder
.decode_bytes($crate::value::AccountDeserializeVisitor::<$decoded>::default())
}
}
};
}
pub struct CPIEventVisitor<T>(PhantomData<T>);
impl<T> Default for CPIEventVisitor<T> {
fn default() -> Self {
Self(Default::default())
}
}
impl<T> Visitor for CPIEventVisitor<T>
where
T: anchor_lang::Event,
{
type Value = T;
fn visit_bytes(self, data: &[u8]) -> Result<Self::Value, crate::DecodeError> {
use anchor_lang::{
event::EVENT_IX_TAG_LE,
prelude::{Error, ErrorCode},
};
if data.len() < EVENT_IX_TAG_LE.len() {
return Err(Error::from(ErrorCode::InstructionDidNotDeserialize).into());
}
let given_tag = &data[..8];
if given_tag != EVENT_IX_TAG_LE {
return Err(crate::DecodeError::custom("not an anchor event ix"));
}
let data = &data[8..];
let disc = T::DISCRIMINATOR;
if data.len() < disc.len() {
return Err(Error::from(ErrorCode::InstructionDidNotDeserialize).into());
}
let given_disc = &data[..8];
if disc != given_disc {
return Err(Error::from(ErrorCode::InstructionDidNotDeserialize).into());
}
Ok(T::try_from_slice(&data[8..]).map_err(anchor_lang::prelude::Error::from)?)
}
}
#[macro_export]
macro_rules! impl_decode_for_cpi_event {
($decoded:ty) => {
impl $crate::Decode for $decoded {
fn decode<D: $crate::Decoder>(decoder: D) -> Result<Self, $crate::DecodeError> {
decoder.decode_bytes($crate::value::CPIEventVisitor::<$decoded>::default())
}
}
};
}