Skip to main content

vyre_reference/execution/
mod.rs

1//! Generic reference interpreter entry points.
2//!
3//! The stable statement-IR [`reference_eval`] entry point remains delegated to
4//! the existing invocation simulator until `Program` stores graph nodes
5//! directly.
6
7pub(crate) mod call;
8pub mod expr;
9pub(crate) mod expr_cast;
10pub(crate) mod hashmap;
11pub mod node;
12pub(crate) mod node_tree;
13pub mod sequential;
14pub(crate) mod typed_ops;
15
16use std::borrow::Cow;
17
18use rustc_hash::FxHashMap;
19use vyre::ir::{InterpCtx, Node, NodeId, NodeStorage, Program, Value as IrValue};
20
21use crate::value::Value;
22
23/// If the program satisfies the public top-level-Region model, return a
24/// byte-identical clone. If not, the usual case is
25/// `optimizer::passes::cleanup::region_inline_engine` having flattened a Category-A wrapper;
26/// in that case [`Program::reconcile_runnable_top_level`] matches
27/// `Program::wrapped` again. When the first entry node is a `Store` (or the
28/// entry is empty), we do **not** auto-wrap: those programs must still use
29/// `Program::wrapped` explicitly, matching `region_gate` negative tests.
30pub(crate) fn program_for_interpreter(program: &Program) -> Result<Cow<'_, Program>, vyre::Error> {
31    let normalized = if let Some(message) = program.top_level_region_violation() {
32        if program.entry().is_empty() {
33            return Err(vyre::Error::interp(format!(
34                "reference interpreter requires a top-level Region-wrapped Program: {message}"
35            )));
36        }
37        if matches!(program.entry().first(), Some(Node::Store { .. })) {
38            return Err(vyre::Error::interp(format!(
39                "reference interpreter requires a top-level Region-wrapped Program: {message}"
40            )));
41        }
42        Cow::Owned(program.clone().reconcile_runnable_top_level())
43    } else {
44        Cow::Borrowed(program)
45    };
46    match vyre_foundation::transform::collectives::lower_single_rank_collectives(
47        normalized.as_ref(),
48    ) {
49        Ok(Some(lowered)) => Ok(Cow::Owned(lowered)),
50        Ok(None) => Ok(normalized),
51        Err(error) => Err(vyre::Error::interp(error.to_string())),
52    }
53}
54
55/// Execute a vyre IR program on the pure Rust reference interpreter.
56///
57/// The current public [`Program`] model is statement-oriented, so this stable
58/// entry point delegates to the statement evaluator. Graph-shaped extension
59/// nodes use [`run_storage_graph`].
60pub fn reference_eval(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
61    run_arena_reference(program, inputs)
62}
63
64/// Execute using the statement-IR reference evaluator.
65pub fn run_arena_reference(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
66    let program = program_for_interpreter(program)?;
67    hashmap::run_hashmap_reference(&program, inputs)
68}
69
70/// Differential oracle retained for tests during the generic interpreter transition.
71#[cfg(test)]
72pub fn eval_hashmap_reference(
73    program: &Program,
74    inputs: &[Value],
75) -> Result<Vec<Value>, vyre::Error> {
76    run_arena_reference(program, inputs)
77}
78
79/// Interpret a compact [`NodeStorage`] graph and return output node values.
80pub fn run_storage_graph(
81    nodes: &[(NodeId, NodeStorage)],
82    outputs: &[NodeId],
83) -> Result<Vec<IrValue>, vyre::Error> {
84    let mut graph = FxHashMap::with_capacity_and_hasher(nodes.len(), Default::default());
85    for (id, node) in nodes {
86        if graph.insert(*id, node).is_some() {
87            return Err(duplicate_node_error(*id));
88        }
89    }
90    let mut ctx = InterpCtx::default();
91    let mut states = FxHashMap::with_capacity_and_hasher(graph.len(), Default::default());
92
93    for output in outputs {
94        eval_storage_node(*output, &graph, &mut ctx, &mut states)?;
95    }
96
97    outputs
98        .iter()
99        .map(|id| ctx.get(*id).map_err(interp_error))
100        .collect()
101}
102
103#[derive(Clone, Copy, Debug, PartialEq, Eq)]
104enum VisitState {
105    Visiting,
106    Done,
107}
108
109fn eval_storage_node(
110    id: NodeId,
111    graph: &FxHashMap<NodeId, &NodeStorage>,
112    ctx: &mut InterpCtx,
113    states: &mut FxHashMap<NodeId, VisitState>,
114) -> Result<(), vyre::Error> {
115    match states.get(&id).copied() {
116        Some(VisitState::Done) => return Ok(()),
117        Some(VisitState::Visiting) => return Err(cycle_error(id)),
118        None => {}
119    }
120
121    let node = *graph.get(&id).ok_or_else(|| missing_node_error(id))?;
122    states.insert(id, VisitState::Visiting);
123    let inputs = node.input_ids();
124    for input in &inputs {
125        eval_storage_node(*input, graph, ctx, states)?;
126    }
127    ctx.set_operands(inputs);
128    let value = node.interpret(ctx).map_err(interp_error)?;
129    ctx.set(id, value);
130    states.insert(id, VisitState::Done);
131    Ok(())
132}
133
134fn interp_error(error: vyre::ir::EvalError) -> vyre::Error {
135    vyre::Error::interp(error.to_string())
136}
137
138fn missing_node_error(id: NodeId) -> vyre::Error {
139    vyre::Error::interp(format!(
140        "graph references missing node {}. Fix: include every dependency in the interpreter input graph.",
141        id.0
142    ))
143}
144
145fn cycle_error(id: NodeId) -> vyre::Error {
146    vyre::Error::interp(format!(
147        "graph contains a dependency cycle at node {}. Fix: submit an acyclic dataflow graph.",
148        id.0
149    ))
150}
151
152fn duplicate_node_error(id: NodeId) -> vyre::Error {
153    vyre::Error::interp(format!(
154        "graph contains duplicate node {}. Fix: submit exactly one storage record for each NodeId before reference execution.",
155        id.0
156    ))
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use vyre::ir::{BinOp, BufferAccess, BufferDecl, DataType, Expr, Node, NodeStorage};
163
164    #[test]
165    fn reference_eval_dispatches_singleton_atomic_flags_across_dynamic_byte_input() {
166        let program = Program::wrapped(
167            vec![
168                BufferDecl::storage("bytes_in", 0, BufferAccess::ReadOnly, DataType::U8)
169                    .with_count(0),
170                BufferDecl::storage("flag", 1, BufferAccess::ReadWrite, DataType::U32)
171                    .with_count(1),
172            ],
173            [256, 1, 1],
174            vec![
175                Node::let_bind("i", Expr::InvocationId { axis: 0 }),
176                Node::if_then(
177                    Expr::lt(Expr::var("i"), Expr::buf_len("bytes_in")),
178                    vec![Node::if_then(
179                        Expr::ne(
180                            Expr::cast(DataType::U32, Expr::load("bytes_in", Expr::var("i"))),
181                            Expr::u32(0),
182                        ),
183                        vec![Node::let_bind(
184                            "flag_old",
185                            Expr::atomic_or("flag", Expr::u32(0), Expr::u32(1)),
186                        )],
187                    )],
188                ),
189            ],
190        );
191        let mut bytes = vec![0u8; 4097];
192        bytes[4096] = 1;
193
194        let outputs = reference_eval(&program, &[Value::from(bytes), Value::from(vec![0u8; 4])])
195            .expect("Fix: reference interpreter should execute singleton atomic flag scans.");
196        let flag = outputs[0].to_bytes();
197
198        assert_eq!(u32::from_le_bytes([flag[0], flag[1], flag[2], flag[3]]), 1);
199    }
200
201    #[test]
202    fn generic_storage_graph_matches_recursive_oracle_for_10k_programs() {
203        let mut rng = 0x9e37_79b9_u64;
204        for case in 0..10_000 {
205            let graph = random_graph(&mut rng, case);
206            let output = graph.last().expect("Fix: generated graph is non-empty").0;
207            let expected =
208                recursive_value(output, &graph).expect("Fix: recursive oracle evaluates");
209            let actual = run_storage_graph(&graph, &[output])
210                .expect("Fix: generic graph interpreter evaluates")[0];
211            assert_eq!(actual, expected, "case {case}");
212        }
213    }
214
215    fn random_graph(rng: &mut u64, case: u32) -> Vec<(NodeId, NodeStorage)> {
216        let len = 2 + (next(rng) as usize % 31);
217        let mut graph = Vec::with_capacity(len);
218        graph.push((NodeId(0), NodeStorage::LitU32(case)));
219        graph.push((NodeId(1), NodeStorage::LitU32(next(rng))));
220        for index in 2..len {
221            let left = NodeId(next(rng) % index as u32);
222            let right = NodeId(next(rng) % index as u32);
223            let op = match next(rng) % 5 {
224                0 => BinOp::Add,
225                1 => BinOp::Sub,
226                2 => BinOp::Mul,
227                3 => BinOp::BitXor,
228                _ => BinOp::BitAnd,
229            };
230            graph.push((NodeId(index as u32), NodeStorage::BinOp { op, left, right }));
231        }
232        graph
233    }
234
235    fn recursive_value(
236        id: NodeId,
237        graph: &[(NodeId, NodeStorage)],
238    ) -> Result<IrValue, vyre::Error> {
239        let node = graph
240            .iter()
241            .find(|(node_id, _)| *node_id == id)
242            .map(|(_, node)| node)
243            .ok_or_else(|| missing_node_error(id))?;
244        match node {
245            NodeStorage::LitU32(value) => Ok(IrValue::U32(*value)),
246            NodeStorage::BinOp { op, left, right } => {
247                let left = expect_u32(recursive_value(*left, graph)?)?;
248                let right = expect_u32(recursive_value(*right, graph)?)?;
249                let value = match op {
250                    BinOp::Add => left.wrapping_add(right),
251                    BinOp::Sub => left.wrapping_sub(right),
252                    BinOp::Mul => left.wrapping_mul(right),
253                    BinOp::BitXor => left ^ right,
254                    BinOp::BitAnd => left & right,
255                    _ => {
256                        return Err(vyre::Error::interp(
257                            "recursive parity oracle received unsupported op. Fix: keep test generation within the oracle domain.",
258                        ));
259                    }
260                };
261                Ok(IrValue::U32(value))
262            }
263            _ => Err(vyre::Error::interp(
264                "recursive parity oracle received unsupported node. Fix: keep test generation within the oracle domain.",
265            )),
266        }
267    }
268
269    fn expect_u32(value: IrValue) -> Result<u32, vyre::Error> {
270        match value {
271            IrValue::U32(value) => Ok(value),
272            other => Err(vyre::Error::interp(format!(
273                "recursive parity oracle expected u32, got {other:?}. Fix: keep generated graphs scalar-u32 only."
274            ))),
275        }
276    }
277
278    fn next(rng: &mut u64) -> u32 {
279        *rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
280        (*rng >> 32) as u32
281    }
282}