miden-processor 0.24.0

Miden VM processor
Documentation
use alloc::vec::Vec;
use core::borrow::BorrowMut;

use miden_air::{
    AceCols, QuadFeltExpr,
    trace::{RowIndex, chiplets::ace::ACE_CHIPLET_NUM_COLS},
};
use miden_core::{
    Felt, Word,
    field::{BasedVectorSpace, QuadFelt},
};

use super::{
    MAX_NUM_ACE_WIRES,
    instruction::{Op, decode_instruction},
};
use crate::{ContextId, errors::AceError};

/// One row of the ACE chiplet trace in `READ` mode: two memory-loaded wires per row, plus the
/// pointer of the word that was loaded.
#[derive(Debug, Clone, Copy)]
struct ReadNode {
    ptr: Felt,
    id_0: Felt,
    v_0: QuadFelt,
    id_1: Felt,
    v_1: QuadFelt,
}

/// One row of the ACE chiplet trace in `EVAL` mode: a single arithmetic gate `(id_0, v_0)` with
/// two inputs `(id_1, v_1)` (left) and `(id_2, v_2)` (right), the instruction pointer that
/// produced it, and the gate's `eval_op` selector.
#[derive(Debug, Clone, Copy)]
struct EvalNode {
    ptr: Felt,
    eval_op: Felt,
    id_0: Felt,
    v_0: QuadFelt,
    id_1: Felt,
    v_1: QuadFelt,
    id_2: Felt,
    v_2: QuadFelt,
}

/// Contains the variable and evaluation nodes resulting from the evaluation of a circuit.
/// The output value is checked to be equal to 0.
///
/// The set of nodes is used to fill the ACE chiplet trace.
#[derive(Debug, Clone)]
pub struct CircuitEvaluation {
    ctx: ContextId,
    clk: RowIndex,
    wire_bus: WireBus,
    read_nodes: Vec<ReadNode>,
    eval_nodes: Vec<EvalNode>,
}

impl CircuitEvaluation {
    /// Generates the nodes in the graph generated by evaluating the inputs and circuit
    /// located in a contiguous memory region.
    ///
    /// # Panics:
    /// This function panics if the number of rows for each section leads to more than
    /// [`MAX_NUM_ACE_WIRES`] wires.
    pub fn new(ctx: ContextId, clk: RowIndex, num_read_rows: u32, num_eval_rows: u32) -> Self {
        let num_wires = 2 * (num_read_rows as u64) + (num_eval_rows as u64);
        assert!(num_wires <= MAX_NUM_ACE_WIRES as u64, "too many wires");

        Self {
            ctx,
            clk,
            wire_bus: WireBus::new(num_wires as u32),
            read_nodes: Vec::with_capacity(num_read_rows as usize),
            eval_nodes: Vec::with_capacity(num_eval_rows as usize),
        }
    }

    pub fn num_rows(&self) -> usize {
        self.read_nodes.len() + self.eval_nodes.len()
    }

    pub fn clk(&self) -> u32 {
        self.clk.into()
    }

    pub fn ctx(&self) -> u32 {
        self.ctx.into()
    }

    pub fn num_read_rows(&self) -> u32 {
        self.read_nodes.len() as u32
    }

    pub fn num_eval_rows(&self) -> u32 {
        self.eval_nodes.len() as u32
    }

    /// Reads the word from memory at `ptr`, interpreting it as `[v_00, v_01, v_10, v_11]`, and
    /// adds wires with values `v_0 = QuadFelt(v_00, v_01)` and `v_1 = QuadFelt(v_10, v_11)`.
    pub fn do_read(&mut self, ptr: Felt, word: Word) {
        let v_0 = QuadFelt::from_basis_coefficients_fn(|i: usize| [word[0], word[1]][i]);
        let id_0 = self.wire_bus.insert(v_0);

        let v_1 = QuadFelt::from_basis_coefficients_fn(|i: usize| [word[2], word[3]][i]);
        let id_1 = self.wire_bus.insert(v_1);

        self.read_nodes.push(ReadNode { ptr, id_0, v_0, id_1, v_1 });
    }

    /// Reads the next instruction at `ptr`, requests the inputs from the wire bus
    /// and inserts a new wire with the result.
    pub fn do_eval(&mut self, ptr: Felt, instruction: Felt) -> Result<(), AceError> {
        let (id_l, id_r, op) = decode_instruction(instruction)
            .ok_or(AceError("failed to decode instruction".into()))?;

        let v_l = self
            .wire_bus
            .read_value(id_l)
            .ok_or(AceError("failed to read from the wiring bus".into()))?;
        let id_1 = Felt::from_u32(id_l);

        let v_r = self
            .wire_bus
            .read_value(id_r)
            .ok_or(AceError("failed to read from the wiring bus".into()))?;
        let id_2 = Felt::from_u32(id_r);

        let v_0 = match op {
            Op::Sub => v_l - v_r,
            Op::Mul => v_l * v_r,
            Op::Add => v_l + v_r,
        };
        let id_0 = self.wire_bus.insert(v_0);

        let eval_op = match op {
            Op::Sub => -Felt::ONE,
            Op::Mul => Felt::ZERO,
            Op::Add => Felt::ONE,
        };

        self.eval_nodes.push(EvalNode {
            ptr,
            eval_op,
            id_0,
            v_0,
            id_1,
            v_1: v_l,
            id_2,
            v_2: v_r,
        });
        Ok(())
    }

