pub mod columns;
use miden_crypto::stark::air::AirBuilder;
use crate::{
Felt, MainCols, MidenAirBuilder,
constraints::{
constants::{F_1, F_128},
decoder::columns::DecoderCols,
op_flags::OpFlags,
utils::{BoolNot, horner_eval_bits},
},
trace::chiplets::hasher::CONTROLLER_ROWS_PER_PERM_FELT,
};
pub fn enforce_main<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let DecoderCols {
addr,
op_bits,
hasher_state,
in_span,
group_count,
op_index,
batch_flags,
extra,
} = local.decoder;
let [b0, b1, _, _, b4, b5, b6] = op_bits;
let [bc0, bc1, bc2] = batch_flags;
let [e0, e1] = extra;
let h0 = hasher_state[0];
let end_flags = local.decoder.end_block_flags();
let is_loop_body = end_flags.is_loop_body;
let is_loop = end_flags.is_loop;
let DecoderCols {
addr: addr_next,
op_bits: op_bits_next,
hasher_state: hasher_state_next,
in_span: in_span_next,
group_count: group_count_next,
op_index: op_index_next,
..
} = next.decoder;
let h0_next = hasher_state_next[0];
let delta_group_count: AB::Expr = group_count - group_count_next;
let is_push = op_flags.push();
builder.when_first_row().assert_zero(in_span);
builder.assert_bool(in_span);
builder.when_transition().when(op_flags.span()).assert_one(in_span_next);
builder.when_transition().when(op_flags.respan()).assert_one(in_span_next);
builder.assert_bools(op_bits);
let e0_expected = b6 * b5.into().not() * b4;
builder.assert_eq(e0, e0_expected);
let e1_expected = b6 * b5;
builder.assert_eq(e1, e1_expected);
builder.when(b6 - e1 - e0).assert_zero(b0);
{
let builder = &mut builder.when(e1);
builder.assert_zero(b0);
builder.assert_zero(b1);
}
let branch_condition = local.stack.get(0);
builder
.when(op_flags.split() + op_flags.loop_op())
.assert_bool(branch_condition);
{
let builder = &mut builder.when(op_flags.dyn_op());
let hasher_zeros = [hasher_state[4], hasher_state[5], hasher_state[6], hasher_state[7]];
builder.assert_zeros(hasher_zeros)
}
{
let loop_condition = local.stack.get(0);
let builder = &mut builder.when(op_flags.repeat());
builder.assert_one(loop_condition);
builder.assert_one(is_loop_body);
}
let loop_condition = local.stack.get(0);
builder.when(op_flags.end()).when(is_loop).assert_zero(loop_condition);
{
let gate = builder.is_transition() * op_flags.end() * op_flags.repeat_next();
let builder = &mut builder.when(gate);
for i in 0..5 {
builder.assert_eq(hasher_state_next[i], hasher_state[i]);
}
}
builder.when_transition().when(op_flags.halt()).assert_one(op_flags.halt_next());
{
let gate = builder.is_transition() * in_span;
let builder = &mut builder.when(gate);
builder.assert_bool(delta_group_count.clone());
builder.when(delta_group_count.clone()).when(is_push.not()).assert_zero(h0);
}
builder
.when_transition()
.when(op_flags.span() + op_flags.respan() + is_push.clone())
.assert_one(delta_group_count.clone());
builder
.when_transition()
.when(delta_group_count.clone())
.assert_zero(op_flags.end_next() + op_flags.respan_next());
builder.when(op_flags.end()).assert_zero(group_count);
{
let f_span = op_flags.span();
let f_respan = op_flags.respan();
let same_group_count: AB::Expr = in_span * in_span_next * delta_group_count.not();
let op_next: AB::Expr = horner_eval_bits(&op_bits_next);
let h0_shift = h0 - h0_next * F_128 - op_next;
let h0_active = f_span + f_respan + is_push.clone() + same_group_count;
builder.when_transition().when(h0_active).assert_zero(h0_shift);
let end_or_respan_next = op_flags.end_next() + op_flags.respan_next();
builder.when_transition().when(in_span).when(end_or_respan_next).assert_zero(h0);
}
{
let new_group: AB::Expr = delta_group_count - is_push;
builder
.when_transition()
.when(op_flags.span() + op_flags.respan())
.assert_zero(op_index_next);
builder
.when_transition()
.when(in_span)
.when(new_group.clone())
.assert_zero(op_index_next);
builder
.when_transition()
.when(in_span)
.when(in_span_next)
.when(new_group.not())
.assert_eq(op_index_next, op_index + F_1);
let mut range_check: AB::Expr = op_index.into();
for i in 1..=8u64 {
range_check *= op_index - Felt::new_unchecked(i);
}
builder.assert_zero(range_check);
}
{
builder.assert_bools([bc0, bc1, bc2]);
let groups_8 = bc0;
let not_bc0 = bc0.into().not();
let groups_4 = not_bc0.clone() * bc1 * bc2.into().not();
let groups_2 = not_bc0.clone() * bc1.into().not() * bc2;
let groups_1 = not_bc0 * bc1 * bc2;
let groups_1_or_2 = groups_1.clone() + groups_2;
let groups_1_or_2_or_4 = groups_1_or_2.clone() + groups_4;
let span_or_respan = op_flags.span() + op_flags.respan();
builder.assert_eq(span_or_respan.clone(), groups_1_or_2_or_4.clone() + groups_8);
builder.when(span_or_respan.not()).assert_zero(bc0 + bc1 + bc2);
{
let builder = &mut builder.when(groups_1_or_2_or_4);
for i in 0..4 {
builder.assert_zero(hasher_state[4 + i]);
}
}
{
let builder = &mut builder.when(groups_1_or_2);
for i in 0..2 {
builder.assert_zero(hasher_state[2 + i]);
}
}
builder.when(groups_1).assert_zero(hasher_state[1]);
}
builder.when_transition().when(in_span).assert_eq(addr_next, addr);
builder
.when_transition()
.when(op_flags.respan())
.assert_eq(addr_next, addr + CONTROLLER_ROWS_PER_PERM_FELT);
builder.when(op_flags.halt()).assert_zero(addr);
builder.assert_one(in_span + op_flags.control_flow());
builder.when_last_row().assert_one(op_flags.halt());
}