Skip to main content

mathlex_eval/eval/
numeric.rs

1use num_complex::Complex;
2use std::ops::{Add, Div, Mul, Neg, Sub};
3
4/// Result of evaluating a compiled expression — either a real or complex number.
5#[derive(Debug, Clone, Copy, PartialEq)]
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7pub enum NumericResult {
8    Real(f64),
9    Complex(Complex<f64>),
10}
11
12impl NumericResult {
13    pub fn is_complex(&self) -> bool {
14        matches!(self, NumericResult::Complex(_))
15    }
16
17    pub fn to_complex(self) -> Complex<f64> {
18        match self {
19            NumericResult::Real(r) => Complex::new(r, 0.0),
20            NumericResult::Complex(c) => c,
21        }
22    }
23
24    pub fn to_f64(self) -> Option<f64> {
25        match self {
26            NumericResult::Real(r) => Some(r),
27            NumericResult::Complex(_) => None,
28        }
29    }
30
31    pub fn pow(self, exp: NumericResult) -> NumericResult {
32        match (self, exp) {
33            (NumericResult::Real(base), NumericResult::Real(e)) => {
34                let result = base.powf(e);
35                if result.is_nan() && base < 0.0 {
36                    // Negative base with fractional exponent → complex
37                    let c = Complex::new(base, 0.0).powc(Complex::new(e, 0.0));
38                    NumericResult::Complex(c).simplify()
39                } else {
40                    NumericResult::Real(result)
41                }
42            }
43            (base, exp) => {
44                let c = base.to_complex().powc(exp.to_complex());
45                NumericResult::Complex(c).simplify()
46            }
47        }
48    }
49
50    pub fn modulo(self, rhs: NumericResult) -> NumericResult {
51        match (self, rhs) {
52            (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a % b),
53            _ => {
54                // Complex modulo not standard; return NaN-like behavior
55                NumericResult::Real(f64::NAN)
56            }
57        }
58    }
59
60    pub fn sqrt(self) -> NumericResult {
61        match self {
62            NumericResult::Real(r) if r >= 0.0 => NumericResult::Real(r.sqrt()),
63            NumericResult::Real(r) => {
64                NumericResult::Complex(Complex::new(0.0, (-r).sqrt())).simplify()
65            }
66            NumericResult::Complex(c) => NumericResult::Complex(c.sqrt()).simplify(),
67        }
68    }
69
70    fn simplify(self) -> NumericResult {
71        if let NumericResult::Complex(c) = self {
72            if c.im.abs() < 1e-15 {
73                return NumericResult::Real(c.re);
74            }
75        }
76        self
77    }
78}
79
80impl From<f64> for NumericResult {
81    fn from(v: f64) -> Self {
82        NumericResult::Real(v)
83    }
84}
85
86impl From<Complex<f64>> for NumericResult {
87    fn from(v: Complex<f64>) -> Self {
88        NumericResult::Complex(v)
89    }
90}
91
92impl From<i64> for NumericResult {
93    fn from(v: i64) -> Self {
94        NumericResult::Real(v as f64)
95    }
96}
97
98impl Add for NumericResult {
99    type Output = NumericResult;
100
101    fn add(self, rhs: NumericResult) -> NumericResult {
102        match (self, rhs) {
103            (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a + b),
104            (a, b) => NumericResult::Complex(a.to_complex() + b.to_complex()).simplify(),
105        }
106    }
107}
108
109impl Sub for NumericResult {
110    type Output = NumericResult;
111
112    fn sub(self, rhs: NumericResult) -> NumericResult {
113        match (self, rhs) {
114            (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a - b),
115            (a, b) => NumericResult::Complex(a.to_complex() - b.to_complex()).simplify(),
116        }
117    }
118}
119
120impl Mul for NumericResult {
121    type Output = NumericResult;
122
123    fn mul(self, rhs: NumericResult) -> NumericResult {
124        match (self, rhs) {
125            (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a * b),
126            (a, b) => NumericResult::Complex(a.to_complex() * b.to_complex()).simplify(),
127        }
128    }
129}
130
131impl Div for NumericResult {
132    type Output = NumericResult;
133
134    fn div(self, rhs: NumericResult) -> NumericResult {
135        match (self, rhs) {
136            (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a / b),
137            (a, b) => NumericResult::Complex(a.to_complex() / b.to_complex()).simplify(),
138        }
139    }
140}
141
142impl Neg for NumericResult {
143    type Output = NumericResult;
144
145    fn neg(self) -> NumericResult {
146        match self {
147            NumericResult::Real(r) => NumericResult::Real(-r),
148            NumericResult::Complex(c) => NumericResult::Complex(-c),
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use approx::assert_abs_diff_eq;
157
158    #[test]
159    fn real_add_real_stays_real() {
160        let r = NumericResult::Real(2.0) + NumericResult::Real(3.0);
161        assert_eq!(r, NumericResult::Real(5.0));
162    }
163
164    #[test]
165    fn real_add_complex_promotes() {
166        let r = NumericResult::Real(1.0) + NumericResult::Complex(Complex::new(2.0, 3.0));
167        assert_eq!(r, NumericResult::Complex(Complex::new(3.0, 3.0)));
168    }
169
170    #[test]
171    fn real_sub_real() {
172        let r = NumericResult::Real(5.0) - NumericResult::Real(3.0);
173        assert_eq!(r, NumericResult::Real(2.0));
174    }
175
176    #[test]
177    fn real_mul_real() {
178        let r = NumericResult::Real(3.0) * NumericResult::Real(4.0);
179        assert_eq!(r, NumericResult::Real(12.0));
180    }
181
182    #[test]
183    fn real_div_real() {
184        let r = NumericResult::Real(10.0) / NumericResult::Real(4.0);
185        assert_eq!(r, NumericResult::Real(2.5));
186    }
187
188    #[test]
189    fn neg_real() {
190        let r = -NumericResult::Real(5.0);
191        assert_eq!(r, NumericResult::Real(-5.0));
192    }
193
194    #[test]
195    fn neg_complex() {
196        let r = -NumericResult::Complex(Complex::new(1.0, 2.0));
197        assert_eq!(r, NumericResult::Complex(Complex::new(-1.0, -2.0)));
198    }
199
200    #[test]
201    fn complex_mul_complex() {
202        // (1+2i) * (3+4i) = 3+4i+6i+8i² = 3+10i-8 = -5+10i
203        let a = NumericResult::Complex(Complex::new(1.0, 2.0));
204        let b = NumericResult::Complex(Complex::new(3.0, 4.0));
205        let r = a * b;
206        assert_eq!(r, NumericResult::Complex(Complex::new(-5.0, 10.0)));
207    }
208
209    #[test]
210    fn sqrt_negative_returns_complex() {
211        let r = NumericResult::Real(-1.0).sqrt();
212        match r {
213            NumericResult::Complex(c) => {
214                assert_abs_diff_eq!(c.re, 0.0, epsilon = 1e-15);
215                assert_abs_diff_eq!(c.im, 1.0, epsilon = 1e-15);
216            }
217            _ => panic!("expected complex"),
218        }
219    }
220
221    #[test]
222    fn sqrt_positive_stays_real() {
223        let r = NumericResult::Real(4.0).sqrt();
224        assert_eq!(r, NumericResult::Real(2.0));
225    }
226
227    #[test]
228    fn complex_with_zero_im_simplifies_to_real() {
229        let c = NumericResult::Complex(Complex::new(5.0, 0.0));
230        let simplified = c.simplify();
231        assert_eq!(simplified, NumericResult::Real(5.0));
232    }
233
234    #[test]
235    fn pow_real_real() {
236        let r = NumericResult::Real(2.0).pow(NumericResult::Real(3.0));
237        assert_eq!(r, NumericResult::Real(8.0));
238    }
239
240    #[test]
241    fn pow_negative_base_fractional_exp_promotes() {
242        let r = NumericResult::Real(-8.0).pow(NumericResult::Real(1.0 / 3.0));
243        assert!(r.is_complex());
244    }
245
246    #[test]
247    fn from_f64() {
248        let r: NumericResult = 2.75.into();
249        assert_eq!(r, NumericResult::Real(2.75));
250    }
251
252    #[test]
253    fn from_complex() {
254        let c = Complex::new(1.0, 2.0);
255        let r: NumericResult = c.into();
256        assert_eq!(r, NumericResult::Complex(c));
257    }
258
259    #[test]
260    fn from_i64() {
261        let r: NumericResult = 42i64.into();
262        assert_eq!(r, NumericResult::Real(42.0));
263    }
264
265    #[test]
266    fn to_f64_real() {
267        assert_eq!(NumericResult::Real(3.0).to_f64(), Some(3.0));
268    }
269
270    #[test]
271    fn to_f64_complex_returns_none() {
272        assert_eq!(
273            NumericResult::Complex(Complex::new(1.0, 2.0)).to_f64(),
274            None
275        );
276    }
277
278    #[test]
279    fn modulo_real() {
280        let r = NumericResult::Real(7.0).modulo(NumericResult::Real(3.0));
281        assert_eq!(r, NumericResult::Real(1.0));
282    }
283}