elastic_elgamal/
serde.rs

1//! (De)serialization utils.
2
3use base64ct::{Base64UrlUnpadded, Encoding};
4use serde::{
5    de::{DeserializeOwned, Error as DeError, SeqAccess, Unexpected, Visitor},
6    Deserialize, Deserializer, Serialize, Serializer,
7};
8use zeroize::Zeroizing;
9
10use core::{fmt, marker::PhantomData};
11
12use crate::{
13    alloc::{vec, ToString, Vec},
14    dkg::Opening,
15    group::Group,
16    Keypair, PublicKey, SecretKey,
17};
18
19fn serialize_bytes<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
20where
21    S: Serializer,
22{
23    if serializer.is_human_readable() {
24        serializer.serialize_str(&Base64UrlUnpadded::encode_string(value))
25    } else {
26        serializer.serialize_bytes(value)
27    }
28}
29
30fn deserialize_bytes<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
31where
32    D: Deserializer<'de>,
33{
34    struct Base64Visitor;
35
36    impl Visitor<'_> for Base64Visitor {
37        type Value = Vec<u8>;
38
39        fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
40            formatter.write_str("base64url-encoded data")
41        }
42
43        fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
44            Base64UrlUnpadded::decode_vec(value)
45                .map_err(|_| E::invalid_value(Unexpected::Str(value), &self))
46        }
47
48        fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
49            Ok(value.to_vec())
50        }
51
52        fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
53            Ok(value)
54        }
55    }
56
57    struct BytesVisitor;
58
59    impl Visitor<'_> for BytesVisitor {
60        type Value = Vec<u8>;
61
62        fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
63            formatter.write_str("byte buffer")
64        }
65
66        fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
67            Ok(value.to_vec())
68        }
69
70        fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
71            Ok(value)
72        }
73    }
74
75    if deserializer.is_human_readable() {
76        deserializer.deserialize_str(Base64Visitor)
77    } else {
78        deserializer.deserialize_byte_buf(BytesVisitor)
79    }
80}
81
82impl Serialize for Opening {
83    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84    where
85        S: Serializer,
86    {
87        serialize_bytes(self.0.as_slice(), serializer)
88    }
89}
90
91impl<'de> Deserialize<'de> for Opening {
92    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
93    where
94        D: Deserializer<'de>,
95    {
96        let bytes = Zeroizing::new(deserialize_bytes(deserializer)?);
97        let mut opening = Opening(Zeroizing::new([0_u8; 32]));
98        if bytes.len() == 32 {
99            opening.0.copy_from_slice(&bytes);
100            Ok(opening)
101        } else {
102            Err(D::Error::invalid_length(bytes.len(), &"32"))
103        }
104    }
105}
106
107impl<G: Group> Serialize for PublicKey<G> {
108    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
109    where
110        S: Serializer,
111    {
112        serialize_bytes(self.as_bytes(), serializer)
113    }
114}
115
116impl<'de, G: Group> Deserialize<'de> for PublicKey<G> {
117    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118    where
119        D: Deserializer<'de>,
120    {
121        let bytes = deserialize_bytes(deserializer)?;
122        Self::from_bytes(&bytes).map_err(D::Error::custom)
123    }
124}
125
126impl<G: Group> Serialize for SecretKey<G> {
127    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128    where
129        S: Serializer,
130    {
131        let mut bytes = Zeroizing::new(vec![0_u8; G::SCALAR_SIZE]);
132        G::serialize_scalar(self.expose_scalar(), &mut bytes);
133        serialize_bytes(&bytes, serializer)
134    }
135}
136
137impl<'de, G: Group> Deserialize<'de> for SecretKey<G> {
138    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
139    where
140        D: Deserializer<'de>,
141    {
142        let bytes = Zeroizing::new(deserialize_bytes(deserializer)?);
143        Self::from_bytes(&bytes)
144            .ok_or_else(|| D::Error::custom("bytes do not represent a group scalar"))
145    }
146}
147
148impl<G: Group> Serialize for Keypair<G> {
149    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
150    where
151        S: Serializer,
152    {
153        self.secret().serialize(serializer)
154    }
155}
156
157impl<'de, G: Group> Deserialize<'de> for Keypair<G> {
158    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
159    where
160        D: Deserializer<'de>,
161    {
162        SecretKey::<G>::deserialize(deserializer).map(From::from)
163    }
164}
165
166/// Common functionality for serialization helpers.
167pub(crate) trait Helper: Serialize + DeserializeOwned {
168    const PLURAL_DESCRIPTION: &'static str;
169    type Target;
170
171    fn from_target(target: &Self::Target) -> Self;
172    fn into_target(self) -> Self::Target;
173}
174
175/// Helper type to deserialize scalars.
176///
177/// **NB.** Scalars are assumed to be public! Secret scalars must be serialized via `SecretKey`.
178#[derive(Debug)]
179pub(crate) struct ScalarHelper<G: Group>(G::Scalar);
180
181impl<G: Group> ScalarHelper<G> {
182    pub fn serialize<S>(scalar: &G::Scalar, serializer: S) -> Result<S::Ok, S::Error>
183    where
184        S: Serializer,
185    {
186        let mut bytes = vec![0_u8; G::SCALAR_SIZE];
187        G::serialize_scalar(scalar, &mut bytes);
188        serialize_bytes(&bytes, serializer)
189    }
190
191    pub fn deserialize<'de, D>(deserializer: D) -> Result<G::Scalar, D::Error>
192    where
193        D: Deserializer<'de>,
194    {
195        let bytes = deserialize_bytes(deserializer)?;
196        if bytes.len() == G::SCALAR_SIZE {
197            G::deserialize_scalar(&bytes)
198                .ok_or_else(|| D::Error::custom("bytes do not represent a group scalar"))
199        } else {
200            let expected_len = G::SCALAR_SIZE.to_string();
201            Err(D::Error::invalid_length(
202                bytes.len(),
203                &expected_len.as_str(),
204            ))
205        }
206    }
207}
208
209impl<G: Group> Serialize for ScalarHelper<G> {
210    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
211    where
212        S: Serializer,
213    {
214        Self::serialize(&self.0, serializer)
215    }
216}
217
218impl<'de, G: Group> Deserialize<'de> for ScalarHelper<G> {
219    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220    where
221        D: Deserializer<'de>,
222    {
223        Self::deserialize(deserializer).map(Self)
224    }
225}
226
227impl<G: Group> Helper for ScalarHelper<G> {
228    const PLURAL_DESCRIPTION: &'static str = "group scalars";
229    type Target = G::Scalar;
230
231    fn from_target(target: &Self::Target) -> Self {
232        Self(*target)
233    }
234
235    fn into_target(self) -> Self::Target {
236        self.0
237    }
238}
239
240/// Helper type to deserialize group elements.
241#[derive(Debug)]
242pub(crate) struct ElementHelper<G: Group>(G::Element);
243
244impl<G: Group> ElementHelper<G> {
245    pub fn serialize<S>(element: &G::Element, serializer: S) -> Result<S::Ok, S::Error>
246    where
247        S: Serializer,
248    {
249        let mut bytes = vec![0_u8; G::ELEMENT_SIZE];
250        G::serialize_element(element, &mut bytes);
251        serialize_bytes(&bytes, serializer)
252    }
253
254    pub fn deserialize<'de, D>(deserializer: D) -> Result<G::Element, D::Error>
255    where
256        D: Deserializer<'de>,
257    {
258        let bytes = deserialize_bytes(deserializer)?;
259        if bytes.len() == G::ELEMENT_SIZE {
260            G::deserialize_element(&bytes)
261                .ok_or_else(|| D::Error::custom("bytes do not represent a group element"))
262        } else {
263            let expected_len = G::ELEMENT_SIZE.to_string();
264            Err(D::Error::invalid_length(
265                bytes.len(),
266                &expected_len.as_str(),
267            ))
268        }
269    }
270}
271
272impl<G: Group> Serialize for ElementHelper<G> {
273    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
274    where
275        S: Serializer,
276    {
277        Self::serialize(&self.0, serializer)
278    }
279}
280
281impl<'de, G: Group> Deserialize<'de> for ElementHelper<G> {
282    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
283    where
284        D: Deserializer<'de>,
285    {
286        Self::deserialize(deserializer).map(Self)
287    }
288}
289
290impl<G: Group> Helper for ElementHelper<G> {
291    const PLURAL_DESCRIPTION: &'static str = "group elements";
292    type Target = G::Element;
293
294    fn from_target(target: &Self::Target) -> Self {
295        Self(*target)
296    }
297
298    fn into_target(self) -> Self::Target {
299        self.0
300    }
301}
302
303pub(crate) struct VecHelper<T, const MIN: usize>(PhantomData<T>);
304
305impl<T: Helper, const MIN: usize> VecHelper<T, MIN> {
306    fn new() -> Self {
307        Self(PhantomData)
308    }
309
310    pub fn serialize<S>(values: &[T::Target], serializer: S) -> Result<S::Ok, S::Error>
311    where
312        S: Serializer,
313    {
314        debug_assert!(values.len() >= MIN);
315        serializer.collect_seq(values.iter().map(T::from_target))
316    }
317
318    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<T::Target>, D::Error>
319    where
320        D: Deserializer<'de>,
321    {
322        deserializer.deserialize_seq(Self::new())
323    }
324}
325
326impl<'de, T: Helper, const MIN: usize> Visitor<'de> for VecHelper<T, MIN> {
327    type Value = Vec<T::Target>;
328
329    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
330        write!(formatter, "at least {MIN} {}", T::PLURAL_DESCRIPTION)
331    }
332
333    fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
334    where
335        S: SeqAccess<'de>,
336    {
337        let mut scalars: Vec<T::Target> = if let Some(size) = access.size_hint() {
338            if size < MIN {
339                return Err(S::Error::invalid_length(size, &self));
340            }
341            Vec::with_capacity(size)
342        } else {
343            Vec::new()
344        };
345
346        while let Some(value) = access.next_element::<T>()? {
347            scalars.push(value.into_target());
348        }
349        if scalars.len() >= MIN {
350            Ok(scalars)
351        } else {
352            Err(S::Error::invalid_length(scalars.len(), &self))
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use rand::thread_rng;
360
361    use super::*;
362    use crate::group::Ristretto;
363
364    #[test]
365    fn opening_roundtrip() {
366        let opening = Opening(Zeroizing::new([6; 32]));
367        let json = serde_json::to_value(&opening).unwrap();
368        assert!(json.is_string(), "{json:?}");
369        let opening_copy: Opening = serde_json::from_value(json).unwrap();
370        assert_eq!(opening_copy.0, opening.0);
371    }
372
373    #[test]
374    fn key_roundtrip() {
375        let keypair = Keypair::<Ristretto>::generate(&mut thread_rng());
376        let json = serde_json::to_value(&keypair).unwrap();
377        assert!(json.is_string(), "{json:?}");
378        let keypair_copy: Keypair<Ristretto> = serde_json::from_value(json).unwrap();
379        assert_eq!(keypair_copy.public(), keypair.public());
380
381        let json = serde_json::to_value(keypair.public()).unwrap();
382        assert!(json.is_string(), "{json:?}");
383        let public_key: PublicKey<Ristretto> = serde_json::from_value(json).unwrap();
384        assert_eq!(public_key, *keypair.public());
385
386        let json = serde_json::to_value(keypair.secret()).unwrap();
387        assert!(json.is_string(), "{json:?}");
388        let secret_key: SecretKey<Ristretto> = serde_json::from_value(json).unwrap();
389        assert_eq!(secret_key.expose_scalar(), keypair.secret().expose_scalar());
390    }
391
392    #[test]
393    fn public_key_deserialization_with_incorrect_length() {
394        let err = serde_json::from_str::<PublicKey<Ristretto>>("\"dGVzdA\"").unwrap_err();
395        let err_string = err.to_string();
396        assert!(
397            err_string.contains("invalid size of the byte buffer"),
398            "{err_string}"
399        );
400    }
401
402    #[test]
403    fn public_key_deserialization_of_non_element() {
404        let err = serde_json::from_str::<PublicKey<Ristretto>>(
405            "\"tNDkeYUVQWgh34d-RqaElOk7yFB8d2qCh5f4Vi2euT0\"",
406        )
407        .unwrap_err();
408        let err_string = err.to_string();
409        assert!(
410            err_string.contains("does not represent a group element"),
411            "{err_string}"
412        );
413    }
414
415    #[test]
416    fn secret_key_deserialization_with_incorrect_length() {
417        let err = serde_json::from_str::<SecretKey<Ristretto>>("\"dGVzdA\"").unwrap_err();
418        let err_string = err.to_string();
419        assert!(
420            err_string.contains("bytes do not represent a group scalar"),
421            "{err_string}"
422        );
423    }
424
425    #[test]
426    fn secret_key_deserialization_of_invalid_scalar() {
427        // Last `_8` chars set the upper byte of the scalar bytes to 0xff, which is invalid
428        // (all scalars are less than 2^253).
429        let err = serde_json::from_str::<SecretKey<Ristretto>>(
430            "\"nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8\"",
431        )
432        .unwrap_err();
433        let err_string = err.to_string();
434        assert!(
435            err_string.contains("bytes do not represent a group scalar"),
436            "{err_string}"
437        );
438    }
439
440    #[derive(Debug, PartialEq, Serialize, Deserialize)]
441    #[serde(bound = "")]
442    struct TestObject<G: Group> {
443        #[serde(with = "ScalarHelper::<G>")]
444        scalar: G::Scalar,
445        #[serde(with = "ElementHelper::<G>")]
446        element: G::Element,
447        #[serde(with = "VecHelper::<ScalarHelper<G>, 2>")]
448        more_scalars: Vec<G::Scalar>,
449    }
450
451    impl TestObject<Ristretto> {
452        fn sample() -> Self {
453            Self {
454                scalar: 12345_u64.into(),
455                element: Ristretto::mul_generator(&54321_u64.into()),
456                more_scalars: vec![7_u64.into(), 890_u64.into()],
457            }
458        }
459    }
460
461    #[test]
462    fn helpers_roundtrip() {
463        let object = TestObject::sample();
464        let json = serde_json::to_value(&object).unwrap();
465        let object_copy: TestObject<Ristretto> = serde_json::from_value(json).unwrap();
466        assert_eq!(object_copy, object);
467    }
468
469    #[test]
470    fn scalar_helper_invalid_scalar() {
471        let object = TestObject::sample();
472        let mut json = serde_json::to_value(object).unwrap();
473        json.as_object_mut()
474            .unwrap()
475            .insert("scalar".into(), "dGVzdA".into());
476
477        let err = serde_json::from_value::<TestObject<Ristretto>>(json.clone()).unwrap_err();
478        let err_string = err.to_string();
479        assert!(
480            err_string.contains("invalid length 4, expected 32"),
481            "{err_string}"
482        );
483
484        json.as_object_mut().unwrap().insert(
485            "scalar".into(),
486            "nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8".into(),
487        );
488        let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
489        let err_string = err.to_string();
490        assert!(
491            err_string.contains("bytes do not represent a group scalar"),
492            "{err_string}"
493        );
494    }
495
496    #[test]
497    fn element_helper_invalid_element() {
498        let object = TestObject::sample();
499        let mut json = serde_json::to_value(object).unwrap();
500        json.as_object_mut()
501            .unwrap()
502            .insert("element".into(), "dGVzdA".into());
503
504        let err = serde_json::from_value::<TestObject<Ristretto>>(json.clone()).unwrap_err();
505        let err_string = err.to_string();
506        assert!(
507            err_string.contains("invalid length 4, expected 32"),
508            "{err_string}"
509        );
510
511        json.as_object_mut().unwrap().insert(
512            "element".into(),
513            "nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8".into(),
514        );
515        let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
516        let err_string = err.to_string();
517        assert!(
518            err_string.contains("bytes do not represent a group element"),
519            "{err_string}"
520        );
521    }
522
523    #[test]
524    fn vec_helper_invalid_length() {
525        let object = TestObject::sample();
526        let mut json = serde_json::to_value(object).unwrap();
527        let more_scalars = &mut json.as_object_mut().unwrap()["more_scalars"];
528        more_scalars.as_array_mut().unwrap().pop();
529
530        let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
531        let err_string = err.to_string();
532        assert!(
533            err_string.contains("invalid length 1, expected at least 2 group scalars"),
534            "{err_string}"
535        );
536    }
537}