use core::iter::successors;
use ff::PrimeField;
pub struct PowPolynomial<Scalar: PrimeField> {
t_pow: Vec<Scalar>,
}
impl<Scalar: PrimeField> PowPolynomial<Scalar> {
pub fn new(t: &Scalar, ell: usize) -> Self {
let t_pow = successors(Some(*t), |p: &Scalar| Some(p.square()))
.take(ell)
.collect::<Vec<_>>();
PowPolynomial { t_pow }
}
pub fn evals(&self) -> Vec<Scalar> {
successors(Some(Scalar::ONE), |p| Some(*p * self.t_pow[0]))
.take(1 << self.t_pow.len())
.collect::<Vec<_>>()
}
pub fn coordinates(&self) -> &[Scalar] {
&self.t_pow
}
pub fn split_evals(&self, len_left: usize, len_right: usize) -> Vec<Scalar> {
let ell = self.t_pow.len();
assert_eq!(len_left * len_right, 1 << ell);
let t = self.t_pow[0];
let left = successors(Some(Scalar::ONE), |p| Some(*p * t))
.take(len_left)
.collect::<Vec<_>>();
let left_last_times_t = left[left.len() - 1] * t;
let mut right = vec![Scalar::ONE; len_right];
right[0] = Scalar::ONE;
right[1] = left_last_times_t;
for i in 2..len_right {
right[i] = right[i - 1] * left_last_times_t;
}
[left, right].concat()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{bn256_grumpkin::bn256, pasta::pallas, secp_secq::secp256k1};
use rand::rngs::OsRng;
fn test_evals_with<Scalar: PrimeField>() {
let t = Scalar::random(&mut OsRng);
let ell = 4;
let pow = PowPolynomial::new(&t, ell);
let evals = pow.evals();
assert_eq!(evals.len(), 1 << ell);
let mut evals_alt = vec![Scalar::ONE; 1 << ell];
evals_alt[0] = Scalar::ONE;
for i in 1..(1 << ell) {
evals_alt[i] = evals_alt[i - 1] * t;
}
for i in 0..(1 << ell) {
if evals[i] != evals_alt[i] {
println!(
"Mismatch at index {}: expected {:?}, got {:?}",
i, evals_alt[i], evals[i]
);
}
assert_eq!(evals[i], evals_alt[i]);
}
}
#[test]
fn test_evals() {
test_evals_with::<bn256::Scalar>();
test_evals_with::<pallas::Scalar>();
test_evals_with::<secp256k1::Scalar>();
}
fn test_split_evals_with<Scalar: PrimeField>() {
let t = Scalar::random(&mut OsRng);
let ell = 4;
let pow = PowPolynomial::new(&t, ell);
let evals = pow.evals();
assert_eq!(evals.len(), 1 << ell);
let split_evals = pow.split_evals(1 << (ell / 2), 1 << (ell - ell / 2));
let (left, right) = split_evals.split_at(1 << (ell / 2));
let mut evals_iter = evals.iter();
for (i, l) in right.iter().enumerate() {
for (j, r) in left.iter().enumerate() {
let eval = evals_iter.next().unwrap();
if eval != &(*l * r) {
println!(
"Mismatch at left index {}, right index {}: expected {:?}, got {:?}",
i,
j,
*l * r,
eval
);
}
assert_eq!(eval, &(*l * r));
}
}
}
#[test]
fn test_split_evals() {
test_split_evals_with::<bn256::Scalar>();
test_split_evals_with::<pallas::Scalar>();
test_split_evals_with::<secp256k1::Scalar>();
}
}