Skip to main content

numra_fft/
complex.rs

1//! Minimal complex number type for FFT interfaces.
2//!
3//! Author: Moussa Leblouba
4//! Date: 9 February 2026
5//! Modified: 2 May 2026
6
7use numra_core::Scalar;
8use std::ops::{Add, Mul, Neg, Sub};
9
10/// A complex number with real and imaginary parts, generic over scalar type.
11#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct Complex<S: Scalar> {
13    pub re: S,
14    pub im: S,
15}
16
17/// Type alias for backward compatibility: `Complex<f64>`.
18pub type ComplexF64 = Complex<f64>;
19
20impl<S: Scalar> Complex<S> {
21    /// Create a new complex number.
22    pub fn new(re: S, im: S) -> Self {
23        Self { re, im }
24    }
25
26    /// Complex zero.
27    pub fn zero() -> Self {
28        Self {
29            re: S::ZERO,
30            im: S::ZERO,
31        }
32    }
33
34    /// Squared magnitude |z|^2 = re^2 + im^2.
35    pub fn norm_sqr(&self) -> S {
36        self.re * self.re + self.im * self.im
37    }
38
39    /// Magnitude |z|.
40    pub fn abs(&self) -> S {
41        self.norm_sqr().sqrt()
42    }
43
44    /// Complex conjugate.
45    pub fn conj(&self) -> Self {
46        Self {
47            re: self.re,
48            im: -self.im,
49        }
50    }
51
52    /// Phase angle (argument) in radians.
53    pub fn arg(&self) -> S {
54        self.im.atan2(self.re)
55    }
56}
57
58impl<S: Scalar> Add for Complex<S> {
59    type Output = Self;
60    fn add(self, rhs: Self) -> Self {
61        Self {
62            re: self.re + rhs.re,
63            im: self.im + rhs.im,
64        }
65    }
66}
67
68impl<S: Scalar> Sub for Complex<S> {
69    type Output = Self;
70    fn sub(self, rhs: Self) -> Self {
71        Self {
72            re: self.re - rhs.re,
73            im: self.im - rhs.im,
74        }
75    }
76}
77
78impl<S: Scalar> Mul for Complex<S> {
79    type Output = Self;
80    fn mul(self, rhs: Self) -> Self {
81        Self {
82            re: self.re * rhs.re - self.im * rhs.im,
83            im: self.re * rhs.im + self.im * rhs.re,
84        }
85    }
86}
87
88impl<S: Scalar> Mul<S> for Complex<S> {
89    type Output = Self;
90    fn mul(self, rhs: S) -> Self {
91        Self {
92            re: self.re * rhs,
93            im: self.im * rhs,
94        }
95    }
96}
97
98impl<S: Scalar> Neg for Complex<S> {
99    type Output = Self;
100    fn neg(self) -> Self {
101        Self {
102            re: -self.re,
103            im: -self.im,
104        }
105    }
106}
107
108/// Convert from rustfft's Complex type.
109impl From<rustfft::num_complex::Complex<f64>> for Complex<f64> {
110    fn from(c: rustfft::num_complex::Complex<f64>) -> Self {
111        Self { re: c.re, im: c.im }
112    }
113}
114
115/// Convert to rustfft's Complex type.
116impl From<Complex<f64>> for rustfft::num_complex::Complex<f64> {
117    fn from(c: Complex<f64>) -> Self {
118        Self::new(c.re, c.im)
119    }
120}
121
122/// Convert from faer's c64 complex type.
123impl From<faer::complex_native::c64> for Complex<f64> {
124    fn from(c: faer::complex_native::c64) -> Self {
125        Self { re: c.re, im: c.im }
126    }
127}
128
129/// Convert to faer's c64 complex type.
130impl From<Complex<f64>> for faer::complex_native::c64 {
131    fn from(c: Complex<f64>) -> faer::complex_native::c64 {
132        faer::complex_native::c64 { re: c.re, im: c.im }
133    }
134}
135
136/// Convert from faer's c32 complex type.
137impl From<faer::complex_native::c32> for Complex<f32> {
138    fn from(c: faer::complex_native::c32) -> Self {
139        Self { re: c.re, im: c.im }
140    }
141}
142
143/// Convert to faer's c32 complex type.
144impl From<Complex<f32>> for faer::complex_native::c32 {
145    fn from(c: Complex<f32>) -> faer::complex_native::c32 {
146        faer::complex_native::c32 { re: c.re, im: c.im }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_basic_ops() {
156        let a = Complex::new(1.0, 2.0);
157        let b = Complex::new(3.0, 4.0);
158
159        let sum = a + b;
160        assert!((sum.re - 4.0).abs() < 1e-14);
161        assert!((sum.im - 6.0).abs() < 1e-14);
162
163        let prod = a * b;
164        // (1+2i)(3+4i) = 3+4i+6i+8i^2 = -5+10i
165        assert!((prod.re - (-5.0)).abs() < 1e-14);
166        assert!((prod.im - 10.0).abs() < 1e-14);
167    }
168
169    #[test]
170    fn test_magnitude() {
171        let c = Complex::new(3.0, 4.0);
172        assert!((c.abs() - 5.0).abs() < 1e-14);
173        assert!((c.norm_sqr() - 25.0).abs() < 1e-14);
174    }
175
176    #[test]
177    fn test_conjugate() {
178        let c = Complex::new(1.0, 2.0);
179        let cc = c.conj();
180        assert!((cc.re - 1.0).abs() < 1e-14);
181        assert!((cc.im - (-2.0)).abs() < 1e-14);
182    }
183
184    #[test]
185    fn test_faer_c64_roundtrip() {
186        let orig = Complex::new(3.14, -2.71);
187        let c64_val: faer::complex_native::c64 = orig.into();
188        let back: Complex<f64> = c64_val.into();
189        assert!((back.re - orig.re).abs() < 1e-14);
190        assert!((back.im - orig.im).abs() < 1e-14);
191    }
192
193    #[test]
194    fn test_faer_c32_roundtrip() {
195        let orig: Complex<f32> = Complex::new(1.5_f32, -0.5_f32);
196        let c32_val: faer::complex_native::c32 = orig.into();
197        let back: Complex<f32> = c32_val.into();
198        assert!((back.re - 1.5_f32).abs() < 1e-6);
199        assert!((back.im - (-0.5_f32)).abs() < 1e-6);
200    }
201
202    #[test]
203    fn test_f32_complex_basic() {
204        let a: Complex<f32> = Complex::new(1.0_f32, 2.0_f32);
205        let b: Complex<f32> = Complex::new(3.0_f32, 4.0_f32);
206        let sum = a + b;
207        assert!((sum.re - 4.0_f32).abs() < 1e-6);
208        assert!((sum.im - 6.0_f32).abs() < 1e-6);
209    }
210}