use ndarray::ArrayD;
use crate::compiler::ir::CompiledExpr;
use crate::error::EvalError;
pub(crate) type BroadcastResult = (Vec<Result<NumericResult, EvalError>>, Vec<usize>);
use crate::eval::input::EvalInput;
use crate::eval::numeric::NumericResult;
use crate::eval::scalar;
pub(crate) enum ResolvedArg {
Scalar(NumericResult),
Array(Vec<NumericResult>),
}
impl ResolvedArg {
pub(crate) fn len(&self) -> usize {
match self {
ResolvedArg::Scalar(_) => 1,
ResolvedArg::Array(v) => v.len(),
}
}
pub(crate) fn is_scalar(&self) -> bool {
matches!(self, ResolvedArg::Scalar(_))
}
pub(crate) fn get(&self, idx: usize) -> NumericResult {
match self {
ResolvedArg::Scalar(v) => *v,
ResolvedArg::Array(v) => v[idx],
}
}
}
pub(crate) fn resolve_input(input: EvalInput) -> ResolvedArg {
match input {
EvalInput::Scalar(v) => ResolvedArg::Scalar(NumericResult::Real(v)),
EvalInput::Complex(c) => ResolvedArg::Scalar(NumericResult::Complex(c)),
EvalInput::Array(arr) => {
ResolvedArg::Array(arr.iter().map(|v| NumericResult::Real(*v)).collect())
}
EvalInput::ComplexArray(arr) => {
ResolvedArg::Array(arr.iter().map(|v| NumericResult::Complex(*v)).collect())
}
EvalInput::Iter(iter) => ResolvedArg::Array(iter.map(NumericResult::Real).collect()),
EvalInput::ComplexIter(iter) => {
ResolvedArg::Array(iter.map(NumericResult::Complex).collect())
}
}
}
pub(crate) fn compute_shape(args: &[ResolvedArg]) -> (Vec<usize>, Vec<usize>) {
let mut shape = Vec::new();
let mut axis_args = Vec::new();
for (i, arg) in args.iter().enumerate() {
if !arg.is_scalar() {
shape.push(arg.len());
axis_args.push(i);
}
}
(shape, axis_args)
}
pub(crate) fn total_elements(shape: &[usize]) -> usize {
if shape.is_empty() {
1 } else {
shape.iter().product()
}
}
pub(crate) fn flat_to_multi(mut flat: usize, shape: &[usize]) -> Vec<usize> {
let mut indices = vec![0; shape.len()];
for i in (0..shape.len()).rev() {
indices[i] = flat % shape[i];
flat /= shape[i];
}
indices
}
pub(crate) fn build_args_for_index(
resolved: &[ResolvedArg],
axis_args: &[usize],
multi_idx: &[usize],
) -> Vec<NumericResult> {
let mut result = Vec::with_capacity(resolved.len());
let mut axis_pos = 0;
for (i, arg) in resolved.iter().enumerate() {
if axis_pos < axis_args.len() && axis_args[axis_pos] == i {
result.push(arg.get(multi_idx[axis_pos]));
axis_pos += 1;
} else {
result.push(arg.get(0));
}
}
result
}
pub(crate) fn eval_broadcast(
expr: &CompiledExpr,
resolved: &[ResolvedArg],
) -> Result<BroadcastResult, EvalError> {
let (shape, axis_args) = compute_shape(resolved);
let total = total_elements(&shape);
let results = eval_broadcast_inner(expr, &shape, &axis_args, total, resolved);
Ok((results, shape))
}
#[cfg(feature = "parallel")]
fn eval_broadcast_inner(
expr: &CompiledExpr,
shape: &[usize],
axis_args: &[usize],
total: usize,
resolved: &[ResolvedArg],
) -> Vec<Result<NumericResult, EvalError>> {
use rayon::prelude::*;
(0..total)
.into_par_iter()
.map(|flat| {
let multi = flat_to_multi(flat, shape);
let args = build_args_for_index(resolved, axis_args, &multi);
scalar::eval_node(&expr.root, &args, &mut vec![])
})
.collect()
}
#[cfg(not(feature = "parallel"))]
fn eval_broadcast_inner(
expr: &CompiledExpr,
shape: &[usize],
axis_args: &[usize],
total: usize,
resolved: &[ResolvedArg],
) -> Vec<Result<NumericResult, EvalError>> {
let mut results = Vec::with_capacity(total);
for flat in 0..total {
let multi = flat_to_multi(flat, shape);
let args = build_args_for_index(resolved, axis_args, &multi);
results.push(scalar::eval_node(&expr.root, &args, &mut vec![]));
}
results
}
pub(crate) fn results_to_array(
results: Vec<Result<NumericResult, EvalError>>,
shape: &[usize],
) -> Result<ArrayD<NumericResult>, EvalError> {
let flat: Vec<NumericResult> = results.into_iter().collect::<Result<_, _>>()?;
let nd_shape: Vec<usize> = if shape.is_empty() {
vec![] } else {
shape.to_vec()
};
Ok(ArrayD::from_shape_vec(nd_shape, flat).expect("shape mismatch in results_to_array"))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compiler::ir::{BinaryOp, CompiledNode};
fn make_expr(root: CompiledNode, arg_names: Vec<&str>) -> CompiledExpr {
CompiledExpr {
root,
argument_names: arg_names.into_iter().map(String::from).collect(),
is_complex: false,
}
}
#[test]
fn flat_to_multi_2d() {
let shape = vec![3, 2];
assert_eq!(flat_to_multi(0, &shape), vec![0, 0]);
assert_eq!(flat_to_multi(1, &shape), vec![0, 1]);
assert_eq!(flat_to_multi(2, &shape), vec![1, 0]);
assert_eq!(flat_to_multi(5, &shape), vec![2, 1]);
}
#[test]
fn flat_to_multi_3d() {
let shape = vec![2, 3, 4];
assert_eq!(flat_to_multi(0, &shape), vec![0, 0, 0]);
assert_eq!(flat_to_multi(23, &shape), vec![1, 2, 3]);
}
#[test]
fn compute_shape_all_scalars() {
let args = vec![
ResolvedArg::Scalar(NumericResult::Real(1.0)),
ResolvedArg::Scalar(NumericResult::Real(2.0)),
];
let (shape, axis_args) = compute_shape(&args);
assert!(shape.is_empty());
assert!(axis_args.is_empty());
}
#[test]
fn compute_shape_one_array() {
let args = vec![ResolvedArg::Array(vec![
NumericResult::Real(1.0),
NumericResult::Real(2.0),
NumericResult::Real(3.0),
])];
let (shape, axis_args) = compute_shape(&args);
assert_eq!(shape, vec![3]);
assert_eq!(axis_args, vec![0]);
}
#[test]
fn compute_shape_mixed() {
let args = vec![
ResolvedArg::Array(vec![NumericResult::Real(1.0), NumericResult::Real(2.0)]),
ResolvedArg::Scalar(NumericResult::Real(5.0)),
ResolvedArg::Array(vec![
NumericResult::Real(10.0),
NumericResult::Real(20.0),
NumericResult::Real(30.0),
]),
];
let (shape, axis_args) = compute_shape(&args);
assert_eq!(shape, vec![2, 3]);
assert_eq!(axis_args, vec![0, 2]);
}
#[test]
fn broadcast_all_scalars() {
let expr = make_expr(
CompiledNode::Binary {
op: BinaryOp::Add,
left: Box::new(CompiledNode::Argument(0)),
right: Box::new(CompiledNode::Argument(1)),
},
vec!["x", "y"],
);
let resolved = vec![
ResolvedArg::Scalar(NumericResult::Real(2.0)),
ResolvedArg::Scalar(NumericResult::Real(3.0)),
];
let (results, shape) = eval_broadcast(&expr, &resolved).unwrap();
assert!(shape.is_empty());
assert_eq!(results.len(), 1);
assert_eq!(results[0].as_ref().unwrap().to_f64().unwrap(), 5.0);
}
#[test]
fn broadcast_one_array() {
let expr = make_expr(
CompiledNode::Binary {
op: BinaryOp::Pow,
left: Box::new(CompiledNode::Argument(0)),
right: Box::new(CompiledNode::Literal(2.0)),
},
vec!["x"],
);
let resolved = vec![ResolvedArg::Array(vec![
NumericResult::Real(1.0),
NumericResult::Real(2.0),
NumericResult::Real(3.0),
])];
let (results, shape) = eval_broadcast(&expr, &resolved).unwrap();
assert_eq!(shape, vec![3]);
let vals: Vec<f64> = results
.into_iter()
.map(|r| r.unwrap().to_f64().unwrap())
.collect();
assert_eq!(vals, vec![1.0, 4.0, 9.0]);
}
#[test]
fn broadcast_two_arrays_cartesian() {
let x_sq = CompiledNode::Binary {
op: BinaryOp::Pow,
left: Box::new(CompiledNode::Argument(0)),
right: Box::new(CompiledNode::Literal(2.0)),
};
let expr = make_expr(
CompiledNode::Binary {
op: BinaryOp::Add,
left: Box::new(x_sq),
right: Box::new(CompiledNode::Argument(1)),
},
vec!["x", "y"],
);
let resolved = vec![
ResolvedArg::Array(vec![
NumericResult::Real(1.0),
NumericResult::Real(2.0),
NumericResult::Real(3.0),
]),
ResolvedArg::Array(vec![NumericResult::Real(10.0), NumericResult::Real(20.0)]),
];
let (results, shape) = eval_broadcast(&expr, &resolved).unwrap();
assert_eq!(shape, vec![3, 2]);
let vals: Vec<f64> = results
.into_iter()
.map(|r| r.unwrap().to_f64().unwrap())
.collect();
assert_eq!(vals, vec![11.0, 21.0, 14.0, 24.0, 19.0, 29.0]);
}
#[test]
fn broadcast_mixed_scalar_array() {
let x_sq = CompiledNode::Binary {
op: BinaryOp::Pow,
left: Box::new(CompiledNode::Argument(0)),
right: Box::new(CompiledNode::Literal(2.0)),
};
let expr = make_expr(
CompiledNode::Binary {
op: BinaryOp::Add,
left: Box::new(x_sq),
right: Box::new(CompiledNode::Argument(1)),
},
vec!["x", "y"],
);
let resolved = vec![
ResolvedArg::Scalar(NumericResult::Real(2.0)),
ResolvedArg::Array(vec![
NumericResult::Real(10.0),
NumericResult::Real(20.0),
NumericResult::Real(30.0),
]),
];
let (results, shape) = eval_broadcast(&expr, &resolved).unwrap();
assert_eq!(shape, vec![3]);
let vals: Vec<f64> = results
.into_iter()
.map(|r| r.unwrap().to_f64().unwrap())
.collect();
assert_eq!(vals, vec![14.0, 24.0, 34.0]);
}
#[test]
fn broadcast_empty_array() {
let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
let resolved = vec![ResolvedArg::Array(vec![])];
let (results, shape) = eval_broadcast(&expr, &resolved).unwrap();
assert_eq!(shape, vec![0]);
assert!(results.is_empty());
}
#[test]
fn broadcast_per_element_error() {
let expr = make_expr(
CompiledNode::Binary {
op: BinaryOp::Div,
left: Box::new(CompiledNode::Literal(1.0)),
right: Box::new(CompiledNode::Argument(0)),
},
vec!["x"],
);
let resolved = vec![ResolvedArg::Array(vec![
NumericResult::Real(1.0),
NumericResult::Real(0.0),
NumericResult::Real(2.0),
])];
let (results, _shape) = eval_broadcast(&expr, &resolved).unwrap();
assert!(results[0].is_ok());
assert!(results[1].is_err());
assert!(results[2].is_ok());
}
#[test]
fn resolve_input_scalar() {
let r = resolve_input(EvalInput::Scalar(5.0));
assert!(r.is_scalar());
}
#[test]
fn resolve_input_array() {
let r = resolve_input(EvalInput::from(vec![1.0, 2.0, 3.0]));
assert_eq!(r.len(), 3);
}
#[test]
fn resolve_input_iter() {
let r = resolve_input(EvalInput::Iter(Box::new(vec![1.0, 2.0].into_iter())));
assert_eq!(r.len(), 2);
}
#[test]
fn results_to_array_0d() {
let results = vec![Ok(NumericResult::Real(42.0))];
let arr = results_to_array(results, &[]).unwrap();
assert_eq!(arr.ndim(), 0);
assert_eq!(*arr.first().unwrap(), NumericResult::Real(42.0));
}
#[test]
fn results_to_array_1d() {
let results = vec![Ok(NumericResult::Real(1.0)), Ok(NumericResult::Real(2.0))];
let arr = results_to_array(results, &[2]).unwrap();
assert_eq!(arr.shape(), &[2]);
}
#[test]
fn results_to_array_with_error() {
let results = vec![Ok(NumericResult::Real(1.0)), Err(EvalError::DivisionByZero)];
let err = results_to_array(results, &[2]).unwrap_err();
assert!(matches!(err, EvalError::DivisionByZero));
}
}