gmsol_decode/value/
anchor.rs

1use std::marker::PhantomData;
2
3use crate::Visitor;
4
5/// Visitor that produces a [`ZeroCopy`](anchor_lang::ZeroCopy).
6pub struct ZeroCopyVisitor<T>(PhantomData<T>);
7
8impl<T> Default for ZeroCopyVisitor<T> {
9    fn default() -> Self {
10        Self(Default::default())
11    }
12}
13
14impl<T> Visitor for ZeroCopyVisitor<T>
15where
16    T: anchor_lang::ZeroCopy,
17{
18    type Value = T;
19
20    fn visit_bytes(self, data: &[u8]) -> Result<Self::Value, crate::DecodeError> {
21        use anchor_lang::prelude::{Error, ErrorCode};
22        use bytemuck::PodCastError;
23
24        let disc = T::DISCRIMINATOR;
25        if data.len() < disc.len() {
26            return Err(Error::from(ErrorCode::AccountDiscriminatorNotFound).into());
27        }
28        let given_disc = &data[..8];
29        if disc != given_disc {
30            return Err(Error::from(ErrorCode::AccountDiscriminatorMismatch).into());
31        }
32        let end = std::mem::size_of::<T>() + 8;
33        if data.len() < end {
34            return Err(Error::from(ErrorCode::AccountDidNotDeserialize).into());
35        }
36        let data_without_discriminator = &data[8..end];
37
38        match bytemuck::try_from_bytes(data_without_discriminator) {
39            Ok(data) => Ok(*data),
40            Err(PodCastError::TargetAlignmentGreaterAndInputNotAligned) => {
41                bytemuck::try_pod_read_unaligned(data_without_discriminator)
42                    .map_err(|_| Error::from(ErrorCode::AccountDidNotDeserialize).into())
43            }
44            Err(error) => Err(crate::DecodeError::custom(format!("bytemuck: {error}"))),
45        }
46    }
47}
48
49/// Implement [`Decode`](crate::Decode) for [`ZeroCopy`](anchor_lang::ZeroCopy).
50#[macro_export]
51macro_rules! impl_decode_for_zero_copy {
52    ($decoded:ty) => {
53        impl $crate::Decode for $decoded {
54            fn decode<D: $crate::Decoder>(decoder: D) -> Result<Self, $crate::DecodeError> {
55                decoder.decode_bytes($crate::value::ZeroCopyVisitor::<$decoded>::default())
56            }
57        }
58    };
59}
60
61/// Visitor that produces an [`AccountDeserialize`](anchor_lang::AccountDeserialize).
62pub struct AccountDeserializeVisitor<T>(PhantomData<T>);
63
64impl<T> Default for AccountDeserializeVisitor<T> {
65    fn default() -> Self {
66        Self(Default::default())
67    }
68}
69
70impl<T> Visitor for AccountDeserializeVisitor<T>
71where
72    T: anchor_lang::AccountDeserialize,
73{
74    type Value = T;
75
76    fn visit_bytes(self, mut data: &[u8]) -> Result<Self::Value, crate::DecodeError> {
77        Ok(T::try_deserialize(&mut data)?)
78    }
79}
80
81/// Implement [`Decode`](crate::Decode) for [`AccountDeserialize`](anchor_lang::AccountDeserialize).
82#[macro_export]
83macro_rules! impl_decode_for_account_deserialize {
84    ($decoded:ty) => {
85        impl $crate::Decode for $decoded {
86            fn decode<D: $crate::Decoder>(decoder: D) -> Result<Self, $crate::DecodeError> {
87                decoder
88                    .decode_bytes($crate::value::AccountDeserializeVisitor::<$decoded>::default())
89            }
90        }
91    };
92}
93
94/// Visitor that produces an CPI [`Event`](anchor_lang::Event).
95pub struct CPIEventVisitor<T>(PhantomData<T>);
96
97impl<T> Default for CPIEventVisitor<T> {
98    fn default() -> Self {
99        Self(Default::default())
100    }
101}
102
103impl<T> Visitor for CPIEventVisitor<T>
104where
105    T: anchor_lang::Event,
106{
107    type Value = T;
108
109    fn visit_bytes(self, data: &[u8]) -> Result<Self::Value, crate::DecodeError> {
110        use anchor_lang::{
111            event::EVENT_IX_TAG_LE,
112            prelude::{Error, ErrorCode},
113        };
114
115        // Valdiate the ix tag.
116        if data.len() < EVENT_IX_TAG_LE.len() {
117            return Err(Error::from(ErrorCode::InstructionDidNotDeserialize).into());
118        }
119        let given_tag = &data[..8];
120        if given_tag != EVENT_IX_TAG_LE {
121            return Err(crate::DecodeError::custom("not an anchor event ix"));
122        }
123
124        let data = &data[8..];
125
126        // Validate the discriminator.
127        let disc = T::DISCRIMINATOR;
128        if data.len() < disc.len() {
129            return Err(Error::from(ErrorCode::InstructionDidNotDeserialize).into());
130        }
131        let given_disc = &data[..8];
132        if disc != given_disc {
133            return Err(Error::from(ErrorCode::InstructionDidNotDeserialize).into());
134        }
135
136        // Deserialize.
137        Ok(T::try_from_slice(&data[8..]).map_err(anchor_lang::prelude::Error::from)?)
138    }
139}
140
141/// Implement [`Decode`](crate::Decode) for CPI events.
142#[macro_export]
143macro_rules! impl_decode_for_cpi_event {
144    ($decoded:ty) => {
145        impl $crate::Decode for $decoded {
146            fn decode<D: $crate::Decoder>(decoder: D) -> Result<Self, $crate::DecodeError> {
147                decoder.decode_bytes($crate::value::CPIEventVisitor::<$decoded>::default())
148            }
149        }
150    };
151}