use super::{Curve, CurveDecodingError, Field, GenericMultiExp, PrimeField};
use crate::common::{Deserial, Serial, Serialize};
use ark_ec::hashing::{HashToCurve, HashToCurveError};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use core::fmt;
#[derive(
PartialOrd, Ord, PartialEq, Eq, Copy, Clone, fmt::Debug, derive_more::From, derive_more::FromStr,
)]
pub struct ArkField<F>(pub(crate) F);
impl<F: Serial> Serial for ArkField<F> {
fn serial<B: crate::common::Buffer>(&self, out: &mut B) {
self.0.serial(out)
}
}
impl<F: Deserial> Deserial for ArkField<F> {
fn deserial<R: byteorder::ReadBytesExt>(source: &mut R) -> crate::common::ParseResult<Self> {
let res = F::deserial(source)?;
Ok(res.into())
}
}
impl<F: ark_ff::Field> Field for ArkField<F> {
fn random<R: rand::prelude::RngCore + ?std::marker::Sized>(rng: &mut R) -> Self {
F::rand(rng).into()
}
fn zero() -> Self {
F::zero().into()
}
fn one() -> Self {
F::one().into()
}
fn is_zero(&self) -> bool {
F::is_zero(&self.0)
}
fn square(&mut self) {
self.0.square_in_place();
}
fn double(&mut self) {
self.0.double_in_place();
}
fn negate(&mut self) {
self.0.neg_in_place();
}
fn add_assign(&mut self, other: &Self) {
self.0 += other.0
}
fn sub_assign(&mut self, other: &Self) {
self.0 -= other.0
}
fn mul_assign(&mut self, other: &Self) {
self.0 *= other.0
}
fn inverse(&self) -> Option<Self> {
self.0.inverse().map(|x| x.into())
}
}
impl<F: ark_ff::Field> ArkField<F> {
pub fn into_ark(&self) -> &F {
&self.0
}
}
impl<F: ark_ff::PrimeField> PrimeField for ArkField<F> {
const CAPACITY: u32 = Self::NUM_BITS - 1;
const NUM_BITS: u32 = F::MODULUS_BIT_SIZE;
fn into_repr(self) -> Vec<u64> {
self.0.into_bigint().as_ref().to_vec()
}
fn from_repr(repr: &[u64]) -> Result<Self, super::CurveDecodingError> {
let mut buffer = Vec::with_capacity(8 * repr.len());
for u in repr {
buffer.extend(u.to_le_bytes());
}
let big_int = num_bigint::BigUint::from_bytes_le(&buffer)
.try_into()
.map_err(|_| CurveDecodingError::NotInField(format!("{:?}", repr)))?;
let res =
F::from_bigint(big_int).ok_or(CurveDecodingError::NotInField(format!("{:?}", repr)))?;
Ok(res.into())
}
}
#[derive(PartialEq, Eq, Copy, Clone, fmt::Debug, derive_more::From)]
pub struct ArkGroup<G>(pub(crate) G);
impl<G: ark_ec::CurveGroup> ArkGroup<G> {
pub fn into_ark(&self) -> &G {
&self.0
}
}
impl<G: ark_ec::CurveGroup> Serial for ArkGroup<G> {
fn serial<B: crate::common::Buffer>(&self, out: &mut B) {
self.0
.into_affine()
.serialize_compressed(out)
.expect("Serialization expected to succeed");
}
}
impl<G: ark_ec::CurveGroup> Deserial for ArkGroup<G> {
fn deserial<R: byteorder::ReadBytesExt>(source: &mut R) -> crate::common::ParseResult<Self> {
let res = G::Affine::deserialize_compressed(source)?;
Ok(ArkGroup(res.into()))
}
}
impl From<HashToCurveError> for CurveDecodingError {
fn from(_value: HashToCurveError) -> Self {
CurveDecodingError::NotOnCurve
}
}
pub(crate) trait ArkCurveConfig<G: ark_ec::CurveGroup> {
const SCALAR_LENGTH: usize;
const GROUP_ELEMENT_LENGTH: usize;
const DOMAIN_STRING: &'static str;
type Hasher: ark_ec::hashing::HashToCurve<G>;
}
impl<G: ark_ec::CurveGroup + ArkCurveConfig<G>> Curve for ArkGroup<G>
where
<G as ark_ec::Group>::ScalarField: Serialize,
{
type MultiExpType = GenericMultiExp<Self>;
type Scalar = ArkField<<G as ark_ec::Group>::ScalarField>;
const GROUP_ELEMENT_LENGTH: usize = G::GROUP_ELEMENT_LENGTH;
const SCALAR_LENGTH: usize = G::SCALAR_LENGTH;
fn zero_point() -> Self {
ArkGroup(G::zero())
}
fn one_point() -> Self {
ArkGroup(G::generator())
}
fn is_zero_point(&self) -> bool {
self.0.is_zero()
}
fn inverse_point(&self) -> Self {
ArkGroup(-self.0)
}
fn double_point(&self) -> Self {
ArkGroup(self.0.double())
}
fn plus_point(&self, other: &Self) -> Self {
ArkGroup(self.0 + other.0)
}
fn minus_point(&self, other: &Self) -> Self {
ArkGroup(self.0 - other.0)
}
fn mul_by_scalar(&self, scalar: &Self::Scalar) -> Self {
ArkGroup(self.0 * scalar.0)
}
fn generate<R: rand::prelude::Rng>(rng: &mut R) -> Self {
ArkGroup(G::rand(rng))
}
fn generate_scalar<R: rand::prelude::Rng>(rng: &mut R) -> Self::Scalar {
<G::ScalarField as ark_ff::UniformRand>::rand(rng).into()
}
fn scalar_from_u64(n: u64) -> Self::Scalar {
ArkField(G::ScalarField::from(n))
}
fn scalar_from_bytes<A: AsRef<[u8]>>(bs: A) -> Self::Scalar {
let num_chunks = num::integer::div_ceil(Self::Scalar::CAPACITY, 64);
let mut fr = vec![0u64; num_chunks as usize];
for (chunk, place) in bs.as_ref().chunks(8).take(num_chunks as usize).zip(&mut fr) {
let mut v = [0u8; 8];
v[..chunk.len()].copy_from_slice(chunk);
*place = u64::from_le_bytes(v);
}
let total_size_in_bits = num_chunks * 64;
let num_bits_to_remove = total_size_in_bits - Self::Scalar::CAPACITY;
let mask = u64::MAX >> num_bits_to_remove;
*fr.last_mut().expect("Non empty vector expected") &= mask;
<Self::Scalar>::from_repr(&fr).unwrap_or_else(|_| {
panic!(
"The scalar {:?} with top {:} bits erased should be valid.",
fr, num_bits_to_remove
)
})
}
fn hash_to_group(m: &[u8]) -> Result<Self, CurveDecodingError> {
let hasher = G::Hasher::new(G::DOMAIN_STRING.as_ref())?;
let res = G::Hasher::hash(&hasher, m)?;
Ok(ArkGroup(res.into()))
}
}