use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::AirBuilder;
use super::selectors::ChipletFlags;
use crate::{
MainCols, MidenAirBuilder,
constraints::{chiplets::columns::MemoryCols, constants::TWO_POW_16, utils::BoolNot},
};
pub fn enforce_memory_constraints<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
flags: &ChipletFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let cols = local.memory();
let cols_next = next.memory();
{
let builder = &mut builder.when(flags.is_active.clone());
builder.assert_bool(cols.is_read);
builder.assert_bool(cols.is_word);
builder.assert_bool(cols.idx0);
builder.assert_bool(cols.idx1);
{
let builder = &mut builder.when(cols.is_word);
builder.assert_zero(cols.idx0);
builder.assert_zero(cols.idx1);
}
}
let not_written = compute_not_written_flags::<AB>(cols_next);
{
let builder = &mut builder.when(flags.next_is_first.clone());
for (i, nw) in not_written.iter().enumerate() {
builder.when(nw.clone()).assert_zero(cols_next.values[i]);
}
}
let builder = &mut builder.when(flags.is_transition.clone());
let d_inv_next = cols_next.d_inv;
let ctx_delta = cols_next.ctx - cols.ctx;
let ctx_changed = ctx_delta.clone() * d_inv_next;
let same_ctx = ctx_changed.not();
builder.assert_bool(ctx_changed.clone());
let addr_delta = cols_next.word_addr - cols.word_addr;
let addr_changed = addr_delta.clone() * d_inv_next;
let same_addr = addr_changed.not();
{
let builder = &mut builder.when(same_ctx.clone());
builder.assert_zero(ctx_delta.clone());
builder.assert_bool(addr_changed.clone());
builder.when(same_addr.clone()).assert_zero(addr_delta.clone());
}
let same_ctx_and_addr = cols_next.is_same_ctx_and_addr;
builder.assert_eq(same_ctx_and_addr, same_ctx.clone() * same_addr.clone());
let clk_delta = cols_next.clk - cols.clk;
let computed_delta = {
let ctx_term = ctx_changed * ctx_delta;
let addr_term = addr_changed * addr_delta;
let clk_term = same_addr * clk_delta.clone();
ctx_term + same_ctx * (addr_term + clk_term)
};
let delta_next = cols_next.d1 * TWO_POW_16 + cols_next.d0;
builder.assert_eq(computed_delta, delta_next);
{
let clk_no_change = AB::Expr::ONE - clk_delta * d_inv_next;
let is_write = cols.is_read.into().not();
let is_write_next = cols_next.is_read.into().not();
let any_write = is_write + is_write_next;
builder.when(same_ctx_and_addr).when(clk_no_change).assert_zero(any_write);
}
let values = cols.values;
let values_next = cols_next.values;
for (i, nw) in not_written.into_iter().enumerate() {
builder.when(nw).assert_eq(values_next[i], same_ctx_and_addr * values[i]);
}
}
fn compute_not_written_flags<AB>(cols: &MemoryCols<AB::Var>) -> [AB::Expr; 4]
where
AB: MidenAirBuilder,
{
let is_read = cols.is_read;
let is_write = is_read.into().not();
let is_word = cols.is_word;
let is_element = is_word.into().not();
let idx0 = cols.idx0;
let idx1 = cols.idx1;
let not_idx0 = idx0.into().not();
let not_idx1 = idx1.into().not();
let selected = [
not_idx1.clone() * not_idx0.clone(), not_idx1 * idx0, idx1 * not_idx0, idx1 * idx0, ];
let is_element_write = is_write * is_element;
selected.map(|s_i| is_read + is_element_write.clone() * s_i.not())
}