use super::types::*;
use crate::category::core::*;
use crate::definition::Def;
use crate::ssa::parallel_ssa;
use open_hypergraphs::lax::NodeId;
use std::collections::HashMap;
pub fn eval<I: Interpreter>(interpreter: &I, term: Term, values: Vec<Value<I>>) -> ResultValues<I> {
eval_with(interpreter, term, values, |_, _| ())
}
pub fn eval_with<I: Interpreter, F: FnMut(NodeId, &Value<I>)>(
interpreter: &I,
term: Term,
values: Vec<Value<I>>,
mut on_write: F,
) -> ResultValues<I> {
assert_eq!(values.len(), term.sources.len());
let mut state = HashMap::<NodeId, Value<I>>::new();
for (node_id, value) in term.sources.iter().zip(values) {
on_write(*node_id, &value);
state.insert(*node_id, value);
}
let target_nodes = term.targets.clone();
for par in parallel_ssa(term.to_strict())? {
for ssa in par {
let mut args = Vec::new();
for (node_id, _) in &ssa.sources {
match state.remove(node_id) {
Some(value) => args.push(value),
None => return Err(InterpreterError::MultipleRead(*node_id)),
}
}
let results = match &ssa.op {
Def::Def(path) => interpreter.handle_definition(&ssa, args, path),
Def::Arr(op) => apply_op(interpreter, &ssa, args, op),
}?;
for ((node_id, _), result) in ssa.targets.iter().zip(results) {
on_write(*node_id, &result);
if state.insert(*node_id, result).is_some() {
return Err(InterpreterError::MultipleWrite(*node_id));
}
}
}
}
let mut target_values = Vec::new();
for target_node in &target_nodes {
match state.remove(target_node) {
Some(value) => target_values.push(value),
None => return Err(InterpreterError::MultipleRead(*target_node)),
}
}
Ok(target_values)
}
fn apply_op<I: Interpreter>(
interpreter: &I,
ssa: &CoreSSA,
args: Vec<Value<I>>,
op: &Operation,
) -> ResultValues<I> {
match op {
Operation::Type(type_op) => apply_type_op(ssa, args, type_op),
Operation::Nat(nat_op) => apply_nat_op(ssa, args, nat_op),
Operation::DtypeConstant(dtype) => Ok(vec![Value::Dtype(I::dtype_constant(dtype.clone()))]),
Operation::Tensor(tensor_op) => interpreter.tensor_op(ssa, args, tensor_op),
Operation::Copy => apply_copy(ssa, args),
Operation::Load(path) => interpreter
.handle_load(ssa, path)
.ok_or(InterpreterError::Load(ssa.edge_id, path.clone())),
}
}
fn apply_copy<V: Interpreter>(ssa: &CoreSSA, args: Vec<Value<V>>) -> Result<Vec<Value<V>>> {
let [v] = get_exact_arity(ssa, args)?;
let n = ssa.targets.len();
let mut result = Vec::with_capacity(n);
result.push(v);
for _ in 1..n {
result.push(result[0].clone())
}
Ok(result)
}
use super::util::{get_exact_arity, to_nat, to_shape, to_tensor};
fn apply_type_op<V: Interpreter>(
ssa: &CoreSSA,
args: Vec<Value<V>>,
type_op: &TypeOp,
) -> ResultValues<V> {
match type_op {
TypeOp::Pack => {
let dims: Result<Vec<V::Nat>> = args.into_iter().map(|v| to_nat(ssa, v)).collect();
Ok(vec![Value::Shape(V::pack(dims?))])
}
TypeOp::Unpack => {
let [arg] = get_exact_arity(ssa, args)?;
let shape = to_shape(ssa, arg)?;
Ok(V::unpack(shape)
.ok_or(InterpreterError::TypeError(ssa.edge_id))?
.into_iter()
.map(|dim| Value::Nat(dim))
.collect())
}
TypeOp::Shape => {
let [arg] = get_exact_arity(ssa, args)?;
let tensor = to_tensor(ssa, arg)?;
Ok(vec![Value::Shape(
V::shape(tensor).ok_or(InterpreterError::TypeError(ssa.edge_id))?,
)])
}
TypeOp::Dtype => {
let [arg] = get_exact_arity(ssa, args)?;
let tensor = to_tensor(ssa, arg)?;
Ok(vec![Value::Dtype(
V::dtype(tensor).ok_or(InterpreterError::TypeError(ssa.edge_id))?,
)])
}
}
}
fn apply_nat_op<I: Interpreter>(ssa: &CoreSSA, args: Vec<Value<I>>, op: &NatOp) -> ResultValues<I> {
let args: Result<Vec<I::Nat>> = args.into_iter().map(|n| to_nat(ssa, n)).collect();
match op {
NatOp::Constant(n) => {
let [] = get_exact_arity(ssa, args?)?;
Ok(vec![Value::Nat(I::nat_constant(*n))])
}
NatOp::Add => {
let [a, b] = get_exact_arity(ssa, args?)?;
Ok(vec![Value::Nat(I::nat_add(a, b))])
}
NatOp::Mul => {
let [a, b] = get_exact_arity(ssa, args?)?;
Ok(vec![Value::Nat(I::nat_mul(a, b))])
}
}
}