1use num_traits::{Num, Signed, Unsigned};
2use std::{
3 fmt::{Debug, Display},
4 ops::{Add, AddAssign, Div, Mul, MulAssign, Neg},
5};
6mod cast;
7mod floats;
8mod ops;
9mod primitives;
10
11pub use crate::complex::primitives::Imag;
12
13#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
14pub struct C<T: Copy + PartialEq>(pub T, pub T);
15
16impl<T: Copy + PartialEq + Display + Signed> Display for C<T> {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self.1.is_positive() {
19 true => write!(f, "{}+{}i", self.0, self.1),
20 false => write!(f, "{}{}i", self.0, self.1),
21 }
22 }
23}
24
25impl<T: Copy + PartialEq> C<T> {
26 pub fn re(&self) -> T {
27 self.0
28 }
29 pub fn im(&self) -> T {
30 self.1
31 }
32}
33
34impl<T: Copy + PartialEq + Neg<Output = T>> C<T> {
35 pub fn conj(&self) -> C<T> {
36 C(self.0, -self.1)
37 }
38}
39
40impl<T: Copy + PartialEq + Add<Output = T> + Mul<Output = T>> C<T> {
41 pub fn r_square(&self) -> T {
42 self.0 * self.0 + self.1 * self.1
43 }
44}
45
46impl<T> C<T>
47where
48 T: Copy + PartialEq + Neg<Output = T> + Div<Output = T> + Mul<Output = T> + Add<Output = T>,
49{
50 pub fn inv(&self) -> Self {
51 let r_sq = self.r_square();
52 C(self.0 / r_sq, -self.1 / r_sq)
53 }
54}
55
56impl<T> C<T>
57where
58 C<T>: MulAssign + Debug,
59 T: Copy + PartialEq + Num,
60{
61 pub fn powi(&self, n: i32) -> Self {
63 if n == 0 {
64 return C(T::one(), T::zero());
65 } else if n > 0 {
66 let mut out = self.clone();
67 for _ in 1..n {
68 out *= *self;
69 }
70 return out;
71 } else {
72 let mut out = *self;
73 for _ in 1..-n {
74 out *= *self;
75 }
76 let out = C(T::one(), T::zero()) / out;
77 return out;
78 }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 #![allow(non_upper_case_globals)]
85 use std::f32::consts::{FRAC_PI_2, LN_2, PI};
86 use std::f64::consts::FRAC_1_SQRT_2;
87
88 use super::*;
89
90 pub const _0_0: C<f64> = C(0.0, 0.0);
91 pub const _1_0: C<f64> = C(1.0, 0.0);
92 pub const _0_1: C<f64> = C(0.0, 1.0);
93 pub const _n1_0: C<f64> = C(-1.0, 0.0);
94 pub const _0_n1: C<f64> = C(-1.0, 0.0);
95 pub const _1_1: C<f64> = C(-1.0, 0.0);
96 pub const _2_n1: C<f64> = C(-2.0, -1.0);
97 pub const unit: C<f64> = C(FRAC_1_SQRT_2, FRAC_1_SQRT_2);
98 pub const all_z: [C<f64>; 8] = [_0_0, _1_0, _0_1, _n1_0, _0_n1, _1_1, _2_n1, unit];
99
100 fn approx_epsilon(a: C<f64>, b: C<f64>, epsilon: f64) -> bool {
101 let approx = (a == b) || (a - b).abs() < epsilon;
102 if !approx {
103 println!("Error: {:?} != {:?}", a, b)
104 }
105 approx
106 }
107
108 fn approx(a: C<f64>, b: C<f64>) -> bool {
109 approx_epsilon(a, b, 1e-10)
110 }
111 #[test]
112 fn add() {
113 assert_eq!(1 + 1.i(), C(1, 1));
114 assert_eq!(1.i() + 1, C(1, 1));
115 assert_eq!(C(0., 2.) + C(2., 3.), C(2., 5.));
116 }
117
118 #[test]
119 fn sub() {
120 assert_eq!(1 - 1.i(), C(1, -1));
121 assert_eq!(1.i() - 1, C(-1, 1));
122 assert_eq!(C(0., 2.) - C(2., 3.), C(-2., -1.));
123 }
124
125 #[test]
126 fn mul() {
127 let c1 = 2 + 3.i();
128 let c2 = 4 + 5.i();
129 let expected = -7 + 22.i();
130 assert_eq!(c1 * c2, expected);
131 }
132
133 #[test]
134 fn division() {
135 let c1 = C(2., 3.);
136 let c2 = C(4., 5.);
137 let expected = C(23. / 41., 2. / 41.);
138 assert_eq!(c1 / c2, expected);
139 }
140 #[test]
141 fn assing() {
142 let mut z = C(0, 0);
143 z += 2;
144 assert_eq!(z, C(2, 0));
145 z -= 4.i();
146 assert_eq!(z, C(2, -4));
147 z *= 3;
148 assert_eq!(z, C(6, -12));
149 z /= C(2, 0);
150 assert_eq!(z, C(3, -6));
151 }
152 #[test]
153 fn conj() {
154 let a: C<i32> = 2 + 3.i();
155 assert!((a * a.conj()).re() == a.r_square())
156 }
157
158 #[test]
159 fn from_num() {
160 let a: u8 = 42;
161 let a_complex = C::from(a);
162 assert_eq!(a_complex, C(42, 0));
163
164 let a: f32 = 42.0;
165 let a_complex = C::from(a);
166 assert_eq!(a_complex, C(42.0, 0.0));
167
168 let a: i32 = -42;
169 let a_complex = C::from(a);
170 assert_eq!(a_complex, C(-42, 0));
171 }
172
173 #[test]
174 fn ln() {
175 let a = C(0.0, 1.0);
176 assert_eq!(a.ln(), C(0.0, FRAC_PI_2));
177
178 let a = C(2.0, 0.0);
179 assert_eq!(a.ln(), C(LN_2, 0.0));
180
181 let a = C(-1.0, 0.0);
182 assert_eq!(a.ln(), C(0.0, PI));
183 }
184
185 #[test]
186 fn abs() {
187 assert_eq!(_0_1.abs(), 1.0);
188 assert_eq!(_1_0.abs(), 1.0);
189 assert_eq!(_n1_0.abs(), 1.0);
190 assert_eq!(_0_n1.abs(), 1.0);
191 assert_eq!(unit.abs(), 1.0);
192 }
193
194 #[test]
195 fn sqrt() {
196 for n in (0..100).map(f64::from) {
197 let n2 = n * n;
198 assert!(approx(C(n2, 0.).sqrt(), C(n, 0.)));
199 assert!(approx(C(-n2, 0.).sqrt(), C(0., n)));
200 assert!(approx(C(-n2, -0.).sqrt(), C(0.0, -n)));
201 }
202 let z2: C<f64> = 0.25 + 0.0.i();
203 assert_eq!(z2.sqrt(), C(0.5, 0.));
204 for c in all_z {
205 assert!(approx(c.conj().sqrt(), c.sqrt().conj()));
206 assert!(approx(c.sqrt() * c.sqrt(), c));
207 assert!(
208 -std::f64::consts::FRAC_PI_2 <= c.sqrt().arg()
209 && c.sqrt().arg() <= std::f64::consts::FRAC_PI_2
210 );
211 }
212 }
213
214 #[test]
215 fn powi() {
216 let z1 = C(2, 0);
217 assert_eq!(z1.powi(3), C(8, 0));
218 let z2 = 2.i();
219 assert_eq!(z2.powi(4), C(16, 0));
220 let z3 = C(3, -5);
221 assert_eq!(z3.clone().powi(3), z3.clone() * z3.clone() * z3);
222 assert_eq!(_2_n1.powi(2), _2_n1 * _2_n1);
223 assert_eq!(C(5, 10).powi(0), C(1, 0));
224 assert_eq!(2.0.i().powi(-2), C(-1. / 4., 0.));
225 }
226 #[test]
227 fn powf() {
228 assert!(approx(_2_n1.powf(2.), _2_n1 * _2_n1));
229 assert!(approx(_2_n1.powf(0.), C(1., 0.)));
230 assert!(approx(_0_1.powf(4.), C(1., 0.)))
231 }
232 #[test]
233 fn powc() {
234 assert!(approx(_2_n1.powc(C(2., 0.)), _2_n1 * _2_n1));
235 assert!(approx(_2_n1.powc(C(0., 0.)), C(1., 0.)));
236 let z: C<f64> = 2.0 + 0.5.i();
241 assert!(approx(
242 z.powc(z.clone()),
243 C(2.4767939208048335, 2.8290270856372506)
244 ))
245 }
246}