use std::{
borrow::{Borrow, Cow},
ops::{Deref, Index},
};
use ark_ff::FftField;
use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
use crate::{utils::horner_evaluate, AssertPowerOfTwo};
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct FoldedEvaluations<const N: usize, F>(Vec<F>);
impl<const N: usize, F> FoldedEvaluations<N, F> {
pub fn new(evaluations: &[F]) -> Self
where
F: Clone,
{
Self::check_slice(evaluations);
let bound = evaluations.len() / N;
let mut folded = Vec::with_capacity(evaluations.len());
for i in 0..bound {
folded.extend(evaluations.iter().skip(i).step_by(bound).cloned());
}
Self::from_flat_evaluations_unchecked(folded)
}
#[inline]
pub fn into_flat_evaluations(self) -> Vec<F> {
self.0
}
#[inline]
pub fn from_flat_evaluations(evaluations: Vec<F>) -> Self {
Self::check_slice(&evaluations);
Self(evaluations)
}
#[inline]
pub fn from_flat_evaluations_unchecked(evaluations: Vec<F>) -> Self {
Self(evaluations)
}
#[inline]
fn check_slice(evaluations: &[F]) {
debug_assert!(
evaluations.len() % N == 0,
"Domain size must be a multiple of `N`"
);
}
}
impl<const N: usize, F> Deref for FoldedEvaluations<N, F> {
type Target = FoldedEvaluationsSlice<N, F>;
#[inline]
fn deref(&self) -> &Self::Target {
FoldedEvaluationsSlice::from_flat_evaluations_unchecked(&self.0)
}
}
impl<const N: usize, F> Borrow<FoldedEvaluationsSlice<N, F>> for FoldedEvaluations<N, F> {
#[inline]
fn borrow(&self) -> &FoldedEvaluationsSlice<N, F> {
self
}
}
impl<const N: usize, F> AsRef<FoldedEvaluationsSlice<N, F>> for FoldedEvaluations<N, F> {
#[inline]
fn as_ref(&self) -> &FoldedEvaluationsSlice<N, F> {
self
}
}
impl<const N: usize, F, R: ?Sized> AsRef<R> for FoldedEvaluations<N, F>
where
FoldedEvaluationsSlice<N, F>: AsRef<R>,
{
#[inline]
fn as_ref(&self) -> &R {
(**self).as_ref()
}
}
impl<'a, const N: usize, F> IntoIterator for &'a FoldedEvaluations<N, F> {
type Item = <&'a FoldedEvaluationsSlice<N, F> as IntoIterator>::Item;
type IntoIter = <&'a FoldedEvaluationsSlice<N, F> as IntoIterator>::IntoIter;
#[inline]
fn into_iter(self) -> Self::IntoIter {
(**self).into_iter()
}
}
#[derive(PartialEq, Eq, Debug)]
#[repr(transparent)]
pub struct FoldedEvaluationsSlice<const N: usize, F>([F]);
impl<const N: usize, F> Index<usize> for FoldedEvaluationsSlice<N, F> {
type Output = [F; N];
#[inline]
fn index(&self, index: usize) -> &Self::Output {
(&self.0[index * N..(index + 1) * N]).try_into().unwrap()
}
}
impl<const N: usize, F> FoldedEvaluationsSlice<N, F> {
#[inline]
pub fn from_flat_evaluations(evaluations: &[F]) -> &Self {
FoldedEvaluations::<N, _>::check_slice(evaluations);
Self::from_flat_evaluations_unchecked(evaluations)
}
#[inline]
pub const fn from_flat_evaluations_unchecked(evaluations: &[F]) -> &Self {
unsafe { &*(evaluations as *const [F] as *const Self) }
}
#[inline]
pub const fn as_flat_evaluations(&self) -> &[F] {
&self.0
}
#[inline]
pub const fn domain_size(&self) -> usize {
self.0.len()
}
#[inline]
pub const fn folded_len(&self) -> usize {
self.0.len() / N
}
}
impl<const N: usize, F> AsRef<[[F; N]]> for FoldedEvaluationsSlice<N, F> {
#[inline]
fn as_ref(&self) -> &[[F; N]] {
unsafe { core::slice::from_raw_parts(self.0.as_ptr() as *const [F; N], self.folded_len()) }
}
}
impl<'a, const N: usize, F> IntoIterator for &'a FoldedEvaluationsSlice<N, F> {
type Item = &'a [F; N];
type IntoIter = core::slice::Iter<'a, [F; N]>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.as_ref().iter()
}
}
pub fn reduce_polynomial<const N: usize, F: FftField>(
evaluations: &FoldedEvaluations<N, F>,
alpha: F,
domain: Option<&GeneralEvaluationDomain<F>>,
) -> Vec<F> {
let domain = domain.map_or_else(
|| Cow::Owned(GeneralEvaluationDomain::new(N).unwrap()),
Cow::Borrowed,
);
let _: () = AssertPowerOfTwo::<N>::OK;
debug_assert_eq!(domain.size(), N, "Evaluation domain must be of size N");
let mut buffer = Vec::with_capacity(N);
let mut new_evaluations = Vec::with_capacity(evaluations.folded_len());
let root_inv = F::get_root_of_unity(evaluations.domain_size() as u64)
.unwrap()
.pow([evaluations.domain_size() as u64 - 1]);
let mut offset = F::ONE;
for batch in evaluations {
buffer.extend_from_slice(batch);
domain.ifft_in_place(&mut buffer);
new_evaluations.push(horner_evaluate(&buffer, alpha * offset));
offset *= root_inv;
buffer.clear();
}
new_evaluations
}
pub fn fold_positions(positions: &[usize], folded_domain_size: usize) -> Vec<usize> {
let mask = folded_domain_size - 1;
let mut new_positions = vec![];
for &position in positions {
let pos = position & mask;
if !new_positions.contains(&pos) {
new_positions.push(pos);
}
}
new_positions
}