1use base64ct::{Base64UrlUnpadded, Encoding};
4use serde::{
5 de::{DeserializeOwned, Error as DeError, SeqAccess, Unexpected, Visitor},
6 Deserialize, Deserializer, Serialize, Serializer,
7};
8use zeroize::Zeroizing;
9
10use core::{fmt, marker::PhantomData};
11
12use crate::{
13 alloc::{vec, ToString, Vec},
14 dkg::Opening,
15 group::Group,
16 Keypair, PublicKey, SecretKey,
17};
18
19fn serialize_bytes<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
20where
21 S: Serializer,
22{
23 if serializer.is_human_readable() {
24 serializer.serialize_str(&Base64UrlUnpadded::encode_string(value))
25 } else {
26 serializer.serialize_bytes(value)
27 }
28}
29
30fn deserialize_bytes<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
31where
32 D: Deserializer<'de>,
33{
34 struct Base64Visitor;
35
36 impl Visitor<'_> for Base64Visitor {
37 type Value = Vec<u8>;
38
39 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
40 formatter.write_str("base64url-encoded data")
41 }
42
43 fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
44 Base64UrlUnpadded::decode_vec(value)
45 .map_err(|_| E::invalid_value(Unexpected::Str(value), &self))
46 }
47
48 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
49 Ok(value.to_vec())
50 }
51
52 fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
53 Ok(value)
54 }
55 }
56
57 struct BytesVisitor;
58
59 impl Visitor<'_> for BytesVisitor {
60 type Value = Vec<u8>;
61
62 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
63 formatter.write_str("byte buffer")
64 }
65
66 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
67 Ok(value.to_vec())
68 }
69
70 fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
71 Ok(value)
72 }
73 }
74
75 if deserializer.is_human_readable() {
76 deserializer.deserialize_str(Base64Visitor)
77 } else {
78 deserializer.deserialize_byte_buf(BytesVisitor)
79 }
80}
81
82impl Serialize for Opening {
83 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84 where
85 S: Serializer,
86 {
87 serialize_bytes(self.0.as_slice(), serializer)
88 }
89}
90
91impl<'de> Deserialize<'de> for Opening {
92 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
93 where
94 D: Deserializer<'de>,
95 {
96 let bytes = Zeroizing::new(deserialize_bytes(deserializer)?);
97 let mut opening = Opening(Zeroizing::new([0_u8; 32]));
98 if bytes.len() == 32 {
99 opening.0.copy_from_slice(&bytes);
100 Ok(opening)
101 } else {
102 Err(D::Error::invalid_length(bytes.len(), &"32"))
103 }
104 }
105}
106
107impl<G: Group> Serialize for PublicKey<G> {
108 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
109 where
110 S: Serializer,
111 {
112 serialize_bytes(self.as_bytes(), serializer)
113 }
114}
115
116impl<'de, G: Group> Deserialize<'de> for PublicKey<G> {
117 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118 where
119 D: Deserializer<'de>,
120 {
121 let bytes = deserialize_bytes(deserializer)?;
122 Self::from_bytes(&bytes).map_err(D::Error::custom)
123 }
124}
125
126impl<G: Group> Serialize for SecretKey<G> {
127 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128 where
129 S: Serializer,
130 {
131 let mut bytes = Zeroizing::new(vec![0_u8; G::SCALAR_SIZE]);
132 G::serialize_scalar(self.expose_scalar(), &mut bytes);
133 serialize_bytes(&bytes, serializer)
134 }
135}
136
137impl<'de, G: Group> Deserialize<'de> for SecretKey<G> {
138 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
139 where
140 D: Deserializer<'de>,
141 {
142 let bytes = Zeroizing::new(deserialize_bytes(deserializer)?);
143 Self::from_bytes(&bytes)
144 .ok_or_else(|| D::Error::custom("bytes do not represent a group scalar"))
145 }
146}
147
148impl<G: Group> Serialize for Keypair<G> {
149 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
150 where
151 S: Serializer,
152 {
153 self.secret().serialize(serializer)
154 }
155}
156
157impl<'de, G: Group> Deserialize<'de> for Keypair<G> {
158 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
159 where
160 D: Deserializer<'de>,
161 {
162 SecretKey::<G>::deserialize(deserializer).map(From::from)
163 }
164}
165
166pub(crate) trait Helper: Serialize + DeserializeOwned {
168 const PLURAL_DESCRIPTION: &'static str;
169 type Target;
170
171 fn from_target(target: &Self::Target) -> Self;
172 fn into_target(self) -> Self::Target;
173}
174
175#[derive(Debug)]
179pub(crate) struct ScalarHelper<G: Group>(G::Scalar);
180
181impl<G: Group> ScalarHelper<G> {
182 pub fn serialize<S>(scalar: &G::Scalar, serializer: S) -> Result<S::Ok, S::Error>
183 where
184 S: Serializer,
185 {
186 let mut bytes = vec![0_u8; G::SCALAR_SIZE];
187 G::serialize_scalar(scalar, &mut bytes);
188 serialize_bytes(&bytes, serializer)
189 }
190
191 pub fn deserialize<'de, D>(deserializer: D) -> Result<G::Scalar, D::Error>
192 where
193 D: Deserializer<'de>,
194 {
195 let bytes = deserialize_bytes(deserializer)?;
196 if bytes.len() == G::SCALAR_SIZE {
197 G::deserialize_scalar(&bytes)
198 .ok_or_else(|| D::Error::custom("bytes do not represent a group scalar"))
199 } else {
200 let expected_len = G::SCALAR_SIZE.to_string();
201 Err(D::Error::invalid_length(
202 bytes.len(),
203 &expected_len.as_str(),
204 ))
205 }
206 }
207}
208
209impl<G: Group> Serialize for ScalarHelper<G> {
210 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
211 where
212 S: Serializer,
213 {
214 Self::serialize(&self.0, serializer)
215 }
216}
217
218impl<'de, G: Group> Deserialize<'de> for ScalarHelper<G> {
219 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220 where
221 D: Deserializer<'de>,
222 {
223 Self::deserialize(deserializer).map(Self)
224 }
225}
226
227impl<G: Group> Helper for ScalarHelper<G> {
228 const PLURAL_DESCRIPTION: &'static str = "group scalars";
229 type Target = G::Scalar;
230
231 fn from_target(target: &Self::Target) -> Self {
232 Self(*target)
233 }
234
235 fn into_target(self) -> Self::Target {
236 self.0
237 }
238}
239
240#[derive(Debug)]
242pub(crate) struct ElementHelper<G: Group>(G::Element);
243
244impl<G: Group> ElementHelper<G> {
245 pub fn serialize<S>(element: &G::Element, serializer: S) -> Result<S::Ok, S::Error>
246 where
247 S: Serializer,
248 {
249 let mut bytes = vec![0_u8; G::ELEMENT_SIZE];
250 G::serialize_element(element, &mut bytes);
251 serialize_bytes(&bytes, serializer)
252 }
253
254 pub fn deserialize<'de, D>(deserializer: D) -> Result<G::Element, D::Error>
255 where
256 D: Deserializer<'de>,
257 {
258 let bytes = deserialize_bytes(deserializer)?;
259 if bytes.len() == G::ELEMENT_SIZE {
260 G::deserialize_element(&bytes)
261 .ok_or_else(|| D::Error::custom("bytes do not represent a group element"))
262 } else {
263 let expected_len = G::ELEMENT_SIZE.to_string();
264 Err(D::Error::invalid_length(
265 bytes.len(),
266 &expected_len.as_str(),
267 ))
268 }
269 }
270}
271
272impl<G: Group> Serialize for ElementHelper<G> {
273 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
274 where
275 S: Serializer,
276 {
277 Self::serialize(&self.0, serializer)
278 }
279}
280
281impl<'de, G: Group> Deserialize<'de> for ElementHelper<G> {
282 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
283 where
284 D: Deserializer<'de>,
285 {
286 Self::deserialize(deserializer).map(Self)
287 }
288}
289
290impl<G: Group> Helper for ElementHelper<G> {
291 const PLURAL_DESCRIPTION: &'static str = "group elements";
292 type Target = G::Element;
293
294 fn from_target(target: &Self::Target) -> Self {
295 Self(*target)
296 }
297
298 fn into_target(self) -> Self::Target {
299 self.0
300 }
301}
302
303pub(crate) struct VecHelper<T, const MIN: usize>(PhantomData<T>);
304
305impl<T: Helper, const MIN: usize> VecHelper<T, MIN> {
306 fn new() -> Self {
307 Self(PhantomData)
308 }
309
310 pub fn serialize<S>(values: &[T::Target], serializer: S) -> Result<S::Ok, S::Error>
311 where
312 S: Serializer,
313 {
314 debug_assert!(values.len() >= MIN);
315 serializer.collect_seq(values.iter().map(T::from_target))
316 }
317
318 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<T::Target>, D::Error>
319 where
320 D: Deserializer<'de>,
321 {
322 deserializer.deserialize_seq(Self::new())
323 }
324}
325
326impl<'de, T: Helper, const MIN: usize> Visitor<'de> for VecHelper<T, MIN> {
327 type Value = Vec<T::Target>;
328
329 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
330 write!(formatter, "at least {MIN} {}", T::PLURAL_DESCRIPTION)
331 }
332
333 fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
334 where
335 S: SeqAccess<'de>,
336 {
337 let mut scalars: Vec<T::Target> = if let Some(size) = access.size_hint() {
338 if size < MIN {
339 return Err(S::Error::invalid_length(size, &self));
340 }
341 Vec::with_capacity(size)
342 } else {
343 Vec::new()
344 };
345
346 while let Some(value) = access.next_element::<T>()? {
347 scalars.push(value.into_target());
348 }
349 if scalars.len() >= MIN {
350 Ok(scalars)
351 } else {
352 Err(S::Error::invalid_length(scalars.len(), &self))
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use rand::thread_rng;
360
361 use super::*;
362 use crate::group::Ristretto;
363
364 #[test]
365 fn opening_roundtrip() {
366 let opening = Opening(Zeroizing::new([6; 32]));
367 let json = serde_json::to_value(&opening).unwrap();
368 assert!(json.is_string(), "{json:?}");
369 let opening_copy: Opening = serde_json::from_value(json).unwrap();
370 assert_eq!(opening_copy.0, opening.0);
371 }
372
373 #[test]
374 fn key_roundtrip() {
375 let keypair = Keypair::<Ristretto>::generate(&mut thread_rng());
376 let json = serde_json::to_value(&keypair).unwrap();
377 assert!(json.is_string(), "{json:?}");
378 let keypair_copy: Keypair<Ristretto> = serde_json::from_value(json).unwrap();
379 assert_eq!(keypair_copy.public(), keypair.public());
380
381 let json = serde_json::to_value(keypair.public()).unwrap();
382 assert!(json.is_string(), "{json:?}");
383 let public_key: PublicKey<Ristretto> = serde_json::from_value(json).unwrap();
384 assert_eq!(public_key, *keypair.public());
385
386 let json = serde_json::to_value(keypair.secret()).unwrap();
387 assert!(json.is_string(), "{json:?}");
388 let secret_key: SecretKey<Ristretto> = serde_json::from_value(json).unwrap();
389 assert_eq!(secret_key.expose_scalar(), keypair.secret().expose_scalar());
390 }
391
392 #[test]
393 fn public_key_deserialization_with_incorrect_length() {
394 let err = serde_json::from_str::<PublicKey<Ristretto>>("\"dGVzdA\"").unwrap_err();
395 let err_string = err.to_string();
396 assert!(
397 err_string.contains("invalid size of the byte buffer"),
398 "{err_string}"
399 );
400 }
401
402 #[test]
403 fn public_key_deserialization_of_non_element() {
404 let err = serde_json::from_str::<PublicKey<Ristretto>>(
405 "\"tNDkeYUVQWgh34d-RqaElOk7yFB8d2qCh5f4Vi2euT0\"",
406 )
407 .unwrap_err();
408 let err_string = err.to_string();
409 assert!(
410 err_string.contains("does not represent a group element"),
411 "{err_string}"
412 );
413 }
414
415 #[test]
416 fn secret_key_deserialization_with_incorrect_length() {
417 let err = serde_json::from_str::<SecretKey<Ristretto>>("\"dGVzdA\"").unwrap_err();
418 let err_string = err.to_string();
419 assert!(
420 err_string.contains("bytes do not represent a group scalar"),
421 "{err_string}"
422 );
423 }
424
425 #[test]
426 fn secret_key_deserialization_of_invalid_scalar() {
427 let err = serde_json::from_str::<SecretKey<Ristretto>>(
430 "\"nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8\"",
431 )
432 .unwrap_err();
433 let err_string = err.to_string();
434 assert!(
435 err_string.contains("bytes do not represent a group scalar"),
436 "{err_string}"
437 );
438 }
439
440 #[derive(Debug, PartialEq, Serialize, Deserialize)]
441 #[serde(bound = "")]
442 struct TestObject<G: Group> {
443 #[serde(with = "ScalarHelper::<G>")]
444 scalar: G::Scalar,
445 #[serde(with = "ElementHelper::<G>")]
446 element: G::Element,
447 #[serde(with = "VecHelper::<ScalarHelper<G>, 2>")]
448 more_scalars: Vec<G::Scalar>,
449 }
450
451 impl TestObject<Ristretto> {
452 fn sample() -> Self {
453 Self {
454 scalar: 12345_u64.into(),
455 element: Ristretto::mul_generator(&54321_u64.into()),
456 more_scalars: vec![7_u64.into(), 890_u64.into()],
457 }
458 }
459 }
460
461 #[test]
462 fn helpers_roundtrip() {
463 let object = TestObject::sample();
464 let json = serde_json::to_value(&object).unwrap();
465 let object_copy: TestObject<Ristretto> = serde_json::from_value(json).unwrap();
466 assert_eq!(object_copy, object);
467 }
468
469 #[test]
470 fn scalar_helper_invalid_scalar() {
471 let object = TestObject::sample();
472 let mut json = serde_json::to_value(object).unwrap();
473 json.as_object_mut()
474 .unwrap()
475 .insert("scalar".into(), "dGVzdA".into());
476
477 let err = serde_json::from_value::<TestObject<Ristretto>>(json.clone()).unwrap_err();
478 let err_string = err.to_string();
479 assert!(
480 err_string.contains("invalid length 4, expected 32"),
481 "{err_string}"
482 );
483
484 json.as_object_mut().unwrap().insert(
485 "scalar".into(),
486 "nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8".into(),
487 );
488 let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
489 let err_string = err.to_string();
490 assert!(
491 err_string.contains("bytes do not represent a group scalar"),
492 "{err_string}"
493 );
494 }
495
496 #[test]
497 fn element_helper_invalid_element() {
498 let object = TestObject::sample();
499 let mut json = serde_json::to_value(object).unwrap();
500 json.as_object_mut()
501 .unwrap()
502 .insert("element".into(), "dGVzdA".into());
503
504 let err = serde_json::from_value::<TestObject<Ristretto>>(json.clone()).unwrap_err();
505 let err_string = err.to_string();
506 assert!(
507 err_string.contains("invalid length 4, expected 32"),
508 "{err_string}"
509 );
510
511 json.as_object_mut().unwrap().insert(
512 "element".into(),
513 "nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8".into(),
514 );
515 let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
516 let err_string = err.to_string();
517 assert!(
518 err_string.contains("bytes do not represent a group element"),
519 "{err_string}"
520 );
521 }
522
523 #[test]
524 fn vec_helper_invalid_length() {
525 let object = TestObject::sample();
526 let mut json = serde_json::to_value(object).unwrap();
527 let more_scalars = &mut json.as_object_mut().unwrap()["more_scalars"];
528 more_scalars.as_array_mut().unwrap().pop();
529
530 let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
531 let err_string = err.to_string();
532 assert!(
533 err_string.contains("invalid length 1, expected at least 2 group scalars"),
534 "{err_string}"
535 );
536 }
537}