use std::{fmt::Debug, io};
use ff::{Field, PrimeField};
use group::{prime::PrimeCurveAffine, Curve, Group};
use midnight_curves::pairing::{Engine, MultiMillerLoop};
use rand_core::RngCore;
use crate::{
poly::commitment::Params,
utils::{
arithmetic::{g_to_lagrange, parallelize},
helpers::ProcessedSerdeObject,
SerdeFormat,
},
};
#[derive(Debug, Clone)]
pub struct ParamsKZG<E: Engine> {
pub(crate) g: Vec<E::G1>,
pub(crate) g_lagrange: Vec<E::G1>,
pub(crate) g2: E::G2,
pub(crate) s_g2: E::G2,
}
impl<E: Engine> Params for ParamsKZG<E> {
fn max_k(&self) -> u32 {
assert_eq!(self.g.len(), self.g_lagrange.len());
self.g.len().ilog2()
}
fn downsize(&mut self, new_k: u32) {
ParamsKZG::<E>::downsize(self, new_k)
}
}
impl<E: Engine + Debug> ParamsKZG<E> {
pub fn downsize(&mut self, new_k: u32) {
if self.max_k() == new_k {
return;
}
let n = 1 << new_k;
assert!(n < self.g_lagrange.len());
self.g.truncate(n);
self.g_lagrange = g_to_lagrange(&self.g, new_k);
}
pub fn unsafe_setup<R: RngCore>(k: u32, rng: R) -> Self {
assert!(k <= E::Fr::S);
let n: u64 = 1 << k;
let g1 = E::G1::generator();
let s = <E::Fr>::random(rng);
let mut g = vec![E::G1::identity(); n as usize];
parallelize(&mut g, |g, start| {
let mut current_g: E::G1 = g1;
current_g *= s.pow_vartime([start as u64]);
for g in g.iter_mut() {
*g = current_g;
current_g *= s;
}
});
let mut g_lagrange = vec![E::G1::identity(); n as usize];
let mut root = E::Fr::ROOT_OF_UNITY;
for _ in k..E::Fr::S {
root = root.square();
}
let n_inv = E::Fr::from(n).invert().expect("inversion should be ok for n = 1<<k");
let multiplier = (s.pow_vartime([n]) - E::Fr::ONE) * n_inv;
parallelize(&mut g_lagrange, |g, start| {
for (idx, g) in g.iter_mut().enumerate() {
let offset = start + idx;
let root_pow = root.pow_vartime([offset as u64]);
let scalar = multiplier * root_pow * (s - root_pow).invert().unwrap();
*g = g1 * scalar;
}
});
let g2 = E::G2::generator();
let s_g2 = g2 * s;
Self {
g,
g_lagrange,
g2,
s_g2,
}
}
pub fn from_parts(
k: u32,
g: Vec<E::G1>,
g_lagrange: Option<Vec<E::G1>>,
g2: E::G2,
s_g2: E::G2,
) -> Self {
Self {
g_lagrange: match g_lagrange {
Some(g_l) => g_l,
None => g_to_lagrange(&g, k),
},
g,
g2,
s_g2,
}
}
pub fn g_lagrange(&self) -> &[E::G1] {
&self.g_lagrange
}
pub fn g2(&self) -> E::G2 {
self.g2
}
pub fn s_g2(&self) -> E::G2 {
self.s_g2
}
pub fn write_custom<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()>
where
E::G1: Curve + ProcessedSerdeObject,
E::G2: Curve + ProcessedSerdeObject,
{
writer.write_all(&self.g.len().ilog2().to_le_bytes())?;
for el in self.g.iter() {
el.write(writer, format)?;
}
for el in self.g_lagrange.iter() {
el.write(writer, format)?;
}
self.g2.write(writer, format)?;
self.s_g2.write(writer, format)?;
Ok(())
}
pub fn read_custom<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self>
where
E::G1: Curve + ProcessedSerdeObject,
E::G2: Curve + ProcessedSerdeObject,
{
let mut k = [0u8; 4];
reader.read_exact(&mut k[..])?;
let k = u32::from_le_bytes(k);
let n = 1 << k;
let (g, g_lagrange) = match format {
SerdeFormat::Processed => {
use group::GroupEncoding;
let load_points_from_file_parallelly =
|reader: &mut R| -> io::Result<Vec<Option<E::G1>>> {
let mut points_compressed =
vec![<<E as Engine>::G1 as GroupEncoding>::Repr::default(); n];
for points_compressed in points_compressed.iter_mut() {
reader.read_exact((*points_compressed).as_mut())?;
}
let mut points = vec![Option::<E::G1>::None; n];
parallelize(&mut points, |points, chunks| {
for (i, point) in points.iter_mut().enumerate() {
*point =
Option::from(E::G1::from_bytes(&points_compressed[chunks + i]));
}
});
Ok(points)
};
let g = load_points_from_file_parallelly(reader)?;
let g: Vec<<E as Engine>::G1> = g
.iter()
.map(|point| point.ok_or_else(|| io::Error::other("invalid point encoding")))
.collect::<Result<_, _>>()?;
let g_lagrange = load_points_from_file_parallelly(reader)?;
let g_lagrange: Vec<<E as Engine>::G1> = g_lagrange
.iter()
.map(|point| point.ok_or_else(|| io::Error::other("invalid point encoding")))
.collect::<Result<_, _>>()?;
(g, g_lagrange)
}
SerdeFormat::RawBytes => {
let g = (0..n)
.map(|_| <E::G1 as ProcessedSerdeObject>::read(reader, format))
.collect::<Result<Vec<_>, _>>()?;
let g_lagrange = (0..n)
.map(|_| <E::G1 as ProcessedSerdeObject>::read(reader, format))
.collect::<Result<Vec<_>, _>>()?;
(g, g_lagrange)
}
SerdeFormat::RawBytesUnchecked => {
let g = (0..n)
.map(|_| <E::G1 as ProcessedSerdeObject>::read(reader, format).unwrap())
.collect::<Vec<_>>();
let g_lagrange = (0..n)
.map(|_| <E::G1 as ProcessedSerdeObject>::read(reader, format).unwrap())
.collect::<Vec<_>>();
(g, g_lagrange)
}
};
let g2 = E::G2::read(reader, format)?;
let s_g2 = E::G2::read(reader, format)?;
Ok(Self {
g,
g_lagrange,
g2,
s_g2,
})
}
}
#[derive(Clone, Debug)]
pub struct ParamsVerifierKZG<E: MultiMillerLoop> {
pub(crate) s_g2: E::G2,
pub(crate) n_g2_prepared: E::G2Prepared,
pub(crate) s_g2_prepared: E::G2Prepared,
}
impl<E: MultiMillerLoop + Debug> ParamsVerifierKZG<E>
where
E::G2: Curve + ProcessedSerdeObject,
{
pub fn s_g2(&self) -> E::G2 {
self.s_g2
}
pub fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
self.s_g2.write(writer, format)?;
Ok(())
}
pub fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
let s_g2 = E::G2::read(reader, format)?;
let s_g2_prepared = E::G2Prepared::from(s_g2.into());
let n_g2_prepared = E::G2Prepared::from(-E::G2Affine::generator());
Ok(Self {
s_g2,
n_g2_prepared,
s_g2_prepared,
})
}
}
impl<E: MultiMillerLoop + Debug> ParamsKZG<E> {
pub fn verifier_params(&self) -> ParamsVerifierKZG<E> {
let n_g2_prepared = E::G2Prepared::from((-self.g2).into());
let s_g2_prepared = E::G2Prepared::from(self.s_g2.into());
ParamsVerifierKZG {
s_g2: self.s_g2,
n_g2_prepared,
s_g2_prepared,
}
}
}
#[cfg(test)]
mod test {
use rand_core::OsRng;
use crate::{
poly::{
commitment::PolynomialCommitmentScheme,
kzg::{params::ParamsKZG, KZGCommitmentScheme},
},
utils::SerdeFormat,
};
#[test]
fn test_commit_lagrange() {
const K: u32 = 6;
use midnight_curves::{Bls12, Fq};
use crate::poly::EvaluationDomain;
let params: ParamsKZG<Bls12> = ParamsKZG::unsafe_setup(K, OsRng);
let domain = EvaluationDomain::new(1, K);
let mut a = domain.empty_lagrange();
for (i, a) in a.iter_mut().enumerate() {
*a = Fq::from(i as u64);
}
let b = domain.lagrange_to_coeff(a.clone());
let tmp = KZGCommitmentScheme::commit_lagrange(¶ms, &a);
let commitment = KZGCommitmentScheme::commit(¶ms, &b);
assert_eq!(commitment, tmp);
}
#[test]
fn test_parameter_serialisation_roundtrip() {
const K: u32 = 4;
use midnight_curves::Bls12;
let params0: ParamsKZG<Bls12> = ParamsKZG::unsafe_setup(K, OsRng);
let mut data = vec![];
ParamsKZG::write_custom(¶ms0, &mut data, SerdeFormat::RawBytesUnchecked).unwrap();
let params1 =
ParamsKZG::<Bls12>::read_custom::<_>(&mut &data[..], SerdeFormat::RawBytesUnchecked)
.unwrap();
assert_eq!(params0.g.len(), params1.g.len());
assert_eq!(params0.g_lagrange.len(), params1.g_lagrange.len());
assert_eq!(params0.g, params1.g);
assert_eq!(params0.g_lagrange, params1.g_lagrange);
assert_eq!(params0.g2, params1.g2);
assert_eq!(params0.s_g2, params1.s_g2);
}
}