cml_core/
serialization.rs

1use crate::error::{DeserializeError, DeserializeFailure};
2use cbor_event::{de::Deserializer, se::Serializer, Sz};
3use std::io::{BufRead, Seek, Write};
4
5pub struct CBORReadLen {
6    deser_len: cbor_event::LenSz,
7    read: u64,
8}
9
10impl CBORReadLen {
11    pub fn new(len: cbor_event::LenSz) -> Self {
12        Self {
13            deser_len: len,
14            read: 0,
15        }
16    }
17
18    pub fn read(&self) -> u64 {
19        self.read
20    }
21
22    // Marks {n} values as being read, and if we go past the available definite length
23    // given by the CBOR, we return an error.
24    pub fn read_elems(&mut self, count: usize) -> Result<(), DeserializeFailure> {
25        match self.deser_len {
26            cbor_event::LenSz::Len(n, _) => {
27                self.read += count as u64;
28                if self.read > n {
29                    Err(DeserializeFailure::DefiniteLenMismatch(n, None))
30                } else {
31                    Ok(())
32                }
33            }
34            cbor_event::LenSz::Indefinite => Ok(()),
35        }
36    }
37
38    pub fn finish(&self) -> Result<(), DeserializeFailure> {
39        match self.deser_len {
40            cbor_event::LenSz::Len(n, _) => {
41                if self.read == n {
42                    Ok(())
43                } else {
44                    Err(DeserializeFailure::DefiniteLenMismatch(n, Some(self.read)))
45                }
46            }
47            cbor_event::LenSz::Indefinite => Ok(()),
48        }
49    }
50}
51
52impl From<cbor_event::Len> for CBORReadLen {
53    // to facilitate mixing with crates that use preserve-encodings=false to generate
54    // we need to create it from cbor_event::Len instead
55    fn from(len: cbor_event::Len) -> Self {
56        Self::new(len_to_len_sz(len))
57    }
58}
59
60pub fn len_to_len_sz(len: cbor_event::Len) -> cbor_event::LenSz {
61    match len {
62        cbor_event::Len::Len(n) => cbor_event::LenSz::Len(n, fit_sz(n, None, true)),
63        cbor_event::Len::Indefinite => cbor_event::LenSz::Indefinite,
64    }
65}
66
67pub trait DeserializeEmbeddedGroup {
68    fn deserialize_as_embedded_group<R: BufRead + Seek>(
69        raw: &mut Deserializer<R>,
70        read_len: &mut CBORReadLen,
71        len: cbor_event::LenSz,
72    ) -> Result<Self, DeserializeError>
73    where
74        Self: Sized;
75}
76
77#[inline]
78pub fn sz_max(sz: cbor_event::Sz) -> u64 {
79    match sz {
80        Sz::Inline => 23u64,
81        Sz::One => u8::MAX as u64,
82        Sz::Two => u16::MAX as u64,
83        Sz::Four => u32::MAX as u64,
84        Sz::Eight => u64::MAX,
85    }
86}
87
88#[derive(Debug, PartialEq, Eq, Copy, Clone)]
89pub enum LenEncoding {
90    Canonical,
91    Definite(cbor_event::Sz),
92    Indefinite,
93}
94
95impl Default for LenEncoding {
96    fn default() -> Self {
97        Self::Canonical
98    }
99}
100
101impl From<cbor_event::LenSz> for LenEncoding {
102    fn from(len_sz: cbor_event::LenSz) -> Self {
103        match len_sz {
104            cbor_event::LenSz::Len(len, sz) => {
105                if cbor_event::Sz::canonical(len) == sz {
106                    Self::Canonical
107                } else {
108                    Self::Definite(sz)
109                }
110            }
111            cbor_event::LenSz::Indefinite => Self::Indefinite,
112        }
113    }
114}
115
116#[derive(Debug, PartialEq, Eq, Clone)]
117pub enum StringEncoding {
118    Canonical,
119    Indefinite(Vec<(u64, Sz)>),
120    Definite(Sz),
121}
122
123impl Default for StringEncoding {
124    fn default() -> Self {
125        Self::Canonical
126    }
127}
128
129impl From<cbor_event::StringLenSz> for StringEncoding {
130    fn from(len_sz: cbor_event::StringLenSz) -> Self {
131        match len_sz {
132            cbor_event::StringLenSz::Len(sz) => Self::Definite(sz),
133            cbor_event::StringLenSz::Indefinite(lens) => Self::Indefinite(lens),
134        }
135    }
136}
137
138#[inline]
139pub fn fit_sz(len: u64, sz: Option<cbor_event::Sz>, force_canonical: bool) -> Sz {
140    match sz {
141        Some(sz) => {
142            if !force_canonical && len <= sz_max(sz) {
143                sz
144            } else {
145                Sz::canonical(len)
146            }
147        }
148        None => Sz::canonical(len),
149    }
150}
151
152impl LenEncoding {
153    pub fn to_len_sz(&self, len: u64, force_canonical: bool) -> cbor_event::LenSz {
154        if force_canonical {
155            cbor_event::LenSz::Len(len, cbor_event::Sz::canonical(len))
156        } else {
157            match self {
158                Self::Canonical => cbor_event::LenSz::Len(len, cbor_event::Sz::canonical(len)),
159                Self::Definite(sz) => {
160                    if sz_max(*sz) >= len {
161                        cbor_event::LenSz::Len(len, *sz)
162                    } else {
163                        cbor_event::LenSz::Len(len, cbor_event::Sz::canonical(len))
164                    }
165                }
166                Self::Indefinite => cbor_event::LenSz::Indefinite,
167            }
168        }
169    }
170
171    pub fn end<'a, W: Write + Sized>(
172        &self,
173        serializer: &'a mut Serializer<W>,
174        force_canonical: bool,
175    ) -> cbor_event::Result<&'a mut Serializer<W>> {
176        if !force_canonical && *self == Self::Indefinite {
177            serializer.write_special(cbor_event::Special::Break)?;
178        }
179        Ok(serializer)
180    }
181}
182
183impl StringEncoding {
184    pub fn to_str_len_sz(&self, len: u64, force_canonical: bool) -> cbor_event::StringLenSz {
185        if force_canonical {
186            cbor_event::StringLenSz::Len(cbor_event::Sz::canonical(len))
187        } else {
188            match self {
189                Self::Canonical => cbor_event::StringLenSz::Len(cbor_event::Sz::canonical(len)),
190                Self::Definite(sz) => {
191                    if sz_max(*sz) >= len {
192                        cbor_event::StringLenSz::Len(*sz)
193                    } else {
194                        cbor_event::StringLenSz::Len(cbor_event::Sz::canonical(len))
195                    }
196                }
197                Self::Indefinite(lens) => cbor_event::StringLenSz::Indefinite(lens.clone()),
198            }
199        }
200    }
201}
202
203pub trait Serialize {
204    fn serialize<'a, W: Write + Sized>(
205        &self,
206        serializer: &'a mut Serializer<W>,
207        force_canonical: bool,
208    ) -> cbor_event::Result<&'a mut Serializer<W>>;
209
210    /// Bytes of a structure using the CBOR bytes as per the CDDL spec
211    /// which for foo = bytes will include the CBOR bytes type/len, etc.
212    /// This gives the original bytes in the case where this was created
213    /// from bytes originally, or will use whatever the specific encoding
214    /// details are present in any encoding details struct for the type.
215    fn to_cbor_bytes(&self) -> Vec<u8> {
216        let mut buf = Serializer::new_vec();
217        self.serialize(&mut buf, false).unwrap();
218        buf.finalize()
219    }
220
221    /// Bytes of a structure using the CBOR bytes as per the CDDL spec
222    /// which for foo = bytes will include the CBOR bytes type/len, etc.
223    /// This gives the canonically encoded CBOR bytes always
224    fn to_canonical_cbor_bytes(&self) -> Vec<u8> {
225        let mut buf = Serializer::new_vec();
226        self.serialize(&mut buf, true).unwrap();
227        buf.finalize()
228    }
229}
230
231pub trait SerializeEmbeddedGroup {
232    fn serialize_as_embedded_group<'a, W: Write + Sized>(
233        &self,
234        serializer: &'a mut Serializer<W>,
235        force_canonical: bool,
236    ) -> cbor_event::Result<&'a mut Serializer<W>>;
237}
238
239pub trait Deserialize {
240    fn deserialize<R: BufRead + Seek>(raw: &mut Deserializer<R>) -> Result<Self, DeserializeError>
241    where
242        Self: Sized;
243
244    /// from-bytes using the exact CBOR format specified in the CDDL binary spec.
245    /// For hashes/addresses/etc this will include the CBOR bytes type/len/etc.
246    fn from_cbor_bytes(data: &[u8]) -> Result<Self, DeserializeError>
247    where
248        Self: Sized,
249    {
250        let mut raw = Deserializer::from(std::io::Cursor::new(data));
251        Self::deserialize(&mut raw)
252    }
253}
254
255// TODO: remove ToBytes / FromBytes after we regenerate the WASM wrappers.
256// This is so the existing generated to/from bytes code works
257// We are, however, using this in CIP25 as a way to get to bytes without
258// caring about the encoding. We could move it to there or make it more explicit
259// that this does not preserve encodings OR do canonical - it's just whatever
260// CBOR format. All other parts of CML implement our own Serialize trait with
261// the assumption that we preserve encodings. This is based off of cbor_event's
262pub trait ToBytes {
263    fn to_bytes(&self) -> Vec<u8>;
264}
265
266impl<T: cbor_event::se::Serialize> ToBytes for T {
267    fn to_bytes(&self) -> Vec<u8> {
268        let mut buf = Serializer::new_vec();
269        self.serialize(&mut buf).unwrap();
270        buf.finalize()
271    }
272}
273
274// TODO: remove ToBytes / FromBytes after we regenerate the WASM wrappers.
275// This is just so the existing generated to/from bytes code works
276pub trait FromBytes {
277    fn from_bytes(data: Vec<u8>) -> Result<Self, DeserializeError>
278    where
279        Self: Sized;
280}
281
282impl<T: Deserialize> FromBytes for T {
283    fn from_bytes(data: Vec<u8>) -> Result<Self, DeserializeError>
284    where
285        Self: Sized,
286    {
287        let mut raw = Deserializer::from(std::io::Cursor::new(data));
288        Self::deserialize(&mut raw).map_err(Into::into)
289    }
290}
291pub trait RawBytesEncoding {
292    fn to_raw_bytes(&self) -> &[u8];
293
294    fn from_raw_bytes(bytes: &[u8]) -> Result<Self, DeserializeError>
295    where
296        Self: Sized;
297
298    fn to_raw_hex(&self) -> String {
299        hex::encode(self.to_raw_bytes())
300    }
301
302    fn from_raw_hex(hex_str: &str) -> Result<Self, DeserializeError>
303    where
304        Self: Sized,
305    {
306        let bytes = hex::decode(hex_str).map_err(|e| {
307            DeserializeError::from(DeserializeFailure::InvalidStructure(Box::new(e)))
308        })?;
309        Self::from_raw_bytes(bytes.as_ref())
310    }
311}