use core::ops::{Add, Mul, Sub};
use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::LiftedAirBuilder;
use super::selectors::memory_chiplet_flag;
use crate::{
Felt, MainTraceRow,
constraints::tagging::{TagGroup, TaggingAirBuilderExt, tagged_assert_zero_integrity},
trace::{
CHIPLETS_OFFSET,
chiplets::{
MEMORY_CLK_COL_IDX, MEMORY_CTX_COL_IDX, MEMORY_D_INV_COL_IDX, MEMORY_D0_COL_IDX,
MEMORY_D1_COL_IDX, MEMORY_FLAG_SAME_CONTEXT_AND_WORD, MEMORY_IDX0_COL_IDX,
MEMORY_IDX1_COL_IDX, MEMORY_IS_READ_COL_IDX, MEMORY_IS_WORD_ACCESS_COL_IDX,
MEMORY_V_COL_RANGE, MEMORY_WORD_COL_IDX,
},
},
};
pub const MEMORY_BASE_ID: usize = super::bitwise::BITWISE_BASE_ID + super::bitwise::BITWISE_COUNT;
pub const MEMORY_COUNT: usize = 21;
const MEMORY_BINARY_BASE_ID: usize = MEMORY_BASE_ID;
const MEMORY_WORD_IDX_BASE_ID: usize = MEMORY_BASE_ID + 4;
const MEMORY_FIRST_ROW_BASE_ID: usize = MEMORY_BASE_ID + 6;
const MEMORY_DELTA_INV_BASE_ID: usize = MEMORY_BASE_ID + 10;
const MEMORY_DELTA_TRANSITION_ID: usize = MEMORY_BASE_ID + 14;
const MEMORY_SCW_FLAG_ID: usize = MEMORY_BASE_ID + 15;
const MEMORY_SCW_READS_ID: usize = MEMORY_BASE_ID + 16;
const MEMORY_VALUE_CONSIST_BASE_ID: usize = MEMORY_BASE_ID + 17;
const MEMORY_BINARY_NAMESPACE: &str = "chiplets.memory.binary";
const MEMORY_WORD_IDX_NAMESPACE: &str = "chiplets.memory.word_idx.zero";
const MEMORY_FIRST_ROW_NAMESPACE: &str = "chiplets.memory.first_row.zero";
const MEMORY_DELTA_INV_NAMESPACE: &str = "chiplets.memory.delta.inv";
const MEMORY_DELTA_TRANSITION_NAMESPACE: &str = "chiplets.memory.delta.transition";
const MEMORY_SCW_FLAG_NAMESPACE: &str = "chiplets.memory.scw.flag";
const MEMORY_SCW_READS_NAMESPACE: &str = "chiplets.memory.scw.reads";
const MEMORY_VALUE_CONSIST_NAMESPACE: &str = "chiplets.memory.value.consistency";
const MEMORY_BINARY_NAMES: [&str; 4] = [MEMORY_BINARY_NAMESPACE; 4];
const MEMORY_WORD_IDX_NAMES: [&str; 2] = [MEMORY_WORD_IDX_NAMESPACE; 2];
const MEMORY_FIRST_ROW_NAMES: [&str; 4] = [MEMORY_FIRST_ROW_NAMESPACE; 4];
const MEMORY_DELTA_INV_NAMES: [&str; 4] = [MEMORY_DELTA_INV_NAMESPACE; 4];
const MEMORY_DELTA_TRANSITION_NAMES: [&str; 1] = [MEMORY_DELTA_TRANSITION_NAMESPACE; 1];
const MEMORY_SCW_FLAG_NAMES: [&str; 1] = [MEMORY_SCW_FLAG_NAMESPACE; 1];
const MEMORY_SCW_READS_NAMES: [&str; 1] = [MEMORY_SCW_READS_NAMESPACE; 1];
const MEMORY_VALUE_CONSIST_NAMES: [&str; 4] = [MEMORY_VALUE_CONSIST_NAMESPACE; 4];
const MEMORY_BINARY_TAGS: TagGroup = TagGroup {
base: MEMORY_BINARY_BASE_ID,
names: &MEMORY_BINARY_NAMES,
};
const MEMORY_WORD_IDX_TAGS: TagGroup = TagGroup {
base: MEMORY_WORD_IDX_BASE_ID,
names: &MEMORY_WORD_IDX_NAMES,
};
const MEMORY_FIRST_ROW_TAGS: TagGroup = TagGroup {
base: MEMORY_FIRST_ROW_BASE_ID,
names: &MEMORY_FIRST_ROW_NAMES,
};
const MEMORY_DELTA_INV_TAGS: TagGroup = TagGroup {
base: MEMORY_DELTA_INV_BASE_ID,
names: &MEMORY_DELTA_INV_NAMES,
};
const MEMORY_DELTA_TRANSITION_TAGS: TagGroup = TagGroup {
base: MEMORY_DELTA_TRANSITION_ID,
names: &MEMORY_DELTA_TRANSITION_NAMES,
};
const MEMORY_SCW_FLAG_TAGS: TagGroup = TagGroup {
base: MEMORY_SCW_FLAG_ID,
names: &MEMORY_SCW_FLAG_NAMES,
};
const MEMORY_SCW_READS_TAGS: TagGroup = TagGroup {
base: MEMORY_SCW_READS_ID,
names: &MEMORY_SCW_READS_NAMES,
};
const MEMORY_VALUE_CONSIST_TAGS: TagGroup = TagGroup {
base: MEMORY_VALUE_CONSIST_BASE_ID,
names: &MEMORY_VALUE_CONSIST_NAMES,
};
pub fn enforce_memory_constraints<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let s0: AB::Expr = local.chiplets[0].clone().into();
let s1: AB::Expr = local.chiplets[1].clone().into();
let s1_next: AB::Expr = next.chiplets[1].clone().into();
let s2_next: AB::Expr = next.chiplets[2].clone().into();
let is_transition: AB::Expr = builder.is_transition();
enforce_memory_constraints_all_rows(builder, local, next);
let flag_next_row_first_memory = is_transition.clone()
* flag_next_row_first_memory(s0.clone(), s1.clone(), s1_next, s2_next.clone());
enforce_memory_constraints_first_row(builder, local, next, flag_next_row_first_memory);
let flag_memory_active_not_last =
is_transition * flag_memory_active_not_last_row(s0, s1, s2_next);
enforce_memory_constraints_all_rows_except_last(
builder,
local,
next,
flag_memory_active_not_last,
);
}
pub fn enforce_memory_constraints_all_rows<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
_next: &MainTraceRow<AB::Var>,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let s0: AB::Expr = local.chiplets[0].clone().into();
let s1: AB::Expr = local.chiplets[1].clone().into();
let s2: AB::Expr = local.chiplets[2].clone().into();
let memory_flag = memory_chiplet_flag(s0, s1, s2);
let cols: MemoryColumns<AB::Expr> = MemoryColumns::from_row::<AB>(local);
let one: AB::Expr = AB::Expr::ONE;
let gate = memory_flag.clone();
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&MEMORY_BINARY_TAGS,
&mut idx,
gate.clone() * cols.is_read.clone() * (cols.is_read.clone() - one.clone()),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_BINARY_TAGS,
&mut idx,
gate.clone() * cols.is_word.clone() * (cols.is_word.clone() - one.clone()),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_BINARY_TAGS,
&mut idx,
gate.clone() * cols.idx0.clone() * (cols.idx0.clone() - one.clone()),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_BINARY_TAGS,
&mut idx,
gate * cols.idx1.clone() * (cols.idx1.clone() - one),
);
let word_gate = memory_flag.clone() * cols.is_word.clone();
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&MEMORY_WORD_IDX_TAGS,
&mut idx,
word_gate.clone() * cols.idx0.clone(),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_WORD_IDX_TAGS,
&mut idx,
word_gate * cols.idx1.clone(),
);
}
pub fn enforce_memory_constraints_first_row<AB>(
builder: &mut AB,
_local: &MainTraceRow<AB::Var>,
cols_first: &MainTraceRow<AB::Var>,
flag_next_row_first_memory: AB::Expr,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let cols_next: MemoryColumns<AB::Expr> = MemoryColumns::from_row::<AB>(cols_first);
let one: AB::Expr = AB::Expr::ONE;
let [c0, c1, c2, c3] = cols_next.compute_value_constraint_flags(one.clone());
let gate = flag_next_row_first_memory;
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&MEMORY_FIRST_ROW_TAGS,
&mut idx,
gate.clone() * c0 * cols_next.values[0].clone(),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_FIRST_ROW_TAGS,
&mut idx,
gate.clone() * c1 * cols_next.values[1].clone(),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_FIRST_ROW_TAGS,
&mut idx,
gate.clone() * c2 * cols_next.values[2].clone(),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_FIRST_ROW_TAGS,
&mut idx,
gate * c3 * cols_next.values[3].clone(),
);
}
pub fn enforce_memory_constraints_all_rows_except_last<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
flag_memory_active_not_last: AB::Expr,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let cols: MemoryColumns<AB::Expr> = MemoryColumns::from_row::<AB>(local);
let cols_next: MemoryColumns<AB::Expr> = MemoryColumns::from_row::<AB>(next);
let one: AB::Expr = AB::Expr::ONE;
let deltas = MemoryDeltas::new::<AB>(&cols, &cols_next, one.clone());
enforce_delta_inverse_constraints::<AB>(
builder,
flag_memory_active_not_last.clone(),
&deltas,
one.clone(),
);
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&MEMORY_DELTA_TRANSITION_TAGS,
&mut idx,
flag_memory_active_not_last.clone()
* (deltas.computed_delta.clone() - deltas.delta_next.clone()),
);
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&MEMORY_SCW_FLAG_TAGS,
&mut idx,
flag_memory_active_not_last.clone()
* (cols_next.flag_same_ctx_word.clone()
- (one.clone() - deltas.n0.clone()) * (one.clone() - deltas.n1.clone())),
);
enforce_scw_readonly_constraint::<AB>(
builder,
flag_memory_active_not_last.clone(),
&cols,
&cols_next,
&deltas,
one.clone(),
);
let [c0, c1, c2, c3] = cols_next.compute_value_constraint_flags(one.clone());
let constrain_value = |c: AB::Expr, v: AB::Expr, v_next: AB::Expr| {
flag_memory_active_not_last.clone()
* c
* (v_next - cols_next.flag_same_ctx_word.clone() * v)
};
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&MEMORY_VALUE_CONSIST_TAGS,
&mut idx,
constrain_value(c0, cols.values[0].clone(), cols_next.values[0].clone()),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_VALUE_CONSIST_TAGS,
&mut idx,
constrain_value(c1, cols.values[1].clone(), cols_next.values[1].clone()),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_VALUE_CONSIST_TAGS,
&mut idx,
constrain_value(c2, cols.values[2].clone(), cols_next.values[2].clone()),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_VALUE_CONSIST_TAGS,
&mut idx,
constrain_value(c3, cols.values[3].clone(), cols_next.values[3].clone()),
);
}
struct MemoryDeltas<E> {
ctx_delta: E,
addr_delta: E,
clk_delta: E,
n0: E,
n1: E,
delta_next: E,
computed_delta: E,
}
impl<E> MemoryDeltas<E>
where
E: Clone + Add<Output = E> + Sub<Output = E> + Mul<Output = E>,
{
fn new<AB>(cols: &MemoryColumns<E>, cols_next: &MemoryColumns<E>, one: E) -> Self
where
AB: LiftedAirBuilder<F = Felt>,
AB::Expr: Into<E>,
{
let ctx_delta = cols_next.ctx.clone() - cols.ctx.clone();
let addr_delta = cols_next.word_addr.clone() - cols.word_addr.clone();
let clk_delta = cols_next.clk.clone() - cols.clk.clone();
let two_pow_16: E = AB::Expr::from_u32(1 << 16).into();
let n0 = ctx_delta.clone() * cols_next.d_inv.clone();
let n1 = addr_delta.clone() * cols_next.d_inv.clone();
let delta_next = cols_next.d1.clone() * two_pow_16 + cols_next.d0.clone();
let computed_delta = n0.clone() * ctx_delta.clone()
+ (one.clone() - n0.clone())
* (n1.clone() * addr_delta.clone() + (one - n1.clone()) * clk_delta.clone());
Self {
ctx_delta,
addr_delta,
clk_delta,
n0,
n1,
delta_next,
computed_delta,
}
}
}
pub struct MemoryColumns<E> {
pub is_read: E,
pub is_word: E,
pub ctx: E,
pub word_addr: E,
pub idx0: E,
pub idx1: E,
pub clk: E,
pub values: [E; 4],
pub d0: E,
pub d1: E,
pub d_inv: E,
pub flag_same_ctx_word: E,
}
impl<E: Clone> MemoryColumns<E> {
pub fn from_row<AB>(row: &MainTraceRow<AB::Var>) -> Self
where
AB: LiftedAirBuilder<F = Felt>,
AB::Var: Into<E> + Clone,
{
let load = |global_idx: usize| {
let local_idx = global_idx - CHIPLETS_OFFSET;
row.chiplets[local_idx].clone().into()
};
MemoryColumns {
is_read: load(MEMORY_IS_READ_COL_IDX),
is_word: load(MEMORY_IS_WORD_ACCESS_COL_IDX),
ctx: load(MEMORY_CTX_COL_IDX),
word_addr: load(MEMORY_WORD_COL_IDX),
idx0: load(MEMORY_IDX0_COL_IDX),
idx1: load(MEMORY_IDX1_COL_IDX),
clk: load(MEMORY_CLK_COL_IDX),
values: core::array::from_fn(|i| load(MEMORY_V_COL_RANGE.start + i)),
d0: load(MEMORY_D0_COL_IDX),
d1: load(MEMORY_D1_COL_IDX),
d_inv: load(MEMORY_D_INV_COL_IDX),
flag_same_ctx_word: load(MEMORY_FLAG_SAME_CONTEXT_AND_WORD),
}
}
pub fn compute_value_constraint_flags<One>(&self, one: One) -> [E; 4]
where
E: Add<Output = E> + Sub<Output = E> + Mul<Output = E>,
One: Into<E>,
{
let one = one.into();
let is_write = one.clone() - self.is_read.clone();
let is_element = one.clone() - self.is_word.clone();
let f0 = (one.clone() - self.idx1.clone()) * (one.clone() - self.idx0.clone());
let f1 = (one.clone() - self.idx1.clone()) * self.idx0.clone();
let f2 = self.idx1.clone() * (one.clone() - self.idx0.clone());
let f3 = self.idx1.clone() * self.idx0.clone();
let compute_c = |f_i: E| {
let not_f_i = one.clone() - f_i;
self.is_read.clone() + is_write.clone() * is_element.clone() * not_f_i
};
[compute_c(f0), compute_c(f1), compute_c(f2), compute_c(f3)]
}
}
fn enforce_delta_inverse_constraints<AB>(
builder: &mut AB,
flag_memory_active_not_last: AB::Expr,
deltas: &MemoryDeltas<AB::Expr>,
one: AB::Expr,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let n0 = deltas.n0.clone();
let n1 = deltas.n1.clone();
let ctx_delta = deltas.ctx_delta.clone();
let addr_delta = deltas.addr_delta.clone();
let not_n0 = one.clone() - n0.clone();
let not_n1 = one.clone() - n1.clone();
let gate = flag_memory_active_not_last;
let gate_not_n0 = gate.clone() * not_n0.clone();
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&MEMORY_DELTA_INV_TAGS,
&mut idx,
gate * n0.clone() * (n0.clone() - one.clone()),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_DELTA_INV_TAGS,
&mut idx,
gate_not_n0.clone() * ctx_delta.clone(),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_DELTA_INV_TAGS,
&mut idx,
gate_not_n0.clone() * n1.clone() * (n1.clone() - one.clone()),
);
tagged_assert_zero_integrity(
builder,
&MEMORY_DELTA_INV_TAGS,
&mut idx,
gate_not_n0 * not_n1 * addr_delta.clone(),
);
}
fn enforce_scw_readonly_constraint<AB>(
builder: &mut AB,
flag_memory_active_not_last: AB::Expr,
cols: &MemoryColumns<AB::Expr>,
cols_next: &MemoryColumns<AB::Expr>,
deltas: &MemoryDeltas<AB::Expr>,
one: AB::Expr,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let clk_no_change = one.clone() - deltas.clk_delta.clone() * cols_next.d_inv.clone();
let is_write = one.clone() - cols.is_read.clone();
let is_write_next = one.clone() - cols_next.is_read.clone();
let any_write = is_write + is_write_next;
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&MEMORY_SCW_READS_TAGS,
&mut idx,
flag_memory_active_not_last
* cols_next.flag_same_ctx_word.clone()
* clk_no_change
* any_write,
);
}
pub fn flag_memory_active_not_last_row<E: PrimeCharacteristicRing>(s0: E, s1: E, s2_next: E) -> E {
s0 * s1 * (E::ONE - s2_next)
}
pub fn flag_next_row_first_memory<E: PrimeCharacteristicRing>(
s0: E,
s1: E,
s1_next: E,
s2_next: E,
) -> E {
(E::ONE - s1) * s0.clone() * s1_next * (E::ONE - s2_next)
}