mathlex_eval/eval/
input.rs1use ndarray::ArrayD;
2use num_complex::Complex;
3
4pub enum EvalInput {
10 Scalar(f64),
11 Complex(Complex<f64>),
12 Array(ArrayD<f64>),
13 ComplexArray(ArrayD<Complex<f64>>),
14 Iter(Box<dyn Iterator<Item = f64>>),
15 ComplexIter(Box<dyn Iterator<Item = Complex<f64>>>),
16}
17
18impl EvalInput {
19 pub fn is_scalar(&self) -> bool {
21 matches!(self, EvalInput::Scalar(_) | EvalInput::Complex(_))
22 }
23
24 pub fn is_complex(&self) -> bool {
26 matches!(
27 self,
28 EvalInput::Complex(_) | EvalInput::ComplexArray(_) | EvalInput::ComplexIter(_)
29 )
30 }
31
32 pub fn is_iter(&self) -> bool {
34 matches!(self, EvalInput::Iter(_) | EvalInput::ComplexIter(_))
35 }
36}
37
38impl From<f64> for EvalInput {
39 fn from(v: f64) -> Self {
40 EvalInput::Scalar(v)
41 }
42}
43
44impl From<Complex<f64>> for EvalInput {
45 fn from(v: Complex<f64>) -> Self {
46 EvalInput::Complex(v)
47 }
48}
49
50impl From<Vec<f64>> for EvalInput {
51 fn from(v: Vec<f64>) -> Self {
52 EvalInput::Array(ArrayD::from_shape_vec(vec![v.len()], v).unwrap())
53 }
54}
55
56impl From<Vec<Complex<f64>>> for EvalInput {
57 fn from(v: Vec<Complex<f64>>) -> Self {
58 EvalInput::ComplexArray(ArrayD::from_shape_vec(vec![v.len()], v).unwrap())
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 #[test]
67 fn scalar_is_scalar() {
68 assert!(EvalInput::Scalar(1.0).is_scalar());
69 }
70
71 #[test]
72 fn complex_is_scalar() {
73 assert!(EvalInput::Complex(Complex::new(1.0, 2.0)).is_scalar());
74 }
75
76 #[test]
77 fn array_not_scalar() {
78 let input: EvalInput = vec![1.0, 2.0, 3.0].into();
79 assert!(!input.is_scalar());
80 }
81
82 #[test]
83 fn complex_input_is_complex() {
84 assert!(EvalInput::Complex(Complex::new(1.0, 0.0)).is_complex());
85 }
86
87 #[test]
88 fn real_scalar_not_complex() {
89 assert!(!EvalInput::Scalar(1.0).is_complex());
90 }
91
92 #[test]
93 fn iter_is_iter() {
94 let input = EvalInput::Iter(Box::new(vec![1.0, 2.0].into_iter()));
95 assert!(input.is_iter());
96 }
97
98 #[test]
99 fn array_not_iter() {
100 let input: EvalInput = vec![1.0].into();
101 assert!(!input.is_iter());
102 }
103
104 #[test]
105 fn from_vec_f64() {
106 let input: EvalInput = vec![1.0, 2.0, 3.0].into();
107 match input {
108 EvalInput::Array(arr) => assert_eq!(arr.len(), 3),
109 _ => panic!("expected Array"),
110 }
111 }
112
113 #[test]
114 fn from_vec_complex() {
115 let input: EvalInput = vec![Complex::new(1.0, 0.0), Complex::new(0.0, 1.0)].into();
116 match input {
117 EvalInput::ComplexArray(arr) => assert_eq!(arr.len(), 2),
118 _ => panic!("expected ComplexArray"),
119 }
120 }
121}