Skip to main content

provekit_common/utils/
serde_ark_vec.rs

1use {
2    crate::FieldElement,
3    ark_serialize::{CanonicalDeserialize, CanonicalSerialize},
4    serde::{
5        de::{Error as _, SeqAccess, Visitor},
6        ser::{Error as _, SerializeSeq},
7        Deserializer, Serializer,
8    },
9    std::fmt,
10};
11
12pub fn serialize<S>(vec: &Vec<FieldElement>, serializer: S) -> Result<S::Ok, S::Error>
13where
14    S: Serializer,
15{
16    let is_human_readable = serializer.is_human_readable();
17    let mut seq = serializer.serialize_seq(Some(vec.len()))?;
18    for element in vec {
19        let mut buf = Vec::with_capacity(element.compressed_size());
20        element
21            .serialize_compressed(&mut buf)
22            .map_err(|e| S::Error::custom(format!("Failed to serialize: {e}")))?;
23
24        // Write bytes
25        if is_human_readable {
26            // ark_serialize doesn't have human-readable serialization. And Serde
27            // doesn't have good defaults for [u8]. So we implement hexadecimal
28            // serialization.
29            let hex = hex::encode(buf);
30            seq.serialize_element(&hex)?;
31        } else {
32            seq.serialize_element(&buf)?;
33        }
34    }
35    seq.end()
36}
37
38pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<FieldElement>, D::Error>
39where
40    D: Deserializer<'de>,
41{
42    struct VecVisitor {
43        is_human_readable: bool,
44    }
45
46    impl<'de> Visitor<'de> for VecVisitor {
47        type Value = Vec<FieldElement>;
48
49        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
50            formatter.write_str("a sequence of field elements")
51        }
52
53        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
54        where
55            A: SeqAccess<'de>,
56        {
57            let mut vec = Vec::new();
58            if self.is_human_readable {
59                while let Some(hex) = seq.next_element::<String>()? {
60                    let bytes = hex::decode(hex)
61                        .map_err(|e| A::Error::custom(format!("invalid hex: {e}")))?;
62                    let mut reader = &*bytes;
63                    let element = FieldElement::deserialize_compressed(&mut reader)
64                        .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?;
65                    if !reader.is_empty() {
66                        return Err(A::Error::custom("while deserializing: trailing bytes"));
67                    }
68                    vec.push(element);
69                }
70            } else {
71                while let Some(bytes) = seq.next_element::<Vec<u8>>()? {
72                    let mut reader = &*bytes;
73                    let element = FieldElement::deserialize_compressed(&mut reader)
74                        .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?;
75                    if !reader.is_empty() {
76                        return Err(A::Error::custom("while deserializing: trailing bytes"));
77                    }
78                    vec.push(element);
79                }
80            }
81            Ok(vec)
82        }
83    }
84
85    let is_human_readable = deserializer.is_human_readable();
86    deserializer.deserialize_seq(VecVisitor { is_human_readable })
87}