Skip to main content

mathlex_eval/eval/
input.rs

1use ndarray::ArrayD;
2use num_complex::Complex;
3
4/// Input value for an argument during evaluation.
5///
6/// Each argument can be a scalar, array, or iterator — real or complex.
7/// Scalar arguments broadcast to all positions. Array arguments contribute
8/// one axis to the output shape. Iterator arguments are cached incrementally.
9pub enum EvalInput {
10    Scalar(f64),
11    Complex(Complex<f64>),
12    Array(ArrayD<f64>),
13    ComplexArray(ArrayD<Complex<f64>>),
14    Iter(Box<dyn Iterator<Item = f64>>),
15    ComplexIter(Box<dyn Iterator<Item = Complex<f64>>>),
16}
17
18impl EvalInput {
19    /// Whether this input is scalar (contributes no axis to output).
20    pub fn is_scalar(&self) -> bool {
21        matches!(self, EvalInput::Scalar(_) | EvalInput::Complex(_))
22    }
23
24    /// Whether this input contains complex values.
25    pub fn is_complex(&self) -> bool {
26        matches!(
27            self,
28            EvalInput::Complex(_) | EvalInput::ComplexArray(_) | EvalInput::ComplexIter(_)
29        )
30    }
31
32    /// Whether this input is an iterator (unknown length until exhausted).
33    pub fn is_iter(&self) -> bool {
34        matches!(self, EvalInput::Iter(_) | EvalInput::ComplexIter(_))
35    }
36}
37
38impl From<f64> for EvalInput {
39    fn from(v: f64) -> Self {
40        EvalInput::Scalar(v)
41    }
42}
43
44impl From<Complex<f64>> for EvalInput {
45    fn from(v: Complex<f64>) -> Self {
46        EvalInput::Complex(v)
47    }
48}
49
50impl From<Vec<f64>> for EvalInput {
51    fn from(v: Vec<f64>) -> Self {
52        EvalInput::Array(ArrayD::from_shape_vec(vec![v.len()], v).unwrap())
53    }
54}
55
56impl From<Vec<Complex<f64>>> for EvalInput {
57    fn from(v: Vec<Complex<f64>>) -> Self {
58        EvalInput::ComplexArray(ArrayD::from_shape_vec(vec![v.len()], v).unwrap())
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    #[test]
67    fn scalar_is_scalar() {
68        assert!(EvalInput::Scalar(1.0).is_scalar());
69    }
70
71    #[test]
72    fn complex_is_scalar() {
73        assert!(EvalInput::Complex(Complex::new(1.0, 2.0)).is_scalar());
74    }
75
76    #[test]
77    fn array_not_scalar() {
78        let input: EvalInput = vec![1.0, 2.0, 3.0].into();
79        assert!(!input.is_scalar());
80    }
81
82    #[test]
83    fn complex_input_is_complex() {
84        assert!(EvalInput::Complex(Complex::new(1.0, 0.0)).is_complex());
85    }
86
87    #[test]
88    fn real_scalar_not_complex() {
89        assert!(!EvalInput::Scalar(1.0).is_complex());
90    }
91
92    #[test]
93    fn iter_is_iter() {
94        let input = EvalInput::Iter(Box::new(vec![1.0, 2.0].into_iter()));
95        assert!(input.is_iter());
96    }
97
98    #[test]
99    fn array_not_iter() {
100        let input: EvalInput = vec![1.0].into();
101        assert!(!input.is_iter());
102    }
103
104    #[test]
105    fn from_vec_f64() {
106        let input: EvalInput = vec![1.0, 2.0, 3.0].into();
107        match input {
108            EvalInput::Array(arr) => assert_eq!(arr.len(), 3),
109            _ => panic!("expected Array"),
110        }
111    }
112
113    #[test]
114    fn from_vec_complex() {
115        let input: EvalInput = vec![Complex::new(1.0, 0.0), Complex::new(0.0, 1.0)].into();
116        match input {
117            EvalInput::ComplexArray(arr) => assert_eq!(arr.len(), 2),
118            _ => panic!("expected ComplexArray"),
119        }
120    }
121}