1use crate::numeric::{Numeric, NumericRef};
51use crate::tensors::indexing::{TensorAccess, TensorReferenceIterator};
52use crate::tensors::views::{TensorIndex, TensorMut, TensorRef, TensorView};
53use crate::tensors::{Dimension, Tensor};
54
55use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
56
57#[inline]
59pub(crate) fn tensor_equality<T, S1, S2, const D: usize>(left: &S1, right: &S2) -> bool
60where
61 T: PartialEq,
62 S1: TensorRef<T, D>,
63 S2: TensorRef<T, D>,
64{
65 left.view_shape() == right.view_shape()
66 && TensorReferenceIterator::from(left)
67 .zip(TensorReferenceIterator::from(right))
68 .all(|(x, y)| x == y)
69}
70
71#[inline]
75pub(crate) fn tensor_similarity<T, S1, S2, const D: usize>(left: &S1, right: &S2) -> bool
76where
77 T: PartialEq,
78 S1: TensorRef<T, D>,
79 S2: TensorRef<T, D>,
80{
81 use crate::tensors::dimensions::names_of;
82 let left_shape = left.view_shape();
83 let access_order = names_of(&left_shape);
84 let left_access = TensorAccess::from_source_order(left);
85 let right_access = match TensorAccess::try_from(right, access_order) {
86 Ok(right_access) => right_access,
87 Err(_) => return false,
88 };
89 if left_shape != right_access.shape() {
91 return false;
92 }
93 left_access
94 .iter_reference()
95 .zip(right_access.iter_reference())
96 .all(|(x, y)| x == y)
97}
98
99impl<T: PartialEq, const D: usize> PartialEq for Tensor<T, D> {
127 fn eq(&self, other: &Self) -> bool {
128 tensor_equality(self, other)
129 }
130}
131
132impl<T, S1, S2, const D: usize> PartialEq<TensorView<T, S2, D>> for TensorView<T, S1, D>
137where
138 T: PartialEq,
139 S1: TensorRef<T, D>,
140 S2: TensorRef<T, D>,
141{
142 fn eq(&self, other: &TensorView<T, S2, D>) -> bool {
143 tensor_equality(self.source_ref(), other.source_ref())
144 }
145}
146
147impl<T, S, const D: usize> PartialEq<TensorView<T, S, D>> for Tensor<T, D>
152where
153 T: PartialEq,
154 S: TensorRef<T, D>,
155{
156 fn eq(&self, other: &TensorView<T, S, D>) -> bool {
157 tensor_equality(self, other.source_ref())
158 }
159}
160
161impl<T, S, const D: usize> PartialEq<Tensor<T, D>> for TensorView<T, S, D>
166where
167 T: PartialEq,
168 S: TensorRef<T, D>,
169{
170 fn eq(&self, other: &Tensor<T, D>) -> bool {
171 tensor_equality(self.source_ref(), other)
172 }
173}
174
175pub trait Similar<Rhs: ?Sized = Self>: private::Sealed {
183 #[must_use]
188 fn similar(&self, other: &Rhs) -> bool;
189}
190
191mod private {
192 use crate::tensors::Tensor;
193 use crate::tensors::views::{TensorRef, TensorView};
194
195 pub trait Sealed<Rhs: ?Sized = Self> {}
196
197 impl<T: PartialEq, const D: usize> Sealed for Tensor<T, D> {}
198 impl<T, S1, S2, const D: usize> Sealed<TensorView<T, S2, D>> for TensorView<T, S1, D>
199 where
200 T: PartialEq,
201 S1: TensorRef<T, D>,
202 S2: TensorRef<T, D>,
203 {
204 }
205 impl<T, S, const D: usize> Sealed<TensorView<T, S, D>> for Tensor<T, D>
206 where
207 T: PartialEq,
208 S: TensorRef<T, D>,
209 {
210 }
211 impl<T, S, const D: usize> Sealed<Tensor<T, D>> for TensorView<T, S, D>
212 where
213 T: PartialEq,
214 S: TensorRef<T, D>,
215 {
216 }
217}
218
219impl<T: PartialEq, const D: usize> Similar for Tensor<T, D> {
258 fn similar(&self, other: &Tensor<T, D>) -> bool {
259 tensor_similarity(self, other)
260 }
261}
262
263impl<T, S1, S2, const D: usize> Similar<TensorView<T, S2, D>> for TensorView<T, S1, D>
270where
271 T: PartialEq,
272 S1: TensorRef<T, D>,
273 S2: TensorRef<T, D>,
274{
275 fn similar(&self, other: &TensorView<T, S2, D>) -> bool {
276 tensor_similarity(self.source_ref(), other.source_ref())
277 }
278}
279
280impl<T, S, const D: usize> Similar<TensorView<T, S, D>> for Tensor<T, D>
287where
288 T: PartialEq,
289 S: TensorRef<T, D>,
290{
291 fn similar(&self, other: &TensorView<T, S, D>) -> bool {
292 tensor_similarity(self, other.source_ref())
293 }
294}
295
296impl<T, S, const D: usize> Similar<Tensor<T, D>> for TensorView<T, S, D>
303where
304 T: PartialEq,
305 S: TensorRef<T, D>,
306{
307 fn similar(&self, other: &Tensor<T, D>) -> bool {
308 tensor_similarity(self.source_ref(), other)
309 }
310}
311
312#[track_caller]
313#[inline]
314fn assert_same_dimensions<const D: usize>(
315 left_shape: [(Dimension, usize); D],
316 right_shape: [(Dimension, usize); D],
317) {
318 if left_shape != right_shape {
319 panic!(
320 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
321 left_shape, right_shape
322 );
323 }
324}
325
326#[track_caller]
327#[inline]
328fn tensor_view_addition_iter<'l, 'r, T, S1, S2, const D: usize>(
329 left_iter: S1,
330 left_shape: [(Dimension, usize); D],
331 right_iter: S2,
332 right_shape: [(Dimension, usize); D],
333) -> Tensor<T, D>
334where
335 T: Numeric,
336 T: 'l,
337 T: 'r,
338 for<'a> &'a T: NumericRef<T>,
339 S1: Iterator<Item = &'l T>,
340 S2: Iterator<Item = &'r T>,
341{
342 assert_same_dimensions(left_shape, right_shape);
343 Tensor::from(
345 left_shape,
346 left_iter.zip(right_iter).map(|(x, y)| x + y).collect(),
347 )
348}
349
350#[track_caller]
351#[inline]
352fn tensor_view_subtraction_iter<'l, 'r, T, S1, S2, const D: usize>(
353 left_iter: S1,
354 left_shape: [(Dimension, usize); D],
355 right_iter: S2,
356 right_shape: [(Dimension, usize); D],
357) -> Tensor<T, D>
358where
359 T: Numeric,
360 T: 'l,
361 T: 'r,
362 for<'a> &'a T: NumericRef<T>,
363 S1: Iterator<Item = &'l T>,
364 S2: Iterator<Item = &'r T>,
365{
366 assert_same_dimensions(left_shape, right_shape);
367 Tensor::from(
369 left_shape,
370 left_iter.zip(right_iter).map(|(x, y)| x - y).collect(),
371 )
372}
373
374#[track_caller]
375#[inline]
376fn tensor_view_assign_addition_iter<'l, 'r, T, S1, S2, const D: usize>(
377 left_iter: S1,
378 left_shape: [(Dimension, usize); D],
379 right_iter: S2,
380 right_shape: [(Dimension, usize); D],
381) where
382 T: Numeric,
383 T: 'l,
384 T: 'r,
385 for<'a> &'a T: NumericRef<T>,
386 S1: Iterator<Item = &'l mut T>,
387 S2: Iterator<Item = &'r T>,
388{
389 assert_same_dimensions(left_shape, right_shape);
390 for (x, y) in left_iter.zip(right_iter) {
391 *x = x.clone() + y;
395 }
396}
397
398#[track_caller]
399#[inline]
400fn tensor_view_assign_subtraction_iter<'l, 'r, T, S1, S2, const D: usize>(
401 left_iter: S1,
402 left_shape: [(Dimension, usize); D],
403 right_iter: S2,
404 right_shape: [(Dimension, usize); D],
405) where
406 T: Numeric,
407 T: 'l,
408 T: 'r,
409 for<'a> &'a T: NumericRef<T>,
410 S1: Iterator<Item = &'l mut T>,
411 S2: Iterator<Item = &'r T>,
412{
413 assert_same_dimensions(left_shape, right_shape);
414 for (x, y) in left_iter.zip(right_iter) {
415 *x = x.clone() - y;
419 }
420}
421
422#[inline]
429pub(crate) fn scalar_product<'l, 'r, T, S1, S2>(left_iter: S1, right_iter: S2) -> T
430where
431 T: Numeric,
432 T: 'l,
433 T: 'r,
434 for<'a> &'a T: NumericRef<T>,
435 S1: Iterator<Item = &'l T>,
436 S2: Iterator<Item = &'r T>,
437{
438 left_iter
441 .zip(right_iter)
442 .map(|(x, y)| x * y)
443 .reduce(|x, y| x + y)
444 .unwrap() }
446
447#[track_caller]
448#[inline]
449fn tensor_view_vector_product_iter<'l, 'r, T, S1, S2>(
450 left_iter: S1,
451 left_shape: [(Dimension, usize); 1],
452 right_iter: S2,
453 right_shape: [(Dimension, usize); 1],
454) -> T
455where
456 T: Numeric,
457 T: 'l,
458 T: 'r,
459 for<'a> &'a T: NumericRef<T>,
460 S1: Iterator<Item = &'l T>,
461 S2: Iterator<Item = &'r T>,
462{
463 if left_shape[0].1 != right_shape[0].1 {
464 panic!(
465 "Dimension lengths of left and right tensors are not the same: (left: {:?}, right: {:?})",
466 left_shape, right_shape
467 );
468 }
469 scalar_product::<T, S1, S2>(left_iter, right_iter)
471}
472
473#[track_caller]
474#[inline]
475fn tensor_view_matrix_product<T, S1, S2>(left: S1, right: S2) -> Tensor<T, 2>
476where
477 T: Numeric,
478 for<'a> &'a T: NumericRef<T>,
479 S1: TensorRef<T, 2>,
480 S2: TensorRef<T, 2>,
481{
482 let left_shape = left.view_shape();
483 let right_shape = right.view_shape();
484 if left_shape[1].1 != right_shape[0].1 {
485 panic!(
486 "Mismatched tensors, left is {:?}, right is {:?}, * is only defined for MxN * NxL dimension lengths",
487 left.view_shape(),
488 right.view_shape()
489 );
490 }
491 if left_shape[0].0 == right_shape[1].0 {
492 panic!(
493 "Matrix multiplication of tensors with shapes left {:?} and right {:?} would \
494 create duplicate dimension names as the shape {:?}. Rename one or both of the \
495 dimension names in the input to prevent this. * is defined as MxN * NxL = MxL",
496 left_shape,
497 right_shape,
498 [left_shape[0], right_shape[1]]
499 )
500 }
501 let mut tensor = Tensor::empty([left.view_shape()[0], right.view_shape()[1]], T::zero());
506 for ([i, j], x) in tensor.iter_reference_mut().with_index() {
507 let left = TensorIndex::from(&left, [(left.view_shape()[0].0, i)]);
509 let right = TensorIndex::from(&right, [(right.view_shape()[1].0, j)]);
511 *x = scalar_product::<T, _, _>(
513 TensorReferenceIterator::from(&left),
514 TensorReferenceIterator::from(&right),
515 )
516 }
517 tensor
518}
519
520#[test]
521fn test_matrix_product() {
522 #[rustfmt::skip]
523 let left = Tensor::from([("r", 2), ("c", 3)], vec![
524 1, 2, 3,
525 4, 5, 6
526 ]);
527 #[rustfmt::skip]
528 let right = Tensor::from([("r", 3), ("c", 2)], vec![
529 10, 11,
530 12, 13,
531 14, 15
532 ]);
533 let result = tensor_view_matrix_product::<i32, _, _>(left, right);
534 #[rustfmt::skip]
535 assert_eq!(
536 result,
537 Tensor::from(
538 [("r", 2), ("c", 2)],
539 vec![
540 1 * 10 + 2 * 12 + 3 * 14, 1 * 11 + 2 * 13 + 3 * 15,
541 4 * 10 + 5 * 12 + 6 * 14, 4 * 11 + 5 * 13 + 6 * 15
542 ]
543 )
544 );
545}
546
547macro_rules! tensor_view_reference_tensor_view_reference_operation_iter {
551 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
552 #[doc=$doc]
553 impl<T, S1, S2, const D: usize> $op<&TensorView<T, S2, D>> for &TensorView<T, S1, D>
554 where
555 T: Numeric,
556 for<'a> &'a T: NumericRef<T>,
557 S1: TensorRef<T, D>,
558 S2: TensorRef<T, D>,
559 {
560 type Output = Tensor<T, D>;
561
562 #[track_caller]
563 #[inline]
564 fn $method(self, rhs: &TensorView<T, S2, D>) -> Self::Output {
565 $implementation::<T, _, _, D>(
566 self.iter_reference(),
567 self.shape(),
568 rhs.iter_reference(),
569 rhs.shape(),
570 )
571 }
572 }
573 };
574}
575
576macro_rules! tensor_view_reference_tensor_view_reference_operation {
577 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
578 #[doc=$doc]
579 impl<T, S1, S2> $op<&TensorView<T, S2, $d>> for &TensorView<T, S1, $d>
580 where
581 T: Numeric,
582 for<'a> &'a T: NumericRef<T>,
583 S1: TensorRef<T, $d>,
584 S2: TensorRef<T, $d>,
585 {
586 type Output = Tensor<T, $d>;
587
588 #[track_caller]
589 #[inline]
590 fn $method(self, rhs: &TensorView<T, S2, $d>) -> Self::Output {
591 $implementation::<T, _, _>(self.source_ref(), rhs.source_ref())
592 }
593 }
594 };
595}
596
597tensor_view_reference_tensor_view_reference_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for two referenced tensor views");
598tensor_view_reference_tensor_view_reference_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two referenced tensor views");
599tensor_view_reference_tensor_view_reference_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication of two referenced 2-dimensional tensors");
600
601macro_rules! tensor_view_assign_tensor_view_reference_operation_iter {
602 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
603 #[doc=$doc]
604 impl<T, S1, S2, const D: usize> $op<&TensorView<T, S2, D>> for TensorView<T, S1, D>
605 where
606 T: Numeric,
607 for<'a> &'a T: NumericRef<T>,
608 S1: TensorMut<T, D>,
609 S2: TensorRef<T, D>,
610 {
611 #[track_caller]
612 #[inline]
613 fn $method(&mut self, rhs: &TensorView<T, S2, D>) {
614 let left_shape = self.shape();
615 $implementation::<T, _, _, D>(
616 self.iter_reference_mut(),
617 left_shape,
618 rhs.iter_reference(),
619 rhs.shape(),
620 )
621 }
622 }
623 };
624}
625
626tensor_view_assign_tensor_view_reference_operation_iter!(impl AddAssign for TensorView { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for two referenced tensor views");
627tensor_view_assign_tensor_view_reference_operation_iter!(impl SubAssign for TensorView { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for two referenced tensor views");
628
629macro_rules! tensor_view_reference_tensor_view_value_operation_iter {
630 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
631 #[doc=$doc]
632 impl<T, S1, S2, const D: usize> $op<TensorView<T, S2, D>> for &TensorView<T, S1, D>
633 where
634 T: Numeric,
635 for<'a> &'a T: NumericRef<T>,
636 S1: TensorRef<T, D>,
637 S2: TensorRef<T, D>,
638 {
639 type Output = Tensor<T, D>;
640
641 #[track_caller]
642 #[inline]
643 fn $method(self, rhs: TensorView<T, S2, D>) -> Self::Output {
644 $implementation::<T, _, _, D>(
645 self.iter_reference(),
646 self.shape(),
647 rhs.iter_reference(),
648 rhs.shape(),
649 )
650 }
651 }
652 };
653}
654
655macro_rules! tensor_view_reference_tensor_view_value_operation {
656 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
657 #[doc=$doc]
658 impl<T, S1, S2> $op<TensorView<T, S2, $d>> for &TensorView<T, S1, $d>
659 where
660 T: Numeric,
661 for<'a> &'a T: NumericRef<T>,
662 S1: TensorRef<T, $d>,
663 S2: TensorRef<T, $d>,
664 {
665 type Output = Tensor<T, $d>;
666
667 #[track_caller]
668 #[inline]
669 fn $method(self, rhs: TensorView<T, S2, $d>) -> Self::Output {
670 $implementation::<T, _, _>(self.source_ref(), rhs.source_ref())
671 }
672 }
673 };
674}
675
676tensor_view_reference_tensor_view_value_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for two tensor views with one referenced");
677tensor_view_reference_tensor_view_value_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensor views with one referenced");
678tensor_view_reference_tensor_view_value_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication of two 2-dimensional tensors with one referenced");
679
680macro_rules! tensor_view_assign_tensor_view_value_operation_iter {
681 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
682 #[doc=$doc]
683 impl<T, S1, S2, const D: usize> $op<TensorView<T, S2, D>> for TensorView<T, S1, D>
684 where
685 T: Numeric,
686 for<'a> &'a T: NumericRef<T>,
687 S1: TensorMut<T, D>,
688 S2: TensorRef<T, D>,
689 {
690 #[track_caller]
691 #[inline]
692 fn $method(&mut self, rhs: TensorView<T, S2, D>) {
693 let left_shape = self.shape();
694 $implementation::<T, _, _, D>(
695 self.iter_reference_mut(),
696 left_shape,
697 rhs.iter_reference(),
698 rhs.shape(),
699 )
700 }
701 }
702 };
703}
704
705tensor_view_assign_tensor_view_value_operation_iter!(impl AddAssign for TensorView { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for two tensor views with one referenced");
706tensor_view_assign_tensor_view_value_operation_iter!(impl SubAssign for TensorView { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for two tensor views with one referenced");
707
708macro_rules! tensor_view_value_tensor_view_reference_operation_iter {
709 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
710 #[doc=$doc]
711 impl<T, S1, S2, const D: usize> $op<&TensorView<T, S2, D>> for TensorView<T, S1, D>
712 where
713 T: Numeric,
714 for<'a> &'a T: NumericRef<T>,
715 S1: TensorRef<T, D>,
716 S2: TensorRef<T, D>,
717 {
718 type Output = Tensor<T, D>;
719
720 #[track_caller]
721 #[inline]
722 fn $method(self, rhs: &TensorView<T, S2, D>) -> Self::Output {
723 $implementation::<T, _, _, D>(
724 self.iter_reference(),
725 self.shape(),
726 rhs.iter_reference(),
727 rhs.shape(),
728 )
729 }
730 }
731 };
732}
733
734macro_rules! tensor_view_value_tensor_view_reference_operation {
735 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
736 #[doc=$doc]
737 impl<T, S1, S2> $op<&TensorView<T, S2, $d>> for TensorView<T, S1, $d>
738 where
739 T: Numeric,
740 for<'a> &'a T: NumericRef<T>,
741 S1: TensorRef<T, $d>,
742 S2: TensorRef<T, $d>,
743 {
744 type Output = Tensor<T, $d>;
745
746 #[track_caller]
747 #[inline]
748 fn $method(self, rhs: &TensorView<T, S2, $d>) -> Self::Output {
749 $implementation::<T, _, _>(self.source_ref(), rhs.source_ref())
750 }
751 }
752 };
753}
754
755tensor_view_value_tensor_view_reference_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for two tensor views with one referenced");
756tensor_view_value_tensor_view_reference_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensor views with one referenced");
757tensor_view_value_tensor_view_reference_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication of two 2-dimensional tensors with one referenced");
758
759macro_rules! tensor_view_value_tensor_view_value_operation_iter {
760 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
761 #[doc=$doc]
762 impl<T, S1, S2, const D: usize> $op<TensorView<T, S2, D>> for TensorView<T, S1, D>
763 where
764 T: Numeric,
765 for<'a> &'a T: NumericRef<T>,
766 S1: TensorRef<T, D>,
767 S2: TensorRef<T, D>,
768 {
769 type Output = Tensor<T, D>;
770
771 #[track_caller]
772 #[inline]
773 fn $method(self, rhs: TensorView<T, S2, D>) -> Self::Output {
774 $implementation::<T, _, _, D>(
775 self.iter_reference(),
776 self.shape(),
777 rhs.iter_reference(),
778 rhs.shape(),
779 )
780 }
781 }
782 };
783}
784
785macro_rules! tensor_view_value_tensor_view_value_operation {
786 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
787 #[doc=$doc]
788 impl<T, S1, S2> $op<TensorView<T, S2, $d>> for TensorView<T, S1, $d>
789 where
790 T: Numeric,
791 for<'a> &'a T: NumericRef<T>,
792 S1: TensorRef<T, $d>,
793 S2: TensorRef<T, $d>,
794 {
795 type Output = Tensor<T, $d>;
796
797 #[track_caller]
798 #[inline]
799 fn $method(self, rhs: TensorView<T, S2, $d>) -> Self::Output {
800 $implementation::<T, _, _>(self.source_ref(), rhs.source_ref())
801 }
802 }
803 };
804}
805
806tensor_view_value_tensor_view_value_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for two tensor views");
807tensor_view_value_tensor_view_value_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensor views");
808tensor_view_value_tensor_view_value_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication of two 2-dimensional tensors");
809
810macro_rules! tensor_view_reference_tensor_reference_operation_iter {
811 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
812 #[doc=$doc]
813 impl<T, S, const D: usize> $op<&Tensor<T, D>> for &TensorView<T, S, D>
814 where
815 T: Numeric,
816 for<'a> &'a T: NumericRef<T>,
817 S: TensorRef<T, D>,
818 {
819 type Output = Tensor<T, D>;
820
821 #[track_caller]
822 #[inline]
823 fn $method(self, rhs: &Tensor<T, D>) -> Self::Output {
824 $implementation::<T, _, _, D>(
825 self.iter_reference(),
826 self.shape(),
827 rhs.direct_iter_reference(),
828 rhs.shape(),
829 )
830 }
831 }
832 };
833}
834
835macro_rules! tensor_view_reference_tensor_reference_operation {
836 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
837 #[doc=$doc]
838 impl<T, S> $op<&Tensor<T, $d>> for &TensorView<T, S, $d>
839 where
840 T: Numeric,
841 for<'a> &'a T: NumericRef<T>,
842 S: TensorRef<T, $d>,
843 {
844 type Output = Tensor<T, $d>;
845
846 #[track_caller]
847 #[inline]
848 fn $method(self, rhs: &Tensor<T, $d>) -> Self::Output {
849 $implementation::<T, _, _>(self.source_ref(), rhs)
850 }
851 }
852 };
853}
854
855tensor_view_reference_tensor_reference_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for a referenced tensor view and a referenced tensor");
856tensor_view_reference_tensor_reference_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a referenced tensor view and a referenced tensor");
857tensor_view_reference_tensor_reference_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional referenced tensor view and a referenced tensor");
858
859macro_rules! tensor_view_assign_tensor_reference_operation_iter {
860 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
861 #[doc=$doc]
862 impl<T, S, const D: usize> $op<&Tensor<T, D>> for TensorView<T, S, D>
863 where
864 T: Numeric,
865 for<'a> &'a T: NumericRef<T>,
866 S: TensorMut<T, D>,
867 {
868 #[track_caller]
869 #[inline]
870 fn $method(&mut self, rhs: &Tensor<T, D>) {
871 let left_shape = self.shape();
872 $implementation::<T, _, _, D>(
873 self.iter_reference_mut(),
874 left_shape,
875 rhs.direct_iter_reference(),
876 rhs.shape(),
877 )
878 }
879 }
880 };
881}
882
883tensor_view_assign_tensor_reference_operation_iter!(impl AddAssign for TensorView { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for a referenced tensor view and a referenced tensor");
884tensor_view_assign_tensor_reference_operation_iter!(impl SubAssign for TensorView { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for a referenced tensor view and a referenced tensor");
885
886macro_rules! tensor_view_reference_tensor_value_operation_iter {
887 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
888 #[doc=$doc]
889 impl<T, S, const D: usize> $op<Tensor<T, D>> for &TensorView<T, S, D>
890 where
891 T: Numeric,
892 for<'a> &'a T: NumericRef<T>,
893 S: TensorRef<T, D>,
894 {
895 type Output = Tensor<T, D>;
896
897 #[track_caller]
898 #[inline]
899 fn $method(self, rhs: Tensor<T, D>) -> Self::Output {
900 $implementation::<T, _, _, D>(
901 self.iter_reference(),
902 self.shape(),
903 rhs.direct_iter_reference(),
904 rhs.shape(),
905 )
906 }
907 }
908 };
909}
910
911macro_rules! tensor_view_reference_tensor_value_operation {
912 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
913 #[doc=$doc]
914 impl<T, S> $op<Tensor<T, $d>> for &TensorView<T, S, $d>
915 where
916 T: Numeric,
917 for<'a> &'a T: NumericRef<T>,
918 S: TensorRef<T, $d>,
919 {
920 type Output = Tensor<T, $d>;
921
922 #[track_caller]
923 #[inline]
924 fn $method(self, rhs: Tensor<T, $d>) -> Self::Output {
925 $implementation::<T, _, _>(self.source_ref(), rhs)
926 }
927 }
928 };
929}
930
931tensor_view_reference_tensor_value_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for a referenced tensor view and a tensor");
932tensor_view_reference_tensor_value_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a referenced tensor view and a tensor");
933tensor_view_reference_tensor_value_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional referenced tensor view and a tensor");
934
935macro_rules! tensor_view_assign_tensor_value_operation_iter {
936 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
937 #[doc=$doc]
938 impl<T, S, const D: usize> $op<Tensor<T, D>> for TensorView<T, S, D>
939 where
940 T: Numeric,
941 for<'a> &'a T: NumericRef<T>,
942 S: TensorMut<T, D>,
943 {
944 #[track_caller]
945 #[inline]
946 fn $method(&mut self, rhs: Tensor<T, D>) {
947 let left_shape = self.shape();
948 $implementation::<T, _, _, D>(
949 self.iter_reference_mut(),
950 left_shape,
951 rhs.direct_iter_reference(),
952 rhs.shape(),
953 )
954 }
955 }
956 };
957}
958
959tensor_view_assign_tensor_value_operation_iter!(impl AddAssign for TensorView { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for a referenced tensor view and a tensor");
960tensor_view_assign_tensor_value_operation_iter!(impl SubAssign for TensorView { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for a referenced tensor view and a tensor");
961
962macro_rules! tensor_view_value_tensor_reference_operation_iter {
963 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
964 #[doc=$doc]
965 impl<T, S, const D: usize> $op<&Tensor<T, D>> for TensorView<T, S, D>
966 where
967 T: Numeric,
968 for<'a> &'a T: NumericRef<T>,
969 S: TensorRef<T, D>,
970 {
971 type Output = Tensor<T, D>;
972
973 #[track_caller]
974 #[inline]
975 fn $method(self, rhs: &Tensor<T, D>) -> Self::Output {
976 $implementation::<T, _, _, D>(
977 self.iter_reference(),
978 self.shape(),
979 rhs.direct_iter_reference(),
980 rhs.shape(),
981 )
982 }
983 }
984 };
985}
986
987macro_rules! tensor_view_value_tensor_reference_operation {
988 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
989 #[doc=$doc]
990 impl<T, S> $op<&Tensor<T, $d>> for TensorView<T, S, $d>
991 where
992 T: Numeric,
993 for<'a> &'a T: NumericRef<T>,
994 S: TensorRef<T, $d>,
995 {
996 type Output = Tensor<T, $d>;
997
998 #[track_caller]
999 #[inline]
1000 fn $method(self, rhs: &Tensor<T, $d>) -> Self::Output {
1001 $implementation::<T, _, _>(self.source_ref(), rhs)
1002 }
1003 }
1004 };
1005}
1006
1007tensor_view_value_tensor_reference_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for a tensor view and a referenced tensor");
1008tensor_view_value_tensor_reference_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a tensor view and a referenced tensor");
1009tensor_view_value_tensor_reference_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional tensor view and a referenced tensor");
1010
1011macro_rules! tensor_view_value_tensor_value_operation_iter {
1012 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
1013 #[doc=$doc]
1014 impl<T, S, const D: usize> $op<Tensor<T, D>> for TensorView<T, S, D>
1015 where
1016 T: Numeric,
1017 for<'a> &'a T: NumericRef<T>,
1018 S: TensorRef<T, D>,
1019 {
1020 type Output = Tensor<T, D>;
1021
1022 #[track_caller]
1023 #[inline]
1024 fn $method(self, rhs: Tensor<T, D>) -> Self::Output {
1025 $implementation::<T, _, _, D>(
1026 self.iter_reference(),
1027 self.shape(),
1028 rhs.direct_iter_reference(),
1029 rhs.shape(),
1030 )
1031 }
1032 }
1033 };
1034}
1035
1036macro_rules! tensor_view_value_tensor_value_operation {
1037 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1038 #[doc=$doc]
1039 impl<T, S> $op<Tensor<T, $d>> for TensorView<T, S, $d>
1040 where
1041 T: Numeric,
1042 for<'a> &'a T: NumericRef<T>,
1043 S: TensorRef<T, $d>,
1044 {
1045 type Output = Tensor<T, $d>;
1046
1047 #[track_caller]
1048 #[inline]
1049 fn $method(self, rhs: Tensor<T, $d>) -> Self::Output {
1050 $implementation::<T, _, _>(self.source_ref(), rhs)
1051 }
1052 }
1053 };
1054}
1055
1056tensor_view_value_tensor_value_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for a tensor view and a tensor");
1057tensor_view_value_tensor_value_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a tensor view and a tensor");
1058tensor_view_value_tensor_value_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional tensor view and a tensor");
1059
1060macro_rules! tensor_reference_tensor_view_reference_operation_iter {
1061 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1062 #[doc=$doc]
1063 impl<T, S, const D: usize> $op<&TensorView<T, S, D>> for &Tensor<T, D>
1064 where
1065 T: Numeric,
1066 for<'a> &'a T: NumericRef<T>,
1067 S: TensorRef<T, D>,
1068 {
1069 type Output = Tensor<T, D>;
1070
1071 #[track_caller]
1072 #[inline]
1073 fn $method(self, rhs: &TensorView<T, S, D>) -> Self::Output {
1074 $implementation::<T, _, _, D>(
1075 self.direct_iter_reference(),
1076 self.shape(),
1077 rhs.iter_reference(),
1078 rhs.shape(),
1079 )
1080 }
1081 }
1082 };
1083}
1084
1085macro_rules! tensor_reference_tensor_view_reference_operation {
1086 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1087 #[doc=$doc]
1088 impl<T, S> $op<&TensorView<T, S, $d>> for &Tensor<T, $d>
1089 where
1090 T: Numeric,
1091 for<'a> &'a T: NumericRef<T>,
1092 S: TensorRef<T, $d>,
1093 {
1094 type Output = Tensor<T, $d>;
1095
1096 #[track_caller]
1097 #[inline]
1098 fn $method(self, rhs: &TensorView<T, S, $d>) -> Self::Output {
1099 $implementation::<T, _, _>(self, rhs.source_ref())
1100 }
1101 }
1102 };
1103}
1104
1105tensor_reference_tensor_view_reference_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for a referenced tensor and a referenced tensor view");
1106tensor_reference_tensor_view_reference_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a referenced tensor and a referenced tensor view");
1107tensor_reference_tensor_view_reference_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional referenced tensor and a referenced tensor view");
1108
1109macro_rules! tensor_assign_tensor_view_reference_operation_iter {
1110 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1111 #[doc=$doc]
1112 impl<T, S, const D: usize> $op<&TensorView<T, S, D>> for Tensor<T, D>
1113 where
1114 T: Numeric,
1115 for<'a> &'a T: NumericRef<T>,
1116 S: TensorRef<T, D>,
1117 {
1118 #[track_caller]
1119 #[inline]
1120 fn $method(&mut self, rhs: &TensorView<T, S, D>) {
1121 let left_shape = self.shape();
1122 $implementation::<T, _, _, D>(
1123 self.direct_iter_reference_mut(),
1124 left_shape,
1125 rhs.iter_reference(),
1126 rhs.shape(),
1127 )
1128 }
1129 }
1130 };
1131}
1132
1133tensor_assign_tensor_view_reference_operation_iter!(impl AddAssign for Tensor { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for a referenced tensor and a referenced tensor view");
1134tensor_assign_tensor_view_reference_operation_iter!(impl SubAssign for Tensor { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for a referenced tensor and a referenced tensor view");
1135
1136macro_rules! tensor_reference_tensor_view_value_operation_iter {
1137 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1138 #[doc=$doc]
1139 impl<T, S, const D: usize> $op<TensorView<T, S, D>> for &Tensor<T, D>
1140 where
1141 T: Numeric,
1142 for<'a> &'a T: NumericRef<T>,
1143 S: TensorRef<T, D>,
1144 {
1145 type Output = Tensor<T, D>;
1146
1147 #[track_caller]
1148 #[inline]
1149 fn $method(self, rhs: TensorView<T, S, D>) -> Self::Output {
1150 $implementation::<T, _, _, D>(
1151 self.direct_iter_reference(),
1152 self.shape(),
1153 rhs.iter_reference(),
1154 rhs.shape(),
1155 )
1156 }
1157 }
1158 };
1159}
1160
1161macro_rules! tensor_reference_tensor_view_value_operation {
1162 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1163 #[doc=$doc]
1164 impl<T, S> $op<TensorView<T, S, $d>> for &Tensor<T, $d>
1165 where
1166 T: Numeric,
1167 for<'a> &'a T: NumericRef<T>,
1168 S: TensorRef<T, $d>,
1169 {
1170 type Output = Tensor<T, $d>;
1171
1172 #[track_caller]
1173 #[inline]
1174 fn $method(self, rhs: TensorView<T, S, $d>) -> Self::Output {
1175 $implementation::<T, _, _>(self, rhs.source_ref())
1176 }
1177 }
1178 };
1179}
1180
1181tensor_reference_tensor_view_value_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for a referenced tensor and a tensor view");
1182tensor_reference_tensor_view_value_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a referenced tensor and a tensor view");
1183tensor_reference_tensor_view_value_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional referenced tensor and a tensor view");
1184
1185macro_rules! tensor_assign_tensor_view_value_operation_iter {
1186 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1187 #[doc=$doc]
1188 impl<T, S, const D: usize> $op<TensorView<T, S, D>> for Tensor<T, D>
1189 where
1190 T: Numeric,
1191 for<'a> &'a T: NumericRef<T>,
1192 S: TensorRef<T, D>,
1193 {
1194 #[track_caller]
1195 #[inline]
1196 fn $method(&mut self, rhs: TensorView<T, S, D>) {
1197 let left_shape = self.shape();
1198 $implementation::<T, _, _, D>(
1199 self.direct_iter_reference_mut(),
1200 left_shape,
1201 rhs.iter_reference(),
1202 rhs.shape(),
1203 )
1204 }
1205 }
1206 };
1207}
1208
1209tensor_assign_tensor_view_value_operation_iter!(impl AddAssign for Tensor { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for a referenced tensor and a tensor view");
1210tensor_assign_tensor_view_value_operation_iter!(impl SubAssign for Tensor { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for a referenced tensor and a tensor view");
1211
1212macro_rules! tensor_value_tensor_view_reference_operation_iter {
1213 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1214 #[doc=$doc]
1215 impl<T, S, const D: usize> $op<&TensorView<T, S, D>> for Tensor<T, D>
1216 where
1217 T: Numeric,
1218 for<'a> &'a T: NumericRef<T>,
1219 S: TensorRef<T, D>,
1220 {
1221 type Output = Tensor<T, D>;
1222
1223 #[track_caller]
1224 #[inline]
1225 fn $method(self, rhs: &TensorView<T, S, D>) -> Self::Output {
1226 $implementation::<T, _, _, D>(
1227 self.direct_iter_reference(),
1228 self.shape(),
1229 rhs.iter_reference(),
1230 rhs.shape(),
1231 )
1232 }
1233 }
1234 };
1235}
1236
1237macro_rules! tensor_value_tensor_view_reference_operation {
1238 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1239 #[doc=$doc]
1240 impl<T, S> $op<&TensorView<T, S, $d>> for Tensor<T, $d>
1241 where
1242 T: Numeric,
1243 for<'a> &'a T: NumericRef<T>,
1244 S: TensorRef<T, $d>,
1245 {
1246 type Output = Tensor<T, $d>;
1247
1248 #[track_caller]
1249 #[inline]
1250 fn $method(self, rhs: &TensorView<T, S, $d>) -> Self::Output {
1251 $implementation::<T, _, _>(self, rhs.source_ref())
1252 }
1253 }
1254 };
1255}
1256
1257tensor_value_tensor_view_reference_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for a tensor and a referenced tensor view");
1258tensor_value_tensor_view_reference_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a tensor and a referenced tensor view");
1259tensor_value_tensor_view_reference_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional tensor and a referenced tensor view");
1260
1261macro_rules! tensor_value_tensor_view_value_operation_iter {
1262 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1263 #[doc=$doc]
1264 impl<T, S, const D: usize> $op<TensorView<T, S, D>> for Tensor<T, D>
1265 where
1266 T: Numeric,
1267 for<'a> &'a T: NumericRef<T>,
1268 S: TensorRef<T, D>,
1269 {
1270 type Output = Tensor<T, D>;
1271
1272 #[track_caller]
1273 #[inline]
1274 fn $method(self, rhs: TensorView<T, S, D>) -> Self::Output {
1275 $implementation::<T, _, _, D>(
1276 self.direct_iter_reference(),
1277 self.shape(),
1278 rhs.iter_reference(),
1279 rhs.shape(),
1280 )
1281 }
1282 }
1283 };
1284}
1285
1286macro_rules! tensor_value_tensor_view_value_operation {
1287 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1288 #[doc=$doc]
1289 impl<T, S> $op<TensorView<T, S, $d>> for Tensor<T, $d>
1290 where
1291 T: Numeric,
1292 for<'a> &'a T: NumericRef<T>,
1293 S: TensorRef<T, $d>,
1294 {
1295 type Output = Tensor<T, $d>;
1296
1297 #[track_caller]
1298 #[inline]
1299 fn $method(self, rhs: TensorView<T, S, $d>) -> Self::Output {
1300 $implementation::<T, _, _>(self, rhs.source_ref())
1301 }
1302 }
1303 };
1304}
1305
1306tensor_value_tensor_view_value_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for a tensor and a tensor view");
1307tensor_value_tensor_view_value_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a tensor and a tensor view");
1308tensor_value_tensor_view_value_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional tensor and a tensor view");
1309
1310macro_rules! tensor_reference_tensor_reference_operation_iter {
1311 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1312 #[doc=$doc]
1313 impl<T, const D: usize> $op<&Tensor<T, D>> for &Tensor<T, D>
1314 where
1315 T: Numeric,
1316 for<'a> &'a T: NumericRef<T>,
1317 {
1318 type Output = Tensor<T, D>;
1319
1320 #[track_caller]
1321 #[inline]
1322 fn $method(self, rhs: &Tensor<T, D>) -> Self::Output {
1323 $implementation::<T, _, _, D>(
1324 self.direct_iter_reference(),
1325 self.shape(),
1326 rhs.direct_iter_reference(),
1327 rhs.shape(),
1328 )
1329 }
1330 }
1331 };
1332}
1333
1334macro_rules! tensor_reference_tensor_reference_operation {
1335 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1336 #[doc=$doc]
1337 impl<T> $op<&Tensor<T, $d>> for &Tensor<T, $d>
1338 where
1339 T: Numeric,
1340 for<'a> &'a T: NumericRef<T>,
1341 {
1342 type Output = Tensor<T, $d>;
1343
1344 #[track_caller]
1345 #[inline]
1346 fn $method(self, rhs: &Tensor<T, $d>) -> Self::Output {
1347 $implementation::<T, _, _>(self, rhs)
1348 }
1349 }
1350 };
1351}
1352
1353tensor_reference_tensor_reference_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for two referenced tensors");
1354tensor_reference_tensor_reference_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two referenced tensors");
1355tensor_reference_tensor_reference_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for two 2-dimensional referenced tensors");
1356
1357macro_rules! tensor_assign_tensor_reference_operation_iter {
1358 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1359 #[doc=$doc]
1360 impl<T, const D: usize> $op<&Tensor<T, D>> for Tensor<T, D>
1361 where
1362 T: Numeric,
1363 for<'a> &'a T: NumericRef<T>,
1364 {
1365 #[track_caller]
1366 #[inline]
1367 fn $method(&mut self, rhs: &Tensor<T, D>) {
1368 let left_shape = self.shape();
1369 $implementation::<T, _, _, D>(
1370 self.direct_iter_reference_mut(),
1371 left_shape,
1372 rhs.direct_iter_reference(),
1373 rhs.shape(),
1374 )
1375 }
1376 }
1377 };
1378}
1379
1380tensor_assign_tensor_reference_operation_iter!(impl AddAssign for Tensor { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for two referenced tensors");
1381tensor_assign_tensor_reference_operation_iter!(impl SubAssign for Tensor { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for two referenced tensors");
1382
1383macro_rules! tensor_reference_tensor_value_operation_iter {
1384 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1385 #[doc=$doc]
1386 impl<T, const D: usize> $op<Tensor<T, D>> for &Tensor<T, D>
1387 where
1388 T: Numeric,
1389 for<'a> &'a T: NumericRef<T>,
1390 {
1391 type Output = Tensor<T, D>;
1392
1393 #[track_caller]
1394 #[inline]
1395 fn $method(self, rhs: Tensor<T, D>) -> Self::Output {
1396 $implementation::<T, _, _, D>(
1397 self.direct_iter_reference(),
1398 self.shape(),
1399 rhs.direct_iter_reference(),
1400 rhs.shape(),
1401 )
1402 }
1403 }
1404 };
1405}
1406
1407macro_rules! tensor_reference_tensor_value_operation {
1408 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1409 #[doc=$doc]
1410 impl<T> $op<Tensor<T, $d>> for &Tensor<T, $d>
1411 where
1412 T: Numeric,
1413 for<'a> &'a T: NumericRef<T>,
1414 {
1415 type Output = Tensor<T, $d>;
1416
1417 #[track_caller]
1418 #[inline]
1419 fn $method(self, rhs: Tensor<T, $d>) -> Self::Output {
1420 $implementation::<T, _, _>(self, rhs)
1421 }
1422 }
1423 };
1424}
1425
1426tensor_reference_tensor_value_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for two tensors with one referenced");
1427tensor_reference_tensor_value_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensors with one referenced");
1428tensor_reference_tensor_value_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for two 2-dimensional tensors with one referenced");
1429
1430macro_rules! tensor_assign_tensor_value_operation_iter {
1431 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1432 #[doc=$doc]
1433 impl<T, const D: usize> $op<Tensor<T, D>> for Tensor<T, D>
1434 where
1435 T: Numeric,
1436 for<'a> &'a T: NumericRef<T>,
1437 {
1438 #[track_caller]
1439 #[inline]
1440 fn $method(&mut self, rhs: Tensor<T, D>) {
1441 let left_shape = self.shape();
1442 $implementation::<T, _, _, D>(
1443 self.direct_iter_reference_mut(),
1444 left_shape,
1445 rhs.direct_iter_reference(),
1446 rhs.shape(),
1447 )
1448 }
1449 }
1450 };
1451}
1452
1453tensor_assign_tensor_value_operation_iter!(impl AddAssign for Tensor { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for two tensors with one referenced");
1454tensor_assign_tensor_value_operation_iter!(impl SubAssign for Tensor { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for two tensors with one referenced");
1455
1456macro_rules! tensor_value_tensor_reference_operation_iter {
1457 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1458 #[doc=$doc]
1459 impl<T, const D: usize> $op<&Tensor<T, D>> for Tensor<T, D>
1460 where
1461 T: Numeric,
1462 for<'a> &'a T: NumericRef<T>,
1463 {
1464 type Output = Tensor<T, D>;
1465
1466 #[track_caller]
1467 #[inline]
1468 fn $method(self, rhs: &Tensor<T, D>) -> Self::Output {
1469 $implementation::<T, _, _, D>(
1470 self.direct_iter_reference(),
1471 self.shape(),
1472 rhs.direct_iter_reference(),
1473 rhs.shape(),
1474 )
1475 }
1476 }
1477 };
1478}
1479
1480macro_rules! tensor_value_tensor_reference_operation {
1481 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1482 #[doc=$doc]
1483 impl<T> $op<&Tensor<T, $d>> for Tensor<T, $d>
1484 where
1485 T: Numeric,
1486 for<'a> &'a T: NumericRef<T>,
1487 {
1488 type Output = Tensor<T, $d>;
1489
1490 #[track_caller]
1491 #[inline]
1492 fn $method(self, rhs: &Tensor<T, $d>) -> Self::Output {
1493 $implementation::<T, _, _>(self, rhs)
1494 }
1495 }
1496 };
1497}
1498
1499tensor_value_tensor_reference_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for two tensors with one referenced");
1500tensor_value_tensor_reference_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensors with one referenced");
1501tensor_value_tensor_reference_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for two 2-dimensional tensors with one referenced");
1502
1503macro_rules! tensor_value_tensor_value_operation_iter {
1504 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1505 #[doc=$doc]
1506 impl<T, const D: usize> $op<Tensor<T, D>> for Tensor<T, D>
1507 where
1508 T: Numeric,
1509 for<'a> &'a T: NumericRef<T>,
1510 {
1511 type Output = Tensor<T, D>;
1512
1513 #[track_caller]
1514 #[inline]
1515 fn $method(self, rhs: Tensor<T, D>) -> Self::Output {
1516 $implementation::<T, _, _, D>(
1517 self.direct_iter_reference(),
1518 self.shape(),
1519 rhs.direct_iter_reference(),
1520 rhs.shape(),
1521 )
1522 }
1523 }
1524 };
1525}
1526
1527macro_rules! tensor_value_tensor_value_operation {
1528 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1529 #[doc=$doc]
1530 impl<T> $op<Tensor<T, $d>> for Tensor<T, $d>
1531 where
1532 T: Numeric,
1533 for<'a> &'a T: NumericRef<T>,
1534 {
1535 type Output = Tensor<T, $d>;
1536
1537 #[track_caller]
1538 #[inline]
1539 fn $method(self, rhs: Tensor<T, $d>) -> Self::Output {
1540 $implementation::<T, _, _>(self, rhs)
1541 }
1542 }
1543 };
1544}
1545
1546tensor_value_tensor_value_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for two tensors");
1547tensor_value_tensor_value_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensors");
1548tensor_value_tensor_value_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for two 2-dimensional tensors");
1549
1550#[test]
1551fn elementwise_addition_test_all_16_combinations() {
1552 fn tensor() -> Tensor<i8, 1> {
1553 Tensor::from([("a", 1)], vec![1])
1554 }
1555 fn tensor_view() -> TensorView<i8, Tensor<i8, 1>, 1> {
1556 TensorView::from(tensor())
1557 }
1558 let mut results = Vec::with_capacity(16);
1559 results.push(tensor() + tensor());
1560 results.push(tensor() + &tensor());
1561 results.push(&tensor() + tensor());
1562 results.push(&tensor() + &tensor());
1563 results.push(tensor_view() + tensor());
1564 results.push(tensor_view() + &tensor());
1565 results.push(&tensor_view() + tensor());
1566 results.push(&tensor_view() + &tensor());
1567 results.push(tensor() + tensor_view());
1568 results.push(tensor() + &tensor_view());
1569 results.push(&tensor() + tensor_view());
1570 results.push(&tensor() + &tensor_view());
1571 results.push(tensor_view() + tensor_view());
1572 results.push(tensor_view() + &tensor_view());
1573 results.push(&tensor_view() + tensor_view());
1574 results.push(&tensor_view() + &tensor_view());
1575 for total in results {
1576 assert_eq!(total.index_by(["a"]).get([0]), 2);
1577 }
1578}
1579
1580#[test]
1581fn elementwise_addition_assign_test_all_8_combinations() {
1582 fn tensor() -> Tensor<i8, 1> {
1583 Tensor::from([("a", 1)], vec![1])
1584 }
1585 fn tensor_view() -> TensorView<i8, Tensor<i8, 1>, 1> {
1586 TensorView::from(tensor())
1587 }
1588 let mut results_tensors = Vec::with_capacity(16);
1589 let mut results_tensor_views = Vec::with_capacity(16);
1590 results_tensors.push({
1591 let mut x = tensor();
1592 x += tensor();
1593 x
1594 });
1595 results_tensors.push({
1596 let mut x = tensor();
1597 x += &tensor();
1598 x
1599 });
1600 results_tensor_views.push({
1601 let mut x = tensor_view();
1602 x += tensor();
1603 x
1604 });
1605 results_tensor_views.push({
1606 let mut x = tensor_view();
1607 x += &tensor();
1608 x
1609 });
1610 results_tensors.push({
1611 let mut x = tensor();
1612 x += tensor_view();
1613 x
1614 });
1615 results_tensors.push({
1616 let mut x = tensor();
1617 x += &tensor_view();
1618 x
1619 });
1620 results_tensor_views.push({
1621 let mut x = tensor_view();
1622 x += tensor_view();
1623 x
1624 });
1625 results_tensor_views.push({
1626 let mut x = tensor_view();
1627 x += &tensor_view();
1628 x
1629 });
1630 for total in results_tensors {
1631 assert_eq!(total.index_by(["a"]).get([0]), 2);
1632 }
1633 for total in results_tensor_views {
1634 assert_eq!(total.index_by(["a"]).get([0]), 2);
1635 }
1636}
1637
1638#[test]
1639fn elementwise_addition_test() {
1640 let tensor_1: Tensor<i32, 2> = Tensor::from([("r", 2), ("c", 2)], vec![1, 2, 3, 4]);
1641 let tensor_2: Tensor<i32, 2> = Tensor::from([("r", 2), ("c", 2)], vec![3, 2, 8, 1]);
1642 let added: Tensor<i32, 2> = tensor_1 + tensor_2;
1643 assert_eq!(added, Tensor::from([("r", 2), ("c", 2)], vec![4, 4, 11, 5]));
1644}
1645
1646#[should_panic]
1647#[test]
1648fn elementwise_addition_test_similar_not_matching() {
1649 let tensor_1: Tensor<i32, 2> = Tensor::from([("r", 2), ("c", 2)], vec![1, 2, 3, 4]);
1650 let tensor_2: Tensor<i32, 2> = Tensor::from([("c", 2), ("r", 2)], vec![3, 8, 2, 1]);
1651 let _: Tensor<i32, 2> = tensor_1 + tensor_2;
1652}
1653
1654#[test]
1655fn matrix_multiplication_test_all_16_combinations() {
1656 #[rustfmt::skip]
1657 fn tensor_1() -> Tensor<i8, 2> {
1658 Tensor::from([("r", 2), ("c", 3)], vec![
1659 1, 2, 3,
1660 4, 5, 6
1661 ])
1662 }
1663 fn tensor_1_view() -> TensorView<i8, Tensor<i8, 2>, 2> {
1664 TensorView::from(tensor_1())
1665 }
1666 #[rustfmt::skip]
1667 fn tensor_2() -> Tensor<i8, 2> {
1668 Tensor::from([("a", 3), ("b", 2)], vec![
1669 1, 2,
1670 3, 4,
1671 5, 6
1672 ])
1673 }
1674 fn tensor_2_view() -> TensorView<i8, Tensor<i8, 2>, 2> {
1675 TensorView::from(tensor_2())
1676 }
1677 let mut results = Vec::with_capacity(16);
1678 results.push(tensor_1() * tensor_2());
1679 results.push(tensor_1() * &tensor_2());
1680 results.push(&tensor_1() * tensor_2());
1681 results.push(&tensor_1() * &tensor_2());
1682 results.push(tensor_1_view() * tensor_2());
1683 results.push(tensor_1_view() * &tensor_2());
1684 results.push(&tensor_1_view() * tensor_2());
1685 results.push(&tensor_1_view() * &tensor_2());
1686 results.push(tensor_1() * tensor_2_view());
1687 results.push(tensor_1() * &tensor_2_view());
1688 results.push(&tensor_1() * tensor_2_view());
1689 results.push(&tensor_1() * &tensor_2_view());
1690 results.push(tensor_1_view() * tensor_2_view());
1691 results.push(tensor_1_view() * &tensor_2_view());
1692 results.push(&tensor_1_view() * tensor_2_view());
1693 results.push(&tensor_1_view() * &tensor_2_view());
1694 for total in results {
1695 #[rustfmt::skip]
1696 assert_eq!(
1697 total,
1698 Tensor::from(
1699 [("r", 2), ("b", 2)],
1700 vec![
1701 1 * 1 + 2 * 3 + 3 * 5, 1 * 2 + 2 * 4 + 3 * 6,
1702 4 * 1 + 5 * 3 + 6 * 5, 4 * 2 + 5 * 4 + 6 * 6,
1703 ]
1704 )
1705 );
1706 }
1707}
1708
1709#[test]
1710fn elementwise_subtraction_assign_test_all_8_combinations() {
1711 fn tensor() -> Tensor<i8, 1> {
1712 Tensor::from([("a", 1)], vec![1])
1713 }
1714 fn tensor_view() -> TensorView<i8, Tensor<i8, 1>, 1> {
1715 TensorView::from(tensor())
1716 }
1717 let mut results_tensors = Vec::with_capacity(16);
1718 let mut results_tensor_views = Vec::with_capacity(16);
1719 results_tensors.push({
1720 let mut x = tensor();
1721 x -= tensor();
1722 x
1723 });
1724 results_tensors.push({
1725 let mut x = tensor();
1726 x -= &tensor();
1727 x
1728 });
1729 results_tensor_views.push({
1730 let mut x = tensor_view();
1731 x -= tensor();
1732 x
1733 });
1734 results_tensor_views.push({
1735 let mut x = tensor_view();
1736 x -= &tensor();
1737 x
1738 });
1739 results_tensors.push({
1740 let mut x = tensor();
1741 x -= tensor_view();
1742 x
1743 });
1744 results_tensors.push({
1745 let mut x = tensor();
1746 x -= &tensor_view();
1747 x
1748 });
1749 results_tensor_views.push({
1750 let mut x = tensor_view();
1751 x -= tensor_view();
1752 x
1753 });
1754 results_tensor_views.push({
1755 let mut x = tensor_view();
1756 x -= &tensor_view();
1757 x
1758 });
1759 for total in results_tensors {
1760 assert_eq!(total.index_by(["a"]).get([0]), 0);
1761 }
1762 for total in results_tensor_views {
1763 assert_eq!(total.index_by(["a"]).get([0]), 0);
1764 }
1765}
1766
1767impl<T> Tensor<T, 1>
1768where
1769 T: Numeric,
1770 for<'a> &'a T: NumericRef<T>,
1771{
1772 pub(crate) fn scalar_product_less_generic<S>(&self, rhs: TensorView<T, S, 1>) -> T
1773 where
1774 S: TensorRef<T, 1>,
1775 {
1776 let left_shape = self.shape();
1777 let right_shape = rhs.shape();
1778 assert_same_dimensions(left_shape, right_shape);
1779 tensor_view_vector_product_iter::<T, _, _>(
1780 self.direct_iter_reference(),
1781 left_shape,
1782 rhs.iter_reference(),
1783 right_shape,
1784 )
1785 }
1786}
1787
1788impl<T, S> TensorView<T, S, 1>
1789where
1790 T: Numeric,
1791 for<'a> &'a T: NumericRef<T>,
1792 S: TensorRef<T, 1>,
1793{
1794 pub(crate) fn scalar_product_less_generic<S2>(&self, rhs: TensorView<T, S2, 1>) -> T
1795 where
1796 S2: TensorRef<T, 1>,
1797 {
1798 let left_shape = self.shape();
1799 let right_shape = rhs.shape();
1800 assert_same_dimensions(left_shape, right_shape);
1801 tensor_view_vector_product_iter::<T, _, _>(
1802 self.iter_reference(),
1803 left_shape,
1804 rhs.iter_reference(),
1805 right_shape,
1806 )
1807 }
1808}
1809
1810macro_rules! tensor_scalar {
1811 (impl $op:tt for Tensor { fn $method:ident }) => {
1812 impl<T: Numeric, const D: usize> $op<&T> for &Tensor<T, D>
1817 where
1818 for<'a> &'a T: NumericRef<T>,
1819 {
1820 type Output = Tensor<T, D>;
1821 #[inline]
1822 fn $method(self, rhs: &T) -> Self::Output {
1823 self.map(|x| (x).$method(rhs.clone()))
1824 }
1825 }
1826
1827 impl<T: Numeric, const D: usize> $op<&T> for Tensor<T, D>
1832 where
1833 for<'a> &'a T: NumericRef<T>,
1834 {
1835 type Output = Tensor<T, D>;
1836 #[inline]
1837 fn $method(self, rhs: &T) -> Self::Output {
1838 self.map(|x| (x).$method(rhs.clone()))
1839 }
1840 }
1841
1842 impl<T: Numeric, const D: usize> $op<T> for &Tensor<T, D>
1847 where
1848 for<'a> &'a T: NumericRef<T>,
1849 {
1850 type Output = Tensor<T, D>;
1851 #[inline]
1852 fn $method(self, rhs: T) -> Self::Output {
1853 self.map(|x| (x).$method(rhs.clone()))
1854 }
1855 }
1856
1857 impl<T: Numeric, const D: usize> $op<T> for Tensor<T, D>
1862 where
1863 for<'a> &'a T: NumericRef<T>,
1864 {
1865 type Output = Tensor<T, D>;
1866 #[inline]
1867 fn $method(self, rhs: T) -> Self::Output {
1868 self.map(|x| (x).$method(rhs.clone()))
1869 }
1870 }
1871 };
1872}
1873
1874macro_rules! tensor_view_scalar {
1875 (impl $op:tt for TensorView { fn $method:ident }) => {
1876 impl<T, S, const D: usize> $op<&T> for &TensorView<T, S, D>
1881 where
1882 T: Numeric,
1883 for<'a> &'a T: NumericRef<T>,
1884 S: TensorRef<T, D>,
1885 {
1886 type Output = Tensor<T, D>;
1887 #[inline]
1888 fn $method(self, rhs: &T) -> Self::Output {
1889 self.map(|x| (x).$method(rhs.clone()))
1890 }
1891 }
1892
1893 impl<T, S, const D: usize> $op<&T> for TensorView<T, S, D>
1898 where
1899 T: Numeric,
1900 for<'a> &'a T: NumericRef<T>,
1901 S: TensorRef<T, D>,
1902 {
1903 type Output = Tensor<T, D>;
1904 #[inline]
1905 fn $method(self, rhs: &T) -> Self::Output {
1906 self.map(|x| (x).$method(rhs.clone()))
1907 }
1908 }
1909
1910 impl<T, S, const D: usize> $op<T> for &TensorView<T, S, D>
1915 where
1916 T: Numeric,
1917 for<'a> &'a T: NumericRef<T>,
1918 S: TensorRef<T, D>,
1919 {
1920 type Output = Tensor<T, D>;
1921 #[inline]
1922 fn $method(self, rhs: T) -> Self::Output {
1923 self.map(|x| (x).$method(rhs.clone()))
1924 }
1925 }
1926
1927 impl<T, S, const D: usize> $op<T> for TensorView<T, S, D>
1932 where
1933 T: Numeric,
1934 for<'a> &'a T: NumericRef<T>,
1935 S: TensorRef<T, D>,
1936 {
1937 type Output = Tensor<T, D>;
1938 #[inline]
1939 fn $method(self, rhs: T) -> Self::Output {
1940 self.map(|x| (x).$method(rhs.clone()))
1941 }
1942 }
1943 };
1944}
1945
1946tensor_scalar!(impl Add for Tensor { fn add });
1947tensor_scalar!(impl Sub for Tensor { fn sub });
1948tensor_scalar!(impl Mul for Tensor { fn mul });
1949tensor_scalar!(impl Div for Tensor { fn div });
1950
1951tensor_view_scalar!(impl Add for TensorView { fn add });
1952tensor_view_scalar!(impl Sub for TensorView { fn sub });
1953tensor_view_scalar!(impl Mul for TensorView { fn mul });
1954tensor_view_scalar!(impl Div for TensorView { fn div });