1use crate::numeric::{Numeric, NumericRef};
52use crate::tensors::indexing::{TensorAccess, TensorReferenceIterator};
53use crate::tensors::views::{TensorIndex, TensorMut, TensorRef, TensorView};
54use crate::tensors::{Dimension, Tensor};
55
56use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
57
58#[inline]
60pub(crate) fn tensor_equality<T, S1, S2, const D: usize>(left: &S1, right: &S2) -> bool
61where
62 T: PartialEq,
63 S1: TensorRef<T, D>,
64 S2: TensorRef<T, D>,
65{
66 left.view_shape() == right.view_shape()
67 && TensorReferenceIterator::from(left)
68 .zip(TensorReferenceIterator::from(right))
69 .all(|(x, y)| x == y)
70}
71
72#[inline]
76pub(crate) fn tensor_similarity<T, S1, S2, const D: usize>(left: &S1, right: &S2) -> bool
77where
78 T: PartialEq,
79 S1: TensorRef<T, D>,
80 S2: TensorRef<T, D>,
81{
82 use crate::tensors::dimensions::names_of;
83 let left_shape = left.view_shape();
84 let access_order = names_of(&left_shape);
85 let left_access = TensorAccess::from_source_order(left);
86 let right_access = match TensorAccess::try_from(right, access_order) {
87 Ok(right_access) => right_access,
88 Err(_) => return false,
89 };
90 if left_shape != right_access.shape() {
92 return false;
93 }
94 left_access
95 .iter_reference()
96 .zip(right_access.iter_reference())
97 .all(|(x, y)| x == y)
98}
99
100impl<T: PartialEq, const D: usize> PartialEq for Tensor<T, D> {
128 fn eq(&self, other: &Self) -> bool {
129 tensor_equality(self, other)
130 }
131}
132
133impl<T, S1, S2, const D: usize> PartialEq<TensorView<T, S2, D>> for TensorView<T, S1, D>
138where
139 T: PartialEq,
140 S1: TensorRef<T, D>,
141 S2: TensorRef<T, D>,
142{
143 fn eq(&self, other: &TensorView<T, S2, D>) -> bool {
144 tensor_equality(self.source_ref(), other.source_ref())
145 }
146}
147
148impl<T, S, const D: usize> PartialEq<TensorView<T, S, D>> for Tensor<T, D>
153where
154 T: PartialEq,
155 S: TensorRef<T, D>,
156{
157 fn eq(&self, other: &TensorView<T, S, D>) -> bool {
158 tensor_equality(self, other.source_ref())
159 }
160}
161
162impl<T, S, const D: usize> PartialEq<Tensor<T, D>> for TensorView<T, S, D>
167where
168 T: PartialEq,
169 S: TensorRef<T, D>,
170{
171 fn eq(&self, other: &Tensor<T, D>) -> bool {
172 tensor_equality(self.source_ref(), other)
173 }
174}
175
176pub trait Similar<Rhs: ?Sized = Self>: private::Sealed {
184 #[must_use]
189 fn similar(&self, other: &Rhs) -> bool;
190}
191
192mod private {
193 use crate::tensors::Tensor;
194 use crate::tensors::views::{TensorRef, TensorView};
195
196 pub trait Sealed<Rhs: ?Sized = Self> {}
197
198 impl<T: PartialEq, const D: usize> Sealed for Tensor<T, D> {}
199 impl<T, S1, S2, const D: usize> Sealed<TensorView<T, S2, D>> for TensorView<T, S1, D>
200 where
201 T: PartialEq,
202 S1: TensorRef<T, D>,
203 S2: TensorRef<T, D>,
204 {
205 }
206 impl<T, S, const D: usize> Sealed<TensorView<T, S, D>> for Tensor<T, D>
207 where
208 T: PartialEq,
209 S: TensorRef<T, D>,
210 {
211 }
212 impl<T, S, const D: usize> Sealed<Tensor<T, D>> for TensorView<T, S, D>
213 where
214 T: PartialEq,
215 S: TensorRef<T, D>,
216 {
217 }
218}
219
220impl<T: PartialEq, const D: usize> Similar for Tensor<T, D> {
259 fn similar(&self, other: &Tensor<T, D>) -> bool {
260 tensor_similarity(self, other)
261 }
262}
263
264impl<T, S1, S2, const D: usize> Similar<TensorView<T, S2, D>> for TensorView<T, S1, D>
271where
272 T: PartialEq,
273 S1: TensorRef<T, D>,
274 S2: TensorRef<T, D>,
275{
276 fn similar(&self, other: &TensorView<T, S2, D>) -> bool {
277 tensor_similarity(self.source_ref(), other.source_ref())
278 }
279}
280
281impl<T, S, const D: usize> Similar<TensorView<T, S, D>> for Tensor<T, D>
288where
289 T: PartialEq,
290 S: TensorRef<T, D>,
291{
292 fn similar(&self, other: &TensorView<T, S, D>) -> bool {
293 tensor_similarity(self, other.source_ref())
294 }
295}
296
297impl<T, S, const D: usize> Similar<Tensor<T, D>> for TensorView<T, S, D>
304where
305 T: PartialEq,
306 S: TensorRef<T, D>,
307{
308 fn similar(&self, other: &Tensor<T, D>) -> bool {
309 tensor_similarity(self.source_ref(), other)
310 }
311}
312
313#[track_caller]
314#[inline]
315fn assert_same_dimensions<const D: usize>(
316 left_shape: [(Dimension, usize); D],
317 right_shape: [(Dimension, usize); D],
318) {
319 if left_shape != right_shape {
320 panic!(
321 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
322 left_shape, right_shape
323 );
324 }
325}
326
327#[track_caller]
328#[inline]
329fn tensor_view_addition_iter<'l, 'r, T, S1, S2, const D: usize>(
330 left_iter: S1,
331 left_shape: [(Dimension, usize); D],
332 right_iter: S2,
333 right_shape: [(Dimension, usize); D],
334) -> Tensor<T, D>
335where
336 T: Numeric,
337 T: 'l,
338 T: 'r,
339 for<'a> &'a T: NumericRef<T>,
340 S1: Iterator<Item = &'l T>,
341 S2: Iterator<Item = &'r T>,
342{
343 assert_same_dimensions(left_shape, right_shape);
344 Tensor::from(
346 left_shape,
347 left_iter.zip(right_iter).map(|(x, y)| x + y).collect(),
348 )
349}
350
351#[track_caller]
352#[inline]
353fn tensor_view_subtraction_iter<'l, 'r, T, S1, S2, const D: usize>(
354 left_iter: S1,
355 left_shape: [(Dimension, usize); D],
356 right_iter: S2,
357 right_shape: [(Dimension, usize); D],
358) -> Tensor<T, D>
359where
360 T: Numeric,
361 T: 'l,
362 T: 'r,
363 for<'a> &'a T: NumericRef<T>,
364 S1: Iterator<Item = &'l T>,
365 S2: Iterator<Item = &'r T>,
366{
367 assert_same_dimensions(left_shape, right_shape);
368 Tensor::from(
370 left_shape,
371 left_iter.zip(right_iter).map(|(x, y)| x - y).collect(),
372 )
373}
374
375#[track_caller]
376#[inline]
377fn tensor_view_assign_addition_iter<'l, 'r, T, S1, S2, const D: usize>(
378 left_iter: S1,
379 left_shape: [(Dimension, usize); D],
380 right_iter: S2,
381 right_shape: [(Dimension, usize); D],
382) where
383 T: Numeric,
384 T: 'l,
385 T: 'r,
386 for<'a> &'a T: NumericRef<T>,
387 S1: Iterator<Item = &'l mut T>,
388 S2: Iterator<Item = &'r T>,
389{
390 assert_same_dimensions(left_shape, right_shape);
391 for (x, y) in left_iter.zip(right_iter) {
392 *x = x.clone() + y;
396 }
397}
398
399#[track_caller]
400#[inline]
401fn tensor_view_assign_subtraction_iter<'l, 'r, T, S1, S2, const D: usize>(
402 left_iter: S1,
403 left_shape: [(Dimension, usize); D],
404 right_iter: S2,
405 right_shape: [(Dimension, usize); D],
406) where
407 T: Numeric,
408 T: 'l,
409 T: 'r,
410 for<'a> &'a T: NumericRef<T>,
411 S1: Iterator<Item = &'l mut T>,
412 S2: Iterator<Item = &'r T>,
413{
414 assert_same_dimensions(left_shape, right_shape);
415 for (x, y) in left_iter.zip(right_iter) {
416 *x = x.clone() - y;
420 }
421}
422
423#[inline]
430pub(crate) fn scalar_product<'l, 'r, T, S1, S2>(left_iter: S1, right_iter: S2) -> T
431where
432 T: Numeric,
433 T: 'l,
434 T: 'r,
435 for<'a> &'a T: NumericRef<T>,
436 S1: Iterator<Item = &'l T>,
437 S2: Iterator<Item = &'r T>,
438{
439 left_iter
442 .zip(right_iter)
443 .map(|(x, y)| x * y)
444 .reduce(|x, y| x + y)
445 .unwrap() }
447
448#[track_caller]
449#[inline]
450fn tensor_view_vector_product_iter<'l, 'r, T, S1, S2>(
451 left_iter: S1,
452 left_shape: [(Dimension, usize); 1],
453 right_iter: S2,
454 right_shape: [(Dimension, usize); 1],
455) -> T
456where
457 T: Numeric,
458 T: 'l,
459 T: 'r,
460 for<'a> &'a T: NumericRef<T>,
461 S1: Iterator<Item = &'l T>,
462 S2: Iterator<Item = &'r T>,
463{
464 if left_shape[0].1 != right_shape[0].1 {
465 panic!(
466 "Dimension lengths of left and right tensors are not the same: (left: {:?}, right: {:?})",
467 left_shape, right_shape
468 );
469 }
470 scalar_product::<T, S1, S2>(left_iter, right_iter)
472}
473
474#[track_caller]
475#[inline]
476fn tensor_view_matrix_product<T, S1, S2>(left: S1, right: S2) -> Tensor<T, 2>
477where
478 T: Numeric,
479 for<'a> &'a T: NumericRef<T>,
480 S1: TensorRef<T, 2>,
481 S2: TensorRef<T, 2>,
482{
483 let left_shape = left.view_shape();
484 let right_shape = right.view_shape();
485 if left_shape[1].1 != right_shape[0].1 {
486 panic!(
487 "Mismatched tensors, left is {:?}, right is {:?}, * is only defined for MxN * NxL dimension lengths",
488 left.view_shape(),
489 right.view_shape()
490 );
491 }
492 if left_shape[0].0 == right_shape[1].0 {
493 panic!(
494 "Matrix multiplication of tensors with shapes left {:?} and right {:?} would \
495 create duplicate dimension names as the shape {:?}. Rename one or both of the \
496 dimension names in the input to prevent this. * is defined as MxN * NxL = MxL",
497 left_shape,
498 right_shape,
499 [left_shape[0], right_shape[1]]
500 )
501 }
502 let mut tensor = Tensor::empty([left.view_shape()[0], right.view_shape()[1]], T::zero());
507 for ([i, j], x) in tensor.iter_reference_mut().with_index() {
508 let left = TensorIndex::from(&left, [(left.view_shape()[0].0, i)]);
510 let right = TensorIndex::from(&right, [(right.view_shape()[1].0, j)]);
512 *x = scalar_product::<T, _, _>(
514 TensorReferenceIterator::from(&left),
515 TensorReferenceIterator::from(&right),
516 )
517 }
518 tensor
519}
520
521#[test]
522fn test_matrix_product() {
523 #[rustfmt::skip]
524 let left = Tensor::from([("r", 2), ("c", 3)], vec![
525 1, 2, 3,
526 4, 5, 6
527 ]);
528 #[rustfmt::skip]
529 let right = Tensor::from([("r", 3), ("c", 2)], vec![
530 10, 11,
531 12, 13,
532 14, 15
533 ]);
534 let result = tensor_view_matrix_product::<i32, _, _>(left, right);
535 #[rustfmt::skip]
536 assert_eq!(
537 result,
538 Tensor::from(
539 [("r", 2), ("c", 2)],
540 vec![
541 1 * 10 + 2 * 12 + 3 * 14, 1 * 11 + 2 * 13 + 3 * 15,
542 4 * 10 + 5 * 12 + 6 * 14, 4 * 11 + 5 * 13 + 6 * 15
543 ]
544 )
545 );
546}
547
548macro_rules! tensor_view_reference_tensor_view_reference_operation_iter {
552 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
553 #[doc=$doc]
554 impl<T, S1, S2, const D: usize> $op<&TensorView<T, S2, D>> for &TensorView<T, S1, D>
555 where
556 T: Numeric,
557 for<'a> &'a T: NumericRef<T>,
558 S1: TensorRef<T, D>,
559 S2: TensorRef<T, D>,
560 {
561 type Output = Tensor<T, D>;
562
563 #[track_caller]
564 #[inline]
565 fn $method(self, rhs: &TensorView<T, S2, D>) -> Self::Output {
566 $implementation::<T, _, _, D>(
567 self.iter_reference(),
568 self.shape(),
569 rhs.iter_reference(),
570 rhs.shape(),
571 )
572 }
573 }
574 };
575}
576
577macro_rules! tensor_view_reference_tensor_view_reference_operation {
578 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
579 #[doc=$doc]
580 impl<T, S1, S2> $op<&TensorView<T, S2, $d>> for &TensorView<T, S1, $d>
581 where
582 T: Numeric,
583 for<'a> &'a T: NumericRef<T>,
584 S1: TensorRef<T, $d>,
585 S2: TensorRef<T, $d>,
586 {
587 type Output = Tensor<T, $d>;
588
589 #[track_caller]
590 #[inline]
591 fn $method(self, rhs: &TensorView<T, S2, $d>) -> Self::Output {
592 $implementation::<T, _, _>(self.source_ref(), rhs.source_ref())
593 }
594 }
595 };
596}
597
598tensor_view_reference_tensor_view_reference_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for two referenced tensor views");
599tensor_view_reference_tensor_view_reference_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two referenced tensor views");
600tensor_view_reference_tensor_view_reference_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication of two referenced 2-dimensional tensors");
601
602macro_rules! tensor_view_assign_tensor_view_reference_operation_iter {
603 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
604 #[doc=$doc]
605 impl<T, S1, S2, const D: usize> $op<&TensorView<T, S2, D>> for TensorView<T, S1, D>
606 where
607 T: Numeric,
608 for<'a> &'a T: NumericRef<T>,
609 S1: TensorMut<T, D>,
610 S2: TensorRef<T, D>,
611 {
612 #[track_caller]
613 #[inline]
614 fn $method(&mut self, rhs: &TensorView<T, S2, D>) {
615 let left_shape = self.shape();
616 $implementation::<T, _, _, D>(
617 self.iter_reference_mut(),
618 left_shape,
619 rhs.iter_reference(),
620 rhs.shape(),
621 )
622 }
623 }
624 };
625}
626
627tensor_view_assign_tensor_view_reference_operation_iter!(impl AddAssign for TensorView { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for two referenced tensor views");
628tensor_view_assign_tensor_view_reference_operation_iter!(impl SubAssign for TensorView { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for two referenced tensor views");
629
630macro_rules! tensor_view_reference_tensor_view_value_operation_iter {
631 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
632 #[doc=$doc]
633 impl<T, S1, S2, const D: usize> $op<TensorView<T, S2, D>> for &TensorView<T, S1, D>
634 where
635 T: Numeric,
636 for<'a> &'a T: NumericRef<T>,
637 S1: TensorRef<T, D>,
638 S2: TensorRef<T, D>,
639 {
640 type Output = Tensor<T, D>;
641
642 #[track_caller]
643 #[inline]
644 fn $method(self, rhs: TensorView<T, S2, D>) -> Self::Output {
645 $implementation::<T, _, _, D>(
646 self.iter_reference(),
647 self.shape(),
648 rhs.iter_reference(),
649 rhs.shape(),
650 )
651 }
652 }
653 };
654}
655
656macro_rules! tensor_view_reference_tensor_view_value_operation {
657 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
658 #[doc=$doc]
659 impl<T, S1, S2> $op<TensorView<T, S2, $d>> for &TensorView<T, S1, $d>
660 where
661 T: Numeric,
662 for<'a> &'a T: NumericRef<T>,
663 S1: TensorRef<T, $d>,
664 S2: TensorRef<T, $d>,
665 {
666 type Output = Tensor<T, $d>;
667
668 #[track_caller]
669 #[inline]
670 fn $method(self, rhs: TensorView<T, S2, $d>) -> Self::Output {
671 $implementation::<T, _, _>(self.source_ref(), rhs.source_ref())
672 }
673 }
674 };
675}
676
677tensor_view_reference_tensor_view_value_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for two tensor views with one referenced");
678tensor_view_reference_tensor_view_value_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensor views with one referenced");
679tensor_view_reference_tensor_view_value_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication of two 2-dimensional tensors with one referenced");
680
681macro_rules! tensor_view_assign_tensor_view_value_operation_iter {
682 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
683 #[doc=$doc]
684 impl<T, S1, S2, const D: usize> $op<TensorView<T, S2, D>> for TensorView<T, S1, D>
685 where
686 T: Numeric,
687 for<'a> &'a T: NumericRef<T>,
688 S1: TensorMut<T, D>,
689 S2: TensorRef<T, D>,
690 {
691 #[track_caller]
692 #[inline]
693 fn $method(&mut self, rhs: TensorView<T, S2, D>) {
694 let left_shape = self.shape();
695 $implementation::<T, _, _, D>(
696 self.iter_reference_mut(),
697 left_shape,
698 rhs.iter_reference(),
699 rhs.shape(),
700 )
701 }
702 }
703 };
704}
705
706tensor_view_assign_tensor_view_value_operation_iter!(impl AddAssign for TensorView { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for two tensor views with one referenced");
707tensor_view_assign_tensor_view_value_operation_iter!(impl SubAssign for TensorView { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for two tensor views with one referenced");
708
709macro_rules! tensor_view_value_tensor_view_reference_operation_iter {
710 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
711 #[doc=$doc]
712 impl<T, S1, S2, const D: usize> $op<&TensorView<T, S2, D>> for TensorView<T, S1, D>
713 where
714 T: Numeric,
715 for<'a> &'a T: NumericRef<T>,
716 S1: TensorRef<T, D>,
717 S2: TensorRef<T, D>,
718 {
719 type Output = Tensor<T, D>;
720
721 #[track_caller]
722 #[inline]
723 fn $method(self, rhs: &TensorView<T, S2, D>) -> Self::Output {
724 $implementation::<T, _, _, D>(
725 self.iter_reference(),
726 self.shape(),
727 rhs.iter_reference(),
728 rhs.shape(),
729 )
730 }
731 }
732 };
733}
734
735macro_rules! tensor_view_value_tensor_view_reference_operation {
736 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
737 #[doc=$doc]
738 impl<T, S1, S2> $op<&TensorView<T, S2, $d>> for TensorView<T, S1, $d>
739 where
740 T: Numeric,
741 for<'a> &'a T: NumericRef<T>,
742 S1: TensorRef<T, $d>,
743 S2: TensorRef<T, $d>,
744 {
745 type Output = Tensor<T, $d>;
746
747 #[track_caller]
748 #[inline]
749 fn $method(self, rhs: &TensorView<T, S2, $d>) -> Self::Output {
750 $implementation::<T, _, _>(self.source_ref(), rhs.source_ref())
751 }
752 }
753 };
754}
755
756tensor_view_value_tensor_view_reference_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for two tensor views with one referenced");
757tensor_view_value_tensor_view_reference_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensor views with one referenced");
758tensor_view_value_tensor_view_reference_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication of two 2-dimensional tensors with one referenced");
759
760macro_rules! tensor_view_value_tensor_view_value_operation_iter {
761 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
762 #[doc=$doc]
763 impl<T, S1, S2, const D: usize> $op<TensorView<T, S2, D>> for TensorView<T, S1, D>
764 where
765 T: Numeric,
766 for<'a> &'a T: NumericRef<T>,
767 S1: TensorRef<T, D>,
768 S2: TensorRef<T, D>,
769 {
770 type Output = Tensor<T, D>;
771
772 #[track_caller]
773 #[inline]
774 fn $method(self, rhs: TensorView<T, S2, D>) -> Self::Output {
775 $implementation::<T, _, _, D>(
776 self.iter_reference(),
777 self.shape(),
778 rhs.iter_reference(),
779 rhs.shape(),
780 )
781 }
782 }
783 };
784}
785
786macro_rules! tensor_view_value_tensor_view_value_operation {
787 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
788 #[doc=$doc]
789 impl<T, S1, S2> $op<TensorView<T, S2, $d>> for TensorView<T, S1, $d>
790 where
791 T: Numeric,
792 for<'a> &'a T: NumericRef<T>,
793 S1: TensorRef<T, $d>,
794 S2: TensorRef<T, $d>,
795 {
796 type Output = Tensor<T, $d>;
797
798 #[track_caller]
799 #[inline]
800 fn $method(self, rhs: TensorView<T, S2, $d>) -> Self::Output {
801 $implementation::<T, _, _>(self.source_ref(), rhs.source_ref())
802 }
803 }
804 };
805}
806
807tensor_view_value_tensor_view_value_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for two tensor views");
808tensor_view_value_tensor_view_value_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensor views");
809tensor_view_value_tensor_view_value_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication of two 2-dimensional tensors");
810
811macro_rules! tensor_view_reference_tensor_reference_operation_iter {
812 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
813 #[doc=$doc]
814 impl<T, S, const D: usize> $op<&Tensor<T, D>> for &TensorView<T, S, D>
815 where
816 T: Numeric,
817 for<'a> &'a T: NumericRef<T>,
818 S: TensorRef<T, D>,
819 {
820 type Output = Tensor<T, D>;
821
822 #[track_caller]
823 #[inline]
824 fn $method(self, rhs: &Tensor<T, D>) -> Self::Output {
825 $implementation::<T, _, _, D>(
826 self.iter_reference(),
827 self.shape(),
828 rhs.direct_iter_reference(),
829 rhs.shape(),
830 )
831 }
832 }
833 };
834}
835
836macro_rules! tensor_view_reference_tensor_reference_operation {
837 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
838 #[doc=$doc]
839 impl<T, S> $op<&Tensor<T, $d>> for &TensorView<T, S, $d>
840 where
841 T: Numeric,
842 for<'a> &'a T: NumericRef<T>,
843 S: TensorRef<T, $d>,
844 {
845 type Output = Tensor<T, $d>;
846
847 #[track_caller]
848 #[inline]
849 fn $method(self, rhs: &Tensor<T, $d>) -> Self::Output {
850 $implementation::<T, _, _>(self.source_ref(), rhs)
851 }
852 }
853 };
854}
855
856tensor_view_reference_tensor_reference_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for a referenced tensor view and a referenced tensor");
857tensor_view_reference_tensor_reference_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a referenced tensor view and a referenced tensor");
858tensor_view_reference_tensor_reference_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional referenced tensor view and a referenced tensor");
859
860macro_rules! tensor_view_assign_tensor_reference_operation_iter {
861 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
862 #[doc=$doc]
863 impl<T, S, const D: usize> $op<&Tensor<T, D>> for TensorView<T, S, D>
864 where
865 T: Numeric,
866 for<'a> &'a T: NumericRef<T>,
867 S: TensorMut<T, D>,
868 {
869 #[track_caller]
870 #[inline]
871 fn $method(&mut self, rhs: &Tensor<T, D>) {
872 let left_shape = self.shape();
873 $implementation::<T, _, _, D>(
874 self.iter_reference_mut(),
875 left_shape,
876 rhs.direct_iter_reference(),
877 rhs.shape(),
878 )
879 }
880 }
881 };
882}
883
884tensor_view_assign_tensor_reference_operation_iter!(impl AddAssign for TensorView { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for a referenced tensor view and a referenced tensor");
885tensor_view_assign_tensor_reference_operation_iter!(impl SubAssign for TensorView { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for a referenced tensor view and a referenced tensor");
886
887macro_rules! tensor_view_reference_tensor_value_operation_iter {
888 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
889 #[doc=$doc]
890 impl<T, S, const D: usize> $op<Tensor<T, D>> for &TensorView<T, S, D>
891 where
892 T: Numeric,
893 for<'a> &'a T: NumericRef<T>,
894 S: TensorRef<T, D>,
895 {
896 type Output = Tensor<T, D>;
897
898 #[track_caller]
899 #[inline]
900 fn $method(self, rhs: Tensor<T, D>) -> Self::Output {
901 $implementation::<T, _, _, D>(
902 self.iter_reference(),
903 self.shape(),
904 rhs.direct_iter_reference(),
905 rhs.shape(),
906 )
907 }
908 }
909 };
910}
911
912macro_rules! tensor_view_reference_tensor_value_operation {
913 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
914 #[doc=$doc]
915 impl<T, S> $op<Tensor<T, $d>> for &TensorView<T, S, $d>
916 where
917 T: Numeric,
918 for<'a> &'a T: NumericRef<T>,
919 S: TensorRef<T, $d>,
920 {
921 type Output = Tensor<T, $d>;
922
923 #[track_caller]
924 #[inline]
925 fn $method(self, rhs: Tensor<T, $d>) -> Self::Output {
926 $implementation::<T, _, _>(self.source_ref(), rhs)
927 }
928 }
929 };
930}
931
932tensor_view_reference_tensor_value_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for a referenced tensor view and a tensor");
933tensor_view_reference_tensor_value_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a referenced tensor view and a tensor");
934tensor_view_reference_tensor_value_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional referenced tensor view and a tensor");
935
936macro_rules! tensor_view_assign_tensor_value_operation_iter {
937 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
938 #[doc=$doc]
939 impl<T, S, const D: usize> $op<Tensor<T, D>> for TensorView<T, S, D>
940 where
941 T: Numeric,
942 for<'a> &'a T: NumericRef<T>,
943 S: TensorMut<T, D>,
944 {
945 #[track_caller]
946 #[inline]
947 fn $method(&mut self, rhs: Tensor<T, D>) {
948 let left_shape = self.shape();
949 $implementation::<T, _, _, D>(
950 self.iter_reference_mut(),
951 left_shape,
952 rhs.direct_iter_reference(),
953 rhs.shape(),
954 )
955 }
956 }
957 };
958}
959
960tensor_view_assign_tensor_value_operation_iter!(impl AddAssign for TensorView { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for a referenced tensor view and a tensor");
961tensor_view_assign_tensor_value_operation_iter!(impl SubAssign for TensorView { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for a referenced tensor view and a tensor");
962
963macro_rules! tensor_view_value_tensor_reference_operation_iter {
964 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
965 #[doc=$doc]
966 impl<T, S, const D: usize> $op<&Tensor<T, D>> for TensorView<T, S, D>
967 where
968 T: Numeric,
969 for<'a> &'a T: NumericRef<T>,
970 S: TensorRef<T, D>,
971 {
972 type Output = Tensor<T, D>;
973
974 #[track_caller]
975 #[inline]
976 fn $method(self, rhs: &Tensor<T, D>) -> Self::Output {
977 $implementation::<T, _, _, D>(
978 self.iter_reference(),
979 self.shape(),
980 rhs.direct_iter_reference(),
981 rhs.shape(),
982 )
983 }
984 }
985 };
986}
987
988macro_rules! tensor_view_value_tensor_reference_operation {
989 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
990 #[doc=$doc]
991 impl<T, S> $op<&Tensor<T, $d>> for TensorView<T, S, $d>
992 where
993 T: Numeric,
994 for<'a> &'a T: NumericRef<T>,
995 S: TensorRef<T, $d>,
996 {
997 type Output = Tensor<T, $d>;
998
999 #[track_caller]
1000 #[inline]
1001 fn $method(self, rhs: &Tensor<T, $d>) -> Self::Output {
1002 $implementation::<T, _, _>(self.source_ref(), rhs)
1003 }
1004 }
1005 };
1006}
1007
1008tensor_view_value_tensor_reference_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for a tensor view and a referenced tensor");
1009tensor_view_value_tensor_reference_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a tensor view and a referenced tensor");
1010tensor_view_value_tensor_reference_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional tensor view and a referenced tensor");
1011
1012macro_rules! tensor_view_value_tensor_value_operation_iter {
1013 (impl $op:tt for TensorView { fn $method:ident } $implementation:ident $doc:tt) => {
1014 #[doc=$doc]
1015 impl<T, S, const D: usize> $op<Tensor<T, D>> for TensorView<T, S, D>
1016 where
1017 T: Numeric,
1018 for<'a> &'a T: NumericRef<T>,
1019 S: TensorRef<T, D>,
1020 {
1021 type Output = Tensor<T, D>;
1022
1023 #[track_caller]
1024 #[inline]
1025 fn $method(self, rhs: Tensor<T, D>) -> Self::Output {
1026 $implementation::<T, _, _, D>(
1027 self.iter_reference(),
1028 self.shape(),
1029 rhs.direct_iter_reference(),
1030 rhs.shape(),
1031 )
1032 }
1033 }
1034 };
1035}
1036
1037macro_rules! tensor_view_value_tensor_value_operation {
1038 (impl $op:tt for TensorView $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1039 #[doc=$doc]
1040 impl<T, S> $op<Tensor<T, $d>> for TensorView<T, S, $d>
1041 where
1042 T: Numeric,
1043 for<'a> &'a T: NumericRef<T>,
1044 S: TensorRef<T, $d>,
1045 {
1046 type Output = Tensor<T, $d>;
1047
1048 #[track_caller]
1049 #[inline]
1050 fn $method(self, rhs: Tensor<T, $d>) -> Self::Output {
1051 $implementation::<T, _, _>(self.source_ref(), rhs)
1052 }
1053 }
1054 };
1055}
1056
1057tensor_view_value_tensor_value_operation_iter!(impl Add for TensorView { fn add } tensor_view_addition_iter "Elementwise addition for a tensor view and a tensor");
1058tensor_view_value_tensor_value_operation_iter!(impl Sub for TensorView { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a tensor view and a tensor");
1059tensor_view_value_tensor_value_operation!(impl Mul for TensorView 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional tensor view and a tensor");
1060
1061macro_rules! tensor_reference_tensor_view_reference_operation_iter {
1062 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1063 #[doc=$doc]
1064 impl<T, S, const D: usize> $op<&TensorView<T, S, D>> for &Tensor<T, D>
1065 where
1066 T: Numeric,
1067 for<'a> &'a T: NumericRef<T>,
1068 S: TensorRef<T, D>,
1069 {
1070 type Output = Tensor<T, D>;
1071
1072 #[track_caller]
1073 #[inline]
1074 fn $method(self, rhs: &TensorView<T, S, D>) -> Self::Output {
1075 $implementation::<T, _, _, D>(
1076 self.direct_iter_reference(),
1077 self.shape(),
1078 rhs.iter_reference(),
1079 rhs.shape(),
1080 )
1081 }
1082 }
1083 };
1084}
1085
1086macro_rules! tensor_reference_tensor_view_reference_operation {
1087 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1088 #[doc=$doc]
1089 impl<T, S> $op<&TensorView<T, S, $d>> for &Tensor<T, $d>
1090 where
1091 T: Numeric,
1092 for<'a> &'a T: NumericRef<T>,
1093 S: TensorRef<T, $d>,
1094 {
1095 type Output = Tensor<T, $d>;
1096
1097 #[track_caller]
1098 #[inline]
1099 fn $method(self, rhs: &TensorView<T, S, $d>) -> Self::Output {
1100 $implementation::<T, _, _>(self, rhs.source_ref())
1101 }
1102 }
1103 };
1104}
1105
1106tensor_reference_tensor_view_reference_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for a referenced tensor and a referenced tensor view");
1107tensor_reference_tensor_view_reference_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a referenced tensor and a referenced tensor view");
1108tensor_reference_tensor_view_reference_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional referenced tensor and a referenced tensor view");
1109
1110macro_rules! tensor_assign_tensor_view_reference_operation_iter {
1111 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1112 #[doc=$doc]
1113 impl<T, S, const D: usize> $op<&TensorView<T, S, D>> for Tensor<T, D>
1114 where
1115 T: Numeric,
1116 for<'a> &'a T: NumericRef<T>,
1117 S: TensorRef<T, D>,
1118 {
1119 #[track_caller]
1120 #[inline]
1121 fn $method(&mut self, rhs: &TensorView<T, S, D>) {
1122 let left_shape = self.shape();
1123 $implementation::<T, _, _, D>(
1124 self.direct_iter_reference_mut(),
1125 left_shape,
1126 rhs.iter_reference(),
1127 rhs.shape(),
1128 )
1129 }
1130 }
1131 };
1132}
1133
1134tensor_assign_tensor_view_reference_operation_iter!(impl AddAssign for Tensor { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for a referenced tensor and a referenced tensor view");
1135tensor_assign_tensor_view_reference_operation_iter!(impl SubAssign for Tensor { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for a referenced tensor and a referenced tensor view");
1136
1137macro_rules! tensor_reference_tensor_view_value_operation_iter {
1138 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1139 #[doc=$doc]
1140 impl<T, S, const D: usize> $op<TensorView<T, S, D>> for &Tensor<T, D>
1141 where
1142 T: Numeric,
1143 for<'a> &'a T: NumericRef<T>,
1144 S: TensorRef<T, D>,
1145 {
1146 type Output = Tensor<T, D>;
1147
1148 #[track_caller]
1149 #[inline]
1150 fn $method(self, rhs: TensorView<T, S, D>) -> Self::Output {
1151 $implementation::<T, _, _, D>(
1152 self.direct_iter_reference(),
1153 self.shape(),
1154 rhs.iter_reference(),
1155 rhs.shape(),
1156 )
1157 }
1158 }
1159 };
1160}
1161
1162macro_rules! tensor_reference_tensor_view_value_operation {
1163 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1164 #[doc=$doc]
1165 impl<T, S> $op<TensorView<T, S, $d>> for &Tensor<T, $d>
1166 where
1167 T: Numeric,
1168 for<'a> &'a T: NumericRef<T>,
1169 S: TensorRef<T, $d>,
1170 {
1171 type Output = Tensor<T, $d>;
1172
1173 #[track_caller]
1174 #[inline]
1175 fn $method(self, rhs: TensorView<T, S, $d>) -> Self::Output {
1176 $implementation::<T, _, _>(self, rhs.source_ref())
1177 }
1178 }
1179 };
1180}
1181
1182tensor_reference_tensor_view_value_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for a referenced tensor and a tensor view");
1183tensor_reference_tensor_view_value_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a referenced tensor and a tensor view");
1184tensor_reference_tensor_view_value_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional referenced tensor and a tensor view");
1185
1186macro_rules! tensor_assign_tensor_view_value_operation_iter {
1187 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1188 #[doc=$doc]
1189 impl<T, S, const D: usize> $op<TensorView<T, S, D>> for Tensor<T, D>
1190 where
1191 T: Numeric,
1192 for<'a> &'a T: NumericRef<T>,
1193 S: TensorRef<T, D>,
1194 {
1195 #[track_caller]
1196 #[inline]
1197 fn $method(&mut self, rhs: TensorView<T, S, D>) {
1198 let left_shape = self.shape();
1199 $implementation::<T, _, _, D>(
1200 self.direct_iter_reference_mut(),
1201 left_shape,
1202 rhs.iter_reference(),
1203 rhs.shape(),
1204 )
1205 }
1206 }
1207 };
1208}
1209
1210tensor_assign_tensor_view_value_operation_iter!(impl AddAssign for Tensor { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for a referenced tensor and a tensor view");
1211tensor_assign_tensor_view_value_operation_iter!(impl SubAssign for Tensor { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for a referenced tensor and a tensor view");
1212
1213macro_rules! tensor_value_tensor_view_reference_operation_iter {
1214 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1215 #[doc=$doc]
1216 impl<T, S, const D: usize> $op<&TensorView<T, S, D>> for Tensor<T, D>
1217 where
1218 T: Numeric,
1219 for<'a> &'a T: NumericRef<T>,
1220 S: TensorRef<T, D>,
1221 {
1222 type Output = Tensor<T, D>;
1223
1224 #[track_caller]
1225 #[inline]
1226 fn $method(self, rhs: &TensorView<T, S, D>) -> Self::Output {
1227 $implementation::<T, _, _, D>(
1228 self.direct_iter_reference(),
1229 self.shape(),
1230 rhs.iter_reference(),
1231 rhs.shape(),
1232 )
1233 }
1234 }
1235 };
1236}
1237
1238macro_rules! tensor_value_tensor_view_reference_operation {
1239 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1240 #[doc=$doc]
1241 impl<T, S> $op<&TensorView<T, S, $d>> for Tensor<T, $d>
1242 where
1243 T: Numeric,
1244 for<'a> &'a T: NumericRef<T>,
1245 S: TensorRef<T, $d>,
1246 {
1247 type Output = Tensor<T, $d>;
1248
1249 #[track_caller]
1250 #[inline]
1251 fn $method(self, rhs: &TensorView<T, S, $d>) -> Self::Output {
1252 $implementation::<T, _, _>(self, rhs.source_ref())
1253 }
1254 }
1255 };
1256}
1257
1258tensor_value_tensor_view_reference_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for a tensor and a referenced tensor view");
1259tensor_value_tensor_view_reference_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a tensor and a referenced tensor view");
1260tensor_value_tensor_view_reference_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional tensor and a referenced tensor view");
1261
1262macro_rules! tensor_value_tensor_view_value_operation_iter {
1263 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1264 #[doc=$doc]
1265 impl<T, S, const D: usize> $op<TensorView<T, S, D>> for Tensor<T, D>
1266 where
1267 T: Numeric,
1268 for<'a> &'a T: NumericRef<T>,
1269 S: TensorRef<T, D>,
1270 {
1271 type Output = Tensor<T, D>;
1272
1273 #[track_caller]
1274 #[inline]
1275 fn $method(self, rhs: TensorView<T, S, D>) -> Self::Output {
1276 $implementation::<T, _, _, D>(
1277 self.direct_iter_reference(),
1278 self.shape(),
1279 rhs.iter_reference(),
1280 rhs.shape(),
1281 )
1282 }
1283 }
1284 };
1285}
1286
1287macro_rules! tensor_value_tensor_view_value_operation {
1288 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1289 #[doc=$doc]
1290 impl<T, S> $op<TensorView<T, S, $d>> for Tensor<T, $d>
1291 where
1292 T: Numeric,
1293 for<'a> &'a T: NumericRef<T>,
1294 S: TensorRef<T, $d>,
1295 {
1296 type Output = Tensor<T, $d>;
1297
1298 #[track_caller]
1299 #[inline]
1300 fn $method(self, rhs: TensorView<T, S, $d>) -> Self::Output {
1301 $implementation::<T, _, _>(self, rhs.source_ref())
1302 }
1303 }
1304 };
1305}
1306
1307tensor_value_tensor_view_value_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for a tensor and a tensor view");
1308tensor_value_tensor_view_value_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for a tensor and a tensor view");
1309tensor_value_tensor_view_value_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for a 2-dimensional tensor and a tensor view");
1310
1311macro_rules! tensor_reference_tensor_reference_operation_iter {
1312 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1313 #[doc=$doc]
1314 impl<T, const D: usize> $op<&Tensor<T, D>> for &Tensor<T, D>
1315 where
1316 T: Numeric,
1317 for<'a> &'a T: NumericRef<T>,
1318 {
1319 type Output = Tensor<T, D>;
1320
1321 #[track_caller]
1322 #[inline]
1323 fn $method(self, rhs: &Tensor<T, D>) -> Self::Output {
1324 $implementation::<T, _, _, D>(
1325 self.direct_iter_reference(),
1326 self.shape(),
1327 rhs.direct_iter_reference(),
1328 rhs.shape(),
1329 )
1330 }
1331 }
1332 };
1333}
1334
1335macro_rules! tensor_reference_tensor_reference_operation {
1336 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1337 #[doc=$doc]
1338 impl<T> $op<&Tensor<T, $d>> for &Tensor<T, $d>
1339 where
1340 T: Numeric,
1341 for<'a> &'a T: NumericRef<T>,
1342 {
1343 type Output = Tensor<T, $d>;
1344
1345 #[track_caller]
1346 #[inline]
1347 fn $method(self, rhs: &Tensor<T, $d>) -> Self::Output {
1348 $implementation::<T, _, _>(self, rhs)
1349 }
1350 }
1351 };
1352}
1353
1354tensor_reference_tensor_reference_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for two referenced tensors");
1355tensor_reference_tensor_reference_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two referenced tensors");
1356tensor_reference_tensor_reference_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for two 2-dimensional referenced tensors");
1357
1358macro_rules! tensor_assign_tensor_reference_operation_iter {
1359 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1360 #[doc=$doc]
1361 impl<T, const D: usize> $op<&Tensor<T, D>> for Tensor<T, D>
1362 where
1363 T: Numeric,
1364 for<'a> &'a T: NumericRef<T>,
1365 {
1366 #[track_caller]
1367 #[inline]
1368 fn $method(&mut self, rhs: &Tensor<T, D>) {
1369 let left_shape = self.shape();
1370 $implementation::<T, _, _, D>(
1371 self.direct_iter_reference_mut(),
1372 left_shape,
1373 rhs.direct_iter_reference(),
1374 rhs.shape(),
1375 )
1376 }
1377 }
1378 };
1379}
1380
1381tensor_assign_tensor_reference_operation_iter!(impl AddAssign for Tensor { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for two referenced tensors");
1382tensor_assign_tensor_reference_operation_iter!(impl SubAssign for Tensor { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for two referenced tensors");
1383
1384macro_rules! tensor_reference_tensor_value_operation_iter {
1385 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1386 #[doc=$doc]
1387 impl<T, const D: usize> $op<Tensor<T, D>> for &Tensor<T, D>
1388 where
1389 T: Numeric,
1390 for<'a> &'a T: NumericRef<T>,
1391 {
1392 type Output = Tensor<T, D>;
1393
1394 #[track_caller]
1395 #[inline]
1396 fn $method(self, rhs: Tensor<T, D>) -> Self::Output {
1397 $implementation::<T, _, _, D>(
1398 self.direct_iter_reference(),
1399 self.shape(),
1400 rhs.direct_iter_reference(),
1401 rhs.shape(),
1402 )
1403 }
1404 }
1405 };
1406}
1407
1408macro_rules! tensor_reference_tensor_value_operation {
1409 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1410 #[doc=$doc]
1411 impl<T> $op<Tensor<T, $d>> for &Tensor<T, $d>
1412 where
1413 T: Numeric,
1414 for<'a> &'a T: NumericRef<T>,
1415 {
1416 type Output = Tensor<T, $d>;
1417
1418 #[track_caller]
1419 #[inline]
1420 fn $method(self, rhs: Tensor<T, $d>) -> Self::Output {
1421 $implementation::<T, _, _>(self, rhs)
1422 }
1423 }
1424 };
1425}
1426
1427tensor_reference_tensor_value_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for two tensors with one referenced");
1428tensor_reference_tensor_value_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensors with one referenced");
1429tensor_reference_tensor_value_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for two 2-dimensional tensors with one referenced");
1430
1431macro_rules! tensor_assign_tensor_value_operation_iter {
1432 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1433 #[doc=$doc]
1434 impl<T, const D: usize> $op<Tensor<T, D>> for Tensor<T, D>
1435 where
1436 T: Numeric,
1437 for<'a> &'a T: NumericRef<T>,
1438 {
1439 #[track_caller]
1440 #[inline]
1441 fn $method(&mut self, rhs: Tensor<T, D>) {
1442 let left_shape = self.shape();
1443 $implementation::<T, _, _, D>(
1444 self.direct_iter_reference_mut(),
1445 left_shape,
1446 rhs.direct_iter_reference(),
1447 rhs.shape(),
1448 )
1449 }
1450 }
1451 };
1452}
1453
1454tensor_assign_tensor_value_operation_iter!(impl AddAssign for Tensor { fn add_assign } tensor_view_assign_addition_iter "Elementwise assigning addition for two tensors with one referenced");
1455tensor_assign_tensor_value_operation_iter!(impl SubAssign for Tensor { fn sub_assign } tensor_view_assign_subtraction_iter "Elementwise assigning subtraction for two tensors with one referenced");
1456
1457macro_rules! tensor_value_tensor_reference_operation_iter {
1458 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1459 #[doc=$doc]
1460 impl<T, const D: usize> $op<&Tensor<T, D>> for Tensor<T, D>
1461 where
1462 T: Numeric,
1463 for<'a> &'a T: NumericRef<T>,
1464 {
1465 type Output = Tensor<T, D>;
1466
1467 #[track_caller]
1468 #[inline]
1469 fn $method(self, rhs: &Tensor<T, D>) -> Self::Output {
1470 $implementation::<T, _, _, D>(
1471 self.direct_iter_reference(),
1472 self.shape(),
1473 rhs.direct_iter_reference(),
1474 rhs.shape(),
1475 )
1476 }
1477 }
1478 };
1479}
1480
1481macro_rules! tensor_value_tensor_reference_operation {
1482 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1483 #[doc=$doc]
1484 impl<T> $op<&Tensor<T, $d>> for Tensor<T, $d>
1485 where
1486 T: Numeric,
1487 for<'a> &'a T: NumericRef<T>,
1488 {
1489 type Output = Tensor<T, $d>;
1490
1491 #[track_caller]
1492 #[inline]
1493 fn $method(self, rhs: &Tensor<T, $d>) -> Self::Output {
1494 $implementation::<T, _, _>(self, rhs)
1495 }
1496 }
1497 };
1498}
1499
1500tensor_value_tensor_reference_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for two tensors with one referenced");
1501tensor_value_tensor_reference_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensors with one referenced");
1502tensor_value_tensor_reference_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for two 2-dimensional tensors with one referenced");
1503
1504macro_rules! tensor_value_tensor_value_operation_iter {
1505 (impl $op:tt for Tensor { fn $method:ident } $implementation:ident $doc:tt) => {
1506 #[doc=$doc]
1507 impl<T, const D: usize> $op<Tensor<T, D>> for Tensor<T, D>
1508 where
1509 T: Numeric,
1510 for<'a> &'a T: NumericRef<T>,
1511 {
1512 type Output = Tensor<T, D>;
1513
1514 #[track_caller]
1515 #[inline]
1516 fn $method(self, rhs: Tensor<T, D>) -> Self::Output {
1517 $implementation::<T, _, _, D>(
1518 self.direct_iter_reference(),
1519 self.shape(),
1520 rhs.direct_iter_reference(),
1521 rhs.shape(),
1522 )
1523 }
1524 }
1525 };
1526}
1527
1528macro_rules! tensor_value_tensor_value_operation {
1529 (impl $op:tt for Tensor $d:literal { fn $method:ident } $implementation:ident $doc:tt) => {
1530 #[doc=$doc]
1531 impl<T> $op<Tensor<T, $d>> for Tensor<T, $d>
1532 where
1533 T: Numeric,
1534 for<'a> &'a T: NumericRef<T>,
1535 {
1536 type Output = Tensor<T, $d>;
1537
1538 #[track_caller]
1539 #[inline]
1540 fn $method(self, rhs: Tensor<T, $d>) -> Self::Output {
1541 $implementation::<T, _, _>(self, rhs)
1542 }
1543 }
1544 };
1545}
1546
1547tensor_value_tensor_value_operation_iter!(impl Add for Tensor { fn add } tensor_view_addition_iter "Elementwise addition for two tensors");
1548tensor_value_tensor_value_operation_iter!(impl Sub for Tensor { fn sub } tensor_view_subtraction_iter "Elementwise subtraction for two tensors");
1549tensor_value_tensor_value_operation!(impl Mul for Tensor 2 { fn mul } tensor_view_matrix_product "Matrix multiplication for two 2-dimensional tensors");
1550
1551#[test]
1552fn elementwise_addition_test_all_16_combinations() {
1553 fn tensor() -> Tensor<i8, 1> {
1554 Tensor::from([("a", 1)], vec![1])
1555 }
1556 fn tensor_view() -> TensorView<i8, Tensor<i8, 1>, 1> {
1557 TensorView::from(tensor())
1558 }
1559 let mut results = Vec::with_capacity(16);
1560 results.push(tensor() + tensor());
1561 results.push(tensor() + &tensor());
1562 results.push(&tensor() + tensor());
1563 results.push(&tensor() + &tensor());
1564 results.push(tensor_view() + tensor());
1565 results.push(tensor_view() + &tensor());
1566 results.push(&tensor_view() + tensor());
1567 results.push(&tensor_view() + &tensor());
1568 results.push(tensor() + tensor_view());
1569 results.push(tensor() + &tensor_view());
1570 results.push(&tensor() + tensor_view());
1571 results.push(&tensor() + &tensor_view());
1572 results.push(tensor_view() + tensor_view());
1573 results.push(tensor_view() + &tensor_view());
1574 results.push(&tensor_view() + tensor_view());
1575 results.push(&tensor_view() + &tensor_view());
1576 for total in results {
1577 assert_eq!(total.index_by(["a"]).get([0]), 2);
1578 }
1579}
1580
1581#[test]
1582fn elementwise_addition_assign_test_all_8_combinations() {
1583 fn tensor() -> Tensor<i8, 1> {
1584 Tensor::from([("a", 1)], vec![1])
1585 }
1586 fn tensor_view() -> TensorView<i8, Tensor<i8, 1>, 1> {
1587 TensorView::from(tensor())
1588 }
1589 let mut results_tensors = Vec::with_capacity(16);
1590 let mut results_tensor_views = Vec::with_capacity(16);
1591 results_tensors.push({
1592 let mut x = tensor();
1593 x += tensor();
1594 x
1595 });
1596 results_tensors.push({
1597 let mut x = tensor();
1598 x += &tensor();
1599 x
1600 });
1601 results_tensor_views.push({
1602 let mut x = tensor_view();
1603 x += tensor();
1604 x
1605 });
1606 results_tensor_views.push({
1607 let mut x = tensor_view();
1608 x += &tensor();
1609 x
1610 });
1611 results_tensors.push({
1612 let mut x = tensor();
1613 x += tensor_view();
1614 x
1615 });
1616 results_tensors.push({
1617 let mut x = tensor();
1618 x += &tensor_view();
1619 x
1620 });
1621 results_tensor_views.push({
1622 let mut x = tensor_view();
1623 x += tensor_view();
1624 x
1625 });
1626 results_tensor_views.push({
1627 let mut x = tensor_view();
1628 x += &tensor_view();
1629 x
1630 });
1631 for total in results_tensors {
1632 assert_eq!(total.index_by(["a"]).get([0]), 2);
1633 }
1634 for total in results_tensor_views {
1635 assert_eq!(total.index_by(["a"]).get([0]), 2);
1636 }
1637}
1638
1639#[test]
1640fn elementwise_addition_test() {
1641 let tensor_1: Tensor<i32, 2> = Tensor::from([("r", 2), ("c", 2)], vec![1, 2, 3, 4]);
1642 let tensor_2: Tensor<i32, 2> = Tensor::from([("r", 2), ("c", 2)], vec![3, 2, 8, 1]);
1643 let added: Tensor<i32, 2> = tensor_1 + tensor_2;
1644 assert_eq!(added, Tensor::from([("r", 2), ("c", 2)], vec![4, 4, 11, 5]));
1645}
1646
1647#[should_panic]
1648#[test]
1649fn elementwise_addition_test_similar_not_matching() {
1650 let tensor_1: Tensor<i32, 2> = Tensor::from([("r", 2), ("c", 2)], vec![1, 2, 3, 4]);
1651 let tensor_2: Tensor<i32, 2> = Tensor::from([("c", 2), ("r", 2)], vec![3, 8, 2, 1]);
1652 let _: Tensor<i32, 2> = tensor_1 + tensor_2;
1653}
1654
1655#[test]
1656fn matrix_multiplication_test_all_16_combinations() {
1657 #[rustfmt::skip]
1658 fn tensor_1() -> Tensor<i8, 2> {
1659 Tensor::from([("r", 2), ("c", 3)], vec![
1660 1, 2, 3,
1661 4, 5, 6
1662 ])
1663 }
1664 fn tensor_1_view() -> TensorView<i8, Tensor<i8, 2>, 2> {
1665 TensorView::from(tensor_1())
1666 }
1667 #[rustfmt::skip]
1668 fn tensor_2() -> Tensor<i8, 2> {
1669 Tensor::from([("a", 3), ("b", 2)], vec![
1670 1, 2,
1671 3, 4,
1672 5, 6
1673 ])
1674 }
1675 fn tensor_2_view() -> TensorView<i8, Tensor<i8, 2>, 2> {
1676 TensorView::from(tensor_2())
1677 }
1678 let mut results = Vec::with_capacity(16);
1679 results.push(tensor_1() * tensor_2());
1680 results.push(tensor_1() * &tensor_2());
1681 results.push(&tensor_1() * tensor_2());
1682 results.push(&tensor_1() * &tensor_2());
1683 results.push(tensor_1_view() * tensor_2());
1684 results.push(tensor_1_view() * &tensor_2());
1685 results.push(&tensor_1_view() * tensor_2());
1686 results.push(&tensor_1_view() * &tensor_2());
1687 results.push(tensor_1() * tensor_2_view());
1688 results.push(tensor_1() * &tensor_2_view());
1689 results.push(&tensor_1() * tensor_2_view());
1690 results.push(&tensor_1() * &tensor_2_view());
1691 results.push(tensor_1_view() * tensor_2_view());
1692 results.push(tensor_1_view() * &tensor_2_view());
1693 results.push(&tensor_1_view() * tensor_2_view());
1694 results.push(&tensor_1_view() * &tensor_2_view());
1695 for total in results {
1696 #[rustfmt::skip]
1697 assert_eq!(
1698 total,
1699 Tensor::from(
1700 [("r", 2), ("b", 2)],
1701 vec![
1702 1 * 1 + 2 * 3 + 3 * 5, 1 * 2 + 2 * 4 + 3 * 6,
1703 4 * 1 + 5 * 3 + 6 * 5, 4 * 2 + 5 * 4 + 6 * 6,
1704 ]
1705 )
1706 );
1707 }
1708}
1709
1710#[test]
1711fn elementwise_subtraction_assign_test_all_8_combinations() {
1712 fn tensor() -> Tensor<i8, 1> {
1713 Tensor::from([("a", 1)], vec![1])
1714 }
1715 fn tensor_view() -> TensorView<i8, Tensor<i8, 1>, 1> {
1716 TensorView::from(tensor())
1717 }
1718 let mut results_tensors = Vec::with_capacity(16);
1719 let mut results_tensor_views = Vec::with_capacity(16);
1720 results_tensors.push({
1721 let mut x = tensor();
1722 x -= tensor();
1723 x
1724 });
1725 results_tensors.push({
1726 let mut x = tensor();
1727 x -= &tensor();
1728 x
1729 });
1730 results_tensor_views.push({
1731 let mut x = tensor_view();
1732 x -= tensor();
1733 x
1734 });
1735 results_tensor_views.push({
1736 let mut x = tensor_view();
1737 x -= &tensor();
1738 x
1739 });
1740 results_tensors.push({
1741 let mut x = tensor();
1742 x -= tensor_view();
1743 x
1744 });
1745 results_tensors.push({
1746 let mut x = tensor();
1747 x -= &tensor_view();
1748 x
1749 });
1750 results_tensor_views.push({
1751 let mut x = tensor_view();
1752 x -= tensor_view();
1753 x
1754 });
1755 results_tensor_views.push({
1756 let mut x = tensor_view();
1757 x -= &tensor_view();
1758 x
1759 });
1760 for total in results_tensors {
1761 assert_eq!(total.index_by(["a"]).get([0]), 0);
1762 }
1763 for total in results_tensor_views {
1764 assert_eq!(total.index_by(["a"]).get([0]), 0);
1765 }
1766}
1767
1768impl<T> Tensor<T, 1>
1769where
1770 T: Numeric,
1771 for<'a> &'a T: NumericRef<T>,
1772{
1773 pub(crate) fn scalar_product_less_generic<S>(&self, rhs: TensorView<T, S, 1>) -> T
1774 where
1775 S: TensorRef<T, 1>,
1776 {
1777 let left_shape = self.shape();
1778 let right_shape = rhs.shape();
1779 assert_same_dimensions(left_shape, right_shape);
1780 tensor_view_vector_product_iter::<T, _, _>(
1781 self.direct_iter_reference(),
1782 left_shape,
1783 rhs.iter_reference(),
1784 right_shape,
1785 )
1786 }
1787}
1788
1789impl<T, S> TensorView<T, S, 1>
1790where
1791 T: Numeric,
1792 for<'a> &'a T: NumericRef<T>,
1793 S: TensorRef<T, 1>,
1794{
1795 pub(crate) fn scalar_product_less_generic<S2>(&self, rhs: TensorView<T, S2, 1>) -> T
1796 where
1797 S2: TensorRef<T, 1>,
1798 {
1799 let left_shape = self.shape();
1800 let right_shape = rhs.shape();
1801 assert_same_dimensions(left_shape, right_shape);
1802 tensor_view_vector_product_iter::<T, _, _>(
1803 self.iter_reference(),
1804 left_shape,
1805 rhs.iter_reference(),
1806 right_shape,
1807 )
1808 }
1809}
1810
1811macro_rules! tensor_scalar {
1812 (impl $op:tt for Tensor { fn $method:ident }) => {
1813 impl<T: Numeric, const D: usize> $op<&T> for &Tensor<T, D>
1818 where
1819 for<'a> &'a T: NumericRef<T>,
1820 {
1821 type Output = Tensor<T, D>;
1822 #[inline]
1823 fn $method(self, rhs: &T) -> Self::Output {
1824 self.map(|x| (x).$method(rhs.clone()))
1825 }
1826 }
1827
1828 impl<T: Numeric, const D: usize> $op<&T> for Tensor<T, D>
1833 where
1834 for<'a> &'a T: NumericRef<T>,
1835 {
1836 type Output = Tensor<T, D>;
1837 #[inline]
1838 fn $method(self, rhs: &T) -> Self::Output {
1839 self.map(|x| (x).$method(rhs.clone()))
1840 }
1841 }
1842
1843 impl<T: Numeric, const D: usize> $op<T> for &Tensor<T, D>
1848 where
1849 for<'a> &'a T: NumericRef<T>,
1850 {
1851 type Output = Tensor<T, D>;
1852 #[inline]
1853 fn $method(self, rhs: T) -> Self::Output {
1854 self.map(|x| (x).$method(rhs.clone()))
1855 }
1856 }
1857
1858 impl<T: Numeric, const D: usize> $op<T> for Tensor<T, D>
1863 where
1864 for<'a> &'a T: NumericRef<T>,
1865 {
1866 type Output = Tensor<T, D>;
1867 #[inline]
1868 fn $method(self, rhs: T) -> Self::Output {
1869 self.map(|x| (x).$method(rhs.clone()))
1870 }
1871 }
1872 };
1873}
1874
1875macro_rules! tensor_view_scalar {
1876 (impl $op:tt for TensorView { fn $method:ident }) => {
1877 impl<T, S, const D: usize> $op<&T> for &TensorView<T, S, D>
1882 where
1883 T: Numeric,
1884 for<'a> &'a T: NumericRef<T>,
1885 S: TensorRef<T, D>,
1886 {
1887 type Output = Tensor<T, D>;
1888 #[inline]
1889 fn $method(self, rhs: &T) -> Self::Output {
1890 self.map(|x| (x).$method(rhs.clone()))
1891 }
1892 }
1893
1894 impl<T, S, const D: usize> $op<&T> for TensorView<T, S, D>
1899 where
1900 T: Numeric,
1901 for<'a> &'a T: NumericRef<T>,
1902 S: TensorRef<T, D>,
1903 {
1904 type Output = Tensor<T, D>;
1905 #[inline]
1906 fn $method(self, rhs: &T) -> Self::Output {
1907 self.map(|x| (x).$method(rhs.clone()))
1908 }
1909 }
1910
1911 impl<T, S, const D: usize> $op<T> for &TensorView<T, S, D>
1916 where
1917 T: Numeric,
1918 for<'a> &'a T: NumericRef<T>,
1919 S: TensorRef<T, D>,
1920 {
1921 type Output = Tensor<T, D>;
1922 #[inline]
1923 fn $method(self, rhs: T) -> Self::Output {
1924 self.map(|x| (x).$method(rhs.clone()))
1925 }
1926 }
1927
1928 impl<T, S, const D: usize> $op<T> for TensorView<T, S, D>
1933 where
1934 T: Numeric,
1935 for<'a> &'a T: NumericRef<T>,
1936 S: TensorRef<T, D>,
1937 {
1938 type Output = Tensor<T, D>;
1939 #[inline]
1940 fn $method(self, rhs: T) -> Self::Output {
1941 self.map(|x| (x).$method(rhs.clone()))
1942 }
1943 }
1944 };
1945}
1946
1947tensor_scalar!(impl Add for Tensor { fn add });
1948tensor_scalar!(impl Sub for Tensor { fn sub });
1949tensor_scalar!(impl Mul for Tensor { fn mul });
1950tensor_scalar!(impl Div for Tensor { fn div });
1951
1952tensor_view_scalar!(impl Add for TensorView { fn add });
1953tensor_view_scalar!(impl Sub for TensorView { fn sub });
1954tensor_view_scalar!(impl Mul for TensorView { fn mul });
1955tensor_view_scalar!(impl Div for TensorView { fn div });