1use std::collections::HashMap;
2
3use ndarray::ArrayD;
4
5use crate::broadcast::{
6 ResolvedArg, build_args_for_index, compute_shape, eval_broadcast, flat_to_multi, resolve_input,
7 results_to_array, total_elements,
8};
9use crate::compiler::ir::CompiledExpr;
10use crate::error::EvalError;
11use crate::eval::input::EvalInput;
12use crate::eval::numeric::NumericResult;
13use crate::eval::scalar;
14
15pub struct EvalHandle {
20 expr: CompiledExpr,
21 resolved: Vec<ResolvedArg>,
22 shape: Vec<usize>,
23 axis_args: Vec<usize>,
24}
25
26impl std::fmt::Debug for EvalHandle {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("EvalHandle")
29 .field("shape", &self.shape)
30 .field("num_args", &self.resolved.len())
31 .finish()
32 }
33}
34
35impl EvalHandle {
36 pub fn shape(&self) -> &[usize] {
38 &self.shape
39 }
40
41 pub fn len(&self) -> usize {
43 total_elements(&self.shape)
44 }
45
46 pub fn is_empty(&self) -> bool {
48 self.len() == 0
49 }
50
51 pub fn scalar(self) -> Result<NumericResult, EvalError> {
53 if !self.shape.is_empty() {
54 return Err(EvalError::ShapeMismatch {
55 details: format!("expected scalar output but shape is {:?}", self.shape),
56 });
57 }
58 let args: Vec<NumericResult> = self.resolved.iter().map(|r| r.get(0)).collect();
59 scalar::eval_node(&self.expr.root, &args, &mut vec![])
60 }
61
62 pub fn to_array(self) -> Result<ArrayD<NumericResult>, EvalError> {
64 let (results, shape) = eval_broadcast(&self.expr, &self.resolved)?;
65 results_to_array(results, &shape)
66 }
67
68 pub fn iter(self) -> EvalIter {
70 let total = total_elements(&self.shape);
71 EvalIter {
72 expr: self.expr,
73 resolved: self.resolved,
74 shape: self.shape,
75 axis_args: self.axis_args,
76 current: 0,
77 total,
78 }
79 }
80}
81
82pub struct EvalIter {
84 expr: CompiledExpr,
85 resolved: Vec<ResolvedArg>,
86 shape: Vec<usize>,
87 axis_args: Vec<usize>,
88 current: usize,
89 total: usize,
90}
91
92impl Iterator for EvalIter {
93 type Item = Result<NumericResult, EvalError>;
94
95 fn next(&mut self) -> Option<Self::Item> {
96 if self.current >= self.total {
97 return None;
98 }
99 let multi = flat_to_multi(self.current, &self.shape);
100 let args = build_args_for_index(&self.resolved, &self.axis_args, &multi);
101 self.current += 1;
102 Some(scalar::eval_node(&self.expr.root, &args, &mut vec![]))
103 }
104}
105
106impl EvalIter {
107 pub fn remaining(&self) -> usize {
109 self.total - self.current
110 }
111}
112
113pub fn eval(
120 expr: &CompiledExpr,
121 mut args: HashMap<&str, EvalInput>,
122) -> Result<EvalHandle, EvalError> {
123 let expected = expr.argument_names();
124
125 for name in args.keys() {
127 if !expected.iter().any(|e| e == name) {
128 return Err(EvalError::UnknownArgument {
129 name: name.to_string(),
130 });
131 }
132 }
133
134 let mut resolved = Vec::with_capacity(expected.len());
136 for name in expected {
137 match args.remove_entry(name.as_str()) {
138 Some((_, input)) => resolved.push(resolve_input(input)),
139 None => {
140 return Err(EvalError::MissingArgument { name: name.clone() });
141 }
142 }
143 }
144
145 let (shape, axis_args) = compute_shape(&resolved);
146
147 Ok(EvalHandle {
148 expr: expr.clone(),
149 resolved,
150 shape,
151 axis_args,
152 })
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::compiler::ir::{BinaryOp, CompiledNode};
159 use approx::assert_abs_diff_eq;
160
161 fn make_expr(root: CompiledNode, arg_names: Vec<&str>) -> CompiledExpr {
162 CompiledExpr {
163 root,
164 argument_names: arg_names.into_iter().map(String::from).collect(),
165 is_complex: false,
166 }
167 }
168
169 fn x_sq_plus_y() -> CompiledExpr {
171 let x_sq = CompiledNode::Binary {
172 op: BinaryOp::Pow,
173 left: Box::new(CompiledNode::Argument(0)),
174 right: Box::new(CompiledNode::Literal(2.0)),
175 };
176 make_expr(
177 CompiledNode::Binary {
178 op: BinaryOp::Add,
179 left: Box::new(x_sq),
180 right: Box::new(CompiledNode::Argument(1)),
181 },
182 vec!["x", "y"],
183 )
184 }
185
186 #[test]
187 fn eval_scalar_result() {
188 let expr = x_sq_plus_y();
189 let mut args = HashMap::new();
190 args.insert("x", EvalInput::Scalar(3.0));
191 args.insert("y", EvalInput::Scalar(10.0));
192 let handle = eval(&expr, args).unwrap();
193 assert!(handle.shape().is_empty());
194 assert_eq!(handle.len(), 1);
195 let result = handle.scalar().unwrap();
196 assert_abs_diff_eq!(result.to_f64().unwrap(), 19.0, epsilon = 1e-10);
197 }
198
199 #[test]
200 fn eval_1d_array() {
201 let expr = make_expr(
202 CompiledNode::Binary {
203 op: BinaryOp::Pow,
204 left: Box::new(CompiledNode::Argument(0)),
205 right: Box::new(CompiledNode::Literal(2.0)),
206 },
207 vec!["x"],
208 );
209 let mut args = HashMap::new();
210 args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
211 let handle = eval(&expr, args).unwrap();
212 assert_eq!(handle.shape(), &[3]);
213 let arr = handle.to_array().unwrap();
214 assert_eq!(arr.shape(), &[3]);
215 assert_eq!(*arr.get([0]).unwrap(), NumericResult::Real(1.0));
216 assert_eq!(*arr.get([1]).unwrap(), NumericResult::Real(4.0));
217 assert_eq!(*arr.get([2]).unwrap(), NumericResult::Real(9.0));
218 }
219
220 #[test]
221 fn eval_2d_cartesian() {
222 let expr = x_sq_plus_y();
223 let mut args = HashMap::new();
224 args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
225 args.insert("y", EvalInput::from(vec![10.0, 20.0]));
226 let handle = eval(&expr, args).unwrap();
227 assert_eq!(handle.shape(), &[3, 2]);
228 assert_eq!(handle.len(), 6);
229 let arr = handle.to_array().unwrap();
230 assert_eq!(*arr.get([0, 0]).unwrap(), NumericResult::Real(11.0));
231 assert_eq!(*arr.get([0, 1]).unwrap(), NumericResult::Real(21.0));
232 assert_eq!(*arr.get([2, 1]).unwrap(), NumericResult::Real(29.0));
233 }
234
235 #[test]
236 fn eval_iter_matches_to_array() {
237 let expr = x_sq_plus_y();
238 let mut args1 = HashMap::new();
239 args1.insert("x", EvalInput::from(vec![1.0, 2.0]));
240 args1.insert("y", EvalInput::from(vec![10.0, 20.0]));
241
242 let mut args2 = HashMap::new();
243 args2.insert("x", EvalInput::from(vec![1.0, 2.0]));
244 args2.insert("y", EvalInput::from(vec![10.0, 20.0]));
245
246 let handle1 = eval(&expr, args1).unwrap();
247 let handle2 = eval(&expr, args2).unwrap();
248
249 let arr_results: Vec<NumericResult> = handle1.to_array().unwrap().iter().copied().collect();
250 let iter_results: Vec<NumericResult> = handle2.iter().map(|r| r.unwrap()).collect();
251
252 assert_eq!(arr_results, iter_results);
253 }
254
255 #[test]
256 fn eval_unknown_argument_error() {
257 let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
258 let mut args = HashMap::new();
259 args.insert("x", EvalInput::Scalar(1.0));
260 args.insert("z", EvalInput::Scalar(2.0));
261 let err = eval(&expr, args).unwrap_err();
262 assert!(matches!(err, EvalError::UnknownArgument { .. }));
263 }
264
265 #[test]
266 fn eval_missing_argument_error() {
267 let expr = x_sq_plus_y();
268 let mut args = HashMap::new();
269 args.insert("x", EvalInput::Scalar(1.0));
270 let err = eval(&expr, args).unwrap_err();
271 assert!(matches!(err, EvalError::MissingArgument { .. }));
272 }
273
274 #[test]
275 fn eval_scalar_on_nonscalar_errors() {
276 let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
277 let mut args = HashMap::new();
278 args.insert("x", EvalInput::from(vec![1.0, 2.0]));
279 let handle = eval(&expr, args).unwrap();
280 let err = handle.scalar().unwrap_err();
281 assert!(matches!(err, EvalError::ShapeMismatch { .. }));
282 }
283
284 #[test]
285 fn eval_no_args_expression() {
286 let expr = make_expr(CompiledNode::Literal(42.0), vec![]);
288 let args = HashMap::new();
289 let handle = eval(&expr, args).unwrap();
290 let result = handle.scalar().unwrap();
291 assert_eq!(result, NumericResult::Real(42.0));
292 }
293
294 #[test]
295 fn eval_iter_remaining() {
296 let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
297 let mut args = HashMap::new();
298 args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
299 let handle = eval(&expr, args).unwrap();
300 let mut iter = handle.iter();
301 assert_eq!(iter.remaining(), 3);
302 iter.next();
303 assert_eq!(iter.remaining(), 2);
304 }
305
306 #[test]
307 fn eval_with_iterator_input() {
308 let expr = make_expr(
309 CompiledNode::Binary {
310 op: BinaryOp::Mul,
311 left: Box::new(CompiledNode::Argument(0)),
312 right: Box::new(CompiledNode::Literal(2.0)),
313 },
314 vec!["x"],
315 );
316 let mut args = HashMap::new();
317 args.insert(
318 "x",
319 EvalInput::Iter(Box::new(vec![1.0, 2.0, 3.0].into_iter())),
320 );
321 let handle = eval(&expr, args).unwrap();
322 let arr = handle.to_array().unwrap();
323 assert_eq!(arr.shape(), &[3]);
324 }
325
326 #[test]
327 fn eval_empty_array() {
328 let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
329 let mut args = HashMap::new();
330 args.insert("x", EvalInput::from(vec![] as Vec<f64>));
331 let handle = eval(&expr, args).unwrap();
332 assert!(handle.is_empty());
333 assert_eq!(handle.shape(), &[0]);
334 }
335}