rapl/complex/
mod.rs

1use num_traits::{Num, Signed, Unsigned};
2use std::{
3    fmt::{Debug, Display},
4    ops::{Add, AddAssign, Div, Mul, MulAssign, Neg},
5};
6mod cast;
7mod floats;
8mod ops;
9mod primitives;
10
11pub use crate::complex::primitives::Imag;
12
13#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
14pub struct C<T: Copy + PartialEq>(pub T, pub T);
15
16impl<T: Copy + PartialEq + Display + Signed> Display for C<T> {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self.1.is_positive() {
19            true => write!(f, "{}+{}i", self.0, self.1),
20            false => write!(f, "{}{}i", self.0, self.1),
21        }
22    }
23}
24
25impl<T: Copy + PartialEq> C<T> {
26    pub fn re(&self) -> T {
27        self.0
28    }
29    pub fn im(&self) -> T {
30        self.1
31    }
32}
33
34impl<T: Copy + PartialEq + Neg<Output = T>> C<T> {
35    pub fn conj(&self) -> C<T> {
36        C(self.0, -self.1)
37    }
38}
39
40impl<T: Copy + PartialEq + Add<Output = T> + Mul<Output = T>> C<T> {
41    pub fn r_square(&self) -> T {
42        self.0 * self.0 + self.1 * self.1
43    }
44}
45
46impl<T> C<T>
47where
48    T: Copy + PartialEq + Neg<Output = T> + Div<Output = T> + Mul<Output = T> + Add<Output = T>,
49{
50    pub fn inv(&self) -> Self {
51        let r_sq = self.r_square();
52        C(self.0 / r_sq, -self.1 / r_sq)
53    }
54}
55
56impl<T> C<T>
57where
58    C<T>: MulAssign + Debug,
59    T: Copy + PartialEq + Num,
60{
61    // this seems to to be relatively fast for n < 100, but we should find a better way for larger n's
62    pub fn powi(&self, n: i32) -> Self {
63        if n == 0 {
64            return C(T::one(), T::zero());
65        } else if n > 0 {
66            let mut out = self.clone();
67            for _ in 1..n {
68                out *= *self;
69            }
70            return out;
71        } else {
72            let mut out = *self;
73            for _ in 1..-n {
74                out *= *self;
75            }
76            let out = C(T::one(), T::zero()) / out;
77            return out;
78        }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    #![allow(non_upper_case_globals)]
85    use std::f32::consts::{FRAC_PI_2, LN_2, PI};
86    use std::f64::consts::FRAC_1_SQRT_2;
87
88    use super::*;
89
90    pub const _0_0: C<f64> = C(0.0, 0.0);
91    pub const _1_0: C<f64> = C(1.0, 0.0);
92    pub const _0_1: C<f64> = C(0.0, 1.0);
93    pub const _n1_0: C<f64> = C(-1.0, 0.0);
94    pub const _0_n1: C<f64> = C(-1.0, 0.0);
95    pub const _1_1: C<f64> = C(-1.0, 0.0);
96    pub const _2_n1: C<f64> = C(-2.0, -1.0);
97    pub const unit: C<f64> = C(FRAC_1_SQRT_2, FRAC_1_SQRT_2);
98    pub const all_z: [C<f64>; 8] = [_0_0, _1_0, _0_1, _n1_0, _0_n1, _1_1, _2_n1, unit];
99
100    fn approx_epsilon(a: C<f64>, b: C<f64>, epsilon: f64) -> bool {
101        let approx = (a == b) || (a - b).abs() < epsilon;
102        if !approx {
103            println!("Error: {:?} != {:?}", a, b)
104        }
105        approx
106    }
107
108    fn approx(a: C<f64>, b: C<f64>) -> bool {
109        approx_epsilon(a, b, 1e-10)
110    }
111    #[test]
112    fn add() {
113        assert_eq!(1 + 1.i(), C(1, 1));
114        assert_eq!(1.i() + 1, C(1, 1));
115        assert_eq!(C(0., 2.) + C(2., 3.), C(2., 5.));
116    }
117
118    #[test]
119    fn sub() {
120        assert_eq!(1 - 1.i(), C(1, -1));
121        assert_eq!(1.i() - 1, C(-1, 1));
122        assert_eq!(C(0., 2.) - C(2., 3.), C(-2., -1.));
123    }
124
125    #[test]
126    fn mul() {
127        let c1 = 2 + 3.i();
128        let c2 = 4 + 5.i();
129        let expected = -7 + 22.i();
130        assert_eq!(c1 * c2, expected);
131    }
132
133    #[test]
134    fn division() {
135        let c1 = C(2., 3.);
136        let c2 = C(4., 5.);
137        let expected = C(23. / 41., 2. / 41.);
138        assert_eq!(c1 / c2, expected);
139    }
140    #[test]
141    fn assing() {
142        let mut z = C(0, 0);
143        z += 2;
144        assert_eq!(z, C(2, 0));
145        z -= 4.i();
146        assert_eq!(z, C(2, -4));
147        z *= 3;
148        assert_eq!(z, C(6, -12));
149        z /= C(2, 0);
150        assert_eq!(z, C(3, -6));
151    }
152    #[test]
153    fn conj() {
154        let a: C<i32> = 2 + 3.i();
155        assert!((a * a.conj()).re() == a.r_square())
156    }
157
158    #[test]
159    fn from_num() {
160        let a: u8 = 42;
161        let a_complex = C::from(a);
162        assert_eq!(a_complex, C(42, 0));
163
164        let a: f32 = 42.0;
165        let a_complex = C::from(a);
166        assert_eq!(a_complex, C(42.0, 0.0));
167
168        let a: i32 = -42;
169        let a_complex = C::from(a);
170        assert_eq!(a_complex, C(-42, 0));
171    }
172
173    #[test]
174    fn ln() {
175        let a = C(0.0, 1.0);
176        assert_eq!(a.ln(), C(0.0, FRAC_PI_2));
177
178        let a = C(2.0, 0.0);
179        assert_eq!(a.ln(), C(LN_2, 0.0));
180
181        let a = C(-1.0, 0.0);
182        assert_eq!(a.ln(), C(0.0, PI));
183    }
184
185    #[test]
186    fn abs() {
187        assert_eq!(_0_1.abs(), 1.0);
188        assert_eq!(_1_0.abs(), 1.0);
189        assert_eq!(_n1_0.abs(), 1.0);
190        assert_eq!(_0_n1.abs(), 1.0);
191        assert_eq!(unit.abs(), 1.0);
192    }
193
194    #[test]
195    fn sqrt() {
196        for n in (0..100).map(f64::from) {
197            let n2 = n * n;
198            assert!(approx(C(n2, 0.).sqrt(), C(n, 0.)));
199            assert!(approx(C(-n2, 0.).sqrt(), C(0., n)));
200            assert!(approx(C(-n2, -0.).sqrt(), C(0.0, -n)));
201        }
202        let z2: C<f64> = 0.25 + 0.0.i();
203        assert_eq!(z2.sqrt(), C(0.5, 0.));
204        for c in all_z {
205            assert!(approx(c.conj().sqrt(), c.sqrt().conj()));
206            assert!(approx(c.sqrt() * c.sqrt(), c));
207            assert!(
208                -std::f64::consts::FRAC_PI_2 <= c.sqrt().arg()
209                    && c.sqrt().arg() <= std::f64::consts::FRAC_PI_2
210            );
211        }
212    }
213
214    #[test]
215    fn powi() {
216        let z1 = C(2, 0);
217        assert_eq!(z1.powi(3), C(8, 0));
218        let z2 = 2.i();
219        assert_eq!(z2.powi(4), C(16, 0));
220        let z3 = C(3, -5);
221        assert_eq!(z3.clone().powi(3), z3.clone() * z3.clone() * z3);
222        assert_eq!(_2_n1.powi(2), _2_n1 * _2_n1);
223        assert_eq!(C(5, 10).powi(0), C(1, 0));
224        assert_eq!(2.0.i().powi(-2), C(-1. / 4., 0.));
225    }
226    #[test]
227    fn powf() {
228        assert!(approx(_2_n1.powf(2.), _2_n1 * _2_n1));
229        assert!(approx(_2_n1.powf(0.), C(1., 0.)));
230        assert!(approx(_0_1.powf(4.), C(1., 0.)))
231    }
232    #[test]
233    fn powc() {
234        assert!(approx(_2_n1.powc(C(2., 0.)), _2_n1 * _2_n1));
235        assert!(approx(_2_n1.powc(C(0., 0.)), C(1., 0.)));
236        //form python
237        //>>> z = 2.0 + 0.5j
238        //>>> z**z
239        //(2.4767939208048335+2.8290270856372506j)
240        let z: C<f64> = 2.0 + 0.5.i();
241        assert!(approx(
242            z.powc(z.clone()),
243            C(2.4767939208048335, 2.8290270856372506)
244        ))
245    }
246}