1use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3#[repr(C)]
5#[derive(Clone, Copy, Debug, PartialEq)]
6pub struct Complex<T> {
7 pub re: T,
8 pub im: T,
9}
10
11impl<T> Complex<T> {
12 #[inline]
13 pub fn new(re: T, im: T) -> Self {
14 Self { re, im }
15 }
16}
17
18impl Complex<f32> {
19 #[inline]
20 pub fn zero() -> Self {
21 Self { re: 0.0, im: 0.0 }
22 }
23
24 #[inline]
25 pub fn arg(self) -> f32 {
26 self.im.atan2(self.re)
27 }
28}
29
30impl Complex<f64> {
31 #[inline]
32 pub fn zero() -> Self {
33 Self { re: 0.0, im: 0.0 }
34 }
35
36 #[inline]
37 pub fn arg(self) -> f64 {
38 self.im.atan2(self.re)
39 }
40}
41
42impl<T> Complex<T>
43where
44 T: Copy + Neg<Output = T>,
45{
46 #[inline]
47 pub fn conj(self) -> Self {
48 Self {
49 re: self.re,
50 im: -self.im,
51 }
52 }
53}
54
55impl<T> Complex<T>
56where
57 T: Copy + Add<Output = T> + Mul<Output = T>,
58{
59 #[inline]
60 pub fn norm_sqr(self) -> T {
61 self.re * self.re + self.im * self.im
62 }
63}
64
65impl<T> Neg for Complex<T>
66where
67 T: Neg<Output = T>,
68{
69 type Output = Self;
70
71 fn neg(self) -> Self::Output {
72 Self {
73 re: -self.re,
74 im: -self.im,
75 }
76 }
77}
78
79impl<T> Add for Complex<T>
80where
81 T: Add<Output = T>,
82{
83 type Output = Self;
84
85 fn add(self, rhs: Self) -> Self::Output {
86 Self {
87 re: self.re + rhs.re,
88 im: self.im + rhs.im,
89 }
90 }
91}
92
93impl<T> Sub for Complex<T>
94where
95 T: Sub<Output = T>,
96{
97 type Output = Self;
98
99 fn sub(self, rhs: Self) -> Self::Output {
100 Self {
101 re: self.re - rhs.re,
102 im: self.im - rhs.im,
103 }
104 }
105}
106
107impl<T> Mul for Complex<T>
108where
109 T: Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T>,
110{
111 type Output = Self;
112
113 fn mul(self, rhs: Self) -> Self::Output {
114 Self {
115 re: self.re * rhs.re - self.im * rhs.im,
116 im: self.re * rhs.im + self.im * rhs.re,
117 }
118 }
119}
120
121impl<T> Mul<T> for Complex<T>
122where
123 T: Copy + Mul<Output = T>,
124{
125 type Output = Self;
126
127 fn mul(self, rhs: T) -> Self::Output {
128 Self {
129 re: self.re * rhs,
130 im: self.im * rhs,
131 }
132 }
133}
134
135impl<T> Div<T> for Complex<T>
136where
137 T: Copy + Div<Output = T>,
138{
139 type Output = Self;
140
141 fn div(self, rhs: T) -> Self::Output {
142 Self {
143 re: self.re / rhs,
144 im: self.im / rhs,
145 }
146 }
147}
148
149impl<T> AddAssign for Complex<T>
150where
151 T: AddAssign,
152{
153 fn add_assign(&mut self, rhs: Self) {
154 self.re += rhs.re;
155 self.im += rhs.im;
156 }
157}
158
159impl<T> SubAssign for Complex<T>
160where
161 T: SubAssign,
162{
163 fn sub_assign(&mut self, rhs: Self) {
164 self.re -= rhs.re;
165 self.im -= rhs.im;
166 }
167}
168
169impl<T> MulAssign for Complex<T>
170where
171 T: Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T>,
172{
173 fn mul_assign(&mut self, rhs: Self) {
174 *self = *self * rhs;
175 }
176}
177
178impl<T> DivAssign<T> for Complex<T>
179where
180 T: Copy + Div<Output = T>,
181{
182 fn div_assign(&mut self, rhs: T) {
183 *self = *self / rhs;
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::Complex;
190
191 #[test]
192 fn arithmetic_matches_expected_values() {
193 let a = Complex::new(1.0f64, 2.0);
194 let b = Complex::new(3.0f64, -1.0);
195 assert_eq!(a + b, Complex::new(4.0, 1.0));
196 assert_eq!(a - b, Complex::new(-2.0, 3.0));
197 assert_eq!(a * b, Complex::new(5.0, 5.0));
198 }
199
200 #[test]
201 fn scalar_ops_and_conjugate_work() {
202 let mut z = Complex::new(2.0f32, -4.0);
203 z /= 2.0;
204 assert_eq!(z, Complex::new(1.0, -2.0));
205 assert_eq!(z * 2.0, Complex::new(2.0, -4.0));
206 assert_eq!(z.conj(), Complex::new(1.0, 2.0));
207 }
208
209 #[test]
210 fn norm_and_phase_match_reference() {
211 let z = Complex::new(3.0f64, 4.0);
212 assert!((z.norm_sqr() - 25.0).abs() < 1e-12);
213 assert!((z.arg() - (4.0f64).atan2(3.0)).abs() < 1e-12);
214 }
215
216 #[test]
217 fn zero_constructs_origin() {
218 assert_eq!(Complex::<f64>::zero(), Complex::new(0.0, 0.0));
219 }
220}