oxiblas_ndarray/
blas.rs

1//! BLAS operations on ndarray types.
2//!
3//! This module provides BLAS Level 1, 2, and 3 operations directly on
4//! ndarray types, using OxiBLAS as the backend.
5
6use crate::conversions::array2_to_mat;
7use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder};
8use num_complex::{Complex32, Complex64};
9use oxiblas_blas::level1::{asum, axpy, dot, dotc_c32, dotc_c64, dotu_c32, dotu_c64, nrm2, scal};
10use oxiblas_blas::level2::{GemvTrans, gemv as blas_gemv};
11use oxiblas_blas::level3::{GemmKernel, gemm as blas_gemm};
12use oxiblas_core::scalar::Field;
13use oxiblas_matrix::Mat;
14
15// =============================================================================
16// BLAS Level 1: Vector-Vector Operations
17// =============================================================================
18
19/// Computes the dot product of two 1D arrays.
20///
21/// # Arguments
22/// * `x` - First vector
23/// * `y` - Second vector
24///
25/// # Returns
26/// The dot product x·y
27///
28/// # Panics
29/// Panics if vectors have different lengths.
30pub fn dot_ndarray<T: Field>(x: &Array1<T>, y: &Array1<T>) -> T {
31    assert_eq!(x.len(), y.len(), "Vector lengths must match");
32
33    // Try to get contiguous slices for efficient computation
34    if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
35        dot(x_slice, y_slice)
36    } else {
37        // Non-contiguous: convert to contiguous first
38        let x_vec: Vec<T> = x.iter().cloned().collect();
39        let y_vec: Vec<T> = y.iter().cloned().collect();
40        dot(&x_vec, &y_vec)
41    }
42}
43
44/// Computes the dot product of two array views.
45pub fn dot_view<T: Field>(x: &ArrayView1<T>, y: &ArrayView1<T>) -> T {
46    assert_eq!(x.len(), y.len(), "Vector lengths must match");
47
48    if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
49        dot(x_slice, y_slice)
50    } else {
51        let x_vec: Vec<T> = x.iter().cloned().collect();
52        let y_vec: Vec<T> = y.iter().cloned().collect();
53        dot(&x_vec, &y_vec)
54    }
55}
56
57// =============================================================================
58// Complex Dot Products
59// =============================================================================
60
61/// Computes the conjugate dot product of two Complex64 vectors (ZDOTC).
62///
63/// x^H · y = Σ conj(x\[i\]) * y\[i\]
64///
65/// This is the standard inner product for complex vector spaces.
66///
67/// # Arguments
68/// * `x` - First complex vector (will be conjugated)
69/// * `y` - Second complex vector
70///
71/// # Returns
72/// The conjugate dot product
73///
74/// # Panics
75/// Panics if vectors have different lengths.
76///
77/// # Example
78/// ```
79/// use oxiblas_ndarray::blas::dotc_c64_ndarray;
80/// use ndarray::array;
81/// use num_complex::Complex64;
82///
83/// let x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
84/// let y = array![Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)];
85/// let result = dotc_c64_ndarray(&x, &y);
86/// // conj(1+2i)*(5+6i) + conj(3+4i)*(7+8i) = (1-2i)(5+6i) + (3-4i)(7+8i)
87/// // = (5+12) + (6-10)i + (21+32) + (24-28)i = 70 - 8i
88/// assert!((result.re - 70.0).abs() < 1e-10);
89/// assert!((result.im - (-8.0)).abs() < 1e-10);
90/// ```
91pub fn dotc_c64_ndarray(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
92    assert_eq!(x.len(), y.len(), "Vector lengths must match");
93
94    if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
95        dotc_c64(x_slice, y_slice)
96    } else {
97        let x_vec: Vec<Complex64> = x.iter().copied().collect();
98        let y_vec: Vec<Complex64> = y.iter().copied().collect();
99        dotc_c64(&x_vec, &y_vec)
100    }
101}
102
103/// Computes the conjugate dot product of two Complex32 vectors (CDOTC).
104///
105/// x^H · y = Σ conj(x\[i\]) * y\[i\]
106///
107/// # Panics
108/// Panics if vectors have different lengths.
109pub fn dotc_c32_ndarray(x: &Array1<Complex32>, y: &Array1<Complex32>) -> Complex32 {
110    assert_eq!(x.len(), y.len(), "Vector lengths must match");
111
112    if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
113        dotc_c32(x_slice, y_slice)
114    } else {
115        let x_vec: Vec<Complex32> = x.iter().copied().collect();
116        let y_vec: Vec<Complex32> = y.iter().copied().collect();
117        dotc_c32(&x_vec, &y_vec)
118    }
119}
120
121/// Computes the unconjugated dot product of two Complex64 vectors (ZDOTU).
122///
123/// x · y = Σ x\[i\] * y\[i\]
124///
125/// Note: This is the bilinear form, not the standard inner product.
126/// For the standard inner product (sesquilinear), use `dotc_c64_ndarray`.
127///
128/// # Panics
129/// Panics if vectors have different lengths.
130pub fn dotu_c64_ndarray(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
131    assert_eq!(x.len(), y.len(), "Vector lengths must match");
132
133    if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
134        dotu_c64(x_slice, y_slice)
135    } else {
136        let x_vec: Vec<Complex64> = x.iter().copied().collect();
137        let y_vec: Vec<Complex64> = y.iter().copied().collect();
138        dotu_c64(&x_vec, &y_vec)
139    }
140}
141
142/// Computes the unconjugated dot product of two Complex32 vectors (CDOTU).
143///
144/// x · y = Σ x\[i\] * y\[i\]
145///
146/// # Panics
147/// Panics if vectors have different lengths.
148pub fn dotu_c32_ndarray(x: &Array1<Complex32>, y: &Array1<Complex32>) -> Complex32 {
149    assert_eq!(x.len(), y.len(), "Vector lengths must match");
150
151    if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
152        dotu_c32(x_slice, y_slice)
153    } else {
154        let x_vec: Vec<Complex32> = x.iter().copied().collect();
155        let y_vec: Vec<Complex32> = y.iter().copied().collect();
156        dotu_c32(&x_vec, &y_vec)
157    }
158}
159
160/// Computes the Euclidean norm of a Complex64 vector.
161///
162/// ||x||_2 = sqrt(Σ |x\[i\]|²) = sqrt(Σ (x\[i\].re² + x\[i\].im²))
163///
164/// This is equivalent to sqrt(x^H · x).
165pub fn nrm2_c64_ndarray(x: &Array1<Complex64>) -> f64 {
166    let mut sum = 0.0f64;
167    for xi in x.iter() {
168        sum += xi.norm_sqr();
169    }
170    sum.sqrt()
171}
172
173/// Computes the Euclidean norm of a Complex32 vector.
174///
175/// ||x||_2 = sqrt(Σ |x\[i\]|²)
176pub fn nrm2_c32_ndarray(x: &Array1<Complex32>) -> f32 {
177    let mut sum = 0.0f32;
178    for xi in x.iter() {
179        sum += xi.norm_sqr();
180    }
181    sum.sqrt()
182}
183
184/// Computes the L1 norm of a Complex64 vector (sum of absolute values).
185///
186/// ||x||_1 = Σ |x\[i\]|
187pub fn asum_c64_ndarray(x: &Array1<Complex64>) -> f64 {
188    let mut sum = 0.0f64;
189    for xi in x.iter() {
190        sum += xi.norm();
191    }
192    sum
193}
194
195/// Computes the L1 norm of a Complex32 vector (sum of absolute values).
196///
197/// ||x||_1 = Σ |x\[i\]|
198pub fn asum_c32_ndarray(x: &Array1<Complex32>) -> f32 {
199    let mut sum = 0.0f32;
200    for xi in x.iter() {
201        sum += xi.norm();
202    }
203    sum
204}
205
206/// Computes the Euclidean (L2) norm of a vector.
207///
208/// ||x||_2 = sqrt(sum(x_i^2))
209pub fn nrm2_ndarray<T: Field + oxiblas_core::scalar::Real>(x: &Array1<T>) -> T {
210    if let Some(slice) = x.as_slice() {
211        nrm2(slice)
212    } else {
213        let vec: Vec<T> = x.iter().cloned().collect();
214        nrm2(&vec)
215    }
216}
217
218/// Computes the L1 norm (sum of absolute values) of a vector.
219///
220/// ||x||_1 = sum(|x_i|)
221pub fn asum_ndarray<T: Field + oxiblas_core::scalar::Real>(x: &Array1<T>) -> T {
222    if let Some(slice) = x.as_slice() {
223        asum(slice)
224    } else {
225        let vec: Vec<T> = x.iter().cloned().collect();
226        asum(&vec)
227    }
228}
229
230/// Computes y = α·x + y (AXPY operation).
231///
232/// # Arguments
233/// * `alpha` - Scalar multiplier
234/// * `x` - Input vector
235/// * `y` - Output vector (modified in place)
236pub fn axpy_ndarray<T: Field>(alpha: T, x: &Array1<T>, y: &mut Array1<T>) {
237    assert_eq!(x.len(), y.len(), "Vector lengths must match");
238
239    if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice_mut()) {
240        axpy(alpha, x_slice, y_slice);
241    } else {
242        // Non-contiguous: element-wise
243        for (yi, xi) in y.iter_mut().zip(x.iter()) {
244            *yi = alpha * (*xi) + *yi;
245        }
246    }
247}
248
249/// Scales a vector: x = α·x
250pub fn scal_ndarray<T: Field>(alpha: T, x: &mut Array1<T>) {
251    if let Some(slice) = x.as_slice_mut() {
252        scal(alpha, slice);
253    } else {
254        for xi in x.iter_mut() {
255            *xi = alpha * (*xi);
256        }
257    }
258}
259
260// =============================================================================
261// BLAS Level 2: Matrix-Vector Operations
262// =============================================================================
263
264/// Transpose options for matrix-vector operations.
265#[derive(Debug, Clone, Copy, PartialEq, Eq)]
266pub enum Transpose {
267    /// No transpose
268    NoTrans,
269    /// Transpose
270    Trans,
271    /// Conjugate transpose (for complex types)
272    ConjTrans,
273}
274
275impl From<Transpose> for GemvTrans {
276    fn from(t: Transpose) -> Self {
277        match t {
278            Transpose::NoTrans => GemvTrans::NoTrans,
279            Transpose::Trans => GemvTrans::Trans,
280            Transpose::ConjTrans => GemvTrans::ConjTrans,
281        }
282    }
283}
284
285/// General matrix-vector multiplication: y = α·op(A)·x + β·y
286///
287/// # Arguments
288/// * `trans` - Whether to transpose A
289/// * `alpha` - Scalar multiplier for A·x
290/// * `a` - The matrix (m×n)
291/// * `x` - Input vector
292/// * `beta` - Scalar multiplier for y
293/// * `y` - Output vector (modified in place)
294///
295/// # Panics
296/// Panics if dimensions don't match.
297pub fn gemv_ndarray<T: Field + Clone>(
298    trans: Transpose,
299    alpha: T,
300    a: &Array2<T>,
301    x: &Array1<T>,
302    beta: T,
303    y: &mut Array1<T>,
304) where
305    T: bytemuck::Zeroable,
306{
307    let a_mat = array2_to_mat(a);
308    let (m, n) = a.dim();
309
310    // Determine expected dimensions
311    let (x_len, y_len) = match trans {
312        Transpose::NoTrans => (n, m),
313        Transpose::Trans | Transpose::ConjTrans => (m, n),
314    };
315
316    assert_eq!(x.len(), x_len, "x dimension mismatch");
317    assert_eq!(y.len(), y_len, "y dimension mismatch");
318
319    // Convert vectors to slices
320    let x_vec: Vec<T> = x.iter().cloned().collect();
321
322    if let Some(y_slice) = y.as_slice_mut() {
323        blas_gemv(trans.into(), alpha, a_mat.as_ref(), &x_vec, beta, y_slice);
324    } else {
325        let mut y_vec: Vec<T> = y.iter().cloned().collect();
326        blas_gemv(
327            trans.into(),
328            alpha,
329            a_mat.as_ref(),
330            &x_vec,
331            beta,
332            &mut y_vec,
333        );
334        for (yi, val) in y.iter_mut().zip(y_vec.into_iter()) {
335            *yi = val;
336        }
337    }
338}
339
340/// Matrix-vector multiplication: y = A·x
341///
342/// Simplified version of gemv with alpha=1, beta=0.
343pub fn matvec<T: Field + Clone>(a: &Array2<T>, x: &Array1<T>) -> Array1<T>
344where
345    T: bytemuck::Zeroable,
346{
347    let (m, _n) = a.dim();
348    let mut y = Array1::zeros(m);
349    gemv_ndarray(Transpose::NoTrans, T::one(), a, x, T::zero(), &mut y);
350    y
351}
352
353/// Transposed matrix-vector multiplication: y = A^T·x
354pub fn matvec_t<T: Field + Clone>(a: &Array2<T>, x: &Array1<T>) -> Array1<T>
355where
356    T: bytemuck::Zeroable,
357{
358    let (_m, n) = a.dim();
359    let mut y = Array1::zeros(n);
360    gemv_ndarray(Transpose::Trans, T::one(), a, x, T::zero(), &mut y);
361    y
362}
363
364// =============================================================================
365// BLAS Level 3: Matrix-Matrix Operations
366// =============================================================================
367
368/// General matrix-matrix multiplication: C = α·A·B + β·C
369///
370/// # Arguments
371/// * `alpha` - Scalar multiplier for A·B
372/// * `a` - Left matrix (m×k)
373/// * `b` - Right matrix (k×n)
374/// * `beta` - Scalar multiplier for C
375/// * `c` - Output matrix (m×n), modified in place
376///
377/// # Panics
378/// Panics if matrix dimensions are incompatible.
379pub fn gemm_ndarray<T: Field + GemmKernel>(
380    alpha: T,
381    a: &Array2<T>,
382    b: &Array2<T>,
383    beta: T,
384    c: &mut Array2<T>,
385) where
386    T: bytemuck::Zeroable + Clone,
387{
388    let a_mat = array2_to_mat(a);
389    let b_mat = array2_to_mat(b);
390
391    let (m, n) = c.dim();
392    let mut c_mat: Mat<T> = Mat::zeros(m, n);
393
394    // Copy existing C values if beta != 0
395    if beta != T::zero() {
396        for i in 0..m {
397            for j in 0..n {
398                c_mat[(i, j)] = c[[i, j]];
399            }
400        }
401    }
402
403    blas_gemm(alpha, a_mat.as_ref(), b_mat.as_ref(), beta, c_mat.as_mut());
404
405    // Copy result back
406    for i in 0..m {
407        for j in 0..n {
408            c[[i, j]] = c_mat[(i, j)];
409        }
410    }
411}
412
413/// Matrix multiplication: C = A·B
414///
415/// Simplified version that allocates a new output matrix.
416pub fn matmul<T: Field + GemmKernel>(a: &Array2<T>, b: &Array2<T>) -> Array2<T>
417where
418    T: bytemuck::Zeroable + Clone,
419{
420    let (m, k1) = a.dim();
421    let (k2, n) = b.dim();
422    assert_eq!(k1, k2, "Inner dimensions must match: {} vs {}", k1, k2);
423
424    let a_mat = array2_to_mat(a);
425    let b_mat = array2_to_mat(b);
426    let mut c_mat: Mat<T> = Mat::zeros(m, n);
427
428    blas_gemm(
429        T::one(),
430        a_mat.as_ref(),
431        b_mat.as_ref(),
432        T::zero(),
433        c_mat.as_mut(),
434    );
435
436    // Create output in column-major order for efficiency
437    Array2::from_shape_fn((m, n).f(), |(i, j)| c_mat[(i, j)])
438}
439
440/// Matrix multiplication returning row-major output.
441pub fn matmul_c<T: Field + GemmKernel>(a: &Array2<T>, b: &Array2<T>) -> Array2<T>
442where
443    T: bytemuck::Zeroable + Clone,
444{
445    let (m, k1) = a.dim();
446    let (k2, n) = b.dim();
447    assert_eq!(k1, k2, "Inner dimensions must match");
448
449    let a_mat = array2_to_mat(a);
450    let b_mat = array2_to_mat(b);
451    let mut c_mat: Mat<T> = Mat::zeros(m, n);
452
453    blas_gemm(
454        T::one(),
455        a_mat.as_ref(),
456        b_mat.as_ref(),
457        T::zero(),
458        c_mat.as_mut(),
459    );
460
461    // Row-major output
462    Array2::from_shape_fn((m, n), |(i, j)| c_mat[(i, j)])
463}
464
465/// In-place matrix multiplication: C = A·B (C is reallocated)
466pub fn matmul_into<T: Field + GemmKernel>(a: &Array2<T>, b: &Array2<T>, c: &mut Array2<T>)
467where
468    T: bytemuck::Zeroable + Clone,
469{
470    gemm_ndarray(T::one(), a, b, T::zero(), c);
471}
472
473// =============================================================================
474// Matrix Norms
475// =============================================================================
476
477/// Computes the Frobenius norm of a matrix.
478///
479/// ||A||_F = sqrt(sum(a_ij^2))
480pub fn frobenius_norm<T: Field + oxiblas_core::scalar::Real>(a: &Array2<T>) -> T {
481    let mut sum = T::zero();
482    for val in a.iter() {
483        sum += (*val) * (*val);
484    }
485    oxiblas_core::scalar::Real::sqrt(sum)
486}
487
488/// Computes the 1-norm (maximum column sum) of a matrix.
489///
490/// For real types where `Real = T`.
491pub fn norm_1(a: &Array2<f64>) -> f64 {
492    let (nrows, ncols) = a.dim();
493    let mut max_sum = 0.0f64;
494
495    for j in 0..ncols {
496        let mut col_sum = 0.0f64;
497        for i in 0..nrows {
498            col_sum += a[[i, j]].abs();
499        }
500        if col_sum > max_sum {
501            max_sum = col_sum;
502        }
503    }
504
505    max_sum
506}
507
508/// Computes the infinity-norm (maximum row sum) of a matrix.
509///
510/// For real types where `Real = T`.
511pub fn norm_inf(a: &Array2<f64>) -> f64 {
512    let (nrows, ncols) = a.dim();
513    let mut max_sum = 0.0f64;
514
515    for i in 0..nrows {
516        let mut row_sum = 0.0f64;
517        for j in 0..ncols {
518            row_sum += a[[i, j]].abs();
519        }
520        if row_sum > max_sum {
521            max_sum = row_sum;
522        }
523    }
524
525    max_sum
526}
527
528/// Computes the maximum absolute element of a matrix.
529///
530/// For real types where `Real = T`.
531pub fn norm_max(a: &Array2<f64>) -> f64 {
532    let mut max_val = 0.0f64;
533    for val in a.iter() {
534        let abs_val = val.abs();
535        if abs_val > max_val {
536            max_val = abs_val;
537        }
538    }
539    max_val
540}
541
542// =============================================================================
543// Additional Operations
544// =============================================================================
545
546/// Computes the trace of a square matrix.
547pub fn trace<T: Field>(a: &Array2<T>) -> T {
548    let (nrows, ncols) = a.dim();
549    assert_eq!(nrows, ncols, "Matrix must be square for trace");
550
551    let mut sum = T::zero();
552    for i in 0..nrows {
553        sum += a[[i, i]];
554    }
555    sum
556}
557
558/// Transposes a matrix.
559pub fn transpose<T: Clone>(a: &Array2<T>) -> Array2<T> {
560    a.t().to_owned()
561}
562
563/// Creates an identity matrix.
564pub fn eye<T: Field>(n: usize) -> Array2<T>
565where
566    T: Clone,
567{
568    let mut result = Array2::zeros((n, n));
569    for i in 0..n {
570        result[[i, i]] = T::one();
571    }
572    result
573}
574
575/// Creates an identity matrix in column-major order.
576pub fn eye_f<T: Field>(n: usize) -> Array2<T>
577where
578    T: Clone,
579{
580    let mut result: Array2<T> = Array2::from_shape_fn((n, n).f(), |_| T::zero());
581    for i in 0..n {
582        result[[i, i]] = T::one();
583    }
584    result
585}
586
587// =============================================================================
588// Complex Matrix Operations
589// =============================================================================
590
591/// Computes the Hermitian (conjugate) transpose of a Complex64 matrix.
592///
593/// Returns A^H where (A^H)\[i,j\] = conj(A\[j,i\])
594///
595/// # Example
596/// ```
597/// use oxiblas_ndarray::blas::conj_transpose_c64;
598/// use ndarray::array;
599/// use num_complex::Complex64;
600///
601/// let a = array![
602///     [Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
603///     [Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)]
604/// ];
605/// let ah = conj_transpose_c64(&a);
606/// // ah[0,0] = conj(a[0,0]) = 1 - 2i
607/// assert!((ah[[0, 0]].re - 1.0).abs() < 1e-10);
608/// assert!((ah[[0, 0]].im - (-2.0)).abs() < 1e-10);
609/// // ah[0,1] = conj(a[1,0]) = 5 - 6i
610/// assert!((ah[[0, 1]].re - 5.0).abs() < 1e-10);
611/// assert!((ah[[0, 1]].im - (-6.0)).abs() < 1e-10);
612/// ```
613pub fn conj_transpose_c64(a: &Array2<Complex64>) -> Array2<Complex64> {
614    let (m, n) = a.dim();
615    Array2::from_shape_fn((n, m), |(i, j)| a[[j, i]].conj())
616}
617
618/// Computes the Hermitian (conjugate) transpose of a Complex32 matrix.
619///
620/// Returns A^H where (A^H)\[i,j\] = conj(A\[j,i\])
621pub fn conj_transpose_c32(a: &Array2<Complex32>) -> Array2<Complex32> {
622    let (m, n) = a.dim();
623    Array2::from_shape_fn((n, m), |(i, j)| a[[j, i]].conj())
624}
625
626/// Computes the Frobenius norm of a Complex64 matrix.
627///
628/// ||A||_F = sqrt(Σ |a\[i,j\]|²) = sqrt(Σ (a\[i,j\].re² + a\[i,j\].im²))
629///
630/// This is equivalent to sqrt(trace(A^H * A)).
631pub fn frobenius_norm_c64(a: &Array2<Complex64>) -> f64 {
632    let mut sum = 0.0f64;
633    for val in a.iter() {
634        sum += val.norm_sqr();
635    }
636    sum.sqrt()
637}
638
639/// Computes the Frobenius norm of a Complex32 matrix.
640///
641/// ||A||_F = sqrt(Σ |a\[i,j\]|²)
642pub fn frobenius_norm_c32(a: &Array2<Complex32>) -> f32 {
643    let mut sum = 0.0f32;
644    for val in a.iter() {
645        sum += val.norm_sqr();
646    }
647    sum.sqrt()
648}
649
650/// Computes the 1-norm (maximum column sum of absolute values) of a Complex64 matrix.
651///
652/// ||A||_1 = max_j Σ_i |a\[i,j\]|
653pub fn norm_1_c64(a: &Array2<Complex64>) -> f64 {
654    let (nrows, ncols) = a.dim();
655    let mut max_sum = 0.0f64;
656
657    for j in 0..ncols {
658        let mut col_sum = 0.0f64;
659        for i in 0..nrows {
660            col_sum += a[[i, j]].norm();
661        }
662        if col_sum > max_sum {
663            max_sum = col_sum;
664        }
665    }
666
667    max_sum
668}
669
670/// Computes the 1-norm of a Complex32 matrix.
671pub fn norm_1_c32(a: &Array2<Complex32>) -> f32 {
672    let (nrows, ncols) = a.dim();
673    let mut max_sum = 0.0f32;
674
675    for j in 0..ncols {
676        let mut col_sum = 0.0f32;
677        for i in 0..nrows {
678            col_sum += a[[i, j]].norm();
679        }
680        if col_sum > max_sum {
681            max_sum = col_sum;
682        }
683    }
684
685    max_sum
686}
687
688/// Computes the infinity-norm (maximum row sum of absolute values) of a Complex64 matrix.
689///
690/// ||A||_∞ = max_i Σ_j |a\[i,j\]|
691pub fn norm_inf_c64(a: &Array2<Complex64>) -> f64 {
692    let (nrows, ncols) = a.dim();
693    let mut max_sum = 0.0f64;
694
695    for i in 0..nrows {
696        let mut row_sum = 0.0f64;
697        for j in 0..ncols {
698            row_sum += a[[i, j]].norm();
699        }
700        if row_sum > max_sum {
701            max_sum = row_sum;
702        }
703    }
704
705    max_sum
706}
707
708/// Computes the infinity-norm of a Complex32 matrix.
709pub fn norm_inf_c32(a: &Array2<Complex32>) -> f32 {
710    let (nrows, ncols) = a.dim();
711    let mut max_sum = 0.0f32;
712
713    for i in 0..nrows {
714        let mut row_sum = 0.0f32;
715        for j in 0..ncols {
716            row_sum += a[[i, j]].norm();
717        }
718        if row_sum > max_sum {
719            max_sum = row_sum;
720        }
721    }
722
723    max_sum
724}
725
726/// Computes the maximum absolute element of a Complex64 matrix.
727///
728/// max |a\[i,j\]|
729pub fn norm_max_c64(a: &Array2<Complex64>) -> f64 {
730    let mut max_val = 0.0f64;
731    for val in a.iter() {
732        let abs_val = val.norm();
733        if abs_val > max_val {
734            max_val = abs_val;
735        }
736    }
737    max_val
738}
739
740/// Computes the maximum absolute element of a Complex32 matrix.
741pub fn norm_max_c32(a: &Array2<Complex32>) -> f32 {
742    let mut max_val = 0.0f32;
743    for val in a.iter() {
744        let abs_val = val.norm();
745        if abs_val > max_val {
746            max_val = abs_val;
747        }
748    }
749    max_val
750}
751
752/// Computes the trace of a Complex64 square matrix.
753pub fn trace_c64(a: &Array2<Complex64>) -> Complex64 {
754    let (nrows, ncols) = a.dim();
755    assert_eq!(nrows, ncols, "Matrix must be square for trace");
756
757    let mut sum = Complex64::new(0.0, 0.0);
758    for i in 0..nrows {
759        sum += a[[i, i]];
760    }
761    sum
762}
763
764/// Computes the trace of a Complex32 square matrix.
765pub fn trace_c32(a: &Array2<Complex32>) -> Complex32 {
766    let (nrows, ncols) = a.dim();
767    assert_eq!(nrows, ncols, "Matrix must be square for trace");
768
769    let mut sum = Complex32::new(0.0, 0.0);
770    for i in 0..nrows {
771        sum += a[[i, i]];
772    }
773    sum
774}
775
776/// Scales a Complex64 vector: x = α·x
777pub fn scal_c64_ndarray(alpha: Complex64, x: &mut Array1<Complex64>) {
778    for xi in x.iter_mut() {
779        *xi = alpha * (*xi);
780    }
781}
782
783/// Scales a Complex32 vector: x = α·x
784pub fn scal_c32_ndarray(alpha: Complex32, x: &mut Array1<Complex32>) {
785    for xi in x.iter_mut() {
786        *xi = alpha * (*xi);
787    }
788}
789
790/// AXPY operation for Complex64: y = α·x + y
791pub fn axpy_c64_ndarray(alpha: Complex64, x: &Array1<Complex64>, y: &mut Array1<Complex64>) {
792    assert_eq!(x.len(), y.len(), "Vector lengths must match");
793
794    for (yi, xi) in y.iter_mut().zip(x.iter()) {
795        *yi = alpha * (*xi) + *yi;
796    }
797}
798
799/// AXPY operation for Complex32: y = α·x + y
800pub fn axpy_c32_ndarray(alpha: Complex32, x: &Array1<Complex32>, y: &mut Array1<Complex32>) {
801    assert_eq!(x.len(), y.len(), "Vector lengths must match");
802
803    for (yi, xi) in y.iter_mut().zip(x.iter()) {
804        *yi = alpha * (*xi) + *yi;
805    }
806}
807
808/// Creates a Complex64 identity matrix.
809pub fn eye_c64(n: usize) -> Array2<Complex64> {
810    let mut result: Array2<Complex64> = Array2::from_elem((n, n), Complex64::new(0.0, 0.0));
811    for i in 0..n {
812        result[[i, i]] = Complex64::new(1.0, 0.0);
813    }
814    result
815}
816
817/// Creates a Complex32 identity matrix.
818pub fn eye_c32(n: usize) -> Array2<Complex32> {
819    let mut result: Array2<Complex32> = Array2::from_elem((n, n), Complex32::new(0.0, 0.0));
820    for i in 0..n {
821        result[[i, i]] = Complex32::new(1.0, 0.0);
822    }
823    result
824}
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829    use ndarray::array;
830
831    #[test]
832    fn test_dot_ndarray() {
833        let x = array![1.0f64, 2.0, 3.0];
834        let y = array![4.0f64, 5.0, 6.0];
835        let d = dot_ndarray(&x, &y);
836        assert!((d - 32.0).abs() < 1e-10);
837    }
838
839    #[test]
840    fn test_nrm2_ndarray() {
841        let x = array![3.0f64, 4.0];
842        let norm = nrm2_ndarray(&x);
843        assert!((norm - 5.0).abs() < 1e-10);
844    }
845
846    #[test]
847    fn test_asum_ndarray() {
848        let x = array![-1.0f64, 2.0, -3.0];
849        let sum = asum_ndarray(&x);
850        assert!((sum - 6.0).abs() < 1e-10);
851    }
852
853    #[test]
854    fn test_axpy_ndarray() {
855        let x = array![1.0f64, 2.0, 3.0];
856        let mut y = array![4.0f64, 5.0, 6.0];
857        axpy_ndarray(2.0, &x, &mut y);
858        assert!((y[0] - 6.0).abs() < 1e-10);
859        assert!((y[1] - 9.0).abs() < 1e-10);
860        assert!((y[2] - 12.0).abs() < 1e-10);
861    }
862
863    #[test]
864    fn test_scal_ndarray() {
865        let mut x = array![1.0f64, 2.0, 3.0];
866        scal_ndarray(2.0, &mut x);
867        assert!((x[0] - 2.0).abs() < 1e-10);
868        assert!((x[1] - 4.0).abs() < 1e-10);
869        assert!((x[2] - 6.0).abs() < 1e-10);
870    }
871
872    #[test]
873    fn test_gemv_notrans() {
874        let a = Array2::from_shape_fn((2, 3), |(i, j)| (i * 3 + j + 1) as f64);
875        let x = array![1.0f64, 1.0, 1.0];
876        let mut y = array![0.0f64, 0.0];
877
878        gemv_ndarray(Transpose::NoTrans, 1.0, &a, &x, 0.0, &mut y);
879
880        // y[0] = 1 + 2 + 3 = 6
881        // y[1] = 4 + 5 + 6 = 15
882        assert!((y[0] - 6.0).abs() < 1e-10);
883        assert!((y[1] - 15.0).abs() < 1e-10);
884    }
885
886    #[test]
887    fn test_gemv_trans() {
888        let a = Array2::from_shape_fn((2, 3), |(i, j)| (i * 3 + j + 1) as f64);
889        let x = array![1.0f64, 1.0];
890        let mut y = array![0.0f64, 0.0, 0.0];
891
892        gemv_ndarray(Transpose::Trans, 1.0, &a, &x, 0.0, &mut y);
893
894        // y[0] = 1 + 4 = 5
895        // y[1] = 2 + 5 = 7
896        // y[2] = 3 + 6 = 9
897        assert!((y[0] - 5.0).abs() < 1e-10);
898        assert!((y[1] - 7.0).abs() < 1e-10);
899        assert!((y[2] - 9.0).abs() < 1e-10);
900    }
901
902    #[test]
903    fn test_matvec() {
904        let a = Array2::from_shape_fn((2, 3), |(i, j)| (i * 3 + j + 1) as f64);
905        let x = array![1.0f64, 2.0, 3.0];
906        let y = matvec(&a, &x);
907
908        // y[0] = 1*1 + 2*2 + 3*3 = 14
909        // y[1] = 4*1 + 5*2 + 6*3 = 32
910        assert!((y[0] - 14.0).abs() < 1e-10);
911        assert!((y[1] - 32.0).abs() < 1e-10);
912    }
913
914    #[test]
915    fn test_matmul() {
916        let a = Array2::from_shape_fn((2, 3), |_| 1.0f64);
917        let b = Array2::from_shape_fn((3, 2), |_| 2.0f64);
918        let c = matmul(&a, &b);
919
920        assert_eq!(c.dim(), (2, 2));
921        for i in 0..2 {
922            for j in 0..2 {
923                assert!((c[[i, j]] - 6.0).abs() < 1e-10);
924            }
925        }
926    }
927
928    #[test]
929    fn test_gemm_ndarray() {
930        let a = Array2::from_shape_fn((2, 3), |_| 1.0f64);
931        let b = Array2::from_shape_fn((3, 2), |_| 2.0f64);
932        let mut c = Array2::from_shape_fn((2, 2), |_| 1.0f64);
933
934        gemm_ndarray(1.0, &a, &b, 1.0, &mut c);
935
936        // C = 1 * A * B + 1 * C = 6 + 1 = 7
937        for i in 0..2 {
938            for j in 0..2 {
939                assert!((c[[i, j]] - 7.0).abs() < 1e-10);
940            }
941        }
942    }
943
944    #[test]
945    fn test_frobenius_norm() {
946        let a = array![[1.0f64, 2.0], [3.0, 4.0]];
947        let norm = frobenius_norm(&a);
948        // sqrt(1 + 4 + 9 + 16) = sqrt(30)
949        assert!((norm - 30.0f64.sqrt()).abs() < 1e-10);
950    }
951
952    #[test]
953    fn test_norm_1() {
954        let a = array![[1.0f64, 2.0], [3.0, 4.0]];
955        let norm = norm_1(&a);
956        // max(1+3, 2+4) = max(4, 6) = 6
957        assert!((norm - 6.0).abs() < 1e-10);
958    }
959
960    #[test]
961    fn test_norm_inf() {
962        let a = array![[1.0f64, 2.0], [3.0, 4.0]];
963        let norm = norm_inf(&a);
964        // max(1+2, 3+4) = max(3, 7) = 7
965        assert!((norm - 7.0).abs() < 1e-10);
966    }
967
968    #[test]
969    fn test_trace() {
970        let a = array![[1.0f64, 2.0], [3.0, 4.0]];
971        let tr = trace(&a);
972        assert!((tr - 5.0).abs() < 1e-10);
973    }
974
975    #[test]
976    fn test_eye() {
977        let id: Array2<f64> = eye(3);
978        for i in 0..3 {
979            for j in 0..3 {
980                if i == j {
981                    assert!((id[[i, j]] - 1.0).abs() < 1e-15);
982                } else {
983                    assert!(id[[i, j]].abs() < 1e-15);
984                }
985            }
986        }
987    }
988
989    #[test]
990    fn test_transpose() {
991        let a = array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
992        let at = transpose(&a);
993        assert_eq!(at.dim(), (3, 2));
994        assert!((at[[0, 0]] - 1.0).abs() < 1e-15);
995        assert!((at[[2, 1]] - 6.0).abs() < 1e-15);
996    }
997
998    // =========================================================================
999    // Complex Number Tests
1000    // =========================================================================
1001
1002    #[test]
1003    fn test_dotc_c64_ndarray() {
1004        // x = [1+2i, 3+4i], y = [5+6i, 7+8i]
1005        // conj(x) * y = (1-2i)(5+6i) + (3-4i)(7+8i)
1006        //             = (5+12) + (6-10)i + (21+32) + (24-28)i
1007        //             = 70 - 8i
1008        let x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1009        let y = array![Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)];
1010
1011        let result = dotc_c64_ndarray(&x, &y);
1012        assert!((result.re - 70.0).abs() < 1e-10);
1013        assert!((result.im - (-8.0)).abs() < 1e-10);
1014    }
1015
1016    #[test]
1017    fn test_dotc_c32_ndarray() {
1018        let x = array![Complex32::new(1.0, 2.0), Complex32::new(3.0, 4.0)];
1019        let y = array![Complex32::new(5.0, 6.0), Complex32::new(7.0, 8.0)];
1020
1021        let result = dotc_c32_ndarray(&x, &y);
1022        assert!((result.re - 70.0).abs() < 1e-5);
1023        assert!((result.im - (-8.0)).abs() < 1e-5);
1024    }
1025
1026    #[test]
1027    fn test_dotu_c64_ndarray() {
1028        // x = [1+2i, 3+4i], y = [5+6i, 7+8i]
1029        // x * y = (1+2i)(5+6i) + (3+4i)(7+8i)
1030        //       = (5-12) + (6+10)i + (21-32) + (24+28)i
1031        //       = -18 + 68i
1032        let x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1033        let y = array![Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)];
1034
1035        let result = dotu_c64_ndarray(&x, &y);
1036        assert!((result.re - (-18.0)).abs() < 1e-10);
1037        assert!((result.im - 68.0).abs() < 1e-10);
1038    }
1039
1040    #[test]
1041    fn test_dotu_c32_ndarray() {
1042        let x = array![Complex32::new(1.0, 2.0), Complex32::new(3.0, 4.0)];
1043        let y = array![Complex32::new(5.0, 6.0), Complex32::new(7.0, 8.0)];
1044
1045        let result = dotu_c32_ndarray(&x, &y);
1046        assert!((result.re - (-18.0)).abs() < 1e-5);
1047        assert!((result.im - 68.0).abs() < 1e-5);
1048    }
1049
1050    #[test]
1051    fn test_dotc_c64_self_inner_product() {
1052        // x^H * x should be real and equal to ||x||^2
1053        let x = array![
1054            Complex64::new(1.0, 2.0),
1055            Complex64::new(3.0, 4.0),
1056            Complex64::new(5.0, 6.0)
1057        ];
1058
1059        let result = dotc_c64_ndarray(&x, &x);
1060
1061        // Should be purely real
1062        assert!(result.im.abs() < 1e-10);
1063
1064        // Should equal sum of |x_i|^2 = (1+4) + (9+16) + (25+36) = 5 + 25 + 61 = 91
1065        assert!((result.re - 91.0).abs() < 1e-10);
1066    }
1067
1068    #[test]
1069    fn test_nrm2_c64_ndarray() {
1070        // ||x||_2 = sqrt(sum(|x_i|^2))
1071        let x = array![Complex64::new(3.0, 4.0)]; // |3+4i| = 5
1072        let norm = nrm2_c64_ndarray(&x);
1073        assert!((norm - 5.0).abs() < 1e-10);
1074
1075        let x = array![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)];
1076        let norm = nrm2_c64_ndarray(&x);
1077        // sqrt(1 + 1) = sqrt(2)
1078        assert!((norm - 2.0f64.sqrt()).abs() < 1e-10);
1079    }
1080
1081    #[test]
1082    fn test_nrm2_c32_ndarray() {
1083        let x = array![Complex32::new(3.0, 4.0)];
1084        let norm = nrm2_c32_ndarray(&x);
1085        assert!((norm - 5.0).abs() < 1e-5);
1086    }
1087
1088    #[test]
1089    fn test_asum_c64_ndarray() {
1090        // sum of |x_i|
1091        let x = array![Complex64::new(3.0, 4.0), Complex64::new(5.0, 12.0)];
1092        // |3+4i| = 5, |5+12i| = 13
1093        let sum = asum_c64_ndarray(&x);
1094        assert!((sum - 18.0).abs() < 1e-10);
1095    }
1096
1097    #[test]
1098    fn test_asum_c32_ndarray() {
1099        let x = array![Complex32::new(3.0, 4.0), Complex32::new(5.0, 12.0)];
1100        let sum = asum_c32_ndarray(&x);
1101        assert!((sum - 18.0).abs() < 1e-5);
1102    }
1103
1104    #[test]
1105    fn test_conj_transpose_c64() {
1106        let a = array![
1107            [Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
1108            [Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)]
1109        ];
1110
1111        let ah = conj_transpose_c64(&a);
1112        assert_eq!(ah.dim(), (2, 2));
1113
1114        // ah[0,0] = conj(a[0,0]) = 1-2i
1115        assert!((ah[[0, 0]].re - 1.0).abs() < 1e-10);
1116        assert!((ah[[0, 0]].im - (-2.0)).abs() < 1e-10);
1117
1118        // ah[0,1] = conj(a[1,0]) = 5-6i
1119        assert!((ah[[0, 1]].re - 5.0).abs() < 1e-10);
1120        assert!((ah[[0, 1]].im - (-6.0)).abs() < 1e-10);
1121
1122        // ah[1,0] = conj(a[0,1]) = 3-4i
1123        assert!((ah[[1, 0]].re - 3.0).abs() < 1e-10);
1124        assert!((ah[[1, 0]].im - (-4.0)).abs() < 1e-10);
1125
1126        // ah[1,1] = conj(a[1,1]) = 7-8i
1127        assert!((ah[[1, 1]].re - 7.0).abs() < 1e-10);
1128        assert!((ah[[1, 1]].im - (-8.0)).abs() < 1e-10);
1129    }
1130
1131    #[test]
1132    fn test_conj_transpose_c64_rectangular() {
1133        let a = array![
1134            [
1135                Complex64::new(1.0, 1.0),
1136                Complex64::new(2.0, 2.0),
1137                Complex64::new(3.0, 3.0)
1138            ],
1139            [
1140                Complex64::new(4.0, 4.0),
1141                Complex64::new(5.0, 5.0),
1142                Complex64::new(6.0, 6.0)
1143            ]
1144        ];
1145
1146        let ah = conj_transpose_c64(&a);
1147        assert_eq!(ah.dim(), (3, 2));
1148
1149        // ah[2,1] = conj(a[1,2]) = 6-6i
1150        assert!((ah[[2, 1]].re - 6.0).abs() < 1e-10);
1151        assert!((ah[[2, 1]].im - (-6.0)).abs() < 1e-10);
1152    }
1153
1154    #[test]
1155    fn test_frobenius_norm_c64() {
1156        let a = array![
1157            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
1158            [Complex64::new(0.0, 1.0), Complex64::new(1.0, 0.0)]
1159        ];
1160        // |1|^2 + |i|^2 + |i|^2 + |1|^2 = 1 + 1 + 1 + 1 = 4
1161        let norm = frobenius_norm_c64(&a);
1162        assert!((norm - 2.0).abs() < 1e-10);
1163    }
1164
1165    #[test]
1166    fn test_frobenius_norm_c32() {
1167        let a = array![
1168            [Complex32::new(3.0, 4.0)] // |3+4i| = 5, |3+4i|^2 = 25
1169        ];
1170        let norm = frobenius_norm_c32(&a);
1171        assert!((norm - 5.0).abs() < 1e-5);
1172    }
1173
1174    #[test]
1175    fn test_norm_1_c64() {
1176        let a = array![
1177            [Complex64::new(3.0, 4.0), Complex64::new(0.0, 1.0)],
1178            [Complex64::new(0.0, 0.0), Complex64::new(5.0, 12.0)]
1179        ];
1180        // col 0: |3+4i| + |0| = 5 + 0 = 5
1181        // col 1: |i| + |5+12i| = 1 + 13 = 14
1182        let norm = norm_1_c64(&a);
1183        assert!((norm - 14.0).abs() < 1e-10);
1184    }
1185
1186    #[test]
1187    fn test_norm_inf_c64() {
1188        let a = array![
1189            [Complex64::new(3.0, 4.0), Complex64::new(0.0, 1.0)],
1190            [Complex64::new(0.0, 0.0), Complex64::new(5.0, 12.0)]
1191        ];
1192        // row 0: |3+4i| + |i| = 5 + 1 = 6
1193        // row 1: |0| + |5+12i| = 0 + 13 = 13
1194        let norm = norm_inf_c64(&a);
1195        assert!((norm - 13.0).abs() < 1e-10);
1196    }
1197
1198    #[test]
1199    fn test_norm_max_c64() {
1200        let a = array![
1201            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 4.0)],
1202            [Complex64::new(5.0, 12.0), Complex64::new(0.0, 1.0)]
1203        ];
1204        // max(|1|, |3+4i|, |5+12i|, |i|) = max(1, 5, 13, 1) = 13
1205        let max = norm_max_c64(&a);
1206        assert!((max - 13.0).abs() < 1e-10);
1207    }
1208
1209    #[test]
1210    fn test_trace_c64() {
1211        let a = array![
1212            [Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
1213            [Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)]
1214        ];
1215        // trace = (1+2i) + (7+8i) = 8 + 10i
1216        let tr = trace_c64(&a);
1217        assert!((tr.re - 8.0).abs() < 1e-10);
1218        assert!((tr.im - 10.0).abs() < 1e-10);
1219    }
1220
1221    #[test]
1222    fn test_trace_c32() {
1223        let a = array![
1224            [Complex32::new(1.0, 2.0), Complex32::new(3.0, 4.0)],
1225            [Complex32::new(5.0, 6.0), Complex32::new(7.0, 8.0)]
1226        ];
1227        let tr = trace_c32(&a);
1228        assert!((tr.re - 8.0).abs() < 1e-5);
1229        assert!((tr.im - 10.0).abs() < 1e-5);
1230    }
1231
1232    #[test]
1233    fn test_scal_c64_ndarray() {
1234        let mut x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1235        let alpha = Complex64::new(2.0, 0.0);
1236        scal_c64_ndarray(alpha, &mut x);
1237
1238        assert!((x[0].re - 2.0).abs() < 1e-10);
1239        assert!((x[0].im - 4.0).abs() < 1e-10);
1240        assert!((x[1].re - 6.0).abs() < 1e-10);
1241        assert!((x[1].im - 8.0).abs() < 1e-10);
1242    }
1243
1244    #[test]
1245    fn test_scal_c64_ndarray_complex_alpha() {
1246        let mut x = array![Complex64::new(1.0, 0.0)];
1247        let alpha = Complex64::new(0.0, 1.0); // i
1248        scal_c64_ndarray(alpha, &mut x);
1249
1250        // i * 1 = i
1251        assert!((x[0].re - 0.0).abs() < 1e-10);
1252        assert!((x[0].im - 1.0).abs() < 1e-10);
1253    }
1254
1255    #[test]
1256    fn test_axpy_c64_ndarray() {
1257        let x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1258        let mut y = array![Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)];
1259        let alpha = Complex64::new(2.0, 0.0);
1260
1261        axpy_c64_ndarray(alpha, &x, &mut y);
1262
1263        // y = 2*x + y = 2*(1+2i) + (5+6i) = (2+4i) + (5+6i) = 7+10i
1264        assert!((y[0].re - 7.0).abs() < 1e-10);
1265        assert!((y[0].im - 10.0).abs() < 1e-10);
1266
1267        // y = 2*(3+4i) + (7+8i) = (6+8i) + (7+8i) = 13+16i
1268        assert!((y[1].re - 13.0).abs() < 1e-10);
1269        assert!((y[1].im - 16.0).abs() < 1e-10);
1270    }
1271
1272    #[test]
1273    fn test_axpy_c32_ndarray() {
1274        let x = array![Complex32::new(1.0, 2.0)];
1275        let mut y = array![Complex32::new(3.0, 4.0)];
1276        let alpha = Complex32::new(0.0, 1.0); // i
1277
1278        axpy_c32_ndarray(alpha, &x, &mut y);
1279
1280        // y = i*(1+2i) + (3+4i) = (i - 2) + (3+4i) = (1) + (5i)
1281        assert!((y[0].re - 1.0).abs() < 1e-5);
1282        assert!((y[0].im - 5.0).abs() < 1e-5);
1283    }
1284
1285    #[test]
1286    fn test_eye_c64() {
1287        let id = eye_c64(3);
1288        assert_eq!(id.dim(), (3, 3));
1289
1290        for i in 0..3 {
1291            for j in 0..3 {
1292                if i == j {
1293                    assert!((id[[i, j]].re - 1.0).abs() < 1e-10);
1294                    assert!(id[[i, j]].im.abs() < 1e-10);
1295                } else {
1296                    assert!(id[[i, j]].re.abs() < 1e-10);
1297                    assert!(id[[i, j]].im.abs() < 1e-10);
1298                }
1299            }
1300        }
1301    }
1302
1303    #[test]
1304    fn test_eye_c32() {
1305        let id = eye_c32(2);
1306        assert_eq!(id.dim(), (2, 2));
1307        assert!((id[[0, 0]].re - 1.0).abs() < 1e-5);
1308        assert!((id[[1, 1]].re - 1.0).abs() < 1e-5);
1309        assert!(id[[0, 1]].re.abs() < 1e-5);
1310        assert!(id[[1, 0]].re.abs() < 1e-5);
1311    }
1312
1313    #[test]
1314    fn test_dotc_c64_large() {
1315        // Test with larger arrays to verify SIMD path
1316        let n = 1000;
1317        let x: Array1<Complex64> =
1318            Array1::from_shape_fn(n, |i| Complex64::new(i as f64, (i as f64) * 0.5));
1319        let y: Array1<Complex64> =
1320            Array1::from_shape_fn(n, |i| Complex64::new(1.0, 0.1 * i as f64));
1321
1322        let result = dotc_c64_ndarray(&x, &y);
1323
1324        // Verify against manual computation
1325        let expected: Complex64 = x.iter().zip(y.iter()).map(|(xi, yi)| xi.conj() * yi).sum();
1326        assert!((result.re - expected.re).abs() < 1e-6);
1327        assert!((result.im - expected.im).abs() < 1e-6);
1328    }
1329
1330    #[test]
1331    fn test_dotu_c64_large() {
1332        let n = 1000;
1333        let x: Array1<Complex64> = Array1::from_shape_fn(n, |i| {
1334            Complex64::new((i % 100) as f64, ((i + 50) % 100) as f64)
1335        });
1336        let y: Array1<Complex64> = Array1::from_shape_fn(n, |i| {
1337            Complex64::new(((i + 25) % 100) as f64, ((i + 75) % 100) as f64)
1338        });
1339
1340        let result = dotu_c64_ndarray(&x, &y);
1341
1342        let expected: Complex64 = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
1343        assert!((result.re - expected.re).abs() < 1e-6);
1344        assert!((result.im - expected.im).abs() < 1e-6);
1345    }
1346
1347    #[test]
1348    fn test_hermitian_property() {
1349        // For a Hermitian matrix A = A^H, verify property holds
1350        let a = array![
1351            [Complex64::new(2.0, 0.0), Complex64::new(1.0, 1.0)],
1352            [Complex64::new(1.0, -1.0), Complex64::new(3.0, 0.0)]
1353        ];
1354
1355        let ah = conj_transpose_c64(&a);
1356
1357        // A should equal A^H for Hermitian matrix
1358        for i in 0..2 {
1359            for j in 0..2 {
1360                assert!((a[[i, j]].re - ah[[i, j]].re).abs() < 1e-10);
1361                assert!((a[[i, j]].im - ah[[i, j]].im).abs() < 1e-10);
1362            }
1363        }
1364    }
1365}