use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::AirBuilder;
use super::selectors::ChipletFlags;
use crate::{
ChipletCols, MidenAirBuilder,
constraints::{
chiplets::columns::MemoryCols,
constants::{F_4, TWO_POW_16},
utils::BoolNot,
},
};
pub fn enforce_memory_constraints<AB>(
builder: &mut AB,
local: &ChipletCols<AB::Var>,
next: &ChipletCols<AB::Var>,
flags: &ChipletFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let cols = local.memory();
let cols_next = next.memory();
{
let builder = &mut builder.when(flags.is_active.clone());
builder.assert_bool(cols.is_read);
builder.assert_bool(cols.is_word);
builder.assert_bool(cols.idx0);
builder.assert_bool(cols.idx1);
let word_addr_lo = local.memory_word_addr_lo();
let word_addr_hi = local.memory_word_addr_hi();
let word_addr = (word_addr_hi * TWO_POW_16 + word_addr_lo) * F_4;
builder.assert_eq(cols.word_addr, word_addr);
{
let builder = &mut builder.when(cols.is_word);
builder.assert_zero(cols.idx0);
builder.assert_zero(cols.idx1);
}
}
let not_written = compute_not_written_flags::<AB>(cols_next);
{
let builder = &mut builder.when(flags.next_is_first.clone());
for (i, nw) in not_written.iter().enumerate() {
builder.when(nw.clone()).assert_zero(cols_next.values[i]);
}
}
let builder = &mut builder.when(flags.is_transition.clone());
let d_inv_next = cols_next.d_inv;
let ctx_delta = cols_next.ctx - cols.ctx;
let ctx_changed = ctx_delta.clone() * d_inv_next;
let same_ctx = ctx_changed.not();
builder.assert_bool(ctx_changed.clone());
let addr_delta = cols_next.word_addr - cols.word_addr;
let addr_changed = addr_delta.clone() * d_inv_next;
let same_addr = addr_changed.not();
{
let builder = &mut builder.when(same_ctx.clone());
builder.assert_zero(ctx_delta.clone());
builder.assert_bool(addr_changed.clone());
builder.when(same_addr.clone()).assert_zero(addr_delta.clone());
}
let same_ctx_and_addr = cols_next.is_same_ctx_and_addr;
builder.assert_eq(same_ctx_and_addr, same_ctx.clone() * same_addr.clone());
let clk_delta = cols_next.clk - cols.clk;
let computed_delta = {
let ctx_term = ctx_changed * ctx_delta;
let addr_term = addr_changed * addr_delta;
let clk_term = same_addr * clk_delta.clone();
ctx_term + same_ctx * (addr_term + clk_term)
};
let delta_next = cols_next.d1 * TWO_POW_16 + cols_next.d0;
builder.assert_eq(computed_delta, delta_next);
{
let clk_no_change = AB::Expr::ONE - clk_delta * d_inv_next;
let is_write = cols.is_read.into().not();
let is_write_next = cols_next.is_read.into().not();
let any_write = is_write + is_write_next;
builder.when(same_ctx_and_addr).when(clk_no_change).assert_zero(any_write);
}
let values = cols.values;
let values_next = cols_next.values;
for (i, nw) in not_written.into_iter().enumerate() {
builder.when(nw).assert_eq(values_next[i], same_ctx_and_addr * values[i]);
}
}
fn compute_not_written_flags<AB>(cols: &MemoryCols<AB::Var>) -> [AB::Expr; 4]
where
AB: MidenAirBuilder,
{
let is_read = cols.is_read;
let is_write = is_read.into().not();
let is_word = cols.is_word;
let is_element = is_word.into().not();
let idx0 = cols.idx0;
let idx1 = cols.idx1;
let not_idx0 = idx0.into().not();
let not_idx1 = idx1.into().not();
let selected = [
not_idx1.clone() * not_idx0.clone(), not_idx1 * idx0, idx1 * not_idx0, idx1 * idx0, ];
let is_element_write = is_write * is_element;
selected.map(|s_i| is_read + is_element_write.clone() * s_i.not())
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use core::borrow::BorrowMut;
use miden_core::{
Felt,
field::{PrimeCharacteristicRing, QuadFelt},
};
use miden_crypto::stark::{
air::{AirBuilder, ExtensionBuilder, PermutationAirBuilder, RowWindow},
matrix::RowMajorMatrix,
};
use super::enforce_memory_constraints;
use crate::{
ChipletCols, MemoryCols,
constraints::chiplets::selectors::ChipletFlags,
trace::{AUX_TRACE_RAND_CHALLENGES, AUX_TRACE_WIDTH, CHIPLETS_WIDTH, TRACE_WIDTH},
};
struct ConstraintEvalBuilder {
main: RowMajorMatrix<Felt>,
aux: RowMajorMatrix<QuadFelt>,
randomness: Vec<QuadFelt>,
permutation_values: Vec<QuadFelt>,
periodic_values: Vec<Felt>,
preprocessed: RowWindow<'static, Felt>,
evaluations: Vec<QuadFelt>,
}
impl ConstraintEvalBuilder {
fn new() -> Self {
Self {
main: RowMajorMatrix::new(vec![Felt::ZERO; TRACE_WIDTH * 2], TRACE_WIDTH),
aux: RowMajorMatrix::new(
vec![QuadFelt::ZERO; AUX_TRACE_WIDTH * 2],
AUX_TRACE_WIDTH,
),
randomness: vec![QuadFelt::ZERO; AUX_TRACE_RAND_CHALLENGES],
permutation_values: vec![QuadFelt::ZERO; AUX_TRACE_WIDTH],
periodic_values: Vec::new(),
preprocessed: RowWindow::from_two_rows(&[], &[]),
evaluations: Vec::new(),
}
}
}
impl AirBuilder for ConstraintEvalBuilder {
type F = Felt;
type Expr = Felt;
type Var = Felt;
type PreprocessedWindow = RowWindow<'static, Felt>;
type MainWindow = RowMajorMatrix<Felt>;
type PublicVar = Felt;
type PeriodicVar = Felt;
fn main(&self) -> Self::MainWindow {
self.main.clone()
}
fn preprocessed(&self) -> &Self::PreprocessedWindow {
&self.preprocessed
}
fn is_first_row(&self) -> Self::Expr {
Felt::ZERO
}
fn is_last_row(&self) -> Self::Expr {
Felt::ZERO
}
fn is_transition(&self) -> Self::Expr {
Felt::ONE
}
fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
self.evaluations.push(QuadFelt::from(x.into()));
}
fn public_values(&self) -> &[Self::PublicVar] {
&[]
}
fn periodic_values(&self) -> &[Self::PeriodicVar] {
&self.periodic_values
}
}
impl ExtensionBuilder for ConstraintEvalBuilder {
type EF = QuadFelt;
type ExprEF = QuadFelt;
type VarEF = QuadFelt;
fn assert_zero_ext<I>(&mut self, x: I)
where
I: Into<Self::ExprEF>,
{
self.evaluations.push(x.into());
}
}
impl PermutationAirBuilder for ConstraintEvalBuilder {
type MP = RowMajorMatrix<QuadFelt>;
type RandomVar = QuadFelt;
type PermutationVar = QuadFelt;
fn permutation(&self) -> Self::MP {
self.aux.clone()
}
fn permutation_randomness(&self) -> &[Self::RandomVar] {
&self.randomness
}
fn permutation_values(&self) -> &[Self::PermutationVar] {
&self.permutation_values
}
}
fn memory_flags() -> ChipletFlags<Felt> {
ChipletFlags {
is_active: Felt::ONE,
is_transition: Felt::ZERO,
is_last: Felt::ZERO,
next_is_first: Felt::ZERO,
}
}
fn memory_row() -> ChipletCols<Felt> {
ChipletCols {
s_00: Felt::ZERO,
s_01: Felt::ZERO,
chip_clk: Felt::ONE,
chiplets: [Felt::ZERO; CHIPLETS_WIDTH - 3],
}
}
fn memory_cols(row: &mut ChipletCols<Felt>) -> &mut MemoryCols<Felt> {
row.chiplets[2..17].borrow_mut()
}
fn set_word_addr_limbs(row: &mut ChipletCols<Felt>, lo: u64, hi: u64) {
row.chiplets[17] = Felt::new_unchecked(lo);
row.chiplets[18] = Felt::new_unchecked(hi);
}
fn eval_memory_constraints(row: &ChipletCols<Felt>) -> Vec<QuadFelt> {
let next = memory_row();
let mut builder = ConstraintEvalBuilder::new();
enforce_memory_constraints(&mut builder, row, &next, &memory_flags());
builder.evaluations
}
fn assert_constraints_accept(row: &ChipletCols<Felt>) {
let evaluations = eval_memory_constraints(row);
assert!(
evaluations.iter().all(|value| *value == QuadFelt::ZERO),
"expected all memory constraints to evaluate to zero; got {evaluations:?}",
);
}
fn assert_constraints_reject(row: &ChipletCols<Felt>) {
let evaluations = eval_memory_constraints(row);
assert!(
evaluations.iter().any(|value| *value != QuadFelt::ZERO),
"expected at least one nonzero memory constraint evaluation",
);
}
#[test]
fn memory_constraints_bind_word_addr_to_range_checked_limbs() {
let mut valid = memory_row();
{
let cols = memory_cols(&mut valid);
cols.is_read = Felt::ONE;
cols.is_word = Felt::ONE;
cols.word_addr = Felt::new_unchecked(4 * (7 + (3 << 16)));
}
set_word_addr_limbs(&mut valid, 7, 3);
assert_constraints_accept(&valid);
let mut invalid = valid.clone();
set_word_addr_limbs(&mut invalid, 0, 0);
assert_constraints_reject(&invalid);
}
}