1use crate::accumulator::BinnedAccumulatorF64;
33
34#[derive(Debug, Clone, Copy, PartialEq)]
42pub struct ComplexF64 {
43 pub re: f64,
44 pub im: f64,
45}
46
47impl ComplexF64 {
48 #[inline]
50 pub fn new(re: f64, im: f64) -> Self {
51 ComplexF64 { re, im }
52 }
53
54 #[inline]
56 pub fn real(re: f64) -> Self {
57 ComplexF64 { re, im: 0.0 }
58 }
59
60 #[inline]
62 pub fn imag(im: f64) -> Self {
63 ComplexF64 { re: 0.0, im }
64 }
65
66 pub const ZERO: ComplexF64 = ComplexF64 { re: 0.0, im: 0.0 };
68
69 pub const ONE: ComplexF64 = ComplexF64 { re: 1.0, im: 0.0 };
71
72 pub const I: ComplexF64 = ComplexF64 { re: 0.0, im: 1.0 };
74
75 #[inline]
77 pub fn norm_sq(self) -> f64 {
78 let r2 = self.re * self.re;
80 let i2 = self.im * self.im;
81 r2 + i2
82 }
83
84 #[inline]
86 pub fn abs(self) -> f64 {
87 self.norm_sq().sqrt()
88 }
89
90 #[inline]
92 pub fn conj(self) -> Self {
93 ComplexF64 { re: self.re, im: -self.im }
94 }
95
96 #[inline]
117 pub fn mul_fixed(self, rhs: Self) -> Self {
118 let t1 = self.re * rhs.re; let t2 = self.im * rhs.im; let t3 = self.re * rhs.im; let t4 = self.im * rhs.re; let re = t1 - t2; let im = t3 + t4; ComplexF64 { re, im }
129 }
130
131 #[inline]
133 pub fn add(self, rhs: Self) -> Self {
134 ComplexF64 {
135 re: self.re + rhs.re,
136 im: self.im + rhs.im,
137 }
138 }
139
140 #[inline]
142 pub fn sub(self, rhs: Self) -> Self {
143 ComplexF64 {
144 re: self.re - rhs.re,
145 im: self.im - rhs.im,
146 }
147 }
148
149 #[inline]
151 pub fn neg(self) -> Self {
152 ComplexF64 { re: -self.re, im: -self.im }
153 }
154
155 #[inline]
167 pub fn div_fixed(self, rhs: Self) -> Self {
168 let cc = rhs.re * rhs.re;
170 let dd = rhs.im * rhs.im;
171 let denom = cc + dd;
172
173 let ac = self.re * rhs.re;
175 let bd = self.im * rhs.im;
176 let re = (ac + bd) / denom;
177
178 let bc = self.im * rhs.re;
180 let ad = self.re * rhs.im;
181 let im = (bc - ad) / denom;
182
183 ComplexF64 { re, im }
184 }
185
186 #[inline]
188 pub fn scale(self, s: f64) -> Self {
189 ComplexF64 { re: s * self.re, im: s * self.im }
190 }
191
192 #[inline]
194 pub fn is_nan(self) -> bool {
195 self.re.is_nan() || self.im.is_nan()
196 }
197
198 #[inline]
200 pub fn is_finite(self) -> bool {
201 self.re.is_finite() && self.im.is_finite()
202 }
203}
204
205impl std::fmt::Display for ComplexF64 {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 if self.im >= 0.0 {
208 write!(f, "{}+{}i", self.re, self.im)
209 } else {
210 write!(f, "{}{}i", self.re, self.im)
211 }
212 }
213}
214
215pub fn complex_dot(a: &[ComplexF64], b: &[ComplexF64]) -> ComplexF64 {
225 debug_assert_eq!(a.len(), b.len());
226 let mut re_acc = BinnedAccumulatorF64::new();
227 let mut im_acc = BinnedAccumulatorF64::new();
228
229 for i in 0..a.len() {
230 let z = a[i].mul_fixed(b[i].conj());
232 re_acc.add(z.re);
233 im_acc.add(z.im);
234 }
235
236 ComplexF64 {
237 re: re_acc.finalize(),
238 im: im_acc.finalize(),
239 }
240}
241
242pub fn complex_sum(values: &[ComplexF64]) -> ComplexF64 {
246 let mut re_acc = BinnedAccumulatorF64::new();
247 let mut im_acc = BinnedAccumulatorF64::new();
248
249 for &z in values {
250 re_acc.add(z.re);
251 im_acc.add(z.im);
252 }
253
254 ComplexF64 {
255 re: re_acc.finalize(),
256 im: im_acc.finalize(),
257 }
258}
259
260pub fn complex_matmul(
264 a: &[ComplexF64], b: &[ComplexF64], out: &mut [ComplexF64],
265 m: usize, k: usize, n: usize,
266) {
267 debug_assert_eq!(a.len(), m * k);
268 debug_assert_eq!(b.len(), k * n);
269 debug_assert_eq!(out.len(), m * n);
270
271 for i in 0..m {
272 for j in 0..n {
273 let mut re_acc = BinnedAccumulatorF64::new();
274 let mut im_acc = BinnedAccumulatorF64::new();
275 for p in 0..k {
276 let prod = a[i * k + p].mul_fixed(b[p * n + j]);
277 re_acc.add(prod.re);
278 im_acc.add(prod.im);
279 }
280 out[i * n + j] = ComplexF64 {
281 re: re_acc.finalize(),
282 im: im_acc.finalize(),
283 };
284 }
285 }
286}
287
288#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_complex_mul_basic() {
298 let a = ComplexF64::new(1.0, 2.0);
300 let b = ComplexF64::new(3.0, 4.0);
301 let c = a.mul_fixed(b);
302 assert_eq!(c.re, -5.0);
303 assert_eq!(c.im, 10.0);
304 }
305
306 #[test]
307 fn test_complex_mul_commutative() {
308 let a = ComplexF64::new(1.23456789, -9.87654321);
309 let b = ComplexF64::new(-3.14159265, 2.71828183);
310 let ab = a.mul_fixed(b);
311 let ba = b.mul_fixed(a);
312 assert_eq!(ab.re.to_bits(), ba.re.to_bits());
313 assert_eq!(ab.im.to_bits(), ba.im.to_bits());
314 }
315
316 #[test]
317 fn test_complex_mul_identity() {
318 let a = ComplexF64::new(7.0, -3.0);
319 let one = ComplexF64::ONE;
320 let result = a.mul_fixed(one);
321 assert_eq!(result.re, a.re);
322 assert_eq!(result.im, a.im);
323 }
324
325 #[test]
326 fn test_complex_mul_i_squared() {
327 let i = ComplexF64::I;
329 let result = i.mul_fixed(i);
330 assert_eq!(result.re, -1.0);
331 assert_eq!(result.im, 0.0);
332 }
333
334 #[test]
335 fn test_complex_conj() {
336 let z = ComplexF64::new(3.0, 4.0);
337 let c = z.conj();
338 assert_eq!(c.re, 3.0);
339 assert_eq!(c.im, -4.0);
340 }
341
342 #[test]
343 fn test_complex_abs() {
344 let z = ComplexF64::new(3.0, 4.0);
345 assert_eq!(z.abs(), 5.0);
346 }
347
348 #[test]
349 fn test_complex_dot_basic() {
350 let a = vec![ComplexF64::new(1.0, 0.0), ComplexF64::new(0.0, 1.0)];
351 let b = vec![ComplexF64::new(1.0, 0.0), ComplexF64::new(0.0, 1.0)];
352 let result = complex_dot(&a, &b);
355 assert_eq!(result.re, 2.0);
356 assert_eq!(result.im, 0.0);
357 }
358
359 #[test]
360 fn test_complex_dot_deterministic() {
361 let n = 500;
362 let a: Vec<ComplexF64> = (0..n)
363 .map(|i| ComplexF64::new(i as f64 * 0.001, -(i as f64 * 0.002)))
364 .collect();
365 let b: Vec<ComplexF64> = (0..n)
366 .map(|i| ComplexF64::new((n - i) as f64 * 0.003, i as f64 * 0.004))
367 .collect();
368
369 let r1 = complex_dot(&a, &b);
370 let r2 = complex_dot(&a, &b);
371 assert_eq!(r1.re.to_bits(), r2.re.to_bits());
372 assert_eq!(r1.im.to_bits(), r2.im.to_bits());
373 }
374
375 #[test]
376 fn test_complex_sum_deterministic() {
377 let values: Vec<ComplexF64> = (0..1000)
378 .map(|i| ComplexF64::new(i as f64 * 0.7 - 350.0, -(i as f64) * 0.3 + 150.0))
379 .collect();
380 let r1 = complex_sum(&values);
381 let r2 = complex_sum(&values);
382 assert_eq!(r1.re.to_bits(), r2.re.to_bits());
383 assert_eq!(r1.im.to_bits(), r2.im.to_bits());
384 }
385
386 #[test]
387 fn test_complex_sum_near_order_invariant() {
388 let values: Vec<ComplexF64> = (0..100)
389 .map(|i| ComplexF64::new(i as f64 * 1.1 - 50.0, -(i as f64) * 0.9 + 45.0))
390 .collect();
391 let mut reversed = values.clone();
392 reversed.reverse();
393
394 let r1 = complex_sum(&values);
395 let r2 = complex_sum(&reversed);
396 let re_ulps = (r1.re.to_bits() as i64 - r2.re.to_bits() as i64).unsigned_abs();
398 let im_ulps = (r1.im.to_bits() as i64 - r2.im.to_bits() as i64).unsigned_abs();
399 assert!(re_ulps < 10, "Real parts near-order-invariant: {re_ulps} ULPs");
400 assert!(im_ulps < 10, "Imaginary parts near-order-invariant: {im_ulps} ULPs");
401 }
402
403 #[test]
404 fn test_complex_sum_merge_order_invariant() {
405 let values: Vec<ComplexF64> = (0..100)
407 .map(|i| ComplexF64::new(i as f64 * 1.1 - 50.0, -(i as f64) * 0.9 + 45.0))
408 .collect();
409
410 let mut re_fwd = BinnedAccumulatorF64::new();
412 let mut im_fwd = BinnedAccumulatorF64::new();
413 for chunk in values.chunks(10) {
414 let mut re_c = BinnedAccumulatorF64::new();
415 let mut im_c = BinnedAccumulatorF64::new();
416 for z in chunk {
417 re_c.add(z.re);
418 im_c.add(z.im);
419 }
420 re_fwd.merge(&re_c);
421 im_fwd.merge(&im_c);
422 }
423
424 let chunks: Vec<Vec<ComplexF64>> = values.chunks(10).map(|c| c.to_vec()).collect();
426 let mut re_rev = BinnedAccumulatorF64::new();
427 let mut im_rev = BinnedAccumulatorF64::new();
428 for chunk in chunks.iter().rev() {
429 let mut re_c = BinnedAccumulatorF64::new();
430 let mut im_c = BinnedAccumulatorF64::new();
431 for z in chunk.iter() {
432 re_c.add(z.re);
433 im_c.add(z.im);
434 }
435 re_rev.merge(&re_c);
436 im_rev.merge(&im_c);
437 }
438
439 assert_eq!(re_fwd.finalize().to_bits(), re_rev.finalize().to_bits(),
440 "Complex real merge must be order-invariant");
441 assert_eq!(im_fwd.finalize().to_bits(), im_rev.finalize().to_bits(),
442 "Complex imaginary merge must be order-invariant");
443 }
444
445 #[test]
446 fn test_complex_matmul_identity() {
447 let identity = vec![
449 ComplexF64::ONE, ComplexF64::ZERO,
450 ComplexF64::ZERO, ComplexF64::ONE,
451 ];
452 let b = vec![
453 ComplexF64::new(1.0, 2.0), ComplexF64::new(3.0, 4.0),
454 ComplexF64::new(5.0, 6.0), ComplexF64::new(7.0, 8.0),
455 ];
456 let mut out = vec![ComplexF64::ZERO; 4];
457 complex_matmul(&identity, &b, &mut out, 2, 2, 2);
458 for (i, &v) in out.iter().enumerate() {
459 assert_eq!(v.re, b[i].re);
460 assert_eq!(v.im, b[i].im);
461 }
462 }
463
464 #[test]
465 fn test_complex_matmul_deterministic() {
466 let a: Vec<ComplexF64> = (0..9)
467 .map(|i| ComplexF64::new(i as f64 * 0.3, -(i as f64) * 0.2))
468 .collect();
469 let b: Vec<ComplexF64> = (0..9)
470 .map(|i| ComplexF64::new(-(i as f64) * 0.1, i as f64 * 0.4))
471 .collect();
472 let mut out1 = vec![ComplexF64::ZERO; 9];
473 let mut out2 = vec![ComplexF64::ZERO; 9];
474 complex_matmul(&a, &b, &mut out1, 3, 3, 3);
475 complex_matmul(&a, &b, &mut out2, 3, 3, 3);
476 for i in 0..9 {
477 assert_eq!(out1[i].re.to_bits(), out2[i].re.to_bits());
478 assert_eq!(out1[i].im.to_bits(), out2[i].im.to_bits());
479 }
480 }
481
482 #[test]
483 fn test_complex_div_basic() {
484 let a = ComplexF64::new(1.0, 2.0);
486 let one = ComplexF64::new(1.0, 0.0);
487 let c = a.div_fixed(one);
488 assert_eq!(c.re, 1.0);
489 assert_eq!(c.im, 2.0);
490 }
491
492 #[test]
493 fn test_complex_div_nontrivial() {
494 let a = ComplexF64::new(3.0, 4.0);
496 let b = ComplexF64::new(1.0, 2.0);
497 let c = a.div_fixed(b);
498 let tol = 1e-15;
499 assert!((c.re - 2.2).abs() < tol, "re: {} vs 2.2", c.re);
500 assert!((c.im - (-0.4)).abs() < tol, "im: {} vs -0.4", c.im);
501 }
502
503 #[test]
504 fn test_complex_div_by_zero() {
505 let a = ComplexF64::new(1.0, 2.0);
507 let zero = ComplexF64::ZERO;
508 let c = a.div_fixed(zero);
509 assert!(!c.re.is_finite() || c.re.is_nan());
511 assert!(!c.im.is_finite() || c.im.is_nan());
512 }
513
514 #[test]
515 fn test_complex_div_roundtrip() {
516 let z = ComplexF64::new(3.7, -2.1);
518 let w = ComplexF64::new(1.5, 0.8);
519 let product = z.mul_fixed(w);
520 let back = product.div_fixed(w);
521 let tol = 1e-12;
522 assert!((back.re - z.re).abs() < tol, "re roundtrip: {} vs {}", back.re, z.re);
523 assert!((back.im - z.im).abs() < tol, "im roundtrip: {} vs {}", back.im, z.im);
524 }
525
526 #[test]
527 fn test_complex_signed_zero_preserved() {
528 let z1 = ComplexF64::new(0.0, 0.0);
529 let z2 = ComplexF64::new(-0.0, -0.0);
530 let sum = z1.add(z2);
532 assert!(sum.re.is_sign_positive() || sum.re == 0.0);
533 }
534
535 #[test]
536 fn test_complex_nan_propagation() {
537 let nan_z = ComplexF64::new(f64::NAN, 1.0);
538 let normal = ComplexF64::new(1.0, 1.0);
539 let result = nan_z.mul_fixed(normal);
540 assert!(result.is_nan());
541 }
542
543 #[test]
544 fn test_complex_display() {
545 let z = ComplexF64::new(3.0, -4.0);
546 assert_eq!(format!("{z}"), "3-4i");
547 let z2 = ComplexF64::new(1.0, 2.0);
548 assert_eq!(format!("{z2}"), "1+2i");
549 }
550}