use std::{any::TypeId, fmt::Debug};
use ff::Field;
use group::{Curve, Group};
use itertools::izip;
use midnight_curves::{
msm::msm_best,
pairing::{Engine, MillerLoopResult, MultiMillerLoop},
CurveAffine, Fq, G1Projective,
};
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
use super::params::ParamsVerifierKZG;
use crate::{
poly::{
commitment::{Guard, PolynomialCommitmentScheme},
kzg::KZGCommitmentScheme,
CommitmentLabel, Error,
},
utils::{
arithmetic::{CurveExt, MSM},
helpers::ProcessedSerdeObject,
},
};
#[derive(Clone, Default, Debug)]
pub struct MSMKZG<E: Engine> {
pub(crate) scalars: Vec<E::Fr>,
pub(crate) bases: Vec<E::G1>,
pub(crate) labels: Vec<CommitmentLabel>,
}
impl<E: Engine> MSMKZG<E> {
pub fn init() -> Self {
MSMKZG {
scalars: vec![],
bases: vec![],
labels: vec![],
}
}
pub fn from_many(msms: Vec<Self>) -> Self {
let len = msms.iter().map(|m| m.scalars.len()).sum();
let mut scalars = Vec::with_capacity(len);
let mut bases = Vec::with_capacity(len);
let mut labels = Vec::with_capacity(len);
for mut msm in msms {
scalars.append(&mut msm.scalars);
bases.append(&mut msm.bases);
labels.append(&mut msm.labels);
}
Self {
scalars,
bases,
labels,
}
}
pub fn from_base(base: &E::G1) -> Self {
MSMKZG {
scalars: vec![E::Fr::ONE],
bases: vec![*base],
labels: vec![CommitmentLabel::NoLabel],
}
}
}
impl<E: Engine + Debug> MSMKZG<E>
where
E::G1Affine: CurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
{
pub fn collapse(&mut self) {
debug_assert!(
self.labels
.iter()
.all(|l| matches!(l, CommitmentLabel::NoLabel | CommitmentLabel::Advice(_))),
"collapse: all labels must be NoLabel or Advice, found: {:?}",
self.labels,
);
let point = self.eval();
self.scalars = vec![E::Fr::ONE];
self.bases = vec![point];
self.labels = vec![CommitmentLabel::NoLabel];
}
}
impl<E: Engine + Debug> MSM<E::G1Affine> for MSMKZG<E>
where
E::G1Affine: CurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
{
fn append_term(&mut self, scalar: E::Fr, point: E::G1, label: CommitmentLabel) {
self.scalars.push(scalar);
self.bases.push(point);
self.labels.push(label);
}
fn add_msm(&mut self, other: &Self) {
self.scalars.reserve(other.scalars().len());
self.scalars.extend_from_slice(&other.scalars());
self.bases.reserve(other.bases().len());
self.bases.extend_from_slice(&other.bases());
self.labels.reserve(other.labels().len());
self.labels.extend_from_slice(&other.labels());
}
fn scale(&mut self, factor: E::Fr) {
self.scalars.par_iter_mut().for_each(|s| {
*s *= &factor;
})
}
fn check(&self) -> bool {
bool::from(self.eval().is_identity())
}
fn eval(&self) -> E::G1 {
if self.scalars == vec![E::Fr::ONE] {
self.bases[0]
} else {
msm_specific::<E::G1Affine>(&self.scalars, &self.bases)
}
}
fn bases(&self) -> Vec<E::G1> {
self.bases.clone()
}
fn scalars(&self) -> Vec<E::Fr> {
self.scalars.clone()
}
fn labels(&self) -> Vec<CommitmentLabel> {
self.labels.clone()
}
}
#[allow(unsafe_code)]
pub fn msm_specific<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C::Curve]) -> C::Curve {
let (coeffs, bases): (Vec<C::Scalar>, Vec<C::Curve>) = coeffs
.iter()
.zip(bases)
.filter(|(s, _)| !s.is_zero_vartime())
.map(|(s, b)| (*s, *b))
.unzip();
if coeffs.is_empty() {
return C::Curve::identity();
}
if coeffs.len() <= (2 << 18) && TypeId::of::<C>() == TypeId::of::<midnight_curves::G1Affine>() {
let coeffs_slice = coeffs.as_slice();
let bases_slice = bases.as_slice();
let coeffs = unsafe { &*(coeffs_slice as *const _ as *const [Fq]) };
let bases = unsafe { &*(bases_slice as *const _ as *const [G1Projective]) };
let res = G1Projective::multi_exp(bases, coeffs);
unsafe { std::mem::transmute_copy(&res) }
} else {
let mut affine_bases = vec![C::identity(); coeffs.len()];
C::Curve::batch_normalize(&bases, &mut affine_bases);
msm_best(&coeffs, &affine_bases)
}
}
#[derive(Debug, Clone)]
pub struct DualMSM<E: Engine> {
pub(crate) left: MSMKZG<E>,
pub(crate) right: MSMKZG<E>,
}
pub type SplitDualMSM<'a, E> = (
Vec<(
&'a CommitmentLabel,
&'a <E as Engine>::Fr,
&'a <E as Engine>::G1,
)>,
Vec<(
&'a CommitmentLabel,
&'a <E as Engine>::Fr,
&'a <E as Engine>::G1,
)>,
);
impl<E: MultiMillerLoop + Debug> Default for DualMSM<E>
where
E::G1Affine: CurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
{
fn default() -> Self {
Self::init()
}
}
impl<E: MultiMillerLoop> Guard<E::Fr, KZGCommitmentScheme<E>> for DualMSM<E>
where
E::G1: Default + CurveExt<ScalarExt = E::Fr> + ProcessedSerdeObject,
E::G1Affine: Default + CurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
{
fn verify(
self,
params: &<KZGCommitmentScheme<E> as PolynomialCommitmentScheme<E::Fr>>::VerifierParameters,
) -> Result<(), Error> {
self.check(params).then_some(()).ok_or(Error::OpeningError)
}
}
impl<E: MultiMillerLoop + Debug> DualMSM<E>
where
E::G1Affine: CurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
{
pub fn init() -> Self {
Self {
left: MSMKZG::init(),
right: MSMKZG::init(),
}
}
pub fn new(left: MSMKZG<E>, right: MSMKZG<E>) -> Self {
Self { left, right }
}
pub fn split(&self) -> SplitDualMSM<'_, E> {
let left = izip!(
self.left.labels.iter(),
self.left.scalars.iter(),
self.left.bases.iter()
)
.collect();
let right = izip!(
self.right.labels.iter(),
self.right.scalars.iter(),
self.right.bases.iter(),
)
.collect();
(left, right)
}
pub fn scale(&mut self, e: E::Fr) {
self.left.scale(e);
self.right.scale(e);
}
pub fn add_msm(&mut self, other: Self) {
self.left.add_msm(&other.left);
self.right.add_msm(&other.right);
}
pub fn check(self, params: &ParamsVerifierKZG<E>) -> bool {
let left = if self.left.scalars.len() == 1 && self.left.scalars[0] == E::Fr::ONE {
self.left.bases[0]
} else {
self.left.eval()
};
let right = self.right.eval();
let (term_1, term_2) = (
(&left.into(), ¶ms.s_g2_prepared),
(&right.into(), ¶ms.n_g2_prepared),
);
let terms = &[term_1, term_2];
bool::from(E::multi_miller_loop(&terms[..]).final_exponentiation().is_identity())
}
}