midnight-proofs 0.7.1

Fast PLONK-based zero-knowledge proving system
//! This module provides common utilities, traits and structures for group,
//! field and polynomial arithmetic.

use std::{
    fmt::Debug,
    ops::{Add, Mul},
};

pub use ff::Field;
use group::{
    ff::{BatchInvert, PrimeField},
    prime::{PrimeCurve, PrimeCurveAffine},
    GroupOpsOwned, ScalarMulOwned,
};
use midnight_curves::{fft::best_fft, pairing::MultiMillerLoop};
pub use midnight_curves::{CurveAffine, CurveExt};

/// This represents an element of a group with basic operations that can be
/// performed. This allows an FFT implementation (for example) to operate
/// generically over either a field or elliptic curve group.
pub trait FftGroup<Scalar: Field>:
    Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
{
}

impl<T, Scalar> FftGroup<Scalar> for T
where
    Scalar: Field,
    T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
{
}

/// Convert coefficient bases group elements to lagrange basis by inverse FFT.
pub fn g_to_lagrange<C: PrimeCurve>(g_projective: &[C], k: u32) -> Vec<C> {
    let n_inv = C::Scalar::TWO_INV.pow_vartime([k as u64, 0, 0, 0]);
    let mut omega_inv = C::Scalar::ROOT_OF_UNITY_INV;
    for _ in k..C::Scalar::S {
        omega_inv = omega_inv.square();
    }

    let mut g_lagrange = g_projective.to_vec();
    best_fft(&mut g_lagrange, omega_inv, k);
    parallelize(&mut g_lagrange, |g, _| {
        for g in g.iter_mut() {
            *g *= n_inv;
        }
    });

    g_lagrange.to_vec()
}

/// This evaluates a provided polynomial (in coefficient form) at `point`.
pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
    fn evaluate<F: Field>(poly: &[F], point: F) -> F {
        poly.iter().rev().fold(F::ZERO, |acc, coeff| acc * point + coeff)
    }
    let n = poly.len();
    let num_threads = rayon::current_num_threads();
    if n * 2 < num_threads {
        evaluate(poly, point)
    } else {
        let chunk_size = n.div_ceil(num_threads);
        let mut parts = vec![F::ZERO; num_threads];
        rayon::scope(|scope| {
            for (chunk_idx, (out, poly)) in
                parts.chunks_mut(1).zip(poly.chunks(chunk_size)).enumerate()
            {
                scope.spawn(move |_| {
                    let start = chunk_idx * chunk_size;
                    out[0] = evaluate(poly, point) * point.pow_vartime([start as u64, 0, 0, 0]);
                });
            }
        });
        parts.iter().fold(F::ZERO, |acc, coeff| acc + coeff)
    }
}

/// This computes the inner product of two vectors `a` and `b`.
///
/// This function will panic if the two vectors are not the same size.
pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
    // TODO: parallelize?
    assert_eq!(a.len(), b.len());

    let mut acc = F::ZERO;
    for (a, b) in a.iter().zip(b.iter()) {
        acc += (*a) * (*b);
    }

    acc
}

/// Divides polynomial `a` in `X` by `X - b` with
/// no remainder.
pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
where
    I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
{
    b = -b;
    let a = a.into_iter();

    let mut q = vec![F::ZERO; a.len() - 1];

    let mut tmp = F::ZERO;
    for (q, r) in q.iter_mut().rev().zip(a.rev()) {
        let mut lead_coeff = *r;
        lead_coeff.sub_assign(&tmp);
        *q = lead_coeff;
        tmp = lead_coeff;
        tmp.mul_assign(&b);
    }

    q
}

