1use phantom_type::PhantomType;
78
79use crate::core::Curve;
80
81#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)]
110pub struct CurveName<E: Curve>(PhantomType<E>);
111
112impl<E: Curve> CurveName<E> {
113 pub fn new() -> Self {
115 Self(PhantomType::new())
116 }
117}
118
119impl<E: Curve> Default for CurveName<E> {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125#[cfg(feature = "serde")]
126impl<E: Curve> serde::Serialize for CurveName<E> {
127 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128 where
129 S: serde::Serializer,
130 {
131 serializer.serialize_str(E::CURVE_NAME)
132 }
133}
134
135#[cfg(feature = "serde")]
136impl<'de, E: Curve> serde::Deserialize<'de> for CurveName<E> {
137 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
138 where
139 D: serde::Deserializer<'de>,
140 {
141 pub struct CurveNameVisitor<E: Curve>(PhantomType<E>);
142 impl<E: Curve> serde::de::Visitor<'_> for CurveNameVisitor<E> {
143 type Value = CurveName<E>;
144 fn expecting(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
145 write!(f, "curve {name}", name = E::CURVE_NAME)
146 }
147 fn visit_str<Error>(self, v: &str) -> Result<Self::Value, Error>
148 where
149 Error: serde::de::Error,
150 {
151 if v == E::CURVE_NAME {
152 Ok(CurveName::default())
153 } else {
154 Err(Error::custom(optional::error_msg::ExpectedCurve {
155 expected: E::CURVE_NAME,
156 got: v,
157 }))
158 }
159 }
160 }
161 deserializer.deserialize_str(CurveNameVisitor(PhantomType::new()))
162 }
163}
164
165#[cfg(feature = "serde")]
166pub use optional::*;
167#[cfg(feature = "serde")]
168mod optional {
169 use crate::{core::Curve, Point, Scalar, SecretScalar};
170
171 use super::CurveName;
172
173 impl<E: Curve> serde::Serialize for Point<E> {
174 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
175 where
176 S: serde::Serializer,
177 {
178 models::PointUncompressed::from(self).serialize(serializer)
179 }
180 }
181
182 impl<'de, E: Curve> serde::Deserialize<'de> for Point<E> {
183 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
184 where
185 D: serde::Deserializer<'de>,
186 {
187 models::PointUncompressed::deserialize(deserializer)?
188 .try_into()
189 .map_err(<D::Error as serde::de::Error>::custom)
190 }
191 }
192
193 impl<E: Curve> serde::Serialize for Scalar<E> {
194 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
195 where
196 S: serde::Serializer,
197 {
198 models::ScalarUncompressed::from(self).serialize(serializer)
199 }
200 }
201
202 impl<'de, E: Curve> serde::Deserialize<'de> for Scalar<E> {
203 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
204 where
205 D: serde::Deserializer<'de>,
206 {
207 models::ScalarUncompressed::deserialize(deserializer)?
208 .try_into()
209 .map_err(<D::Error as serde::de::Error>::custom)
210 }
211 }
212
213 impl<E: Curve> serde::Serialize for SecretScalar<E> {
214 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
215 where
216 S: serde::Serializer,
217 {
218 self.as_ref().serialize(serializer)
219 }
220 }
221
222 impl<'de, E: Curve> serde::Deserialize<'de> for SecretScalar<E> {
223 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
224 where
225 D: serde::Deserializer<'de>,
226 {
227 Ok(SecretScalar::new(&mut Scalar::deserialize(deserializer)?))
228 }
229 }
230
231 pub struct Compact;
233
234 impl<E: Curve> serde_with::SerializeAs<Point<E>> for Compact {
235 fn serialize_as<S>(source: &Point<E>, serializer: S) -> Result<S::Ok, S::Error>
236 where
237 S: serde::Serializer,
238 {
239 use serde::Serialize;
240 models::PointCompact::from(source).serialize(serializer)
241 }
242 }
243
244 impl<'de, E: Curve> serde_with::DeserializeAs<'de, Point<E>> for Compact {
245 fn deserialize_as<D>(deserializer: D) -> Result<Point<E>, D::Error>
246 where
247 D: serde::Deserializer<'de>,
248 {
249 use serde::Deserialize;
250 models::PointCompact::deserialize(deserializer)?
251 .try_into()
252 .map_err(<D::Error as serde::de::Error>::custom)
253 }
254 }
255
256 impl<E: Curve> serde_with::SerializeAs<Scalar<E>> for Compact {
257 fn serialize_as<S>(source: &Scalar<E>, serializer: S) -> Result<S::Ok, S::Error>
258 where
259 S: serde::Serializer,
260 {
261 use serde::Serialize;
262 models::ScalarCompact::from(source).serialize(serializer)
263 }
264 }
265
266 impl<'de, E: Curve> serde_with::DeserializeAs<'de, Scalar<E>> for Compact {
267 fn deserialize_as<D>(deserializer: D) -> Result<Scalar<E>, D::Error>
268 where
269 D: serde::Deserializer<'de>,
270 {
271 use serde::Deserialize;
272 models::ScalarCompact::deserialize(deserializer)?
273 .try_into()
274 .map_err(<D::Error as serde::de::Error>::custom)
275 }
276 }
277
278 impl<E: Curve> serde_with::SerializeAs<SecretScalar<E>> for Compact {
279 fn serialize_as<S>(source: &SecretScalar<E>, serializer: S) -> Result<S::Ok, S::Error>
280 where
281 S: serde::Serializer,
282 {
283 use serde::Serialize;
284 models::ScalarCompact::from(source.as_ref()).serialize(serializer)
285 }
286 }
287
288 impl<'de, E: Curve> serde_with::DeserializeAs<'de, SecretScalar<E>> for Compact {
289 fn deserialize_as<D>(deserializer: D) -> Result<SecretScalar<E>, D::Error>
290 where
291 D: serde::Deserializer<'de>,
292 {
293 let mut scalar =
294 <Compact as serde_with::DeserializeAs<'de, Scalar<E>>>::deserialize_as(
295 deserializer,
296 )?;
297 Ok(SecretScalar::new(&mut scalar))
298 }
299 }
300
301 impl<T> serde_with::SerializeAs<crate::NonZero<T>> for Compact
302 where
303 Compact: serde_with::SerializeAs<T>,
304 {
305 fn serialize_as<S>(source: &crate::NonZero<T>, serializer: S) -> Result<S::Ok, S::Error>
306 where
307 S: serde::Serializer,
308 {
309 Compact::serialize_as(source.as_ref(), serializer)
310 }
311 }
312
313 impl<'de, T> serde_with::DeserializeAs<'de, crate::NonZero<T>> for Compact
314 where
315 Compact: serde_with::DeserializeAs<'de, T>,
316 crate::NonZero<T>: TryFrom<T>,
317 <crate::NonZero<T> as TryFrom<T>>::Error: core::fmt::Display,
318 {
319 fn deserialize_as<D>(deserializer: D) -> Result<crate::NonZero<T>, D::Error>
320 where
321 D: serde::Deserializer<'de>,
322 {
323 let value = Compact::deserialize_as(deserializer)?;
324 crate::NonZero::try_from(value).map_err(<D::Error as serde::de::Error>::custom)
325 }
326 }
327
328 impl<'a, T> serde_with::SerializeAs<&'a T> for Compact
329 where
330 Compact: serde_with::SerializeAs<T>,
331 {
332 fn serialize_as<S>(source: &&'a T, serializer: S) -> Result<S::Ok, S::Error>
333 where
334 S: serde::Serializer,
335 {
336 Compact::serialize_as(*source, serializer)
337 }
338 }
339
340 pub struct PreferCompact;
353
354 impl<T> serde_with::SerializeAs<T> for PreferCompact
355 where
356 Compact: serde_with::SerializeAs<T>,
357 {
358 fn serialize_as<S>(source: &T, serializer: S) -> Result<S::Ok, S::Error>
359 where
360 S: serde::Serializer,
361 {
362 <Compact as serde_with::SerializeAs<T>>::serialize_as(source, serializer)
363 }
364 }
365
366 impl<'de, T> serde_with::DeserializeAs<'de, T> for PreferCompact
367 where
368 T: serde::Deserialize<'de>,
369 Compact: serde_with::DeserializeAs<'de, T>,
370 {
371 fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
372 where
373 D: serde::Deserializer<'de>,
374 {
375 use serde_with::DeserializeAs;
376
377 struct Visitor<T> {
378 is_human_readable: bool,
379 _out: core::marker::PhantomData<T>,
380 }
381 impl<'de, T> serde::de::Visitor<'de> for Visitor<T>
382 where
383 T: serde::Deserialize<'de>,
384 Compact: serde_with::DeserializeAs<'de, T>,
385 {
386 type Value = T;
387 fn expecting(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
388 f.write_str("preferably compact point/scalar")
389 }
390
391 fn visit_bytes<Err>(self, v: &[u8]) -> Result<Self::Value, Err>
392 where
393 Err: serde::de::Error,
394 {
395 Compact::deserialize_as(NewTypeDeserializer::new(OverrideHumanReadable {
396 deserializer: serde::de::value::BytesDeserializer::<Err>::new(v),
397 is_human_readable: self.is_human_readable,
398 }))
399 }
400 fn visit_str<Err>(self, v: &str) -> Result<Self::Value, Err>
401 where
402 Err: serde::de::Error,
403 {
404 Compact::deserialize_as(NewTypeDeserializer::new(OverrideHumanReadable {
405 deserializer: serde::de::value::StrDeserializer::<Err>::new(v),
406 is_human_readable: self.is_human_readable,
407 }))
408 }
409
410 fn visit_seq<A>(self, _seq: A) -> Result<Self::Value, A::Error>
411 where
412 A: serde::de::SeqAccess<'de>,
413 {
414 Err(<A::Error as serde::de::Error>::custom(
415 "cannot deserialize in `PreferCompact` mode \
416 from sequence: it's ambiguous",
417 ))
418 }
419 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
420 where
421 A: serde::de::MapAccess<'de>,
422 {
423 T::deserialize(OverrideHumanReadable {
424 deserializer: serde::de::value::MapAccessDeserializer::new(map),
425 is_human_readable: self.is_human_readable,
426 })
427 }
428
429 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
430 where
431 D: serde::Deserializer<'de>,
432 {
433 Compact::deserialize_as(NewTypeDeserializer::new(OverrideHumanReadable {
434 deserializer,
435 is_human_readable: self.is_human_readable,
436 }))
437 }
438 }
439
440 let is_human_readable = deserializer.is_human_readable();
441 deserializer.deserialize_any(Visitor {
442 is_human_readable,
443 _out: core::marker::PhantomData::<T>,
444 })
445 }
446 }
447
448 struct OverrideHumanReadable<D> {
450 is_human_readable: bool,
451 deserializer: D,
452 }
453 impl<'de, D> serde::Deserializer<'de> for OverrideHumanReadable<D>
454 where
455 D: serde::Deserializer<'de>,
456 {
457 type Error = <D as serde::Deserializer<'de>>::Error;
458
459 fn is_human_readable(&self) -> bool {
460 self.is_human_readable
461 }
462
463 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
464 where
465 V: serde::de::Visitor<'de>,
466 {
467 self.deserializer.deserialize_any(visitor)
468 }
469
470 serde::forward_to_deserialize_any! {
471 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
472 bytes byte_buf option unit unit_struct newtype_struct seq tuple
473 tuple_struct map struct enum identifier ignored_any
474 }
475 }
476
477 struct NewTypeDeserializer<D> {
479 deserializer: D,
480 }
481 impl<D> NewTypeDeserializer<D> {
482 pub fn new(deserializer: D) -> Self {
483 Self { deserializer }
484 }
485 }
486 impl<'de, D> serde::Deserializer<'de> for NewTypeDeserializer<D>
487 where
488 D: serde::Deserializer<'de>,
489 {
490 type Error = D::Error;
491 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
492 where
493 V: serde::de::Visitor<'de>,
494 {
495 visitor.visit_newtype_struct(self.deserializer)
496 }
497 fn is_human_readable(&self) -> bool {
498 self.deserializer.is_human_readable()
499 }
500 serde::forward_to_deserialize_any! {
501 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
502 bytes byte_buf option unit unit_struct newtype_struct seq tuple
503 tuple_struct map struct enum identifier ignored_any
504 }
505 }
506
507 mod models {
508 use core::convert::TryFrom;
509
510 use serde::{Deserialize, Serialize};
511 use serde_with::serde_as;
512
513 use crate::core::{CompressedEncoding, IntegerEncoding, UncompressedEncoding};
514 use crate::{as_raw::AsRaw, Curve, Point, Scalar};
515
516 use super::{
517 error_msg::{InvalidPoint, InvalidScalar},
518 CurveName,
519 };
520
521 #[serde_as]
522 #[derive(Serialize, Deserialize)]
523 #[serde(bound = "")]
524 pub struct PointUncompressed<E: Curve> {
525 curve: CurveName<E>,
526 #[serde_as(as = "super::utils::Bytes")]
527 point: E::UncompressedPointArray,
528 }
529 impl<E: Curve> From<&Point<E>> for PointUncompressed<E> {
530 fn from(p: &Point<E>) -> Self {
531 let bytes = p.as_raw().to_bytes_uncompressed();
532 Self {
533 curve: CurveName::new(),
534 point: bytes,
535 }
536 }
537 }
538 impl<E: Curve> TryFrom<PointUncompressed<E>> for Point<E> {
539 type Error = InvalidPoint;
540 fn try_from(value: PointUncompressed<E>) -> Result<Self, Self::Error> {
541 Point::from_bytes(value.point).or(Err(InvalidPoint))
542 }
543 }
544
545 #[serde_as]
546 #[derive(Serialize, Deserialize)]
547 #[serde(bound = "")]
548 pub struct PointCompact<E: Curve>(
549 #[serde_as(as = "super::utils::Bytes")] E::CompressedPointArray,
550 );
551 impl<E: Curve> From<&Point<E>> for PointCompact<E> {
552 fn from(p: &Point<E>) -> Self {
553 let bytes = p.as_raw().to_bytes_compressed();
554 Self(bytes)
555 }
556 }
557 impl<E: Curve> TryFrom<PointCompact<E>> for Point<E> {
558 type Error = InvalidPoint;
559 fn try_from(value: PointCompact<E>) -> Result<Self, Self::Error> {
560 Point::from_bytes(value.0).or(Err(InvalidPoint))
561 }
562 }
563
564 #[serde_as]
565 #[derive(Serialize, Deserialize)]
566 #[serde(bound = "")]
567 pub struct ScalarUncompressed<E: Curve> {
568 curve: CurveName<E>,
569 #[serde_as(as = "super::utils::Bytes")]
570 scalar: E::ScalarArray,
571 }
572 impl<E: Curve> From<&Scalar<E>> for ScalarUncompressed<E> {
573 fn from(s: &Scalar<E>) -> Self {
574 let bytes = s.as_raw().to_be_bytes();
575 Self {
576 curve: CurveName::new(),
577 scalar: bytes,
578 }
579 }
580 }
581 impl<E: Curve> TryFrom<ScalarUncompressed<E>> for Scalar<E> {
582 type Error = InvalidScalar;
583 fn try_from(value: ScalarUncompressed<E>) -> Result<Self, Self::Error> {
584 Scalar::from_be_bytes(value.scalar).or(Err(InvalidScalar))
585 }
586 }
587
588 #[serde_as]
589 #[derive(Serialize, Deserialize)]
590 #[serde(bound = "")]
591 pub struct ScalarCompact<E: Curve>(#[serde_as(as = "super::utils::Bytes")] E::ScalarArray);
592 impl<E: Curve> From<&Scalar<E>> for ScalarCompact<E> {
593 fn from(s: &Scalar<E>) -> Self {
594 let bytes = s.as_raw().to_be_bytes();
595 Self(bytes)
596 }
597 }
598 impl<E: Curve> TryFrom<ScalarCompact<E>> for Scalar<E> {
599 type Error = InvalidScalar;
600 fn try_from(value: ScalarCompact<E>) -> Result<Self, Self::Error> {
601 Scalar::from_be_bytes(&value.0).or(Err(InvalidScalar))
602 }
603 }
604 }
605
606 mod utils {
607 use core::fmt;
608
609 use serde::de::{self, Visitor};
610 use serde_with::{DeserializeAs, SerializeAs};
611
612 use crate::core::ByteArray;
613
614 pub struct Bytes;
615
616 impl<T> SerializeAs<T> for Bytes
617 where
618 T: AsRef<[u8]>,
619 {
620 fn serialize_as<S>(source: &T, serializer: S) -> Result<S::Ok, S::Error>
621 where
622 S: serde::Serializer,
623 {
624 if serializer.is_human_readable() {
625 let mut buf = [0u8; 256];
628
629 if source.as_ref().len() * 2 > buf.len() {
630 return Err(<S::Error as serde::ser::Error>::custom(
631 super::error_msg::ByteArrayTooLarge {
632 len: source.as_ref().len(),
633 supported_len: buf.len() / 2,
634 },
635 ));
636 }
637 let buf = &mut buf[..2 * source.as_ref().len()];
638 hex::encode_to_slice(source, buf)
639 .map_err(<S::Error as serde::ser::Error>::custom)?;
640 let buf_str = core::str::from_utf8(buf).map_err(|e| {
641 <S::Error as serde::ser::Error>::custom(super::error_msg::MalformedHex(e))
642 })?;
643 serializer.serialize_str(buf_str)
644 } else {
645 serializer.serialize_bytes(source.as_ref())
646 }
647 }
648 }
649
650 impl<'de, T> DeserializeAs<'de, T> for Bytes
651 where
652 T: ByteArray,
653 {
654 fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
655 where
656 D: serde::Deserializer<'de>,
657 {
658 pub struct BytesVisitor<T>(T);
659 impl<'de, T: AsMut<[u8]>> Visitor<'de> for BytesVisitor<T> {
660 type Value = T;
661 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
662 write!(f, "bytes")
663 }
664 fn visit_str<E>(mut self, v: &str) -> Result<Self::Value, E>
665 where
666 E: serde::de::Error,
667 {
668 hex::decode_to_slice(v, self.0.as_mut()).map_err(E::custom)?;
669 Ok(self.0)
670 }
671 fn visit_bytes<E>(mut self, v: &[u8]) -> Result<Self::Value, E>
672 where
673 E: serde::de::Error,
674 {
675 let expected_len = self.0.as_mut().len();
676 if v.len() != expected_len {
677 return Err(E::invalid_length(
678 v.len(),
679 &super::error_msg::ExpectedLen(expected_len),
680 ));
681 }
682 self.0.as_mut().copy_from_slice(v);
683 Ok(self.0)
684 }
685 fn visit_seq<A>(mut self, mut seq: A) -> Result<Self::Value, A::Error>
686 where
687 A: serde::de::SeqAccess<'de>,
688 {
689 let expected_len = self.0.as_mut().len();
690 let bytes = self.0.as_mut().iter_mut().enumerate();
691
692 for (i, byte_i) in bytes {
693 let byte_parsed = seq.next_element()?.ok_or_else(|| {
694 <A::Error as de::Error>::invalid_length(
695 i,
696 &super::error_msg::ExpectedLen(expected_len),
697 )
698 })?;
699 *byte_i = byte_parsed;
700 }
701
702 let mut unparsed_bytes = 0;
703 while seq.next_element::<serde::de::IgnoredAny>()?.is_some() {
704 unparsed_bytes += 1
705 }
706
707 if unparsed_bytes > 0 {
708 Err(<A::Error as de::Error>::invalid_length(
709 expected_len + unparsed_bytes,
710 &super::error_msg::ExpectedLen(expected_len),
711 ))
712 } else {
713 Ok(self.0)
714 }
715 }
716 }
717 let visitor = BytesVisitor(T::zeroes());
718 if deserializer.is_human_readable() {
719 deserializer.deserialize_str(visitor)
720 } else {
721 deserializer.deserialize_bytes(visitor)
722 }
723 }
724 }
725 }
726
727 pub(super) mod error_msg {
728 use core::fmt;
729
730 use serde::de::Expected;
731
732 pub struct ExpectedCurve<'g> {
733 pub expected: &'static str,
734 pub got: &'g str,
735 }
736
737 impl fmt::Display for ExpectedCurve<'_> {
738 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
739 write!(
740 f,
741 "expected {e} curve, got {g}",
742 e = self.expected,
743 g = self.got
744 )
745 }
746 }
747
748 pub struct ExpectedLen(pub usize);
749
750 impl Expected for ExpectedLen {
751 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
752 write!(f, "{} bytes", self.0)
753 }
754 }
755
756 pub struct InvalidPoint;
757 impl fmt::Display for InvalidPoint {
758 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
759 write!(f, "invalid point")
760 }
761 }
762
763 pub struct InvalidScalar;
764 impl fmt::Display for InvalidScalar {
765 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
766 write!(f, "invalid scalar")
767 }
768 }
769
770 pub struct MalformedHex(pub core::str::Utf8Error);
771 impl fmt::Display for MalformedHex {
772 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
773 write!(f, "malformed hex: {}", self.0)
774 }
775 }
776
777 pub struct ByteArrayTooLarge {
778 pub len: usize,
779 pub supported_len: usize,
780 }
781 impl fmt::Display for ByteArrayTooLarge {
782 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
783 write!(f, "byte array is too large: its length is {} bytes, but only up to {} bytes can be serialized", self.len, self.supported_len)
784 }
785 }
786 }
787}