Skip to main content

cjc_runtime/
complex.rs

1//! Complex BLAS — ComplexF64 with Fixed-Sequence Arithmetic.
2//!
3//! # Design
4//!
5//! Complex multiplication is lowered to a **fixed-sequence** of four
6//! multiplications and two additions, explicitly ordered to prevent
7//! cross-architecture FMA drift. This ensures bit-parity between x86
8//! and ARM platforms.
9//!
10//! Complex reductions (dot products, sums) feed real and imaginary parts
11//! separately into BinnedAccumulators for deterministic results.
12//!
13//! # Fixed-Sequence Complex Multiply
14//!
15//! ```text
16//! (a + bi)(c + di) = (ac - bd) + (ad + bc)i
17//! ```
18//!
19//! The four multiplications are computed first, then the two additions:
20//! ```text
21//! t1 = a * c   (mul #1)
22//! t2 = b * d   (mul #2)
23//! t3 = a * d   (mul #3)
24//! t4 = b * c   (mul #4)
25//! re = t1 - t2 (sub #1)
26//! im = t3 + t4 (add #1)
27//! ```
28//!
29//! This explicit ordering prevents the compiler from fusing `a*c - b*d`
30//! into an FMA (which would change the rounding behavior).
31
32use crate::accumulator::BinnedAccumulatorF64;
33
34// ---------------------------------------------------------------------------
35// ComplexF64
36// ---------------------------------------------------------------------------
37
38/// A complex number with f64 real and imaginary parts.
39///
40/// Arithmetic follows the fixed-sequence protocol to prevent FMA drift.
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub struct ComplexF64 {
43    pub re: f64,
44    pub im: f64,
45}
46
47impl ComplexF64 {
48    /// Create a new complex number.
49    #[inline]
50    pub fn new(re: f64, im: f64) -> Self {
51        ComplexF64 { re, im }
52    }
53
54    /// Create a purely real complex number.
55    #[inline]
56    pub fn real(re: f64) -> Self {
57        ComplexF64 { re, im: 0.0 }
58    }
59
60    /// Create a purely imaginary complex number.
61    #[inline]
62    pub fn imag(im: f64) -> Self {
63        ComplexF64 { re: 0.0, im }
64    }
65
66    /// Zero.
67    pub const ZERO: ComplexF64 = ComplexF64 { re: 0.0, im: 0.0 };
68
69    /// One.
70    pub const ONE: ComplexF64 = ComplexF64 { re: 1.0, im: 0.0 };
71
72    /// Imaginary unit.
73    pub const I: ComplexF64 = ComplexF64 { re: 0.0, im: 1.0 };
74
75    /// Squared magnitude: |z|^2 = re^2 + im^2.
76    #[inline]
77    pub fn norm_sq(self) -> f64 {
78        // Fixed sequence: two muls, one add.
79        let r2 = self.re * self.re;
80        let i2 = self.im * self.im;
81        r2 + i2
82    }
83
84    /// Magnitude: |z| = sqrt(re^2 + im^2).
85    #[inline]
86    pub fn abs(self) -> f64 {
87        self.norm_sq().sqrt()
88    }
89
90    /// Complex conjugate: (a - bi).
91    #[inline]
92    pub fn conj(self) -> Self {
93        ComplexF64 { re: self.re, im: -self.im }
94    }
95
96    /// Fixed-Sequence Complex Multiplication.
97    ///
98    /// Explicitly computes four multiplications and two additions in a
99    /// deterministic order, preventing FMA contraction:
100    ///
101    /// ```text
102    /// t1 = a.re * b.re
103    /// t2 = a.im * b.im
104    /// t3 = a.re * b.im
105    /// t4 = a.im * b.re
106    /// result.re = t1 - t2
107    /// result.im = t3 + t4
108    /// ```
109    ///
110    /// # FMA Prevention
111    ///
112    /// By storing intermediates in local variables and computing each step
113    /// explicitly, we prevent LLVM from fusing operations into FMA
114    /// instructions, which would cause different rounding on platforms
115    /// with/without hardware FMA support.
116    #[inline]
117    pub fn mul_fixed(self, rhs: Self) -> Self {
118        // Step 1: Four independent multiplications.
119        let t1 = self.re * rhs.re; // a*c
120        let t2 = self.im * rhs.im; // b*d
121        let t3 = self.re * rhs.im; // a*d
122        let t4 = self.im * rhs.re; // b*c
123
124        // Step 2: Two additions (using the pre-computed products).
125        let re = t1 - t2; // ac - bd
126        let im = t3 + t4; // ad + bc
127
128        ComplexF64 { re, im }
129    }
130
131    /// Complex addition: (a+bi) + (c+di) = (a+c) + (b+d)i.
132    #[inline]
133    pub fn add(self, rhs: Self) -> Self {
134        ComplexF64 {
135            re: self.re + rhs.re,
136            im: self.im + rhs.im,
137        }
138    }
139
140    /// Complex subtraction: (a+bi) - (c+di) = (a-c) + (b-d)i.
141    #[inline]
142    pub fn sub(self, rhs: Self) -> Self {
143        ComplexF64 {
144            re: self.re - rhs.re,
145            im: self.im - rhs.im,
146        }
147    }
148
149    /// Complex negation: -(a+bi) = (-a) + (-b)i.
150    #[inline]
151    pub fn neg(self) -> Self {
152        ComplexF64 { re: -self.re, im: -self.im }
153    }
154
155    /// Fixed-Sequence Complex Division.
156    ///
157    /// Computes `(a+bi) / (c+di)` using a fixed sequence:
158    ///
159    /// ```text
160    /// denom = c*c + d*d         (two muls, one add — ordered)
161    /// re = (a*c + b*d) / denom  (two muls, one add, one div)
162    /// im = (b*c - a*d) / denom  (two muls, one sub, one div)
163    /// ```
164    ///
165    /// Division by zero (0+0i) produces NaN/Inf stably (no panic).
166    #[inline]
167    pub fn div_fixed(self, rhs: Self) -> Self {
168        // Step 1: Denominator (ordered: c*c first, then d*d, then add).
169        let cc = rhs.re * rhs.re;
170        let dd = rhs.im * rhs.im;
171        let denom = cc + dd;
172
173        // Step 2: Numerator real part (ordered: a*c first, then b*d, then add).
174        let ac = self.re * rhs.re;
175        let bd = self.im * rhs.im;
176        let re = (ac + bd) / denom;
177
178        // Step 3: Numerator imaginary part (ordered: b*c first, then a*d, then sub).
179        let bc = self.im * rhs.re;
180        let ad = self.re * rhs.im;
181        let im = (bc - ad) / denom;
182
183        ComplexF64 { re, im }
184    }
185
186    /// Scalar multiplication: s * (a+bi) = (s*a) + (s*b)i.
187    #[inline]
188    pub fn scale(self, s: f64) -> Self {
189        ComplexF64 { re: s * self.re, im: s * self.im }
190    }
191
192    /// Check if NaN in either component.
193    #[inline]
194    pub fn is_nan(self) -> bool {
195        self.re.is_nan() || self.im.is_nan()
196    }
197
198    /// Check if both components are finite.
199    #[inline]
200    pub fn is_finite(self) -> bool {
201        self.re.is_finite() && self.im.is_finite()
202    }
203}
204
205impl std::fmt::Display for ComplexF64 {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        if self.im >= 0.0 {
208            write!(f, "{}+{}i", self.re, self.im)
209        } else {
210            write!(f, "{}{}i", self.re, self.im)
211        }
212    }
213}
214
215// ---------------------------------------------------------------------------
216// Complex BLAS Operations via BinnedAccumulator
217// ---------------------------------------------------------------------------
218
219/// Complex dot product using BinnedAccumulator for deterministic results.
220///
221/// `dot(a, b) = Σ a[i] * conj(b[i])` (standard Hermitian inner product).
222///
223/// Real and imaginary parts are accumulated separately via BinnedAccumulator.
224pub fn complex_dot(a: &[ComplexF64], b: &[ComplexF64]) -> ComplexF64 {
225    debug_assert_eq!(a.len(), b.len());
226    let mut re_acc = BinnedAccumulatorF64::new();
227    let mut im_acc = BinnedAccumulatorF64::new();
228
229    for i in 0..a.len() {
230        // z = a[i] * conj(b[i])
231        let z = a[i].mul_fixed(b[i].conj());
232        re_acc.add(z.re);
233        im_acc.add(z.im);
234    }
235
236    ComplexF64 {
237        re: re_acc.finalize(),
238        im: im_acc.finalize(),
239    }
240}
241
242/// Complex sum using BinnedAccumulator for deterministic results.
243///
244/// Real and imaginary parts accumulated independently.
245pub fn complex_sum(values: &[ComplexF64]) -> ComplexF64 {
246    let mut re_acc = BinnedAccumulatorF64::new();
247    let mut im_acc = BinnedAccumulatorF64::new();
248
249    for &z in values {
250        re_acc.add(z.re);
251        im_acc.add(z.im);
252    }
253
254    ComplexF64 {
255        re: re_acc.finalize(),
256        im: im_acc.finalize(),
257    }
258}
259
260/// Complex matrix multiply: C[m,n] = A[m,k] × B[k,n] (fixed-sequence).
261///
262/// Each element C[i,j] = Σ_p A[i,p] * B[p,j] with BinnedAccumulator.
263pub fn complex_matmul(
264    a: &[ComplexF64], b: &[ComplexF64], out: &mut [ComplexF64],
265    m: usize, k: usize, n: usize,
266) {
267    debug_assert_eq!(a.len(), m * k);
268    debug_assert_eq!(b.len(), k * n);
269    debug_assert_eq!(out.len(), m * n);
270
271    for i in 0..m {
272        for j in 0..n {
273            let mut re_acc = BinnedAccumulatorF64::new();
274            let mut im_acc = BinnedAccumulatorF64::new();
275            for p in 0..k {
276                let prod = a[i * k + p].mul_fixed(b[p * n + j]);
277                re_acc.add(prod.re);
278                im_acc.add(prod.im);
279            }
280            out[i * n + j] = ComplexF64 {
281                re: re_acc.finalize(),
282                im: im_acc.finalize(),
283            };
284        }
285    }
286}
287
288// ---------------------------------------------------------------------------
289// Inline tests
290// ---------------------------------------------------------------------------
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_complex_mul_basic() {
298        // (1+2i)(3+4i) = (3-8) + (4+6)i = -5 + 10i
299        let a = ComplexF64::new(1.0, 2.0);
300        let b = ComplexF64::new(3.0, 4.0);
301        let c = a.mul_fixed(b);
302        assert_eq!(c.re, -5.0);
303        assert_eq!(c.im, 10.0);
304    }
305
306    #[test]
307    fn test_complex_mul_commutative() {
308        let a = ComplexF64::new(1.23456789, -9.87654321);
309        let b = ComplexF64::new(-3.14159265, 2.71828183);
310        let ab = a.mul_fixed(b);
311        let ba = b.mul_fixed(a);
312        assert_eq!(ab.re.to_bits(), ba.re.to_bits());
313        assert_eq!(ab.im.to_bits(), ba.im.to_bits());
314    }
315
316    #[test]
317    fn test_complex_mul_identity() {
318        let a = ComplexF64::new(7.0, -3.0);
319        let one = ComplexF64::ONE;
320        let result = a.mul_fixed(one);
321        assert_eq!(result.re, a.re);
322        assert_eq!(result.im, a.im);
323    }
324
325    #[test]
326    fn test_complex_mul_i_squared() {
327        // i * i = -1
328        let i = ComplexF64::I;
329        let result = i.mul_fixed(i);
330        assert_eq!(result.re, -1.0);
331        assert_eq!(result.im, 0.0);
332    }
333
334    #[test]
335    fn test_complex_conj() {
336        let z = ComplexF64::new(3.0, 4.0);
337        let c = z.conj();
338        assert_eq!(c.re, 3.0);
339        assert_eq!(c.im, -4.0);
340    }
341
342    #[test]
343    fn test_complex_abs() {
344        let z = ComplexF64::new(3.0, 4.0);
345        assert_eq!(z.abs(), 5.0);
346    }
347
348    #[test]
349    fn test_complex_dot_basic() {
350        let a = vec![ComplexF64::new(1.0, 0.0), ComplexF64::new(0.0, 1.0)];
351        let b = vec![ComplexF64::new(1.0, 0.0), ComplexF64::new(0.0, 1.0)];
352        // dot = a[0]*conj(b[0]) + a[1]*conj(b[1])
353        //     = 1*1 + i*(-i) = 1 + 1 = 2 + 0i
354        let result = complex_dot(&a, &b);
355        assert_eq!(result.re, 2.0);
356        assert_eq!(result.im, 0.0);
357    }
358
359    #[test]
360    fn test_complex_dot_deterministic() {
361        let n = 500;
362        let a: Vec<ComplexF64> = (0..n)
363            .map(|i| ComplexF64::new(i as f64 * 0.001, -(i as f64 * 0.002)))
364            .collect();
365        let b: Vec<ComplexF64> = (0..n)
366            .map(|i| ComplexF64::new((n - i) as f64 * 0.003, i as f64 * 0.004))
367            .collect();
368
369        let r1 = complex_dot(&a, &b);
370        let r2 = complex_dot(&a, &b);
371        assert_eq!(r1.re.to_bits(), r2.re.to_bits());
372        assert_eq!(r1.im.to_bits(), r2.im.to_bits());
373    }
374
375    #[test]
376    fn test_complex_sum_deterministic() {
377        let values: Vec<ComplexF64> = (0..1000)
378            .map(|i| ComplexF64::new(i as f64 * 0.7 - 350.0, -(i as f64) * 0.3 + 150.0))
379            .collect();
380        let r1 = complex_sum(&values);
381        let r2 = complex_sum(&values);
382        assert_eq!(r1.re.to_bits(), r2.re.to_bits());
383        assert_eq!(r1.im.to_bits(), r2.im.to_bits());
384    }
385
386    #[test]
387    fn test_complex_sum_near_order_invariant() {
388        let values: Vec<ComplexF64> = (0..100)
389            .map(|i| ComplexF64::new(i as f64 * 1.1 - 50.0, -(i as f64) * 0.9 + 45.0))
390            .collect();
391        let mut reversed = values.clone();
392        reversed.reverse();
393
394        let r1 = complex_sum(&values);
395        let r2 = complex_sum(&reversed);
396        // Within-bin accumulation is near-order-invariant (sub-10 ULPs).
397        let re_ulps = (r1.re.to_bits() as i64 - r2.re.to_bits() as i64).unsigned_abs();
398        let im_ulps = (r1.im.to_bits() as i64 - r2.im.to_bits() as i64).unsigned_abs();
399        assert!(re_ulps < 10, "Real parts near-order-invariant: {re_ulps} ULPs");
400        assert!(im_ulps < 10, "Imaginary parts near-order-invariant: {im_ulps} ULPs");
401    }
402
403    #[test]
404    fn test_complex_sum_merge_order_invariant() {
405        // Merge-based complex summation IS fully order-invariant.
406        let values: Vec<ComplexF64> = (0..100)
407            .map(|i| ComplexF64::new(i as f64 * 1.1 - 50.0, -(i as f64) * 0.9 + 45.0))
408            .collect();
409
410        // Chunk into 10s, merge forward.
411        let mut re_fwd = BinnedAccumulatorF64::new();
412        let mut im_fwd = BinnedAccumulatorF64::new();
413        for chunk in values.chunks(10) {
414            let mut re_c = BinnedAccumulatorF64::new();
415            let mut im_c = BinnedAccumulatorF64::new();
416            for z in chunk {
417                re_c.add(z.re);
418                im_c.add(z.im);
419            }
420            re_fwd.merge(&re_c);
421            im_fwd.merge(&im_c);
422        }
423
424        // Chunk into 10s, merge reverse.
425        let chunks: Vec<Vec<ComplexF64>> = values.chunks(10).map(|c| c.to_vec()).collect();
426        let mut re_rev = BinnedAccumulatorF64::new();
427        let mut im_rev = BinnedAccumulatorF64::new();
428        for chunk in chunks.iter().rev() {
429            let mut re_c = BinnedAccumulatorF64::new();
430            let mut im_c = BinnedAccumulatorF64::new();
431            for z in chunk.iter() {
432                re_c.add(z.re);
433                im_c.add(z.im);
434            }
435            re_rev.merge(&re_c);
436            im_rev.merge(&im_c);
437        }
438
439        assert_eq!(re_fwd.finalize().to_bits(), re_rev.finalize().to_bits(),
440            "Complex real merge must be order-invariant");
441        assert_eq!(im_fwd.finalize().to_bits(), im_rev.finalize().to_bits(),
442            "Complex imaginary merge must be order-invariant");
443    }
444
445    #[test]
446    fn test_complex_matmul_identity() {
447        // 2x2 identity × arbitrary = same
448        let identity = vec![
449            ComplexF64::ONE, ComplexF64::ZERO,
450            ComplexF64::ZERO, ComplexF64::ONE,
451        ];
452        let b = vec![
453            ComplexF64::new(1.0, 2.0), ComplexF64::new(3.0, 4.0),
454            ComplexF64::new(5.0, 6.0), ComplexF64::new(7.0, 8.0),
455        ];
456        let mut out = vec![ComplexF64::ZERO; 4];
457        complex_matmul(&identity, &b, &mut out, 2, 2, 2);
458        for (i, &v) in out.iter().enumerate() {
459            assert_eq!(v.re, b[i].re);
460            assert_eq!(v.im, b[i].im);
461        }
462    }
463
464    #[test]
465    fn test_complex_matmul_deterministic() {
466        let a: Vec<ComplexF64> = (0..9)
467            .map(|i| ComplexF64::new(i as f64 * 0.3, -(i as f64) * 0.2))
468            .collect();
469        let b: Vec<ComplexF64> = (0..9)
470            .map(|i| ComplexF64::new(-(i as f64) * 0.1, i as f64 * 0.4))
471            .collect();
472        let mut out1 = vec![ComplexF64::ZERO; 9];
473        let mut out2 = vec![ComplexF64::ZERO; 9];
474        complex_matmul(&a, &b, &mut out1, 3, 3, 3);
475        complex_matmul(&a, &b, &mut out2, 3, 3, 3);
476        for i in 0..9 {
477            assert_eq!(out1[i].re.to_bits(), out2[i].re.to_bits());
478            assert_eq!(out1[i].im.to_bits(), out2[i].im.to_bits());
479        }
480    }
481
482    #[test]
483    fn test_complex_div_basic() {
484        // (1+2i) / (1+0i) = 1+2i
485        let a = ComplexF64::new(1.0, 2.0);
486        let one = ComplexF64::new(1.0, 0.0);
487        let c = a.div_fixed(one);
488        assert_eq!(c.re, 1.0);
489        assert_eq!(c.im, 2.0);
490    }
491
492    #[test]
493    fn test_complex_div_nontrivial() {
494        // (3+4i) / (1+2i) = (3+4i)(1-2i) / (1+4) = (3+8 + 4i-6i)/5 = (11-2i)/5
495        let a = ComplexF64::new(3.0, 4.0);
496        let b = ComplexF64::new(1.0, 2.0);
497        let c = a.div_fixed(b);
498        let tol = 1e-15;
499        assert!((c.re - 2.2).abs() < tol, "re: {} vs 2.2", c.re);
500        assert!((c.im - (-0.4)).abs() < tol, "im: {} vs -0.4", c.im);
501    }
502
503    #[test]
504    fn test_complex_div_by_zero() {
505        // Division by 0+0i should produce NaN/Inf stably (no panic).
506        let a = ComplexF64::new(1.0, 2.0);
507        let zero = ComplexF64::ZERO;
508        let c = a.div_fixed(zero);
509        // Result should be NaN or Inf (stable, no panic).
510        assert!(!c.re.is_finite() || c.re.is_nan());
511        assert!(!c.im.is_finite() || c.im.is_nan());
512    }
513
514    #[test]
515    fn test_complex_div_roundtrip() {
516        // (z * w) / w ≈ z for non-zero w.
517        let z = ComplexF64::new(3.7, -2.1);
518        let w = ComplexF64::new(1.5, 0.8);
519        let product = z.mul_fixed(w);
520        let back = product.div_fixed(w);
521        let tol = 1e-12;
522        assert!((back.re - z.re).abs() < tol, "re roundtrip: {} vs {}", back.re, z.re);
523        assert!((back.im - z.im).abs() < tol, "im roundtrip: {} vs {}", back.im, z.im);
524    }
525
526    #[test]
527    fn test_complex_signed_zero_preserved() {
528        let z1 = ComplexF64::new(0.0, 0.0);
529        let z2 = ComplexF64::new(-0.0, -0.0);
530        // Addition should preserve signs correctly.
531        let sum = z1.add(z2);
532        assert!(sum.re.is_sign_positive() || sum.re == 0.0);
533    }
534
535    #[test]
536    fn test_complex_nan_propagation() {
537        let nan_z = ComplexF64::new(f64::NAN, 1.0);
538        let normal = ComplexF64::new(1.0, 1.0);
539        let result = nan_z.mul_fixed(normal);
540        assert!(result.is_nan());
541    }
542
543    #[test]
544    fn test_complex_display() {
545        let z = ComplexF64::new(3.0, -4.0);
546        assert_eq!(format!("{z}"), "3-4i");
547        let z2 = ComplexF64::new(1.0, 2.0);
548        assert_eq!(format!("{z2}"), "1+2i");
549    }
550}