/// This utility function will parallelize an operation that is to be
/// performed over a mutable slice.
pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
    // Algorithm rationale:
    //
    // Using the stdlib `chunks_mut` will lead to severe load imbalance.
    // From https://github.com/rust-lang/rust/blob/e94bda3/library/core/src/slice/iter.rs#L1607-L1637
    // if the division is not exact, the last chunk will be the remainder.
    //
    // Dividing 40 items on 12 threads will lead to a chunk size of 40/12 = 3,
    // There will be a 13 chunks of size 3 and 1 of size 1 distributed on 12
    // threads. This leads to 1 thread working on 6 iterations, 1 on 4
    // iterations and 10 on 3 iterations, a load imbalance of 2x.
    //
    // Instead we can divide work into chunks of size
    // 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3 = 4*4 + 3*8 = 40
    //
    // This would lead to a 6/4 = 1.5x speedup compared to naive chunks_mut
    //
    // See also OpenMP spec (page 60)
    // http://www.openmp.org/mp-documents/openmp-4.5.pdf
    // "When no chunk_size is specified, the iteration space is divided into chunks
    // that are approximately equal in size, and at most one chunk is distributed to
    // each thread. The size of the chunks is unspecified in this case."
    // This implies chunks are the same size ±1

    let f = &f;
    let total_iters = v.len();
    let num_threads = rayon::current_num_threads();
    let base_chunk_size = total_iters / num_threads;
    let cutoff_chunk_id = total_iters % num_threads;
    let split_pos = cutoff_chunk_id * (base_chunk_size + 1);
    let (v_hi, v_lo) = v.split_at_mut(split_pos);

    rayon::scope(|scope| {
        // Skip special-case: number of iterations is cleanly divided by number of
        // threads.
        if cutoff_chunk_id != 0 {
            for (chunk_id, chunk) in v_hi.chunks_exact_mut(base_chunk_size + 1).enumerate() {
                let offset = chunk_id * (base_chunk_size + 1);
                scope.spawn(move |_| f(chunk, offset));
            }
        }
        // Skip special-case: less iterations than number of threads.
        if base_chunk_size != 0 {
            for (chunk_id, chunk) in v_lo.chunks_exact_mut(base_chunk_size).enumerate() {
                let offset = split_pos + (chunk_id * base_chunk_size);
                scope.spawn(move |_| f(chunk, offset));
            }
        }
    });
}

/// Returns coefficients of an n - 1 degree polynomial given a set of n points
/// and their evaluations. This function will panic if two values in `points`
/// are the same.
pub fn lagrange_interpolate<F: Field + Ord>(points: &[F], evals: &[F]) -> Vec<F> {
    assert_eq!(points.len(), evals.len());
    {
        let mut sorted_points = points.to_vec();
        sorted_points.sort();
        assert!(!sorted_points.windows(2).any(|w| w[0] == w[1]));
    }

    if points.len() == 1 {
        // Constant polynomial
        vec![evals[0]]
    } else {
        let mut denoms = Vec::with_capacity(points.len());
        for (j, x_j) in points.iter().enumerate() {
            let mut denom = Vec::with_capacity(points.len() - 1);
            for x_k in points.iter().enumerate().filter(|&(k, _)| k != j).map(|a| a.1) {
                denom.push(*x_j - x_k);
            }
            denoms.push(denom);
        }
        // Compute (x_j - x_k)^(-1) for each j != i
        denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();

        let mut final_poly = vec![F::ZERO; points.len()];
        for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
            let mut tmp: Vec<F> = Vec::with_capacity(points.len());
            let mut product = Vec::with_capacity(points.len() - 1);
            tmp.push(F::ONE);
            for (x_k, denom) in points
                .iter()
                .enumerate()
                .filter(|&(k, _)| k != j)
                .map(|a| a.1)
                .zip(denoms.into_iter())
            {
                product.resize(tmp.len() + 1, F::ZERO);
                for ((a, b), product) in tmp
                    .iter()
                    .chain(std::iter::once(&F::ZERO))
                    .zip(std::iter::once(&F::ZERO).chain(tmp.iter()))
                    .zip(product.iter_mut())
                {
                    *product = *a * (-denom * x_k) + *b * denom;
                }
                std::mem::swap(&mut tmp, &mut product);
            }
            assert_eq!(tmp.len(), points.len());
            assert_eq!(product.len(), points.len() - 1);
            for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp.into_iter()) {
                *final_coeff += interpolation_coeff * eval;
            }
        }
        final_poly
    }
}

#[cfg(feature = "truncated-challenges")]
use num_bigint::BigUint;

/// Truncates a scalar field element to half its byte size.
///
/// This function reduces a scalar field element `scalar` to half its size by
/// retaining only the lower half of its little-endian byte representation.
///
/// # Note
/// For cryptographically secure elliptic curves, the scalar field is
/// approximately twice the size of the security parameter. When scalars are
/// sampled uniformly at random, truncating to half the field size retains
/// sufficient entropy for security while reducing computational overhead.
///
/// # Warning
/// 128 bits may not be enough entropy depending on the application. For
/// example, it makes a collision attack feasible with 2^64 memory and ~2^64
/// operations.
#[cfg(feature = "truncated-challenges")]
pub(crate) fn truncate<F: PrimeField>(scalar: F) -> F {
    let nb_bytes = F::NUM_BITS.div_ceil(8).div_ceil(2) as usize;
    let bytes = scalar.to_repr().as_ref()[..nb_bytes].to_vec();
    let bi = BigUint::from_bytes_le(&bytes);
    F::from_str_vartime(&BigUint::to_string(&bi)).unwrap()
}

