use std::{
fmt::Debug,
ops::{Add, Mul},
};
pub use ff::Field;
use group::{
ff::{BatchInvert, PrimeField},
prime::{PrimeCurve, PrimeCurveAffine},
GroupOpsOwned, ScalarMulOwned,
};
use midnight_curves::{fft::best_fft, pairing::MultiMillerLoop};
pub use midnight_curves::{CurveAffine, CurveExt};
pub trait FftGroup<Scalar: Field>:
Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
{
}
impl<T, Scalar> FftGroup<Scalar> for T
where
Scalar: Field,
T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
{
}
pub fn g_to_lagrange<C: PrimeCurve>(g_projective: &[C], k: u32) -> Vec<C> {
let n_inv = C::Scalar::TWO_INV.pow_vartime([k as u64, 0, 0, 0]);
let mut omega_inv = C::Scalar::ROOT_OF_UNITY_INV;
for _ in k..C::Scalar::S {
omega_inv = omega_inv.square();
}
let mut g_lagrange = g_projective.to_vec();
best_fft(&mut g_lagrange, omega_inv, k);
parallelize(&mut g_lagrange, |g, _| {
for g in g.iter_mut() {
*g *= n_inv;
}
});
g_lagrange.to_vec()
}
pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
fn evaluate<F: Field>(poly: &[F], point: F) -> F {
poly.iter().rev().fold(F::ZERO, |acc, coeff| acc * point + coeff)
}
let n = poly.len();
let num_threads = rayon::current_num_threads();
if n * 2 < num_threads {
evaluate(poly, point)
} else {
let chunk_size = n.div_ceil(num_threads);
let mut parts = vec![F::ZERO; num_threads];
rayon::scope(|scope| {
for (chunk_idx, (out, poly)) in
parts.chunks_mut(1).zip(poly.chunks(chunk_size)).enumerate()
{
scope.spawn(move |_| {
let start = chunk_idx * chunk_size;
out[0] = evaluate(poly, point) * point.pow_vartime([start as u64, 0, 0, 0]);
});
}
});
parts.iter().fold(F::ZERO, |acc, coeff| acc + coeff)
}
}
pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
assert_eq!(a.len(), b.len());
let mut acc = F::ZERO;
for (a, b) in a.iter().zip(b.iter()) {
acc += (*a) * (*b);
}
acc
}
pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
where
I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
{
b = -b;
let a = a.into_iter();
let mut q = vec![F::ZERO; a.len() - 1];
let mut tmp = F::ZERO;
for (q, r) in q.iter_mut().rev().zip(a.rev()) {
let mut lead_coeff = *r;
lead_coeff.sub_assign(&tmp);
*q = lead_coeff;
tmp = lead_coeff;
tmp.mul_assign(&b);
}
q
}
pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
let f = &f;
let total_iters = v.len();
let num_threads = rayon::current_num_threads();
let base_chunk_size = total_iters / num_threads;
let cutoff_chunk_id = total_iters % num_threads;
let split_pos = cutoff_chunk_id * (base_chunk_size + 1);
let (v_hi, v_lo) = v.split_at_mut(split_pos);
rayon::scope(|scope| {
if cutoff_chunk_id != 0 {
for (chunk_id, chunk) in v_hi.chunks_exact_mut(base_chunk_size + 1).enumerate() {
let offset = chunk_id * (base_chunk_size + 1);
scope.spawn(move |_| f(chunk, offset));
}
}
if base_chunk_size != 0 {
for (chunk_id, chunk) in v_lo.chunks_exact_mut(base_chunk_size).enumerate() {
let offset = split_pos + (chunk_id * base_chunk_size);
scope.spawn(move |_| f(chunk, offset));
}
}
});
}
pub fn lagrange_interpolate<F: Field + Ord>(points: &[F], evals: &[F]) -> Vec<F> {
assert_eq!(points.len(), evals.len());
{
let mut sorted_points = points.to_vec();
sorted_points.sort();
assert!(!sorted_points.windows(2).any(|w| w[0] == w[1]));
}
if points.len() == 1 {
vec![evals[0]]
} else {
let mut denoms = Vec::with_capacity(points.len());
for (j, x_j) in points.iter().enumerate() {
let mut denom = Vec::with_capacity(points.len() - 1);
for x_k in points.iter().enumerate().filter(|&(k, _)| k != j).map(|a| a.1) {
denom.push(*x_j - x_k);
}
denoms.push(denom);
}
denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
let mut final_poly = vec![F::ZERO; points.len()];
for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
let mut tmp: Vec<F> = Vec::with_capacity(points.len());
let mut product = Vec::with_capacity(points.len() - 1);
tmp.push(F::ONE);
for (x_k, denom) in points
.iter()
.enumerate()
.filter(|&(k, _)| k != j)
.map(|a| a.1)
.zip(denoms.into_iter())
{
product.resize(tmp.len() + 1, F::ZERO);
for ((a, b), product) in tmp
.iter()
.chain(std::iter::once(&F::ZERO))
.zip(std::iter::once(&F::ZERO).chain(tmp.iter()))
.zip(product.iter_mut())
{
*product = *a * (-denom * x_k) + *b * denom;
}
std::mem::swap(&mut tmp, &mut product);
}
assert_eq!(tmp.len(), points.len());
assert_eq!(product.len(), points.len() - 1);
for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp.into_iter()) {
*final_coeff += interpolation_coeff * eval;
}
}
final_poly
}
}
#[cfg(feature = "truncated-challenges")]
use num_bigint::BigUint;
#[cfg(feature = "truncated-challenges")]
pub(crate) fn truncate<F: PrimeField>(scalar: F) -> F {
let nb_bytes = F::NUM_BITS.div_ceil(8).div_ceil(2) as usize;
let bytes = scalar.to_repr().as_ref()[..nb_bytes].to_vec();
let bi = BigUint::from_bytes_le(&bytes);
F::from_str_vartime(&BigUint::to_string(&bi)).unwrap()
}
#[cfg(feature = "truncated-challenges")]
pub(crate) fn truncated_powers<F: PrimeField>(base: F) -> impl Iterator<Item = F> {
powers(base).map(truncate)
}
pub(crate) fn powers<F: Field>(base: F) -> impl Iterator<Item = F> {
std::iter::successors(Some(F::ONE), move |power| Some(base * power))
}
pub(crate) fn inner_product<F: PrimeField, T: Mul<F, Output = T> + Add<T, Output = T> + Clone>(
polys: &[T],
scalars: impl Iterator<Item = F>,
) -> T {
polys
.iter()
.zip(scalars)
.map(|(p, s)| p.clone() * s)
.reduce(|acc, p| acc + p)
.unwrap()
}
pub(crate) fn msm_inner_product<E>(mut msms: Vec<MSMKZG<E>>, scalars: &[E::Fr]) -> MSMKZG<E>
where
E: MultiMillerLoop + Debug,
E::G1Affine: CurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
E::Fr: Ord,
{
let len: usize = msms.iter().map(|m| m.scalars.len()).sum();
let mut new_scalars = Vec::with_capacity(len);
let mut new_bases = Vec::with_capacity(len);
msms.iter_mut().zip(scalars.iter()).for_each(|(msm, s)| {
msm.scale(*s);
new_scalars.extend(&msm.scalars);
new_bases.extend(&msm.bases);
});
MSMKZG {
scalars: new_scalars,
bases: new_bases,
}
}
pub(crate) fn evals_inner_product<F: PrimeField + Clone>(
evals_set: &[Vec<F>],
scalars: &[F],
) -> Vec<F> {
let mut res = vec![F::ZERO; evals_set[0].len()];
for (poly_evals, s) in evals_set.iter().zip(scalars) {
for i in 0..res.len() {
res[i] += poly_evals[i] * s;
}
}
res
}
pub trait MSM<C: PrimeCurveAffine>: Clone + Debug + Send + Sized + Sync {
fn append_term(&mut self, scalar: C::Scalar, point: C::Curve);
fn add_msm(&mut self, other: &Self);
fn scale(&mut self, factor: C::Scalar);
fn check(&self) -> bool;
fn eval(&self) -> C::Curve;
fn bases(&self) -> Vec<C::Curve>;
fn scalars(&self) -> Vec<C::Scalar>;
}
#[cfg(test)]
use midnight_curves::Fq as Scalar;
#[cfg(test)]
use rand_core::OsRng;
use crate::poly::kzg::msm::MSMKZG;
#[test]
fn test_lagrange_interpolate() {
let rng = OsRng;
let points = (0..5).map(|_| Scalar::random(rng)).collect::<Vec<_>>();
let evals = (0..5).map(|_| Scalar::random(rng)).collect::<Vec<_>>();
for coeffs in 0..5 {
let points = &points[0..coeffs];
let evals = &evals[0..coeffs];
let poly = lagrange_interpolate(points, evals);
assert_eq!(poly.len(), points.len());
for (point, eval) in points.iter().zip(evals) {
assert_eq!(eval_polynomial(&poly, *point), *eval);
}
}
}