autograd/ops/
dot_ops.rs

1/// Some gemm kernel usages are ported from ndarray
2use crate::ndarray_ext::NdArray;
3#[cfg(feature = "mkl")]
4use crate::ndarray_ext::{get_batch_ptrs, get_batch_ptrs_mut};
5#[cfg(feature = "mkl")]
6use crate::ops::mkl_ffi::*;
7use crate::same_type;
8use crate::tensor::Tensor;
9use crate::Float;
10use crate::NdArrayView;
11use crate::{op, NdArrayViewMut};
12use ndarray;
13#[cfg(feature = "mkl")]
14use ndarray::Dimension;
15use ndarray::{ArrayView2, ArrayViewMut2};
16#[cfg(feature = "mkl")]
17use std::cmp;
18#[cfg(feature = "mkl")]
19use std::mem;
20
21#[cfg(feature = "mkl")]
22#[inline]
23fn blas_row_major_2d<T: 'static, F>(a: &ndarray::ArrayView2<F>) -> bool
24where
25    F: Float,
26{
27    if !same_type::<F, T>() {
28        return false;
29    }
30    is_blas_2d(&a.raw_dim(), a.strides(), MemoryOrder::C)
31}
32
33#[cfg(feature = "mkl")]
34#[inline]
35fn blas_row_major_nd<T: 'static, F>(a: &NdArrayView<F>) -> bool
36where
37    F: Float,
38{
39    if !same_type::<F, T>() {
40        return false;
41    }
42    let strides = a.strides();
43    let rank = strides.len();
44    is_blas_nd(
45        a.shape(),
46        strides[rank - 2],
47        strides[rank - 1],
48        MemoryOrder::C,
49    )
50}
51
52#[cfg(feature = "mkl")]
53#[inline]
54fn blas_row_major_2d_mut<T: 'static, F>(a: &ndarray::ArrayViewMut2<F>) -> bool
55where
56    F: Float,
57{
58    if !same_type::<F, T>() {
59        return false;
60    }
61    is_blas_2d(&a.raw_dim(), a.strides(), MemoryOrder::C)
62}
63
64#[cfg(feature = "mkl")]
65#[inline]
66fn blas_row_major_nd_mut<T: 'static, F>(a: &NdArrayViewMut<F>) -> bool
67where
68    F: Float,
69{
70    if !same_type::<F, T>() {
71        return false;
72    }
73    let strides = a.strides();
74    let rank = strides.len();
75    is_blas_nd(
76        a.shape(),
77        strides[rank - 2],
78        strides[rank - 1],
79        MemoryOrder::C,
80    )
81}
82
83#[cfg(feature = "mkl")]
84fn is_blas_nd(shape: &[usize], stride0: isize, stride1: isize, order: MemoryOrder) -> bool {
85    let (m, n) = (shape[0], shape[1]);
86    let (inner_stride, outer_dim) = match order {
87        MemoryOrder::C => (stride1, n),
88        MemoryOrder::F => (stride0, m),
89    };
90    if !(inner_stride == 1 || outer_dim == 1) {
91        return false;
92    }
93    if stride0 < 1 || stride1 < 1 {
94        return false;
95    }
96    if (stride0 > MklInt::max_value() as isize || stride0 < MklInt::min_value() as isize)
97        || (stride1 > MklInt::max_value() as isize || stride1 < MklInt::min_value() as isize)
98    {
99        return false;
100    }
101    if m > MklInt::max_value() as usize || n > MklInt::max_value() as usize {
102        return false;
103    }
104    true
105}
106
107#[cfg(feature = "mkl")]
108fn is_blas_2d(dim: &ndarray::Ix2, stride: &[isize], order: MemoryOrder) -> bool {
109    let (m, n) = dim.into_pattern();
110    let s0 = stride[0] as isize;
111    let s1 = stride[1] as isize;
112    let (inner_stride, outer_dim) = match order {
113        MemoryOrder::C => (s1, n),
114        MemoryOrder::F => (s0, m),
115    };
116    if !(inner_stride == 1 || outer_dim == 1) {
117        return false;
118    }
119    if s0 < 1 || s1 < 1 {
120        return false;
121    }
122    if (s0 > MklInt::max_value() as isize || s0 < MklInt::min_value() as isize)
123        || (s1 > MklInt::max_value() as isize || s1 < MklInt::min_value() as isize)
124    {
125        return false;
126    }
127    if m > MklInt::max_value() as usize || n > MklInt::max_value() as usize {
128        return false;
129    }
130    true
131}
132
133// Read pointer to type `A` as type `B`.
134//
135// **Panics** if `A` and `B` are not the same type
136#[inline]
137fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
138    assert!(same_type::<A, B>());
139    unsafe { ::std::ptr::read(a as *const _ as *const B) }
140}
141
142// mkl version of ndarray's mat_mul_impl
143#[cfg(feature = "mkl")]
144fn mat_mul_impl_blas<F: Float>(
145    alpha: F,
146    lhs: &ArrayView2<'_, F>,
147    rhs: &ArrayView2<'_, F>,
148    beta: F,
149    c: &mut ArrayViewMut2<'_, F>,
150) {
151    const GEMM_BLAS_CUTOFF: usize = 7;
152
153    // size cutoff for using BLAS
154    let cut = GEMM_BLAS_CUTOFF;
155    let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
156    if !(m > cut || n > cut || a > cut) || !(same_type::<F, f32>() || same_type::<F, f64>()) {
157        return mat_mul_impl_slow(alpha, lhs, rhs, beta, c);
158    }
159    {
160        // Use `c` for c-order and `f` for an f-order matrix
161        // We can handle c * c, f * f generally and
162        // c * f and f * c if the `f` matrix is square.
163        let mut lhs_ = lhs.view();
164        let mut rhs_ = rhs.view();
165        let mut c_ = c.view_mut();
166        let lhs_s0 = lhs_.strides()[0];
167        let rhs_s0 = rhs_.strides()[0];
168        let both_f = lhs_s0 == 1 && rhs_s0 == 1;
169        let mut lhs_trans = CblasTranspose::CblasNoTrans;
170        let mut rhs_trans = CblasTranspose::CblasNoTrans;
171        if both_f {
172            // A^t B^t = C^t => B A = C
173            let lhs_t = lhs_.reversed_axes();
174            lhs_ = rhs_.reversed_axes();
175            rhs_ = lhs_t;
176            c_ = c_.reversed_axes();
177            mem::swap(&mut m, &mut n);
178        } else if lhs_s0 == 1 && m == a {
179            lhs_ = lhs_.reversed_axes();
180            lhs_trans = CblasTranspose::CblasTrans;
181        } else if rhs_s0 == 1 && a == n {
182            rhs_ = rhs_.reversed_axes();
183            rhs_trans = CblasTranspose::CblasTrans;
184        }
185
186        macro_rules! call_kernel_def {
187            ($ty:ty, $f:ident) => {
188                if blas_row_major_2d::<$ty, _>(&lhs_)
189                    && blas_row_major_2d::<$ty, _>(&rhs_)
190                    && blas_row_major_2d_mut::<$ty, _>(&c_)
191                {
192                    let (m, k) = match lhs_trans {
193                        CblasTranspose::CblasNoTrans => lhs_.dim(),
194                        _ => {
195                            let (rows, cols) = lhs_.dim();
196                            (cols, rows)
197                        }
198                    };
199                    let n = match rhs_trans {
200                        CblasTranspose::CblasNoTrans => rhs_.raw_dim()[1],
201                        _ => rhs_.raw_dim()[0],
202                    };
203                    // adjust strides, these may [1, 1] for column matrices
204                    let lhs_stride = cmp::max(lhs_.strides()[0] as MklInt, k as MklInt);
205                    let rhs_stride = cmp::max(rhs_.strides()[0] as MklInt, n as MklInt);
206                    let c_stride = cmp::max(c_.strides()[0] as MklInt, n as MklInt);
207
208                    // gemm is C ← αA^Op B^Op + βC
209                    // Where Op is notrans/trans/conjtrans
210                    unsafe {
211                        $f(
212                            CBLAS_ROW_MAJOR,
213                            lhs_trans,
214                            rhs_trans,
215                            m as MklInt,               // m, rows of Op(a)
216                            n as MklInt,               // n, cols of Op(b)
217                            k as MklInt,               // k, cols of Op(a)
218                            cast_as(&alpha),           // alpha
219                            lhs_.as_ptr() as *const _, // a
220                            lhs_stride,                // lda
221                            rhs_.as_ptr() as *const _, // b
222                            rhs_stride,                // ldb
223                            cast_as(&beta),            // beta
224                            c_.as_mut_ptr() as *mut _, // c
225                            c_stride,                  // ldc
226                        );
227                    }
228                    return;
229                }
230            };
231        }
232        call_kernel_def!(f32, cblas_sgemm);
233        call_kernel_def!(f64, cblas_dgemm);
234    }
235    mat_mul_impl_slow(alpha, lhs, rhs, beta, c)
236}
237
238#[allow(unused_assignments)]
239#[cfg(feature = "mkl")]
240fn batch_mat_mul_impl<F: Float>(
241    alpha: F,
242    lhs: &NdArrayView<'_, F>,
243    rhs: &NdArrayView<'_, F>,
244    beta: F,
245    c: &mut NdArrayViewMut<'_, F>,
246) {
247    let lhs_shape = lhs.shape();
248    let rhs_shape = rhs.shape();
249    let rank = lhs.ndim();
250    let (mut m, a, mut n) = (
251        lhs_shape[rank - 2],
252        lhs_shape[rank - 1],
253        rhs_shape[rank - 1],
254    );
255
256    {
257        // Use `c` for c-order and `f` for an f-order matrix
258        // We can handle c * c, f * f generally and
259        // c * f and f * c if the `f` matrix is square.
260        let mut lhs_ = lhs.view();
261        let mut rhs_ = rhs.view();
262        let mut c_ = c.view_mut();
263        let mut lhs_strides = lhs_.strides();
264        let mut rhs_strides = rhs_.strides();
265
266        // copy if batch dims appear in last two dims.
267        let mut copied_lhs = None;
268        let mut copied_rhs = None;
269        if batch_mat_mul_requires_copy(lhs_strides) {
270            copied_lhs = Some(crate::ndarray_ext::deep_copy(&lhs_));
271            lhs_ = copied_lhs.as_ref().unwrap().view();
272            lhs_strides = lhs_.strides();
273        }
274        if batch_mat_mul_requires_copy(rhs_strides) {
275            copied_rhs = Some(crate::ndarray_ext::deep_copy(&rhs_));
276            rhs_ = copied_rhs.as_ref().unwrap().view();
277            rhs_strides = rhs_.strides();
278        }
279
280        let lhs_s0 = lhs_strides[rank - 2];
281        let rhs_s0 = rhs_strides[rank - 2];
282        let both_f = lhs_s0 == 1 && rhs_s0 == 1;
283
284        let mut lhs_trans = CblasTranspose::CblasNoTrans;
285        let mut rhs_trans = CblasTranspose::CblasNoTrans;
286
287        // Update lhs, rhs info if needed
288        if both_f {
289            // A^t B^t = C^t => B A = C
290            let mut lhs_t = lhs_;
291            lhs_t.swap_axes(rank - 2, rank - 1);
292            lhs_ = rhs_;
293            lhs_.swap_axes(rank - 2, rank - 1);
294            rhs_ = lhs_t;
295            c_.swap_axes(rank - 2, rank - 1);
296            mem::swap(&mut m, &mut n);
297        } else if lhs_s0 == 1 && m == a {
298            lhs_.swap_axes(rank - 2, rank - 1);
299            lhs_trans = CblasTranspose::CblasTrans;
300        } else if rhs_s0 == 1 && a == n {
301            rhs_.swap_axes(rank - 2, rank - 1);
302            rhs_trans = CblasTranspose::CblasTrans;
303        }
304        let batch_size: usize = lhs_shape[..rank - 2].iter().product();
305
306        macro_rules! call_kernel_def {
307            ($ty:ty, $f:ident) => {
308                if blas_row_major_nd::<$ty, _>(&lhs_)
309                    && blas_row_major_nd::<$ty, _>(&rhs_)
310                    && blas_row_major_nd_mut::<$ty, _>(&c_)
311                {
312                    let (m, k) = match lhs_trans {
313                        CblasTranspose::CblasNoTrans => {
314                            let s = lhs_.shape();
315                            (s[rank - 2], s[rank - 1])
316                        },
317                        _ => {
318                            let s = lhs_.shape();
319                            (s[rank - 1], s[rank - 2])
320                        }
321                    };
322                    let n = match rhs_trans {
323                        CblasTranspose::CblasNoTrans => rhs_.raw_dim()[rank - 1],
324                        _ => rhs_.raw_dim()[rank - 2],
325                    };
326                    // adjust strides, these may [1, 1] for column matrices
327                    let lhs_stride = cmp::max(lhs_.strides()[rank - 2] as MklInt, k as MklInt);
328                    let rhs_stride = cmp::max(rhs_.strides()[rank - 2] as MklInt, n as MklInt);
329                    let c_stride = cmp::max(c_.strides()[rank - 2] as MklInt, n as MklInt);
330
331                    unsafe {
332                        const GROUP_COUNT: usize = 1;  // Fixed
333                        $f(
334                            CBLAS_ROW_MAJOR,
335                            [lhs_trans; GROUP_COUNT].as_ptr(),
336                            [rhs_trans; GROUP_COUNT].as_ptr(),
337                            [m as MklInt; GROUP_COUNT].as_ptr(),
338                            [n as MklInt; GROUP_COUNT].as_ptr(),
339                            [k as MklInt; GROUP_COUNT].as_ptr(),
340                            [cast_as(&alpha); GROUP_COUNT].as_ptr(),             // alpha
341                            get_batch_ptrs(batch_size, lhs_.as_ptr(), lhs_.len()).as_ptr(), // a array
342                            [lhs_stride; GROUP_COUNT].as_ptr(),
343                            get_batch_ptrs(batch_size, rhs_.as_ptr(), rhs_.len()).as_ptr(), // b array
344                            [rhs_stride; GROUP_COUNT].as_ptr(),
345                            [cast_as(&beta); GROUP_COUNT].as_ptr(),               // alpha
346                            get_batch_ptrs_mut(batch_size, c_.as_mut_ptr(), c_.len()).as_mut_ptr(), // c array
347                            [c_stride; GROUP_COUNT].as_ptr(),
348                            GROUP_COUNT as MklInt,
349                            [batch_size as MklInt; GROUP_COUNT].as_ptr()
350                        );
351                    }
352                    return;
353                }
354            };
355        }
356        call_kernel_def!(f32, cblas_sgemm_batch);
357        call_kernel_def!(f64, cblas_dgemm_batch);
358    }
359    batch_mat_mul_impl_slow(alpha, lhs, rhs, beta, c)
360}
361
362/// C ← α A B + β C
363fn mat_mul_impl_slow<F: Float>(
364    alpha: F,
365    lhs: &ArrayView2<'_, F>,
366    rhs: &ArrayView2<'_, F>,
367    beta: F,
368    c: &mut ArrayViewMut2<'_, F>,
369) {
370    let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
371    // common parameters for gemm
372    let ap = lhs.as_ptr();
373    let bp = rhs.as_ptr();
374    let cp = c.as_mut_ptr();
375    let (rsc, csc) = (c.strides()[0], c.strides()[1]);
376    macro_rules! kernel_call_def {
377        ($ty:ty, $f:ident) => {
378            if crate::same_type::<F, $ty>() {
379                unsafe {
380                    ::matrixmultiply::$f(
381                        m,
382                        k,
383                        n,
384                        cast_as(&alpha),
385                        ap as *const _,
386                        lhs.strides()[0],
387                        lhs.strides()[1],
388                        bp as *const _,
389                        rhs.strides()[0],
390                        rhs.strides()[1],
391                        cast_as(&beta),
392                        cp as *mut _,
393                        rsc,
394                        csc,
395                    );
396                }
397            }
398        };
399    }
400    kernel_call_def!(f32, sgemm);
401    kernel_call_def!(f64, dgemm);
402}
403
404/// C ← α A B + β C
405#[allow(unused_assignments)]
406#[allow(unused)]
407fn batch_mat_mul_impl_slow<F: Float>(
408    alpha: F,
409    lhs: &NdArrayView<'_, F>,
410    rhs: &NdArrayView<'_, F>,
411    beta: F,
412    c: &mut NdArrayViewMut<'_, F>,
413) {
414    let mut lhs_ = lhs.view();
415    let mut rhs_ = rhs.view();
416    let c_ = c.view_mut();
417    let mut lhs_strides = lhs_.strides();
418    let mut rhs_strides = rhs_.strides();
419    let rank = lhs_strides.len();
420    let lhs_requires_copy = batch_mat_mul_requires_copy(lhs_strides);
421    let rhs_requires_copy = batch_mat_mul_requires_copy(rhs_strides);
422
423    let mut copied_lhs = None;
424    let mut copied_rhs = None;
425    // Update lhs, rhs info with copied ones
426    {
427        if lhs_requires_copy {
428            copied_lhs = Some(crate::ndarray_ext::deep_copy(&lhs_));
429            lhs_ = copied_lhs.as_ref().unwrap().view();
430            lhs_strides = lhs_.strides();
431        }
432        if rhs_requires_copy {
433            copied_rhs = Some(crate::ndarray_ext::deep_copy(&rhs_));
434            rhs_ = copied_rhs.as_ref().unwrap().view();
435            rhs_strides = rhs_.strides();
436        }
437    }
438
439    let lhs_shape = lhs_.shape();
440    let rhs_shape = rhs_.shape();
441    let (m, k, n) = (
442        lhs_shape[rank - 2],
443        lhs_shape[rank - 1],
444        rhs_shape[rank - 1],
445    );
446
447    // common parameters for gemm
448    let (rsa, csa) = (lhs_strides[rank - 2], lhs_strides[rank - 1]);
449    let (rsb, csb) = (rhs_strides[rank - 2], rhs_strides[rank - 1]);
450    let (rsc, csc) = {
451        let strides = c_.strides();
452        (strides[rank - 2], strides[rank - 1])
453    };
454    let num_batches: usize = lhs_shape[..rank - 2].iter().product();
455    let lhs_batch_size = lhs_.len() / num_batches;
456    let rhs_batch_size = rhs_.len() / num_batches;
457    let c_batch_size = c_.len() / num_batches;
458    let ap_init = lhs.as_ptr();
459    let bp_init = rhs.as_ptr();
460    let cp_init = c.as_mut_ptr();
461    unsafe {
462        macro_rules! kernel_call_def {
463            ($ty:ty, $f:ident) => {
464                if crate::same_type::<F, $ty>() {
465                    for batch_i in 0..num_batches {
466                        let a_pos = (lhs_batch_size * batch_i) as isize;
467                        let b_pos = (rhs_batch_size * batch_i) as isize;
468                        let c_pos = (c_batch_size * batch_i) as isize;
469                        let ap = ap_init.offset(a_pos);
470                        let bp = bp_init.offset(b_pos);
471                        let cp = cp_init.offset(c_pos);
472                        ::matrixmultiply::$f(
473                            m,
474                            k,
475                            n,
476                            cast_as(&alpha),
477                            ap as *const _,
478                            rsa,
479                            csa,
480                            bp as *const _,
481                            rsb,
482                            csb,
483                            cast_as(&beta),
484                            cp as *mut _,
485                            rsc,
486                            csc,
487                        );
488                    }
489                }
490            };
491        }
492        kernel_call_def!(f32, sgemm);
493        kernel_call_def!(f64, dgemm);
494    }
495}
496
497#[inline]
498fn batch_mat_mul_requires_copy(stride: &[ndarray::Ixs]) -> bool {
499    let rank = stride.len();
500    // unwrap is ok since stride.len() > 2
501    let min_str = *stride[0..rank - 2].iter().min().unwrap();
502    let row_str = stride[rank - 2];
503    let col_str = stride[rank - 1];
504    min_str < row_str || min_str < col_str
505}
506
507fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> String {
508    match m.checked_mul(n) {
509        Some(len) if len <= ::std::isize::MAX as usize => {}
510        _ => {
511            return format!("ndarray: shape {} × {} overflows isize", m, n);
512        }
513    }
514    format!(
515        "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
516        m, k, k2, n
517    )
518}
519
520// ========= Op impls =========
521
522use ndarray::ShapeBuilder;
523
524pub struct MatMul {
525    pub transpose_a: bool,
526    pub transpose_b: bool,
527}
528
529pub struct BatchMatMul {
530    pub transpose_a: bool,
531    pub transpose_b: bool,
532}
533
534impl<T: Float> op::Op<T> for MatMul {
535    fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
536        let mut a = ctx
537            .input(0)
538            .into_dimensionality::<ndarray::Ix2>()
539            .expect("lhs input for MatMul must be 2D");
540        let mut b = ctx
541            .input(1)
542            .into_dimensionality::<ndarray::Ix2>()
543            .expect("rhs input for MatMul must be 2D");
544        if self.transpose_a {
545            a.swap_axes(0, 1);
546        }
547        if self.transpose_b {
548            b.swap_axes(0, 1);
549        }
550        let ((m, k), (k2, n)) = (a.dim(), b.dim());
551        if k != k2 || m.checked_mul(n).is_none() {
552            ctx.set_error(op::OpError::IncompatibleShape(dot_shape_error(m, k, k2, n)));
553            return;
554        }
555
556        let lhs_s0 = a.strides()[0];
557        let rhs_s0 = b.strides()[0];
558        let column_major = lhs_s0 == 1 && rhs_s0 == 1;
559        // A is Copy so this is safe
560        let mut v = Vec::with_capacity(m * n);
561        let mut c;
562        unsafe {
563            v.set_len(m * n);
564            c = ndarray::Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
565        }
566
567        #[cfg(feature = "mkl")]
568        {
569            mat_mul_impl_blas(T::one(), &a, &b, T::zero(), &mut c.view_mut());
570        }
571        #[cfg(not(feature = "mkl"))]
572        {
573            mat_mul_impl_slow(T::one(), &a, &b, T::zero(), &mut c.view_mut());
574        }
575        ctx.append_output(c.into_dyn());
576    }
577
578    fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
579        let s = ctx.graph();
580        let gy = &ctx.output_grad();
581        let opa = Tensor::builder().set_ro_inputs(&[gy, &ctx.input(1)]).build(
582            s,
583            MatMul {
584                transpose_a: false,
585                transpose_b: true,
586            },
587        );
588
589        let opb = Tensor::builder().set_ro_inputs(&[&ctx.input(0), gy]).build(
590            s,
591            MatMul {
592                transpose_a: true,
593                transpose_b: false,
594            },
595        );
596
597        ctx.append_input_grad(Some(opa));
598        ctx.append_input_grad(Some(opb));
599    }
600}
601
602impl<T: Float> op::Op<T> for BatchMatMul {
603    fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
604        let mut x0 = ctx.input(0);
605        let mut x1 = ctx.input(1);
606        let rank0 = x0.ndim();
607        let rank1 = x1.ndim();
608
609        if rank0 < 2 {
610            ctx.set_error(op::OpError::IncompatibleShape(format!(
611                "BatchMatMul: Left-hand-side input's ndim must be >= 2, actual: {}",
612                rank0
613            )));
614            return;
615        }
616        if rank1 < 2 {
617            ctx.set_error(op::OpError::IncompatibleShape(format!(
618                "BatchMatMul: Right-hand-side input's ndim must be >= 2, actual: {}",
619                rank1
620            )));
621            return;
622        }
623
624        if self.transpose_a {
625            x0.swap_axes(rank0 - 2, rank0 - 1);
626        }
627
628        if self.transpose_b {
629            x1.swap_axes(rank1 - 2, rank1 - 1);
630        }
631
632        let shape0 = x0.shape();
633        let shape1 = x1.shape();
634        if rank0 != rank1 || shape0[..rank0 - 2] != shape1[..rank0 - 2] {
635            ctx.set_error(op::OpError::IncompatibleShape(format!(
636                "Input shapes mismatch: {:?} vs {:?}",
637                shape0, shape1
638            )));
639            return;
640        }
641
642        let ret_shape = {
643            let mut ret = shape0.to_vec();
644            ret[rank0 - 2] = shape0[rank0 - 2];
645            ret[rank0 - 1] = shape1[rank0 - 1];
646            ret
647        };
648        // A is Copy so this is safe
649        let size: usize = ret_shape.iter().product();
650        let mut v = Vec::with_capacity(size);
651        let mut c;
652        unsafe {
653            v.set_len(size);
654            // BatchMatMul's ret val is a c-order array.
655            c = ndarray::Array::from_shape_vec_unchecked(ret_shape, v);
656        }
657        #[cfg(feature = "mkl")]
658        {
659            batch_mat_mul_impl(T::one(), &x0, &x1, T::zero(), &mut c.view_mut());
660        }
661        #[cfg(not(feature = "mkl"))]
662        {
663            batch_mat_mul_impl_slow(T::one(), &x0, &x1, T::zero(), &mut c.view_mut())
664        }
665
666        // reshape to dst shape with safe unwrapping
667        ctx.append_output(c);
668    }
669
670    fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
671        let gy = &ctx.output_grad();
672        let opa = Tensor::builder().set_ro_inputs(&[gy, &ctx.input(1)]).build(
673            ctx.graph(),
674            BatchMatMul {
675                transpose_a: false,
676                transpose_b: true,
677            },
678        );
679
680        let opb = Tensor::builder().set_ro_inputs(&[&ctx.input(0), gy]).build(
681            ctx.graph(),
682            BatchMatMul {
683                transpose_a: true,
684                transpose_b: false,
685            },
686        );
687
688        ctx.append_input_grad(Some(opa));
689        ctx.append_input_grad(Some(opb));
690    }
691}
692
693pub struct TensordotPreprocess;
694
695#[inline]
696fn tensordot_preprocess<T: Float>(
697    shape: &[usize],
698    axes: &[usize],
699    flip: bool,
700) -> (Vec<T>, Vec<T>, Vec<T>) {
701    let free = (0..shape.len())
702        .filter(|i| !axes.contains(i))
703        .collect::<Vec<usize>>();
704    let mut free_dims = Vec::with_capacity(free.len());
705    let mut prod_free_dims = 1;
706    {
707        for &i in &free {
708            prod_free_dims *= shape[i];
709            free_dims.push(T::from(shape[i]).unwrap());
710        }
711    }
712    let prod_axes_dims = axes.iter().map(|&i| shape[i]).product::<usize>();
713
714    // make perm
715    let first = if flip { axes } else { &free };
716    let second = if flip { &free } else { axes };
717    let mut perm = Vec::with_capacity(first.len() + second.len());
718    for &a in first {
719        perm.push(T::from(a).unwrap());
720    }
721    for &a in second {
722        perm.push(T::from(a).unwrap());
723    }
724
725    // make new shape
726    let new_shape = if flip {
727        vec![
728            T::from(prod_axes_dims).unwrap(),
729            T::from(prod_free_dims).unwrap(),
730        ]
731    } else {
732        vec![
733            T::from(prod_free_dims).unwrap(),
734            T::from(prod_axes_dims).unwrap(),
735        ]
736    };
737
738    (perm, new_shape, free_dims)
739}
740
741impl<T: Float> op::Op<T> for TensordotPreprocess {
742    fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
743        let x0 = ctx.input(0);
744        let x1 = &ctx.input(1);
745        let axes0 = crate::ndarray_ext::normalize_negative_axes(&ctx.input(2), x0.ndim());
746        let axes1 = crate::ndarray_ext::normalize_negative_axes(&ctx.input(3), x1.ndim());
747
748        let (perm0, new_shape0, mut free_dims0) = tensordot_preprocess(x0.shape(), &axes0, false);
749        let (perm1, new_shape1, free_dims1) = tensordot_preprocess(x1.shape(), &axes1, true);
750        free_dims0.extend(free_dims1);
751
752        let r0 = NdArray::from_shape_vec(ndarray::IxDyn(&[free_dims0.len()]), free_dims0).unwrap();
753        let r1 = NdArray::from_shape_vec(ndarray::IxDyn(&[perm0.len()]), perm0).unwrap();
754        let r2 = NdArray::from_shape_vec(ndarray::IxDyn(&[perm1.len()]), perm1).unwrap();
755        let r3 = NdArray::from_shape_vec(ndarray::IxDyn(&[new_shape0.len()]), new_shape0).unwrap();
756        let r4 = NdArray::from_shape_vec(ndarray::IxDyn(&[new_shape1.len()]), new_shape1).unwrap();
757
758        ctx.append_output(r0);
759        ctx.append_output(r1);
760        ctx.append_output(r2);
761        ctx.append_output(r3);
762        ctx.append_output(r4);
763    }
764
765    fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
766        ctx.append_input_grad(None);
767        ctx.append_input_grad(None);
768        ctx.append_input_grad(None);
769        ctx.append_input_grad(None);
770    }
771}