use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::{LiftedAirBuilder, WindowAccess};
use crate::{
Felt, MainTraceRow,
constraints::{
bus::indices::B_HASH_KERNEL,
chiplets::hasher::{flags, periodic},
op_flags::OpFlags,
tagging::{
TagGroup, TaggingAirBuilderExt, ids::TAG_HASH_KERNEL_BUS_BASE, tagged_assert_zero_ext,
},
},
trace::{
CHIPLETS_OFFSET, Challenges, LOG_PRECOMPILE_LABEL,
chiplets::{
HASHER_NODE_INDEX_COL_IDX, HASHER_SELECTOR_COL_RANGE, HASHER_STATE_COL_RANGE,
NUM_ACE_SELECTORS,
ace::{
ACE_INSTRUCTION_ID1_OFFSET, ACE_INSTRUCTION_ID2_OFFSET, CLK_IDX, CTX_IDX,
EVAL_OP_IDX, ID_1_IDX, ID_2_IDX, PTR_IDX, SELECTOR_BLOCK_IDX, V_0_0_IDX, V_0_1_IDX,
V_1_0_IDX, V_1_1_IDX,
},
memory::{MEMORY_READ_ELEMENT_LABEL, MEMORY_READ_WORD_LABEL},
},
decoder::USER_OP_HELPERS_OFFSET,
log_precompile::{HELPER_CAP_PREV_RANGE, STACK_CAP_NEXT_RANGE},
},
};
const S_START: usize = HASHER_SELECTOR_COL_RANGE.start - CHIPLETS_OFFSET;
const H_START: usize = HASHER_STATE_COL_RANGE.start - CHIPLETS_OFFSET;
const IDX_COL: usize = HASHER_NODE_INDEX_COL_IDX - CHIPLETS_OFFSET;
const HASH_KERNEL_BUS_ID: usize = TAG_HASH_KERNEL_BUS_BASE;
const HASH_KERNEL_BUS_NAMESPACE: &str = "chiplets.bus.hash_kernel.transition";
const HASH_KERNEL_BUS_NAMES: [&str; 1] = [HASH_KERNEL_BUS_NAMESPACE; 1];
const HASH_KERNEL_BUS_TAGS: TagGroup = TagGroup {
base: HASH_KERNEL_BUS_ID,
names: &HASH_KERNEL_BUS_NAMES,
};
pub fn enforce_hash_kernel_constraint<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
challenges: &Challenges<AB::ExprEF>,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let (p_local, p_next) = {
let aux = builder.permutation();
let aux_local = aux.current_slice();
let aux_next = aux.next_slice();
(aux_local[B_HASH_KERNEL], aux_next[B_HASH_KERNEL])
};
let one = AB::Expr::ONE;
let one_ef = AB::ExprEF::ONE;
let (cycle_row_0, cycle_row_31) = {
let p = builder.periodic_values();
let cycle_row_0: AB::Expr = p[periodic::P_CYCLE_ROW_0].into();
let cycle_row_31: AB::Expr = p[periodic::P_CYCLE_ROW_31].into();
(cycle_row_0, cycle_row_31)
};
let chiplet_selector: AB::Expr = local.chiplets[0].clone().into();
let is_hasher: AB::Expr = one.clone() - chiplet_selector.clone();
let s0: AB::Expr = local.chiplets[S_START].clone().into();
let s1: AB::Expr = local.chiplets[S_START + 1].clone().into();
let s2: AB::Expr = local.chiplets[S_START + 2].clone().into();
let node_index: AB::Expr = local.chiplets[IDX_COL].clone().into();
let node_index_next: AB::Expr = next.chiplets[IDX_COL].clone().into();
let h: [AB::Expr; 12] = core::array::from_fn(|i| local.chiplets[H_START + i].clone().into());
let h_next: [AB::Expr; 12] =
core::array::from_fn(|i| next.chiplets[H_START + i].clone().into());
let f_mu: AB::Expr =
is_hasher.clone() * flags::f_mu(cycle_row_0.clone(), s0.clone(), s1.clone(), s2.clone());
let f_mua: AB::Expr =
is_hasher.clone() * flags::f_mua(cycle_row_31.clone(), s0.clone(), s1.clone(), s2.clone());
let f_mv: AB::Expr =
is_hasher.clone() * flags::f_mv(cycle_row_0.clone(), s0.clone(), s1.clone(), s2.clone());
let f_mva: AB::Expr = is_hasher.clone() * flags::f_mva(cycle_row_31.clone(), s0, s1, s2);
let b: AB::Expr = node_index.clone() - node_index_next.clone().double();
let is_b_zero = one.clone() - b.clone();
let is_b_one = b;
let v_sibling_curr = compute_sibling_b0::<AB>(challenges, &node_index, &h) * is_b_zero.clone()
+ compute_sibling_b1::<AB>(challenges, &node_index, &h) * is_b_one.clone();
let v_sibling_next = compute_sibling_b0::<AB>(challenges, &node_index, &h_next) * is_b_zero
+ compute_sibling_b1::<AB>(challenges, &node_index, &h_next) * is_b_one;
let s3: AB::Expr = local.chiplets[3].clone().into();
let chiplet_s1: AB::Expr = local.chiplets[1].clone().into();
let chiplet_s2: AB::Expr = local.chiplets[2].clone().into();
let is_ace_row: AB::Expr =
chiplet_selector.clone() * chiplet_s1.clone() * chiplet_s2.clone() * (one.clone() - s3);
let block_selector: AB::Expr =
local.chiplets[NUM_ACE_SELECTORS + SELECTOR_BLOCK_IDX].clone().into();
let f_ace_read: AB::Expr = is_ace_row.clone() * (one.clone() - block_selector.clone());
let f_ace_eval: AB::Expr = is_ace_row * block_selector;
let ace_clk: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + CLK_IDX].clone().into();
let ace_ctx: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + CTX_IDX].clone().into();
let ace_ptr: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + PTR_IDX].clone().into();
let v_ace_word = {
let v0_0: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + V_0_0_IDX].clone().into();
let v0_1: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + V_0_1_IDX].clone().into();
let v1_0: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + V_1_0_IDX].clone().into();
let v1_1: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + V_1_1_IDX].clone().into();
let label: AB::Expr = AB::Expr::from(Felt::from_u8(MEMORY_READ_WORD_LABEL));
challenges.encode([
label,
ace_ctx.clone(),
ace_ptr.clone(),
ace_clk.clone(),
v0_0,
v0_1,
v1_0,
v1_1,
])
};
let v_ace_element = {
let id_1: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + ID_1_IDX].clone().into();
let id_2: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + ID_2_IDX].clone().into();
let eval_op: AB::Expr = local.chiplets[NUM_ACE_SELECTORS + EVAL_OP_IDX].clone().into();
let offset1: AB::Expr = AB::Expr::from(ACE_INSTRUCTION_ID1_OFFSET);
let offset2: AB::Expr = AB::Expr::from(ACE_INSTRUCTION_ID2_OFFSET);
let element = id_1 + id_2 * offset1 + (eval_op + one.clone()) * offset2;
let label: AB::Expr = AB::Expr::from(Felt::from_u8(MEMORY_READ_ELEMENT_LABEL));
challenges.encode([label, ace_ctx, ace_ptr, ace_clk, element])
};
let f_logprecompile: AB::Expr = op_flags.log_precompile();
let cap_prev: [AB::Expr; 4] = core::array::from_fn(|i| {
local.decoder[USER_OP_HELPERS_OFFSET + HELPER_CAP_PREV_RANGE.start + i]
.clone()
.into()
});
let cap_next: [AB::Expr; 4] =
core::array::from_fn(|i| next.stack[STACK_CAP_NEXT_RANGE.start + i].clone().into());
let log_label: AB::Expr = AB::Expr::from(Felt::from_u8(LOG_PRECOMPILE_LABEL));
let v_cap_prev = challenges.encode([
log_label.clone(),
cap_prev[0].clone(),
cap_prev[1].clone(),
cap_prev[2].clone(),
cap_prev[3].clone(),
]);
let v_cap_next = challenges.encode([
log_label,
cap_next[0].clone(),
cap_next[1].clone(),
cap_next[2].clone(),
cap_next[3].clone(),
]);
let request_flag_sum = f_mu.clone()
+ f_mua.clone()
+ f_ace_read.clone()
+ f_ace_eval.clone()
+ f_logprecompile.clone();
let requests: AB::ExprEF = v_sibling_curr.clone() * f_mu.clone()
+ v_sibling_next.clone() * f_mua.clone()
+ v_ace_word * f_ace_read
+ v_ace_element * f_ace_eval
+ v_cap_prev * f_logprecompile.clone()
+ (one_ef.clone() - request_flag_sum);
let response_flag_sum = f_mv.clone() + f_mva.clone() + f_logprecompile.clone();
let responses: AB::ExprEF = v_sibling_curr * f_mv
+ v_sibling_next * f_mva
+ v_cap_next * f_logprecompile
+ (one_ef - response_flag_sum);
let p_local_ef: AB::ExprEF = p_local.into();
let p_next_ef: AB::ExprEF = p_next.into();
let mut idx = 0;
tagged_assert_zero_ext(
builder,
&HASH_KERNEL_BUS_TAGS,
&mut idx,
p_next_ef * requests - p_local_ef * responses,
);
}
const SIBLING_B0_LAYOUT: [usize; 5] = [2, 7, 8, 9, 10];
const SIBLING_B1_LAYOUT: [usize; 5] = [2, 3, 4, 5, 6];
fn compute_sibling_b0<AB>(
challenges: &Challenges<AB::ExprEF>,
node_index: &AB::Expr,
h: &[AB::Expr; 12],
) -> AB::ExprEF
where
AB: LiftedAirBuilder<F = Felt>,
{
challenges.encode_sparse(
SIBLING_B0_LAYOUT,
[node_index.clone(), h[4].clone(), h[5].clone(), h[6].clone(), h[7].clone()],
)
}
fn compute_sibling_b1<AB>(
challenges: &Challenges<AB::ExprEF>,
node_index: &AB::Expr,
h: &[AB::Expr; 12],
) -> AB::ExprEF
where
AB: LiftedAirBuilder<F = Felt>,
{
challenges.encode_sparse(
SIBLING_B1_LAYOUT,
[node_index.clone(), h[0].clone(), h[1].clone(), h[2].clone(), h[3].clone()],
)
}