#[cfg(feature = "truncated-challenges")]
pub(crate) fn truncated_powers<F: PrimeField>(base: F) -> impl Iterator<Item = F> {
    powers(base).map(truncate)
}

pub(crate) fn powers<F: Field>(base: F) -> impl Iterator<Item = F> {
    std::iter::successors(Some(F::ONE), move |power| Some(base * power))
}

pub(crate) fn inner_product<F: PrimeField, T: Mul<F, Output = T> + Add<T, Output = T> + Clone>(
    polys: &[T],
    scalars: impl Iterator<Item = F>,
) -> T {
    polys
        .iter()
        .zip(scalars)
        .map(|(p, s)| p.clone() * s)
        .reduce(|acc, p| acc + p)
        .unwrap()
}

pub(crate) fn msm_inner_product<E>(mut msms: Vec<MSMKZG<E>>, scalars: &[E::Fr]) -> MSMKZG<E>
where
    E: MultiMillerLoop + Debug,
    E::G1Affine: CurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
    E::Fr: Ord,
{
    let len: usize = msms.iter().map(|m| m.scalars.len()).sum();

    let mut new_scalars = Vec::with_capacity(len);
    let mut new_bases = Vec::with_capacity(len);

    msms.iter_mut().zip(scalars.iter()).for_each(|(msm, s)| {
        msm.scale(*s);
        new_scalars.extend(&msm.scalars);
        new_bases.extend(&msm.bases);
    });

    MSMKZG {
        scalars: new_scalars,
        bases: new_bases,
    }
}

/// Computes the inner product of a set of polynomial evaluations and a set of
/// scalar values. This function computes the weighted sum of polynomial
/// evaluations. Each vector in `evals_set` is multiplied element-wise by a
/// corresponding scalar from `scalars`, and the results are accumulated
/// into a single vector.
pub(crate) fn evals_inner_product<F: PrimeField + Clone>(
    evals_set: &[Vec<F>],
    scalars: &[F],
) -> Vec<F> {
    let mut res = vec![F::ZERO; evals_set[0].len()];
    for (poly_evals, s) in evals_set.iter().zip(scalars) {
        for i in 0..res.len() {
            res[i] += poly_evals[i] * s;
        }
    }
    res
}

/// Multi scalar multiplication engine
pub trait MSM<C: PrimeCurveAffine>: Clone + Debug + Send + Sized + Sync {
    /// Add arbitrary term (the scalar and the point)
    fn append_term(&mut self, scalar: C::Scalar, point: C::Curve);

    /// Add another multiexp into this one
    fn add_msm(&mut self, other: &Self);

    /// Scale all scalars in the MSM by some scaling factor
    fn scale(&mut self, factor: C::Scalar);

    /// Perform multiexp and check that it results in zero
    fn check(&self) -> bool;

    /// Perform multiexp and return the result
    fn eval(&self) -> C::Curve;

    /// Return base points
    fn bases(&self) -> Vec<C::Curve>;

    /// Scalars
    fn scalars(&self) -> Vec<C::Scalar>;
}

#[cfg(test)]
use midnight_curves::Fq as Scalar;
#[cfg(test)]
use rand_core::OsRng;

use crate::poly::kzg::msm::MSMKZG;

#[test]
fn test_lagrange_interpolate() {
    let rng = OsRng;

    let points = (0..5).map(|_| Scalar::random(rng)).collect::<Vec<_>>();
    let evals = (0..5).map(|_| Scalar::random(rng)).collect::<Vec<_>>();

    for coeffs in 0..5 {
        let points = &points[0..coeffs];
        let evals = &evals[0..coeffs];

        let poly = lagrange_interpolate(points, evals);
        assert_eq!(poly.len(), points.len());

        for (point, eval) in points.iter().zip(evals) {
            assert_eq!(eval_polynomial(&poly, *point), *eval);
        }
    }
}