use alloc::vec::Vec;
use core::borrow::BorrowMut;
use miden_air::{
AceCols, QuadFeltExpr,
trace::{RowIndex, chiplets::ace::ACE_CHIPLET_NUM_COLS},
};
use miden_core::{
Felt, Word,
field::{BasedVectorSpace, QuadFelt},
};
use super::{
MAX_NUM_ACE_WIRES,
instruction::{Op, decode_instruction},
};
use crate::{ContextId, errors::AceError};
#[derive(Debug, Clone, Copy)]
struct ReadNode {
ptr: Felt,
id_0: Felt,
v_0: QuadFelt,
id_1: Felt,
v_1: QuadFelt,
}
#[derive(Debug, Clone, Copy)]
struct EvalNode {
ptr: Felt,
eval_op: Felt,
id_0: Felt,
v_0: QuadFelt,
id_1: Felt,
v_1: QuadFelt,
id_2: Felt,
v_2: QuadFelt,
}
#[derive(Debug, Clone)]
pub struct CircuitEvaluation {
ctx: ContextId,
clk: RowIndex,
wire_bus: WireBus,
read_nodes: Vec<ReadNode>,
eval_nodes: Vec<EvalNode>,
}
impl CircuitEvaluation {
pub fn new(ctx: ContextId, clk: RowIndex, num_read_rows: u32, num_eval_rows: u32) -> Self {
let num_wires = 2 * (num_read_rows as u64) + (num_eval_rows as u64);
assert!(num_wires <= MAX_NUM_ACE_WIRES as u64, "too many wires");
Self {
ctx,
clk,
wire_bus: WireBus::new(num_wires as u32),
read_nodes: Vec::with_capacity(num_read_rows as usize),
eval_nodes: Vec::with_capacity(num_eval_rows as usize),
}
}
pub fn num_rows(&self) -> usize {
self.read_nodes.len() + self.eval_nodes.len()
}
pub fn clk(&self) -> u32 {
self.clk.into()
}
pub fn ctx(&self) -> u32 {
self.ctx.into()
}
pub fn num_read_rows(&self) -> u32 {
self.read_nodes.len() as u32
}
pub fn num_eval_rows(&self) -> u32 {
self.eval_nodes.len() as u32
}
pub fn do_read(&mut self, ptr: Felt, word: Word) {
let v_0 = QuadFelt::from_basis_coefficients_fn(|i: usize| [word[0], word[1]][i]);
let id_0 = self.wire_bus.insert(v_0);
let v_1 = QuadFelt::from_basis_coefficients_fn(|i: usize| [word[2], word[3]][i]);
let id_1 = self.wire_bus.insert(v_1);
self.read_nodes.push(ReadNode { ptr, id_0, v_0, id_1, v_1 });
}
pub fn do_eval(&mut self, ptr: Felt, instruction: Felt) -> Result<(), AceError> {
let (id_l, id_r, op) = decode_instruction(instruction)
.ok_or(AceError("failed to decode instruction".into()))?;
let v_l = self
.wire_bus
.read_value(id_l)
.ok_or(AceError("failed to read from the wiring bus".into()))?;
let id_1 = Felt::from_u32(id_l);
let v_r = self
.wire_bus
.read_value(id_r)
.ok_or(AceError("failed to read from the wiring bus".into()))?;
let id_2 = Felt::from_u32(id_r);
let v_0 = match op {
Op::Sub => v_l - v_r,
Op::Mul => v_l * v_r,
Op::Add => v_l + v_r,
};
let id_0 = self.wire_bus.insert(v_0);
let eval_op = match op {
Op::Sub => -Felt::ONE,
Op::Mul => Felt::ZERO,
Op::Add => Felt::ONE,
};
self.eval_nodes.push(EvalNode {
ptr,
eval_op,
id_0,
v_0,
id_1,
v_1: v_l,
id_2,
v_2: v_r,
});
Ok(())
}
pub fn fill(&self, offset: usize, out: &mut [Felt]) {
const W: usize = ACE_CHIPLET_NUM_COLS;
let (out_rows, _) = out.as_chunks_mut::<W>();
let num_read_rows = self.read_nodes.len();
let num_eval_rows = self.eval_nodes.len();
let ctx_felt: Felt = self.ctx.into();
let clk_felt: Felt = self.clk.into();
let eval_section_first_idx = Felt::from_u32(num_eval_rows as u32 - 1);
let mut multiplicities_iter = self.wire_bus.wires.iter().map(|(_v, m)| Felt::from_u32(*m));
for (i, node) in self.read_nodes.iter().enumerate() {
let cols: &mut AceCols<Felt> = out_rows[offset + i].as_mut_slice().borrow_mut();
cols.s_start = if i == 0 { Felt::ONE } else { Felt::ZERO };
cols.s_block = Felt::ZERO;
cols.ctx = ctx_felt;
cols.clk = clk_felt;
cols.ptr = node.ptr;
cols.id_0 = node.id_0;
cols.v_0 = quad_to_expr(node.v_0);
cols.id_1 = node.id_1;
cols.v_1 = quad_to_expr(node.v_1);
let m_0 = multiplicities_iter
.next()
.expect("the m0 multiplicities were not constructed properly");
let m_1 = multiplicities_iter
.next()
.expect("the m1 multiplicities were not constructed properly");
let read = cols.read_mut();
read.num_eval = eval_section_first_idx;
read.m_0 = m_0;
read.m_1 = m_1;
}
for (i, node) in self.eval_nodes.iter().enumerate() {
let cols: &mut AceCols<Felt> =
out_rows[offset + num_read_rows + i].as_mut_slice().borrow_mut();
cols.s_start = Felt::ZERO;
cols.s_block = Felt::ONE;
cols.ctx = ctx_felt;
cols.clk = clk_felt;
cols.ptr = node.ptr;
cols.eval_op = node.eval_op;
cols.id_0 = node.id_0;
cols.v_0 = quad_to_expr(node.v_0);
cols.id_1 = node.id_1;
cols.v_1 = quad_to_expr(node.v_1);
let m_0 = multiplicities_iter
.next()
.expect("the m0 multiplicities were not constructed properly");
let eval = cols.eval_mut();
eval.id_2 = node.id_2;
eval.v_2 = quad_to_expr(node.v_2);
eval.m_0 = m_0;
}
debug_assert!(multiplicities_iter.next().is_none());
}
pub fn output_value(&self) -> Option<QuadFelt> {
if !self.wire_bus.is_finalized() {
return None;
}
self.wire_bus.wires.last().map(|(v, _m)| *v)
}
}
fn quad_to_expr(v: QuadFelt) -> QuadFeltExpr<Felt> {
let c = v.as_basis_coefficients_slice();
QuadFeltExpr(c[0], c[1])
}
#[derive(Debug, Clone)]
struct WireBus {
id_next: Felt,
wires: Vec<(QuadFelt, u32)>,
num_wires: u32,
}
impl WireBus {
fn new(num_wires: u32) -> Self {
Self {
wires: Vec::with_capacity(num_wires as usize),
num_wires,
id_next: Felt::from_u32(num_wires - 1),
}
}
fn insert(&mut self, value: QuadFelt) -> Felt {
debug_assert!(!self.is_finalized());
self.wires.push((value, 0));
let id = self.id_next;
self.id_next -= Felt::ONE;
id
}
fn read_value(&mut self, id: u32) -> Option<QuadFelt> {
let (v, m) = self
.num_wires
.checked_sub(id + 1)
.and_then(|id| self.wires.get_mut(id as usize))?;
*m += 1;
Some(*v)
}
fn is_finalized(&self) -> bool {
self.wires.len() == self.num_wires as usize
}
}