1use core::fmt;
2use std::{ops::{Add, Div, Mul, Sub}, str::FromStr};
3
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5
6use crate::Number;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub struct Complex {
10 pub re: Number,
11 pub im: Number,
12}
13
14impl Complex {
15 pub fn new<N1: Into<Number>, N2: Into<Number>>(re: N1, im: N2) -> Self {
16 Self { re: re.into(), im: im.into() }
17 }
18
19 pub fn parts(&self) -> (Number, Number) {
20 (self.re, self.im)
21 }
22
23 pub fn conjugate(&self) -> Self {
24 Self {
25 re: self.re,
26 im: -self.im,
27 }
28 }
29
30 pub fn norm_sqr(&self) -> Number {
31 self.re * self.re + self.im * self.im
32 }
33
34 pub fn abs(&self) -> Number {
35 (self.re.powf(2.) + self.im.powf(2.)).sqrt()
36 }
37
38 pub fn arg(&self) -> Number {
39 self.im.atan2(self.re)
40 }
41}
42
43impl Add for Complex {
44 type Output = Self;
45 fn add(self, rhs: Self) -> Self {
46 let (a, b) = self.parts();
47 let (c, d) = rhs.parts();
48 Complex::new(a + c, b + d)
49 }
50}
51
52impl Sub for Complex {
53 type Output = Self;
54 fn sub(self, rhs: Self) -> Self {
55 let (a, b) = self.parts();
56 let (c, d) = rhs.parts();
57 Complex::new(a - c, b - d)
58 }
59}
60
61impl Mul for Complex {
62 type Output = Self;
63 fn mul(self, rhs: Self) -> Self {
64 let (a, b) = self.parts();
65 let (c, d) = rhs.parts();
66 Complex::new(a * c - b * d, a * d + b * c)
68 }
69}
70
71impl Div for Complex {
72 type Output = Self;
73 fn div(self, rhs: Self) -> Self {
74 let (a, b) = self.parts();
75 let (c, d) = rhs.parts();
76 let denom = c * c + d * d;
77 if denom.is_zero() {
78 panic!("Divide by zero in complex division");
79 }
80 let re = (a * c + b * d) / denom;
81 let im = (b * c - a * d) / denom;
82 Complex::new(re, im)
83 }
84}
85
86impl fmt::Display for Complex {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 let re_is_zero = self.re.to_f64() == 0.0;
89 let im_is_zero = self.im.to_f64() == 0.0;
90
91 let precision = f.precision();
92
93 match (re_is_zero, im_is_zero) {
94 (true, true) => write!(f, "0"),
95 (false, true) => {
96 match precision {
97 Some(p) => write!(f, "{:.*}", p, self.re),
98 None => write!(f, "{}", self.re),
99 }
100 }
101 (true, false) => {
102 if self.im.to_f64() >= 0.0 {
103 match precision {
104 Some(p) => write!(f, "{:.*}j", p, self.im),
105 None => write!(f, "{}j", self.im),
106 }
107 } else {
108 match precision {
109 Some(p) => write!(f, "-{:.*}j", p, -self.im),
110 None => write!(f, "-{}j", -self.im),
111 }
112 }
113 }
114 (false, false) => {
115 let re_fmt = match precision {
116 Some(p) => format!("{:.*}", p, self.re),
117 None => format!("{}", self.re),
118 };
119 let im_fmt = match precision {
120 Some(p) => format!("{:.*}", p, self.im),
121 None => format!("{}", self.im),
122 };
123
124 if self.im.to_f64() >= 0.0 {
125 write!(f, "{}+{}j", re_fmt, im_fmt)
126 } else {
127 write!(f, "{}-{}j", re_fmt, im_fmt.trim_start_matches('-'))
128 }
129 }
130 }
131 }
132}
133
134impl FromStr for Complex {
135 type Err = String;
136
137 fn from_str(s: &str) -> Result<Self, Self::Err> {
138 fn find_real_imag_separator(s: &str) -> Option<usize> {
139 let mut chars = s.char_indices().peekable();
140 chars.next();
141
142 while let Some((i, c)) = chars.next() {
143 if (c == '+' || c == '-') && s[i+1..].contains('j') {
144 return Some(i);
145 }
146 }
147 None
148 }
149
150 let s = s.trim();
151
152 if let Some(idx) = find_real_imag_separator(s) {
154 let (real_part, imag_part) = s.split_at(idx);
155 let real = real_part.trim().parse::<Number>()
156 .map_err(|e| format!("Parse real part error: {}", e))?;
157 let imag_str = imag_part.trim_end_matches('j');
158 let imag = imag_str.parse::<Number>()
159 .map_err(|e| format!("Parse imaginary part error: {}", e))?;
160 return Ok(Complex { re: real, im: imag });
161 }
162
163 if s.ends_with('j') {
165 let imag_part = &s[..s.len()-1];
166 let im = imag_part.parse::<Number>()
167 .map_err(|e| format!("Parse imaginary part error: {}", e))?;
168 return Ok(Complex { re: Number::zero(), im });
169 }
170
171 let re = s.parse::<Number>()
172 .map_err(|e| format!("Parse real number error: {}", e))?;
173 Ok(Complex { re, im: Number::zero() })
174 }
175}
176
177impl Serialize for Complex {
178 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
179 where
180 S: Serializer,
181 {
182 let s = self.to_string();
183 serializer.serialize_str(&s)
184 }
185}
186
187impl<'de> Deserialize<'de> for Complex {
188 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
189 where
190 D: Deserializer<'de>,
191 {
192 let s = String::deserialize(deserializer)?;
193 Complex::from_str(&s).map_err(serde::de::Error::custom)
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use crate::{complex, num, Suffix};
200
201 use super::*;
202
203 #[test]
204 fn test_real_only() {
205 let c = Complex::from_str("1.5").unwrap();
206 assert_eq!(c.re, Number::new(1.5, Suffix::None));
207 assert_eq!(c.im, Number::new(0.0, Suffix::None));
208
209 let c = Complex::from_str("2.2u").unwrap();
210 assert_eq!(c.re, Number::new(2.2, Suffix::Micro));
211 assert_eq!(c.im, Number::new(0.0, Suffix::None));
212 }
213
214 #[test]
215 fn test_imag_only() {
216 let c = Complex::from_str("+3.3j").unwrap();
217 assert_eq!(c.re, Number::new(0.0, Suffix::None));
218 assert_eq!(c.im, Number::new(3.3, Suffix::None));
219
220 let c = Complex::from_str("-5.5mj").unwrap();
221 assert_eq!(c.re, Number::new(0.0, Suffix::None));
222 assert_eq!(c.im, Number::new(-5.5, Suffix::Milli));
223
224 let c = Complex::from_str("5.5mj").unwrap();
225 assert_eq!(c.re, Number::new(0.0, Suffix::None));
226 assert_eq!(c.im, Number::new(5.5, Suffix::Milli));
227 }
228
229 #[test]
230 fn test_real_imag() {
231 let c = Complex::from_str("1.1+2.2j").unwrap();
232 assert_eq!(c.re, Number::new(1.1, Suffix::None));
233 assert_eq!(c.im, Number::new(2.2, Suffix::None));
234
235 let c = Complex::from_str("-3.0-4.4uj").unwrap();
236 assert_eq!(c.re, Number::new(-3.0, Suffix::None));
237 assert_eq!(c.im, Number::new(-4.4, Suffix::Micro));
238
239 let c = Complex::from_str("10.5-7.5nj").unwrap();
240 assert_eq!(c.re, Number::new(10.5, Suffix::None));
241 assert_eq!(c.im, Number::new(-7.5, Suffix::Nano));
242 }
243
244 #[test]
245 fn test_error_cases() {
246 assert!(Complex::from_str("hello").is_err());
247 assert!(Complex::from_str("1.2+badj").is_err());
248 assert!(Complex::from_str("1.2+3.3").is_err());
249 assert!(Complex::from_str("j3.3").is_err());
250 }
251
252 #[test]
253 fn test_creation() {
254 let c = Complex { re: num!(3.0), im: num!(4.0) };
255 let _ = complex!(3.0, 4.0);
256 assert_eq!(c.re, num!(3.0));
257 assert_eq!(c.im, num!(4.0));
258 }
259
260 #[test]
261 fn test_equality() {
262 let a = Complex { re: num!(1.0), im: num!(2.0) };
263 let b = Complex { re: num!(1.0), im: num!(2.0) };
264 let c = Complex { re: num!(1.0), im: num!(3.0) };
265 assert_eq!(a, b);
266 assert_ne!(a, c);
267 }
268
269 #[test]
270 fn test_addition() {
271 let a = Complex { re: num!(1.0), im: num!(2.0) };
272 let b = Complex { re: num!(3.0), im: num!(4.0) };
273 let sum = Complex { re: num!(4.0), im: num!(6.0) };
274 assert_eq!(a + b, sum);
275 }
276
277 #[test]
278 fn test_multiplication() {
279 let a = Complex { re: num!(1.0), im: num!(2.0) };
280 let b = Complex { re: num!(3.0), im: num!(4.0) };
281 let product = Complex { re: num!(-5.0), im: num!(10.0) };
283 assert_eq!(a * b, product);
284 }
285
286 #[test]
287 fn test_conjugate() {
288 let a = Complex { re: num!(5.0), im: num!(-7.0) };
289 let conj = Complex { re: num!(5.0), im: num!(7.0) };
290 assert_eq!(a.conjugate(), conj);
291 }
292
293 #[test]
294 fn test_magnitude_squared() {
295 let c = Complex { re: num!(3.0), im: num!(4.0) };
296 assert_eq!(c.norm_sqr(), num!(25.0));
297 }
298
299 #[test]
300 fn test_serialize_deserialize_complex_real_only() {
301 let c = Complex::from_str("3.3u").unwrap();
302 let json = serde_json::to_string(&c).unwrap();
303 assert_eq!(json, "\"3.3u\"");
304 let parsed: Complex = serde_json::from_str(&json).unwrap();
305 assert_eq!(parsed, c);
306 }
307
308 #[test]
309 fn test_serialize_deserialize_complex_imag_only() {
310 let c = Complex::from_str("2.2mj").unwrap();
311 let json = serde_json::to_string(&c).unwrap();
312 assert_eq!(json, "\"2.2mj\"");
313 let parsed: Complex = serde_json::from_str(&json).unwrap();
314 assert_eq!(parsed, c);
315 }
316
317 #[test]
318 fn test_serialize_deserialize_complex_full() {
319 let c = Complex::from_str("1.5+2.5uj").unwrap();
320 let json = serde_json::to_string(&c).unwrap();
321 assert_eq!(json, "\"1.5+2.5uJ\"".replace("J", "j")); let parsed: Complex = serde_json::from_str(&json).unwrap();
323 assert_eq!(parsed, c);
324 }
325}