use super::{field_adapters::FFField, Curve, CurveDecodingError, Field, MultiExp, PrimeField};
use crate::common::{Buffer, Deserial, Serial};
use byteorder::{ByteOrder, LittleEndian};
use curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_POINT,
ristretto::{CompressedRistretto, RistrettoPoint, VartimeRistrettoPrecomputation},
scalar::Scalar,
traits::{Identity, VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul},
};
use sha2::Sha512;
use std::{borrow::Borrow, result::Result};
impl Serial for Scalar {
fn serial<B: Buffer>(&self, out: &mut B) {
let res: &[u8; 32] = self.as_bytes();
out.write_all(res)
.expect("Writing to a buffer should not fail.");
}
}
impl Deserial for Scalar {
fn deserial<R: byteorder::ReadBytesExt>(source: &mut R) -> crate::common::ParseResult<Self> {
let mut buf: [u8; 32] = [0; 32];
source.read_exact(&mut buf)?;
let res: Option<_> = Scalar::from_canonical_bytes(buf).into();
res.ok_or(anyhow::anyhow!(
"Deserialization failed! Not a field value!"
))
}
}
impl PrimeField for FFField<Scalar> {
const CAPACITY: u32 = <Scalar as ff::PrimeField>::CAPACITY;
const NUM_BITS: u32 = <Scalar as ff::PrimeField>::NUM_BITS;
fn into_repr(self) -> Vec<u64> {
let bytes = <Scalar as ff::PrimeField>::to_repr(&self.0);
let mut vec: Vec<u64> = Vec::new();
for chunk in bytes.chunks_exact(8) {
let x: [u8; 8] = chunk.try_into().unwrap();
let x_64 = u64::from_le_bytes(x);
vec.push(x_64);
}
vec
}
fn from_repr(r: &[u64]) -> Result<Self, super::CurveDecodingError> {
let tmp: [u64; 4] = r
.try_into()
.map_err(|_| super::CurveDecodingError::NotInField(format!("{:?}", r)))?;
let mut s_bytes = [0u8; 32];
let mut offset = 0;
for x in tmp {
let max = offset + 8;
LittleEndian::write_u64(&mut s_bytes[offset..max], x);
offset = max;
}
let res: Option<_> = Scalar::from_canonical_bytes(s_bytes).into();
let scalar: Scalar = res.ok_or(super::CurveDecodingError::NotInField(format!(
"{:?}",
s_bytes
)))?;
Ok(scalar.into())
}
}
impl Serial for RistrettoPoint {
fn serial<B: Buffer>(&self, out: &mut B) {
let compressed_point = self.compress();
let res: &[u8; 32] = compressed_point.as_bytes();
out.write_all(res)
.expect("Writing to a buffer should not fail.");
}
}
impl Deserial for RistrettoPoint {
fn deserial<R: byteorder::ReadBytesExt>(source: &mut R) -> crate::common::ParseResult<Self> {
let mut buf: [u8; 32] = [0; 32];
source.read_exact(&mut buf)?;
let res = CompressedRistretto::from_slice(&buf)?;
let point = res.decompress().ok_or(anyhow::anyhow!("Failed!"))?;
Ok(point)
}
}
impl Curve for RistrettoPoint {
type MultiExpType = RistrettoMultiExpNoPrecompute;
type Scalar = FFField<Scalar>;
const GROUP_ELEMENT_LENGTH: usize = 32;
const SCALAR_LENGTH: usize = 32;
fn zero_point() -> Self {
Self::identity()
}
fn one_point() -> Self {
RISTRETTO_BASEPOINT_POINT
}
fn is_zero_point(&self) -> bool {
self == &Self::zero_point()
}
fn inverse_point(&self) -> Self {
-self
}
fn double_point(&self) -> Self {
self + self
}
fn plus_point(&self, other: &Self) -> Self {
self + other
}
fn minus_point(&self, other: &Self) -> Self {
self - other
}
fn mul_by_scalar(&self, scalar: &Self::Scalar) -> Self {
self * scalar.0
}
fn generate<R: rand::Rng>(rng: &mut R) -> Self {
let mut uniform_bytes = [0u8; 64];
rng.fill_bytes(&mut uniform_bytes);
RistrettoPoint::from_uniform_bytes(&uniform_bytes)
}
fn generate_scalar<R: rand::Rng>(rng: &mut R) -> Self::Scalar {
Self::Scalar::random(rng)
}
fn scalar_from_u64(n: u64) -> Self::Scalar {
Scalar::from(n).into()
}
fn scalar_from_bytes<A: AsRef<[u8]>>(bs: A) -> Self::Scalar {
let mut fr = [0u64; 4];
for (i, chunk) in bs.as_ref().chunks(8).take(4).enumerate() {
let mut v = [0u8; 8];
v[..chunk.len()].copy_from_slice(chunk);
fr[i] = u64::from_le_bytes(v);
}
fr[3] &= !(1u64 << 63 | 1u64 << 62 | 1u64 << 61 | 1u64 << 60);
<Self::Scalar as PrimeField>::from_repr(&fr)
.expect("The scalar with top two bits erased should be valid.")
}
fn hash_to_group(m: &[u8]) -> Result<Self, CurveDecodingError> {
Result::Ok(RistrettoPoint::hash_from_bytes::<Sha512>(m))
}
}
impl MultiExp for VartimeRistrettoPrecomputation {
type CurvePoint = RistrettoPoint;
fn new<X: Borrow<Self::CurvePoint>>(gs: &[X]) -> Self {
<Self as VartimePrecomputedMultiscalarMul>::new(gs.iter().map(|p| p.borrow()))
}
fn multiexp<X: Borrow<<Self::CurvePoint as Curve>::Scalar>>(
&self,
exps: &[X],
) -> Self::CurvePoint {
self.vartime_multiscalar_mul(exps.iter().map(|p| p.borrow().0))
}
}
pub struct RistrettoMultiExpNoPrecompute {
points: Vec<RistrettoPoint>,
}
impl MultiExp for RistrettoMultiExpNoPrecompute {
type CurvePoint = RistrettoPoint;
fn new<X: Borrow<Self::CurvePoint>>(gs: &[X]) -> Self {
Self {
points: gs.iter().map(|x| *x.borrow()).collect(),
}
}
fn multiexp<X: Borrow<<Self::CurvePoint as Curve>::Scalar>>(
&self,
exps: &[X],
) -> Self::CurvePoint {
Self::CurvePoint::vartime_multiscalar_mul(exps.iter().map(|p| p.borrow().0), &self.points)
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::{
common::*,
curve_arithmetic::{field_adapters::FFField, Field},
};
use curve25519_dalek::{ristretto::RistrettoPoint, Scalar};
use rand::{Rng, RngCore};
use std::io::Cursor;
type RistrettoScalar = FFField<Scalar>;
#[test]
fn test_scalar_serialization() {
let mut csprng = rand::thread_rng();
for _ in 0..1000 {
let mut out = Vec::<u8>::new();
let scalar = RistrettoScalar::random(&mut csprng);
scalar.serial(&mut out);
let scalar_res = RistrettoScalar::deserial(&mut Cursor::new(out));
assert!(scalar_res.is_ok());
assert_eq!(scalar, scalar_res.unwrap());
}
}
#[test]
fn test_point_serialization() {
let mut csprng = rand::thread_rng();
for _ in 0..1000 {
let mut out = Vec::<u8>::new();
let point = RistrettoPoint::generate(&mut csprng);
point.serial(&mut out);
let point_res = RistrettoPoint::deserial(&mut Cursor::new(out));
assert!(point_res.is_ok());
assert_eq!(point, point_res.unwrap());
}
}
#[test]
fn test_into_from_rep() {
let mut csprng = rand::thread_rng();
for _ in 0..1000 {
let scalar = RistrettoScalar::random(&mut csprng);
let scalar_vec64 = scalar.into_repr();
let scalar_res = RistrettoScalar::from_repr(&scalar_vec64);
assert!(scalar_res.is_ok());
assert_eq!(scalar, scalar_res.unwrap());
}
}
#[test]
fn test_into() {
let res: Option<Scalar> = Scalar::from_canonical_bytes([
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 254, 255, 255, 255, 255, 255, 255, 255,
0, 0, 0, 0, 0, 0, 0, 0,
])
.into();
let s: RistrettoScalar = res.expect("Expected a valid scalar").into();
assert_eq!(s.into_repr(), [1u64, 0u64, u64::MAX - 1, 0u64]);
}
#[test]
fn test_scalar_from_bytes_small() {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let n = <RistrettoScalar as Field>::random(&mut rng);
let bytes = to_bytes(&n);
let m = <RistrettoPoint as Curve>::scalar_from_bytes(&bytes);
let n = n.into_repr();
let m = m.into_repr();
let mask = !(1u64 << 63 | 1u64 << 62 | 1u64 << 61 | 1u64 << 60);
assert_eq!(n[0], m[0], "First limb.");
assert_eq!(n[1], m[1], "Second limb.");
assert_eq!(n[2], m[2], "Third limb.");
assert_eq!(n[3] & mask, m[3], "Fourth limb with top bit masked.");
}
}
#[test]
fn test_scalar_from_bytes_big() {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let mut lower_bytes: [u8; 31] = [0u8; 31];
rng.fill_bytes(&mut lower_bytes);
let mut fits_capacity_bytes = [0u8; 32];
fits_capacity_bytes[0..31].copy_from_slice(&lower_bytes);
let n = rng.gen_range(0..16);
fits_capacity_bytes[31] = n;
let fits_capacity = <RistrettoPoint as Curve>::scalar_from_bytes(fits_capacity_bytes);
let i = rng.gen_range(1..16);
let mut bytes: [u8; 32] = [0u8; 32];
bytes[0..31].copy_from_slice(&lower_bytes);
bytes[31] = n + (i << 4);
let over_capacity = <RistrettoPoint as Curve>::scalar_from_bytes(bytes);
assert_eq!(fits_capacity, over_capacity);
}
}
}