Skip to main content

neco_complex/
lib.rs

1use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Lightweight complex number type shared by FFT- and solver-adjacent crates.
4#[repr(C)]
5#[derive(Clone, Copy, Debug, PartialEq)]
6pub struct Complex<T> {
7    pub re: T,
8    pub im: T,
9}
10
11impl<T> Complex<T> {
12    #[inline]
13    pub fn new(re: T, im: T) -> Self {
14        Self { re, im }
15    }
16}
17
18impl Complex<f32> {
19    #[inline]
20    pub fn zero() -> Self {
21        Self { re: 0.0, im: 0.0 }
22    }
23
24    #[inline]
25    pub fn arg(self) -> f32 {
26        self.im.atan2(self.re)
27    }
28}
29
30impl Complex<f64> {
31    #[inline]
32    pub fn zero() -> Self {
33        Self { re: 0.0, im: 0.0 }
34    }
35
36    #[inline]
37    pub fn arg(self) -> f64 {
38        self.im.atan2(self.re)
39    }
40}
41
42impl<T> Complex<T>
43where
44    T: Copy + Neg<Output = T>,
45{
46    #[inline]
47    pub fn conj(self) -> Self {
48        Self {
49            re: self.re,
50            im: -self.im,
51        }
52    }
53}
54
55impl<T> Complex<T>
56where
57    T: Copy + Add<Output = T> + Mul<Output = T>,
58{
59    #[inline]
60    pub fn norm_sqr(self) -> T {
61        self.re * self.re + self.im * self.im
62    }
63}
64
65impl<T> Neg for Complex<T>
66where
67    T: Neg<Output = T>,
68{
69    type Output = Self;
70
71    fn neg(self) -> Self::Output {
72        Self {
73            re: -self.re,
74            im: -self.im,
75        }
76    }
77}
78
79impl<T> Add for Complex<T>
80where
81    T: Add<Output = T>,
82{
83    type Output = Self;
84
85    fn add(self, rhs: Self) -> Self::Output {
86        Self {
87            re: self.re + rhs.re,
88            im: self.im + rhs.im,
89        }
90    }
91}
92
93impl<T> Sub for Complex<T>
94where
95    T: Sub<Output = T>,
96{
97    type Output = Self;
98
99    fn sub(self, rhs: Self) -> Self::Output {
100        Self {
101            re: self.re - rhs.re,
102            im: self.im - rhs.im,
103        }
104    }
105}
106
107impl<T> Mul for Complex<T>
108where
109    T: Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T>,
110{
111    type Output = Self;
112
113    fn mul(self, rhs: Self) -> Self::Output {
114        Self {
115            re: self.re * rhs.re - self.im * rhs.im,
116            im: self.re * rhs.im + self.im * rhs.re,
117        }
118    }
119}
120
121impl<T> Mul<T> for Complex<T>
122where
123    T: Copy + Mul<Output = T>,
124{
125    type Output = Self;
126
127    fn mul(self, rhs: T) -> Self::Output {
128        Self {
129            re: self.re * rhs,
130            im: self.im * rhs,
131        }
132    }
133}
134
135impl<T> Div<T> for Complex<T>
136where
137    T: Copy + Div<Output = T>,
138{
139    type Output = Self;
140
141    fn div(self, rhs: T) -> Self::Output {
142        Self {
143            re: self.re / rhs,
144            im: self.im / rhs,
145        }
146    }
147}
148
149impl<T> AddAssign for Complex<T>
150where
151    T: AddAssign,
152{
153    fn add_assign(&mut self, rhs: Self) {
154        self.re += rhs.re;
155        self.im += rhs.im;
156    }
157}
158
159impl<T> SubAssign for Complex<T>
160where
161    T: SubAssign,
162{
163    fn sub_assign(&mut self, rhs: Self) {
164        self.re -= rhs.re;
165        self.im -= rhs.im;
166    }
167}
168
169impl<T> MulAssign for Complex<T>
170where
171    T: Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T>,
172{
173    fn mul_assign(&mut self, rhs: Self) {
174        *self = *self * rhs;
175    }
176}
177
178impl<T> DivAssign<T> for Complex<T>
179where
180    T: Copy + Div<Output = T>,
181{
182    fn div_assign(&mut self, rhs: T) {
183        *self = *self / rhs;
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::Complex;
190
191    #[test]
192    fn arithmetic_matches_expected_values() {
193        let a = Complex::new(1.0f64, 2.0);
194        let b = Complex::new(3.0f64, -1.0);
195        assert_eq!(a + b, Complex::new(4.0, 1.0));
196        assert_eq!(a - b, Complex::new(-2.0, 3.0));
197        assert_eq!(a * b, Complex::new(5.0, 5.0));
198    }
199
200    #[test]
201    fn scalar_ops_and_conjugate_work() {
202        let mut z = Complex::new(2.0f32, -4.0);
203        z /= 2.0;
204        assert_eq!(z, Complex::new(1.0, -2.0));
205        assert_eq!(z * 2.0, Complex::new(2.0, -4.0));
206        assert_eq!(z.conj(), Complex::new(1.0, 2.0));
207    }
208
209    #[test]
210    fn norm_and_phase_match_reference() {
211        let z = Complex::new(3.0f64, 4.0);
212        assert!((z.norm_sqr() - 25.0).abs() < 1e-12);
213        assert!((z.arg() - (4.0f64).atan2(3.0)).abs() < 1e-12);
214    }
215
216    #[test]
217    fn zero_constructs_origin() {
218        assert_eq!(Complex::<f64>::zero(), Complex::new(0.0, 0.0));
219    }
220}