1use crate::conversions::array2_to_mat;
7use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder};
8use num_complex::{Complex32, Complex64};
9use oxiblas_blas::level1::{asum, axpy, dot, dotc_c32, dotc_c64, dotu_c32, dotu_c64, nrm2, scal};
10use oxiblas_blas::level2::{GemvTrans, gemv as blas_gemv};
11use oxiblas_blas::level3::{GemmKernel, gemm as blas_gemm};
12use oxiblas_core::scalar::Field;
13use oxiblas_matrix::Mat;
14
15pub fn dot_ndarray<T: Field>(x: &Array1<T>, y: &Array1<T>) -> T {
31 assert_eq!(x.len(), y.len(), "Vector lengths must match");
32
33 if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
35 dot(x_slice, y_slice)
36 } else {
37 let x_vec: Vec<T> = x.iter().cloned().collect();
39 let y_vec: Vec<T> = y.iter().cloned().collect();
40 dot(&x_vec, &y_vec)
41 }
42}
43
44pub fn dot_view<T: Field>(x: &ArrayView1<T>, y: &ArrayView1<T>) -> T {
46 assert_eq!(x.len(), y.len(), "Vector lengths must match");
47
48 if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
49 dot(x_slice, y_slice)
50 } else {
51 let x_vec: Vec<T> = x.iter().cloned().collect();
52 let y_vec: Vec<T> = y.iter().cloned().collect();
53 dot(&x_vec, &y_vec)
54 }
55}
56
57pub fn dotc_c64_ndarray(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
92 assert_eq!(x.len(), y.len(), "Vector lengths must match");
93
94 if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
95 dotc_c64(x_slice, y_slice)
96 } else {
97 let x_vec: Vec<Complex64> = x.iter().copied().collect();
98 let y_vec: Vec<Complex64> = y.iter().copied().collect();
99 dotc_c64(&x_vec, &y_vec)
100 }
101}
102
103pub fn dotc_c32_ndarray(x: &Array1<Complex32>, y: &Array1<Complex32>) -> Complex32 {
110 assert_eq!(x.len(), y.len(), "Vector lengths must match");
111
112 if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
113 dotc_c32(x_slice, y_slice)
114 } else {
115 let x_vec: Vec<Complex32> = x.iter().copied().collect();
116 let y_vec: Vec<Complex32> = y.iter().copied().collect();
117 dotc_c32(&x_vec, &y_vec)
118 }
119}
120
121pub fn dotu_c64_ndarray(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
131 assert_eq!(x.len(), y.len(), "Vector lengths must match");
132
133 if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
134 dotu_c64(x_slice, y_slice)
135 } else {
136 let x_vec: Vec<Complex64> = x.iter().copied().collect();
137 let y_vec: Vec<Complex64> = y.iter().copied().collect();
138 dotu_c64(&x_vec, &y_vec)
139 }
140}
141
142pub fn dotu_c32_ndarray(x: &Array1<Complex32>, y: &Array1<Complex32>) -> Complex32 {
149 assert_eq!(x.len(), y.len(), "Vector lengths must match");
150
151 if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice()) {
152 dotu_c32(x_slice, y_slice)
153 } else {
154 let x_vec: Vec<Complex32> = x.iter().copied().collect();
155 let y_vec: Vec<Complex32> = y.iter().copied().collect();
156 dotu_c32(&x_vec, &y_vec)
157 }
158}
159
160pub fn nrm2_c64_ndarray(x: &Array1<Complex64>) -> f64 {
166 let mut sum = 0.0f64;
167 for xi in x.iter() {
168 sum += xi.norm_sqr();
169 }
170 sum.sqrt()
171}
172
173pub fn nrm2_c32_ndarray(x: &Array1<Complex32>) -> f32 {
177 let mut sum = 0.0f32;
178 for xi in x.iter() {
179 sum += xi.norm_sqr();
180 }
181 sum.sqrt()
182}
183
184pub fn asum_c64_ndarray(x: &Array1<Complex64>) -> f64 {
188 let mut sum = 0.0f64;
189 for xi in x.iter() {
190 sum += xi.norm();
191 }
192 sum
193}
194
195pub fn asum_c32_ndarray(x: &Array1<Complex32>) -> f32 {
199 let mut sum = 0.0f32;
200 for xi in x.iter() {
201 sum += xi.norm();
202 }
203 sum
204}
205
206pub fn nrm2_ndarray<T: Field + oxiblas_core::scalar::Real>(x: &Array1<T>) -> T {
210 if let Some(slice) = x.as_slice() {
211 nrm2(slice)
212 } else {
213 let vec: Vec<T> = x.iter().cloned().collect();
214 nrm2(&vec)
215 }
216}
217
218pub fn asum_ndarray<T: Field + oxiblas_core::scalar::Real>(x: &Array1<T>) -> T {
222 if let Some(slice) = x.as_slice() {
223 asum(slice)
224 } else {
225 let vec: Vec<T> = x.iter().cloned().collect();
226 asum(&vec)
227 }
228}
229
230pub fn axpy_ndarray<T: Field>(alpha: T, x: &Array1<T>, y: &mut Array1<T>) {
237 assert_eq!(x.len(), y.len(), "Vector lengths must match");
238
239 if let (Some(x_slice), Some(y_slice)) = (x.as_slice(), y.as_slice_mut()) {
240 axpy(alpha, x_slice, y_slice);
241 } else {
242 for (yi, xi) in y.iter_mut().zip(x.iter()) {
244 *yi = alpha * (*xi) + *yi;
245 }
246 }
247}
248
249pub fn scal_ndarray<T: Field>(alpha: T, x: &mut Array1<T>) {
251 if let Some(slice) = x.as_slice_mut() {
252 scal(alpha, slice);
253 } else {
254 for xi in x.iter_mut() {
255 *xi = alpha * (*xi);
256 }
257 }
258}
259
260#[derive(Debug, Clone, Copy, PartialEq, Eq)]
266pub enum Transpose {
267 NoTrans,
269 Trans,
271 ConjTrans,
273}
274
275impl From<Transpose> for GemvTrans {
276 fn from(t: Transpose) -> Self {
277 match t {
278 Transpose::NoTrans => GemvTrans::NoTrans,
279 Transpose::Trans => GemvTrans::Trans,
280 Transpose::ConjTrans => GemvTrans::ConjTrans,
281 }
282 }
283}
284
285pub fn gemv_ndarray<T: Field + Clone>(
298 trans: Transpose,
299 alpha: T,
300 a: &Array2<T>,
301 x: &Array1<T>,
302 beta: T,
303 y: &mut Array1<T>,
304) where
305 T: bytemuck::Zeroable,
306{
307 let a_mat = array2_to_mat(a);
308 let (m, n) = a.dim();
309
310 let (x_len, y_len) = match trans {
312 Transpose::NoTrans => (n, m),
313 Transpose::Trans | Transpose::ConjTrans => (m, n),
314 };
315
316 assert_eq!(x.len(), x_len, "x dimension mismatch");
317 assert_eq!(y.len(), y_len, "y dimension mismatch");
318
319 let x_vec: Vec<T> = x.iter().cloned().collect();
321
322 if let Some(y_slice) = y.as_slice_mut() {
323 blas_gemv(trans.into(), alpha, a_mat.as_ref(), &x_vec, beta, y_slice);
324 } else {
325 let mut y_vec: Vec<T> = y.iter().cloned().collect();
326 blas_gemv(
327 trans.into(),
328 alpha,
329 a_mat.as_ref(),
330 &x_vec,
331 beta,
332 &mut y_vec,
333 );
334 for (yi, val) in y.iter_mut().zip(y_vec.into_iter()) {
335 *yi = val;
336 }
337 }
338}
339
340pub fn matvec<T: Field + Clone>(a: &Array2<T>, x: &Array1<T>) -> Array1<T>
344where
345 T: bytemuck::Zeroable,
346{
347 let (m, _n) = a.dim();
348 let mut y = Array1::zeros(m);
349 gemv_ndarray(Transpose::NoTrans, T::one(), a, x, T::zero(), &mut y);
350 y
351}
352
353pub fn matvec_t<T: Field + Clone>(a: &Array2<T>, x: &Array1<T>) -> Array1<T>
355where
356 T: bytemuck::Zeroable,
357{
358 let (_m, n) = a.dim();
359 let mut y = Array1::zeros(n);
360 gemv_ndarray(Transpose::Trans, T::one(), a, x, T::zero(), &mut y);
361 y
362}
363
364pub fn gemm_ndarray<T: Field + GemmKernel>(
380 alpha: T,
381 a: &Array2<T>,
382 b: &Array2<T>,
383 beta: T,
384 c: &mut Array2<T>,
385) where
386 T: bytemuck::Zeroable + Clone,
387{
388 let a_mat = array2_to_mat(a);
389 let b_mat = array2_to_mat(b);
390
391 let (m, n) = c.dim();
392 let mut c_mat: Mat<T> = Mat::zeros(m, n);
393
394 if beta != T::zero() {
396 for i in 0..m {
397 for j in 0..n {
398 c_mat[(i, j)] = c[[i, j]];
399 }
400 }
401 }
402
403 blas_gemm(alpha, a_mat.as_ref(), b_mat.as_ref(), beta, c_mat.as_mut());
404
405 for i in 0..m {
407 for j in 0..n {
408 c[[i, j]] = c_mat[(i, j)];
409 }
410 }
411}
412
413pub fn matmul<T: Field + GemmKernel>(a: &Array2<T>, b: &Array2<T>) -> Array2<T>
417where
418 T: bytemuck::Zeroable + Clone,
419{
420 let (m, k1) = a.dim();
421 let (k2, n) = b.dim();
422 assert_eq!(k1, k2, "Inner dimensions must match: {} vs {}", k1, k2);
423
424 let a_mat = array2_to_mat(a);
425 let b_mat = array2_to_mat(b);
426 let mut c_mat: Mat<T> = Mat::zeros(m, n);
427
428 blas_gemm(
429 T::one(),
430 a_mat.as_ref(),
431 b_mat.as_ref(),
432 T::zero(),
433 c_mat.as_mut(),
434 );
435
436 Array2::from_shape_fn((m, n).f(), |(i, j)| c_mat[(i, j)])
438}
439
440pub fn matmul_c<T: Field + GemmKernel>(a: &Array2<T>, b: &Array2<T>) -> Array2<T>
442where
443 T: bytemuck::Zeroable + Clone,
444{
445 let (m, k1) = a.dim();
446 let (k2, n) = b.dim();
447 assert_eq!(k1, k2, "Inner dimensions must match");
448
449 let a_mat = array2_to_mat(a);
450 let b_mat = array2_to_mat(b);
451 let mut c_mat: Mat<T> = Mat::zeros(m, n);
452
453 blas_gemm(
454 T::one(),
455 a_mat.as_ref(),
456 b_mat.as_ref(),
457 T::zero(),
458 c_mat.as_mut(),
459 );
460
461 Array2::from_shape_fn((m, n), |(i, j)| c_mat[(i, j)])
463}
464
465pub fn matmul_into<T: Field + GemmKernel>(a: &Array2<T>, b: &Array2<T>, c: &mut Array2<T>)
467where
468 T: bytemuck::Zeroable + Clone,
469{
470 gemm_ndarray(T::one(), a, b, T::zero(), c);
471}
472
473pub fn frobenius_norm<T: Field + oxiblas_core::scalar::Real>(a: &Array2<T>) -> T {
481 let mut sum = T::zero();
482 for val in a.iter() {
483 sum += (*val) * (*val);
484 }
485 oxiblas_core::scalar::Real::sqrt(sum)
486}
487
488pub fn norm_1(a: &Array2<f64>) -> f64 {
492 let (nrows, ncols) = a.dim();
493 let mut max_sum = 0.0f64;
494
495 for j in 0..ncols {
496 let mut col_sum = 0.0f64;
497 for i in 0..nrows {
498 col_sum += a[[i, j]].abs();
499 }
500 if col_sum > max_sum {
501 max_sum = col_sum;
502 }
503 }
504
505 max_sum
506}
507
508pub fn norm_inf(a: &Array2<f64>) -> f64 {
512 let (nrows, ncols) = a.dim();
513 let mut max_sum = 0.0f64;
514
515 for i in 0..nrows {
516 let mut row_sum = 0.0f64;
517 for j in 0..ncols {
518 row_sum += a[[i, j]].abs();
519 }
520 if row_sum > max_sum {
521 max_sum = row_sum;
522 }
523 }
524
525 max_sum
526}
527
528pub fn norm_max(a: &Array2<f64>) -> f64 {
532 let mut max_val = 0.0f64;
533 for val in a.iter() {
534 let abs_val = val.abs();
535 if abs_val > max_val {
536 max_val = abs_val;
537 }
538 }
539 max_val
540}
541
542pub fn trace<T: Field>(a: &Array2<T>) -> T {
548 let (nrows, ncols) = a.dim();
549 assert_eq!(nrows, ncols, "Matrix must be square for trace");
550
551 let mut sum = T::zero();
552 for i in 0..nrows {
553 sum += a[[i, i]];
554 }
555 sum
556}
557
558pub fn transpose<T: Clone>(a: &Array2<T>) -> Array2<T> {
560 a.t().to_owned()
561}
562
563pub fn eye<T: Field>(n: usize) -> Array2<T>
565where
566 T: Clone,
567{
568 let mut result = Array2::zeros((n, n));
569 for i in 0..n {
570 result[[i, i]] = T::one();
571 }
572 result
573}
574
575pub fn eye_f<T: Field>(n: usize) -> Array2<T>
577where
578 T: Clone,
579{
580 let mut result: Array2<T> = Array2::from_shape_fn((n, n).f(), |_| T::zero());
581 for i in 0..n {
582 result[[i, i]] = T::one();
583 }
584 result
585}
586
587pub fn conj_transpose_c64(a: &Array2<Complex64>) -> Array2<Complex64> {
614 let (m, n) = a.dim();
615 Array2::from_shape_fn((n, m), |(i, j)| a[[j, i]].conj())
616}
617
618pub fn conj_transpose_c32(a: &Array2<Complex32>) -> Array2<Complex32> {
622 let (m, n) = a.dim();
623 Array2::from_shape_fn((n, m), |(i, j)| a[[j, i]].conj())
624}
625
626pub fn frobenius_norm_c64(a: &Array2<Complex64>) -> f64 {
632 let mut sum = 0.0f64;
633 for val in a.iter() {
634 sum += val.norm_sqr();
635 }
636 sum.sqrt()
637}
638
639pub fn frobenius_norm_c32(a: &Array2<Complex32>) -> f32 {
643 let mut sum = 0.0f32;
644 for val in a.iter() {
645 sum += val.norm_sqr();
646 }
647 sum.sqrt()
648}
649
650pub fn norm_1_c64(a: &Array2<Complex64>) -> f64 {
654 let (nrows, ncols) = a.dim();
655 let mut max_sum = 0.0f64;
656
657 for j in 0..ncols {
658 let mut col_sum = 0.0f64;
659 for i in 0..nrows {
660 col_sum += a[[i, j]].norm();
661 }
662 if col_sum > max_sum {
663 max_sum = col_sum;
664 }
665 }
666
667 max_sum
668}
669
670pub fn norm_1_c32(a: &Array2<Complex32>) -> f32 {
672 let (nrows, ncols) = a.dim();
673 let mut max_sum = 0.0f32;
674
675 for j in 0..ncols {
676 let mut col_sum = 0.0f32;
677 for i in 0..nrows {
678 col_sum += a[[i, j]].norm();
679 }
680 if col_sum > max_sum {
681 max_sum = col_sum;
682 }
683 }
684
685 max_sum
686}
687
688pub fn norm_inf_c64(a: &Array2<Complex64>) -> f64 {
692 let (nrows, ncols) = a.dim();
693 let mut max_sum = 0.0f64;
694
695 for i in 0..nrows {
696 let mut row_sum = 0.0f64;
697 for j in 0..ncols {
698 row_sum += a[[i, j]].norm();
699 }
700 if row_sum > max_sum {
701 max_sum = row_sum;
702 }
703 }
704
705 max_sum
706}
707
708pub fn norm_inf_c32(a: &Array2<Complex32>) -> f32 {
710 let (nrows, ncols) = a.dim();
711 let mut max_sum = 0.0f32;
712
713 for i in 0..nrows {
714 let mut row_sum = 0.0f32;
715 for j in 0..ncols {
716 row_sum += a[[i, j]].norm();
717 }
718 if row_sum > max_sum {
719 max_sum = row_sum;
720 }
721 }
722
723 max_sum
724}
725
726pub fn norm_max_c64(a: &Array2<Complex64>) -> f64 {
730 let mut max_val = 0.0f64;
731 for val in a.iter() {
732 let abs_val = val.norm();
733 if abs_val > max_val {
734 max_val = abs_val;
735 }
736 }
737 max_val
738}
739
740pub fn norm_max_c32(a: &Array2<Complex32>) -> f32 {
742 let mut max_val = 0.0f32;
743 for val in a.iter() {
744 let abs_val = val.norm();
745 if abs_val > max_val {
746 max_val = abs_val;
747 }
748 }
749 max_val
750}
751
752pub fn trace_c64(a: &Array2<Complex64>) -> Complex64 {
754 let (nrows, ncols) = a.dim();
755 assert_eq!(nrows, ncols, "Matrix must be square for trace");
756
757 let mut sum = Complex64::new(0.0, 0.0);
758 for i in 0..nrows {
759 sum += a[[i, i]];
760 }
761 sum
762}
763
764pub fn trace_c32(a: &Array2<Complex32>) -> Complex32 {
766 let (nrows, ncols) = a.dim();
767 assert_eq!(nrows, ncols, "Matrix must be square for trace");
768
769 let mut sum = Complex32::new(0.0, 0.0);
770 for i in 0..nrows {
771 sum += a[[i, i]];
772 }
773 sum
774}
775
776pub fn scal_c64_ndarray(alpha: Complex64, x: &mut Array1<Complex64>) {
778 for xi in x.iter_mut() {
779 *xi = alpha * (*xi);
780 }
781}
782
783pub fn scal_c32_ndarray(alpha: Complex32, x: &mut Array1<Complex32>) {
785 for xi in x.iter_mut() {
786 *xi = alpha * (*xi);
787 }
788}
789
790pub fn axpy_c64_ndarray(alpha: Complex64, x: &Array1<Complex64>, y: &mut Array1<Complex64>) {
792 assert_eq!(x.len(), y.len(), "Vector lengths must match");
793
794 for (yi, xi) in y.iter_mut().zip(x.iter()) {
795 *yi = alpha * (*xi) + *yi;
796 }
797}
798
799pub fn axpy_c32_ndarray(alpha: Complex32, x: &Array1<Complex32>, y: &mut Array1<Complex32>) {
801 assert_eq!(x.len(), y.len(), "Vector lengths must match");
802
803 for (yi, xi) in y.iter_mut().zip(x.iter()) {
804 *yi = alpha * (*xi) + *yi;
805 }
806}
807
808pub fn eye_c64(n: usize) -> Array2<Complex64> {
810 let mut result: Array2<Complex64> = Array2::from_elem((n, n), Complex64::new(0.0, 0.0));
811 for i in 0..n {
812 result[[i, i]] = Complex64::new(1.0, 0.0);
813 }
814 result
815}
816
817pub fn eye_c32(n: usize) -> Array2<Complex32> {
819 let mut result: Array2<Complex32> = Array2::from_elem((n, n), Complex32::new(0.0, 0.0));
820 for i in 0..n {
821 result[[i, i]] = Complex32::new(1.0, 0.0);
822 }
823 result
824}
825
826#[cfg(test)]
827mod tests {
828 use super::*;
829 use ndarray::array;
830
831 #[test]
832 fn test_dot_ndarray() {
833 let x = array![1.0f64, 2.0, 3.0];
834 let y = array![4.0f64, 5.0, 6.0];
835 let d = dot_ndarray(&x, &y);
836 assert!((d - 32.0).abs() < 1e-10);
837 }
838
839 #[test]
840 fn test_nrm2_ndarray() {
841 let x = array![3.0f64, 4.0];
842 let norm = nrm2_ndarray(&x);
843 assert!((norm - 5.0).abs() < 1e-10);
844 }
845
846 #[test]
847 fn test_asum_ndarray() {
848 let x = array![-1.0f64, 2.0, -3.0];
849 let sum = asum_ndarray(&x);
850 assert!((sum - 6.0).abs() < 1e-10);
851 }
852
853 #[test]
854 fn test_axpy_ndarray() {
855 let x = array![1.0f64, 2.0, 3.0];
856 let mut y = array![4.0f64, 5.0, 6.0];
857 axpy_ndarray(2.0, &x, &mut y);
858 assert!((y[0] - 6.0).abs() < 1e-10);
859 assert!((y[1] - 9.0).abs() < 1e-10);
860 assert!((y[2] - 12.0).abs() < 1e-10);
861 }
862
863 #[test]
864 fn test_scal_ndarray() {
865 let mut x = array![1.0f64, 2.0, 3.0];
866 scal_ndarray(2.0, &mut x);
867 assert!((x[0] - 2.0).abs() < 1e-10);
868 assert!((x[1] - 4.0).abs() < 1e-10);
869 assert!((x[2] - 6.0).abs() < 1e-10);
870 }
871
872 #[test]
873 fn test_gemv_notrans() {
874 let a = Array2::from_shape_fn((2, 3), |(i, j)| (i * 3 + j + 1) as f64);
875 let x = array![1.0f64, 1.0, 1.0];
876 let mut y = array![0.0f64, 0.0];
877
878 gemv_ndarray(Transpose::NoTrans, 1.0, &a, &x, 0.0, &mut y);
879
880 assert!((y[0] - 6.0).abs() < 1e-10);
883 assert!((y[1] - 15.0).abs() < 1e-10);
884 }
885
886 #[test]
887 fn test_gemv_trans() {
888 let a = Array2::from_shape_fn((2, 3), |(i, j)| (i * 3 + j + 1) as f64);
889 let x = array![1.0f64, 1.0];
890 let mut y = array![0.0f64, 0.0, 0.0];
891
892 gemv_ndarray(Transpose::Trans, 1.0, &a, &x, 0.0, &mut y);
893
894 assert!((y[0] - 5.0).abs() < 1e-10);
898 assert!((y[1] - 7.0).abs() < 1e-10);
899 assert!((y[2] - 9.0).abs() < 1e-10);
900 }
901
902 #[test]
903 fn test_matvec() {
904 let a = Array2::from_shape_fn((2, 3), |(i, j)| (i * 3 + j + 1) as f64);
905 let x = array![1.0f64, 2.0, 3.0];
906 let y = matvec(&a, &x);
907
908 assert!((y[0] - 14.0).abs() < 1e-10);
911 assert!((y[1] - 32.0).abs() < 1e-10);
912 }
913
914 #[test]
915 fn test_matmul() {
916 let a = Array2::from_shape_fn((2, 3), |_| 1.0f64);
917 let b = Array2::from_shape_fn((3, 2), |_| 2.0f64);
918 let c = matmul(&a, &b);
919
920 assert_eq!(c.dim(), (2, 2));
921 for i in 0..2 {
922 for j in 0..2 {
923 assert!((c[[i, j]] - 6.0).abs() < 1e-10);
924 }
925 }
926 }
927
928 #[test]
929 fn test_gemm_ndarray() {
930 let a = Array2::from_shape_fn((2, 3), |_| 1.0f64);
931 let b = Array2::from_shape_fn((3, 2), |_| 2.0f64);
932 let mut c = Array2::from_shape_fn((2, 2), |_| 1.0f64);
933
934 gemm_ndarray(1.0, &a, &b, 1.0, &mut c);
935
936 for i in 0..2 {
938 for j in 0..2 {
939 assert!((c[[i, j]] - 7.0).abs() < 1e-10);
940 }
941 }
942 }
943
944 #[test]
945 fn test_frobenius_norm() {
946 let a = array![[1.0f64, 2.0], [3.0, 4.0]];
947 let norm = frobenius_norm(&a);
948 assert!((norm - 30.0f64.sqrt()).abs() < 1e-10);
950 }
951
952 #[test]
953 fn test_norm_1() {
954 let a = array![[1.0f64, 2.0], [3.0, 4.0]];
955 let norm = norm_1(&a);
956 assert!((norm - 6.0).abs() < 1e-10);
958 }
959
960 #[test]
961 fn test_norm_inf() {
962 let a = array![[1.0f64, 2.0], [3.0, 4.0]];
963 let norm = norm_inf(&a);
964 assert!((norm - 7.0).abs() < 1e-10);
966 }
967
968 #[test]
969 fn test_trace() {
970 let a = array![[1.0f64, 2.0], [3.0, 4.0]];
971 let tr = trace(&a);
972 assert!((tr - 5.0).abs() < 1e-10);
973 }
974
975 #[test]
976 fn test_eye() {
977 let id: Array2<f64> = eye(3);
978 for i in 0..3 {
979 for j in 0..3 {
980 if i == j {
981 assert!((id[[i, j]] - 1.0).abs() < 1e-15);
982 } else {
983 assert!(id[[i, j]].abs() < 1e-15);
984 }
985 }
986 }
987 }
988
989 #[test]
990 fn test_transpose() {
991 let a = array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
992 let at = transpose(&a);
993 assert_eq!(at.dim(), (3, 2));
994 assert!((at[[0, 0]] - 1.0).abs() < 1e-15);
995 assert!((at[[2, 1]] - 6.0).abs() < 1e-15);
996 }
997
998 #[test]
1003 fn test_dotc_c64_ndarray() {
1004 let x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1009 let y = array![Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)];
1010
1011 let result = dotc_c64_ndarray(&x, &y);
1012 assert!((result.re - 70.0).abs() < 1e-10);
1013 assert!((result.im - (-8.0)).abs() < 1e-10);
1014 }
1015
1016 #[test]
1017 fn test_dotc_c32_ndarray() {
1018 let x = array![Complex32::new(1.0, 2.0), Complex32::new(3.0, 4.0)];
1019 let y = array![Complex32::new(5.0, 6.0), Complex32::new(7.0, 8.0)];
1020
1021 let result = dotc_c32_ndarray(&x, &y);
1022 assert!((result.re - 70.0).abs() < 1e-5);
1023 assert!((result.im - (-8.0)).abs() < 1e-5);
1024 }
1025
1026 #[test]
1027 fn test_dotu_c64_ndarray() {
1028 let x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1033 let y = array![Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)];
1034
1035 let result = dotu_c64_ndarray(&x, &y);
1036 assert!((result.re - (-18.0)).abs() < 1e-10);
1037 assert!((result.im - 68.0).abs() < 1e-10);
1038 }
1039
1040 #[test]
1041 fn test_dotu_c32_ndarray() {
1042 let x = array![Complex32::new(1.0, 2.0), Complex32::new(3.0, 4.0)];
1043 let y = array![Complex32::new(5.0, 6.0), Complex32::new(7.0, 8.0)];
1044
1045 let result = dotu_c32_ndarray(&x, &y);
1046 assert!((result.re - (-18.0)).abs() < 1e-5);
1047 assert!((result.im - 68.0).abs() < 1e-5);
1048 }
1049
1050 #[test]
1051 fn test_dotc_c64_self_inner_product() {
1052 let x = array![
1054 Complex64::new(1.0, 2.0),
1055 Complex64::new(3.0, 4.0),
1056 Complex64::new(5.0, 6.0)
1057 ];
1058
1059 let result = dotc_c64_ndarray(&x, &x);
1060
1061 assert!(result.im.abs() < 1e-10);
1063
1064 assert!((result.re - 91.0).abs() < 1e-10);
1066 }
1067
1068 #[test]
1069 fn test_nrm2_c64_ndarray() {
1070 let x = array![Complex64::new(3.0, 4.0)]; let norm = nrm2_c64_ndarray(&x);
1073 assert!((norm - 5.0).abs() < 1e-10);
1074
1075 let x = array![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)];
1076 let norm = nrm2_c64_ndarray(&x);
1077 assert!((norm - 2.0f64.sqrt()).abs() < 1e-10);
1079 }
1080
1081 #[test]
1082 fn test_nrm2_c32_ndarray() {
1083 let x = array![Complex32::new(3.0, 4.0)];
1084 let norm = nrm2_c32_ndarray(&x);
1085 assert!((norm - 5.0).abs() < 1e-5);
1086 }
1087
1088 #[test]
1089 fn test_asum_c64_ndarray() {
1090 let x = array![Complex64::new(3.0, 4.0), Complex64::new(5.0, 12.0)];
1092 let sum = asum_c64_ndarray(&x);
1094 assert!((sum - 18.0).abs() < 1e-10);
1095 }
1096
1097 #[test]
1098 fn test_asum_c32_ndarray() {
1099 let x = array![Complex32::new(3.0, 4.0), Complex32::new(5.0, 12.0)];
1100 let sum = asum_c32_ndarray(&x);
1101 assert!((sum - 18.0).abs() < 1e-5);
1102 }
1103
1104 #[test]
1105 fn test_conj_transpose_c64() {
1106 let a = array![
1107 [Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
1108 [Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)]
1109 ];
1110
1111 let ah = conj_transpose_c64(&a);
1112 assert_eq!(ah.dim(), (2, 2));
1113
1114 assert!((ah[[0, 0]].re - 1.0).abs() < 1e-10);
1116 assert!((ah[[0, 0]].im - (-2.0)).abs() < 1e-10);
1117
1118 assert!((ah[[0, 1]].re - 5.0).abs() < 1e-10);
1120 assert!((ah[[0, 1]].im - (-6.0)).abs() < 1e-10);
1121
1122 assert!((ah[[1, 0]].re - 3.0).abs() < 1e-10);
1124 assert!((ah[[1, 0]].im - (-4.0)).abs() < 1e-10);
1125
1126 assert!((ah[[1, 1]].re - 7.0).abs() < 1e-10);
1128 assert!((ah[[1, 1]].im - (-8.0)).abs() < 1e-10);
1129 }
1130
1131 #[test]
1132 fn test_conj_transpose_c64_rectangular() {
1133 let a = array![
1134 [
1135 Complex64::new(1.0, 1.0),
1136 Complex64::new(2.0, 2.0),
1137 Complex64::new(3.0, 3.0)
1138 ],
1139 [
1140 Complex64::new(4.0, 4.0),
1141 Complex64::new(5.0, 5.0),
1142 Complex64::new(6.0, 6.0)
1143 ]
1144 ];
1145
1146 let ah = conj_transpose_c64(&a);
1147 assert_eq!(ah.dim(), (3, 2));
1148
1149 assert!((ah[[2, 1]].re - 6.0).abs() < 1e-10);
1151 assert!((ah[[2, 1]].im - (-6.0)).abs() < 1e-10);
1152 }
1153
1154 #[test]
1155 fn test_frobenius_norm_c64() {
1156 let a = array![
1157 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
1158 [Complex64::new(0.0, 1.0), Complex64::new(1.0, 0.0)]
1159 ];
1160 let norm = frobenius_norm_c64(&a);
1162 assert!((norm - 2.0).abs() < 1e-10);
1163 }
1164
1165 #[test]
1166 fn test_frobenius_norm_c32() {
1167 let a = array![
1168 [Complex32::new(3.0, 4.0)] ];
1170 let norm = frobenius_norm_c32(&a);
1171 assert!((norm - 5.0).abs() < 1e-5);
1172 }
1173
1174 #[test]
1175 fn test_norm_1_c64() {
1176 let a = array![
1177 [Complex64::new(3.0, 4.0), Complex64::new(0.0, 1.0)],
1178 [Complex64::new(0.0, 0.0), Complex64::new(5.0, 12.0)]
1179 ];
1180 let norm = norm_1_c64(&a);
1183 assert!((norm - 14.0).abs() < 1e-10);
1184 }
1185
1186 #[test]
1187 fn test_norm_inf_c64() {
1188 let a = array![
1189 [Complex64::new(3.0, 4.0), Complex64::new(0.0, 1.0)],
1190 [Complex64::new(0.0, 0.0), Complex64::new(5.0, 12.0)]
1191 ];
1192 let norm = norm_inf_c64(&a);
1195 assert!((norm - 13.0).abs() < 1e-10);
1196 }
1197
1198 #[test]
1199 fn test_norm_max_c64() {
1200 let a = array![
1201 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 4.0)],
1202 [Complex64::new(5.0, 12.0), Complex64::new(0.0, 1.0)]
1203 ];
1204 let max = norm_max_c64(&a);
1206 assert!((max - 13.0).abs() < 1e-10);
1207 }
1208
1209 #[test]
1210 fn test_trace_c64() {
1211 let a = array![
1212 [Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
1213 [Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)]
1214 ];
1215 let tr = trace_c64(&a);
1217 assert!((tr.re - 8.0).abs() < 1e-10);
1218 assert!((tr.im - 10.0).abs() < 1e-10);
1219 }
1220
1221 #[test]
1222 fn test_trace_c32() {
1223 let a = array![
1224 [Complex32::new(1.0, 2.0), Complex32::new(3.0, 4.0)],
1225 [Complex32::new(5.0, 6.0), Complex32::new(7.0, 8.0)]
1226 ];
1227 let tr = trace_c32(&a);
1228 assert!((tr.re - 8.0).abs() < 1e-5);
1229 assert!((tr.im - 10.0).abs() < 1e-5);
1230 }
1231
1232 #[test]
1233 fn test_scal_c64_ndarray() {
1234 let mut x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1235 let alpha = Complex64::new(2.0, 0.0);
1236 scal_c64_ndarray(alpha, &mut x);
1237
1238 assert!((x[0].re - 2.0).abs() < 1e-10);
1239 assert!((x[0].im - 4.0).abs() < 1e-10);
1240 assert!((x[1].re - 6.0).abs() < 1e-10);
1241 assert!((x[1].im - 8.0).abs() < 1e-10);
1242 }
1243
1244 #[test]
1245 fn test_scal_c64_ndarray_complex_alpha() {
1246 let mut x = array![Complex64::new(1.0, 0.0)];
1247 let alpha = Complex64::new(0.0, 1.0); scal_c64_ndarray(alpha, &mut x);
1249
1250 assert!((x[0].re - 0.0).abs() < 1e-10);
1252 assert!((x[0].im - 1.0).abs() < 1e-10);
1253 }
1254
1255 #[test]
1256 fn test_axpy_c64_ndarray() {
1257 let x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1258 let mut y = array![Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)];
1259 let alpha = Complex64::new(2.0, 0.0);
1260
1261 axpy_c64_ndarray(alpha, &x, &mut y);
1262
1263 assert!((y[0].re - 7.0).abs() < 1e-10);
1265 assert!((y[0].im - 10.0).abs() < 1e-10);
1266
1267 assert!((y[1].re - 13.0).abs() < 1e-10);
1269 assert!((y[1].im - 16.0).abs() < 1e-10);
1270 }
1271
1272 #[test]
1273 fn test_axpy_c32_ndarray() {
1274 let x = array![Complex32::new(1.0, 2.0)];
1275 let mut y = array![Complex32::new(3.0, 4.0)];
1276 let alpha = Complex32::new(0.0, 1.0); axpy_c32_ndarray(alpha, &x, &mut y);
1279
1280 assert!((y[0].re - 1.0).abs() < 1e-5);
1282 assert!((y[0].im - 5.0).abs() < 1e-5);
1283 }
1284
1285 #[test]
1286 fn test_eye_c64() {
1287 let id = eye_c64(3);
1288 assert_eq!(id.dim(), (3, 3));
1289
1290 for i in 0..3 {
1291 for j in 0..3 {
1292 if i == j {
1293 assert!((id[[i, j]].re - 1.0).abs() < 1e-10);
1294 assert!(id[[i, j]].im.abs() < 1e-10);
1295 } else {
1296 assert!(id[[i, j]].re.abs() < 1e-10);
1297 assert!(id[[i, j]].im.abs() < 1e-10);
1298 }
1299 }
1300 }
1301 }
1302
1303 #[test]
1304 fn test_eye_c32() {
1305 let id = eye_c32(2);
1306 assert_eq!(id.dim(), (2, 2));
1307 assert!((id[[0, 0]].re - 1.0).abs() < 1e-5);
1308 assert!((id[[1, 1]].re - 1.0).abs() < 1e-5);
1309 assert!(id[[0, 1]].re.abs() < 1e-5);
1310 assert!(id[[1, 0]].re.abs() < 1e-5);
1311 }
1312
1313 #[test]
1314 fn test_dotc_c64_large() {
1315 let n = 1000;
1317 let x: Array1<Complex64> =
1318 Array1::from_shape_fn(n, |i| Complex64::new(i as f64, (i as f64) * 0.5));
1319 let y: Array1<Complex64> =
1320 Array1::from_shape_fn(n, |i| Complex64::new(1.0, 0.1 * i as f64));
1321
1322 let result = dotc_c64_ndarray(&x, &y);
1323
1324 let expected: Complex64 = x.iter().zip(y.iter()).map(|(xi, yi)| xi.conj() * yi).sum();
1326 assert!((result.re - expected.re).abs() < 1e-6);
1327 assert!((result.im - expected.im).abs() < 1e-6);
1328 }
1329
1330 #[test]
1331 fn test_dotu_c64_large() {
1332 let n = 1000;
1333 let x: Array1<Complex64> = Array1::from_shape_fn(n, |i| {
1334 Complex64::new((i % 100) as f64, ((i + 50) % 100) as f64)
1335 });
1336 let y: Array1<Complex64> = Array1::from_shape_fn(n, |i| {
1337 Complex64::new(((i + 25) % 100) as f64, ((i + 75) % 100) as f64)
1338 });
1339
1340 let result = dotu_c64_ndarray(&x, &y);
1341
1342 let expected: Complex64 = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
1343 assert!((result.re - expected.re).abs() < 1e-6);
1344 assert!((result.im - expected.im).abs() < 1e-6);
1345 }
1346
1347 #[test]
1348 fn test_hermitian_property() {
1349 let a = array![
1351 [Complex64::new(2.0, 0.0), Complex64::new(1.0, 1.0)],
1352 [Complex64::new(1.0, -1.0), Complex64::new(3.0, 0.0)]
1353 ];
1354
1355 let ah = conj_transpose_c64(&a);
1356
1357 for i in 0..2 {
1359 for j in 0..2 {
1360 assert!((a[[i, j]].re - ah[[i, j]].re).abs() < 1e-10);
1361 assert!((a[[i, j]].im - ah[[i, j]].im).abs() < 1e-10);
1362 }
1363 }
1364 }
1365}