use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::{ExtensionBuilder, LiftedAirBuilder, WindowAccess};
use crate::{
MainTraceRow,
constraints::{
bus::indices::P1_BLOCK_STACK,
op_flags::OpFlags,
tagging::{TaggingAirBuilderExt, ids::TAG_DECODER_BUS_BASE},
},
trace::Challenges,
};
const DECODER_BUS_BASE_ID: usize = TAG_DECODER_BUS_BASE;
const DECODER_BUS_NAMES: [&str; 3] = [
"decoder.bus.p1.transition",
"decoder.bus.p2.transition",
"decoder.bus.p3.transition",
];
const OP_BIT_WEIGHTS: [u16; 7] = [1, 2, 4, 8, 16, 32, 64];
struct BlockStackEncoders<'a, AB: LiftedAirBuilder> {
challenges: &'a Challenges<AB::ExprEF>,
}
impl<'a, AB: LiftedAirBuilder> BlockStackEncoders<'a, AB> {
fn new(challenges: &'a Challenges<AB::ExprEF>) -> Self {
Self { challenges }
}
fn simple(&self, block_id: &AB::Expr, parent_id: &AB::Expr, is_loop: &AB::Expr) -> AB::ExprEF {
self.challenges.encode([block_id.clone(), parent_id.clone(), is_loop.clone()])
}
fn full(
&self,
block_id: &AB::Expr,
parent_id: &AB::Expr,
is_loop: &AB::Expr,
ctx: &AB::Expr,
depth: &AB::Expr,
overflow: &AB::Expr,
fh: &[AB::Expr; 4],
) -> AB::ExprEF {
self.challenges.encode([
block_id.clone(),
parent_id.clone(),
is_loop.clone(),
ctx.clone(),
depth.clone(),
overflow.clone(),
fh[0].clone(),
fh[1].clone(),
fh[2].clone(),
fh[3].clone(),
])
}
}
struct BlockHashEncoder<'a, AB: LiftedAirBuilder> {
challenges: &'a Challenges<AB::ExprEF>,
}
impl<'a, AB: LiftedAirBuilder> BlockHashEncoder<'a, AB> {
fn new(challenges: &'a Challenges<AB::ExprEF>) -> Self {
Self { challenges }
}
fn encode(
&self,
parent: &AB::Expr,
hash: [&AB::Expr; 4],
first_child: &AB::Expr,
loop_body: &AB::Expr,
) -> AB::ExprEF {
self.challenges.encode([
parent.clone(),
hash[0].clone(),
hash[1].clone(),
hash[2].clone(),
hash[3].clone(),
first_child.clone(),
loop_body.clone(),
])
}
}
struct OpGroupEncoder<'a, AB: LiftedAirBuilder> {
challenges: &'a Challenges<AB::ExprEF>,
}
impl<'a, AB: LiftedAirBuilder> OpGroupEncoder<'a, AB> {
fn new(challenges: &'a Challenges<AB::ExprEF>) -> Self {
Self { challenges }
}
fn encode(&self, block_id: &AB::Expr, group_count: &AB::Expr, value: &AB::Expr) -> AB::ExprEF {
self.challenges.encode([block_id.clone(), group_count.clone(), value.clone()])
}
}
mod decoder_cols {
pub const ADDR: usize = 0;
pub const HASHER_STATE_OFFSET: usize = 8;
pub const IS_LOOP_FLAG: usize = HASHER_STATE_OFFSET + 5;
pub const IS_CALL_FLAG: usize = HASHER_STATE_OFFSET + 6;
pub const IS_SYSCALL_FLAG: usize = HASHER_STATE_OFFSET + 7;
}
mod stack_cols {
pub const B0: usize = 16;
pub const B1: usize = 17;
}
mod op_group_cols {
const HASHER_STATE_END: usize = super::decoder_cols::HASHER_STATE_OFFSET + 8;
pub const IS_IN_SPAN: usize = HASHER_STATE_END;
pub const GROUP_COUNT: usize = IS_IN_SPAN + 1;
const OP_INDEX: usize = GROUP_COUNT + 1;
const BATCH_FLAGS_OFFSET: usize = OP_INDEX + 1;
pub const BATCH_FLAG_0: usize = BATCH_FLAGS_OFFSET;
pub const BATCH_FLAG_1: usize = BATCH_FLAGS_OFFSET + 1;
pub const BATCH_FLAG_2: usize = BATCH_FLAGS_OFFSET + 2;
}
fn opcode_from_row<AB>(row: &MainTraceRow<AB::Var>) -> AB::Expr
where
AB: LiftedAirBuilder,
{
OP_BIT_WEIGHTS.iter().enumerate().fold(AB::Expr::ZERO, |acc, (i, weight)| {
let bit: AB::Expr = row.decoder[1 + i].clone().into();
acc + bit * AB::Expr::from_u16(*weight)
})
}
pub fn enforce_bus<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
challenges: &Challenges<AB::ExprEF>,
) where
AB: LiftedAirBuilder,
{
enforce_block_stack_table_constraint(builder, local, next, op_flags, challenges);
enforce_block_hash_table_constraint(builder, local, next, op_flags, challenges);
enforce_op_group_table_constraint(builder, local, next, op_flags, challenges);
}
pub fn enforce_block_stack_table_constraint<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
challenges: &Challenges<AB::ExprEF>,
) where
AB: LiftedAirBuilder,
{
let (p1_local, p1_next) = {
let aux = builder.permutation();
let aux_local = aux.current_slice();
let aux_next = aux.next_slice();
(aux_local[P1_BLOCK_STACK], aux_next[P1_BLOCK_STACK])
};
let one = AB::Expr::ONE;
let zero = AB::Expr::ZERO;
let one_ef = AB::ExprEF::ONE;
let to_expr = |v: AB::Var| -> AB::Expr { v.into() };
let addr_local = to_expr(local.decoder[decoder_cols::ADDR].clone());
let addr_next = to_expr(next.decoder[decoder_cols::ADDR].clone());
let h1_next = to_expr(next.decoder[decoder_cols::HASHER_STATE_OFFSET + 1].clone());
let s0 = to_expr(local.stack[0].clone());
let ctx_local = to_expr(local.ctx.clone());
let b0_local = to_expr(local.stack[stack_cols::B0].clone());
let b1_local = to_expr(local.stack[stack_cols::B1].clone());
let fn_hash_local: [AB::Expr; 4] = [
to_expr(local.fn_hash[0].clone()),
to_expr(local.fn_hash[1].clone()),
to_expr(local.fn_hash[2].clone()),
to_expr(local.fn_hash[3].clone()),
];
let h4_local = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 4].clone());
let h5_local = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 5].clone());
let is_loop_flag = to_expr(local.decoder[decoder_cols::IS_LOOP_FLAG].clone());
let is_call_flag = to_expr(local.decoder[decoder_cols::IS_CALL_FLAG].clone());
let is_syscall_flag = to_expr(local.decoder[decoder_cols::IS_SYSCALL_FLAG].clone());
let ctx_next = to_expr(next.ctx.clone());
let b0_next = to_expr(next.stack[stack_cols::B0].clone());
let b1_next = to_expr(next.stack[stack_cols::B1].clone());
let fn_hash_next: [AB::Expr; 4] = [
to_expr(next.fn_hash[0].clone()),
to_expr(next.fn_hash[1].clone()),
to_expr(next.fn_hash[2].clone()),
to_expr(next.fn_hash[3].clone()),
];
let encoders = BlockStackEncoders::<AB>::new(challenges);
let is_join = op_flags.join();
let is_split = op_flags.split();
let is_span = op_flags.span();
let is_dyn = op_flags.dyn_op();
let is_loop = op_flags.loop_op();
let is_respan = op_flags.respan();
let is_call = op_flags.call();
let is_syscall = op_flags.syscall();
let is_dyncall = op_flags.dyncall();
let is_end = op_flags.end();
let msg_simple = encoders.simple(&addr_next, &addr_local, &zero);
let v_join = msg_simple.clone() * is_join.clone();
let v_split = msg_simple.clone() * is_split.clone();
let v_span = msg_simple.clone() * is_span.clone();
let v_dyn = msg_simple.clone() * is_dyn.clone();
let msg_loop = encoders.simple(&addr_next, &addr_local, &s0);
let v_loop = msg_loop * is_loop.clone();
let msg_respan_insert = encoders.simple(&addr_next, &h1_next, &zero);
let v_respan = msg_respan_insert * is_respan.clone();
let msg_call = encoders.full(
&addr_next,
&addr_local,
&zero,
&ctx_local,
&b0_local,
&b1_local,
&fn_hash_local,
);
let v_call = msg_call.clone() * is_call.clone();
let v_syscall = msg_call * is_syscall.clone();
let msg_dyncall = encoders.full(
&addr_next,
&addr_local,
&zero,
&ctx_local,
&h4_local,
&h5_local,
&fn_hash_local,
);
let v_dyncall = msg_dyncall * is_dyncall.clone();
let insert_flag_sum = is_join.clone()
+ is_split.clone()
+ is_span.clone()
+ is_dyn.clone()
+ is_loop.clone()
+ is_respan.clone()
+ is_call.clone()
+ is_syscall.clone()
+ is_dyncall.clone();
let insertion_sum =
v_join + v_split + v_span + v_dyn + v_loop + v_respan + v_call + v_syscall + v_dyncall;
let response = insertion_sum + (one_ef.clone() - insert_flag_sum);
let msg_respan_remove = encoders.simple(&addr_local, &h1_next, &zero);
let u_respan = msg_respan_remove * is_respan.clone();
let is_simple_end = one.clone() - is_call_flag.clone() - is_syscall_flag.clone();
let msg_end_simple = encoders.simple(&addr_local, &addr_next, &is_loop_flag);
let end_simple_gate = is_end.clone() * is_simple_end;
let u_end_simple = msg_end_simple * end_simple_gate;
let is_call_or_syscall = is_call_flag.clone() + is_syscall_flag.clone();
let msg_end_call = encoders.full(
&addr_local,
&addr_next,
&is_loop_flag,
&ctx_next,
&b0_next,
&b1_next,
&fn_hash_next,
);
let end_call_gate = is_end.clone() * is_call_or_syscall;
let u_end_call = msg_end_call * end_call_gate;
let u_end = u_end_simple + u_end_call;
let remove_flag_sum = is_end.clone() + is_respan.clone();
let removal_sum = u_end + u_respan;
let request = removal_sum + (one_ef.clone() - remove_flag_sum);
let lhs: AB::ExprEF = p1_next.into() * request;
let rhs: AB::ExprEF = p1_local.into() * response;
builder.tagged(DECODER_BUS_BASE_ID, DECODER_BUS_NAMES[0], |builder| {
builder.when_transition().assert_zero_ext(lhs - rhs);
});
}
pub fn enforce_block_hash_table_constraint<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
challenges: &Challenges<AB::ExprEF>,
) where
AB: LiftedAirBuilder,
{
let (p2_local, p2_next) = {
let aux = builder.permutation();
let aux_local = aux.current_slice();
let aux_next = aux.next_slice();
(
aux_local[crate::constraints::bus::indices::P2_BLOCK_HASH],
aux_next[crate::constraints::bus::indices::P2_BLOCK_HASH],
)
};
let one = AB::Expr::ONE;
let zero = AB::Expr::ZERO;
let one_ef = AB::ExprEF::ONE;
let to_expr = |v: AB::Var| -> AB::Expr { v.into() };
let parent_id = to_expr(next.decoder[decoder_cols::ADDR].clone());
let h0 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET].clone());
let h1 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 1].clone());
let h2 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 2].clone());
let h3 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 3].clone());
let h4 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 4].clone());
let h5 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 5].clone());
let h6 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 6].clone());
let h7 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 7].clone());
let s0: AB::Expr = to_expr(local.stack[0].clone());
let end_parent_id = to_expr(next.decoder[decoder_cols::ADDR].clone());
let end_hash_0 = h0.clone();
let end_hash_1 = h1.clone();
let end_hash_2 = h2.clone();
let end_hash_3 = h3.clone();
let is_loop_body_flag = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 4].clone());
let accessor_next =
crate::constraints::op_flags::ExprDecoderAccess::<AB::Var, AB::Expr>::new(next);
let op_flags_next = OpFlags::new(accessor_next);
let is_end_next = op_flags_next.end();
let is_repeat_next = op_flags_next.repeat();
let is_halt_next = op_flags_next.halt();
let is_not_first_child = is_end_next + is_repeat_next + is_halt_next;
let is_first_child = one.clone() - is_not_first_child;
let encoder = BlockHashEncoder::<AB>::new(challenges);
let is_join = op_flags.join();
let is_split = op_flags.split();
let is_loop = op_flags.loop_op();
let is_repeat = op_flags.repeat();
let is_dyn = op_flags.dyn_op();
let is_dyncall = op_flags.dyncall();
let is_call = op_flags.call();
let is_syscall = op_flags.syscall();
let is_end = op_flags.end();
let msg_join_left = encoder.encode(&parent_id, [&h0, &h1, &h2, &h3], &one, &zero);
let msg_join_right = encoder.encode(&parent_id, [&h4, &h5, &h6, &h7], &zero, &zero);
let v_join = (msg_join_left * msg_join_right) * is_join.clone();
let split_h0 = s0.clone() * h0.clone() + (one.clone() - s0.clone()) * h4.clone();
let split_h1 = s0.clone() * h1.clone() + (one.clone() - s0.clone()) * h5.clone();
let split_h2 = s0.clone() * h2.clone() + (one.clone() - s0.clone()) * h6.clone();
let split_h3 = s0.clone() * h3.clone() + (one.clone() - s0.clone()) * h7.clone();
let msg_split =
encoder.encode(&parent_id, [&split_h0, &split_h1, &split_h2, &split_h3], &zero, &zero);
let v_split = msg_split * is_split.clone();
let msg_loop = encoder.encode(&parent_id, [&h0, &h1, &h2, &h3], &zero, &one);
let v_loop = (msg_loop * s0.clone() + (one_ef.clone() - s0.clone())) * is_loop.clone();
let msg_repeat = encoder.encode(&parent_id, [&h0, &h1, &h2, &h3], &zero, &one);
let v_repeat = msg_repeat * is_repeat.clone();
let msg_call_like = encoder.encode(&parent_id, [&h0, &h1, &h2, &h3], &zero, &zero);
let v_dyn = msg_call_like.clone() * is_dyn.clone();
let v_dyncall = msg_call_like.clone() * is_dyncall.clone();
let v_call = msg_call_like.clone() * is_call.clone();
let v_syscall = msg_call_like * is_syscall.clone();
let insert_flag_sum = is_join.clone()
+ is_split.clone()
+ is_loop.clone()
+ is_repeat.clone()
+ is_dyn.clone()
+ is_dyncall.clone()
+ is_call.clone()
+ is_syscall.clone();
let response = v_join
+ v_split
+ v_loop
+ v_repeat
+ v_dyn
+ v_dyncall
+ v_call
+ v_syscall
+ (one_ef.clone() - insert_flag_sum);
let msg_end = encoder.encode(
&end_parent_id,
[&end_hash_0, &end_hash_1, &end_hash_2, &end_hash_3],
&is_first_child,
&is_loop_body_flag,
);
let u_end = msg_end * is_end.clone();
let request = u_end + (one_ef.clone() - is_end);
let lhs: AB::ExprEF = p2_next.into() * request;
let rhs: AB::ExprEF = p2_local.into() * response;
builder.tagged(DECODER_BUS_BASE_ID + 1, DECODER_BUS_NAMES[1], |builder| {
builder.when_transition().assert_zero_ext(lhs - rhs);
});
}
pub fn enforce_op_group_table_constraint<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
challenges: &Challenges<AB::ExprEF>,
) where
AB: LiftedAirBuilder,
{
let (p3_local, p3_next) = {
let aux = builder.permutation();
let aux_local = aux.current_slice();
let aux_next = aux.next_slice();
(
aux_local[crate::constraints::bus::indices::P3_OP_GROUP],
aux_next[crate::constraints::bus::indices::P3_OP_GROUP],
)
};
let one = AB::Expr::ONE;
let one_ef = AB::ExprEF::ONE;
let to_expr = |v: AB::Var| -> AB::Expr { v.into() };
let block_id_insert = to_expr(next.decoder[decoder_cols::ADDR].clone());
let block_id_remove = to_expr(local.decoder[decoder_cols::ADDR].clone());
let gc = to_expr(local.decoder[op_group_cols::GROUP_COUNT].clone());
let gc_next = to_expr(next.decoder[op_group_cols::GROUP_COUNT].clone());
let h1 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 1].clone());
let h2 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 2].clone());
let h3 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 3].clone());
let h4 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 4].clone());
let h5 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 5].clone());
let h6 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 6].clone());
let h7 = to_expr(local.decoder[decoder_cols::HASHER_STATE_OFFSET + 7].clone());
let c0 = to_expr(local.decoder[op_group_cols::BATCH_FLAG_0].clone());
let c1 = to_expr(local.decoder[op_group_cols::BATCH_FLAG_1].clone());
let c2 = to_expr(local.decoder[op_group_cols::BATCH_FLAG_2].clone());
let h0_next = to_expr(next.decoder[decoder_cols::HASHER_STATE_OFFSET].clone());
let s0_next = to_expr(next.stack[0].clone());
let sp = to_expr(local.decoder[op_group_cols::IS_IN_SPAN].clone());
let encoder = OpGroupEncoder::<AB>::new(challenges);
let is_push = op_flags.push();
let f_g8 = c0.clone();
let f_g4 = (one.clone() - c0.clone()) * c1.clone() * (one.clone() - c2.clone());
let f_g2 = (one.clone() - c0.clone()) * (one.clone() - c1.clone()) * c2.clone();
let two = AB::Expr::from_u16(2);
let three = AB::Expr::from_u16(3);
let four = AB::Expr::from_u16(4);
let five = AB::Expr::from_u16(5);
let six = AB::Expr::from_u16(6);
let seven = AB::Expr::from_u16(7);
let onetwentyeight = AB::Expr::from_u16(128);
let v_1 = encoder.encode(&block_id_insert, &(gc.clone() - one.clone()), &h1);
let v_2 = encoder.encode(&block_id_insert, &(gc.clone() - two.clone()), &h2);
let v_3 = encoder.encode(&block_id_insert, &(gc.clone() - three.clone()), &h3);
let v_4 = encoder.encode(&block_id_insert, &(gc.clone() - four.clone()), &h4);
let v_5 = encoder.encode(&block_id_insert, &(gc.clone() - five.clone()), &h5);
let v_6 = encoder.encode(&block_id_insert, &(gc.clone() - six.clone()), &h6);
let v_7 = encoder.encode(&block_id_insert, &(gc.clone() - seven.clone()), &h7);
let prod_3 = v_1.clone() * v_2.clone() * v_3.clone();
let prod_7 = v_1.clone() * v_2 * v_3 * v_4 * v_5 * v_6 * v_7;
let response = (v_1.clone() * f_g2.clone())
+ (prod_3 * f_g4.clone())
+ (prod_7 * f_g8.clone())
+ (one_ef.clone() - (f_g2 + f_g4 + f_g8));
let delta_gc = gc.clone() - gc_next;
let f_dg = sp * delta_gc;
let op_code_next = opcode_from_row::<AB>(next);
let group_value_non_push = h0_next * onetwentyeight + op_code_next;
let group_value = is_push.clone() * s0_next + (one.clone() - is_push) * group_value_non_push;
let u = encoder.encode(&block_id_remove, &gc, &group_value);
let request = u * f_dg.clone() + (one_ef.clone() - f_dg);
let lhs: AB::ExprEF = p3_next.into() * request;
let rhs: AB::ExprEF = p3_local.into() * response;
builder.tagged(DECODER_BUS_BASE_ID + 2, DECODER_BUS_NAMES[2], |builder| {
builder.when_transition().assert_zero_ext(lhs - rhs);
});
}