use miden_core::field::PrimeCharacteristicRing;
use super::selectors::{ace_chiplet_flag, memory_chiplet_flag};
use crate::{
Felt, MainTraceRow,
constraints::tagging::{TagGroup, TaggingAirBuilderExt, tagged_assert_zero_integrity},
trace::chiplets::ace::{
CLK_IDX, CTX_IDX, EVAL_OP_IDX, ID_0_IDX, ID_1_IDX, PTR_IDX, READ_NUM_EVAL_IDX,
SELECTOR_BLOCK_IDX, SELECTOR_START_IDX, V_0_0_IDX, V_0_1_IDX, V_1_0_IDX, V_1_1_IDX,
V_2_0_IDX, V_2_1_IDX,
},
};
const ACE_OFFSET: usize = 4;
const ACE_BASE_ID: usize = super::memory::MEMORY_BASE_ID + super::memory::MEMORY_COUNT;
pub(super) const ACE_COUNT: usize = 20;
const ACE_BINARY_BASE_ID: usize = ACE_BASE_ID;
const ACE_SECTION_BASE_ID: usize = ACE_BASE_ID + 2;
const ACE_WITHIN_SECTION_BASE_ID: usize = ACE_BASE_ID + 7;
const ACE_READ_ID_ID: usize = ACE_BASE_ID + 11;
const ACE_READ_TO_EVAL_ID: usize = ACE_BASE_ID + 12;
const ACE_EVAL_OP_ID: usize = ACE_BASE_ID + 13;
const ACE_EVAL_RESULT_BASE_ID: usize = ACE_BASE_ID + 14;
const ACE_FINAL_BASE_ID: usize = ACE_BASE_ID + 16;
const ACE_FIRST_ROW_ID: usize = ACE_BASE_ID + 19;
const ACE_BINARY_NAMESPACE: &str = "chiplets.ace.selector.binary";
const ACE_SECTION_NAMESPACE: &str = "chiplets.ace.section.flags";
const ACE_WITHIN_SECTION_NAMESPACE: &str = "chiplets.ace.section.transition";
const ACE_READ_ID_NAMESPACE: &str = "chiplets.ace.read.ids";
const ACE_READ_TO_EVAL_NAMESPACE: &str = "chiplets.ace.read.to_eval";
const ACE_EVAL_OP_NAMESPACE: &str = "chiplets.ace.eval.op";
const ACE_EVAL_RESULT_NAMESPACE: &str = "chiplets.ace.eval.result";
const ACE_FINAL_NAMESPACE: &str = "chiplets.ace.final.zero";
const ACE_FIRST_ROW_NAMESPACE: &str = "chiplets.ace.first_row.start";
const ACE_BINARY_NAMES: [&str; 2] = [ACE_BINARY_NAMESPACE; 2];
const ACE_SECTION_NAMES: [&str; 5] = [ACE_SECTION_NAMESPACE; 5];
const ACE_WITHIN_SECTION_NAMES: [&str; 4] = [ACE_WITHIN_SECTION_NAMESPACE; 4];
const ACE_READ_ID_NAMES: [&str; 1] = [ACE_READ_ID_NAMESPACE; 1];
const ACE_READ_TO_EVAL_NAMES: [&str; 1] = [ACE_READ_TO_EVAL_NAMESPACE; 1];
const ACE_EVAL_OP_NAMES: [&str; 1] = [ACE_EVAL_OP_NAMESPACE; 1];
const ACE_EVAL_RESULT_NAMES: [&str; 2] = [ACE_EVAL_RESULT_NAMESPACE; 2];
const ACE_FINAL_NAMES: [&str; 3] = [ACE_FINAL_NAMESPACE; 3];
const ACE_FIRST_ROW_NAMES: [&str; 1] = [ACE_FIRST_ROW_NAMESPACE; 1];
const ACE_BINARY_TAGS: TagGroup = TagGroup {
base: ACE_BINARY_BASE_ID,
names: &ACE_BINARY_NAMES,
};
const ACE_SECTION_TAGS: TagGroup = TagGroup {
base: ACE_SECTION_BASE_ID,
names: &ACE_SECTION_NAMES,
};
const ACE_WITHIN_SECTION_TAGS: TagGroup = TagGroup {
base: ACE_WITHIN_SECTION_BASE_ID,
names: &ACE_WITHIN_SECTION_NAMES,
};
const ACE_READ_ID_TAGS: TagGroup = TagGroup {
base: ACE_READ_ID_ID,
names: &ACE_READ_ID_NAMES,
};
const ACE_READ_TO_EVAL_TAGS: TagGroup = TagGroup {
base: ACE_READ_TO_EVAL_ID,
names: &ACE_READ_TO_EVAL_NAMES,
};
const ACE_EVAL_OP_TAGS: TagGroup = TagGroup {
base: ACE_EVAL_OP_ID,
names: &ACE_EVAL_OP_NAMES,
};
const ACE_EVAL_RESULT_TAGS: TagGroup = TagGroup {
base: ACE_EVAL_RESULT_BASE_ID,
names: &ACE_EVAL_RESULT_NAMES,
};
const ACE_FINAL_TAGS: TagGroup = TagGroup {
base: ACE_FINAL_BASE_ID,
names: &ACE_FINAL_NAMES,
};
const ACE_FIRST_ROW_TAGS: TagGroup = TagGroup {
base: ACE_FIRST_ROW_ID,
names: &ACE_FIRST_ROW_NAMES,
};
pub fn enforce_ace_constraints<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let s0: AB::Expr = local.chiplets[0].clone().into();
let s1: AB::Expr = local.chiplets[1].clone().into();
let s2: AB::Expr = local.chiplets[2].clone().into();
let s2_next: AB::Expr = next.chiplets[2].clone().into();
let s3_next: AB::Expr = next.chiplets[3].clone().into();
let is_transition: AB::Expr = builder.is_transition();
enforce_ace_constraints_all_rows(builder, local, next);
let memory_flag = memory_chiplet_flag(s0, s1, s2);
let ace_next = s2_next * (AB::Expr::ONE - s3_next);
let flag_next_row_first_ace = is_transition * memory_flag * ace_next;
enforce_ace_constraints_first_row(builder, local, next, flag_next_row_first_ace);
}
pub fn enforce_ace_constraints_all_rows<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let s0: AB::Expr = local.chiplets[0].clone().into();
let s1: AB::Expr = local.chiplets[1].clone().into();
let s2: AB::Expr = local.chiplets[2].clone().into();
let s3: AB::Expr = local.chiplets[3].clone().into();
let s3_next: AB::Expr = next.chiplets[3].clone().into();
let ace_flag = ace_chiplet_flag(s0.clone(), s1.clone(), s2.clone(), s3.clone());
let sstart: AB::Expr = load_ace_col::<AB>(local, SELECTOR_START_IDX);
let sstart_next: AB::Expr = load_ace_col::<AB>(next, SELECTOR_START_IDX);
let sblock: AB::Expr = load_ace_col::<AB>(local, SELECTOR_BLOCK_IDX);
let sblock_next: AB::Expr = load_ace_col::<AB>(next, SELECTOR_BLOCK_IDX);
let ctx: AB::Expr = load_ace_col::<AB>(local, CTX_IDX);
let ctx_next: AB::Expr = load_ace_col::<AB>(next, CTX_IDX);
let ptr: AB::Expr = load_ace_col::<AB>(local, PTR_IDX);
let ptr_next: AB::Expr = load_ace_col::<AB>(next, PTR_IDX);
let clk: AB::Expr = load_ace_col::<AB>(local, CLK_IDX);
let clk_next: AB::Expr = load_ace_col::<AB>(next, CLK_IDX);
let op: AB::Expr = load_ace_col::<AB>(local, EVAL_OP_IDX);
let id0: AB::Expr = load_ace_col::<AB>(local, ID_0_IDX);
let id0_next: AB::Expr = load_ace_col::<AB>(next, ID_0_IDX);
let id1: AB::Expr = load_ace_col::<AB>(local, ID_1_IDX);
let n_eval: AB::Expr = load_ace_col::<AB>(local, READ_NUM_EVAL_IDX);
let n_eval_next: AB::Expr = load_ace_col::<AB>(next, READ_NUM_EVAL_IDX);
let v0_0: AB::Expr = load_ace_col::<AB>(local, V_0_0_IDX);
let v0_1: AB::Expr = load_ace_col::<AB>(local, V_0_1_IDX);
let v1_0: AB::Expr = load_ace_col::<AB>(local, V_1_0_IDX);
let v1_1: AB::Expr = load_ace_col::<AB>(local, V_1_1_IDX);
let v2_0: AB::Expr = load_ace_col::<AB>(local, V_2_0_IDX);
let v2_1: AB::Expr = load_ace_col::<AB>(local, V_2_1_IDX);
let one: AB::Expr = AB::Expr::ONE;
let four: AB::Expr = AB::Expr::from_u32(4);
let is_transition: AB::Expr = builder.is_transition();
let flag_ace_next = is_transition.clone() * (one.clone() - s3_next.clone());
let flag_ace_last = is_transition.clone() * s3_next.clone();
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&ACE_BINARY_TAGS,
&mut idx,
ace_flag.clone() * sstart.clone() * (sstart.clone() - one.clone()),
);
tagged_assert_zero_integrity(
builder,
&ACE_BINARY_TAGS,
&mut idx,
ace_flag.clone() * sblock.clone() * (sblock.clone() - one.clone()),
);
let f_next = one.clone() - sstart_next.clone();
let f_end = binary_or((one.clone() - s3_next.clone()) * sstart_next.clone(), s3_next.clone());
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&ACE_SECTION_TAGS,
&mut idx,
ace_flag.clone() * flag_ace_last.clone() * sstart.clone(),
);
tagged_assert_zero_integrity(
builder,
&ACE_SECTION_TAGS,
&mut idx,
ace_flag.clone() * flag_ace_next.clone() * sstart.clone() * sstart_next.clone(),
);
tagged_assert_zero_integrity(
builder,
&ACE_SECTION_TAGS,
&mut idx,
ace_flag.clone() * sstart.clone() * sblock.clone(),
);
tagged_assert_zero_integrity(
builder,
&ACE_SECTION_TAGS,
&mut idx,
ace_flag.clone()
* flag_ace_next.clone()
* f_next.clone()
* sblock.clone()
* (one.clone() - sblock_next.clone()),
);
tagged_assert_zero_integrity(
builder,
&ACE_SECTION_TAGS,
&mut idx,
ace_flag.clone() * is_transition.clone() * f_end.clone() * (one.clone() - sblock.clone()),
);
let flag_within_section = one.clone() - sstart_next.clone();
let f_read = one.clone() - sblock.clone();
let f_eval = sblock.clone();
let within_section_gate =
ace_flag.clone() * flag_ace_next.clone() * flag_within_section.clone();
let expected_ptr_next = ptr.clone() + four.clone() * f_read.clone() + f_eval.clone();
let expected_id0 = id0_next.clone() + f_read.clone().double() + f_eval.clone();
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&ACE_WITHIN_SECTION_TAGS,
&mut idx,
within_section_gate.clone() * (ctx_next.clone() - ctx.clone()),
);
tagged_assert_zero_integrity(
builder,
&ACE_WITHIN_SECTION_TAGS,
&mut idx,
within_section_gate.clone() * (clk_next.clone() - clk.clone()),
);
tagged_assert_zero_integrity(
builder,
&ACE_WITHIN_SECTION_TAGS,
&mut idx,
within_section_gate.clone() * (ptr_next.clone() - expected_ptr_next),
);
tagged_assert_zero_integrity(
builder,
&ACE_WITHIN_SECTION_TAGS,
&mut idx,
within_section_gate * (id0.clone() - expected_id0),
);
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&ACE_READ_ID_TAGS,
&mut idx,
ace_flag.clone() * f_read.clone() * (id1.clone() - id0.clone() + one.clone()),
);
let f_read_next = one.clone() - sblock_next.clone();
let f_eval_next = sblock_next.clone();
let selected = f_read_next * n_eval_next.clone() + f_eval_next * id0_next.clone();
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&ACE_READ_TO_EVAL_TAGS,
&mut idx,
is_transition.clone() * ace_flag.clone() * f_read.clone() * (selected - n_eval),
);
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&ACE_EVAL_OP_TAGS,
&mut idx,
ace_flag.clone()
* f_eval.clone()
* op.clone()
* (op.clone() - one.clone())
* (op.clone() + one.clone()),
);
let eval_gate = ace_flag.clone() * f_eval.clone();
let (expected_0, expected_1) = compute_arithmetic_expected::<AB>(op, v1_0, v1_1, v2_0, v2_1);
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&ACE_EVAL_RESULT_TAGS,
&mut idx,
eval_gate.clone() * (expected_0 - v0_0.clone()),
);
tagged_assert_zero_integrity(
builder,
&ACE_EVAL_RESULT_TAGS,
&mut idx,
eval_gate.clone() * (expected_1 - v0_1.clone()),
);
let gate = ace_flag * is_transition * f_end;
let mut idx = 0;
tagged_assert_zero_integrity(builder, &ACE_FINAL_TAGS, &mut idx, gate.clone() * v0_0);
tagged_assert_zero_integrity(builder, &ACE_FINAL_TAGS, &mut idx, gate.clone() * v0_1);
tagged_assert_zero_integrity(builder, &ACE_FINAL_TAGS, &mut idx, gate * id0);
}
pub fn enforce_ace_constraints_first_row<AB>(
builder: &mut AB,
_local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
flag_next_row_first_ace: AB::Expr,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let sstart_next: AB::Expr = load_ace_col::<AB>(next, SELECTOR_START_IDX);
let one: AB::Expr = AB::Expr::ONE;
let mut idx = 0;
tagged_assert_zero_integrity(
builder,
&ACE_FIRST_ROW_TAGS,
&mut idx,
flag_next_row_first_ace * (sstart_next - one),
);
}
fn load_ace_col<AB>(row: &MainTraceRow<AB::Var>, ace_col_idx: usize) -> AB::Expr
where
AB: TaggingAirBuilderExt<F = Felt>,
{
let local_idx = ACE_OFFSET + ace_col_idx;
row.chiplets[local_idx].clone().into()
}
fn compute_arithmetic_expected<AB>(
op: AB::Expr,
v1_0: AB::Expr,
v1_1: AB::Expr,
v2_0: AB::Expr,
v2_1: AB::Expr,
) -> (AB::Expr, AB::Expr)
where
AB: TaggingAirBuilderExt<F = Felt>,
{
use crate::constraints::ext_field::QuadFeltExpr;
let v1 = QuadFeltExpr(v1_0, v1_1);
let v2 = QuadFeltExpr(v2_0, v2_1);
let linear = v1.clone() + v2.clone() * op.clone();
let nonlinear = v1 * v2;
let op_square = op.clone() * op;
let expected = QuadFeltExpr(
op_square.clone() * (linear.0.clone() - nonlinear.0.clone()) + nonlinear.0,
op_square * (linear.1.clone() - nonlinear.1.clone()) + nonlinear.1,
);
expected.into_parts().into()
}
#[inline]
pub fn binary_or<E>(a: E, b: E) -> E
where
E: Clone + core::ops::Add<Output = E> + core::ops::Sub<Output = E> + core::ops::Mul<Output = E>,
{
a.clone() + b.clone() - a * b
}