use alloc::collections::btree_map::BTreeMap;
use alloc::vec;
use alloc::vec::Vec;
use p3_commit::PolynomialSpace;
use p3_field::{ExtensionField, TwoAdicField, batch_multiplicative_inverse};
use p3_interpolation::interpolate_coset_with_precomputation;
use p3_matrix::dense::RowMajorMatrix;
use p3_util::log2_strict_usize;
pub(crate) fn compute_periodic_on_quotient_eval_domain<F, EF>(
periodic_table: Vec<Vec<F>>,
trace_domain: impl PolynomialSpace<Val = F>,
quotient_points: &[EF],
) -> Option<Vec<Vec<EF>>>
where
F: TwoAdicField,
EF: ExtensionField<F>,
{
if periodic_table.is_empty() {
return None;
}
let (trace_height, log_trace_height, shift_inv) = trace_context(&trace_domain);
let quotient_size = quotient_points.len();
let mut grouped: BTreeMap<usize, Vec<(usize, Vec<F>)>> = BTreeMap::new();
let mut evals = vec![Vec::new(); periodic_table.len()];
for (idx, col) in periodic_table.into_iter().enumerate() {
assert!(!col.is_empty());
grouped.entry(col.len()).or_default().push((idx, col));
}
for (period, cols) in grouped {
let (rate_bits, subgroup) = subgroup_data::<F>(trace_height, log_trace_height, period);
let num_cols = cols.len();
let mut subgroup_values = Vec::with_capacity(period * num_cols);
for row in 0..period {
for (_, col) in cols.iter() {
subgroup_values.push(col[row]);
}
}
let subgroup_matrix = RowMajorMatrix::new(subgroup_values, num_cols);
let mut group_evals = vec![Vec::with_capacity(quotient_size); num_cols];
for &x in quotient_points {
let unshifted = x * EF::from(shift_inv);
let y = unshifted.exp_power_of_2(rate_bits);
let diffs: Vec<_> = subgroup.iter().map(|&g| y - EF::from(g)).collect();
let diff_invs = batch_multiplicative_inverse(&diffs);
let values_at_y = interpolate_coset_with_precomputation(
&subgroup_matrix,
F::ONE,
y,
&subgroup,
&diff_invs,
);
for (col_idx, value) in values_at_y.into_iter().enumerate() {
group_evals[col_idx].push(value);
}
}
for (local_idx, (orig_idx, _)) in cols.iter().enumerate() {
evals[*orig_idx] = group_evals[local_idx].clone();
}
}
Some(evals)
}
pub(crate) fn evaluate_periodic_at_point<F, EF>(
periodic_table: Vec<Vec<F>>,
trace_domain: impl PolynomialSpace<Val = F>,
zeta: EF,
) -> Vec<EF>
where
F: TwoAdicField,
EF: ExtensionField<F>,
{
if periodic_table.is_empty() {
return Vec::new();
}
let (trace_height, log_trace_height, shift_inv) = trace_context(&trace_domain);
let unshifted_zeta = zeta * EF::from(shift_inv);
periodic_table
.into_iter()
.map(|col| {
if col.is_empty() {
return EF::ZERO;
}
let (rate_bits, subgroup) =
subgroup_data::<F>(trace_height, log_trace_height, col.len());
let y = unshifted_zeta.exp_power_of_2(rate_bits);
let diffs: Vec<_> = subgroup.iter().map(|&g| y - EF::from(g)).collect();
let diff_invs = batch_multiplicative_inverse(&diffs);
interpolate_coset_with_precomputation(
&RowMajorMatrix::new(col, 1),
F::ONE,
y,
&subgroup,
&diff_invs,
)
.pop()
.expect("single-column interpolation should return one value")
})
.collect()
}
fn trace_context<F>(trace_domain: &impl PolynomialSpace<Val = F>) -> (usize, usize, F)
where
F: TwoAdicField,
{
let trace_height = trace_domain.size();
let log_trace_height = log2_strict_usize(trace_height);
let shift_inv = trace_domain.first_point().inverse();
(trace_height, log_trace_height, shift_inv)
}
fn subgroup_data<F>(trace_height: usize, log_trace_height: usize, period: usize) -> (usize, Vec<F>)
where
F: TwoAdicField,
{
debug_assert!(
trace_height.is_multiple_of(period),
"Periodic column length must divide trace length"
);
let log_period = log2_strict_usize(period);
debug_assert!(
log_trace_height >= log_period,
"Periodic column period cannot exceed trace height"
);
let rate_bits = log_trace_height - log_period;
let subgroup: Vec<_> = F::two_adic_generator(log_period)
.powers()
.take(period)
.collect();
(rate_bits, subgroup)
}
#[cfg(test)]
mod tests {
use p3_field::coset::TwoAdicMultiplicativeCoset;
use p3_field::extension::BinomialExtensionField;
use p3_field::{Field, PrimeCharacteristicRing};
use p3_goldilocks::Goldilocks;
use p3_interpolation::interpolate_coset;
use p3_matrix::dense::RowMajorMatrix;
use super::*;
type Val = Goldilocks;
type Challenge = BinomialExtensionField<Val, 2>;
#[test]
fn test_compute_periodic_on_quotient_eval_domain_correctness() {
let trace_height = 16; let log_quotient_degree = 2;
let quotient_size = trace_height << log_quotient_degree;
let periodic_table = vec![
vec![Val::from_u32(10), Val::from_u32(20)],
vec![
Val::from_u32(1),
Val::from_u32(2),
Val::from_u32(3),
Val::from_u32(4),
],
vec![
Val::from_u32(5),
Val::from_u32(6),
Val::from_u32(7),
Val::from_u32(8),
Val::from_u32(9),
Val::from_u32(10),
Val::from_u32(11),
Val::from_u32(12),
],
];
let log_trace_height = log2_strict_usize(trace_height);
let trace_domain = TwoAdicMultiplicativeCoset::new(Val::GENERATOR, log_trace_height)
.expect("valid trace domain");
let quotient_domain = trace_domain.create_disjoint_domain(quotient_size);
let quotient_points: Vec<Challenge> = {
let mut pts = Vec::with_capacity(quotient_size);
let mut point = Challenge::from(quotient_domain.first_point());
pts.push(point);
for _ in 1..quotient_size {
point = quotient_domain
.next_point(point)
.expect("quotient_domain should yield enough points");
pts.push(point);
}
pts
};
let optimized_result = compute_periodic_on_quotient_eval_domain::<Val, Challenge>(
periodic_table.clone(),
trace_domain,
"ient_points,
)
.expect("periodic_table should not be empty");
let shift = trace_domain.first_point();
let naive_result: Vec<Vec<Challenge>> = periodic_table
.iter()
.map(|periodic_col| {
let period = periodic_col.len();
let mut unpacked = Vec::with_capacity(trace_height);
for i in 0..trace_height {
unpacked.push(periodic_col[i % period]);
}
let unpacked_matrix = RowMajorMatrix::new(unpacked, 1);
let mut evals = Vec::with_capacity(quotient_size);
for &z in "ient_points {
let result = interpolate_coset(&unpacked_matrix, shift, z);
evals.push(result[0]);
}
evals
})
.collect();
assert_eq!(optimized_result, naive_result);
}
#[test]
fn test_compute_periodic_on_quotient_eval_domain_full_period() {
let trace_height = 8;
let log_quotient_degree = 1;
let quotient_size = trace_height << log_quotient_degree;
let periodic_table = vec![vec![
Val::from_u32(1),
Val::from_u32(2),
Val::from_u32(3),
Val::from_u32(4),
Val::from_u32(5),
Val::from_u32(6),
Val::from_u32(7),
Val::from_u32(8),
]];
let trace_domain =
TwoAdicMultiplicativeCoset::new(Val::GENERATOR, log2_strict_usize(trace_height))
.expect("valid trace domain");
let quotient_domain = trace_domain.create_disjoint_domain(quotient_size);
let quotient_points: Vec<Challenge> = {
let mut pts = Vec::with_capacity(quotient_size);
let mut point = Challenge::from(quotient_domain.first_point());
pts.push(point);
for _ in 1..quotient_size {
point = quotient_domain
.next_point(point)
.expect("quotient_domain should yield enough points");
pts.push(point);
}
pts
};
let optimized_result = compute_periodic_on_quotient_eval_domain::<Val, Challenge>(
periodic_table.clone(),
trace_domain,
"ient_points,
)
.expect("periodic_table should not be empty");
let shift = trace_domain.first_point();
let naive_result: Vec<Vec<Challenge>> = periodic_table
.iter()
.map(|periodic_col| {
let unpacked_matrix = RowMajorMatrix::new(periodic_col.clone(), 1);
let mut evals = Vec::with_capacity(quotient_size);
for &z in "ient_points {
let result = interpolate_coset(&unpacked_matrix, shift, z);
evals.push(result[0]);
}
evals
})
.collect();
assert_eq!(optimized_result, naive_result);
}
}