pub(crate) mod call;
pub mod expr;
pub(crate) mod expr_cast;
pub(crate) mod hashmap;
pub mod node;
pub mod sequential;
pub(crate) mod typed_ops;
use std::borrow::Cow;
use rustc_hash::FxHashMap;
use vyre::ir::{InterpCtx, Node, NodeId, NodeStorage, Program, Value as IrValue};
use crate::value::Value;
pub(crate) fn program_for_interpreter(program: &Program) -> Result<Cow<'_, Program>, vyre::Error> {
if let Some(message) = program.top_level_region_violation() {
if program.entry().is_empty() {
return Err(vyre::Error::interp(format!(
"reference interpreter requires a top-level Region-wrapped Program: {message}"
)));
}
if matches!(program.entry().first(), Some(Node::Store { .. })) {
return Err(vyre::Error::interp(format!(
"reference interpreter requires a top-level Region-wrapped Program: {message}"
)));
}
return Ok(Cow::Owned(program.clone().reconcile_runnable_top_level()));
}
Ok(Cow::Borrowed(program))
}
pub fn reference_eval(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
run_arena_reference(program, inputs)
}
pub fn run_arena_reference(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
let program = program_for_interpreter(program)?;
hashmap::run_hashmap_reference(&program, inputs)
}
#[cfg(test)]
pub fn eval_hashmap_reference(
program: &Program,
inputs: &[Value],
) -> Result<Vec<Value>, vyre::Error> {
run_arena_reference(program, inputs)
}
pub fn run_storage_graph(
nodes: &[(NodeId, NodeStorage)],
outputs: &[NodeId],
) -> Result<Vec<IrValue>, vyre::Error> {
let graph = nodes
.iter()
.map(|(id, node)| (*id, node))
.collect::<FxHashMap<_, _>>();
let mut ctx = InterpCtx::default();
let mut states = FxHashMap::with_capacity_and_hasher(graph.len(), Default::default());
for output in outputs {
eval_storage_node(*output, &graph, &mut ctx, &mut states)?;
}
outputs
.iter()
.map(|id| ctx.get(*id).map_err(interp_error))
.collect()
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum VisitState {
Visiting,
Done,
}
fn eval_storage_node(
id: NodeId,
graph: &FxHashMap<NodeId, &NodeStorage>,
ctx: &mut InterpCtx,
states: &mut FxHashMap<NodeId, VisitState>,
) -> Result<(), vyre::Error> {
match states.get(&id).copied() {
Some(VisitState::Done) => return Ok(()),
Some(VisitState::Visiting) => return Err(cycle_error(id)),
None => {}
}
let node = *graph.get(&id).ok_or_else(|| missing_node_error(id))?;
states.insert(id, VisitState::Visiting);
let inputs = node.input_ids();
for input in &inputs {
eval_storage_node(*input, graph, ctx, states)?;
}
ctx.set_operands(inputs);
let value = node.interpret(ctx).map_err(interp_error)?;
ctx.set(id, value);
states.insert(id, VisitState::Done);
Ok(())
}
fn interp_error(error: vyre::ir::EvalError) -> vyre::Error {
vyre::Error::interp(error.to_string())
}
fn missing_node_error(id: NodeId) -> vyre::Error {
vyre::Error::interp(format!(
"graph references missing node {}. Fix: include every dependency in the interpreter input graph.",
id.0
))
}
fn cycle_error(id: NodeId) -> vyre::Error {
vyre::Error::interp(format!(
"graph contains a dependency cycle at node {}. Fix: submit an acyclic dataflow graph.",
id.0
))
}
#[cfg(test)]
mod tests {
use super::*;
use vyre::ir::{BinOp, NodeStorage};
#[test]
fn generic_storage_graph_matches_recursive_oracle_for_10k_programs() {
let mut rng = 0x9e37_79b9_u64;
for case in 0..10_000 {
let graph = random_graph(&mut rng, case);
let output = graph.last().expect("Fix: generated graph is non-empty").0;
let expected =
recursive_value(output, &graph).expect("Fix: recursive oracle evaluates");
let actual = run_storage_graph(&graph, &[output])
.expect("Fix: generic graph interpreter evaluates")[0];
assert_eq!(actual, expected, "case {case}");
}
}
fn random_graph(rng: &mut u64, case: u32) -> Vec<(NodeId, NodeStorage)> {
let len = 2 + (next(rng) as usize % 31);
let mut graph = Vec::with_capacity(len);
graph.push((NodeId(0), NodeStorage::LitU32(case)));
graph.push((NodeId(1), NodeStorage::LitU32(next(rng))));
for index in 2..len {
let left = NodeId(next(rng) % index as u32);
let right = NodeId(next(rng) % index as u32);
let op = match next(rng) % 5 {
0 => BinOp::Add,
1 => BinOp::Sub,
2 => BinOp::Mul,
3 => BinOp::BitXor,
_ => BinOp::BitAnd,
};
graph.push((NodeId(index as u32), NodeStorage::BinOp { op, left, right }));
}
graph
}
fn recursive_value(
id: NodeId,
graph: &[(NodeId, NodeStorage)],
) -> Result<IrValue, vyre::Error> {
let node = graph
.iter()
.find(|(node_id, _)| *node_id == id)
.map(|(_, node)| node)
.ok_or_else(|| missing_node_error(id))?;
match node {
NodeStorage::LitU32(value) => Ok(IrValue::U32(*value)),
NodeStorage::BinOp { op, left, right } => {
let left = expect_u32(recursive_value(*left, graph)?)?;
let right = expect_u32(recursive_value(*right, graph)?)?;
let value = match op {
BinOp::Add => left.wrapping_add(right),
BinOp::Sub => left.wrapping_sub(right),
BinOp::Mul => left.wrapping_mul(right),
BinOp::BitXor => left ^ right,
BinOp::BitAnd => left & right,
_ => {
return Err(vyre::Error::interp(
"recursive parity oracle received unsupported op. Fix: keep test generation within the oracle domain.",
));
}
};
Ok(IrValue::U32(value))
}
_ => Err(vyre::Error::interp(
"recursive parity oracle received unsupported node. Fix: keep test generation within the oracle domain.",
)),
}
}
fn expect_u32(value: IrValue) -> Result<u32, vyre::Error> {
match value {
IrValue::U32(value) => Ok(value),
other => Err(vyre::Error::interp(format!(
"recursive parity oracle expected u32, got {other:?}. Fix: keep generated graphs scalar-u32 only."
))),
}
}
fn next(rng: &mut u64) -> u32 {
*rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
(*rng >> 32) as u32
}
}