use core::{
fmt::{self, Formatter},
marker::PhantomData,
};
use ff::PrimeField;
use group::{Group, GroupEncoding};
use serde::{
self,
de::{Error as DError, Visitor},
Deserializer, Serializer,
};
pub fn serialize_point<G, S>(g: &G, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
G: Group + GroupEncoding,
{
serialize_(g.to_bytes(), s)
}
pub fn deserialize_point<'de, G, D>(d: D) -> Result<G, D::Error>
where
D: Deserializer<'de>,
G: Group + GroupEncoding,
{
let bytes = deserialize_(d)?;
Option::from(G::from_bytes(&bytes)).ok_or(DError::custom("invalid group element"))
}
pub fn serialize_scalar<F, S>(f: &F, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
F: PrimeField,
{
serialize_(f.to_repr(), s)
}
pub fn deserialize_scalar<'de, F, D>(d: D) -> Result<F, D::Error>
where
D: Deserializer<'de>,
F: PrimeField,
{
let repr = deserialize_(d)?;
Option::from(F::from_repr(repr)).ok_or(DError::custom("invalid prime field element"))
}
fn serialize_<B, S>(bytes: B, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
B: AsRef<[u8]> + AsMut<[u8]> + Default,
{
if s.is_human_readable() {
s.serialize_str(&hex::encode(bytes.as_ref()))
} else {
s.serialize_bytes(bytes.as_ref())
}
}
fn deserialize_<'de, B: AsRef<[u8]> + AsMut<[u8]> + Default, D: Deserializer<'de>>(
d: D,
) -> Result<B, D::Error> {
if d.is_human_readable() {
struct StrVisitor<B: AsRef<[u8]> + AsMut<[u8]> + Default>(PhantomData<B>);
impl<'de, B> Visitor<'de> for StrVisitor<B>
where
B: AsRef<[u8]> + AsMut<[u8]> + Default,
{
type Value = B;
fn expecting(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "a {} length hex string", B::default().as_ref().len() * 2)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: DError,
{
let mut repr = B::default();
let length = repr.as_ref().len();
if v.len() != length * 2 {
return Err(DError::custom("invalid length"));
}
hex::decode_to_slice(v, repr.as_mut())
.map_err(|_| DError::custom("invalid input"))?;
Ok(repr)
}
}
d.deserialize_str(StrVisitor(PhantomData))
} else {
struct ByteVisitor<B: AsRef<[u8]> + AsMut<[u8]> + Default>(PhantomData<B>);
impl<'de, B> Visitor<'de> for ByteVisitor<B>
where
B: AsRef<[u8]> + AsMut<[u8]> + Default,
{
type Value = B;
fn expecting(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "a {} byte", B::default().as_ref().len())
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let mut repr = B::default();
if v.len() != repr.as_ref().len() {
return Err(serde::de::Error::custom("invalid length"));
}
repr.as_mut().copy_from_slice(v);
Ok(repr)
}
}
d.deserialize_bytes(ByteVisitor(PhantomData))
}
}