vyre_reference/execution/
mod.rs1pub(crate) mod call;
8pub mod expr;
9pub(crate) mod expr_cast;
10pub(crate) mod hashmap;
11pub mod node;
12pub(crate) mod node_tree;
13pub mod sequential;
14pub(crate) mod typed_ops;
15
16use std::borrow::Cow;
17
18use rustc_hash::FxHashMap;
19use vyre::ir::{InterpCtx, Node, NodeId, NodeStorage, Program, Value as IrValue};
20
21use crate::value::Value;
22
23pub(crate) fn program_for_interpreter(program: &Program) -> Result<Cow<'_, Program>, vyre::Error> {
31 let normalized = if let Some(message) = program.top_level_region_violation() {
32 if program.entry().is_empty() {
33 return Err(vyre::Error::interp(format!(
34 "reference interpreter requires a top-level Region-wrapped Program: {message}"
35 )));
36 }
37 if matches!(program.entry().first(), Some(Node::Store { .. })) {
38 return Err(vyre::Error::interp(format!(
39 "reference interpreter requires a top-level Region-wrapped Program: {message}"
40 )));
41 }
42 Cow::Owned(program.clone().reconcile_runnable_top_level())
43 } else {
44 Cow::Borrowed(program)
45 };
46 match vyre_foundation::transform::collectives::lower_single_rank_collectives(
47 normalized.as_ref(),
48 ) {
49 Ok(Some(lowered)) => Ok(Cow::Owned(lowered)),
50 Ok(None) => Ok(normalized),
51 Err(error) => Err(vyre::Error::interp(error.to_string())),
52 }
53}
54
55pub fn reference_eval(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
61 run_arena_reference(program, inputs)
62}
63
64pub fn run_arena_reference(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
66 let program = program_for_interpreter(program)?;
67 hashmap::run_hashmap_reference(&program, inputs)
68}
69
70#[cfg(test)]
72pub fn eval_hashmap_reference(
73 program: &Program,
74 inputs: &[Value],
75) -> Result<Vec<Value>, vyre::Error> {
76 run_arena_reference(program, inputs)
77}
78
79pub fn run_storage_graph(
81 nodes: &[(NodeId, NodeStorage)],
82 outputs: &[NodeId],
83) -> Result<Vec<IrValue>, vyre::Error> {
84 let mut graph = FxHashMap::with_capacity_and_hasher(nodes.len(), Default::default());
85 for (id, node) in nodes {
86 if graph.insert(*id, node).is_some() {
87 return Err(duplicate_node_error(*id));
88 }
89 }
90 let mut ctx = InterpCtx::default();
91 let mut states = FxHashMap::with_capacity_and_hasher(graph.len(), Default::default());
92
93 for output in outputs {
94 eval_storage_node(*output, &graph, &mut ctx, &mut states)?;
95 }
96
97 outputs
98 .iter()
99 .map(|id| ctx.get(*id).map_err(interp_error))
100 .collect()
101}
102
103#[derive(Clone, Copy, Debug, PartialEq, Eq)]
104enum VisitState {
105 Visiting,
106 Done,
107}
108
109fn eval_storage_node(
110 id: NodeId,
111 graph: &FxHashMap<NodeId, &NodeStorage>,
112 ctx: &mut InterpCtx,
113 states: &mut FxHashMap<NodeId, VisitState>,
114) -> Result<(), vyre::Error> {
115 match states.get(&id).copied() {
116 Some(VisitState::Done) => return Ok(()),
117 Some(VisitState::Visiting) => return Err(cycle_error(id)),
118 None => {}
119 }
120
121 let node = *graph.get(&id).ok_or_else(|| missing_node_error(id))?;
122 states.insert(id, VisitState::Visiting);
123 let inputs = node.input_ids();
124 for input in &inputs {
125 eval_storage_node(*input, graph, ctx, states)?;
126 }
127 ctx.set_operands(inputs);
128 let value = node.interpret(ctx).map_err(interp_error)?;
129 ctx.set(id, value);
130 states.insert(id, VisitState::Done);
131 Ok(())
132}
133
134fn interp_error(error: vyre::ir::EvalError) -> vyre::Error {
135 vyre::Error::interp(error.to_string())
136}
137
138fn missing_node_error(id: NodeId) -> vyre::Error {
139 vyre::Error::interp(format!(
140 "graph references missing node {}. Fix: include every dependency in the interpreter input graph.",
141 id.0
142 ))
143}
144
145fn cycle_error(id: NodeId) -> vyre::Error {
146 vyre::Error::interp(format!(
147 "graph contains a dependency cycle at node {}. Fix: submit an acyclic dataflow graph.",
148 id.0
149 ))
150}
151
152fn duplicate_node_error(id: NodeId) -> vyre::Error {
153 vyre::Error::interp(format!(
154 "graph contains duplicate node {}. Fix: submit exactly one storage record for each NodeId before reference execution.",
155 id.0
156 ))
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use vyre::ir::{BinOp, BufferAccess, BufferDecl, DataType, Expr, Node, NodeStorage};
163
164 #[test]
165 fn reference_eval_dispatches_singleton_atomic_flags_across_dynamic_byte_input() {
166 let program = Program::wrapped(
167 vec![
168 BufferDecl::storage("bytes_in", 0, BufferAccess::ReadOnly, DataType::U8)
169 .with_count(0),
170 BufferDecl::storage("flag", 1, BufferAccess::ReadWrite, DataType::U32)
171 .with_count(1),
172 ],
173 [256, 1, 1],
174 vec![
175 Node::let_bind("i", Expr::InvocationId { axis: 0 }),
176 Node::if_then(
177 Expr::lt(Expr::var("i"), Expr::buf_len("bytes_in")),
178 vec![Node::if_then(
179 Expr::ne(
180 Expr::cast(DataType::U32, Expr::load("bytes_in", Expr::var("i"))),
181 Expr::u32(0),
182 ),
183 vec![Node::let_bind(
184 "flag_old",
185 Expr::atomic_or("flag", Expr::u32(0), Expr::u32(1)),
186 )],
187 )],
188 ),
189 ],
190 );
191 let mut bytes = vec![0u8; 4097];
192 bytes[4096] = 1;
193
194 let outputs = reference_eval(&program, &[Value::from(bytes), Value::from(vec![0u8; 4])])
195 .expect("Fix: reference interpreter should execute singleton atomic flag scans.");
196 let flag = outputs[0].to_bytes();
197
198 assert_eq!(u32::from_le_bytes([flag[0], flag[1], flag[2], flag[3]]), 1);
199 }
200
201 #[test]
202 fn generic_storage_graph_matches_recursive_oracle_for_10k_programs() {
203 let mut rng = 0x9e37_79b9_u64;
204 for case in 0..10_000 {
205 let graph = random_graph(&mut rng, case);
206 let output = graph.last().expect("Fix: generated graph is non-empty").0;
207 let expected =
208 recursive_value(output, &graph).expect("Fix: recursive oracle evaluates");
209 let actual = run_storage_graph(&graph, &[output])
210 .expect("Fix: generic graph interpreter evaluates")[0];
211 assert_eq!(actual, expected, "case {case}");
212 }
213 }
214
215 fn random_graph(rng: &mut u64, case: u32) -> Vec<(NodeId, NodeStorage)> {
216 let len = 2 + (next(rng) as usize % 31);
217 let mut graph = Vec::with_capacity(len);
218 graph.push((NodeId(0), NodeStorage::LitU32(case)));
219 graph.push((NodeId(1), NodeStorage::LitU32(next(rng))));
220 for index in 2..len {
221 let left = NodeId(next(rng) % index as u32);
222 let right = NodeId(next(rng) % index as u32);
223 let op = match next(rng) % 5 {
224 0 => BinOp::Add,
225 1 => BinOp::Sub,
226 2 => BinOp::Mul,
227 3 => BinOp::BitXor,
228 _ => BinOp::BitAnd,
229 };
230 graph.push((NodeId(index as u32), NodeStorage::BinOp { op, left, right }));
231 }
232 graph
233 }
234
235 fn recursive_value(
236 id: NodeId,
237 graph: &[(NodeId, NodeStorage)],
238 ) -> Result<IrValue, vyre::Error> {
239 let node = graph
240 .iter()
241 .find(|(node_id, _)| *node_id == id)
242 .map(|(_, node)| node)
243 .ok_or_else(|| missing_node_error(id))?;
244 match node {
245 NodeStorage::LitU32(value) => Ok(IrValue::U32(*value)),
246 NodeStorage::BinOp { op, left, right } => {
247 let left = expect_u32(recursive_value(*left, graph)?)?;
248 let right = expect_u32(recursive_value(*right, graph)?)?;
249 let value = match op {
250 BinOp::Add => left.wrapping_add(right),
251 BinOp::Sub => left.wrapping_sub(right),
252 BinOp::Mul => left.wrapping_mul(right),
253 BinOp::BitXor => left ^ right,
254 BinOp::BitAnd => left & right,
255 _ => {
256 return Err(vyre::Error::interp(
257 "recursive parity oracle received unsupported op. Fix: keep test generation within the oracle domain.",
258 ));
259 }
260 };
261 Ok(IrValue::U32(value))
262 }
263 _ => Err(vyre::Error::interp(
264 "recursive parity oracle received unsupported node. Fix: keep test generation within the oracle domain.",
265 )),
266 }
267 }
268
269 fn expect_u32(value: IrValue) -> Result<u32, vyre::Error> {
270 match value {
271 IrValue::U32(value) => Ok(value),
272 other => Err(vyre::Error::interp(format!(
273 "recursive parity oracle expected u32, got {other:?}. Fix: keep generated graphs scalar-u32 only."
274 ))),
275 }
276 }
277
278 fn next(rng: &mut u64) -> u32 {
279 *rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
280 (*rng >> 32) as u32
281 }
282}