Skip to main content

mathlex_eval/eval/
handle.rs

1use std::collections::HashMap;
2
3use ndarray::ArrayD;
4
5use crate::broadcast::{
6    ResolvedArg, build_args_for_index, compute_shape, eval_broadcast, flat_to_multi, resolve_input,
7    results_to_array, total_elements,
8};
9use crate::compiler::ir::CompiledExpr;
10use crate::error::EvalError;
11use crate::eval::input::EvalInput;
12use crate::eval::numeric::NumericResult;
13use crate::eval::scalar;
14
15/// Lazy evaluation handle returned by [`eval()`].
16///
17/// Computes output shape from input shapes but defers actual evaluation
18/// until the caller chooses a consumption mode.
19pub struct EvalHandle {
20    expr: CompiledExpr,
21    resolved: Vec<ResolvedArg>,
22    shape: Vec<usize>,
23    axis_args: Vec<usize>,
24}
25
26impl std::fmt::Debug for EvalHandle {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("EvalHandle")
29            .field("shape", &self.shape)
30            .field("num_args", &self.resolved.len())
31            .finish()
32    }
33}
34
35impl EvalHandle {
36    /// Output shape. Empty for scalar output, `[n]` for 1-D, `[n, m]` for 2-D, etc.
37    pub fn shape(&self) -> &[usize] {
38        &self.shape
39    }
40
41    /// Total number of output elements.
42    pub fn len(&self) -> usize {
43        total_elements(&self.shape)
44    }
45
46    /// Whether the output is empty (zero elements due to empty input array).
47    pub fn is_empty(&self) -> bool {
48        self.len() == 0
49    }
50
51    /// Consume as scalar. Errors if output is not 0-d.
52    pub fn scalar(self) -> Result<NumericResult, EvalError> {
53        if !self.shape.is_empty() {
54            return Err(EvalError::ShapeMismatch {
55                details: format!("expected scalar output but shape is {:?}", self.shape),
56            });
57        }
58        let args: Vec<NumericResult> = self.resolved.iter().map(|r| r.get(0)).collect();
59        scalar::eval_node(&self.expr.root, &args, &mut vec![])
60    }
61
62    /// Consume eagerly into a full N-dimensional array.
63    pub fn to_array(self) -> Result<ArrayD<NumericResult>, EvalError> {
64        let (results, shape) = eval_broadcast(&self.expr, &self.resolved)?;
65        results_to_array(results, &shape)
66    }
67
68    /// Consume lazily — yields results as they become computable.
69    pub fn iter(self) -> EvalIter {
70        let total = total_elements(&self.shape);
71        EvalIter {
72            expr: self.expr,
73            resolved: self.resolved,
74            shape: self.shape,
75            axis_args: self.axis_args,
76            current: 0,
77            total,
78        }
79    }
80}
81
82/// Streaming result iterator over broadcast evaluation.
83pub struct EvalIter {
84    expr: CompiledExpr,
85    resolved: Vec<ResolvedArg>,
86    shape: Vec<usize>,
87    axis_args: Vec<usize>,
88    current: usize,
89    total: usize,
90}
91
92impl Iterator for EvalIter {
93    type Item = Result<NumericResult, EvalError>;
94
95    fn next(&mut self) -> Option<Self::Item> {
96        if self.current >= self.total {
97            return None;
98        }
99        let multi = flat_to_multi(self.current, &self.shape);
100        let args = build_args_for_index(&self.resolved, &self.axis_args, &multi);
101        self.current += 1;
102        Some(scalar::eval_node(&self.expr.root, &args, &mut vec![]))
103    }
104}
105
106impl EvalIter {
107    /// Number of remaining elements.
108    pub fn remaining(&self) -> usize {
109        self.total - self.current
110    }
111}
112
113/// Create a lazy eval handle from a compiled expression and arguments.
114///
115/// Validates that all expected arguments are provided and no unknown
116/// arguments are present. Resolves inputs and computes output shape,
117/// but defers evaluation until consumed via `scalar()`, `to_array()`,
118/// or `iter()`.
119pub fn eval(
120    expr: &CompiledExpr,
121    mut args: HashMap<&str, EvalInput>,
122) -> Result<EvalHandle, EvalError> {
123    let expected = expr.argument_names();
124
125    // Check for unknown arguments
126    for name in args.keys() {
127        if !expected.iter().any(|e| e == name) {
128            return Err(EvalError::UnknownArgument {
129                name: name.to_string(),
130            });
131        }
132    }
133
134    // Resolve arguments in declaration order
135    let mut resolved = Vec::with_capacity(expected.len());
136    for name in expected {
137        match args.remove_entry(name.as_str()) {
138            Some((_, input)) => resolved.push(resolve_input(input)),
139            None => {
140                return Err(EvalError::MissingArgument { name: name.clone() });
141            }
142        }
143    }
144
145    let (shape, axis_args) = compute_shape(&resolved);
146
147    Ok(EvalHandle {
148        expr: expr.clone(),
149        resolved,
150        shape,
151        axis_args,
152    })
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::compiler::ir::{BinaryOp, CompiledNode};
159    use approx::assert_abs_diff_eq;
160
161    fn make_expr(root: CompiledNode, arg_names: Vec<&str>) -> CompiledExpr {
162        CompiledExpr {
163            root,
164            argument_names: arg_names.into_iter().map(String::from).collect(),
165            is_complex: false,
166        }
167    }
168
169    // x^2 + y
170    fn x_sq_plus_y() -> CompiledExpr {
171        let x_sq = CompiledNode::Binary {
172            op: BinaryOp::Pow,
173            left: Box::new(CompiledNode::Argument(0)),
174            right: Box::new(CompiledNode::Literal(2.0)),
175        };
176        make_expr(
177            CompiledNode::Binary {
178                op: BinaryOp::Add,
179                left: Box::new(x_sq),
180                right: Box::new(CompiledNode::Argument(1)),
181            },
182            vec!["x", "y"],
183        )
184    }
185
186    #[test]
187    fn eval_scalar_result() {
188        let expr = x_sq_plus_y();
189        let mut args = HashMap::new();
190        args.insert("x", EvalInput::Scalar(3.0));
191        args.insert("y", EvalInput::Scalar(10.0));
192        let handle = eval(&expr, args).unwrap();
193        assert!(handle.shape().is_empty());
194        assert_eq!(handle.len(), 1);
195        let result = handle.scalar().unwrap();
196        assert_abs_diff_eq!(result.to_f64().unwrap(), 19.0, epsilon = 1e-10);
197    }
198
199    #[test]
200    fn eval_1d_array() {
201        let expr = make_expr(
202            CompiledNode::Binary {
203                op: BinaryOp::Pow,
204                left: Box::new(CompiledNode::Argument(0)),
205                right: Box::new(CompiledNode::Literal(2.0)),
206            },
207            vec!["x"],
208        );
209        let mut args = HashMap::new();
210        args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
211        let handle = eval(&expr, args).unwrap();
212        assert_eq!(handle.shape(), &[3]);
213        let arr = handle.to_array().unwrap();
214        assert_eq!(arr.shape(), &[3]);
215        assert_eq!(*arr.get([0]).unwrap(), NumericResult::Real(1.0));
216        assert_eq!(*arr.get([1]).unwrap(), NumericResult::Real(4.0));
217        assert_eq!(*arr.get([2]).unwrap(), NumericResult::Real(9.0));
218    }
219
220    #[test]
221    fn eval_2d_cartesian() {
222        let expr = x_sq_plus_y();
223        let mut args = HashMap::new();
224        args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
225        args.insert("y", EvalInput::from(vec![10.0, 20.0]));
226        let handle = eval(&expr, args).unwrap();
227        assert_eq!(handle.shape(), &[3, 2]);
228        assert_eq!(handle.len(), 6);
229        let arr = handle.to_array().unwrap();
230        assert_eq!(*arr.get([0, 0]).unwrap(), NumericResult::Real(11.0));
231        assert_eq!(*arr.get([0, 1]).unwrap(), NumericResult::Real(21.0));
232        assert_eq!(*arr.get([2, 1]).unwrap(), NumericResult::Real(29.0));
233    }
234
235    #[test]
236    fn eval_iter_matches_to_array() {
237        let expr = x_sq_plus_y();
238        let mut args1 = HashMap::new();
239        args1.insert("x", EvalInput::from(vec![1.0, 2.0]));
240        args1.insert("y", EvalInput::from(vec![10.0, 20.0]));
241
242        let mut args2 = HashMap::new();
243        args2.insert("x", EvalInput::from(vec![1.0, 2.0]));
244        args2.insert("y", EvalInput::from(vec![10.0, 20.0]));
245
246        let handle1 = eval(&expr, args1).unwrap();
247        let handle2 = eval(&expr, args2).unwrap();
248
249        let arr_results: Vec<NumericResult> = handle1.to_array().unwrap().iter().copied().collect();
250        let iter_results: Vec<NumericResult> = handle2.iter().map(|r| r.unwrap()).collect();
251
252        assert_eq!(arr_results, iter_results);
253    }
254
255    #[test]
256    fn eval_unknown_argument_error() {
257        let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
258        let mut args = HashMap::new();
259        args.insert("x", EvalInput::Scalar(1.0));
260        args.insert("z", EvalInput::Scalar(2.0));
261        let err = eval(&expr, args).unwrap_err();
262        assert!(matches!(err, EvalError::UnknownArgument { .. }));
263    }
264
265    #[test]
266    fn eval_missing_argument_error() {
267        let expr = x_sq_plus_y();
268        let mut args = HashMap::new();
269        args.insert("x", EvalInput::Scalar(1.0));
270        let err = eval(&expr, args).unwrap_err();
271        assert!(matches!(err, EvalError::MissingArgument { .. }));
272    }
273
274    #[test]
275    fn eval_scalar_on_nonscalar_errors() {
276        let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
277        let mut args = HashMap::new();
278        args.insert("x", EvalInput::from(vec![1.0, 2.0]));
279        let handle = eval(&expr, args).unwrap();
280        let err = handle.scalar().unwrap_err();
281        assert!(matches!(err, EvalError::ShapeMismatch { .. }));
282    }
283
284    #[test]
285    fn eval_no_args_expression() {
286        // Constant expression: 42
287        let expr = make_expr(CompiledNode::Literal(42.0), vec![]);
288        let args = HashMap::new();
289        let handle = eval(&expr, args).unwrap();
290        let result = handle.scalar().unwrap();
291        assert_eq!(result, NumericResult::Real(42.0));
292    }
293
294    #[test]
295    fn eval_iter_remaining() {
296        let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
297        let mut args = HashMap::new();
298        args.insert("x", EvalInput::from(vec![1.0, 2.0, 3.0]));
299        let handle = eval(&expr, args).unwrap();
300        let mut iter = handle.iter();
301        assert_eq!(iter.remaining(), 3);
302        iter.next();
303        assert_eq!(iter.remaining(), 2);
304    }
305
306    #[test]
307    fn eval_with_iterator_input() {
308        let expr = make_expr(
309            CompiledNode::Binary {
310                op: BinaryOp::Mul,
311                left: Box::new(CompiledNode::Argument(0)),
312                right: Box::new(CompiledNode::Literal(2.0)),
313            },
314            vec!["x"],
315        );
316        let mut args = HashMap::new();
317        args.insert(
318            "x",
319            EvalInput::Iter(Box::new(vec![1.0, 2.0, 3.0].into_iter())),
320        );
321        let handle = eval(&expr, args).unwrap();
322        let arr = handle.to_array().unwrap();
323        assert_eq!(arr.shape(), &[3]);
324    }
325
326    #[test]
327    fn eval_empty_array() {
328        let expr = make_expr(CompiledNode::Argument(0), vec!["x"]);
329        let mut args = HashMap::new();
330        args.insert("x", EvalInput::from(vec![] as Vec<f64>));
331        let handle = eval(&expr, args).unwrap();
332        assert!(handle.is_empty());
333        assert_eq!(handle.shape(), &[0]);
334    }
335}