use std::{fmt::Debug, ops::Deref};
use ark_ff::FftField;
use ark_poly::{
univariate::{DenseOrSparsePolynomial, DensePolynomial},
DenseUVPolynomial,
};
#[derive(Clone, PartialEq, Eq)]
struct VecBinaryTree<F>(Vec<F>);
impl<F> Deref for VecBinaryTree<F> {
type Target = VecBinarySubTree<F>;
#[inline]
fn deref(&self) -> &Self::Target {
VecBinarySubTree::from_slice(&self.0)
}
}
impl<F: Debug> Debug for VecBinaryTree<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
(**self).fmt(f)
}
}
#[derive(PartialEq, Eq)]
struct VecBinarySubTree<F>([F]);
impl<F> VecBinarySubTree<F> {
#[inline]
const fn from_slice(slice: &[F]) -> &Self {
unsafe { &*(slice as *const [F] as *const Self) }
}
#[inline]
pub const fn root(&self) -> &F {
&self.0[self.0.len() - 1]
}
#[inline]
pub fn left_child(&self) -> Option<&Self> {
(self.0.len() > 1).then(|| Self::from_slice(&self.0[..((self.0.len() - 1) / 2)]))
}
#[inline]
pub fn right_child(&self) -> Option<&Self> {
(self.0.len() > 1)
.then(|| Self::from_slice(&self.0[((self.0.len() - 1) / 2)..(self.0.len() - 1)]))
}
}
impl<F: Debug> Debug for VecBinarySubTree<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"\"{:?}\": [{:?}, {:?}]",
self.root(),
self.left_child(),
self.right_child()
)
}
}
#[inline]
fn subproduct_tree<F: FftField>(x: &[F]) -> VecBinaryTree<DensePolynomial<F>> {
assert!(
x.len().is_power_of_two(),
"Number of points must be a power of two, found {}",
x.len(),
);
let mut buffer = Vec::with_capacity(2 * x.len() - 1);
aux(x, &mut buffer);
return VecBinaryTree(buffer);
fn aux<F: FftField>(x: &[F], tree: &mut Vec<DensePolynomial<F>>) {
if x.len() == 1 {
tree.push(DensePolynomial::from_coefficients_slice(&[-x[0], F::ONE]));
} else {
let offset = tree.len();
let (x1, x2) = x.split_at(x.len() / 2);
aux(x1, tree);
aux(x2, tree);
let len = tree.len();
let poly1 = &tree[offset + (len - offset) / 2 - 1];
let poly2 = &tree[len - 1];
tree.push(poly1 * poly2);
}
}
}
#[allow(dead_code)]
fn fast_fraction_sum<F: FftField>(
subproducts: &VecBinarySubTree<DensePolynomial<F>>,
c: &[F],
) -> DensePolynomial<F> {
if c.len() == 1 {
DensePolynomial::from_coefficients_slice(&[c[0]])
} else {
let a1 = subproducts.left_child().unwrap();
let a2 = subproducts.right_child().unwrap();
let (c1, c2) = c.split_at(c.len() / 2);
let n1 = fast_fraction_sum(a1, c1);
let n2 = fast_fraction_sum(a2, c2);
let p1 = a1.root();
let p2 = a2.root();
&n1 * p2 + &n2 * p1
}
}
fn multiple_fast_fraction_sum<F: FftField>(
subproducts: &VecBinarySubTree<DensePolynomial<F>>,
all_c: &[Vec<F>],
) -> Vec<DensePolynomial<F>> {
return aux(subproducts, all_c, (0, all_c[0].len() - 1));
fn aux<F: FftField>(
subproducts: &VecBinarySubTree<DensePolynomial<F>>,
all_c: &[Vec<F>],
indices: (usize, usize),
) -> Vec<DensePolynomial<F>> {
if indices.1 == indices.0 {
let mut output = Vec::with_capacity(all_c.len());
all_c
.iter()
.map(|c| DensePolynomial::from_coefficients_slice(&[c[indices.0]]))
.for_each(|p| output.push(p));
output
} else {
let left_tree = subproducts.left_child().unwrap();
let right_tree = subproducts.right_child().unwrap();
let indices_left = (indices.0, indices.0 + (indices.1 - indices.0 + 1) / 2 - 1);
let indices_right = (indices_left.1 + 1, indices.1);
let vec_n1 = aux(left_tree, all_c, indices_left);
let vec_n2 = aux(right_tree, all_c, indices_right);
let p1 = left_tree.root();
let p2 = right_tree.root();
let mut output = Vec::with_capacity(all_c.len());
for i in 0..all_c.len() {
output.push(&vec_n1[i] * p2 + &vec_n2[i] * p1);
}
output
}
}
}
#[inline]
fn multipoint_evaluation<F: FftField>(
poly: DensePolynomial<F>,
subproducts: &VecBinarySubTree<DensePolynomial<F>>,
) -> Vec<F> {
let mut output = Vec::with_capacity((subproducts.0.len() + 1) / 2);
aux(poly, subproducts, &mut output);
return output;
fn aux<F: FftField>(
poly: DensePolynomial<F>,
subproducts: &VecBinarySubTree<DensePolynomial<F>>,
output: &mut Vec<F>,
) {
if poly.len() == 1 {
output.push(poly[0]);
} else {
let left_tree = subproducts.left_child().unwrap();
let right_tree = subproducts.right_child().unwrap();
let poly = DenseOrSparsePolynomial::from(poly);
let p0 = poly
.divide_with_q_and_r(&left_tree.root().into())
.unwrap()
.1;
let p1 = poly
.divide_with_q_and_r(&right_tree.root().into())
.unwrap()
.1;
aux(p0, left_tree, output);
aux(p1, right_tree, output);
}
}
}
pub fn interpolate_polynomials<F: FftField>(shards: &[Vec<F>], positions: &[F]) -> Vec<Vec<F>> {
if shards.is_empty() {
return vec![];
}
assert!(
shards.iter().all(|shard| shard.len() == positions.len()),
"The size of all the shards must match the number of positions"
);
let subproducts = subproduct_tree(positions);
let mut root = subproducts.root().clone();
for i in 1..root.coeffs.len() {
root.coeffs[i - 1] = root.coeffs[i] * F::from(i as u64);
}
root.coeffs.pop();
let d_inv = multipoint_evaluation(root, &subproducts)
.into_iter()
.map(|di| di.inverse().unwrap())
.collect::<Vec<_>>();
let mut polynomials = Vec::with_capacity(shards.len());
let all_c: Vec<Vec<F>> = shards
.iter()
.map(|shard| {
shard
.iter()
.zip(d_inv.iter())
.map(|(&x, &y)| x * y)
.collect()
})
.collect::<Vec<_>>();
let r = multiple_fast_fraction_sum(&subproducts, all_c.as_slice());
r.iter().for_each(|v| polynomials.push(v.coeffs.clone()));
polynomials
}
#[cfg(test)]
mod tests {
use std::{collections::HashSet, iter::zip, time};
use crate::{
interpolation::{multipoint_evaluation, subproduct_tree},
utils::to_evaluations,
};
use ark_ff::{FftField, Field, Fp64, MontBackend, MontConfig};
use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial};
use dragoonfri_test_utils::{random_file, Fq};
use rand::{thread_rng, Rng};
use super::interpolate_polynomials;
#[derive(MontConfig)]
#[modulus = "17"]
#[generator = "3"]
pub struct TestSmall;
pub type Fsmall = Fp64<MontBackend<TestSmall, 1>>;
macro_rules! array_in_f {
[$($i: expr),*] => {
[$(Fsmall::from($i)),*]
};
}
#[test]
fn test_subproducts_and_multipoint() {
let points = array_in_f![1, 5, 2, 6].to_vec();
let tree = subproduct_tree(&points);
assert_eq!(&tree.root().coeffs, &array_in_f![9, 7, 14, 3, 1]);
assert_eq!(
&tree.left_child().unwrap().root().coeffs,
&array_in_f![5, 11, 1]
);
assert_eq!(
&tree.right_child().unwrap().root().coeffs,
&array_in_f![12, 9, 1]
);
assert_eq!(
&tree
.left_child()
.unwrap()
.left_child()
.unwrap()
.root()
.coeffs,
&array_in_f![16, 1]
);
assert_eq!(
&tree
.left_child()
.unwrap()
.right_child()
.unwrap()
.root()
.coeffs,
&array_in_f![12, 1]
);
assert_eq!(
&tree
.right_child()
.unwrap()
.left_child()
.unwrap()
.root()
.coeffs,
&array_in_f![15, 1]
);
assert_eq!(
&tree
.right_child()
.unwrap()
.right_child()
.unwrap()
.root()
.coeffs,
&array_in_f![11, 1]
);
let poly = DensePolynomial::from_coefficients_vec(array_in_f![4, 7, 2].to_vec());
assert_eq!(
&multipoint_evaluation(poly, &tree),
&array_in_f![13, 4, 9, 16]
);
}
#[test]
fn test_interpolate_polynomials() {
const NUM_COEFFS: usize = 1024;
const DOMAIN_SIZE: usize = 4096;
const NUM_POLY: usize = 2;
let polys = random_file::<Fq>(NUM_COEFFS, NUM_POLY);
let evaluations = polys
.clone()
.into_iter()
.map(|poly| to_evaluations(poly, DOMAIN_SIZE))
.collect::<Vec<_>>();
let mut shards = Vec::with_capacity(NUM_POLY);
shards.resize_with(NUM_POLY, || Vec::with_capacity(NUM_COEFFS));
let mut positions = Vec::with_capacity(NUM_COEFFS);
let mut positions_set = HashSet::with_capacity(NUM_COEFFS);
let mut rng = thread_rng();
let root = Fq::get_root_of_unity(DOMAIN_SIZE as u64).unwrap();
while positions.len() < NUM_COEFFS {
let p = rng.gen_range(0..DOMAIN_SIZE);
if !positions_set.contains(&p) {
positions.push(root.pow([p as u64]));
positions_set.insert(p);
for (evaluation, shard) in zip(&evaluations, &mut shards) {
shard.push(evaluation[p]);
}
}
}
drop(positions_set);
let start = time::Instant::now();
let polys_inter = interpolate_polynomials(&shards, &positions);
let end = time::Instant::now();
assert_eq!(polys, polys_inter);
println!("Total time: {} seconds", (end - start).as_secs_f64());
println!(
"Average time per polynomial: {} seconds",
(end - start).as_secs_f64() / NUM_POLY as f64
);
}
}