ndarray/linalg/
impl_linalg.rs

1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::imp_prelude::*;
10
11#[cfg(feature = "blas")]
12use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
13use crate::numeric_util;
14use crate::ArrayRef1;
15use crate::ArrayRef2;
16
17use crate::{LinalgScalar, Zip};
18
19#[cfg(not(feature = "std"))]
20use alloc::vec;
21#[cfg(not(feature = "std"))]
22use alloc::vec::Vec;
23
24use std::any::TypeId;
25use std::mem::MaybeUninit;
26
27use num_complex::Complex;
28use num_complex::{Complex32 as c32, Complex64 as c64};
29
30#[cfg(feature = "blas")]
31use libc::c_int;
32
33#[cfg(feature = "blas")]
34use cblas_sys as blas_sys;
35#[cfg(feature = "blas")]
36use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE};
37
38/// len of vector before we use blas
39#[cfg(feature = "blas")]
40const DOT_BLAS_CUTOFF: usize = 32;
41/// side of matrix before we use blas
42#[cfg(feature = "blas")]
43const GEMM_BLAS_CUTOFF: usize = 7;
44#[cfg(feature = "blas")]
45#[allow(non_camel_case_types)]
46type blas_index = c_int; // blas index type
47
48impl<A> ArrayRef<A, Ix1>
49{
50    /// Perform dot product or matrix multiplication of arrays `self` and `rhs`.
51    ///
52    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
53    ///
54    /// If `Rhs` is one-dimensional, then the operation is a vector dot
55    /// product, which is the sum of the elementwise products (no conjugation
56    /// of complex operands, and thus not their inner product). In this case,
57    /// `self` and `rhs` must be the same length.
58    ///
59    /// If `Rhs` is two-dimensional, then the operation is matrix
60    /// multiplication, where `self` is treated as a row vector. In this case,
61    /// if `self` is shape *M*, then `rhs` is shape *M* × *N* and the result is
62    /// shape *N*.
63    ///
64    /// **Panics** if the array shapes are incompatible.<br>
65    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
66    /// layout allows.
67    #[track_caller]
68    pub fn dot<Rhs: ?Sized>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
69    where Self: Dot<Rhs>
70    {
71        Dot::dot(self, rhs)
72    }
73
74    fn dot_generic(&self, rhs: &ArrayRef<A, Ix1>) -> A
75    where A: LinalgScalar
76    {
77        debug_assert_eq!(self.len(), rhs.len());
78        assert!(self.len() == rhs.len());
79        if let Some(self_s) = self.as_slice() {
80            if let Some(rhs_s) = rhs.as_slice() {
81                return numeric_util::unrolled_dot(self_s, rhs_s);
82            }
83        }
84        let mut sum = A::zero();
85        for i in 0..self.len() {
86            unsafe {
87                sum = sum + *self.uget(i) * *rhs.uget(i);
88            }
89        }
90        sum
91    }
92
93    #[cfg(not(feature = "blas"))]
94    fn dot_impl(&self, rhs: &ArrayRef<A, Ix1>) -> A
95    where A: LinalgScalar
96    {
97        self.dot_generic(rhs)
98    }
99
100    #[cfg(feature = "blas")]
101    fn dot_impl(&self, rhs: &ArrayRef<A, Ix1>) -> A
102    where A: LinalgScalar
103    {
104        // Use only if the vector is large enough to be worth it
105        if self.len() >= DOT_BLAS_CUTOFF {
106            debug_assert_eq!(self.len(), rhs.len());
107            assert!(self.len() == rhs.len());
108            macro_rules! dot {
109                ($ty:ty, $func:ident) => {{
110                    if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
111                        unsafe {
112                            let (lhs_ptr, n, incx) =
113                                blas_1d_params(self._ptr().as_ptr(), self.len(), self.strides()[0]);
114                            let (rhs_ptr, _, incy) =
115                                blas_1d_params(rhs._ptr().as_ptr(), rhs.len(), rhs.strides()[0]);
116                            let ret = blas_sys::$func(
117                                n,
118                                lhs_ptr as *const $ty,
119                                incx,
120                                rhs_ptr as *const $ty,
121                                incy,
122                            );
123                            return cast_as::<$ty, A>(&ret);
124                        }
125                    }
126                }};
127            }
128
129            dot! {f32, cblas_sdot};
130            dot! {f64, cblas_ddot};
131        }
132        self.dot_generic(rhs)
133    }
134}
135
136/// Return a pointer to the starting element in BLAS's view.
137///
138/// BLAS wants a pointer to the element with lowest address,
139/// which agrees with our pointer for non-negative strides, but
140/// is at the opposite end for negative strides.
141#[cfg(feature = "blas")]
142unsafe fn blas_1d_params<A>(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index)
143{
144    // [x x x x]
145    //        ^--ptr
146    //        stride = -1
147    //  ^--blas_ptr = ptr + (len - 1) * stride
148    if stride >= 0 || len == 0 {
149        (ptr, len as blas_index, stride as blas_index)
150    } else {
151        let ptr = ptr.offset((len - 1) as isize * stride);
152        (ptr, len as blas_index, stride as blas_index)
153    }
154}
155
156/// Matrix Multiplication
157///
158/// For two-dimensional arrays, the dot method computes the matrix
159/// multiplication.
160pub trait Dot<Rhs: ?Sized>
161{
162    /// The result of the operation.
163    ///
164    /// For two-dimensional arrays: a rectangular array.
165    type Output;
166
167    /// Compute the dot product of two arrays.
168    ///
169    /// **Panics** if the arrays' shapes are not compatible.
170    fn dot(&self, rhs: &Rhs) -> Self::Output;
171}
172
173macro_rules! impl_dots {
174    (
175        $shape1:ty,
176        $shape2:ty
177    ) => {
178        impl<A, S, S2> Dot<ArrayBase<S2, $shape2>> for ArrayBase<S, $shape1>
179        where
180            S: Data<Elem = A>,
181            S2: Data<Elem = A>,
182            A: LinalgScalar,
183        {
184            type Output = <ArrayRef<A, $shape1> as Dot<ArrayRef<A, $shape2>>>::Output;
185
186            fn dot(&self, rhs: &ArrayBase<S2, $shape2>) -> Self::Output
187            {
188                Dot::dot(&**self, &**rhs)
189            }
190        }
191
192        impl<A, S> Dot<ArrayRef<A, $shape2>> for ArrayBase<S, $shape1>
193        where
194            S: Data<Elem = A>,
195            A: LinalgScalar,
196        {
197            type Output = <ArrayRef<A, $shape1> as Dot<ArrayRef<A, $shape2>>>::Output;
198
199            fn dot(&self, rhs: &ArrayRef<A, $shape2>) -> Self::Output
200            {
201                (**self).dot(rhs)
202            }
203        }
204
205        impl<A, S> Dot<ArrayBase<S, $shape2>> for ArrayRef<A, $shape1>
206        where
207            S: Data<Elem = A>,
208            A: LinalgScalar,
209        {
210            type Output = <ArrayRef<A, $shape1> as Dot<ArrayRef<A, $shape2>>>::Output;
211
212            fn dot(&self, rhs: &ArrayBase<S, $shape2>) -> Self::Output
213            {
214                self.dot(&**rhs)
215            }
216        }
217    };
218}
219
220impl_dots!(Ix1, Ix1);
221impl_dots!(Ix1, Ix2);
222impl_dots!(Ix2, Ix1);
223impl_dots!(Ix2, Ix2);
224
225impl<A> Dot<ArrayRef<A, Ix1>> for ArrayRef<A, Ix1>
226where A: LinalgScalar
227{
228    type Output = A;
229
230    /// Compute the dot product of one-dimensional arrays.
231    ///
232    /// The dot product is a sum of the elementwise products (no conjugation
233    /// of complex operands, and thus not their inner product).
234    ///
235    /// **Panics** if the arrays are not of the same length.<br>
236    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
237    /// layout allows.
238    #[track_caller]
239    fn dot(&self, rhs: &ArrayRef<A, Ix1>) -> A
240    {
241        self.dot_impl(rhs)
242    }
243}
244
245impl<A> Dot<ArrayRef<A, Ix2>> for ArrayRef<A, Ix1>
246where A: LinalgScalar
247{
248    type Output = Array<A, Ix1>;
249
250    /// Perform the matrix multiplication of the row vector `self` and
251    /// rectangular matrix `rhs`.
252    ///
253    /// The array shapes must agree in the way that
254    /// if `self` is *M*, then `rhs` is *M* × *N*.
255    ///
256    /// Return a result array with shape *N*.
257    ///
258    /// **Panics** if shapes are incompatible.
259    #[track_caller]
260    fn dot(&self, rhs: &ArrayRef<A, Ix2>) -> Array<A, Ix1>
261    {
262        (*rhs.t()).dot(self)
263    }
264}
265
266impl<A> ArrayRef<A, Ix2>
267{
268    /// Perform matrix multiplication of rectangular arrays `self` and `rhs`.
269    ///
270    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
271    ///
272    /// If Rhs is two-dimensional, they array shapes must agree in the way that
273    /// if `self` is *M* × *N*, then `rhs` is *N* × *K*.
274    ///
275    /// Return a result array with shape *M* × *K*.
276    ///
277    /// **Panics** if shapes are incompatible or the number of elements in the
278    /// result would overflow `isize`.
279    ///
280    /// *Note:* If enabled, uses blas `gemv/gemm` for elements of `f32, f64`
281    /// when memory layout allows. The default matrixmultiply backend
282    /// is otherwise used for `f32, f64` for all memory layouts.
283    ///
284    /// ```
285    /// use ndarray::arr2;
286    ///
287    /// let a = arr2(&[[1., 2.],
288    ///                [0., 1.]]);
289    /// let b = arr2(&[[1., 2.],
290    ///                [2., 3.]]);
291    ///
292    /// assert!(
293    ///     a.dot(&b) == arr2(&[[5., 8.],
294    ///                         [2., 3.]])
295    /// );
296    /// ```
297    #[track_caller]
298    pub fn dot<Rhs: ?Sized>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
299    where Self: Dot<Rhs>
300    {
301        Dot::dot(self, rhs)
302    }
303}
304
305impl<A> Dot<ArrayRef<A, Ix2>> for ArrayRef<A, Ix2>
306where A: LinalgScalar
307{
308    type Output = Array2<A>;
309
310    fn dot(&self, b: &ArrayRef<A, Ix2>) -> Array2<A>
311    {
312        let a = self.view();
313        let b = b.view();
314        let ((m, k), (k2, n)) = (a.dim(), b.dim());
315        if k != k2 || m.checked_mul(n).is_none() {
316            dot_shape_error(m, k, k2, n);
317        }
318
319        let lhs_s0 = a.strides()[0];
320        let rhs_s0 = b.strides()[0];
321        let column_major = lhs_s0 == 1 && rhs_s0 == 1;
322        // A is Copy so this is safe
323        let mut v = Vec::with_capacity(m * n);
324        let mut c;
325        unsafe {
326            v.set_len(m * n);
327            c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
328        }
329        mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
330        c
331    }
332}
333
334/// Assumes that `m` and `n` are ≤ `isize::MAX`.
335#[cold]
336#[inline(never)]
337fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> !
338{
339    match m.checked_mul(n) {
340        Some(len) if len <= isize::MAX as usize => {}
341        _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
342    }
343    panic!(
344        "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
345        m, k, k2, n
346    );
347}
348
349#[cold]
350#[inline(never)]
351fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> !
352{
353    panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
354           m, k, k2, n, c1, c2);
355}
356
357/// Perform the matrix multiplication of the rectangular array `self` and
358/// column vector `rhs`.
359///
360/// The array shapes must agree in the way that
361/// if `self` is *M* × *N*, then `rhs` is *N*.
362///
363/// Return a result array with shape *M*.
364///
365/// **Panics** if shapes are incompatible.
366impl<A> Dot<ArrayRef<A, Ix1>> for ArrayRef<A, Ix2>
367where A: LinalgScalar
368{
369    type Output = Array<A, Ix1>;
370
371    #[track_caller]
372    fn dot(&self, rhs: &ArrayRef<A, Ix1>) -> Array<A, Ix1>
373    {
374        let ((m, a), n) = (self.dim(), rhs.dim());
375        if a != n {
376            dot_shape_error(m, a, n, 1);
377        }
378
379        // Avoid initializing the memory in vec -- set it during iteration
380        unsafe {
381            let mut c = Array1::uninit(m);
382            general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
383            c.assume_init()
384        }
385    }
386}
387
388impl<A, D> ArrayRef<A, D>
389where D: Dimension
390{
391    /// Perform the operation `self += alpha * rhs` efficiently, where
392    /// `alpha` is a scalar and `rhs` is another array. This operation is
393    /// also known as `axpy` in BLAS.
394    ///
395    /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
396    ///
397    /// **Panics** if broadcasting isn't possible.
398    #[track_caller]
399    pub fn scaled_add<E>(&mut self, alpha: A, rhs: &ArrayRef<A, E>)
400    where
401        A: LinalgScalar,
402        E: Dimension,
403    {
404        self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
405    }
406}
407
408// mat_mul_impl uses ArrayRef arguments to send all array kinds into
409// the same instantiated implementation.
410#[cfg(not(feature = "blas"))]
411use self::mat_mul_general as mat_mul_impl;
412
413#[cfg(feature = "blas")]
414fn mat_mul_impl<A>(alpha: A, a: &ArrayRef2<A>, b: &ArrayRef2<A>, beta: A, c: &mut ArrayRef2<A>)
415where A: LinalgScalar
416{
417    let ((m, k), (k2, n)) = (a.dim(), b.dim());
418    debug_assert_eq!(k, k2);
419    if (m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF)
420        && (same_type::<A, f32>() || same_type::<A, f64>() || same_type::<A, c32>() || same_type::<A, c64>())
421    {
422        // Compute A B -> C
423        // We require for BLAS compatibility that:
424        // A, B, C are contiguous (stride=1) in their fastest dimension,
425        // but they can be either row major/"c" or col major/"f".
426        //
427        // The "normal case" is CblasRowMajor for cblas.
428        // Select CblasRowMajor / CblasColMajor to fit C's memory order.
429        //
430        // Apply transpose to A, B as needed if they differ from the row major case.
431        // If C is CblasColMajor then transpose both A, B (again!)
432
433        if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
434            (get_blas_compatible_layout(a), get_blas_compatible_layout(b), get_blas_compatible_layout(c))
435        {
436            let cblas_layout = c_layout.to_cblas_layout();
437            let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
438            let lda = blas_stride(&a, a_layout);
439
440            let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
441            let ldb = blas_stride(&b, b_layout);
442
443            let ldc = blas_stride(&c, c_layout);
444
445            macro_rules! gemm_scalar_cast {
446                (f32, $var:ident) => {
447                    cast_as(&$var)
448                };
449                (f64, $var:ident) => {
450                    cast_as(&$var)
451                };
452                (c32, $var:ident) => {
453                    &$var as *const A as *const _
454                };
455                (c64, $var:ident) => {
456                    &$var as *const A as *const _
457                };
458            }
459
460            macro_rules! gemm {
461                ($ty:tt, $gemm:ident) => {
462                    if same_type::<A, $ty>() {
463                        // gemm is C ← αA^Op B^Op + βC
464                        // Where Op is notrans/trans/conjtrans
465                        unsafe {
466                            blas_sys::$gemm(
467                                cblas_layout,
468                                a_trans,
469                                b_trans,
470                                m as blas_index,                 // m, rows of Op(a)
471                                n as blas_index,                 // n, cols of Op(b)
472                                k as blas_index,                 // k, cols of Op(a)
473                                gemm_scalar_cast!($ty, alpha),   // alpha
474                                a._ptr().as_ptr() as *const _,      // a
475                                lda,                             // lda
476                                b._ptr().as_ptr() as *const _,      // b
477                                ldb,                             // ldb
478                                gemm_scalar_cast!($ty, beta),    // beta
479                                c._ptr().as_ptr() as *mut _,        // c
480                                ldc,                             // ldc
481                            );
482                        }
483                        return;
484                    }
485                };
486            }
487
488            gemm!(f32, cblas_sgemm);
489            gemm!(f64, cblas_dgemm);
490            gemm!(c32, cblas_cgemm);
491            gemm!(c64, cblas_zgemm);
492
493            unreachable!() // we checked above that A is one of f32, f64, c32, c64
494        }
495    }
496    mat_mul_general(alpha, a, b, beta, c)
497}
498
499/// C ← α A B + β C
500fn mat_mul_general<A>(alpha: A, lhs: &ArrayRef2<A>, rhs: &ArrayRef2<A>, beta: A, c: &mut ArrayRef2<A>)
501where A: LinalgScalar
502{
503    let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
504
505    // common parameters for gemm
506    let ap = lhs.as_ptr();
507    let bp = rhs.as_ptr();
508    let cp = c.as_mut_ptr();
509    let (rsc, csc) = (c.strides()[0], c.strides()[1]);
510    if same_type::<A, f32>() {
511        unsafe {
512            matrixmultiply::sgemm(
513                m,
514                k,
515                n,
516                cast_as(&alpha),
517                ap as *const _,
518                lhs.strides()[0],
519                lhs.strides()[1],
520                bp as *const _,
521                rhs.strides()[0],
522                rhs.strides()[1],
523                cast_as(&beta),
524                cp as *mut _,
525                rsc,
526                csc,
527            );
528        }
529    } else if same_type::<A, f64>() {
530        unsafe {
531            matrixmultiply::dgemm(
532                m,
533                k,
534                n,
535                cast_as(&alpha),
536                ap as *const _,
537                lhs.strides()[0],
538                lhs.strides()[1],
539                bp as *const _,
540                rhs.strides()[0],
541                rhs.strides()[1],
542                cast_as(&beta),
543                cp as *mut _,
544                rsc,
545                csc,
546            );
547        }
548    } else if same_type::<A, c32>() {
549        unsafe {
550            matrixmultiply::cgemm(
551                matrixmultiply::CGemmOption::Standard,
552                matrixmultiply::CGemmOption::Standard,
553                m,
554                k,
555                n,
556                complex_array(cast_as(&alpha)),
557                ap as *const _,
558                lhs.strides()[0],
559                lhs.strides()[1],
560                bp as *const _,
561                rhs.strides()[0],
562                rhs.strides()[1],
563                complex_array(cast_as(&beta)),
564                cp as *mut _,
565                rsc,
566                csc,
567            );
568        }
569    } else if same_type::<A, c64>() {
570        unsafe {
571            matrixmultiply::zgemm(
572                matrixmultiply::CGemmOption::Standard,
573                matrixmultiply::CGemmOption::Standard,
574                m,
575                k,
576                n,
577                complex_array(cast_as(&alpha)),
578                ap as *const _,
579                lhs.strides()[0],
580                lhs.strides()[1],
581                bp as *const _,
582                rhs.strides()[0],
583                rhs.strides()[1],
584                complex_array(cast_as(&beta)),
585                cp as *mut _,
586                rsc,
587                csc,
588            );
589        }
590    } else {
591        // It's a no-op if `c` has zero length.
592        if c.is_empty() {
593            return;
594        }
595
596        // initialize memory if beta is zero
597        if beta.is_zero() {
598            c.fill(beta);
599        }
600
601        let mut i = 0;
602        let mut j = 0;
603        loop {
604            unsafe {
605                let elt = c.uget_mut((i, j));
606                *elt =
607                    *elt * beta + alpha * (0..k).fold(A::zero(), move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j)));
608            }
609            j += 1;
610            if j == n {
611                j = 0;
612                i += 1;
613                if i == m {
614                    break;
615                }
616            }
617        }
618    }
619}
620
621/// General matrix-matrix multiplication.
622///
623/// Compute C ← α A B + β C
624///
625/// The array shapes must agree in the way that
626/// if `a` is *M* × *N*, then `b` is *N* × *K* and `c` is *M* × *K*.
627///
628/// ***Panics*** if array shapes are not compatible<br>
629/// *Note:* If enabled, uses blas `gemm` for elements of `f32, f64` when memory
630/// layout allows.  The default matrixmultiply backend is otherwise used for
631/// `f32, f64` for all memory layouts.
632#[track_caller]
633pub fn general_mat_mul<A>(alpha: A, a: &ArrayRef2<A>, b: &ArrayRef2<A>, beta: A, c: &mut ArrayRef2<A>)
634where A: LinalgScalar
635{
636    let ((m, k), (k2, n)) = (a.dim(), b.dim());
637    let (m2, n2) = c.dim();
638    if k != k2 || m != m2 || n != n2 {
639        general_dot_shape_error(m, k, k2, n, m2, n2);
640    } else {
641        mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
642    }
643}
644
645/// General matrix-vector multiplication.
646///
647/// Compute y ← α A x + β y
648///
649/// where A is a *M* × *N* matrix and x is an *N*-element column vector and
650/// y an *M*-element column vector (one dimensional arrays).
651///
652/// ***Panics*** if array shapes are not compatible<br>
653/// *Note:* If enabled, uses blas `gemv` for elements of `f32, f64` when memory
654/// layout allows.
655#[track_caller]
656#[allow(clippy::collapsible_if)]
657pub fn general_mat_vec_mul<A>(alpha: A, a: &ArrayRef2<A>, x: &ArrayRef1<A>, beta: A, y: &mut ArrayRef1<A>)
658where A: LinalgScalar
659{
660    unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
661}
662
663/// General matrix-vector multiplication
664///
665/// Use a raw view for the destination vector, so that it can be uninitialized.
666///
667/// ## Safety
668///
669/// The caller must ensure that the raw view is valid for writing.
670/// the destination may be uninitialized iff beta is zero.
671#[allow(clippy::collapsible_else_if)]
672unsafe fn general_mat_vec_mul_impl<A>(
673    alpha: A, a: &ArrayRef2<A>, x: &ArrayRef1<A>, beta: A, y: RawArrayViewMut<A, Ix1>,
674) where A: LinalgScalar
675{
676    let ((m, k), k2) = (a.dim(), x.dim());
677    let m2 = y.dim();
678    if k != k2 || m != m2 {
679        general_dot_shape_error(m, k, k2, 1, m2, 1);
680    } else {
681        #[cfg(feature = "blas")]
682        macro_rules! gemv {
683            ($ty:ty, $gemv:ident) => {
684                if same_type::<A, $ty>() {
685                    if let Some(layout) = get_blas_compatible_layout(&a) {
686                        if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y.as_ref()) {
687                            // Determine stride between rows or columns. Note that the stride is
688                            // adjusted to at least `k` or `m` to handle the case of a matrix with a
689                            // trivial (length 1) dimension, since the stride for the trivial dimension
690                            // may be arbitrary.
691                            let a_trans = CblasNoTrans;
692
693                            let a_stride = blas_stride(&a, layout);
694                            let cblas_layout = layout.to_cblas_layout();
695
696                            // Low addr in memory pointers required for x, y
697                            let x_offset = offset_from_low_addr_ptr_to_logical_ptr(x._dim(), x._strides());
698                            let x_ptr = x._ptr().as_ptr().sub(x_offset);
699                            let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.parts.dim, &y.parts.strides);
700                            let y_ptr = y.parts.ptr.as_ptr().sub(y_offset);
701
702                            let x_stride = x.strides()[0] as blas_index;
703                            let y_stride = y.strides()[0] as blas_index;
704
705                            blas_sys::$gemv(
706                                cblas_layout,
707                                a_trans,
708                                m as blas_index,            // m, rows of Op(a)
709                                k as blas_index,            // n, cols of Op(a)
710                                cast_as(&alpha),            // alpha
711                                a._ptr().as_ptr() as *const _, // a
712                                a_stride,                   // lda
713                                x_ptr as *const _,          // x
714                                x_stride,
715                                cast_as(&beta),             // beta
716                                y_ptr as *mut _,            // y
717                                y_stride,
718                            );
719                            return;
720                        }
721                    }
722                }
723            };
724        }
725        #[cfg(feature = "blas")]
726        gemv!(f32, cblas_sgemv);
727        #[cfg(feature = "blas")]
728        gemv!(f64, cblas_dgemv);
729
730        /* general */
731
732        if beta.is_zero() {
733            // when beta is zero, c may be uninitialized
734            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
735                elt.write(row.dot(x) * alpha);
736            });
737        } else {
738            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
739                *elt = *elt * beta + row.dot(x) * alpha;
740            });
741        }
742    }
743}
744
745/// Kronecker product of 2D matrices.
746///
747/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R)
748/// matrix K formed by the block multiplication A_ij * B.
749pub fn kron<A>(a: &ArrayRef2<A>, b: &ArrayRef2<A>) -> Array<A, Ix2>
750where A: LinalgScalar
751{
752    let dimar = a.shape()[0];
753    let dimac = a.shape()[1];
754    let dimbr = b.shape()[0];
755    let dimbc = b.shape()[1];
756    let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
757        dimar
758            .checked_mul(dimbr)
759            .expect("Dimensions of kronecker product output array overflows usize."),
760        dimac
761            .checked_mul(dimbc)
762            .expect("Dimensions of kronecker product output array overflows usize."),
763    ));
764    Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
765        .and(a)
766        .for_each(|out, &a| {
767            Zip::from(out).and(b).for_each(|out, &b| {
768                *out = MaybeUninit::new(a * b);
769            })
770        });
771    unsafe { out.assume_init() }
772}
773
774#[inline(always)]
775/// Return `true` if `A` and `B` are the same type
776fn same_type<A: 'static, B: 'static>() -> bool
777{
778    TypeId::of::<A>() == TypeId::of::<B>()
779}
780
781// Read pointer to type `A` as type `B`.
782//
783// **Panics** if `A` and `B` are not the same type
784fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B
785{
786    assert!(same_type::<A, B>(), "expect type {} and {} to match",
787            std::any::type_name::<A>(), std::any::type_name::<B>());
788    unsafe { ::std::ptr::read(a as *const _ as *const B) }
789}
790
791/// Return the complex in the form of an array [re, im]
792#[inline]
793fn complex_array<A: 'static + Copy>(z: Complex<A>) -> [A; 2]
794{
795    [z.re, z.im]
796}
797
798#[cfg(feature = "blas")]
799fn blas_compat_1d<A, B>(a: &RawRef<B, Ix1>) -> bool
800where
801    A: 'static,
802    B: 'static,
803{
804    if !same_type::<A, B>() {
805        return false;
806    }
807    if a.len() > blas_index::MAX as usize {
808        return false;
809    }
810    let stride = a.strides()[0];
811    if stride == 0 || stride > blas_index::MAX as isize || stride < blas_index::MIN as isize {
812        return false;
813    }
814    true
815}
816
817#[cfg(feature = "blas")]
818#[derive(Copy, Clone)]
819#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
820enum BlasOrder
821{
822    C,
823    F,
824}
825
826#[cfg(feature = "blas")]
827impl BlasOrder
828{
829    fn transpose(self) -> Self
830    {
831        match self {
832            Self::C => Self::F,
833            Self::F => Self::C,
834        }
835    }
836
837    #[inline]
838    /// Axis of leading stride (opposite of contiguous axis)
839    fn get_blas_lead_axis(self) -> usize
840    {
841        match self {
842            Self::C => 0,
843            Self::F => 1,
844        }
845    }
846
847    fn to_cblas_layout(self) -> CBLAS_LAYOUT
848    {
849        match self {
850            Self::C => CBLAS_LAYOUT::CblasRowMajor,
851            Self::F => CBLAS_LAYOUT::CblasColMajor,
852        }
853    }
854
855    /// When using cblas_sgemm (etc) with C matrix using `for_layout`,
856    /// how should this `self` matrix be transposed
857    fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE
858    {
859        let effective_order = match for_layout {
860            CBLAS_LAYOUT::CblasRowMajor => self,
861            CBLAS_LAYOUT::CblasColMajor => self.transpose(),
862        };
863
864        match effective_order {
865            Self::C => CblasNoTrans,
866            Self::F => CblasTrans,
867        }
868    }
869}
870
871#[cfg(feature = "blas")]
872fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool
873{
874    let (m, n) = dim.into_pattern();
875    let s0 = stride[0] as isize;
876    let s1 = stride[1] as isize;
877    let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
878        BlasOrder::C => (s1, s0, m, n),
879        BlasOrder::F => (s0, s1, n, m),
880    };
881
882    if !(inner_stride == 1 || outer_dim == 1) {
883        return false;
884    }
885
886    if s0 < 1 || s1 < 1 {
887        return false;
888    }
889
890    if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
891        || (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
892    {
893        return false;
894    }
895
896    // leading stride must >= the dimension (no broadcasting/aliasing)
897    if inner_dim > 1 && (outer_stride as usize) < outer_dim {
898        return false;
899    }
900
901    if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
902        return false;
903    }
904
905    true
906}
907
908/// Get BLAS compatible layout if any (C or F, preferring the former)
909#[cfg(feature = "blas")]
910fn get_blas_compatible_layout<A>(a: &ArrayRef<A, Ix2>) -> Option<BlasOrder>
911{
912    if is_blas_2d(a._dim(), a._strides(), BlasOrder::C) {
913        Some(BlasOrder::C)
914    } else if is_blas_2d(a._dim(), a._strides(), BlasOrder::F) {
915        Some(BlasOrder::F)
916    } else {
917        None
918    }
919}
920
921/// `a` should be blas compatible.
922/// axis: 0 or 1.
923///
924/// Return leading stride (lda, ldb, ldc) of array
925#[cfg(feature = "blas")]
926fn blas_stride<A>(a: &ArrayRef<A, Ix2>, order: BlasOrder) -> blas_index
927{
928    let axis = order.get_blas_lead_axis();
929    let other_axis = 1 - axis;
930    let len_this = a.shape()[axis];
931    let len_other = a.shape()[other_axis];
932    let stride = a.strides()[axis];
933
934    // if current axis has length == 1, then stride does not matter for ndarray
935    // but for BLAS we need a stride that makes sense, i.e. it's >= the other axis
936
937    // cast: a should already be blas compatible
938    (if len_this <= 1 {
939        Ord::max(stride, len_other as isize)
940    } else {
941        stride
942    }) as blas_index
943}
944
945#[cfg(test)]
946#[cfg(feature = "blas")]
947fn blas_row_major_2d<A, B>(a: &ArrayRef2<B>) -> bool
948where
949    A: 'static,
950    B: 'static,
951{
952    if !same_type::<A, B>() {
953        return false;
954    }
955    is_blas_2d(a._dim(), a._strides(), BlasOrder::C)
956}
957
958#[cfg(test)]
959#[cfg(feature = "blas")]
960fn blas_column_major_2d<A, B>(a: &ArrayRef2<B>) -> bool
961where
962    A: 'static,
963    B: 'static,
964{
965    if !same_type::<A, B>() {
966        return false;
967    }
968    is_blas_2d(a._dim(), a._strides(), BlasOrder::F)
969}
970
971#[cfg(test)]
972#[cfg(feature = "blas")]
973mod blas_tests
974{
975    use super::*;
976
977    #[test]
978    fn blas_row_major_2d_normal_matrix()
979    {
980        let m: Array2<f32> = Array2::zeros((3, 5));
981        assert!(blas_row_major_2d::<f32, _>(&m));
982        assert!(!blas_column_major_2d::<f32, _>(&m));
983    }
984
985    #[test]
986    fn blas_row_major_2d_row_matrix()
987    {
988        let m: Array2<f32> = Array2::zeros((1, 5));
989        assert!(blas_row_major_2d::<f32, _>(&m));
990        assert!(blas_column_major_2d::<f32, _>(&m));
991    }
992
993    #[test]
994    fn blas_row_major_2d_column_matrix()
995    {
996        let m: Array2<f32> = Array2::zeros((5, 1));
997        assert!(blas_row_major_2d::<f32, _>(&m));
998        assert!(blas_column_major_2d::<f32, _>(&m));
999    }
1000
1001    #[test]
1002    fn blas_row_major_2d_transposed_row_matrix()
1003    {
1004        let m: Array2<f32> = Array2::zeros((1, 5));
1005        let m_t = m.t();
1006        assert!(blas_row_major_2d::<f32, _>(&m_t));
1007        assert!(blas_column_major_2d::<f32, _>(&m_t));
1008    }
1009
1010    #[test]
1011    fn blas_row_major_2d_transposed_column_matrix()
1012    {
1013        let m: Array2<f32> = Array2::zeros((5, 1));
1014        let m_t = m.t();
1015        assert!(blas_row_major_2d::<f32, _>(&m_t));
1016        assert!(blas_column_major_2d::<f32, _>(&m_t));
1017    }
1018
1019    #[test]
1020    fn blas_column_major_2d_normal_matrix()
1021    {
1022        let m: Array2<f32> = Array2::zeros((3, 5).f());
1023        assert!(!blas_row_major_2d::<f32, _>(&m));
1024        assert!(blas_column_major_2d::<f32, _>(&m));
1025    }
1026
1027    #[test]
1028    fn blas_row_major_2d_skip_rows_ok()
1029    {
1030        let m: Array2<f32> = Array2::zeros((5, 5));
1031        let mv = m.slice(s![..;2, ..]);
1032        assert!(blas_row_major_2d::<f32, _>(&mv));
1033        assert!(!blas_column_major_2d::<f32, _>(&mv));
1034    }
1035
1036    #[test]
1037    fn blas_row_major_2d_skip_columns_fail()
1038    {
1039        let m: Array2<f32> = Array2::zeros((5, 5));
1040        let mv = m.slice(s![.., ..;2]);
1041        assert!(!blas_row_major_2d::<f32, _>(&mv));
1042        assert!(!blas_column_major_2d::<f32, _>(&mv));
1043    }
1044
1045    #[test]
1046    fn blas_col_major_2d_skip_columns_ok()
1047    {
1048        let m: Array2<f32> = Array2::zeros((5, 5).f());
1049        let mv = m.slice(s![.., ..;2]);
1050        assert!(blas_column_major_2d::<f32, _>(&mv));
1051        assert!(!blas_row_major_2d::<f32, _>(&mv));
1052    }
1053
1054    #[test]
1055    fn blas_col_major_2d_skip_rows_fail()
1056    {
1057        let m: Array2<f32> = Array2::zeros((5, 5).f());
1058        let mv = m.slice(s![..;2, ..]);
1059        assert!(!blas_column_major_2d::<f32, _>(&mv));
1060        assert!(!blas_row_major_2d::<f32, _>(&mv));
1061    }
1062
1063    #[test]
1064    fn blas_too_short_stride()
1065    {
1066        // leading stride must be longer than the other dimension
1067        // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.
1068
1069        const N: usize = 5;
1070        const MAXSTRIDE: usize = N + 2;
1071        let mut data = [0; MAXSTRIDE * N];
1072        let mut iter = 0..data.len();
1073        data.fill_with(|| iter.next().unwrap());
1074
1075        for stride in 1..=MAXSTRIDE {
1076            let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
1077
1078            if stride < N {
1079                assert_eq!(get_blas_compatible_layout(&m), None);
1080            } else {
1081                assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C));
1082            }
1083        }
1084    }
1085}
1086
1087/// Dot product for dynamic-dimensional arrays (`ArrayD`).
1088///
1089/// For one-dimensional arrays, computes the vector dot product, which is the sum
1090/// of the elementwise products (no conjugation of complex operands).
1091/// Both arrays must have the same length.
1092///
1093/// For two-dimensional arrays, performs matrix multiplication. The array shapes
1094/// must be compatible in the following ways:
1095/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication
1096/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M*
1097/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N*
1098/// - If both arrays are one-dimensional of length *N*, returns a scalar
1099///
1100/// **Panics** if:
1101/// - The arrays have dimensions other than 1 or 2
1102/// - The array shapes are incompatible for the operation
1103/// - For vector dot product: the vectors have different lengths
1104impl<A> Dot<ArrayRef<A, IxDyn>> for ArrayRef<A, IxDyn>
1105where A: LinalgScalar
1106{
1107    type Output = Array<A, IxDyn>;
1108
1109    fn dot(&self, rhs: &ArrayRef<A, IxDyn>) -> Self::Output
1110    {
1111        match (self.ndim(), rhs.ndim()) {
1112            (1, 1) => {
1113                let a = self.view().into_dimensionality::<Ix1>().unwrap();
1114                let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
1115                let result = a.dot(&b);
1116                ArrayD::from_elem(vec![], result)
1117            }
1118            (2, 2) => {
1119                // Matrix-matrix multiplication
1120                let a = self.view().into_dimensionality::<Ix2>().unwrap();
1121                let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
1122                let result = a.dot(&b);
1123                result.into_dimensionality::<IxDyn>().unwrap()
1124            }
1125            (2, 1) => {
1126                // Matrix-vector multiplication
1127                let a = self.view().into_dimensionality::<Ix2>().unwrap();
1128                let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
1129                let result = a.dot(&b);
1130                result.into_dimensionality::<IxDyn>().unwrap()
1131            }
1132            (1, 2) => {
1133                // Vector-matrix multiplication
1134                let a = self.view().into_dimensionality::<Ix1>().unwrap();
1135                let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
1136                let result = a.dot(&b);
1137                result.into_dimensionality::<IxDyn>().unwrap()
1138            }
1139            _ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
1140        }
1141    }
1142}