1use crate::imp_prelude::*;
10
11#[cfg(feature = "blas")]
12use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
13use crate::numeric_util;
14use crate::ArrayRef1;
15use crate::ArrayRef2;
16
17use crate::{LinalgScalar, Zip};
18
19#[cfg(not(feature = "std"))]
20use alloc::vec;
21#[cfg(not(feature = "std"))]
22use alloc::vec::Vec;
23
24use std::any::TypeId;
25use std::mem::MaybeUninit;
26
27use num_complex::Complex;
28use num_complex::{Complex32 as c32, Complex64 as c64};
29
30#[cfg(feature = "blas")]
31use libc::c_int;
32
33#[cfg(feature = "blas")]
34use cblas_sys as blas_sys;
35#[cfg(feature = "blas")]
36use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE};
37
38#[cfg(feature = "blas")]
40const DOT_BLAS_CUTOFF: usize = 32;
41#[cfg(feature = "blas")]
43const GEMM_BLAS_CUTOFF: usize = 7;
44#[cfg(feature = "blas")]
45#[allow(non_camel_case_types)]
46type blas_index = c_int; impl<A> ArrayRef<A, Ix1>
49{
50 #[track_caller]
68 pub fn dot<Rhs: ?Sized>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
69 where Self: Dot<Rhs>
70 {
71 Dot::dot(self, rhs)
72 }
73
74 fn dot_generic(&self, rhs: &ArrayRef<A, Ix1>) -> A
75 where A: LinalgScalar
76 {
77 debug_assert_eq!(self.len(), rhs.len());
78 assert!(self.len() == rhs.len());
79 if let Some(self_s) = self.as_slice() {
80 if let Some(rhs_s) = rhs.as_slice() {
81 return numeric_util::unrolled_dot(self_s, rhs_s);
82 }
83 }
84 let mut sum = A::zero();
85 for i in 0..self.len() {
86 unsafe {
87 sum = sum + *self.uget(i) * *rhs.uget(i);
88 }
89 }
90 sum
91 }
92
93 #[cfg(not(feature = "blas"))]
94 fn dot_impl(&self, rhs: &ArrayRef<A, Ix1>) -> A
95 where A: LinalgScalar
96 {
97 self.dot_generic(rhs)
98 }
99
100 #[cfg(feature = "blas")]
101 fn dot_impl(&self, rhs: &ArrayRef<A, Ix1>) -> A
102 where A: LinalgScalar
103 {
104 if self.len() >= DOT_BLAS_CUTOFF {
106 debug_assert_eq!(self.len(), rhs.len());
107 assert!(self.len() == rhs.len());
108 macro_rules! dot {
109 ($ty:ty, $func:ident) => {{
110 if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
111 unsafe {
112 let (lhs_ptr, n, incx) =
113 blas_1d_params(self._ptr().as_ptr(), self.len(), self.strides()[0]);
114 let (rhs_ptr, _, incy) =
115 blas_1d_params(rhs._ptr().as_ptr(), rhs.len(), rhs.strides()[0]);
116 let ret = blas_sys::$func(
117 n,
118 lhs_ptr as *const $ty,
119 incx,
120 rhs_ptr as *const $ty,
121 incy,
122 );
123 return cast_as::<$ty, A>(&ret);
124 }
125 }
126 }};
127 }
128
129 dot! {f32, cblas_sdot};
130 dot! {f64, cblas_ddot};
131 }
132 self.dot_generic(rhs)
133 }
134}
135
136#[cfg(feature = "blas")]
142unsafe fn blas_1d_params<A>(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index)
143{
144 if stride >= 0 || len == 0 {
149 (ptr, len as blas_index, stride as blas_index)
150 } else {
151 let ptr = ptr.offset((len - 1) as isize * stride);
152 (ptr, len as blas_index, stride as blas_index)
153 }
154}
155
156pub trait Dot<Rhs: ?Sized>
161{
162 type Output;
166
167 fn dot(&self, rhs: &Rhs) -> Self::Output;
171}
172
173macro_rules! impl_dots {
174 (
175 $shape1:ty,
176 $shape2:ty
177 ) => {
178 impl<A, S, S2> Dot<ArrayBase<S2, $shape2>> for ArrayBase<S, $shape1>
179 where
180 S: Data<Elem = A>,
181 S2: Data<Elem = A>,
182 A: LinalgScalar,
183 {
184 type Output = <ArrayRef<A, $shape1> as Dot<ArrayRef<A, $shape2>>>::Output;
185
186 fn dot(&self, rhs: &ArrayBase<S2, $shape2>) -> Self::Output
187 {
188 Dot::dot(&**self, &**rhs)
189 }
190 }
191
192 impl<A, S> Dot<ArrayRef<A, $shape2>> for ArrayBase<S, $shape1>
193 where
194 S: Data<Elem = A>,
195 A: LinalgScalar,
196 {
197 type Output = <ArrayRef<A, $shape1> as Dot<ArrayRef<A, $shape2>>>::Output;
198
199 fn dot(&self, rhs: &ArrayRef<A, $shape2>) -> Self::Output
200 {
201 (**self).dot(rhs)
202 }
203 }
204
205 impl<A, S> Dot<ArrayBase<S, $shape2>> for ArrayRef<A, $shape1>
206 where
207 S: Data<Elem = A>,
208 A: LinalgScalar,
209 {
210 type Output = <ArrayRef<A, $shape1> as Dot<ArrayRef<A, $shape2>>>::Output;
211
212 fn dot(&self, rhs: &ArrayBase<S, $shape2>) -> Self::Output
213 {
214 self.dot(&**rhs)
215 }
216 }
217 };
218}
219
220impl_dots!(Ix1, Ix1);
221impl_dots!(Ix1, Ix2);
222impl_dots!(Ix2, Ix1);
223impl_dots!(Ix2, Ix2);
224
225impl<A> Dot<ArrayRef<A, Ix1>> for ArrayRef<A, Ix1>
226where A: LinalgScalar
227{
228 type Output = A;
229
230 #[track_caller]
239 fn dot(&self, rhs: &ArrayRef<A, Ix1>) -> A
240 {
241 self.dot_impl(rhs)
242 }
243}
244
245impl<A> Dot<ArrayRef<A, Ix2>> for ArrayRef<A, Ix1>
246where A: LinalgScalar
247{
248 type Output = Array<A, Ix1>;
249
250 #[track_caller]
260 fn dot(&self, rhs: &ArrayRef<A, Ix2>) -> Array<A, Ix1>
261 {
262 (*rhs.t()).dot(self)
263 }
264}
265
266impl<A> ArrayRef<A, Ix2>
267{
268 #[track_caller]
298 pub fn dot<Rhs: ?Sized>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
299 where Self: Dot<Rhs>
300 {
301 Dot::dot(self, rhs)
302 }
303}
304
305impl<A> Dot<ArrayRef<A, Ix2>> for ArrayRef<A, Ix2>
306where A: LinalgScalar
307{
308 type Output = Array2<A>;
309
310 fn dot(&self, b: &ArrayRef<A, Ix2>) -> Array2<A>
311 {
312 let a = self.view();
313 let b = b.view();
314 let ((m, k), (k2, n)) = (a.dim(), b.dim());
315 if k != k2 || m.checked_mul(n).is_none() {
316 dot_shape_error(m, k, k2, n);
317 }
318
319 let lhs_s0 = a.strides()[0];
320 let rhs_s0 = b.strides()[0];
321 let column_major = lhs_s0 == 1 && rhs_s0 == 1;
322 let mut v = Vec::with_capacity(m * n);
324 let mut c;
325 unsafe {
326 v.set_len(m * n);
327 c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
328 }
329 mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
330 c
331 }
332}
333
334#[cold]
336#[inline(never)]
337fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> !
338{
339 match m.checked_mul(n) {
340 Some(len) if len <= isize::MAX as usize => {}
341 _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
342 }
343 panic!(
344 "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
345 m, k, k2, n
346 );
347}
348
349#[cold]
350#[inline(never)]
351fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> !
352{
353 panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
354 m, k, k2, n, c1, c2);
355}
356
357impl<A> Dot<ArrayRef<A, Ix1>> for ArrayRef<A, Ix2>
367where A: LinalgScalar
368{
369 type Output = Array<A, Ix1>;
370
371 #[track_caller]
372 fn dot(&self, rhs: &ArrayRef<A, Ix1>) -> Array<A, Ix1>
373 {
374 let ((m, a), n) = (self.dim(), rhs.dim());
375 if a != n {
376 dot_shape_error(m, a, n, 1);
377 }
378
379 unsafe {
381 let mut c = Array1::uninit(m);
382 general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
383 c.assume_init()
384 }
385 }
386}
387
388impl<A, D> ArrayRef<A, D>
389where D: Dimension
390{
391 #[track_caller]
399 pub fn scaled_add<E>(&mut self, alpha: A, rhs: &ArrayRef<A, E>)
400 where
401 A: LinalgScalar,
402 E: Dimension,
403 {
404 self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
405 }
406}
407
408#[cfg(not(feature = "blas"))]
411use self::mat_mul_general as mat_mul_impl;
412
413#[cfg(feature = "blas")]
414fn mat_mul_impl<A>(alpha: A, a: &ArrayRef2<A>, b: &ArrayRef2<A>, beta: A, c: &mut ArrayRef2<A>)
415where A: LinalgScalar
416{
417 let ((m, k), (k2, n)) = (a.dim(), b.dim());
418 debug_assert_eq!(k, k2);
419 if (m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF)
420 && (same_type::<A, f32>() || same_type::<A, f64>() || same_type::<A, c32>() || same_type::<A, c64>())
421 {
422 if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
434 (get_blas_compatible_layout(a), get_blas_compatible_layout(b), get_blas_compatible_layout(c))
435 {
436 let cblas_layout = c_layout.to_cblas_layout();
437 let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
438 let lda = blas_stride(&a, a_layout);
439
440 let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
441 let ldb = blas_stride(&b, b_layout);
442
443 let ldc = blas_stride(&c, c_layout);
444
445 macro_rules! gemm_scalar_cast {
446 (f32, $var:ident) => {
447 cast_as(&$var)
448 };
449 (f64, $var:ident) => {
450 cast_as(&$var)
451 };
452 (c32, $var:ident) => {
453 &$var as *const A as *const _
454 };
455 (c64, $var:ident) => {
456 &$var as *const A as *const _
457 };
458 }
459
460 macro_rules! gemm {
461 ($ty:tt, $gemm:ident) => {
462 if same_type::<A, $ty>() {
463 unsafe {
466 blas_sys::$gemm(
467 cblas_layout,
468 a_trans,
469 b_trans,
470 m as blas_index, n as blas_index, k as blas_index, gemm_scalar_cast!($ty, alpha), a._ptr().as_ptr() as *const _, lda, b._ptr().as_ptr() as *const _, ldb, gemm_scalar_cast!($ty, beta), c._ptr().as_ptr() as *mut _, ldc, );
482 }
483 return;
484 }
485 };
486 }
487
488 gemm!(f32, cblas_sgemm);
489 gemm!(f64, cblas_dgemm);
490 gemm!(c32, cblas_cgemm);
491 gemm!(c64, cblas_zgemm);
492
493 unreachable!() }
495 }
496 mat_mul_general(alpha, a, b, beta, c)
497}
498
499fn mat_mul_general<A>(alpha: A, lhs: &ArrayRef2<A>, rhs: &ArrayRef2<A>, beta: A, c: &mut ArrayRef2<A>)
501where A: LinalgScalar
502{
503 let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
504
505 let ap = lhs.as_ptr();
507 let bp = rhs.as_ptr();
508 let cp = c.as_mut_ptr();
509 let (rsc, csc) = (c.strides()[0], c.strides()[1]);
510 if same_type::<A, f32>() {
511 unsafe {
512 matrixmultiply::sgemm(
513 m,
514 k,
515 n,
516 cast_as(&alpha),
517 ap as *const _,
518 lhs.strides()[0],
519 lhs.strides()[1],
520 bp as *const _,
521 rhs.strides()[0],
522 rhs.strides()[1],
523 cast_as(&beta),
524 cp as *mut _,
525 rsc,
526 csc,
527 );
528 }
529 } else if same_type::<A, f64>() {
530 unsafe {
531 matrixmultiply::dgemm(
532 m,
533 k,
534 n,
535 cast_as(&alpha),
536 ap as *const _,
537 lhs.strides()[0],
538 lhs.strides()[1],
539 bp as *const _,
540 rhs.strides()[0],
541 rhs.strides()[1],
542 cast_as(&beta),
543 cp as *mut _,
544 rsc,
545 csc,
546 );
547 }
548 } else if same_type::<A, c32>() {
549 unsafe {
550 matrixmultiply::cgemm(
551 matrixmultiply::CGemmOption::Standard,
552 matrixmultiply::CGemmOption::Standard,
553 m,
554 k,
555 n,
556 complex_array(cast_as(&alpha)),
557 ap as *const _,
558 lhs.strides()[0],
559 lhs.strides()[1],
560 bp as *const _,
561 rhs.strides()[0],
562 rhs.strides()[1],
563 complex_array(cast_as(&beta)),
564 cp as *mut _,
565 rsc,
566 csc,
567 );
568 }
569 } else if same_type::<A, c64>() {
570 unsafe {
571 matrixmultiply::zgemm(
572 matrixmultiply::CGemmOption::Standard,
573 matrixmultiply::CGemmOption::Standard,
574 m,
575 k,
576 n,
577 complex_array(cast_as(&alpha)),
578 ap as *const _,
579 lhs.strides()[0],
580 lhs.strides()[1],
581 bp as *const _,
582 rhs.strides()[0],
583 rhs.strides()[1],
584 complex_array(cast_as(&beta)),
585 cp as *mut _,
586 rsc,
587 csc,
588 );
589 }
590 } else {
591 if c.is_empty() {
593 return;
594 }
595
596 if beta.is_zero() {
598 c.fill(beta);
599 }
600
601 let mut i = 0;
602 let mut j = 0;
603 loop {
604 unsafe {
605 let elt = c.uget_mut((i, j));
606 *elt =
607 *elt * beta + alpha * (0..k).fold(A::zero(), move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j)));
608 }
609 j += 1;
610 if j == n {
611 j = 0;
612 i += 1;
613 if i == m {
614 break;
615 }
616 }
617 }
618 }
619}
620
621#[track_caller]
633pub fn general_mat_mul<A>(alpha: A, a: &ArrayRef2<A>, b: &ArrayRef2<A>, beta: A, c: &mut ArrayRef2<A>)
634where A: LinalgScalar
635{
636 let ((m, k), (k2, n)) = (a.dim(), b.dim());
637 let (m2, n2) = c.dim();
638 if k != k2 || m != m2 || n != n2 {
639 general_dot_shape_error(m, k, k2, n, m2, n2);
640 } else {
641 mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
642 }
643}
644
645#[track_caller]
656#[allow(clippy::collapsible_if)]
657pub fn general_mat_vec_mul<A>(alpha: A, a: &ArrayRef2<A>, x: &ArrayRef1<A>, beta: A, y: &mut ArrayRef1<A>)
658where A: LinalgScalar
659{
660 unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
661}
662
663#[allow(clippy::collapsible_else_if)]
672unsafe fn general_mat_vec_mul_impl<A>(
673 alpha: A, a: &ArrayRef2<A>, x: &ArrayRef1<A>, beta: A, y: RawArrayViewMut<A, Ix1>,
674) where A: LinalgScalar
675{
676 let ((m, k), k2) = (a.dim(), x.dim());
677 let m2 = y.dim();
678 if k != k2 || m != m2 {
679 general_dot_shape_error(m, k, k2, 1, m2, 1);
680 } else {
681 #[cfg(feature = "blas")]
682 macro_rules! gemv {
683 ($ty:ty, $gemv:ident) => {
684 if same_type::<A, $ty>() {
685 if let Some(layout) = get_blas_compatible_layout(&a) {
686 if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y.as_ref()) {
687 let a_trans = CblasNoTrans;
692
693 let a_stride = blas_stride(&a, layout);
694 let cblas_layout = layout.to_cblas_layout();
695
696 let x_offset = offset_from_low_addr_ptr_to_logical_ptr(x._dim(), x._strides());
698 let x_ptr = x._ptr().as_ptr().sub(x_offset);
699 let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.parts.dim, &y.parts.strides);
700 let y_ptr = y.parts.ptr.as_ptr().sub(y_offset);
701
702 let x_stride = x.strides()[0] as blas_index;
703 let y_stride = y.strides()[0] as blas_index;
704
705 blas_sys::$gemv(
706 cblas_layout,
707 a_trans,
708 m as blas_index, k as blas_index, cast_as(&alpha), a._ptr().as_ptr() as *const _, a_stride, x_ptr as *const _, x_stride,
715 cast_as(&beta), y_ptr as *mut _, y_stride,
718 );
719 return;
720 }
721 }
722 }
723 };
724 }
725 #[cfg(feature = "blas")]
726 gemv!(f32, cblas_sgemv);
727 #[cfg(feature = "blas")]
728 gemv!(f64, cblas_dgemv);
729
730 if beta.is_zero() {
733 Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
735 elt.write(row.dot(x) * alpha);
736 });
737 } else {
738 Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
739 *elt = *elt * beta + row.dot(x) * alpha;
740 });
741 }
742 }
743}
744
745pub fn kron<A>(a: &ArrayRef2<A>, b: &ArrayRef2<A>) -> Array<A, Ix2>
750where A: LinalgScalar
751{
752 let dimar = a.shape()[0];
753 let dimac = a.shape()[1];
754 let dimbr = b.shape()[0];
755 let dimbc = b.shape()[1];
756 let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
757 dimar
758 .checked_mul(dimbr)
759 .expect("Dimensions of kronecker product output array overflows usize."),
760 dimac
761 .checked_mul(dimbc)
762 .expect("Dimensions of kronecker product output array overflows usize."),
763 ));
764 Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
765 .and(a)
766 .for_each(|out, &a| {
767 Zip::from(out).and(b).for_each(|out, &b| {
768 *out = MaybeUninit::new(a * b);
769 })
770 });
771 unsafe { out.assume_init() }
772}
773
774#[inline(always)]
775fn same_type<A: 'static, B: 'static>() -> bool
777{
778 TypeId::of::<A>() == TypeId::of::<B>()
779}
780
781fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B
785{
786 assert!(same_type::<A, B>(), "expect type {} and {} to match",
787 std::any::type_name::<A>(), std::any::type_name::<B>());
788 unsafe { ::std::ptr::read(a as *const _ as *const B) }
789}
790
791#[inline]
793fn complex_array<A: 'static + Copy>(z: Complex<A>) -> [A; 2]
794{
795 [z.re, z.im]
796}
797
798#[cfg(feature = "blas")]
799fn blas_compat_1d<A, B>(a: &RawRef<B, Ix1>) -> bool
800where
801 A: 'static,
802 B: 'static,
803{
804 if !same_type::<A, B>() {
805 return false;
806 }
807 if a.len() > blas_index::MAX as usize {
808 return false;
809 }
810 let stride = a.strides()[0];
811 if stride == 0 || stride > blas_index::MAX as isize || stride < blas_index::MIN as isize {
812 return false;
813 }
814 true
815}
816
817#[cfg(feature = "blas")]
818#[derive(Copy, Clone)]
819#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
820enum BlasOrder
821{
822 C,
823 F,
824}
825
826#[cfg(feature = "blas")]
827impl BlasOrder
828{
829 fn transpose(self) -> Self
830 {
831 match self {
832 Self::C => Self::F,
833 Self::F => Self::C,
834 }
835 }
836
837 #[inline]
838 fn get_blas_lead_axis(self) -> usize
840 {
841 match self {
842 Self::C => 0,
843 Self::F => 1,
844 }
845 }
846
847 fn to_cblas_layout(self) -> CBLAS_LAYOUT
848 {
849 match self {
850 Self::C => CBLAS_LAYOUT::CblasRowMajor,
851 Self::F => CBLAS_LAYOUT::CblasColMajor,
852 }
853 }
854
855 fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE
858 {
859 let effective_order = match for_layout {
860 CBLAS_LAYOUT::CblasRowMajor => self,
861 CBLAS_LAYOUT::CblasColMajor => self.transpose(),
862 };
863
864 match effective_order {
865 Self::C => CblasNoTrans,
866 Self::F => CblasTrans,
867 }
868 }
869}
870
871#[cfg(feature = "blas")]
872fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool
873{
874 let (m, n) = dim.into_pattern();
875 let s0 = stride[0] as isize;
876 let s1 = stride[1] as isize;
877 let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
878 BlasOrder::C => (s1, s0, m, n),
879 BlasOrder::F => (s0, s1, n, m),
880 };
881
882 if !(inner_stride == 1 || outer_dim == 1) {
883 return false;
884 }
885
886 if s0 < 1 || s1 < 1 {
887 return false;
888 }
889
890 if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
891 || (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
892 {
893 return false;
894 }
895
896 if inner_dim > 1 && (outer_stride as usize) < outer_dim {
898 return false;
899 }
900
901 if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
902 return false;
903 }
904
905 true
906}
907
908#[cfg(feature = "blas")]
910fn get_blas_compatible_layout<A>(a: &ArrayRef<A, Ix2>) -> Option<BlasOrder>
911{
912 if is_blas_2d(a._dim(), a._strides(), BlasOrder::C) {
913 Some(BlasOrder::C)
914 } else if is_blas_2d(a._dim(), a._strides(), BlasOrder::F) {
915 Some(BlasOrder::F)
916 } else {
917 None
918 }
919}
920
921#[cfg(feature = "blas")]
926fn blas_stride<A>(a: &ArrayRef<A, Ix2>, order: BlasOrder) -> blas_index
927{
928 let axis = order.get_blas_lead_axis();
929 let other_axis = 1 - axis;
930 let len_this = a.shape()[axis];
931 let len_other = a.shape()[other_axis];
932 let stride = a.strides()[axis];
933
934 (if len_this <= 1 {
939 Ord::max(stride, len_other as isize)
940 } else {
941 stride
942 }) as blas_index
943}
944
945#[cfg(test)]
946#[cfg(feature = "blas")]
947fn blas_row_major_2d<A, B>(a: &ArrayRef2<B>) -> bool
948where
949 A: 'static,
950 B: 'static,
951{
952 if !same_type::<A, B>() {
953 return false;
954 }
955 is_blas_2d(a._dim(), a._strides(), BlasOrder::C)
956}
957
958#[cfg(test)]
959#[cfg(feature = "blas")]
960fn blas_column_major_2d<A, B>(a: &ArrayRef2<B>) -> bool
961where
962 A: 'static,
963 B: 'static,
964{
965 if !same_type::<A, B>() {
966 return false;
967 }
968 is_blas_2d(a._dim(), a._strides(), BlasOrder::F)
969}
970
971#[cfg(test)]
972#[cfg(feature = "blas")]
973mod blas_tests
974{
975 use super::*;
976
977 #[test]
978 fn blas_row_major_2d_normal_matrix()
979 {
980 let m: Array2<f32> = Array2::zeros((3, 5));
981 assert!(blas_row_major_2d::<f32, _>(&m));
982 assert!(!blas_column_major_2d::<f32, _>(&m));
983 }
984
985 #[test]
986 fn blas_row_major_2d_row_matrix()
987 {
988 let m: Array2<f32> = Array2::zeros((1, 5));
989 assert!(blas_row_major_2d::<f32, _>(&m));
990 assert!(blas_column_major_2d::<f32, _>(&m));
991 }
992
993 #[test]
994 fn blas_row_major_2d_column_matrix()
995 {
996 let m: Array2<f32> = Array2::zeros((5, 1));
997 assert!(blas_row_major_2d::<f32, _>(&m));
998 assert!(blas_column_major_2d::<f32, _>(&m));
999 }
1000
1001 #[test]
1002 fn blas_row_major_2d_transposed_row_matrix()
1003 {
1004 let m: Array2<f32> = Array2::zeros((1, 5));
1005 let m_t = m.t();
1006 assert!(blas_row_major_2d::<f32, _>(&m_t));
1007 assert!(blas_column_major_2d::<f32, _>(&m_t));
1008 }
1009
1010 #[test]
1011 fn blas_row_major_2d_transposed_column_matrix()
1012 {
1013 let m: Array2<f32> = Array2::zeros((5, 1));
1014 let m_t = m.t();
1015 assert!(blas_row_major_2d::<f32, _>(&m_t));
1016 assert!(blas_column_major_2d::<f32, _>(&m_t));
1017 }
1018
1019 #[test]
1020 fn blas_column_major_2d_normal_matrix()
1021 {
1022 let m: Array2<f32> = Array2::zeros((3, 5).f());
1023 assert!(!blas_row_major_2d::<f32, _>(&m));
1024 assert!(blas_column_major_2d::<f32, _>(&m));
1025 }
1026
1027 #[test]
1028 fn blas_row_major_2d_skip_rows_ok()
1029 {
1030 let m: Array2<f32> = Array2::zeros((5, 5));
1031 let mv = m.slice(s![..;2, ..]);
1032 assert!(blas_row_major_2d::<f32, _>(&mv));
1033 assert!(!blas_column_major_2d::<f32, _>(&mv));
1034 }
1035
1036 #[test]
1037 fn blas_row_major_2d_skip_columns_fail()
1038 {
1039 let m: Array2<f32> = Array2::zeros((5, 5));
1040 let mv = m.slice(s![.., ..;2]);
1041 assert!(!blas_row_major_2d::<f32, _>(&mv));
1042 assert!(!blas_column_major_2d::<f32, _>(&mv));
1043 }
1044
1045 #[test]
1046 fn blas_col_major_2d_skip_columns_ok()
1047 {
1048 let m: Array2<f32> = Array2::zeros((5, 5).f());
1049 let mv = m.slice(s![.., ..;2]);
1050 assert!(blas_column_major_2d::<f32, _>(&mv));
1051 assert!(!blas_row_major_2d::<f32, _>(&mv));
1052 }
1053
1054 #[test]
1055 fn blas_col_major_2d_skip_rows_fail()
1056 {
1057 let m: Array2<f32> = Array2::zeros((5, 5).f());
1058 let mv = m.slice(s![..;2, ..]);
1059 assert!(!blas_column_major_2d::<f32, _>(&mv));
1060 assert!(!blas_row_major_2d::<f32, _>(&mv));
1061 }
1062
1063 #[test]
1064 fn blas_too_short_stride()
1065 {
1066 const N: usize = 5;
1070 const MAXSTRIDE: usize = N + 2;
1071 let mut data = [0; MAXSTRIDE * N];
1072 let mut iter = 0..data.len();
1073 data.fill_with(|| iter.next().unwrap());
1074
1075 for stride in 1..=MAXSTRIDE {
1076 let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
1077
1078 if stride < N {
1079 assert_eq!(get_blas_compatible_layout(&m), None);
1080 } else {
1081 assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C));
1082 }
1083 }
1084 }
1085}
1086
1087impl<A> Dot<ArrayRef<A, IxDyn>> for ArrayRef<A, IxDyn>
1105where A: LinalgScalar
1106{
1107 type Output = Array<A, IxDyn>;
1108
1109 fn dot(&self, rhs: &ArrayRef<A, IxDyn>) -> Self::Output
1110 {
1111 match (self.ndim(), rhs.ndim()) {
1112 (1, 1) => {
1113 let a = self.view().into_dimensionality::<Ix1>().unwrap();
1114 let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
1115 let result = a.dot(&b);
1116 ArrayD::from_elem(vec![], result)
1117 }
1118 (2, 2) => {
1119 let a = self.view().into_dimensionality::<Ix2>().unwrap();
1121 let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
1122 let result = a.dot(&b);
1123 result.into_dimensionality::<IxDyn>().unwrap()
1124 }
1125 (2, 1) => {
1126 let a = self.view().into_dimensionality::<Ix2>().unwrap();
1128 let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
1129 let result = a.dot(&b);
1130 result.into_dimensionality::<IxDyn>().unwrap()
1131 }
1132 (1, 2) => {
1133 let a = self.view().into_dimensionality::<Ix1>().unwrap();
1135 let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
1136 let result = a.dot(&b);
1137 result.into_dimensionality::<IxDyn>().unwrap()
1138 }
1139 _ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
1140 }
1141 }
1142}