Skip to main content

ark_serialize/impls/
collections.rs

1use crate::{
2    CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate,
3};
4use ark_std::{
5    borrow::*,
6    collections::{BTreeMap, BTreeSet, LinkedList, VecDeque},
7    io::{Read, Write},
8    string::*,
9    vec::*,
10};
11
12macro_rules! impl_valid_seq {
13    ($type:ty  $( ; $($extra:tt)+ )?) => {
14        impl<T: Valid $( , $($extra)+ )? > Valid for $type {
15            const TRIVIAL_CHECK: bool = T::TRIVIAL_CHECK;
16
17            #[inline]
18            fn check(&self) -> Result<(), SerializationError> {
19                if Self::TRIVIAL_CHECK {
20                    Ok(())
21                } else {
22                    T::batch_check(self.iter())
23                }
24            }
25
26            #[inline]
27            fn batch_check<'a>(
28                batch: impl Iterator<Item = &'a Self> + Send,
29            ) -> Result<(), SerializationError>
30            where
31                Self: 'a,
32            {
33                if Self::TRIVIAL_CHECK {
34                    Ok(())
35                } else {
36                    T::batch_check(batch.flat_map(|v| v.iter()))
37                }
38            }
39        }
40
41    };
42}
43
44macro_rules! impl_canonical_serialize_seq {
45    ($type:ty) => {
46        impl<T: CanonicalSerialize> CanonicalSerialize for $type {
47            #[inline]
48            fn serialize_with_mode<W: Write>(
49                &self,
50                mut writer: W,
51                compress: Compress,
52            ) -> Result<(), SerializationError> {
53                let len = self.len() as u64;
54                len.serialize_with_mode(&mut writer, compress)?;
55                for item in self.iter() {
56                    item.borrow().serialize_with_mode(&mut writer, compress)?;
57                }
58                Ok(())
59            }
60
61            #[inline]
62            fn serialized_size(&self, compress: Compress) -> usize {
63                8 + self
64                    .iter()
65                    .map(|item| item.borrow().serialized_size(compress))
66                    .sum::<usize>()
67            }
68        }
69    };
70}
71
72macro_rules! impl_canonical_deserialize_seq {
73    ($type:ty $( ; $($t_bounds:tt)+ )?) => {
74        impl<T> CanonicalDeserialize for $type
75        where
76            T: CanonicalDeserialize $(+ $($t_bounds)+ )?
77        {
78            #[inline]
79            fn deserialize_with_mode<R: Read>(mut reader: R, compress: Compress, validate: Validate) -> Result<Self, SerializationError> {
80                let len = u64::deserialize_with_mode(&mut reader, compress, validate)?
81                    .try_into()
82                    .map_err(|_| SerializationError::NotEnoughSpace)?;
83
84                let values = (0..len)
85                    .map(|_| T::deserialize_with_mode(&mut reader, compress, Validate::No))
86                    .collect::<Result<Self, SerializationError>>()?;
87
88                if validate == Validate::Yes {
89                    T::batch_check(values.iter())?;
90                }
91                Ok(values)
92            }
93        }
94    };
95}
96
97impl<T: CanonicalSerialize, const N: usize> CanonicalSerialize for [T; N] {
98    #[inline]
99    fn serialize_with_mode<W: Write>(
100        &self,
101        mut writer: W,
102        compress: Compress,
103    ) -> Result<(), SerializationError> {
104        for item in self {
105            item.serialize_with_mode(&mut writer, compress)?;
106        }
107        Ok(())
108    }
109
110    #[inline]
111    fn serialized_size(&self, compress: Compress) -> usize {
112        self.iter()
113            .map(|item| item.serialized_size(compress))
114            .sum::<usize>()
115    }
116}
117impl_valid_seq!([T; N]; const N: usize);
118
119impl<T: CanonicalDeserialize, const N: usize> CanonicalDeserialize for [T; N] {
120    #[inline]
121    #[allow(unsafe_code)]
122    fn deserialize_with_mode<R: Read>(
123        mut reader: R,
124        compress: Compress,
125        validate: Validate,
126    ) -> Result<Self, SerializationError> {
127        use core::mem::MaybeUninit;
128        let mut data: [MaybeUninit<T>; N] = [const { MaybeUninit::uninit() }; N];
129        for elem in &mut data[..] {
130            elem.write(T::deserialize_with_mode(&mut reader, compress, validate)?);
131        }
132        Ok(data.map(|x| unsafe { x.assume_init() }))
133    }
134}
135
136impl<T: CanonicalSerialize> CanonicalSerialize for Vec<T> {
137    #[inline]
138    fn serialize_with_mode<W: Write>(
139        &self,
140        mut writer: W,
141        compress: Compress,
142    ) -> Result<(), SerializationError> {
143        self.as_slice().serialize_with_mode(&mut writer, compress)
144    }
145
146    #[inline]
147    fn serialized_size(&self, compress: Compress) -> usize {
148        self.as_slice().serialized_size(compress)
149    }
150}
151
152impl_valid_seq!(Vec<T>);
153impl_canonical_deserialize_seq!(Vec<T>);
154
155impl_canonical_serialize_seq!(VecDeque<T>);
156impl_valid_seq!(VecDeque<T>);
157impl_canonical_deserialize_seq!(VecDeque<T>);
158
159impl_canonical_serialize_seq!(LinkedList<T>);
160impl_valid_seq!(LinkedList<T>);
161impl_canonical_deserialize_seq!(LinkedList<T>);
162
163impl_canonical_serialize_seq!([T]);
164impl_canonical_serialize_seq!(&[T]);
165impl_canonical_serialize_seq!(&mut [T]);
166
167impl_canonical_serialize_seq!(ark_std::boxed::Box<[T]>);
168impl_valid_seq!(ark_std::boxed::Box<[T]>);
169impl_canonical_deserialize_seq!(ark_std::boxed::Box<[T]>);
170
171impl_canonical_serialize_seq!(BTreeSet<T>);
172impl_valid_seq!(BTreeSet<T>);
173impl_canonical_deserialize_seq!(BTreeSet<T>; Ord);
174
175#[cfg(feature = "std")]
176impl_canonical_serialize_seq!(std::collections::HashSet<T>);
177#[cfg(feature = "std")]
178impl_valid_seq!(std::collections::HashSet<T>);
179#[cfg(feature = "std")]
180impl_canonical_deserialize_seq!(std::collections::HashSet<T>; core::hash::Hash + Eq);
181
182impl CanonicalSerialize for String {
183    #[inline]
184    fn serialize_with_mode<W: Write>(
185        &self,
186        mut writer: W,
187        compress: Compress,
188    ) -> Result<(), SerializationError> {
189        self.as_bytes().serialize_with_mode(&mut writer, compress)
190    }
191
192    #[inline]
193    fn serialized_size(&self, compress: Compress) -> usize {
194        self.as_bytes().serialized_size(compress)
195    }
196}
197
198impl Valid for String {
199    #[inline]
200    fn check(&self) -> Result<(), SerializationError> {
201        Ok(())
202    }
203}
204
205impl CanonicalDeserialize for String {
206    #[inline]
207    fn deserialize_with_mode<R: Read>(
208        reader: R,
209        compress: Compress,
210        validate: Validate,
211    ) -> Result<Self, SerializationError> {
212        let bytes = <Vec<u8>>::deserialize_with_mode(reader, compress, validate)?;
213        Self::from_utf8(bytes).map_err(|_| SerializationError::InvalidData)
214    }
215}
216
217impl<K, V> CanonicalSerialize for BTreeMap<K, V>
218where
219    K: CanonicalSerialize,
220    V: CanonicalSerialize,
221{
222    /// Serializes a `BTreeMap` as `len(map) || key 1 || value 1 || ... || key n || value n`.
223    fn serialize_with_mode<W: Write>(
224        &self,
225        mut writer: W,
226        compress: Compress,
227    ) -> Result<(), SerializationError> {
228        let len = self.len() as u64;
229        len.serialize_with_mode(&mut writer, compress)?;
230        for (k, v) in self {
231            k.serialize_with_mode(&mut writer, compress)?;
232            v.serialize_with_mode(&mut writer, compress)?;
233        }
234        Ok(())
235    }
236
237    fn serialized_size(&self, compress: Compress) -> usize {
238        8 + self
239            .iter()
240            .map(|(k, v)| k.serialized_size(compress) + v.serialized_size(compress))
241            .sum::<usize>()
242    }
243}
244
245impl<K: Valid, V: Valid> Valid for BTreeMap<K, V> {
246    const TRIVIAL_CHECK: bool = K::TRIVIAL_CHECK & V::TRIVIAL_CHECK;
247    #[inline]
248    fn check(&self) -> Result<(), SerializationError> {
249        if Self::TRIVIAL_CHECK {
250            return Ok(());
251        }
252        if !K::TRIVIAL_CHECK {
253            K::batch_check(self.keys())?;
254        }
255        if !V::TRIVIAL_CHECK {
256            V::batch_check(self.values())?;
257        }
258        Ok(())
259    }
260
261    #[inline]
262    fn batch_check<'a>(batch: impl Iterator<Item = &'a Self>) -> Result<(), SerializationError>
263    where
264        Self: 'a,
265    {
266        if Self::TRIVIAL_CHECK {
267            return Ok(());
268        }
269        let (keys, values): (Vec<_>, Vec<_>) = batch.map(|b| (b.keys(), b.values())).unzip();
270        if !K::TRIVIAL_CHECK {
271            K::batch_check(keys.into_iter().flatten())?;
272        }
273        if !V::TRIVIAL_CHECK {
274            V::batch_check(values.into_iter().flatten())?;
275        }
276        Ok(())
277    }
278}
279
280impl<K, V> CanonicalDeserialize for BTreeMap<K, V>
281where
282    K: Ord + CanonicalDeserialize,
283    V: CanonicalDeserialize,
284{
285    /// Deserializes a `BTreeMap` from `len(map) || key 1 || value 1 || ... || key n || value n`.
286    fn deserialize_with_mode<R: Read>(
287        mut reader: R,
288        compress: Compress,
289        validate: Validate,
290    ) -> Result<Self, SerializationError> {
291        let len = u64::deserialize_with_mode(&mut reader, compress, validate)?;
292        (0..len)
293            .map(|_| {
294                Ok((
295                    K::deserialize_with_mode(&mut reader, compress, validate)?,
296                    V::deserialize_with_mode(&mut reader, compress, validate)?,
297                ))
298            })
299            .collect()
300    }
301}
302
303#[cfg(feature = "std")]
304impl<K, V> CanonicalSerialize for std::collections::HashMap<K, V>
305where
306    K: CanonicalSerialize,
307    V: CanonicalSerialize,
308{
309    /// Serializes a `HashMap` as `len(map) || key 1 || value 1 || ... || key n || value n`.
310    fn serialize_with_mode<W: Write>(
311        &self,
312        mut writer: W,
313        compress: Compress,
314    ) -> Result<(), SerializationError> {
315        let len = self.len() as u64;
316        len.serialize_with_mode(&mut writer, compress)?;
317        for (k, v) in self {
318            k.serialize_with_mode(&mut writer, compress)?;
319            v.serialize_with_mode(&mut writer, compress)?;
320        }
321        Ok(())
322    }
323
324    fn serialized_size(&self, compress: Compress) -> usize {
325        8 + self
326            .iter()
327            .map(|(k, v)| k.serialized_size(compress) + v.serialized_size(compress))
328            .sum::<usize>()
329    }
330}
331
332#[cfg(feature = "std")]
333impl<K: Valid, V: Valid> Valid for std::collections::HashMap<K, V> {
334    const TRIVIAL_CHECK: bool = K::TRIVIAL_CHECK & V::TRIVIAL_CHECK;
335    #[inline]
336    fn check(&self) -> Result<(), SerializationError> {
337        if Self::TRIVIAL_CHECK {
338            return Ok(());
339        }
340        if !K::TRIVIAL_CHECK {
341            K::batch_check(self.keys())?;
342        }
343        if !V::TRIVIAL_CHECK {
344            V::batch_check(self.values())?;
345        }
346        Ok(())
347    }
348
349    #[inline]
350    fn batch_check<'a>(batch: impl Iterator<Item = &'a Self>) -> Result<(), SerializationError>
351    where
352        Self: 'a,
353    {
354        if Self::TRIVIAL_CHECK {
355            return Ok(());
356        }
357        let (keys, values): (Vec<_>, Vec<_>) = batch.map(|b| (b.keys(), b.values())).unzip();
358        if !K::TRIVIAL_CHECK {
359            K::batch_check(keys.into_iter().flatten())?;
360        }
361        if !V::TRIVIAL_CHECK {
362            V::batch_check(values.into_iter().flatten())?;
363        }
364        Ok(())
365    }
366}
367
368#[cfg(feature = "std")]
369impl<K, V> CanonicalDeserialize for std::collections::HashMap<K, V>
370where
371    K: core::hash::Hash + Eq + CanonicalDeserialize,
372    V: CanonicalDeserialize,
373{
374    /// Deserializes a `HashMap` from `len(map) || key 1 || value 1 || ... || key n || value n`.
375    fn deserialize_with_mode<R: Read>(
376        mut reader: R,
377        compress: Compress,
378        validate: Validate,
379    ) -> Result<Self, SerializationError> {
380        let len = u64::deserialize_with_mode(&mut reader, compress, validate)?;
381        (0..len)
382            .map(|_| {
383                Ok((
384                    K::deserialize_with_mode(&mut reader, compress, validate)?,
385                    V::deserialize_with_mode(&mut reader, compress, validate)?,
386                ))
387            })
388            .collect()
389    }
390}