1use num_complex::Complex;
2use std::ops::{Add, Div, Mul, Neg, Sub};
3
4#[derive(Debug, Clone, Copy, PartialEq)]
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7pub enum NumericResult {
8 Real(f64),
9 Complex(Complex<f64>),
10}
11
12impl NumericResult {
13 pub fn is_complex(&self) -> bool {
14 matches!(self, NumericResult::Complex(_))
15 }
16
17 pub fn to_complex(self) -> Complex<f64> {
18 match self {
19 NumericResult::Real(r) => Complex::new(r, 0.0),
20 NumericResult::Complex(c) => c,
21 }
22 }
23
24 pub fn to_f64(self) -> Option<f64> {
25 match self {
26 NumericResult::Real(r) => Some(r),
27 NumericResult::Complex(_) => None,
28 }
29 }
30
31 pub fn pow(self, exp: NumericResult) -> NumericResult {
32 match (self, exp) {
33 (NumericResult::Real(base), NumericResult::Real(e)) => {
34 let result = base.powf(e);
35 if result.is_nan() && base < 0.0 {
36 let c = Complex::new(base, 0.0).powc(Complex::new(e, 0.0));
38 NumericResult::Complex(c).simplify()
39 } else {
40 NumericResult::Real(result)
41 }
42 }
43 (base, exp) => {
44 let c = base.to_complex().powc(exp.to_complex());
45 NumericResult::Complex(c).simplify()
46 }
47 }
48 }
49
50 pub fn modulo(self, rhs: NumericResult) -> NumericResult {
51 match (self, rhs) {
52 (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a % b),
53 _ => {
54 NumericResult::Real(f64::NAN)
56 }
57 }
58 }
59
60 pub fn sqrt(self) -> NumericResult {
61 match self {
62 NumericResult::Real(r) if r >= 0.0 => NumericResult::Real(r.sqrt()),
63 NumericResult::Real(r) => {
64 NumericResult::Complex(Complex::new(0.0, (-r).sqrt())).simplify()
65 }
66 NumericResult::Complex(c) => NumericResult::Complex(c.sqrt()).simplify(),
67 }
68 }
69
70 fn simplify(self) -> NumericResult {
71 if let NumericResult::Complex(c) = self {
72 if c.im.abs() < 1e-15 {
73 return NumericResult::Real(c.re);
74 }
75 }
76 self
77 }
78}
79
80impl From<f64> for NumericResult {
81 fn from(v: f64) -> Self {
82 NumericResult::Real(v)
83 }
84}
85
86impl From<Complex<f64>> for NumericResult {
87 fn from(v: Complex<f64>) -> Self {
88 NumericResult::Complex(v)
89 }
90}
91
92impl From<i64> for NumericResult {
93 fn from(v: i64) -> Self {
94 NumericResult::Real(v as f64)
95 }
96}
97
98impl Add for NumericResult {
99 type Output = NumericResult;
100
101 fn add(self, rhs: NumericResult) -> NumericResult {
102 match (self, rhs) {
103 (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a + b),
104 (a, b) => NumericResult::Complex(a.to_complex() + b.to_complex()).simplify(),
105 }
106 }
107}
108
109impl Sub for NumericResult {
110 type Output = NumericResult;
111
112 fn sub(self, rhs: NumericResult) -> NumericResult {
113 match (self, rhs) {
114 (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a - b),
115 (a, b) => NumericResult::Complex(a.to_complex() - b.to_complex()).simplify(),
116 }
117 }
118}
119
120impl Mul for NumericResult {
121 type Output = NumericResult;
122
123 fn mul(self, rhs: NumericResult) -> NumericResult {
124 match (self, rhs) {
125 (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a * b),
126 (a, b) => NumericResult::Complex(a.to_complex() * b.to_complex()).simplify(),
127 }
128 }
129}
130
131impl Div for NumericResult {
132 type Output = NumericResult;
133
134 fn div(self, rhs: NumericResult) -> NumericResult {
135 match (self, rhs) {
136 (NumericResult::Real(a), NumericResult::Real(b)) => NumericResult::Real(a / b),
137 (a, b) => NumericResult::Complex(a.to_complex() / b.to_complex()).simplify(),
138 }
139 }
140}
141
142impl Neg for NumericResult {
143 type Output = NumericResult;
144
145 fn neg(self) -> NumericResult {
146 match self {
147 NumericResult::Real(r) => NumericResult::Real(-r),
148 NumericResult::Complex(c) => NumericResult::Complex(-c),
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use approx::assert_abs_diff_eq;
157
158 #[test]
159 fn real_add_real_stays_real() {
160 let r = NumericResult::Real(2.0) + NumericResult::Real(3.0);
161 assert_eq!(r, NumericResult::Real(5.0));
162 }
163
164 #[test]
165 fn real_add_complex_promotes() {
166 let r = NumericResult::Real(1.0) + NumericResult::Complex(Complex::new(2.0, 3.0));
167 assert_eq!(r, NumericResult::Complex(Complex::new(3.0, 3.0)));
168 }
169
170 #[test]
171 fn real_sub_real() {
172 let r = NumericResult::Real(5.0) - NumericResult::Real(3.0);
173 assert_eq!(r, NumericResult::Real(2.0));
174 }
175
176 #[test]
177 fn real_mul_real() {
178 let r = NumericResult::Real(3.0) * NumericResult::Real(4.0);
179 assert_eq!(r, NumericResult::Real(12.0));
180 }
181
182 #[test]
183 fn real_div_real() {
184 let r = NumericResult::Real(10.0) / NumericResult::Real(4.0);
185 assert_eq!(r, NumericResult::Real(2.5));
186 }
187
188 #[test]
189 fn neg_real() {
190 let r = -NumericResult::Real(5.0);
191 assert_eq!(r, NumericResult::Real(-5.0));
192 }
193
194 #[test]
195 fn neg_complex() {
196 let r = -NumericResult::Complex(Complex::new(1.0, 2.0));
197 assert_eq!(r, NumericResult::Complex(Complex::new(-1.0, -2.0)));
198 }
199
200 #[test]
201 fn complex_mul_complex() {
202 let a = NumericResult::Complex(Complex::new(1.0, 2.0));
204 let b = NumericResult::Complex(Complex::new(3.0, 4.0));
205 let r = a * b;
206 assert_eq!(r, NumericResult::Complex(Complex::new(-5.0, 10.0)));
207 }
208
209 #[test]
210 fn sqrt_negative_returns_complex() {
211 let r = NumericResult::Real(-1.0).sqrt();
212 match r {
213 NumericResult::Complex(c) => {
214 assert_abs_diff_eq!(c.re, 0.0, epsilon = 1e-15);
215 assert_abs_diff_eq!(c.im, 1.0, epsilon = 1e-15);
216 }
217 _ => panic!("expected complex"),
218 }
219 }
220
221 #[test]
222 fn sqrt_positive_stays_real() {
223 let r = NumericResult::Real(4.0).sqrt();
224 assert_eq!(r, NumericResult::Real(2.0));
225 }
226
227 #[test]
228 fn complex_with_zero_im_simplifies_to_real() {
229 let c = NumericResult::Complex(Complex::new(5.0, 0.0));
230 let simplified = c.simplify();
231 assert_eq!(simplified, NumericResult::Real(5.0));
232 }
233
234 #[test]
235 fn pow_real_real() {
236 let r = NumericResult::Real(2.0).pow(NumericResult::Real(3.0));
237 assert_eq!(r, NumericResult::Real(8.0));
238 }
239
240 #[test]
241 fn pow_negative_base_fractional_exp_promotes() {
242 let r = NumericResult::Real(-8.0).pow(NumericResult::Real(1.0 / 3.0));
243 assert!(r.is_complex());
244 }
245
246 #[test]
247 fn from_f64() {
248 let r: NumericResult = 2.75.into();
249 assert_eq!(r, NumericResult::Real(2.75));
250 }
251
252 #[test]
253 fn from_complex() {
254 let c = Complex::new(1.0, 2.0);
255 let r: NumericResult = c.into();
256 assert_eq!(r, NumericResult::Complex(c));
257 }
258
259 #[test]
260 fn from_i64() {
261 let r: NumericResult = 42i64.into();
262 assert_eq!(r, NumericResult::Real(42.0));
263 }
264
265 #[test]
266 fn to_f64_real() {
267 assert_eq!(NumericResult::Real(3.0).to_f64(), Some(3.0));
268 }
269
270 #[test]
271 fn to_f64_complex_returns_none() {
272 assert_eq!(
273 NumericResult::Complex(Complex::new(1.0, 2.0)).to_f64(),
274 None
275 );
276 }
277
278 #[test]
279 fn modulo_real() {
280 let r = NumericResult::Real(7.0).modulo(NumericResult::Real(3.0));
281 assert_eq!(r, NumericResult::Real(1.0));
282 }
283}