use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::{AirBuilder, LiftedAirBuilder};
use crate::{
MainTraceRow,
constraints::{
op_flags::{ExprDecoderAccess, OpFlags},
tagging::{
TaggingAirBuilderExt,
ids::{
TAG_SYSTEM_CLK_BASE, TAG_SYSTEM_CLK_COUNT, TAG_SYSTEM_CTX_BASE,
TAG_SYSTEM_CTX_COUNT, TAG_SYSTEM_FN_HASH_BASE,
},
},
},
trace::decoder::HASHER_STATE_OFFSET,
};
const SYSTEM_CLK_NAMES: [&str; TAG_SYSTEM_CLK_COUNT] =
["system.clk.first_row", "system.clk.transition"];
const SYSTEM_CTX_NAMES: [&str; TAG_SYSTEM_CTX_COUNT] =
["system.ctx.call_dyncall", "system.ctx.syscall", "system.ctx.default"];
const SYSTEM_FN_HASH_LOAD_NAMESPACE: &str = "system.fn_hash.load";
const SYSTEM_FN_HASH_PRESERVE_NAMESPACE: &str = "system.fn_hash.preserve";
pub fn enforce_main<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
) where
AB: LiftedAirBuilder,
{
enforce_clock_constraint(builder, local, next);
enforce_ctx_constraints(builder, local, next);
enforce_fn_hash_constraints(builder, local, next);
}
pub(crate) fn enforce_clock_constraint<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
) where
AB: LiftedAirBuilder,
{
builder.tagged(TAG_SYSTEM_CLK_BASE, SYSTEM_CLK_NAMES[0], |builder| {
builder.when_first_row().assert_zero(local.clk.clone());
});
builder.tagged(TAG_SYSTEM_CLK_BASE + 1, SYSTEM_CLK_NAMES[1], |builder| {
builder
.when_transition()
.assert_eq(next.clk.clone(), local.clk.clone() + AB::Expr::ONE);
});
}
pub(crate) fn enforce_ctx_constraints<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
) where
AB: LiftedAirBuilder,
{
let ctx: AB::Expr = local.ctx.clone().into();
let ctx_next: AB::Expr = next.ctx.clone().into();
let clk: AB::Expr = local.clk.clone().into();
let op_flags = OpFlags::new(ExprDecoderAccess::new(local));
let f_call = op_flags.call();
let f_syscall = op_flags.syscall();
let f_dyncall = op_flags.dyncall();
let f_end = op_flags.end();
let call_dyncall_flag = f_call.clone() + f_dyncall.clone();
let expected_new_ctx = clk + AB::Expr::ONE;
builder.tagged(TAG_SYSTEM_CTX_BASE, SYSTEM_CTX_NAMES[0], |builder| {
builder
.when_transition()
.assert_zero(call_dyncall_flag * (ctx_next.clone() - expected_new_ctx));
});
builder.tagged(TAG_SYSTEM_CTX_BASE + 1, SYSTEM_CTX_NAMES[1], |builder| {
builder.when_transition().assert_zero(f_syscall.clone() * ctx_next.clone());
});
let change_ctx_flag = f_call + f_syscall + f_dyncall + f_end;
let default_flag = AB::Expr::ONE - change_ctx_flag;
builder.tagged(TAG_SYSTEM_CTX_BASE + 2, SYSTEM_CTX_NAMES[2], |builder| {
builder.when_transition().assert_zero(default_flag * (ctx_next - ctx));
});
}
pub(crate) fn enforce_fn_hash_constraints<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
) where
AB: LiftedAirBuilder,
{
let op_flags = OpFlags::new(ExprDecoderAccess::new(local));
let f_call = op_flags.call();
let f_dyncall = op_flags.dyncall();
let f_end = op_flags.end();
let f_load = f_call.clone() + f_dyncall.clone();
let f_preserve = AB::Expr::ONE - (f_load.clone() + f_end);
let load_ids: [usize; 4] = core::array::from_fn(|i| TAG_SYSTEM_FN_HASH_BASE + i);
builder.tagged_list(load_ids, SYSTEM_FN_HASH_LOAD_NAMESPACE, |builder| {
builder.when_transition().when(f_load.clone()).assert_zeros(
core::array::from_fn::<_, 4, _>(|i| {
let fn_hash_i_next: AB::Expr = next.fn_hash[i].clone().into();
let decoder_h_i: AB::Expr = local.decoder[HASHER_STATE_OFFSET + i].clone().into();
fn_hash_i_next - decoder_h_i
}),
);
});
let preserve_ids: [usize; 4] = core::array::from_fn(|i| TAG_SYSTEM_FN_HASH_BASE + 4 + i);
builder.tagged_list(preserve_ids, SYSTEM_FN_HASH_PRESERVE_NAMESPACE, |builder| {
builder
.when_transition()
.when(f_preserve.clone())
.assert_zeros(core::array::from_fn::<_, 4, _>(|i| {
let fn_hash_i: AB::Expr = local.fn_hash[i].clone().into();
let fn_hash_i_next: AB::Expr = next.fn_hash[i].clone().into();
fn_hash_i_next - fn_hash_i
}));
});
}