data_anchor_utils/
encoding.rs

1#[derive(Debug, thiserror::Error)]
2pub enum DataAnchorEncodingError {
3    #[error("Postcard encoding error: {0}")]
4    Postcard(#[from] postcard::Error),
5
6    #[error("Bincode encoding error: {0}")]
7    Bincode(#[from] bincode::Error),
8
9    #[error("JSON encoding error: {0}")]
10    Json(#[from] serde_json::Error),
11
12    #[error("Unknown encoding type")]
13    UnknownEncodingType,
14
15    #[error("Encoding type mismatch expected: {0:?}, found: {1:?}")]
16    EncodingTypeMismatch(EncodingType, EncodingType),
17
18    #[error("No data to decode")]
19    NoDataToDecode,
20
21    #[cfg(feature = "borsh")]
22    #[error("Borsh encoding error: {0}")]
23    Borsh(#[from] borsh::io::Error),
24}
25
26pub type DataAnchorEncodingResult<T = ()> = Result<T, DataAnchorEncodingError>;
27
28#[cfg(not(feature = "borsh"))]
29mod _no_borsh {
30
31    pub trait Encodable: serde::ser::Serialize {}
32
33    impl<T: serde::ser::Serialize> Encodable for T {}
34
35    pub trait Decodable: serde::de::DeserializeOwned {}
36
37    impl<T: serde::de::DeserializeOwned> Decodable for T {}
38}
39
40#[cfg(not(feature = "borsh"))]
41pub use _no_borsh::*;
42
43#[cfg(feature = "borsh")]
44mod _with_borsh {
45    pub trait Encodable: serde::ser::Serialize + borsh::BorshSerialize {}
46
47    impl<T: serde::ser::Serialize + borsh::BorshSerialize> Encodable for T {}
48
49    pub trait Decodable: serde::de::DeserializeOwned + borsh::BorshDeserialize {}
50
51    impl<T: serde::de::DeserializeOwned + borsh::BorshDeserialize> Decodable for T {}
52}
53
54#[cfg(feature = "borsh")]
55pub use _with_borsh::*;
56
57pub trait DataAnchorEncoding {
58    fn encode<T: Encodable>(&self, data: &T) -> DataAnchorEncodingResult<Vec<u8>>;
59    fn decode<T: Decodable>(&self, data: &[u8]) -> DataAnchorEncodingResult<T>;
60}
61
62#[derive(
63    Debug, Clone, Copy, PartialEq, Eq, std::default::Default, serde::Serialize, serde::Deserialize,
64)]
65#[cfg_attr(
66    feature = "borsh",
67    derive(borsh::BorshSerialize, borsh::BorshDeserialize)
68)]
69#[repr(u8)]
70pub enum EncodingType {
71    #[default]
72    Postcard,
73    Bincode,
74    Json,
75    #[cfg(feature = "borsh")]
76    Borsh,
77}
78
79impl std::fmt::Display for EncodingType {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            EncodingType::Postcard => write!(f, "postcard"),
83            EncodingType::Bincode => write!(f, "bincode"),
84            EncodingType::Json => write!(f, "json"),
85            #[cfg(feature = "borsh")]
86            EncodingType::Borsh => write!(f, "borsh"),
87        }
88    }
89}
90
91impl TryFrom<u8> for EncodingType {
92    type Error = DataAnchorEncodingError;
93
94    fn try_from(value: u8) -> Result<Self, Self::Error> {
95        match value {
96            0 => Ok(EncodingType::Postcard),
97            1 => Ok(EncodingType::Bincode),
98            2 => Ok(EncodingType::Json),
99            #[cfg(feature = "borsh")]
100            3 => Ok(EncodingType::Borsh),
101            _ => Err(DataAnchorEncodingError::UnknownEncodingType),
102        }
103    }
104}
105
106impl EncodingType {
107    /// Add a marker byte to the beginning of the data to indicate the encoding type.
108    pub fn mark(self, data: Vec<u8>) -> Vec<u8> {
109        [[self as u8].to_vec(), data].concat()
110    }
111
112    /// Inspect encoding type from a byte slice.
113    pub fn inspect(data: &[u8]) -> DataAnchorEncodingResult<Self> {
114        let Some(encoding_type_byte) = data.first() else {
115            return Err(DataAnchorEncodingError::NoDataToDecode);
116        };
117
118        EncodingType::try_from(*encoding_type_byte)
119    }
120
121    /// Extract the encoding type and the data from a byte slice.
122    pub fn get_encoding_and_data(data: &[u8]) -> DataAnchorEncodingResult<(Self, &[u8])> {
123        let Some((encoding_type_byte, data)) = data.split_first() else {
124            return Err(DataAnchorEncodingError::NoDataToDecode);
125        };
126
127        let encoding_type = EncodingType::try_from(*encoding_type_byte)?;
128
129        Ok((encoding_type, data))
130    }
131
132    /// Assert that the encoding type matches the expected type.
133    pub fn assert_encoding_type<'a>(&self, data: &'a [u8]) -> DataAnchorEncodingResult<&'a [u8]> {
134        let (encoding_type, data) = Self::get_encoding_and_data(data)?;
135        if encoding_type != *self {
136            return Err(DataAnchorEncodingError::EncodingTypeMismatch(
137                *self,
138                encoding_type,
139            ));
140        }
141        Ok(data)
142    }
143}
144
145impl DataAnchorEncoding for EncodingType {
146    fn encode<T: Encodable>(&self, data: &T) -> DataAnchorEncodingResult<Vec<u8>> {
147        match self {
148            EncodingType::Postcard => Postcard.encode(data),
149            EncodingType::Bincode => Bincode.encode(data),
150            EncodingType::Json => Json.encode(data),
151            #[cfg(feature = "borsh")]
152            EncodingType::Borsh => Borsh.encode(data),
153        }
154    }
155
156    fn decode<T: Decodable>(&self, data: &[u8]) -> DataAnchorEncodingResult<T> {
157        let encoding_type = EncodingType::inspect(data)?;
158
159        match encoding_type {
160            EncodingType::Postcard => Postcard.decode(data),
161            EncodingType::Bincode => Bincode.decode(data),
162            EncodingType::Json => Json.decode(data),
163            #[cfg(feature = "borsh")]
164            EncodingType::Borsh => Borsh.decode(data),
165        }
166    }
167}
168
169#[derive(Debug, Clone, Copy, std::default::Default)]
170pub struct Postcard;
171
172pub use Postcard as Default;
173
174impl DataAnchorEncoding for Postcard {
175    fn encode<T: Encodable>(&self, data: &T) -> DataAnchorEncodingResult<Vec<u8>> {
176        Ok(EncodingType::Postcard.mark(postcard::to_allocvec(data)?))
177    }
178
179    fn decode<T: Decodable>(&self, data: &[u8]) -> DataAnchorEncodingResult<T> {
180        Ok(postcard::from_bytes(
181            EncodingType::Postcard.assert_encoding_type(data)?,
182        )?)
183    }
184}
185
186#[derive(Debug, Clone, Copy, std::default::Default)]
187pub struct Bincode;
188
189impl DataAnchorEncoding for Bincode {
190    fn encode<T: Encodable>(&self, data: &T) -> DataAnchorEncodingResult<Vec<u8>> {
191        Ok(EncodingType::Bincode.mark(bincode::serialize(data)?))
192    }
193
194    fn decode<T: Decodable>(&self, data: &[u8]) -> DataAnchorEncodingResult<T> {
195        Ok(bincode::deserialize(
196            EncodingType::Bincode.assert_encoding_type(data)?,
197        )?)
198    }
199}
200
201#[derive(Debug, Clone, Copy, std::default::Default)]
202pub struct Json;
203
204impl DataAnchorEncoding for Json {
205    fn encode<T: Encodable>(&self, data: &T) -> DataAnchorEncodingResult<Vec<u8>> {
206        Ok(EncodingType::Json.mark(serde_json::to_vec(data)?))
207    }
208
209    fn decode<T: Decodable>(&self, data: &[u8]) -> DataAnchorEncodingResult<T> {
210        Ok(serde_json::from_slice(
211            EncodingType::Json.assert_encoding_type(data)?,
212        )?)
213    }
214}
215
216#[cfg(feature = "borsh")]
217#[derive(Debug, Clone, Copy, std::default::Default)]
218pub struct Borsh;
219
220#[cfg(feature = "borsh")]
221impl DataAnchorEncoding for Borsh {
222    fn encode<T: Encodable>(&self, data: &T) -> DataAnchorEncodingResult<Vec<u8>> {
223        Ok(EncodingType::Borsh.mark(borsh::to_vec(data)?))
224    }
225
226    fn decode<T: Decodable>(&self, data: &[u8]) -> DataAnchorEncodingResult<T> {
227        Ok(borsh::from_slice(
228            EncodingType::Borsh.assert_encoding_type(data)?,
229        )?)
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use rstest::rstest;
236
237    use super::*;
238
239    #[derive(
240        Debug,
241        PartialEq,
242        serde::Serialize,
243        serde::Deserialize,
244        borsh::BorshSerialize,
245        borsh::BorshDeserialize,
246    )]
247    pub struct TestStruct {
248        pub field1: String,
249        pub field2: u32,
250    }
251
252    #[rstest]
253    #[case::string("Hello, World!".to_string())]
254    #[case::bytes(vec![20, 30])]
255    #[case::tuple((1, 2, 3))]
256    #[case::bool(true)]
257    #[case::json(TestStruct {
258        field1: "Test".to_string(),
259        field2: 42,
260    })]
261    fn test_encoding<T, E>(
262        #[case] data: T,
263        #[values(Default, Postcard, Bincode, Json, Borsh, EncodingType::default())] encoding: E,
264    ) where
265        T: Encodable + Decodable + PartialEq + std::fmt::Debug,
266        E: DataAnchorEncoding,
267    {
268        let encoded = encoding.encode(&data).unwrap();
269        let decoded: T = encoding.decode(&encoded).unwrap();
270        assert_eq!(data, decoded);
271    }
272}