use super::{super::QuadFelt, ExecutionError, Felt, Host, Operation, Process};
use vm_core::{ExtensionOf, FieldElement, StarkField, ONE, ZERO};
const TWO: Felt = Felt::new(2);
const TWO_INV: Felt = Felt::new(9223372034707292161);
const DOMAIN_OFFSET: Felt = Felt::GENERATOR;
const TAU_INV: Felt = Felt::new(18446462594437873665); const TAU2_INV: Felt = Felt::new(18446744069414584320); const TAU3_INV: Felt = Felt::new(281474976710656);
impl<H> Process<H>
where
H: Host,
{
pub(super) fn op_fri_ext2fold4(&mut self) -> Result<(), ExecutionError> {
let query_values = self.get_query_values();
let f_pos = self.get_folded_position();
let d_seg = self.get_domain_segment().as_int();
let poe = self.get_poe();
let prev_value = self.get_previous_value();
let alpha = self.get_alpha();
let layer_ptr = self.get_layer_ptr();
if d_seg > 3 {
return Err(ExecutionError::InvalidFriDomainSegment(d_seg));
}
let d_seg = d_seg as usize;
if query_values[d_seg] != prev_value {
return Err(ExecutionError::InvalidFriLayerFolding(prev_value, query_values[d_seg]));
}
let f_tau = get_tau_factor(d_seg);
let x = poe * f_tau * DOMAIN_OFFSET;
let x_inv = x.inv();
let (ev, es) = compute_evaluation_points(alpha, x_inv);
let (folded_value, tmp0, tmp1) = fold4(query_values, ev, es);
let tmp0 = tmp0.to_base_elements();
let tmp1 = tmp1.to_base_elements();
let ds = get_domain_segment_flags(d_seg);
let folded_value = folded_value.to_base_elements();
let poe2 = poe.square();
let poe4 = poe2.square();
self.stack.set(0, tmp0[1]);
self.stack.set(1, tmp0[0]);
self.stack.set(2, tmp1[1]);
self.stack.set(3, tmp1[0]);
self.stack.set(4, ds[3]);
self.stack.set(5, ds[2]);
self.stack.set(6, ds[1]);
self.stack.set(7, ds[0]);
self.stack.set(8, poe2);
self.stack.set(9, f_tau);
self.stack.set(10, layer_ptr + TWO);
self.stack.set(11, poe4);
self.stack.set(12, f_pos);
self.stack.set(13, folded_value[1]);
self.stack.set(14, folded_value[0]);
self.set_helper_registers(ev, es, x, x_inv);
self.stack.shift_left(16);
Ok(())
}
fn get_query_values(&self) -> [QuadFelt; 4] {
let v7 = self.stack.get(0);
let v6 = self.stack.get(1);
let v5 = self.stack.get(2);
let v4 = self.stack.get(3);
let v3 = self.stack.get(4);
let v2 = self.stack.get(5);
let v1 = self.stack.get(6);
let v0 = self.stack.get(7);
[
QuadFelt::new(v0, v1),
QuadFelt::new(v2, v3),
QuadFelt::new(v4, v5),
QuadFelt::new(v6, v7),
]
}
fn get_folded_position(&self) -> Felt {
self.stack.get(8)
}
fn get_domain_segment(&self) -> Felt {
self.stack.get(9)
}
fn get_poe(&self) -> Felt {
self.stack.get(10)
}
fn get_previous_value(&self) -> QuadFelt {
let pe1 = self.stack.get(11);
let pe0 = self.stack.get(12);
QuadFelt::new(pe0, pe1)
}
fn get_alpha(&self) -> QuadFelt {
let a1 = self.stack.get(13);
let a0 = self.stack.get(14);
QuadFelt::new(a0, a1)
}
fn get_layer_ptr(&self) -> Felt {
self.stack.get(15)
}
fn set_helper_registers(&mut self, ev: QuadFelt, es: QuadFelt, x: Felt, x_inv: Felt) {
let ev_arr = [ev];
let ev_felts = QuadFelt::slice_as_base_elements(&ev_arr);
let es_arr = [es];
let es_felts = QuadFelt::slice_as_base_elements(&es_arr);
let values = [ev_felts[0], ev_felts[1], es_felts[0], es_felts[1], x, x_inv];
self.decoder.set_user_op_helpers(Operation::FriE2F4, &values);
}
}
fn get_tau_factor(domain_segment: usize) -> Felt {
match domain_segment {
0 => ONE,
1 => TAU_INV,
2 => TAU2_INV,
3 => TAU3_INV,
_ => panic!("invalid domain segment {domain_segment}"),
}
}
fn get_domain_segment_flags(domain_segment: usize) -> [Felt; 4] {
match domain_segment {
0 => [ONE, ZERO, ZERO, ZERO],
1 => [ZERO, ONE, ZERO, ZERO],
2 => [ZERO, ZERO, ONE, ZERO],
3 => [ZERO, ZERO, ZERO, ONE],
_ => panic!("invalid domain segment {domain_segment}"),
}
}
fn compute_evaluation_points(alpha: QuadFelt, x_inv: Felt) -> (QuadFelt, QuadFelt) {
let ev = alpha.mul_base(x_inv);
let es = ev.square();
(ev, es)
}
fn fold4(values: [QuadFelt; 4], ev: QuadFelt, es: QuadFelt) -> (QuadFelt, QuadFelt, QuadFelt) {
let tmp0 = fold2(values[0], values[2], ev);
let tmp1 = fold2(values[1], values[3], ev.mul_base(TAU_INV));
let folded_value = fold2(tmp0, tmp1, es);
(folded_value, tmp0, tmp1)
}
fn fold2(f_x: QuadFelt, f_neg_x: QuadFelt, ep: QuadFelt) -> QuadFelt {
(f_x + f_neg_x + ((f_x - f_neg_x) * ep)).mul_base(TWO_INV)
}
#[cfg(test)]
mod tests {
use super::{
ExtensionOf, Felt, FieldElement, Operation, Process, QuadFelt, StarkField, TWO, TWO_INV,
};
use alloc::vec::Vec;
use test_utils::rand::{rand_array, rand_value, rand_vector};
use vm_core::StackInputs;
use winter_prover::math::{fft, get_power_series_with_offset};
use winter_utils::transpose_slice;
#[test]
fn fold4() {
let blowup = 4_usize;
let alpha: QuadFelt = rand_value();
let poly: Vec<QuadFelt> = rand_vector(8);
let offset = Felt::GENERATOR;
let twiddles = fft::get_twiddles(poly.len());
let evaluations = fft::evaluate_poly_with_offset(&poly, &twiddles, offset, blowup);
let transposed_evaluations = transpose_slice::<QuadFelt, 4>(&evaluations);
let folded_evaluations =
winter_fri::folding::apply_drp(&transposed_evaluations, offset, alpha);
let n = poly.len() * blowup;
let g = Felt::get_root_of_unity(n.trailing_zeros());
let domain = get_power_series_with_offset(g, offset, n);
let pos = 3;
let x = domain[pos];
let ev = alpha.mul_base(x.inv());
let (result, _, _) = super::fold4(transposed_evaluations[pos], ev, ev.square());
assert_eq!(folded_evaluations[pos], result)
}
#[test]
fn constants() {
let tau = Felt::get_root_of_unity(2);
assert_eq!(super::TAU_INV, tau.inv());
assert_eq!(super::TAU2_INV, tau.square().inv());
assert_eq!(super::TAU3_INV, tau.cube().inv());
assert_eq!(TWO.inv(), TWO_INV);
}
#[test]
fn op_fri_ext2fold4() {
let mut inputs = rand_array::<Felt, 17>();
inputs[7] = TWO;
inputs[4] = inputs[13];
inputs[5] = inputs[14];
let end_ptr = inputs[0];
let layer_ptr = inputs[1];
let alpha = QuadFelt::new(inputs[2], inputs[3]);
let poe = inputs[6];
let d_seg = inputs[7];
let f_pos = inputs[8];
let query_values = [
QuadFelt::new(inputs[9], inputs[10]),
QuadFelt::new(inputs[11], inputs[12]),
QuadFelt::new(inputs[13], inputs[14]),
QuadFelt::new(inputs[15], inputs[16]),
];
let stack_inputs = StackInputs::new(inputs.to_vec()).expect("inputs lenght too long");
let mut process = Process::new_dummy_with_decoder_helpers(stack_inputs);
process.execute_op(Operation::FriE2F4).unwrap();
let stack_state = process.stack.trace_state();
let f_tau = super::get_tau_factor(d_seg.as_int() as usize);
let x = poe * f_tau * super::DOMAIN_OFFSET;
let x_inv = x.inv();
let (ev, es) = super::compute_evaluation_points(alpha, x_inv);
let (folded_value, tmp0, tmp1) = super::fold4(query_values, ev, es);
let tmp0 = tmp0.to_base_elements();
let tmp1 = tmp1.to_base_elements();
assert_eq!(stack_state[0], tmp0[1]);
assert_eq!(stack_state[1], tmp0[0]);
assert_eq!(stack_state[2], tmp1[1]);
assert_eq!(stack_state[3], tmp1[0]);
let ds = super::get_domain_segment_flags(d_seg.as_int() as usize);
assert_eq!(stack_state[4], ds[3]);
assert_eq!(stack_state[5], ds[2]);
assert_eq!(stack_state[6], ds[1]);
assert_eq!(stack_state[7], ds[0]);
assert_eq!(stack_state[8], poe.square());
assert_eq!(stack_state[9], f_tau);
assert_eq!(stack_state[10], layer_ptr + TWO);
assert_eq!(stack_state[11], poe.exp(4));
assert_eq!(stack_state[12], f_pos);
let folded_value = folded_value.to_base_elements();
assert_eq!(stack_state[13], folded_value[1]);
assert_eq!(stack_state[14], folded_value[0]);
assert_eq!(stack_state[15], end_ptr);
let mut expected_helpers = QuadFelt::slice_as_base_elements(&[ev, es]).to_vec();
expected_helpers.push(x);
expected_helpers.push(x_inv);
assert_eq!(expected_helpers, process.decoder.get_user_op_helpers().to_vec());
}
}