use std::collections::HashMap;
use ndarray::ArrayD;
use crate::broadcast::{
ResolvedArg, build_args_for_index, compute_shape, eval_broadcast, flat_to_multi, resolve_input,
results_to_array, total_elements,
};
use crate::compiler::ir::CompiledExpr;
use crate::error::EvalError;
use crate::eval::input::EvalInput;
use crate::eval::numeric::NumericResult;
use crate::eval::scalar;
pub struct EvalHandle {
expr: CompiledExpr,
resolved: Vec<ResolvedArg>,
shape: Vec<usize>,
axis_args: Vec<usize>,
}
impl std::fmt::Debug for EvalHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EvalHandle")
.field("shape", &self.shape)
.field("num_args", &self.resolved.len())
.finish()
}
}
impl EvalHandle {
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn len(&self) -> usize {
total_elements(&self.shape)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn scalar(self) -> Result<NumericResult, EvalError> {
if !self.shape.is_empty() {
return Err(EvalError::ShapeMismatch {
details: format!("expected scalar output but shape is {:?}", self.shape),
});
}
let args: Vec<NumericResult> = self.resolved.iter().map(|r| r.get(0)).collect();
scalar::eval_node(&self.expr.root, &args, &mut vec![])
}
pub fn to_array(self) -> Result<ArrayD<NumericResult>, EvalError> {
let (results, shape) = eval_broadcast(&self.expr, &self.resolved)?;
results_to_array(results, &shape)
}
pub fn iter(self) -> EvalIter {
let total = total_elements(&self.shape);
EvalIter {
expr: self.expr,
resolved: self.resolved,
shape: self.shape,
axis_args: self.axis_args,
current: 0,
total,
}
}
}
pub struct EvalIter {
expr: CompiledExpr,
resolved: Vec<ResolvedArg>,
shape: Vec<usize>,
axis_args: Vec<usize>,
current: usize,
total: usize,
}
impl Iterator for EvalIter {
type Item = Result<NumericResult, EvalError>;
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.total {
return None;
}
let multi = flat_to_multi(self.current, &self.shape);
let args = build_args_for_index(&self.resolved, &self.axis_args, &multi);
self.current += 1;
Some(scalar::eval_node(&self.expr.root, &args, &mut vec![]))
}
}
impl EvalIter {
pub fn remaining(&self) -> usize {
self.total - self.current
}
}
pub fn eval(
expr: &CompiledExpr,
mut args: HashMap<&str, EvalInput>,
) -> Result<EvalHandle, EvalError> {
let expected = expr.argument_names();
for name in args.keys() {
if !expected.iter().any(|e| e == name) {
return Err(EvalError::UnknownArgument {
name: name.to_string(),
});
}
}
let mut resolved = Vec::with_capacity(expected.len());
for name in expected {
match args.remove_entry(name.as_str()) {
Some((_, input)) => resolved.push(resolve_input(input)),
None => {
return Err(EvalError::MissingArgument { name: name.clone() });
}
}
}
let (shape, axis_args) = compute_shape(&resolved);
Ok(EvalHandle {
expr: expr.clone(),
resolved,
shape,
axis_args,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compiler::ir::{BinaryOp, CompiledNode};
use approx::assert_abs_diff_eq;
fn make_expr(root: CompiledNode, arg_names: Vec<&str>) -> CompiledExpr {
CompiledExpr {
root,
argument_names: arg_names.into_iter().map(String::from).collect(),
is_complex: false,
}
}
fn x_sq_plus_y() -> CompiledExpr {
let x_sq = CompiledNode::Binary {
op: BinaryOp::Pow,
left: Box::new(CompiledNode::Argument(0)),
right: Box::new(CompiledNode::Literal(2.0)),
};
make_expr(
CompiledNode::Binary {
op: BinaryOp::Add,
left: Box::new(x_sq),
right: Box::new(CompiledNode::Argument(1)),
},
vec!["x", "y"],
)
}
#[test]
fn eval_scalar_result() {
let expr = x_sq_plus_y();
let mut args = HashMap::new();
args.insert("x", EvalInput::Scalar(3.0));
args.insert("y", EvalInput::Scalar(10.0));
let handle = eval(&expr, args).unwrap();
assert!(handle.shape().is_empty());
assert_eq!(handle.len(), 1);
let result = handle.scalar().unwrap();
assert_abs_diff_eq!(result.to_f64().unwrap(), 19.0, epsilon = 1e-10);
}
#[test]
fn eval_1d_array() {
let expr = make_expr(
CompiledNode::Binary {
op: BinaryOp::Pow,
left: Box::new(CompiledNode::Argument(0)),
right: Box::new(CompiledNode::Literal(2.0)),
},
vec!["x"],
);
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
let handle = eval(&expr, args).unwrap();
assert_eq!(handle.shape(), &[3]);
let arr = handle.to_array().unwrap();
assert_eq!(arr.shape(), &[3]);
assert_eq!(*arr.get([0]).unwrap(), NumericResult::Real(1.0));
assert_eq!(*arr.get([1]).unwrap(), NumericResult::Real(4.0));
assert_eq!(*arr.get([2]).unwrap(), NumericResult::Real(9.0));
}
#[test]
fn eval_2d_cartesian() {
let expr = x_sq_plus_y();
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
args.insert("y", EvalInput::from(vec![10.0, 20.0]));
let handle = eval(&expr, args).unwrap();
assert_eq!(handle.shape(), &[3, 2]);
assert_eq!(handle.len(), 6);
let arr = handle.to_array().unwrap();
assert_eq!(*arr.get([0, 0]).unwrap(), NumericResult::Real(11.0));
assert_eq!(*arr.get([0, 1]).unwrap(), NumericResult::Real(21.0));
assert_eq!(*arr.get([2, 1]).unwrap(), NumericResult::Real(29.0));
}
#[test]
fn eval_iter_matches_to_array() {
let expr = x_sq_plus_y();
let mut args1 = HashMap::new();
args1.insert("x", EvalInput::from(vec![1.0, 2.0]));
args1.insert("y", EvalInput::from(vec![10.0, 20.0]));
let mut args2 = HashMap::new();
args2.insert("x", EvalInput::from(vec![1.0, 2.0]));
args2.insert("y", EvalInput::from(vec![10.0, 20.0]));
let handle1 = eval(&expr, args1).unwrap();
let handle2 = eval(&expr, args2).unwrap();
let arr_results: Vec<NumericResult> = handle1.to_array().unwrap().iter().copied().collect();
let iter_results: Vec<NumericResult> = handle2.iter().map(|r| r.unwrap()).collect();
assert_eq!(arr_results, iter_results);
}
#[test]
fn eval_unknown_argument_error() {
let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
let mut args = HashMap::new();
args.insert("x", EvalInput::Scalar(1.0));
args.insert("z", EvalInput::Scalar(2.0));
let err = eval(&expr, args).unwrap_err();
assert!(matches!(err, EvalError::UnknownArgument { .. }));
}
#[test]
fn eval_missing_argument_error() {
let expr = x_sq_plus_y();
let mut args = HashMap::new();
args.insert("x", EvalInput::Scalar(1.0));
let err = eval(&expr, args).unwrap_err();
assert!(matches!(err, EvalError::MissingArgument { .. }));
}
#[test]
fn eval_scalar_on_nonscalar_errors() {
let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0]));
let handle = eval(&expr, args).unwrap();
let err = handle.scalar().unwrap_err();
assert!(matches!(err, EvalError::ShapeMismatch { .. }));
}
#[test]
fn eval_no_args_expression() {
let expr = make_expr(CompiledNode::Literal(42.0), vec![]);
let args = HashMap::new();
let handle = eval(&expr, args).unwrap();
let result = handle.scalar().unwrap();
assert_eq!(result, NumericResult::Real(42.0));
}
#[test]
fn eval_iter_remaining() {
let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
let handle = eval(&expr, args).unwrap();
let mut iter = handle.iter();
assert_eq!(iter.remaining(), 3);
iter.next();
assert_eq!(iter.remaining(), 2);
}
#[test]
fn eval_with_iterator_input() {
let expr = make_expr(
CompiledNode::Binary {
op: BinaryOp::Mul,
left: Box::new(CompiledNode::Argument(0)),
right: Box::new(CompiledNode::Literal(2.0)),
},
vec!["x"],
);
let mut args = HashMap::new();
args.insert(
"x",
EvalInput::Iter(Box::new(vec![1.0, 2.0, 3.0].into_iter())),
);
let handle = eval(&expr, args).unwrap();
let arr = handle.to_array().unwrap();
assert_eq!(arr.shape(), &[3]);
}
#[test]
fn eval_empty_array() {
let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
let mut args = HashMap::new();
args.insert("x", EvalInput::from(vec![] as Vec<f64>));
let handle = eval(&expr, args).unwrap();
assert!(handle.is_empty());
assert_eq!(handle.shape(), &[0]);
}
}