    /// Writes this circuit evaluation's rows into the row-major buffer `out`
    /// (`ACE_CHIPLET_NUM_COLS` contiguous cells per row), starting at row `offset`. `out`
    /// is assumed zero-initialized, so columns that are zero on a row are left untouched.
    pub fn fill(&self, offset: usize, out: &mut [Felt]) {
        const W: usize = ACE_CHIPLET_NUM_COLS;
        let (out_rows, _) = out.as_chunks_mut::<W>();
        let num_read_rows = self.read_nodes.len();
        let num_eval_rows = self.eval_nodes.len();

        let ctx_felt: Felt = self.ctx.into();
        let clk_felt: Felt = self.clk.into();
        let eval_section_first_idx = Felt::from_u32(num_eval_rows as u32 - 1);
        let mut multiplicities_iter = self.wire_bus.wires.iter().map(|(_v, m)| Felt::from_u32(*m));

        // READ rows.
        for (i, node) in self.read_nodes.iter().enumerate() {
            let cols: &mut AceCols<Felt> = out_rows[offset + i].as_mut_slice().borrow_mut();
            cols.s_start = if i == 0 { Felt::ONE } else { Felt::ZERO };
            cols.s_block = Felt::ZERO;
            cols.ctx = ctx_felt;
            cols.clk = clk_felt;
            cols.ptr = node.ptr;
            cols.id_0 = node.id_0;
            cols.v_0 = quad_to_expr(node.v_0);
            cols.id_1 = node.id_1;
            cols.v_1 = quad_to_expr(node.v_1);

            let m_0 = multiplicities_iter
                .next()
                .expect("the m0 multiplicities were not constructed properly");
            let m_1 = multiplicities_iter
                .next()
                .expect("the m1 multiplicities were not constructed properly");

            let read = cols.read_mut();
            read.num_eval = eval_section_first_idx;
            read.m_0 = m_0;
            read.m_1 = m_1;
        }

        // EVAL rows.
        for (i, node) in self.eval_nodes.iter().enumerate() {
            let cols: &mut AceCols<Felt> =
                out_rows[offset + num_read_rows + i].as_mut_slice().borrow_mut();
            cols.s_start = Felt::ZERO;
            cols.s_block = Felt::ONE;
            cols.ctx = ctx_felt;
            cols.clk = clk_felt;
            cols.ptr = node.ptr;
            cols.eval_op = node.eval_op;
            cols.id_0 = node.id_0;
            cols.v_0 = quad_to_expr(node.v_0);
            cols.id_1 = node.id_1;
            cols.v_1 = quad_to_expr(node.v_1);

            let m_0 = multiplicities_iter
                .next()
                .expect("the m0 multiplicities were not constructed properly");

            let eval = cols.eval_mut();
            eval.id_2 = node.id_2;
            eval.v_2 = quad_to_expr(node.v_2);
            eval.m_0 = m_0;
        }

        debug_assert!(multiplicities_iter.next().is_none());
    }

    /// Returns the output value, if the circuit has finished evaluating.
    pub fn output_value(&self) -> Option<QuadFelt> {
        if !self.wire_bus.is_finalized() {
            return None;
        }
        self.wire_bus.wires.last().map(|(v, _m)| *v)
    }
}

/// Lifts a `QuadFelt` value into the [`QuadFeltExpr<Felt>`] basis-coefficient pair expected by the
/// chiplet column structs.
fn quad_to_expr(v: QuadFelt) -> QuadFeltExpr<Felt> {
    let c = v.as_basis_coefficients_slice();
    QuadFeltExpr(c[0], c[1])
}

/// A LogUp-based bus used for wiring the gates of the circuit.
///
/// Gates are fan-in 2 but can have fan-out up to the field characteristic which, given the bounds
/// on the execution trace length, means practically arbitrary fan-out.
/// The main idea, with some slight variations between the `READ` and `EVAL` sections, is, for each
/// gate, to "receive" the values of the input wires from the bus and to "send" the value of
/// the value of the output wire back with multiplicity equal to the fan-out of the respective gate.
/// Note that the messages include extra data in order to avoid collisions.
#[derive(Debug, Clone)]
struct WireBus {
    // Circuit ID as Felt of the next wire to be inserted
    id_next: Felt,
    // Pairs of values and multiplicities
    // The wire with index `id` is stored at `num_wires - 1 - id`
    wires: Vec<(QuadFelt, u32)>,
    // Total expected number of wires to be inserted.
    num_wires: u32,
}

impl WireBus {
    fn new(num_wires: u32) -> Self {
        Self {
            wires: Vec::with_capacity(num_wires as usize),
            num_wires,
            id_next: Felt::from_u32(num_wires - 1),
        }
    }

    /// Inserts a new value into the bus, and returns its expected id as `Felt`
    fn insert(&mut self, value: QuadFelt) -> Felt {
        debug_assert!(!self.is_finalized());
        self.wires.push((value, 0));
        let id = self.id_next;
        self.id_next -= Felt::ONE;
        id
    }

    /// Reads the value of a wire with given `id`, incrementing its multiplicity.
    /// Returns `None` if the requested wire has not been inserted yet.
    fn read_value(&mut self, id: u32) -> Option<QuadFelt> {
        // Ensures subtracting the id from num_wires results in a valid wire index
        let (v, m) = self
            .num_wires
            .checked_sub(id + 1)
            .and_then(|id| self.wires.get_mut(id as usize))?;
        *m += 1;
        Some(*v)
    }

    /// Return true if the expected number of wires have been inserted.
    fn is_finalized(&self) -> bool {
        self.wires.len() == self.num_wires as usize
    }
}