Skip to main content

vyre_reference/execution/
mod.rs

1//! Generic reference interpreter entry points.
2//!
3//! The stable statement-IR [`reference_eval`] entry point remains delegated to
4//! the existing invocation simulator until `Program` stores graph nodes
5//! directly.
6
7pub(crate) mod call;
8pub mod expr;
9pub(crate) mod expr_cast;
10pub(crate) mod hashmap;
11pub mod node;
12pub mod sequential;
13pub(crate) mod typed_ops;
14
15use std::borrow::Cow;
16
17use rustc_hash::FxHashMap;
18use vyre::ir::{InterpCtx, Node, NodeId, NodeStorage, Program, Value as IrValue};
19
20use crate::value::Value;
21
22/// If the program satisfies the public top-level-Region model, return a
23/// byte-identical clone. If not, the usual case is
24/// `optimizer::passes::cleanup::region_inline_engine` having flattened a Category-A wrapper;
25/// in that case [`Program::reconcile_runnable_top_level`] matches
26/// `Program::wrapped` again. When the first entry node is a `Store` (or the
27/// entry is empty), we do **not** auto-wrap: those programs must still use
28/// `Program::wrapped` explicitly, matching `region_gate` negative tests.
29pub(crate) fn program_for_interpreter(program: &Program) -> Result<Cow<'_, Program>, vyre::Error> {
30    if let Some(message) = program.top_level_region_violation() {
31        if program.entry().is_empty() {
32            return Err(vyre::Error::interp(format!(
33                "reference interpreter requires a top-level Region-wrapped Program: {message}"
34            )));
35        }
36        if matches!(program.entry().first(), Some(Node::Store { .. })) {
37            return Err(vyre::Error::interp(format!(
38                "reference interpreter requires a top-level Region-wrapped Program: {message}"
39            )));
40        }
41        return Ok(Cow::Owned(program.clone().reconcile_runnable_top_level()));
42    }
43    Ok(Cow::Borrowed(program))
44}
45
46/// Execute a vyre IR program on the pure Rust reference interpreter.
47///
48/// The current public [`Program`] model is statement-oriented, so this stable
49/// entry point delegates to the statement evaluator. Graph-shaped extension
50/// nodes use [`run_storage_graph`].
51pub fn reference_eval(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
52    run_arena_reference(program, inputs)
53}
54
55/// Execute using the statement-IR reference evaluator.
56pub fn run_arena_reference(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
57    let program = program_for_interpreter(program)?;
58    hashmap::run_hashmap_reference(&program, inputs)
59}
60
61/// Differential oracle retained for tests during the generic interpreter transition.
62#[cfg(test)]
63pub fn eval_hashmap_reference(
64    program: &Program,
65    inputs: &[Value],
66) -> Result<Vec<Value>, vyre::Error> {
67    run_arena_reference(program, inputs)
68}
69
70/// Interpret a compact [`NodeStorage`] graph and return output node values.
71pub fn run_storage_graph(
72    nodes: &[(NodeId, NodeStorage)],
73    outputs: &[NodeId],
74) -> Result<Vec<IrValue>, vyre::Error> {
75    let graph = nodes
76        .iter()
77        .map(|(id, node)| (*id, node))
78        .collect::<FxHashMap<_, _>>();
79    let mut ctx = InterpCtx::default();
80    let mut states = FxHashMap::with_capacity_and_hasher(graph.len(), Default::default());
81
82    for output in outputs {
83        eval_storage_node(*output, &graph, &mut ctx, &mut states)?;
84    }
85
86    outputs
87        .iter()
88        .map(|id| ctx.get(*id).map_err(interp_error))
89        .collect()
90}
91
92#[derive(Clone, Copy, Debug, PartialEq, Eq)]
93enum VisitState {
94    Visiting,
95    Done,
96}
97
98fn eval_storage_node(
99    id: NodeId,
100    graph: &FxHashMap<NodeId, &NodeStorage>,
101    ctx: &mut InterpCtx,
102    states: &mut FxHashMap<NodeId, VisitState>,
103) -> Result<(), vyre::Error> {
104    match states.get(&id).copied() {
105        Some(VisitState::Done) => return Ok(()),
106        Some(VisitState::Visiting) => return Err(cycle_error(id)),
107        None => {}
108    }
109
110    let node = *graph.get(&id).ok_or_else(|| missing_node_error(id))?;
111    states.insert(id, VisitState::Visiting);
112    let inputs = node.input_ids();
113    for input in &inputs {
114        eval_storage_node(*input, graph, ctx, states)?;
115    }
116    ctx.set_operands(inputs);
117    let value = node.interpret(ctx).map_err(interp_error)?;
118    ctx.set(id, value);
119    states.insert(id, VisitState::Done);
120    Ok(())
121}
122
123fn interp_error(error: vyre::ir::EvalError) -> vyre::Error {
124    vyre::Error::interp(error.to_string())
125}
126
127fn missing_node_error(id: NodeId) -> vyre::Error {
128    vyre::Error::interp(format!(
129        "graph references missing node {}. Fix: include every dependency in the interpreter input graph.",
130        id.0
131    ))
132}
133
134fn cycle_error(id: NodeId) -> vyre::Error {
135    vyre::Error::interp(format!(
136        "graph contains a dependency cycle at node {}. Fix: submit an acyclic dataflow graph.",
137        id.0
138    ))
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use vyre::ir::{BinOp, NodeStorage};
145
146    #[test]
147    fn generic_storage_graph_matches_recursive_oracle_for_10k_programs() {
148        let mut rng = 0x9e37_79b9_u64;
149        for case in 0..10_000 {
150            let graph = random_graph(&mut rng, case);
151            let output = graph.last().expect("Fix: generated graph is non-empty").0;
152            let expected =
153                recursive_value(output, &graph).expect("Fix: recursive oracle evaluates");
154            let actual = run_storage_graph(&graph, &[output])
155                .expect("Fix: generic graph interpreter evaluates")[0];
156            assert_eq!(actual, expected, "case {case}");
157        }
158    }
159
160    fn random_graph(rng: &mut u64, case: u32) -> Vec<(NodeId, NodeStorage)> {
161        let len = 2 + (next(rng) as usize % 31);
162        let mut graph = Vec::with_capacity(len);
163        graph.push((NodeId(0), NodeStorage::LitU32(case)));
164        graph.push((NodeId(1), NodeStorage::LitU32(next(rng))));
165        for index in 2..len {
166            let left = NodeId(next(rng) % index as u32);
167            let right = NodeId(next(rng) % index as u32);
168            let op = match next(rng) % 5 {
169                0 => BinOp::Add,
170                1 => BinOp::Sub,
171                2 => BinOp::Mul,
172                3 => BinOp::BitXor,
173                _ => BinOp::BitAnd,
174            };
175            graph.push((NodeId(index as u32), NodeStorage::BinOp { op, left, right }));
176        }
177        graph
178    }
179
180    fn recursive_value(
181        id: NodeId,
182        graph: &[(NodeId, NodeStorage)],
183    ) -> Result<IrValue, vyre::Error> {
184        let node = graph
185            .iter()
186            .find(|(node_id, _)| *node_id == id)
187            .map(|(_, node)| node)
188            .ok_or_else(|| missing_node_error(id))?;
189        match node {
190            NodeStorage::LitU32(value) => Ok(IrValue::U32(*value)),
191            NodeStorage::BinOp { op, left, right } => {
192                let left = expect_u32(recursive_value(*left, graph)?)?;
193                let right = expect_u32(recursive_value(*right, graph)?)?;
194                let value = match op {
195                    BinOp::Add => left.wrapping_add(right),
196                    BinOp::Sub => left.wrapping_sub(right),
197                    BinOp::Mul => left.wrapping_mul(right),
198                    BinOp::BitXor => left ^ right,
199                    BinOp::BitAnd => left & right,
200                    _ => {
201                        return Err(vyre::Error::interp(
202                            "recursive parity oracle received unsupported op. Fix: keep test generation within the oracle domain.",
203                        ));
204                    }
205                };
206                Ok(IrValue::U32(value))
207            }
208            _ => Err(vyre::Error::interp(
209                "recursive parity oracle received unsupported node. Fix: keep test generation within the oracle domain.",
210            )),
211        }
212    }
213
214    fn expect_u32(value: IrValue) -> Result<u32, vyre::Error> {
215        match value {
216            IrValue::U32(value) => Ok(value),
217            other => Err(vyre::Error::interp(format!(
218                "recursive parity oracle expected u32, got {other:?}. Fix: keep generated graphs scalar-u32 only."
219            ))),
220        }
221    }
222
223    fn next(rng: &mut u64) -> u32 {
224        *rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
225        (*rng >> 32) as u32
226    }
227}