Skip to main content

miden_ace_codegen/
circuit.rs

1//! ACE circuit emission for the DAG IR.
2//!
3//! The emitted circuit is a flat list of inputs, constants, and arithmetic
4//! ops that matches the ACE chiplet execution model.
5
6use std::collections::HashMap;
7
8use miden_crypto::field::Field;
9
10use crate::{
11    AceError, InputLayout,
12    dag::{AceDag, NodeId, NodeKind},
13};
14
15/// Arithmetic operations supported by the ACE circuit.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub(crate) enum AceOp {
18    Add,
19    Sub,
20    Mul,
21}
22
23/// Nodes in the emitted ACE circuit.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
25pub(crate) enum AceNode {
26    Input(usize),
27    Constant(usize),
28    Operation(usize),
29}
30
31/// Operation node in the ACE circuit.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub(crate) struct AceOpNode {
34    pub op: AceOp,
35    pub lhs: AceNode,
36    pub rhs: AceNode,
37}
38
39/// Emitted ACE circuit with layout and operation list.
40///
41/// This is the off-VM representation used by tests and tools.
42#[derive(Debug, Clone)]
43pub struct AceCircuit<EF> {
44    pub(crate) layout: InputLayout,
45    pub(crate) constants: Vec<EF>,
46    pub(crate) operations: Vec<AceOpNode>,
47    pub(crate) root: AceNode,
48}
49
50impl<EF: Field> AceCircuit<EF> {
51    /// Return the input layout for this circuit.
52    pub fn layout(&self) -> &InputLayout {
53        &self.layout
54    }
55
56    /// Evaluate the circuit against the provided input vector.
57    pub fn eval(&self, inputs: &[EF]) -> Result<EF, AceError> {
58        if inputs.len() != self.layout.total_inputs {
59            return Err(AceError::InvalidInputLength {
60                expected: self.layout.total_inputs,
61                got: inputs.len(),
62            });
63        }
64        let mut op_values = vec![EF::ZERO; self.operations.len()];
65        for (idx, op) in self.operations.iter().enumerate() {
66            let lhs = self.node_value(op.lhs, inputs, &op_values);
67            let rhs = self.node_value(op.rhs, inputs, &op_values);
68            op_values[idx] = match op.op {
69                AceOp::Add => lhs + rhs,
70                AceOp::Sub => lhs - rhs,
71                AceOp::Mul => lhs * rhs,
72            };
73        }
74        Ok(self.node_value(self.root, inputs, &op_values))
75    }
76
77    /// Total number of nodes (inputs + constants + ops).
78    pub fn num_nodes(&self) -> usize {
79        self.layout.total_inputs + self.constants.len() + self.operations.len()
80    }
81
82    fn node_value(&self, node: AceNode, inputs: &[EF], op_values: &[EF]) -> EF {
83        match node {
84            AceNode::Input(index) => inputs[index],
85            AceNode::Constant(index) => self.constants[index],
86            AceNode::Operation(index) => op_values[index],
87        }
88    }
89}
90
91/// Emit an ACE circuit from the DAG and input layout.
92pub(crate) fn emit_circuit<EF>(
93    dag: &AceDag<EF>,
94    layout: InputLayout,
95) -> Result<AceCircuit<EF>, AceError>
96where
97    EF: Field,
98{
99    let mut constants = Vec::new();
100    let mut constant_map = HashMap::<EF, usize>::new();
101    let mut operations = Vec::new();
102    let mut node_map: Vec<Option<AceNode>> = vec![None; dag.nodes.len()];
103
104    for (idx, node) in dag.nodes.iter().enumerate() {
105        let ace_node = match node {
106            NodeKind::Input(key) => {
107                let input_idx = layout.index(*key).expect("input key must be present in layout");
108                AceNode::Input(input_idx)
109            },
110            NodeKind::Constant(value) => {
111                let const_idx = *constant_map.entry(*value).or_insert_with(|| {
112                    constants.push(*value);
113                    constants.len() - 1
114                });
115                AceNode::Constant(const_idx)
116            },
117            NodeKind::Add(a, b) => {
118                let lhs = lookup_node(&node_map, *a);
119                let rhs = lookup_node(&node_map, *b);
120                let op_idx = operations.len();
121                operations.push(AceOpNode { op: AceOp::Add, lhs, rhs });
122                AceNode::Operation(op_idx)
123            },
124            NodeKind::Sub(a, b) => {
125                let lhs = lookup_node(&node_map, *a);
126                let rhs = lookup_node(&node_map, *b);
127                let op_idx = operations.len();
128                operations.push(AceOpNode { op: AceOp::Sub, lhs, rhs });
129                AceNode::Operation(op_idx)
130            },
131            NodeKind::Mul(a, b) => {
132                let lhs = lookup_node(&node_map, *a);
133                let rhs = lookup_node(&node_map, *b);
134                let op_idx = operations.len();
135                operations.push(AceOpNode { op: AceOp::Mul, lhs, rhs });
136                AceNode::Operation(op_idx)
137            },
138            NodeKind::Neg(a) => {
139                let rhs = lookup_node(&node_map, *a);
140                let zero = *constant_map.entry(EF::ZERO).or_insert_with(|| {
141                    constants.push(EF::ZERO);
142                    constants.len() - 1
143                });
144                let op_idx = operations.len();
145                operations.push(AceOpNode {
146                    op: AceOp::Sub,
147                    lhs: AceNode::Constant(zero),
148                    rhs,
149                });
150                AceNode::Operation(op_idx)
151            },
152        };
153        node_map[idx] = Some(ace_node);
154    }
155
156    let root = lookup_node(&node_map, dag.root);
157    Ok(AceCircuit { layout, constants, operations, root })
158}
159
160fn lookup_node(map: &[Option<AceNode>], id: NodeId) -> AceNode {
161    map[id.index()].expect("ACE DAG nodes must be topologically ordered")
162}