1use numra_core::Scalar;
8use std::ops::{Add, Mul, Neg, Sub};
9
10#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct Complex<S: Scalar> {
13 pub re: S,
14 pub im: S,
15}
16
17pub type ComplexF64 = Complex<f64>;
19
20impl<S: Scalar> Complex<S> {
21 pub fn new(re: S, im: S) -> Self {
23 Self { re, im }
24 }
25
26 pub fn zero() -> Self {
28 Self {
29 re: S::ZERO,
30 im: S::ZERO,
31 }
32 }
33
34 pub fn norm_sqr(&self) -> S {
36 self.re * self.re + self.im * self.im
37 }
38
39 pub fn abs(&self) -> S {
41 self.norm_sqr().sqrt()
42 }
43
44 pub fn conj(&self) -> Self {
46 Self {
47 re: self.re,
48 im: -self.im,
49 }
50 }
51
52 pub fn arg(&self) -> S {
54 self.im.atan2(self.re)
55 }
56}
57
58impl<S: Scalar> Add for Complex<S> {
59 type Output = Self;
60 fn add(self, rhs: Self) -> Self {
61 Self {
62 re: self.re + rhs.re,
63 im: self.im + rhs.im,
64 }
65 }
66}
67
68impl<S: Scalar> Sub for Complex<S> {
69 type Output = Self;
70 fn sub(self, rhs: Self) -> Self {
71 Self {
72 re: self.re - rhs.re,
73 im: self.im - rhs.im,
74 }
75 }
76}
77
78impl<S: Scalar> Mul for Complex<S> {
79 type Output = Self;
80 fn mul(self, rhs: Self) -> Self {
81 Self {
82 re: self.re * rhs.re - self.im * rhs.im,
83 im: self.re * rhs.im + self.im * rhs.re,
84 }
85 }
86}
87
88impl<S: Scalar> Mul<S> for Complex<S> {
89 type Output = Self;
90 fn mul(self, rhs: S) -> Self {
91 Self {
92 re: self.re * rhs,
93 im: self.im * rhs,
94 }
95 }
96}
97
98impl<S: Scalar> Neg for Complex<S> {
99 type Output = Self;
100 fn neg(self) -> Self {
101 Self {
102 re: -self.re,
103 im: -self.im,
104 }
105 }
106}
107
108impl From<rustfft::num_complex::Complex<f64>> for Complex<f64> {
110 fn from(c: rustfft::num_complex::Complex<f64>) -> Self {
111 Self { re: c.re, im: c.im }
112 }
113}
114
115impl From<Complex<f64>> for rustfft::num_complex::Complex<f64> {
117 fn from(c: Complex<f64>) -> Self {
118 Self::new(c.re, c.im)
119 }
120}
121
122impl From<faer::complex_native::c64> for Complex<f64> {
124 fn from(c: faer::complex_native::c64) -> Self {
125 Self { re: c.re, im: c.im }
126 }
127}
128
129impl From<Complex<f64>> for faer::complex_native::c64 {
131 fn from(c: Complex<f64>) -> faer::complex_native::c64 {
132 faer::complex_native::c64 { re: c.re, im: c.im }
133 }
134}
135
136impl From<faer::complex_native::c32> for Complex<f32> {
138 fn from(c: faer::complex_native::c32) -> Self {
139 Self { re: c.re, im: c.im }
140 }
141}
142
143impl From<Complex<f32>> for faer::complex_native::c32 {
145 fn from(c: Complex<f32>) -> faer::complex_native::c32 {
146 faer::complex_native::c32 { re: c.re, im: c.im }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_basic_ops() {
156 let a = Complex::new(1.0, 2.0);
157 let b = Complex::new(3.0, 4.0);
158
159 let sum = a + b;
160 assert!((sum.re - 4.0).abs() < 1e-14);
161 assert!((sum.im - 6.0).abs() < 1e-14);
162
163 let prod = a * b;
164 assert!((prod.re - (-5.0)).abs() < 1e-14);
166 assert!((prod.im - 10.0).abs() < 1e-14);
167 }
168
169 #[test]
170 fn test_magnitude() {
171 let c = Complex::new(3.0, 4.0);
172 assert!((c.abs() - 5.0).abs() < 1e-14);
173 assert!((c.norm_sqr() - 25.0).abs() < 1e-14);
174 }
175
176 #[test]
177 fn test_conjugate() {
178 let c = Complex::new(1.0, 2.0);
179 let cc = c.conj();
180 assert!((cc.re - 1.0).abs() < 1e-14);
181 assert!((cc.im - (-2.0)).abs() < 1e-14);
182 }
183
184 #[test]
185 fn test_faer_c64_roundtrip() {
186 let orig = Complex::new(3.14, -2.71);
187 let c64_val: faer::complex_native::c64 = orig.into();
188 let back: Complex<f64> = c64_val.into();
189 assert!((back.re - orig.re).abs() < 1e-14);
190 assert!((back.im - orig.im).abs() < 1e-14);
191 }
192
193 #[test]
194 fn test_faer_c32_roundtrip() {
195 let orig: Complex<f32> = Complex::new(1.5_f32, -0.5_f32);
196 let c32_val: faer::complex_native::c32 = orig.into();
197 let back: Complex<f32> = c32_val.into();
198 assert!((back.re - 1.5_f32).abs() < 1e-6);
199 assert!((back.im - (-0.5_f32)).abs() < 1e-6);
200 }
201
202 #[test]
203 fn test_f32_complex_basic() {
204 let a: Complex<f32> = Complex::new(1.0_f32, 2.0_f32);
205 let b: Complex<f32> = Complex::new(3.0_f32, 4.0_f32);
206 let sum = a + b;
207 assert!((sum.re - 4.0_f32).abs() < 1e-6);
208 assert!((sum.im - 6.0_f32).abs() < 1e-6);
209 }
210}