use crate::{
Felt, ONE, ZERO,
errors::OperationError,
field::{BasedVectorSpace, Field, QuadFelt},
processor::{Processor, StackInterface},
tracer::OperationHelperRegisters,
};
#[cfg(test)]
mod tests;
#[inline(always)]
pub(super) fn op_fri_ext2fold4<P>(
processor: &mut P,
) -> Result<OperationHelperRegisters, OperationError>
where
P: Processor,
{
let query_values = get_query_values(processor);
let folded_pos = processor.stack().get(8);
let domain_segment = processor.stack().get(9).as_canonical_u64();
let poe = processor.stack().get(10);
if poe.is_zero() {
return Err(OperationError::FriError("domain size was 0".into()));
}
let prev_value = {
let pe1 = processor.stack().get(11);
let pe0 = processor.stack().get(12);
QuadFelt::from_basis_coefficients_fn(|i: usize| [pe0, pe1][i])
};
let alpha = {
let a1 = processor.stack().get(13);
let a0 = processor.stack().get(14);
QuadFelt::from_basis_coefficients_fn(|i: usize| [a0, a1][i])
};
let layer_ptr = processor.stack().get(15);
if domain_segment > 3 {
return Err(OperationError::FriError(format!(
"domain segment value cannot exceed 3, but was {domain_segment}"
)));
}
let d_seg = domain_segment as usize;
if query_values[d_seg] != prev_value {
return Err(OperationError::FriError(format!(
"degree-respecting projection is inconsistent: expected {} but was {}",
prev_value, query_values[d_seg]
)));
}
let f_tau = get_tau_factor(d_seg);
let x = poe * f_tau;
let x_inv = x.inverse();
let (ev, es) = compute_evaluation_points(alpha, x_inv);
let (folded_value, tmp0, tmp1) = fold4(query_values, ev, es);
let tmp0 = tmp0.as_basis_coefficients_slice();
let tmp1 = tmp1.as_basis_coefficients_slice();
let ds = get_domain_segment_flags(d_seg);
let folded_value = folded_value.as_basis_coefficients_slice();
let poe2 = poe * poe;
let poe4 = poe2 * poe2;
processor.stack_mut().decrement_size()?;
processor.stack_mut().set(0, tmp0[1]);
processor.stack_mut().set(1, tmp0[0]);
processor.stack_mut().set(2, tmp1[1]);
processor.stack_mut().set(3, tmp1[0]);
processor.stack_mut().set_word(4, &ds.into());
processor.stack_mut().set(8, poe2);
processor.stack_mut().set(9, f_tau);
processor.stack_mut().set(10, layer_ptr + EIGHT);
processor.stack_mut().set(11, poe4);
processor.stack_mut().set(12, folded_pos);
processor.stack_mut().set(13, folded_value[1]);
processor.stack_mut().set(14, folded_value[0]);
Ok(OperationHelperRegisters::FriExt2Fold4 { ev, es, x, x_inv })
}
#[inline(always)]
fn get_query_values<P: Processor>(processor: &mut P) -> [QuadFelt; 4] {
let [v0, v1, v2, v3]: [Felt; 4] = processor.stack().get_word(0).into();
let [v4, v5, v6, v7]: [Felt; 4] = processor.stack().get_word(4).into();
[
QuadFelt::from_basis_coefficients_fn(|i: usize| [v0, v1][i]),
QuadFelt::from_basis_coefficients_fn(|i: usize| [v2, v3][i]),
QuadFelt::from_basis_coefficients_fn(|i: usize| [v4, v5][i]),
QuadFelt::from_basis_coefficients_fn(|i: usize| [v6, v7][i]),
]
}
const EIGHT: Felt = Felt::new(8);
const TWO_INV: Felt = Felt::new(9223372034707292161);
const TAU_INV: Felt = Felt::new(18446462594437873665); const TAU2_INV: Felt = Felt::new(18446744069414584320); const TAU3_INV: Felt = Felt::new(281474976710656);
fn get_tau_factor(domain_segment: usize) -> Felt {
match domain_segment {
0 => ONE,
1 => TAU_INV,
2 => TAU2_INV,
3 => TAU3_INV,
_ => panic!("invalid domain segment {domain_segment}"),
}
}
fn get_domain_segment_flags(domain_segment: usize) -> [Felt; 4] {
match domain_segment {
0 => [ONE, ZERO, ZERO, ZERO],
1 => [ZERO, ONE, ZERO, ZERO],
2 => [ZERO, ZERO, ONE, ZERO],
3 => [ZERO, ZERO, ZERO, ONE],
_ => panic!("invalid domain segment {domain_segment}"),
}
}
fn compute_evaluation_points(alpha: QuadFelt, x_inv: Felt) -> (QuadFelt, QuadFelt) {
let ev = alpha * x_inv;
let es = ev * ev;
(ev, es)
}
fn fold4(values: [QuadFelt; 4], ev: QuadFelt, es: QuadFelt) -> (QuadFelt, QuadFelt, QuadFelt) {
let tmp0 = fold2(values[0], values[2], ev);
let tmp1 = fold2(values[1], values[3], ev * TAU_INV);
let folded_value = fold2(tmp0, tmp1, es);
(folded_value, tmp0, tmp1)
}
fn fold2(f_x: QuadFelt, f_neg_x: QuadFelt, ep: QuadFelt) -> QuadFelt {
(f_x + f_neg_x + ((f_x - f_neg_x) * ep)) * TWO_INV
}