use crate::imp_prelude::*;
#[cfg(feature = "blas")]
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
use crate::numeric_util;
use crate::ArrayRef1;
use crate::ArrayRef2;
use crate::{LinalgScalar, Zip};
#[cfg(not(feature = "std"))]
use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use std::any::TypeId;
use std::mem::MaybeUninit;
use num_complex::Complex;
use num_complex::{Complex32 as c32, Complex64 as c64};
#[cfg(feature = "blas")]
use libc::c_int;
#[cfg(feature = "blas")]
use cblas_sys as blas_sys;
#[cfg(feature = "blas")]
use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE};
#[cfg(feature = "blas")]
const DOT_BLAS_CUTOFF: usize = 32;
#[cfg(feature = "blas")]
const GEMM_BLAS_CUTOFF: usize = 7;
#[cfg(feature = "blas")]
#[allow(non_camel_case_types)]
type blas_index = c_int;
impl<A> ArrayRef<A, Ix1>
{
#[track_caller]
pub fn dot<Rhs: ?Sized>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
where Self: Dot<Rhs>
{
Dot::dot(self, rhs)
}
fn dot_generic(&self, rhs: &ArrayRef<A, Ix1>) -> A
where A: LinalgScalar
{
debug_assert_eq!(self.len(), rhs.len());
assert!(self.len() == rhs.len());
if let Some(self_s) = self.as_slice() {
if let Some(rhs_s) = rhs.as_slice() {
return numeric_util::unrolled_dot(self_s, rhs_s);
}
}
let mut sum = A::zero();
for i in 0..self.len() {
unsafe {
sum = sum + *self.uget(i) * *rhs.uget(i);
}
}
sum
}
#[cfg(not(feature = "blas"))]
fn dot_impl(&self, rhs: &ArrayRef<A, Ix1>) -> A
where A: LinalgScalar
{
self.dot_generic(rhs)
}
#[cfg(feature = "blas")]
fn dot_impl(&self, rhs: &ArrayRef<A, Ix1>) -> A
where A: LinalgScalar
{
if self.len() >= DOT_BLAS_CUTOFF {
debug_assert_eq!(self.len(), rhs.len());
assert!(self.len() == rhs.len());
macro_rules! dot {
($ty:ty, $func:ident) => {{
if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
unsafe {
let (lhs_ptr, n, incx) =
blas_1d_params(self._ptr().as_ptr(), self.len(), self.strides()[0]);
let (rhs_ptr, _, incy) =
blas_1d_params(rhs._ptr().as_ptr(), rhs.len(), rhs.strides()[0]);
let ret = blas_sys::$func(
n,
lhs_ptr as *const $ty,
incx,
rhs_ptr as *const $ty,
incy,
);
return cast_as::<$ty, A>(&ret);
}
}
}};
}
dot! {f32, cblas_sdot};
dot! {f64, cblas_ddot};
}
self.dot_generic(rhs)
}
}
#[cfg(feature = "blas")]
unsafe fn blas_1d_params<A>(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index)
{
if stride >= 0 || len == 0 {
(ptr, len as blas_index, stride as blas_index)
} else {
let ptr = ptr.offset((len - 1) as isize * stride);
(ptr, len as blas_index, stride as blas_index)
}
}
pub trait Dot<Rhs: ?Sized>
{
type Output;
fn dot(&self, rhs: &Rhs) -> Self::Output;
}
macro_rules! impl_dots {
(
$shape1:ty,
$shape2:ty
) => {
impl<A, S, S2> Dot<ArrayBase<S2, $shape2>> for ArrayBase<S, $shape1>
where
S: Data<Elem = A>,
S2: Data<Elem = A>,
A: LinalgScalar,
{
type Output = <ArrayRef<A, $shape1> as Dot<ArrayRef<A, $shape2>>>::Output;
fn dot(&self, rhs: &ArrayBase<S2, $shape2>) -> Self::Output
{
Dot::dot(&**self, &**rhs)
}
}
impl<A, S> Dot<ArrayRef<A, $shape2>> for ArrayBase<S, $shape1>
where
S: Data<Elem = A>,
A: LinalgScalar,
{
type Output = <ArrayRef<A, $shape1> as Dot<ArrayRef<A, $shape2>>>::Output;
fn dot(&self, rhs: &ArrayRef<A, $shape2>) -> Self::Output
{
(**self).dot(rhs)
}
}
impl<A, S> Dot<ArrayBase<S, $shape2>> for ArrayRef<A, $shape1>
where
S: Data<Elem = A>,
A: LinalgScalar,
{
type Output = <ArrayRef<A, $shape1> as Dot<ArrayRef<A, $shape2>>>::Output;
fn dot(&self, rhs: &ArrayBase<S, $shape2>) -> Self::Output
{
self.dot(&**rhs)
}
}
};
}
impl_dots!(Ix1, Ix1);
impl_dots!(Ix1, Ix2);
impl_dots!(Ix2, Ix1);
impl_dots!(Ix2, Ix2);
impl<A> Dot<ArrayRef<A, Ix1>> for ArrayRef<A, Ix1>
where A: LinalgScalar
{
type Output = A;
#[track_caller]
fn dot(&self, rhs: &ArrayRef<A, Ix1>) -> A
{
self.dot_impl(rhs)
}
}
impl<A> Dot<ArrayRef<A, Ix2>> for ArrayRef<A, Ix1>
where A: LinalgScalar
{
type Output = Array<A, Ix1>;
#[track_caller]
fn dot(&self, rhs: &ArrayRef<A, Ix2>) -> Array<A, Ix1>
{
(*rhs.t()).dot(self)
}
}
impl<A> ArrayRef<A, Ix2>
{
#[track_caller]
pub fn dot<Rhs: ?Sized>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
where Self: Dot<Rhs>
{
Dot::dot(self, rhs)
}
}
impl<A> Dot<ArrayRef<A, Ix2>> for ArrayRef<A, Ix2>
where A: LinalgScalar
{
type Output = Array2<A>;
fn dot(&self, b: &ArrayRef<A, Ix2>) -> Array2<A>
{
let a = self.view();
let b = b.view();
let ((m, k), (k2, n)) = (a.dim(), b.dim());
if k != k2 || m.checked_mul(n).is_none() {
dot_shape_error(m, k, k2, n);
}
let lhs_s0 = a.strides()[0];
let rhs_s0 = b.strides()[0];
let column_major = lhs_s0 == 1 && rhs_s0 == 1;
let mut v = Vec::with_capacity(m * n);
let mut c;
unsafe {
v.set_len(m * n);
c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
}
mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
c
}
}
#[cold]
#[inline(never)]
fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> !
{
match m.checked_mul(n) {
Some(len) if len <= isize::MAX as usize => {}
_ => panic!("ndarray: shape {} × {} overflows isize", m, n),
}
panic!(
"ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
m, k, k2, n
);
}
#[cold]
#[inline(never)]
fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> !
{
panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
m, k, k2, n, c1, c2);
}
impl<A> Dot<ArrayRef<A, Ix1>> for ArrayRef<A, Ix2>
where A: LinalgScalar
{
type Output = Array<A, Ix1>;
#[track_caller]
fn dot(&self, rhs: &ArrayRef<A, Ix1>) -> Array<A, Ix1>
{
let ((m, a), n) = (self.dim(), rhs.dim());
if a != n {
dot_shape_error(m, a, n, 1);
}
unsafe {
let mut c = Array1::uninit(m);
general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
c.assume_init()
}
}
}
impl<A, D> ArrayRef<A, D>
where D: Dimension
{
#[track_caller]
pub fn scaled_add<E>(&mut self, alpha: A, rhs: &ArrayRef<A, E>)
where
A: LinalgScalar,
E: Dimension,
{
self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
}
}
#[cfg(not(feature = "blas"))]
use self::mat_mul_general as mat_mul_impl;
#[cfg(feature = "blas")]
fn mat_mul_impl<A>(alpha: A, a: &ArrayRef2<A>, b: &ArrayRef2<A>, beta: A, c: &mut ArrayRef2<A>)
where A: LinalgScalar
{
let ((m, k), (k2, n)) = (a.dim(), b.dim());
debug_assert_eq!(k, k2);
if (m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF)
&& (same_type::<A, f32>() || same_type::<A, f64>() || same_type::<A, c32>() || same_type::<A, c64>())
{
if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
(get_blas_compatible_layout(a), get_blas_compatible_layout(b), get_blas_compatible_layout(c))
{
let cblas_layout = c_layout.to_cblas_layout();
let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
let lda = blas_stride(a, a_layout);
let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
let ldb = blas_stride(b, b_layout);
let ldc = blas_stride(c, c_layout);
macro_rules! gemm_scalar_cast {
(f32, $var:ident) => {
cast_as(&$var)
};
(f64, $var:ident) => {
cast_as(&$var)
};
(c32, $var:ident) => {
&$var as *const A as *const _
};
(c64, $var:ident) => {
&$var as *const A as *const _
};
}
macro_rules! gemm {
($ty:tt, $gemm:ident) => {
if same_type::<A, $ty>() {
unsafe {
blas_sys::$gemm(
cblas_layout,
a_trans,
b_trans,
m as blas_index, n as blas_index, k as blas_index, gemm_scalar_cast!($ty, alpha), a._ptr().as_ptr() as *const _, lda, b._ptr().as_ptr() as *const _, ldb, gemm_scalar_cast!($ty, beta), c._ptr().as_ptr() as *mut _, ldc, );
}
return;
}
};
}
gemm!(f32, cblas_sgemm);
gemm!(f64, cblas_dgemm);
gemm!(c32, cblas_cgemm);
gemm!(c64, cblas_zgemm);
unreachable!() }
}
mat_mul_general(alpha, a, b, beta, c)
}
fn mat_mul_general<A>(alpha: A, lhs: &ArrayRef2<A>, rhs: &ArrayRef2<A>, beta: A, c: &mut ArrayRef2<A>)
where A: LinalgScalar
{
let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
let ap = lhs.as_ptr();
let bp = rhs.as_ptr();
let cp = c.as_mut_ptr();
let (rsc, csc) = (c.strides()[0], c.strides()[1]);
if same_type::<A, f32>() {
unsafe {
matrixmultiply::sgemm(
m,
k,
n,
cast_as(&alpha),
ap as *const _,
lhs.strides()[0],
lhs.strides()[1],
bp as *const _,
rhs.strides()[0],
rhs.strides()[1],
cast_as(&beta),
cp as *mut _,
rsc,
csc,
);
}
} else if same_type::<A, f64>() {
unsafe {
matrixmultiply::dgemm(
m,
k,
n,
cast_as(&alpha),
ap as *const _,
lhs.strides()[0],
lhs.strides()[1],
bp as *const _,
rhs.strides()[0],
rhs.strides()[1],
cast_as(&beta),
cp as *mut _,
rsc,
csc,
);
}
} else if same_type::<A, c32>() {
unsafe {
matrixmultiply::cgemm(
matrixmultiply::CGemmOption::Standard,
matrixmultiply::CGemmOption::Standard,
m,
k,
n,
complex_array(cast_as(&alpha)),
ap as *const _,
lhs.strides()[0],
lhs.strides()[1],
bp as *const _,
rhs.strides()[0],
rhs.strides()[1],
complex_array(cast_as(&beta)),
cp as *mut _,
rsc,
csc,
);
}
} else if same_type::<A, c64>() {
unsafe {
matrixmultiply::zgemm(
matrixmultiply::CGemmOption::Standard,
matrixmultiply::CGemmOption::Standard,
m,
k,
n,
complex_array(cast_as(&alpha)),
ap as *const _,
lhs.strides()[0],
lhs.strides()[1],
bp as *const _,
rhs.strides()[0],
rhs.strides()[1],
complex_array(cast_as(&beta)),
cp as *mut _,
rsc,
csc,
);
}
} else {
if c.is_empty() {
return;
}
if beta.is_zero() {
c.fill(beta);
}
let mut i = 0;
let mut j = 0;
loop {
unsafe {
let elt = c.uget_mut((i, j));
*elt =
*elt * beta + alpha * (0..k).fold(A::zero(), move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j)));
}
j += 1;
if j == n {
j = 0;
i += 1;
if i == m {
break;
}
}
}
}
}
#[track_caller]
pub fn general_mat_mul<A>(alpha: A, a: &ArrayRef2<A>, b: &ArrayRef2<A>, beta: A, c: &mut ArrayRef2<A>)
where A: LinalgScalar
{
let ((m, k), (k2, n)) = (a.dim(), b.dim());
let (m2, n2) = c.dim();
if k != k2 || m != m2 || n != n2 {
general_dot_shape_error(m, k, k2, n, m2, n2);
} else {
mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
}
}
#[track_caller]
#[allow(clippy::collapsible_if)]
pub fn general_mat_vec_mul<A>(alpha: A, a: &ArrayRef2<A>, x: &ArrayRef1<A>, beta: A, y: &mut ArrayRef1<A>)
where A: LinalgScalar
{
unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
}
#[allow(clippy::collapsible_else_if)]
unsafe fn general_mat_vec_mul_impl<A>(
alpha: A, a: &ArrayRef2<A>, x: &ArrayRef1<A>, beta: A, y: RawArrayViewMut<A, Ix1>,
) where A: LinalgScalar
{
let ((m, k), k2) = (a.dim(), x.dim());
let m2 = y.dim();
if k != k2 || m != m2 {
general_dot_shape_error(m, k, k2, 1, m2, 1);
} else {
#[cfg(feature = "blas")]
macro_rules! gemv {
($ty:ty, $gemv:ident) => {
if same_type::<A, $ty>() {
if let Some(layout) = get_blas_compatible_layout(&a) {
if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y.as_ref()) {
let a_trans = CblasNoTrans;
let a_stride = blas_stride(&a, layout);
let cblas_layout = layout.to_cblas_layout();
let x_offset = offset_from_low_addr_ptr_to_logical_ptr(x._dim(), x._strides());
let x_ptr = x._ptr().as_ptr().sub(x_offset);
let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.parts.dim, &y.parts.strides);
let y_ptr = y.parts.ptr.as_ptr().sub(y_offset);
let x_stride = x.strides()[0] as blas_index;
let y_stride = y.strides()[0] as blas_index;
blas_sys::$gemv(
cblas_layout,
a_trans,
m as blas_index, k as blas_index, cast_as(&alpha), a._ptr().as_ptr() as *const _, a_stride, x_ptr as *const _, x_stride,
cast_as(&beta), y_ptr as *mut _, y_stride,
);
return;
}
}
}
};
}
#[cfg(feature = "blas")]
gemv!(f32, cblas_sgemv);
#[cfg(feature = "blas")]
gemv!(f64, cblas_dgemv);
if beta.is_zero() {
Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
elt.write(row.dot(x) * alpha);
});
} else {
Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
*elt = *elt * beta + row.dot(x) * alpha;
});
}
}
}
pub fn kron<A>(a: &ArrayRef2<A>, b: &ArrayRef2<A>) -> Array<A, Ix2>
where A: LinalgScalar
{
let dimar = a.shape()[0];
let dimac = a.shape()[1];
let dimbr = b.shape()[0];
let dimbc = b.shape()[1];
let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
dimar
.checked_mul(dimbr)
.expect("Dimensions of kronecker product output array overflows usize."),
dimac
.checked_mul(dimbc)
.expect("Dimensions of kronecker product output array overflows usize."),
));
Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
.and(a)
.for_each(|out, &a| {
Zip::from(out).and(b).for_each(|out, &b| {
*out = MaybeUninit::new(a * b);
})
});
unsafe { out.assume_init() }
}
#[inline(always)]
fn same_type<A: 'static, B: 'static>() -> bool
{
TypeId::of::<A>() == TypeId::of::<B>()
}
fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B
{
assert!(same_type::<A, B>(), "expect type {} and {} to match",
std::any::type_name::<A>(), std::any::type_name::<B>());
unsafe { ::std::ptr::read(a as *const _ as *const B) }
}
#[inline]
fn complex_array<A: 'static + Copy>(z: Complex<A>) -> [A; 2]
{
[z.re, z.im]
}
#[cfg(feature = "blas")]
fn blas_compat_1d<A, B>(a: &RawRef<B, Ix1>) -> bool
where
A: 'static,
B: 'static,
{
if !same_type::<A, B>() {
return false;
}
if a.len() > blas_index::MAX as usize {
return false;
}
let stride = a.strides()[0];
if stride == 0 || stride > blas_index::MAX as isize || stride < blas_index::MIN as isize {
return false;
}
true
}
#[cfg(feature = "blas")]
#[derive(Copy, Clone)]
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
enum BlasOrder
{
C,
F,
}
#[cfg(feature = "blas")]
impl BlasOrder
{
fn transpose(self) -> Self
{
match self {
Self::C => Self::F,
Self::F => Self::C,
}
}
#[inline]
fn get_blas_lead_axis(self) -> usize
{
match self {
Self::C => 0,
Self::F => 1,
}
}
fn to_cblas_layout(self) -> CBLAS_LAYOUT
{
match self {
Self::C => CBLAS_LAYOUT::CblasRowMajor,
Self::F => CBLAS_LAYOUT::CblasColMajor,
}
}
fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE
{
let effective_order = match for_layout {
CBLAS_LAYOUT::CblasRowMajor => self,
CBLAS_LAYOUT::CblasColMajor => self.transpose(),
};
match effective_order {
Self::C => CblasNoTrans,
Self::F => CblasTrans,
}
}
}
#[cfg(feature = "blas")]
fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool
{
let (m, n) = dim.into_pattern();
let s0 = stride[0] as isize;
let s1 = stride[1] as isize;
let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
BlasOrder::C => (s1, s0, m, n),
BlasOrder::F => (s0, s1, n, m),
};
if !(inner_stride == 1 || outer_dim == 1) {
return false;
}
if s0 < 1 || s1 < 1 {
return false;
}
if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
|| (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
{
return false;
}
if inner_dim > 1 && (outer_stride as usize) < outer_dim {
return false;
}
if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
return false;
}
true
}
#[cfg(feature = "blas")]
fn get_blas_compatible_layout<A>(a: &ArrayRef<A, Ix2>) -> Option<BlasOrder>
{
if is_blas_2d(a._dim(), a._strides(), BlasOrder::C) {
Some(BlasOrder::C)
} else if is_blas_2d(a._dim(), a._strides(), BlasOrder::F) {
Some(BlasOrder::F)
} else {
None
}
}
#[cfg(feature = "blas")]
fn blas_stride<A>(a: &ArrayRef<A, Ix2>, order: BlasOrder) -> blas_index
{
let axis = order.get_blas_lead_axis();
let other_axis = 1 - axis;
let len_this = a.shape()[axis];
let len_other = a.shape()[other_axis];
let stride = a.strides()[axis];
(if len_this <= 1 {
Ord::max(stride, len_other as isize)
} else {
stride
}) as blas_index
}
#[cfg(test)]
#[cfg(feature = "blas")]
fn blas_row_major_2d<A, B>(a: &ArrayRef2<B>) -> bool
where
A: 'static,
B: 'static,
{
if !same_type::<A, B>() {
return false;
}
is_blas_2d(a._dim(), a._strides(), BlasOrder::C)
}
#[cfg(test)]
#[cfg(feature = "blas")]
fn blas_column_major_2d<A, B>(a: &ArrayRef2<B>) -> bool
where
A: 'static,
B: 'static,
{
if !same_type::<A, B>() {
return false;
}
is_blas_2d(a._dim(), a._strides(), BlasOrder::F)
}
#[cfg(test)]
#[cfg(feature = "blas")]
mod blas_tests
{
use super::*;
#[test]
fn blas_row_major_2d_normal_matrix()
{
let m: Array2<f32> = Array2::zeros((3, 5));
assert!(blas_row_major_2d::<f32, _>(&m));
assert!(!blas_column_major_2d::<f32, _>(&m));
}
#[test]
fn blas_row_major_2d_row_matrix()
{
let m: Array2<f32> = Array2::zeros((1, 5));
assert!(blas_row_major_2d::<f32, _>(&m));
assert!(blas_column_major_2d::<f32, _>(&m));
}
#[test]
fn blas_row_major_2d_column_matrix()
{
let m: Array2<f32> = Array2::zeros((5, 1));
assert!(blas_row_major_2d::<f32, _>(&m));
assert!(blas_column_major_2d::<f32, _>(&m));
}
#[test]
fn blas_row_major_2d_transposed_row_matrix()
{
let m: Array2<f32> = Array2::zeros((1, 5));
let m_t = m.t();
assert!(blas_row_major_2d::<f32, _>(&m_t));
assert!(blas_column_major_2d::<f32, _>(&m_t));
}
#[test]
fn blas_row_major_2d_transposed_column_matrix()
{
let m: Array2<f32> = Array2::zeros((5, 1));
let m_t = m.t();
assert!(blas_row_major_2d::<f32, _>(&m_t));
assert!(blas_column_major_2d::<f32, _>(&m_t));
}
#[test]
fn blas_column_major_2d_normal_matrix()
{
let m: Array2<f32> = Array2::zeros((3, 5).f());
assert!(!blas_row_major_2d::<f32, _>(&m));
assert!(blas_column_major_2d::<f32, _>(&m));
}
#[test]
fn blas_row_major_2d_skip_rows_ok()
{
let m: Array2<f32> = Array2::zeros((5, 5));
let mv = m.slice(s![..;2, ..]);
assert!(blas_row_major_2d::<f32, _>(&mv));
assert!(!blas_column_major_2d::<f32, _>(&mv));
}
#[test]
fn blas_row_major_2d_skip_columns_fail()
{
let m: Array2<f32> = Array2::zeros((5, 5));
let mv = m.slice(s![.., ..;2]);
assert!(!blas_row_major_2d::<f32, _>(&mv));
assert!(!blas_column_major_2d::<f32, _>(&mv));
}
#[test]
fn blas_col_major_2d_skip_columns_ok()
{
let m: Array2<f32> = Array2::zeros((5, 5).f());
let mv = m.slice(s![.., ..;2]);
assert!(blas_column_major_2d::<f32, _>(&mv));
assert!(!blas_row_major_2d::<f32, _>(&mv));
}
#[test]
fn blas_col_major_2d_skip_rows_fail()
{
let m: Array2<f32> = Array2::zeros((5, 5).f());
let mv = m.slice(s![..;2, ..]);
assert!(!blas_column_major_2d::<f32, _>(&mv));
assert!(!blas_row_major_2d::<f32, _>(&mv));
}
#[test]
fn blas_too_short_stride()
{
const N: usize = 5;
const MAXSTRIDE: usize = N + 2;
let mut data = [0; MAXSTRIDE * N];
let mut iter = 0..data.len();
data.fill_with(|| iter.next().unwrap());
for stride in 1..=MAXSTRIDE {
let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
if stride < N {
assert_eq!(get_blas_compatible_layout(&m), None);
} else {
assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C));
}
}
}
}
impl<A> Dot<ArrayRef<A, IxDyn>> for ArrayRef<A, IxDyn>
where A: LinalgScalar
{
type Output = Array<A, IxDyn>;
fn dot(&self, rhs: &ArrayRef<A, IxDyn>) -> Self::Output
{
match (self.ndim(), rhs.ndim()) {
(1, 1) => {
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
ArrayD::from_elem(vec![], result)
}
(2, 2) => {
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(2, 1) => {
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(1, 2) => {
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
}
}
}