use core::{array, borrow::Borrow};
use miden_core::field::PrimeCharacteristicRing;
use super::selectors::ChipletFlags;
use crate::{
AirBuilder, MainCols, MidenAirBuilder,
constraints::{
chiplets::columns::{BitwiseCols, PeriodicCols},
constants::F_16,
utils::horner_eval_bits,
},
};
pub fn enforce_bitwise_constraints<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
flags: &ChipletFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let periodic: &PeriodicCols<_> = builder.periodic_values().borrow();
let k_first = periodic.bitwise.k_first;
let k_transition = periodic.bitwise.k_transition;
let bitwise_flag = flags.is_active.clone();
let cols: &BitwiseCols<AB::Var> = local.bitwise();
let cols_next: &BitwiseCols<AB::Var> = next.bitwise();
let bitwise_builder = &mut builder.when(bitwise_flag);
let op_flag = cols.op_flag;
bitwise_builder.assert_bool(op_flag);
let op_flag_next = cols_next.op_flag;
bitwise_builder.when(k_transition).assert_eq(op_flag, op_flag_next);
let (a, a_bits) = (cols.a, cols.a_bits);
let (b, b_bits) = (cols.b, cols.b_bits);
bitwise_builder.assert_bools(a_bits);
bitwise_builder.assert_bools(b_bits);
{
let builder = &mut bitwise_builder.when(k_first);
let a_expected = horner_eval_bits(&a_bits);
builder.assert_eq(a, a_expected);
let b_expected = horner_eval_bits(&b_bits);
builder.assert_eq(b, b_expected);
builder.assert_zero(cols.prev_output);
}
let (a_next, a_next_bits) = (cols_next.a, cols_next.a_bits);
let (b_next, b_next_bits) = (cols_next.b, cols_next.b_bits);
{
let builder = &mut bitwise_builder.when(k_transition);
let a_next_expected = a * F_16 + horner_eval_bits(&a_next_bits);
builder.assert_eq(a_next, a_next_expected);
let b_next_expected = b * F_16 + horner_eval_bits(&b_next_bits);
builder.assert_eq(b_next, b_next_expected);
}
let output = cols.output;
let prev_output_next = cols_next.prev_output;
bitwise_builder.when(k_transition).assert_eq(output, prev_output_next);
let a_and_b_bits: [AB::Expr; 4] = array::from_fn(|i| a_bits[i] * b_bits[i]);
let a_and_b: AB::Expr = horner_eval_bits(&a_and_b_bits);
let a_xor_b_bits: [AB::Expr; 4] =
array::from_fn(|i| a_bits[i] + b_bits[i] - a_and_b_bits[i].clone().double());
let a_xor_b: AB::Expr = horner_eval_bits(&a_xor_b_bits);
let zp = cols.prev_output;
let expected_z = zp * F_16 + a_and_b.clone() + op_flag * (a_xor_b - a_and_b);
let z = cols.output;
bitwise_builder.assert_eq(z, expected_z);
}