reda_unit/
complex.rs

1use core::fmt;
2use std::{ops::{Add, Div, Mul, Sub}, str::FromStr};
3
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5
6use crate::Number;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub struct Complex {
10    pub re: Number,
11    pub im: Number,
12}
13
14impl Complex {
15    pub fn new<N1: Into<Number>, N2: Into<Number>>(re: N1, im: N2) -> Self {
16        Self { re: re.into(), im: im.into() }
17    }
18
19    pub fn parts(&self) -> (Number, Number) {
20        (self.re, self.im)
21    }
22
23    pub fn conjugate(&self) -> Self {
24        Self {
25            re: self.re,
26            im: -self.im,
27        }
28    }
29
30    pub fn norm_sqr(&self) -> Number {
31        self.re * self.re + self.im * self.im
32    }
33
34    pub fn abs(&self) -> Number {
35        (self.re.powf(2.) + self.im.powf(2.)).sqrt()
36    }
37
38    pub fn arg(&self) -> Number {
39        self.im.atan2(self.re)
40    }
41}
42
43impl Add for Complex {
44    type Output = Self;
45    fn add(self, rhs: Self) -> Self {
46        let (a, b) = self.parts();
47        let (c, d) = rhs.parts();
48        Complex::new(a + c, b + d)
49    }
50}
51
52impl Sub for Complex {
53    type Output = Self;
54    fn sub(self, rhs: Self) -> Self {
55        let (a, b) = self.parts();
56        let (c, d) = rhs.parts();
57        Complex::new(a - c, b - d)
58    }
59}
60
61impl Mul for Complex {
62    type Output = Self;
63    fn mul(self, rhs: Self) -> Self {
64        let (a, b) = self.parts();
65        let (c, d) = rhs.parts();
66        // (a + bj) * (c + dj) = (ac - bd) + (ad + bc)j
67        Complex::new(a * c - b * d, a * d + b * c)
68    }
69}
70
71impl Div for Complex {
72    type Output = Self;
73    fn div(self, rhs: Self) -> Self {
74        let (a, b) = self.parts();
75        let (c, d) = rhs.parts();
76        let denom = c * c + d * d;
77        if denom.is_zero() {
78            panic!("Divide by zero in complex division");
79        }
80        let re = (a * c + b * d) / denom;
81        let im = (b * c - a * d) / denom;
82        Complex::new(re, im)
83    }
84}
85
86impl fmt::Display for Complex {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        let re_is_zero = self.re.to_f64() == 0.0;
89        let im_is_zero = self.im.to_f64() == 0.0;
90
91        let precision = f.precision();
92
93        match (re_is_zero, im_is_zero) {
94            (true, true) => write!(f, "0"),
95            (false, true) => {
96                match precision {
97                    Some(p) => write!(f, "{:.*}", p, self.re),
98                    None => write!(f, "{}", self.re),
99                }
100            }
101            (true, false) => {
102                if self.im.to_f64() >= 0.0 {
103                    match precision {
104                        Some(p) => write!(f, "{:.*}j", p, self.im),
105                        None => write!(f, "{}j", self.im),
106                    }
107                } else {
108                    match precision {
109                        Some(p) => write!(f, "-{:.*}j", p, -self.im),
110                        None => write!(f, "-{}j", -self.im),
111                    }
112                }
113            }
114            (false, false) => {
115                let re_fmt = match precision {
116                    Some(p) => format!("{:.*}", p, self.re),
117                    None => format!("{}", self.re),
118                };
119                let im_fmt = match precision {
120                    Some(p) => format!("{:.*}", p, self.im),
121                    None => format!("{}", self.im),
122                };
123
124                if self.im.to_f64() >= 0.0 {
125                    write!(f, "{}+{}j", re_fmt, im_fmt)
126                } else {
127                    write!(f, "{}-{}j", re_fmt, im_fmt.trim_start_matches('-'))
128                }
129            }
130        }
131    }
132}
133
134impl FromStr for Complex {
135    type Err = String;
136
137    fn from_str(s: &str) -> Result<Self, Self::Err> {
138        fn find_real_imag_separator(s: &str) -> Option<usize> {
139            let mut chars = s.char_indices().peekable();
140            chars.next(); 
141
142            while let Some((i, c)) = chars.next() {
143                if (c == '+' || c == '-') && s[i+1..].contains('j') {
144                    return Some(i);
145                }
146            }
147            None
148        }
149        
150        let s = s.trim();
151        
152        // '+'  '-' 
153        if let Some(idx) = find_real_imag_separator(s) {
154            let (real_part, imag_part) = s.split_at(idx);
155            let real = real_part.trim().parse::<Number>()
156                .map_err(|e| format!("Parse real part error: {}", e))?;
157            let imag_str = imag_part.trim_end_matches('j');
158            let imag = imag_str.parse::<Number>()
159                .map_err(|e| format!("Parse imaginary part error: {}", e))?;
160            return Ok(Complex { re: real, im: imag });
161        }
162        
163        // 1.2j、3uj
164        if s.ends_with('j') {
165            let imag_part = &s[..s.len()-1];
166            let im = imag_part.parse::<Number>()
167                .map_err(|e| format!("Parse imaginary part error: {}", e))?;
168            return Ok(Complex { re: Number::zero(), im });
169        }
170
171        let re = s.parse::<Number>()
172            .map_err(|e| format!("Parse real number error: {}", e))?;
173        Ok(Complex { re, im: Number::zero() })
174    }
175}
176
177impl Serialize for Complex {
178    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
179    where
180        S: Serializer,
181    {
182        let s = self.to_string();
183        serializer.serialize_str(&s)
184    }
185}
186
187impl<'de> Deserialize<'de> for Complex {
188    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
189    where
190        D: Deserializer<'de>,
191    {
192        let s = String::deserialize(deserializer)?;
193        Complex::from_str(&s).map_err(serde::de::Error::custom)
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use crate::{complex, num, Suffix};
200
201    use super::*;
202
203    #[test]
204    fn test_real_only() {
205        let c = Complex::from_str("1.5").unwrap();
206        assert_eq!(c.re, Number::new(1.5, Suffix::None));
207        assert_eq!(c.im, Number::new(0.0, Suffix::None));
208
209        let c = Complex::from_str("2.2u").unwrap();
210        assert_eq!(c.re, Number::new(2.2, Suffix::Micro));
211        assert_eq!(c.im, Number::new(0.0, Suffix::None));
212    }
213
214    #[test]
215    fn test_imag_only() {
216        let c = Complex::from_str("+3.3j").unwrap();
217        assert_eq!(c.re, Number::new(0.0, Suffix::None));
218        assert_eq!(c.im, Number::new(3.3, Suffix::None));
219
220        let c = Complex::from_str("-5.5mj").unwrap();
221        assert_eq!(c.re, Number::new(0.0, Suffix::None));
222        assert_eq!(c.im, Number::new(-5.5, Suffix::Milli));
223
224        let c = Complex::from_str("5.5mj").unwrap();
225        assert_eq!(c.re, Number::new(0.0, Suffix::None));
226        assert_eq!(c.im, Number::new(5.5, Suffix::Milli));
227    }
228
229    #[test]
230    fn test_real_imag() {
231        let c = Complex::from_str("1.1+2.2j").unwrap();
232        assert_eq!(c.re, Number::new(1.1, Suffix::None));
233        assert_eq!(c.im, Number::new(2.2, Suffix::None));
234
235        let c = Complex::from_str("-3.0-4.4uj").unwrap();
236        assert_eq!(c.re, Number::new(-3.0, Suffix::None));
237        assert_eq!(c.im, Number::new(-4.4, Suffix::Micro));
238
239        let c = Complex::from_str("10.5-7.5nj").unwrap();
240        assert_eq!(c.re, Number::new(10.5, Suffix::None));
241        assert_eq!(c.im, Number::new(-7.5, Suffix::Nano));
242    }
243
244    #[test]
245    fn test_error_cases() {
246        assert!(Complex::from_str("hello").is_err());
247        assert!(Complex::from_str("1.2+badj").is_err());
248        assert!(Complex::from_str("1.2+3.3").is_err());
249        assert!(Complex::from_str("j3.3").is_err());
250    }
251
252    #[test]
253    fn test_creation() {
254        let c = Complex { re: num!(3.0), im: num!(4.0) };
255        let _ = complex!(3.0, 4.0);
256        assert_eq!(c.re, num!(3.0));
257        assert_eq!(c.im, num!(4.0));
258    }
259
260    #[test]
261    fn test_equality() {
262        let a = Complex { re: num!(1.0), im: num!(2.0) };
263        let b = Complex { re: num!(1.0), im: num!(2.0) };
264        let c = Complex { re: num!(1.0), im: num!(3.0) };
265        assert_eq!(a, b);
266        assert_ne!(a, c);
267    }
268
269    #[test]
270    fn test_addition() {
271        let a = Complex { re: num!(1.0), im: num!(2.0) };
272        let b = Complex { re: num!(3.0), im: num!(4.0) };
273        let sum = Complex { re: num!(4.0), im: num!(6.0) };
274        assert_eq!(a + b, sum);
275    }
276
277    #[test]
278    fn test_multiplication() {
279        let a = Complex { re: num!(1.0), im: num!(2.0) };
280        let b = Complex { re: num!(3.0), im: num!(4.0) };
281        // (1 + 2i)(3 + 4i) = (3 - 8) + (4 + 6)i = -5 + 10i
282        let product = Complex { re: num!(-5.0), im: num!(10.0) };
283        assert_eq!(a * b, product);
284    }
285
286    #[test]
287    fn test_conjugate() {
288        let a = Complex { re: num!(5.0), im: num!(-7.0) };
289        let conj = Complex { re: num!(5.0), im: num!(7.0) };
290        assert_eq!(a.conjugate(), conj);
291    }
292
293    #[test]
294    fn test_magnitude_squared() {
295        let c = Complex { re: num!(3.0), im: num!(4.0) };
296        assert_eq!(c.norm_sqr(), num!(25.0));
297    }
298
299    #[test]
300    fn test_serialize_deserialize_complex_real_only() {
301        let c = Complex::from_str("3.3u").unwrap();
302        let json = serde_json::to_string(&c).unwrap();
303        assert_eq!(json, "\"3.3u\"");
304        let parsed: Complex = serde_json::from_str(&json).unwrap();
305        assert_eq!(parsed, c);
306    }
307
308    #[test]
309    fn test_serialize_deserialize_complex_imag_only() {
310        let c = Complex::from_str("2.2mj").unwrap();
311        let json = serde_json::to_string(&c).unwrap();
312        assert_eq!(json, "\"2.2mj\"");
313        let parsed: Complex = serde_json::from_str(&json).unwrap();
314        assert_eq!(parsed, c);
315    }
316
317    #[test]
318    fn test_serialize_deserialize_complex_full() {
319        let c = Complex::from_str("1.5+2.5uj").unwrap();
320        let json = serde_json::to_string(&c).unwrap();
321        assert_eq!(json, "\"1.5+2.5uJ\"".replace("J", "j")); // 复数内部序列化成小写 j
322        let parsed: Complex = serde_json::from_str(&json).unwrap();
323        assert_eq!(parsed, c);
324    }
325}