1use std::collections::HashMap;
7
8use miden_crypto::field::Field;
9
10use crate::{
11 AceError, InputLayout,
12 dag::{AceDag, NodeId, NodeKind},
13};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub(crate) enum AceOp {
18 Add,
19 Sub,
20 Mul,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
25pub(crate) enum AceNode {
26 Input(usize),
27 Constant(usize),
28 Operation(usize),
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub(crate) struct AceOpNode {
34 pub op: AceOp,
35 pub lhs: AceNode,
36 pub rhs: AceNode,
37}
38
39#[derive(Debug, Clone)]
43pub struct AceCircuit<EF> {
44 pub(crate) layout: InputLayout,
45 pub(crate) constants: Vec<EF>,
46 pub(crate) operations: Vec<AceOpNode>,
47 pub(crate) root: AceNode,
48}
49
50impl<EF: Field> AceCircuit<EF> {
51 pub fn layout(&self) -> &InputLayout {
53 &self.layout
54 }
55
56 pub fn eval(&self, inputs: &[EF]) -> Result<EF, AceError> {
58 if inputs.len() != self.layout.total_inputs {
59 return Err(AceError::InvalidInputLength {
60 expected: self.layout.total_inputs,
61 got: inputs.len(),
62 });
63 }
64 let mut op_values = vec![EF::ZERO; self.operations.len()];
65 for (idx, op) in self.operations.iter().enumerate() {
66 let lhs = self.node_value(op.lhs, inputs, &op_values);
67 let rhs = self.node_value(op.rhs, inputs, &op_values);
68 op_values[idx] = match op.op {
69 AceOp::Add => lhs + rhs,
70 AceOp::Sub => lhs - rhs,
71 AceOp::Mul => lhs * rhs,
72 };
73 }
74 Ok(self.node_value(self.root, inputs, &op_values))
75 }
76
77 pub fn num_nodes(&self) -> usize {
79 self.layout.total_inputs + self.constants.len() + self.operations.len()
80 }
81
82 fn node_value(&self, node: AceNode, inputs: &[EF], op_values: &[EF]) -> EF {
83 match node {
84 AceNode::Input(index) => inputs[index],
85 AceNode::Constant(index) => self.constants[index],
86 AceNode::Operation(index) => op_values[index],
87 }
88 }
89}
90
91pub(crate) fn emit_circuit<EF>(
93 dag: &AceDag<EF>,
94 layout: InputLayout,
95) -> Result<AceCircuit<EF>, AceError>
96where
97 EF: Field,
98{
99 let mut constants = Vec::new();
100 let mut constant_map = HashMap::<EF, usize>::new();
101 let mut operations = Vec::new();
102 let mut node_map: Vec<Option<AceNode>> = vec![None; dag.nodes.len()];
103
104 for (idx, node) in dag.nodes.iter().enumerate() {
105 let ace_node = match node {
106 NodeKind::Input(key) => {
107 let input_idx = layout.index(*key).expect("input key must be present in layout");
108 AceNode::Input(input_idx)
109 },
110 NodeKind::Constant(value) => {
111 let const_idx = *constant_map.entry(*value).or_insert_with(|| {
112 constants.push(*value);
113 constants.len() - 1
114 });
115 AceNode::Constant(const_idx)
116 },
117 NodeKind::Add(a, b) => {
118 let lhs = lookup_node(&node_map, *a);
119 let rhs = lookup_node(&node_map, *b);
120 let op_idx = operations.len();
121 operations.push(AceOpNode { op: AceOp::Add, lhs, rhs });
122 AceNode::Operation(op_idx)
123 },
124 NodeKind::Sub(a, b) => {
125 let lhs = lookup_node(&node_map, *a);
126 let rhs = lookup_node(&node_map, *b);
127 let op_idx = operations.len();
128 operations.push(AceOpNode { op: AceOp::Sub, lhs, rhs });
129 AceNode::Operation(op_idx)
130 },
131 NodeKind::Mul(a, b) => {
132 let lhs = lookup_node(&node_map, *a);
133 let rhs = lookup_node(&node_map, *b);
134 let op_idx = operations.len();
135 operations.push(AceOpNode { op: AceOp::Mul, lhs, rhs });
136 AceNode::Operation(op_idx)
137 },
138 NodeKind::Neg(a) => {
139 let rhs = lookup_node(&node_map, *a);
140 let zero = *constant_map.entry(EF::ZERO).or_insert_with(|| {
141 constants.push(EF::ZERO);
142 constants.len() - 1
143 });
144 let op_idx = operations.len();
145 operations.push(AceOpNode {
146 op: AceOp::Sub,
147 lhs: AceNode::Constant(zero),
148 rhs,
149 });
150 AceNode::Operation(op_idx)
151 },
152 };
153 node_map[idx] = Some(ace_node);
154 }
155
156 let root = lookup_node(&node_map, dag.root);
157 Ok(AceCircuit { layout, constants, operations, root })
158}
159
160fn lookup_node(map: &[Option<AceNode>], id: NodeId) -> AceNode {
161 map[id.index()].expect("ACE DAG nodes must be topologically ordered")
162}