use miden_core::{Felt, field::PrimeCharacteristicRing};
use miden_crypto::stark::air::AirBuilder;
use crate::{
MainCols, MidenAirBuilder,
constraints::{
constants::{F_3, F_4, F_8},
ext_field::{QuadFeltAirBuilder, QuadFeltExpr},
op_flags::OpFlags,
},
};
const TAU_INV: Felt = Felt::new_unchecked(18446462594437873665);
const TAU2_INV: Felt = Felt::new_unchecked(18446744069414584320);
const TAU3_INV: Felt = Felt::new_unchecked(281474976710656);
pub fn enforce_main<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
enforce_cryptostream_constraints(builder, local, next, op_flags);
enforce_hornerbase_constraints(builder, local, next, op_flags);
enforce_hornerext_constraints(builder, local, next, op_flags);
enforce_frie2f4_constraints(builder, local, next, op_flags);
}
fn enforce_cryptostream_constraints<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let gate = builder.is_transition() * op_flags.cryptostream();
let builder = &mut builder.when(gate);
let s = &local.stack.top;
let s_next = &next.stack.top;
builder.assert_eq(s_next[8], s[8]);
builder.assert_eq(s_next[9], s[9]);
builder.assert_eq(s_next[10], s[10]);
builder.assert_eq(s_next[11], s[11]);
builder.assert_eq(s_next[12], s[12].into() + F_8);
builder.assert_eq(s_next[13], s[13].into() + F_8);
builder.assert_eq(s_next[14], s[14]);
builder.assert_eq(s_next[15], s[15]);
}
fn enforce_hornerbase_constraints<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let horner_builder = &mut builder.when(op_flags.hornerbase());
let s = &local.stack.top;
let s_next = &next.stack.top;
let helpers = local.decoder.user_op_helpers();
{
let builder = &mut horner_builder.when_transition();
for i in 0..14 {
builder.assert_eq(s_next[i], s[i]);
}
}
let alpha: QuadFeltExpr<AB::Expr> = QuadFeltExpr::new(helpers[0], helpers[1]);
let alpha_sq = alpha.clone().square();
let alpha_cubed = alpha_sq.clone() * alpha.clone();
let tmp0 = QuadFeltExpr::new(helpers[4], helpers[5]);
let tmp1 = QuadFeltExpr::new(helpers[2], helpers[3]);
let acc = QuadFeltExpr::new(s[14], s[15]);
let acc_next = QuadFeltExpr::new(s_next[14], s_next[15]);
let c = |i: usize| -> AB::Expr { s[i].into() };
let tmp0_expected = acc * alpha_sq.clone() + alpha.clone() * c(0) + c(1);
let tmp1_expected =
tmp0.clone() * alpha_cubed.clone() + alpha_sq.clone() * c(2) + alpha.clone() * c(3) + c(4);
let acc_expected = tmp1.clone() * alpha_cubed + alpha_sq * c(5) + alpha * c(6) + c(7);
horner_builder.assert_eq_quad(tmp0, tmp0_expected);
horner_builder.assert_eq_quad(tmp1, tmp1_expected);
horner_builder.when_transition().assert_eq_quad(acc_next, acc_expected);
}
fn enforce_hornerext_constraints<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let horner_builder = &mut builder.when(op_flags.hornerext());
let s = &local.stack.top;
let s_next = &next.stack.top;
let helpers = local.decoder.user_op_helpers();
{
let builder = &mut horner_builder.when_transition();
for i in 0..14 {
builder.assert_eq(s_next[i], s[i]);
}
}
let alpha: QuadFeltExpr<AB::Expr> = QuadFeltExpr::new(helpers[0], helpers[1]);
let alpha_sq = alpha.clone().square();
let tmp = QuadFeltExpr::new(helpers[4], helpers[5]);
let acc = QuadFeltExpr::new(s[14], s[15]);
let acc_next = QuadFeltExpr::new(s_next[14], s_next[15]);
let c0 = QuadFeltExpr::new(s[0], s[1]);
let c1 = QuadFeltExpr::new(s[2], s[3]);
let c2 = QuadFeltExpr::new(s[4], s[5]);
let c3 = QuadFeltExpr::new(s[6], s[7]);
let tmp_expected = acc * alpha_sq.clone() + alpha.clone() * c0 + c1;
let acc_expected = tmp.clone() * alpha_sq + alpha * c2 + c3;
horner_builder.assert_eq_quad(tmp, tmp_expected);
horner_builder.when_transition().assert_eq_quad(acc_next, acc_expected);
}
fn enforce_frie2f4_constraints<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let builder = &mut builder.when(op_flags.frie2f4());
let s = &local.stack.top;
let s_next = &next.stack.top;
let helpers = local.decoder.user_op_helpers();
let q0 = QuadFeltExpr::new(s[0], s[1]);
let q2 = QuadFeltExpr::new(s[2], s[3]);
let q1 = QuadFeltExpr::new(s[4], s[5]);
let q3 = QuadFeltExpr::new(s[6], s[7]);
let folded_pos = s[8];
let tree_index = s[9];
let poe = s[10];
let prev_eval = QuadFeltExpr::new(s[11], s[12]);
let alpha = QuadFeltExpr::new(s[13], s[14]);
let layer_ptr = s[15];
let seg_flag_0 = s_next[4];
let seg_flag_1 = s_next[5];
let seg_flag_2 = s_next[6];
let seg_flag_3 = s_next[7];
builder.assert_bools([seg_flag_0, seg_flag_1, seg_flag_2, seg_flag_3]);
builder.assert_one(seg_flag_0 + seg_flag_1 + seg_flag_2 + seg_flag_3);
let folded_pos_next = s_next[12];
let segment_index = seg_flag_1.into().double() + seg_flag_2 + seg_flag_3 * F_3;
builder.assert_eq(tree_index, folded_pos_next * F_4 + segment_index);
let tau_factor = s_next[9];
let expected_tau =
seg_flag_0 + seg_flag_1 * TAU_INV + seg_flag_2 * TAU2_INV + seg_flag_3 * TAU3_INV;
builder.assert_eq(tau_factor, expected_tau);
let domain_point = helpers[4];
let domain_point_inv = helpers[5];
builder.assert_eq(domain_point, poe * tau_factor);
builder.assert_one(domain_point * domain_point_inv);
let eval_point: QuadFeltExpr<AB::Expr> = QuadFeltExpr::new(helpers[0], helpers[1]);
builder.assert_eq_quad(eval_point.clone(), alpha * domain_point_inv.into());
let eval_point_sq: QuadFeltExpr<AB::Expr> = QuadFeltExpr::new(helpers[2], helpers[3]);
builder.assert_eq_quad(eval_point_sq.clone(), eval_point.clone().square());
let fold2_doubled = |a: QuadFeltExpr<AB::Expr>,
b: QuadFeltExpr<AB::Expr>,
ep: QuadFeltExpr<AB::Expr>|
-> QuadFeltExpr<AB::Expr> { (a.clone() + b.clone()) + (a - b) * ep };
let fold_mid0 = QuadFeltExpr::new(s_next[0], s_next[1]);
let fold_mid1 = QuadFeltExpr::new(s_next[2], s_next[3]);
let fold_result = QuadFeltExpr::new(s_next[13], s_next[14]);
builder.assert_eq_quad(fold_mid0.clone().double(), fold2_doubled(q0, q2, eval_point.clone()));
let eval_point_coset = eval_point * AB::Expr::from(TAU_INV);
builder.assert_eq_quad(fold_mid1.clone().double(), fold2_doubled(q1, q3, eval_point_coset));
builder
.assert_eq_quad(fold_result.double(), fold2_doubled(fold_mid0, fold_mid1, eval_point_sq));
let selected_re = s[0] * seg_flag_0 + s[4] * seg_flag_1 + s[2] * seg_flag_2 + s[6] * seg_flag_3;
let selected_im = s[1] * seg_flag_0 + s[5] * seg_flag_1 + s[3] * seg_flag_2 + s[7] * seg_flag_3;
builder.assert_eq_quad(prev_eval, QuadFeltExpr::new(selected_re, selected_im));
let poe_sq = s_next[8];
let poe_fourth = s_next[11];
builder.assert_eq(poe_sq, poe * poe);
builder.assert_eq(poe_fourth, poe_sq * poe_sq);
let layer_ptr_next = s_next[10];
builder.assert_eq(layer_ptr_next, layer_ptr + F_8);
builder.assert_eq(folded_pos_next, folded_pos);
}