1use crate::ndarray_ext::NdArray;
3#[cfg(feature = "mkl")]
4use crate::ndarray_ext::{get_batch_ptrs, get_batch_ptrs_mut};
5#[cfg(feature = "mkl")]
6use crate::ops::mkl_ffi::*;
7use crate::same_type;
8use crate::tensor::Tensor;
9use crate::Float;
10use crate::NdArrayView;
11use crate::{op, NdArrayViewMut};
12use ndarray;
13#[cfg(feature = "mkl")]
14use ndarray::Dimension;
15use ndarray::{ArrayView2, ArrayViewMut2};
16#[cfg(feature = "mkl")]
17use std::cmp;
18#[cfg(feature = "mkl")]
19use std::mem;
20
21#[cfg(feature = "mkl")]
22#[inline]
23fn blas_row_major_2d<T: 'static, F>(a: &ndarray::ArrayView2<F>) -> bool
24where
25 F: Float,
26{
27 if !same_type::<F, T>() {
28 return false;
29 }
30 is_blas_2d(&a.raw_dim(), a.strides(), MemoryOrder::C)
31}
32
33#[cfg(feature = "mkl")]
34#[inline]
35fn blas_row_major_nd<T: 'static, F>(a: &NdArrayView<F>) -> bool
36where
37 F: Float,
38{
39 if !same_type::<F, T>() {
40 return false;
41 }
42 let strides = a.strides();
43 let rank = strides.len();
44 is_blas_nd(
45 a.shape(),
46 strides[rank - 2],
47 strides[rank - 1],
48 MemoryOrder::C,
49 )
50}
51
52#[cfg(feature = "mkl")]
53#[inline]
54fn blas_row_major_2d_mut<T: 'static, F>(a: &ndarray::ArrayViewMut2<F>) -> bool
55where
56 F: Float,
57{
58 if !same_type::<F, T>() {
59 return false;
60 }
61 is_blas_2d(&a.raw_dim(), a.strides(), MemoryOrder::C)
62}
63
64#[cfg(feature = "mkl")]
65#[inline]
66fn blas_row_major_nd_mut<T: 'static, F>(a: &NdArrayViewMut<F>) -> bool
67where
68 F: Float,
69{
70 if !same_type::<F, T>() {
71 return false;
72 }
73 let strides = a.strides();
74 let rank = strides.len();
75 is_blas_nd(
76 a.shape(),
77 strides[rank - 2],
78 strides[rank - 1],
79 MemoryOrder::C,
80 )
81}
82
83#[cfg(feature = "mkl")]
84fn is_blas_nd(shape: &[usize], stride0: isize, stride1: isize, order: MemoryOrder) -> bool {
85 let (m, n) = (shape[0], shape[1]);
86 let (inner_stride, outer_dim) = match order {
87 MemoryOrder::C => (stride1, n),
88 MemoryOrder::F => (stride0, m),
89 };
90 if !(inner_stride == 1 || outer_dim == 1) {
91 return false;
92 }
93 if stride0 < 1 || stride1 < 1 {
94 return false;
95 }
96 if (stride0 > MklInt::max_value() as isize || stride0 < MklInt::min_value() as isize)
97 || (stride1 > MklInt::max_value() as isize || stride1 < MklInt::min_value() as isize)
98 {
99 return false;
100 }
101 if m > MklInt::max_value() as usize || n > MklInt::max_value() as usize {
102 return false;
103 }
104 true
105}
106
107#[cfg(feature = "mkl")]
108fn is_blas_2d(dim: &ndarray::Ix2, stride: &[isize], order: MemoryOrder) -> bool {
109 let (m, n) = dim.into_pattern();
110 let s0 = stride[0] as isize;
111 let s1 = stride[1] as isize;
112 let (inner_stride, outer_dim) = match order {
113 MemoryOrder::C => (s1, n),
114 MemoryOrder::F => (s0, m),
115 };
116 if !(inner_stride == 1 || outer_dim == 1) {
117 return false;
118 }
119 if s0 < 1 || s1 < 1 {
120 return false;
121 }
122 if (s0 > MklInt::max_value() as isize || s0 < MklInt::min_value() as isize)
123 || (s1 > MklInt::max_value() as isize || s1 < MklInt::min_value() as isize)
124 {
125 return false;
126 }
127 if m > MklInt::max_value() as usize || n > MklInt::max_value() as usize {
128 return false;
129 }
130 true
131}
132
133#[inline]
137fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
138 assert!(same_type::<A, B>());
139 unsafe { ::std::ptr::read(a as *const _ as *const B) }
140}
141
142#[cfg(feature = "mkl")]
144fn mat_mul_impl_blas<F: Float>(
145 alpha: F,
146 lhs: &ArrayView2<'_, F>,
147 rhs: &ArrayView2<'_, F>,
148 beta: F,
149 c: &mut ArrayViewMut2<'_, F>,
150) {
151 const GEMM_BLAS_CUTOFF: usize = 7;
152
153 let cut = GEMM_BLAS_CUTOFF;
155 let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
156 if !(m > cut || n > cut || a > cut) || !(same_type::<F, f32>() || same_type::<F, f64>()) {
157 return mat_mul_impl_slow(alpha, lhs, rhs, beta, c);
158 }
159 {
160 let mut lhs_ = lhs.view();
164 let mut rhs_ = rhs.view();
165 let mut c_ = c.view_mut();
166 let lhs_s0 = lhs_.strides()[0];
167 let rhs_s0 = rhs_.strides()[0];
168 let both_f = lhs_s0 == 1 && rhs_s0 == 1;
169 let mut lhs_trans = CblasTranspose::CblasNoTrans;
170 let mut rhs_trans = CblasTranspose::CblasNoTrans;
171 if both_f {
172 let lhs_t = lhs_.reversed_axes();
174 lhs_ = rhs_.reversed_axes();
175 rhs_ = lhs_t;
176 c_ = c_.reversed_axes();
177 mem::swap(&mut m, &mut n);
178 } else if lhs_s0 == 1 && m == a {
179 lhs_ = lhs_.reversed_axes();
180 lhs_trans = CblasTranspose::CblasTrans;
181 } else if rhs_s0 == 1 && a == n {
182 rhs_ = rhs_.reversed_axes();
183 rhs_trans = CblasTranspose::CblasTrans;
184 }
185
186 macro_rules! call_kernel_def {
187 ($ty:ty, $f:ident) => {
188 if blas_row_major_2d::<$ty, _>(&lhs_)
189 && blas_row_major_2d::<$ty, _>(&rhs_)
190 && blas_row_major_2d_mut::<$ty, _>(&c_)
191 {
192 let (m, k) = match lhs_trans {
193 CblasTranspose::CblasNoTrans => lhs_.dim(),
194 _ => {
195 let (rows, cols) = lhs_.dim();
196 (cols, rows)
197 }
198 };
199 let n = match rhs_trans {
200 CblasTranspose::CblasNoTrans => rhs_.raw_dim()[1],
201 _ => rhs_.raw_dim()[0],
202 };
203 let lhs_stride = cmp::max(lhs_.strides()[0] as MklInt, k as MklInt);
205 let rhs_stride = cmp::max(rhs_.strides()[0] as MklInt, n as MklInt);
206 let c_stride = cmp::max(c_.strides()[0] as MklInt, n as MklInt);
207
208 unsafe {
211 $f(
212 CBLAS_ROW_MAJOR,
213 lhs_trans,
214 rhs_trans,
215 m as MklInt, n as MklInt, k as MklInt, cast_as(&alpha), lhs_.as_ptr() as *const _, lhs_stride, rhs_.as_ptr() as *const _, rhs_stride, cast_as(&beta), c_.as_mut_ptr() as *mut _, c_stride, );
227 }
228 return;
229 }
230 };
231 }
232 call_kernel_def!(f32, cblas_sgemm);
233 call_kernel_def!(f64, cblas_dgemm);
234 }
235 mat_mul_impl_slow(alpha, lhs, rhs, beta, c)
236}
237
238#[allow(unused_assignments)]
239#[cfg(feature = "mkl")]
240fn batch_mat_mul_impl<F: Float>(
241 alpha: F,
242 lhs: &NdArrayView<'_, F>,
243 rhs: &NdArrayView<'_, F>,
244 beta: F,
245 c: &mut NdArrayViewMut<'_, F>,
246) {
247 let lhs_shape = lhs.shape();
248 let rhs_shape = rhs.shape();
249 let rank = lhs.ndim();
250 let (mut m, a, mut n) = (
251 lhs_shape[rank - 2],
252 lhs_shape[rank - 1],
253 rhs_shape[rank - 1],
254 );
255
256 {
257 let mut lhs_ = lhs.view();
261 let mut rhs_ = rhs.view();
262 let mut c_ = c.view_mut();
263 let mut lhs_strides = lhs_.strides();
264 let mut rhs_strides = rhs_.strides();
265
266 let mut copied_lhs = None;
268 let mut copied_rhs = None;
269 if batch_mat_mul_requires_copy(lhs_strides) {
270 copied_lhs = Some(crate::ndarray_ext::deep_copy(&lhs_));
271 lhs_ = copied_lhs.as_ref().unwrap().view();
272 lhs_strides = lhs_.strides();
273 }
274 if batch_mat_mul_requires_copy(rhs_strides) {
275 copied_rhs = Some(crate::ndarray_ext::deep_copy(&rhs_));
276 rhs_ = copied_rhs.as_ref().unwrap().view();
277 rhs_strides = rhs_.strides();
278 }
279
280 let lhs_s0 = lhs_strides[rank - 2];
281 let rhs_s0 = rhs_strides[rank - 2];
282 let both_f = lhs_s0 == 1 && rhs_s0 == 1;
283
284 let mut lhs_trans = CblasTranspose::CblasNoTrans;
285 let mut rhs_trans = CblasTranspose::CblasNoTrans;
286
287 if both_f {
289 let mut lhs_t = lhs_;
291 lhs_t.swap_axes(rank - 2, rank - 1);
292 lhs_ = rhs_;
293 lhs_.swap_axes(rank - 2, rank - 1);
294 rhs_ = lhs_t;
295 c_.swap_axes(rank - 2, rank - 1);
296 mem::swap(&mut m, &mut n);
297 } else if lhs_s0 == 1 && m == a {
298 lhs_.swap_axes(rank - 2, rank - 1);
299 lhs_trans = CblasTranspose::CblasTrans;
300 } else if rhs_s0 == 1 && a == n {
301 rhs_.swap_axes(rank - 2, rank - 1);
302 rhs_trans = CblasTranspose::CblasTrans;
303 }
304 let batch_size: usize = lhs_shape[..rank - 2].iter().product();
305
306 macro_rules! call_kernel_def {
307 ($ty:ty, $f:ident) => {
308 if blas_row_major_nd::<$ty, _>(&lhs_)
309 && blas_row_major_nd::<$ty, _>(&rhs_)
310 && blas_row_major_nd_mut::<$ty, _>(&c_)
311 {
312 let (m, k) = match lhs_trans {
313 CblasTranspose::CblasNoTrans => {
314 let s = lhs_.shape();
315 (s[rank - 2], s[rank - 1])
316 },
317 _ => {
318 let s = lhs_.shape();
319 (s[rank - 1], s[rank - 2])
320 }
321 };
322 let n = match rhs_trans {
323 CblasTranspose::CblasNoTrans => rhs_.raw_dim()[rank - 1],
324 _ => rhs_.raw_dim()[rank - 2],
325 };
326 let lhs_stride = cmp::max(lhs_.strides()[rank - 2] as MklInt, k as MklInt);
328 let rhs_stride = cmp::max(rhs_.strides()[rank - 2] as MklInt, n as MklInt);
329 let c_stride = cmp::max(c_.strides()[rank - 2] as MklInt, n as MklInt);
330
331 unsafe {
332 const GROUP_COUNT: usize = 1; $f(
334 CBLAS_ROW_MAJOR,
335 [lhs_trans; GROUP_COUNT].as_ptr(),
336 [rhs_trans; GROUP_COUNT].as_ptr(),
337 [m as MklInt; GROUP_COUNT].as_ptr(),
338 [n as MklInt; GROUP_COUNT].as_ptr(),
339 [k as MklInt; GROUP_COUNT].as_ptr(),
340 [cast_as(&alpha); GROUP_COUNT].as_ptr(), get_batch_ptrs(batch_size, lhs_.as_ptr(), lhs_.len()).as_ptr(), [lhs_stride; GROUP_COUNT].as_ptr(),
343 get_batch_ptrs(batch_size, rhs_.as_ptr(), rhs_.len()).as_ptr(), [rhs_stride; GROUP_COUNT].as_ptr(),
345 [cast_as(&beta); GROUP_COUNT].as_ptr(), get_batch_ptrs_mut(batch_size, c_.as_mut_ptr(), c_.len()).as_mut_ptr(), [c_stride; GROUP_COUNT].as_ptr(),
348 GROUP_COUNT as MklInt,
349 [batch_size as MklInt; GROUP_COUNT].as_ptr()
350 );
351 }
352 return;
353 }
354 };
355 }
356 call_kernel_def!(f32, cblas_sgemm_batch);
357 call_kernel_def!(f64, cblas_dgemm_batch);
358 }
359 batch_mat_mul_impl_slow(alpha, lhs, rhs, beta, c)
360}
361
362fn mat_mul_impl_slow<F: Float>(
364 alpha: F,
365 lhs: &ArrayView2<'_, F>,
366 rhs: &ArrayView2<'_, F>,
367 beta: F,
368 c: &mut ArrayViewMut2<'_, F>,
369) {
370 let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
371 let ap = lhs.as_ptr();
373 let bp = rhs.as_ptr();
374 let cp = c.as_mut_ptr();
375 let (rsc, csc) = (c.strides()[0], c.strides()[1]);
376 macro_rules! kernel_call_def {
377 ($ty:ty, $f:ident) => {
378 if crate::same_type::<F, $ty>() {
379 unsafe {
380 ::matrixmultiply::$f(
381 m,
382 k,
383 n,
384 cast_as(&alpha),
385 ap as *const _,
386 lhs.strides()[0],
387 lhs.strides()[1],
388 bp as *const _,
389 rhs.strides()[0],
390 rhs.strides()[1],
391 cast_as(&beta),
392 cp as *mut _,
393 rsc,
394 csc,
395 );
396 }
397 }
398 };
399 }
400 kernel_call_def!(f32, sgemm);
401 kernel_call_def!(f64, dgemm);
402}
403
404#[allow(unused_assignments)]
406#[allow(unused)]
407fn batch_mat_mul_impl_slow<F: Float>(
408 alpha: F,
409 lhs: &NdArrayView<'_, F>,
410 rhs: &NdArrayView<'_, F>,
411 beta: F,
412 c: &mut NdArrayViewMut<'_, F>,
413) {
414 let mut lhs_ = lhs.view();
415 let mut rhs_ = rhs.view();
416 let c_ = c.view_mut();
417 let mut lhs_strides = lhs_.strides();
418 let mut rhs_strides = rhs_.strides();
419 let rank = lhs_strides.len();
420 let lhs_requires_copy = batch_mat_mul_requires_copy(lhs_strides);
421 let rhs_requires_copy = batch_mat_mul_requires_copy(rhs_strides);
422
423 let mut copied_lhs = None;
424 let mut copied_rhs = None;
425 {
427 if lhs_requires_copy {
428 copied_lhs = Some(crate::ndarray_ext::deep_copy(&lhs_));
429 lhs_ = copied_lhs.as_ref().unwrap().view();
430 lhs_strides = lhs_.strides();
431 }
432 if rhs_requires_copy {
433 copied_rhs = Some(crate::ndarray_ext::deep_copy(&rhs_));
434 rhs_ = copied_rhs.as_ref().unwrap().view();
435 rhs_strides = rhs_.strides();
436 }
437 }
438
439 let lhs_shape = lhs_.shape();
440 let rhs_shape = rhs_.shape();
441 let (m, k, n) = (
442 lhs_shape[rank - 2],
443 lhs_shape[rank - 1],
444 rhs_shape[rank - 1],
445 );
446
447 let (rsa, csa) = (lhs_strides[rank - 2], lhs_strides[rank - 1]);
449 let (rsb, csb) = (rhs_strides[rank - 2], rhs_strides[rank - 1]);
450 let (rsc, csc) = {
451 let strides = c_.strides();
452 (strides[rank - 2], strides[rank - 1])
453 };
454 let num_batches: usize = lhs_shape[..rank - 2].iter().product();
455 let lhs_batch_size = lhs_.len() / num_batches;
456 let rhs_batch_size = rhs_.len() / num_batches;
457 let c_batch_size = c_.len() / num_batches;
458 let ap_init = lhs.as_ptr();
459 let bp_init = rhs.as_ptr();
460 let cp_init = c.as_mut_ptr();
461 unsafe {
462 macro_rules! kernel_call_def {
463 ($ty:ty, $f:ident) => {
464 if crate::same_type::<F, $ty>() {
465 for batch_i in 0..num_batches {
466 let a_pos = (lhs_batch_size * batch_i) as isize;
467 let b_pos = (rhs_batch_size * batch_i) as isize;
468 let c_pos = (c_batch_size * batch_i) as isize;
469 let ap = ap_init.offset(a_pos);
470 let bp = bp_init.offset(b_pos);
471 let cp = cp_init.offset(c_pos);
472 ::matrixmultiply::$f(
473 m,
474 k,
475 n,
476 cast_as(&alpha),
477 ap as *const _,
478 rsa,
479 csa,
480 bp as *const _,
481 rsb,
482 csb,
483 cast_as(&beta),
484 cp as *mut _,
485 rsc,
486 csc,
487 );
488 }
489 }
490 };
491 }
492 kernel_call_def!(f32, sgemm);
493 kernel_call_def!(f64, dgemm);
494 }
495}
496
497#[inline]
498fn batch_mat_mul_requires_copy(stride: &[ndarray::Ixs]) -> bool {
499 let rank = stride.len();
500 let min_str = *stride[0..rank - 2].iter().min().unwrap();
502 let row_str = stride[rank - 2];
503 let col_str = stride[rank - 1];
504 min_str < row_str || min_str < col_str
505}
506
507fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> String {
508 match m.checked_mul(n) {
509 Some(len) if len <= ::std::isize::MAX as usize => {}
510 _ => {
511 return format!("ndarray: shape {} × {} overflows isize", m, n);
512 }
513 }
514 format!(
515 "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
516 m, k, k2, n
517 )
518}
519
520use ndarray::ShapeBuilder;
523
524pub struct MatMul {
525 pub transpose_a: bool,
526 pub transpose_b: bool,
527}
528
529pub struct BatchMatMul {
530 pub transpose_a: bool,
531 pub transpose_b: bool,
532}
533
534impl<T: Float> op::Op<T> for MatMul {
535 fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
536 let mut a = ctx
537 .input(0)
538 .into_dimensionality::<ndarray::Ix2>()
539 .expect("lhs input for MatMul must be 2D");
540 let mut b = ctx
541 .input(1)
542 .into_dimensionality::<ndarray::Ix2>()
543 .expect("rhs input for MatMul must be 2D");
544 if self.transpose_a {
545 a.swap_axes(0, 1);
546 }
547 if self.transpose_b {
548 b.swap_axes(0, 1);
549 }
550 let ((m, k), (k2, n)) = (a.dim(), b.dim());
551 if k != k2 || m.checked_mul(n).is_none() {
552 ctx.set_error(op::OpError::IncompatibleShape(dot_shape_error(m, k, k2, n)));
553 return;
554 }
555
556 let lhs_s0 = a.strides()[0];
557 let rhs_s0 = b.strides()[0];
558 let column_major = lhs_s0 == 1 && rhs_s0 == 1;
559 let mut v = Vec::with_capacity(m * n);
561 let mut c;
562 unsafe {
563 v.set_len(m * n);
564 c = ndarray::Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
565 }
566
567 #[cfg(feature = "mkl")]
568 {
569 mat_mul_impl_blas(T::one(), &a, &b, T::zero(), &mut c.view_mut());
570 }
571 #[cfg(not(feature = "mkl"))]
572 {
573 mat_mul_impl_slow(T::one(), &a, &b, T::zero(), &mut c.view_mut());
574 }
575 ctx.append_output(c.into_dyn());
576 }
577
578 fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
579 let s = ctx.graph();
580 let gy = &ctx.output_grad();
581 let opa = Tensor::builder().set_ro_inputs(&[gy, &ctx.input(1)]).build(
582 s,
583 MatMul {
584 transpose_a: false,
585 transpose_b: true,
586 },
587 );
588
589 let opb = Tensor::builder().set_ro_inputs(&[&ctx.input(0), gy]).build(
590 s,
591 MatMul {
592 transpose_a: true,
593 transpose_b: false,
594 },
595 );
596
597 ctx.append_input_grad(Some(opa));
598 ctx.append_input_grad(Some(opb));
599 }
600}
601
602impl<T: Float> op::Op<T> for BatchMatMul {
603 fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
604 let mut x0 = ctx.input(0);
605 let mut x1 = ctx.input(1);
606 let rank0 = x0.ndim();
607 let rank1 = x1.ndim();
608
609 if rank0 < 2 {
610 ctx.set_error(op::OpError::IncompatibleShape(format!(
611 "BatchMatMul: Left-hand-side input's ndim must be >= 2, actual: {}",
612 rank0
613 )));
614 return;
615 }
616 if rank1 < 2 {
617 ctx.set_error(op::OpError::IncompatibleShape(format!(
618 "BatchMatMul: Right-hand-side input's ndim must be >= 2, actual: {}",
619 rank1
620 )));
621 return;
622 }
623
624 if self.transpose_a {
625 x0.swap_axes(rank0 - 2, rank0 - 1);
626 }
627
628 if self.transpose_b {
629 x1.swap_axes(rank1 - 2, rank1 - 1);
630 }
631
632 let shape0 = x0.shape();
633 let shape1 = x1.shape();
634 if rank0 != rank1 || shape0[..rank0 - 2] != shape1[..rank0 - 2] {
635 ctx.set_error(op::OpError::IncompatibleShape(format!(
636 "Input shapes mismatch: {:?} vs {:?}",
637 shape0, shape1
638 )));
639 return;
640 }
641
642 let ret_shape = {
643 let mut ret = shape0.to_vec();
644 ret[rank0 - 2] = shape0[rank0 - 2];
645 ret[rank0 - 1] = shape1[rank0 - 1];
646 ret
647 };
648 let size: usize = ret_shape.iter().product();
650 let mut v = Vec::with_capacity(size);
651 let mut c;
652 unsafe {
653 v.set_len(size);
654 c = ndarray::Array::from_shape_vec_unchecked(ret_shape, v);
656 }
657 #[cfg(feature = "mkl")]
658 {
659 batch_mat_mul_impl(T::one(), &x0, &x1, T::zero(), &mut c.view_mut());
660 }
661 #[cfg(not(feature = "mkl"))]
662 {
663 batch_mat_mul_impl_slow(T::one(), &x0, &x1, T::zero(), &mut c.view_mut())
664 }
665
666 ctx.append_output(c);
668 }
669
670 fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
671 let gy = &ctx.output_grad();
672 let opa = Tensor::builder().set_ro_inputs(&[gy, &ctx.input(1)]).build(
673 ctx.graph(),
674 BatchMatMul {
675 transpose_a: false,
676 transpose_b: true,
677 },
678 );
679
680 let opb = Tensor::builder().set_ro_inputs(&[&ctx.input(0), gy]).build(
681 ctx.graph(),
682 BatchMatMul {
683 transpose_a: true,
684 transpose_b: false,
685 },
686 );
687
688 ctx.append_input_grad(Some(opa));
689 ctx.append_input_grad(Some(opb));
690 }
691}
692
693pub struct TensordotPreprocess;
694
695#[inline]
696fn tensordot_preprocess<T: Float>(
697 shape: &[usize],
698 axes: &[usize],
699 flip: bool,
700) -> (Vec<T>, Vec<T>, Vec<T>) {
701 let free = (0..shape.len())
702 .filter(|i| !axes.contains(i))
703 .collect::<Vec<usize>>();
704 let mut free_dims = Vec::with_capacity(free.len());
705 let mut prod_free_dims = 1;
706 {
707 for &i in &free {
708 prod_free_dims *= shape[i];
709 free_dims.push(T::from(shape[i]).unwrap());
710 }
711 }
712 let prod_axes_dims = axes.iter().map(|&i| shape[i]).product::<usize>();
713
714 let first = if flip { axes } else { &free };
716 let second = if flip { &free } else { axes };
717 let mut perm = Vec::with_capacity(first.len() + second.len());
718 for &a in first {
719 perm.push(T::from(a).unwrap());
720 }
721 for &a in second {
722 perm.push(T::from(a).unwrap());
723 }
724
725 let new_shape = if flip {
727 vec![
728 T::from(prod_axes_dims).unwrap(),
729 T::from(prod_free_dims).unwrap(),
730 ]
731 } else {
732 vec![
733 T::from(prod_free_dims).unwrap(),
734 T::from(prod_axes_dims).unwrap(),
735 ]
736 };
737
738 (perm, new_shape, free_dims)
739}
740
741impl<T: Float> op::Op<T> for TensordotPreprocess {
742 fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
743 let x0 = ctx.input(0);
744 let x1 = &ctx.input(1);
745 let axes0 = crate::ndarray_ext::normalize_negative_axes(&ctx.input(2), x0.ndim());
746 let axes1 = crate::ndarray_ext::normalize_negative_axes(&ctx.input(3), x1.ndim());
747
748 let (perm0, new_shape0, mut free_dims0) = tensordot_preprocess(x0.shape(), &axes0, false);
749 let (perm1, new_shape1, free_dims1) = tensordot_preprocess(x1.shape(), &axes1, true);
750 free_dims0.extend(free_dims1);
751
752 let r0 = NdArray::from_shape_vec(ndarray::IxDyn(&[free_dims0.len()]), free_dims0).unwrap();
753 let r1 = NdArray::from_shape_vec(ndarray::IxDyn(&[perm0.len()]), perm0).unwrap();
754 let r2 = NdArray::from_shape_vec(ndarray::IxDyn(&[perm1.len()]), perm1).unwrap();
755 let r3 = NdArray::from_shape_vec(ndarray::IxDyn(&[new_shape0.len()]), new_shape0).unwrap();
756 let r4 = NdArray::from_shape_vec(ndarray::IxDyn(&[new_shape1.len()]), new_shape1).unwrap();
757
758 ctx.append_output(r0);
759 ctx.append_output(r1);
760 ctx.append_output(r2);
761 ctx.append_output(r3);
762 ctx.append_output(r4);
763 }
764
765 fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
766 ctx.append_input_grad(None);
767 ctx.append_input_grad(None);
768 ctx.append_input_grad(None);
769 ctx.append_input_grad(None);
770 }
771}