1use crate::differentiation::functions::{
2 Addition, Cosine, Division, Exponential, FunctionDerivative, Multiplication, NaturalLogarithm,
3 Negation, Power, Sine, SquareRoot, Subtraction, UnaryFunctionDerivative,
4};
5use crate::differentiation::record_operations::are_same_list;
6use crate::differentiation::{Index, Primitive, WengertList};
7use crate::differentiation::{RecordContainer, RecordMatrix, RecordTensor};
8use crate::matrices::Matrix;
9use crate::matrices::views::{MatrixMap, MatrixMut, MatrixRef, MatrixView, NoInteriorMutability};
10use crate::numeric::{Numeric, NumericRef};
11use crate::tensors::Tensor;
12use crate::tensors::views::{TensorMap, TensorMut, TensorRef, TensorView};
13
14use crate::numeric::extra::{Cos, Exp, Ln, Pow, Real, RealRef, Sin, Sqrt};
15
16use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
17
18mod swapped;
19
20impl<'a, T, S> std::fmt::Display for RecordMatrix<'a, T, S>
24where
25 T: std::fmt::Display + Primitive,
26 S: MatrixRef<(T, Index)>,
27{
28 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
29 write!(
30 f,
31 "{}",
32 MatrixView::from(MatrixMap::from(self.numbers.source_ref(), |(x, _)| x))
33 )
34 }
35}
36
37impl<'a, T, S, const D: usize> std::fmt::Display for RecordTensor<'a, T, S, D>
41where
42 T: std::fmt::Display + Primitive,
43 S: TensorRef<(T, Index), D>,
44{
45 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
46 write!(
47 f,
48 "{}",
49 TensorView::from(TensorMap::from(self.numbers.source_ref(), |(x, _)| x))
50 )
51 }
52}
53
54impl<'a, T, S, const D: usize> Clone for RecordContainer<'a, T, S, D>
58where
59 T: Clone + Primitive,
60 S: Clone,
61{
62 #[inline]
63 fn clone(&self) -> Self {
64 RecordContainer {
65 numbers: self.numbers.clone(),
66 history: self.history,
67 }
68 }
69}
70
71macro_rules! record_tensor_operator_impl_value_value {
72 (impl $op:tt for RecordTensor { fn $method:ident } $implementation:ident) => {
73 impl<'a, T, S1, S2, const D: usize> $op<RecordTensor<'a, T, S2, D>>
77 for RecordTensor<'a, T, S1, D>
78 where
79 T: Numeric + Primitive,
80 for<'t> &'t T: NumericRef<T>,
81 S1: TensorRef<(T, Index), D>,
82 S2: TensorRef<(T, Index), D>,
83 {
84 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
85 #[track_caller]
86 fn $method(self, rhs: RecordTensor<'a, T, S2, D>) -> Self::Output {
87 $implementation::<T, S1, S2, D>(self, rhs)
88 }
89 }
90 };
91}
92
93macro_rules! record_tensor_operator_impl_value_assign {
94 (impl $op:tt for RecordTensor { fn $method:ident } $implementation:ident) => {
95 impl<'a, T, S1, S2, const D: usize> $op<RecordTensor<'a, T, S2, D>>
99 for RecordTensor<'a, T, S1, D>
100 where
101 T: Numeric + Primitive,
102 for<'t> &'t T: NumericRef<T>,
103 S1: TensorMut<(T, Index), D>,
104 S2: TensorRef<(T, Index), D>,
105 {
106 #[track_caller]
107 fn $method(&mut self, rhs: RecordTensor<'a, T, S2, D>) {
108 $implementation::<T, S1, S2, D>(self, rhs)
109 }
110 }
111 };
112}
113
114macro_rules! record_matrix_operator_impl_value_value {
115 (impl $op:tt for RecordMatrix { fn $method:ident } $implementation:ident) => {
116 impl<'a, T, S1, S2> $op<RecordMatrix<'a, T, S2>> for RecordMatrix<'a, T, S1>
120 where
121 T: Numeric + Primitive,
122 for<'t> &'t T: NumericRef<T>,
123 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
124 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
125 {
126 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
127 #[track_caller]
128 fn $method(self, rhs: RecordMatrix<'a, T, S2>) -> Self::Output {
129 $implementation::<T, S1, S2>(self, rhs)
130 }
131 }
132 };
133}
134
135macro_rules! record_matrix_operator_impl_value_assign {
136 (impl $op:tt for RecordMatrix { fn $method:ident } $implementation:ident) => {
137 impl<'a, T, S1, S2> $op<RecordMatrix<'a, T, S2>> for RecordMatrix<'a, T, S1>
141 where
142 T: Numeric + Primitive,
143 for<'t> &'t T: NumericRef<T>,
144 S1: MatrixMut<(T, Index)> + NoInteriorMutability,
145 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
146 {
147 #[track_caller]
148 fn $method(&mut self, rhs: RecordMatrix<'a, T, S2>) {
149 $implementation::<T, S1, S2>(self, rhs)
150 }
151 }
152 };
153}
154
155macro_rules! record_tensor_operator_impl_value_reference {
156 (impl $op:tt for RecordTensor { fn $method:ident } $implementation:ident) => {
157 impl<'a, T, S1, S2, const D: usize> $op<&RecordTensor<'a, T, S2, D>>
161 for RecordTensor<'a, T, S1, D>
162 where
163 T: Numeric + Primitive,
164 for<'t> &'t T: NumericRef<T>,
165 S1: TensorRef<(T, Index), D>,
166 S2: TensorRef<(T, Index), D>,
167 {
168 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
169 #[track_caller]
170 fn $method(self, rhs: &RecordTensor<'a, T, S2, D>) -> Self::Output {
171 $implementation::<T, S1, S2, D>(self, rhs)
172 }
173 }
174 };
175}
176
177macro_rules! record_tensor_operator_impl_reference_assign {
178 (impl $op:tt for RecordTensor { fn $method:ident } $implementation:ident) => {
179 impl<'a, T, S1, S2, const D: usize> $op<&RecordTensor<'a, T, S2, D>>
183 for RecordTensor<'a, T, S1, D>
184 where
185 T: Numeric + Primitive,
186 for<'t> &'t T: NumericRef<T>,
187 S1: TensorMut<(T, Index), D>,
188 S2: TensorRef<(T, Index), D>,
189 {
190 #[track_caller]
191 fn $method(&mut self, rhs: &RecordTensor<'a, T, S2, D>) {
192 $implementation::<T, S1, S2, D>(self, rhs)
193 }
194 }
195 };
196}
197
198macro_rules! record_matrix_operator_impl_value_reference {
199 (impl $op:tt for RecordMatrix { fn $method:ident } $implementation:ident) => {
200 impl<'a, T, S1, S2> $op<&RecordMatrix<'a, T, S2>> for RecordMatrix<'a, T, S1>
204 where
205 T: Numeric + Primitive,
206 for<'t> &'t T: NumericRef<T>,
207 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
208 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
209 {
210 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
211 #[track_caller]
212 fn $method(self, rhs: &RecordMatrix<'a, T, S2>) -> Self::Output {
213 $implementation::<T, S1, S2>(self, rhs)
214 }
215 }
216 };
217}
218
219macro_rules! record_matrix_operator_impl_reference_assign {
220 (impl $op:tt for RecordMatrix { fn $method:ident } $implementation:ident) => {
221 impl<'a, T, S1, S2> $op<&RecordMatrix<'a, T, S2>> for RecordMatrix<'a, T, S1>
225 where
226 T: Numeric + Primitive,
227 for<'t> &'t T: NumericRef<T>,
228 S1: MatrixMut<(T, Index)> + NoInteriorMutability,
229 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
230 {
231 #[track_caller]
232 fn $method(&mut self, rhs: &RecordMatrix<'a, T, S2>) {
233 $implementation::<T, S1, S2>(self, rhs)
234 }
235 }
236 };
237}
238
239macro_rules! record_tensor_operator_impl_reference_value {
240 (impl $op:tt for RecordTensor { fn $method:ident } $implementation:ident) => {
241 impl<'a, T, S1, S2, const D: usize> $op<RecordTensor<'a, T, S2, D>>
245 for &RecordTensor<'a, T, S1, D>
246 where
247 T: Numeric + Primitive,
248 for<'t> &'t T: NumericRef<T>,
249 S1: TensorRef<(T, Index), D>,
250 S2: TensorRef<(T, Index), D>,
251 {
252 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
253 #[track_caller]
254 fn $method(self, rhs: RecordTensor<'a, T, S2, D>) -> Self::Output {
255 $implementation::<T, S1, S2, D>(self, rhs)
256 }
257 }
258 };
259}
260
261macro_rules! record_matrix_operator_impl_reference_value {
262 (impl $op:tt for RecordMatrix { fn $method:ident } $implementation:ident) => {
263 impl<'a, T, S1, S2> $op<RecordMatrix<'a, T, S2>> for &RecordMatrix<'a, T, S1>
267 where
268 T: Numeric + Primitive,
269 for<'t> &'t T: NumericRef<T>,
270 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
271 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
272 {
273 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
274 #[track_caller]
275 fn $method(self, rhs: RecordMatrix<'a, T, S2>) -> Self::Output {
276 $implementation::<T, S1, S2>(self, rhs)
277 }
278 }
279 };
280}
281
282macro_rules! record_tensor_operator_impl_reference_reference {
283 (impl $op:tt for RecordTensor { fn $method:ident } $implementation:ident) => {
284 impl<'a, T, S1, S2, const D: usize> $op<&RecordTensor<'a, T, S2, D>>
288 for &RecordTensor<'a, T, S1, D>
289 where
290 T: Numeric + Primitive,
291 for<'t> &'t T: NumericRef<T>,
292 S1: TensorRef<(T, Index), D>,
293 S2: TensorRef<(T, Index), D>,
294 {
295 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
296 #[track_caller]
297 fn $method(self, rhs: &RecordTensor<'a, T, S2, D>) -> Self::Output {
298 $implementation::<T, S1, S2, D>(self, rhs)
299 }
300 }
301 };
302}
303
304macro_rules! record_matrix_operator_impl_reference_reference {
305 (impl $op:tt for RecordMatrix { fn $method:ident } $implementation:ident) => {
306 impl<'a, T, S1, S2> $op<&RecordMatrix<'a, T, S2>> for &RecordMatrix<'a, T, S1>
310 where
311 T: Numeric + Primitive,
312 for<'t> &'t T: NumericRef<T>,
313 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
314 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
315 {
316 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
317 #[track_caller]
318 fn $method(self, rhs: &RecordMatrix<'a, T, S2>) -> Self::Output {
319 $implementation::<T, S1, S2>(self, rhs)
320 }
321 }
322 };
323}
324
325macro_rules! record_tensor_operator_impl_value {
326 (impl $op:tt for RecordTensor { fn $method:ident } $implementation:ident) => {
327 impl<'a, T, S, const D: usize> $op for RecordTensor<'a, T, S, D>
331 where
332 T: Numeric + Primitive,
333 for<'t> &'t T: NumericRef<T>,
334 S: TensorRef<(T, Index), D>,
335 {
336 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
337 #[track_caller]
338 fn $method(self) -> Self::Output {
339 $implementation::<T, S, D>(self)
340 }
341 }
342 };
343}
344
345macro_rules! record_matrix_operator_impl_value {
346 (impl $op:tt for RecordMatrix { fn $method:ident } $implementation:ident) => {
347 impl<'a, T, S> $op for RecordMatrix<'a, T, S>
351 where
352 T: Numeric + Primitive,
353 for<'t> &'t T: NumericRef<T>,
354 S: MatrixRef<(T, Index)> + NoInteriorMutability,
355 {
356 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
357 #[track_caller]
358 fn $method(self) -> Self::Output {
359 $implementation::<T, S>(self)
360 }
361 }
362 };
363}
364
365macro_rules! record_tensor_operator_impl_reference {
366 (impl $op:tt for RecordTensor { fn $method:ident } $implementation:ident) => {
367 impl<'a, T, S, const D: usize> $op for &RecordTensor<'a, T, S, D>
371 where
372 T: Numeric + Primitive,
373 for<'t> &'t T: NumericRef<T>,
374 S: TensorRef<(T, Index), D>,
375 {
376 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
377 #[track_caller]
378 fn $method(self) -> Self::Output {
379 $implementation::<T, S, D>(self)
380 }
381 }
382 };
383}
384
385macro_rules! record_matrix_operator_impl_reference {
386 (impl $op:tt for RecordMatrix { fn $method:ident } $implementation:ident) => {
387 impl<'a, T, S> $op for &RecordMatrix<'a, T, S>
391 where
392 T: Numeric + Primitive,
393 for<'t> &'t T: NumericRef<T>,
394 S: MatrixRef<(T, Index)> + NoInteriorMutability,
395 {
396 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
397 #[track_caller]
398 fn $method(self) -> Self::Output {
399 $implementation::<T, S>(self)
400 }
401 }
402 };
403}
404
405macro_rules! record_real_tensor_operator_impl_unary {
406 (impl $op:tt for RecordTensor { fn $method:ident } $function:ident) => {
407 impl<'a, T, S, const D: usize> $op for RecordTensor<'a, T, S, D>
411 where
412 T: Real + Primitive,
413 for<'t> &'t T: RealRef<T>,
414 S: TensorRef<(T, Index), D>,
415 {
416 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
417 #[track_caller]
418 fn $method(self) -> Self::Output {
419 self.unary($function::<T>::function, $function::<T>::d_function_dx)
420 }
421 }
422
423 impl<'a, T, S, const D: usize> $op for &RecordTensor<'a, T, S, D>
427 where
428 T: Real + Primitive,
429 for<'t> &'t T: RealRef<T>,
430 S: TensorRef<(T, Index), D>,
431 {
432 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
433 #[track_caller]
434 fn $method(self) -> Self::Output {
435 self.unary($function::<T>::function, $function::<T>::d_function_dx)
436 }
437 }
438 };
439}
440
441macro_rules! record_real_matrix_operator_impl_unary {
442 (impl $op:tt for RecordMatrix { fn $method:ident } $function:ident) => {
443 impl<'a, T, S> $op for RecordMatrix<'a, T, S>
447 where
448 T: Real + Primitive,
449 for<'t> &'t T: RealRef<T>,
450 S: MatrixRef<(T, Index)> + NoInteriorMutability,
451 {
452 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
453 #[track_caller]
454 fn $method(self) -> Self::Output {
455 self.unary($function::<T>::function, $function::<T>::d_function_dx)
456 }
457 }
458
459 impl<'a, T, S> $op for &RecordMatrix<'a, T, S>
463 where
464 T: Real + Primitive,
465 for<'t> &'t T: RealRef<T>,
466 S: MatrixRef<(T, Index)> + NoInteriorMutability,
467 {
468 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
469 #[track_caller]
470 fn $method(self) -> Self::Output {
471 self.unary($function::<T>::function, $function::<T>::d_function_dx)
472 }
473 }
474 };
475}
476
477macro_rules! record_real_tensor_operator_impl_scalar {
478 (impl $op:tt for RecordTensor { fn $method:ident } $function:ident) => {
479 impl<'a, T, S, const D: usize> $op<T> for RecordTensor<'a, T, S, D>
484 where
485 T: Real + Primitive,
486 for<'t> &'t T: RealRef<T>,
487 S: TensorRef<(T, Index), D>,
488 {
489 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
490 #[track_caller]
491 fn $method(self, rhs: T) -> Self::Output {
492 self.unary(
493 |x| $function::<T>::function(x, rhs.clone()),
494 |x| $function::<T>::d_function_dx(x, rhs.clone()),
495 )
496 }
497 }
498
499 impl<'a, T, S, const D: usize> $op<T> for &RecordTensor<'a, T, S, D>
504 where
505 T: Real + Primitive,
506 for<'t> &'t T: RealRef<T>,
507 S: TensorRef<(T, Index), D>,
508 {
509 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
510 #[track_caller]
511 fn $method(self, rhs: T) -> Self::Output {
512 self.unary(
513 |x| $function::<T>::function(x, rhs.clone()),
514 |x| $function::<T>::d_function_dx(x, rhs.clone()),
515 )
516 }
517 }
518
519 impl<'a, T, S, const D: usize> $op<&T> for RecordTensor<'a, T, S, D>
524 where
525 T: Real + Primitive,
526 for<'t> &'t T: RealRef<T>,
527 S: TensorRef<(T, Index), D>,
528 {
529 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
530 #[track_caller]
531 fn $method(self, rhs: &T) -> Self::Output {
532 self.unary(
533 |x| $function::<T>::function(x, rhs.clone()),
534 |x| $function::<T>::d_function_dx(x, rhs.clone()),
535 )
536 }
537 }
538
539 impl<'a, T, S, const D: usize> $op<&T> for &RecordTensor<'a, T, S, D>
544 where
545 T: Real + Primitive,
546 for<'t> &'t T: RealRef<T>,
547 S: TensorRef<(T, Index), D>,
548 {
549 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
550 #[track_caller]
551 fn $method(self, rhs: &T) -> Self::Output {
552 self.unary(
553 |x| $function::<T>::function(x, rhs.clone()),
554 |x| $function::<T>::d_function_dx(x, rhs.clone()),
555 )
556 }
557 }
558 };
559}
560
561macro_rules! record_real_tensor_operator_impl_scalar_no_orphan_rule {
562 (impl $op:tt for RecordTensor { fn $method:ident } $function:ident) => {
563 impl<'a, T, S, const D: usize> $op<RecordTensor<'a, T, S, D>> for T
568 where
569 T: Real + Primitive,
570 for<'t> &'t T: RealRef<T>,
571 S: TensorRef<(T, Index), D>,
572 {
573 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
574 #[track_caller]
575 fn $method(self, rhs: RecordTensor<'a, T, S, D>) -> Self::Output {
576 rhs.unary(
577 |x| $function::<T>::function(self.clone(), x),
580 |x| $function::<T>::d_function_dy(self.clone(), x),
581 )
582 }
583 }
584
585 impl<'a, T, S, const D: usize> $op<&RecordTensor<'a, T, S, D>> for T
590 where
591 T: Real + Primitive,
592 for<'t> &'t T: RealRef<T>,
593 S: TensorRef<(T, Index), D>,
594 {
595 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
596 #[track_caller]
597 fn $method(self, rhs: &RecordTensor<'a, T, S, D>) -> Self::Output {
598 rhs.unary(
599 |x| $function::<T>::function(self.clone(), x),
602 |x| $function::<T>::d_function_dy(self.clone(), x),
603 )
604 }
605 }
606
607 impl<'a, T, S, const D: usize> $op<RecordTensor<'a, T, S, D>> for &T
612 where
613 T: Real + Primitive,
614 for<'t> &'t T: RealRef<T>,
615 S: TensorRef<(T, Index), D>,
616 {
617 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
618 #[track_caller]
619 fn $method(self, rhs: RecordTensor<'a, T, S, D>) -> Self::Output {
620 rhs.unary(
621 |x| $function::<T>::function(self.clone(), x),
624 |x| $function::<T>::d_function_dy(self.clone(), x),
625 )
626 }
627 }
628
629 impl<'a, T, S, const D: usize> $op<&RecordTensor<'a, T, S, D>> for &T
635 where
636 T: Real + Primitive,
637 for<'t> &'t T: RealRef<T>,
638 S: TensorRef<(T, Index), D>,
639 {
640 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
641 #[track_caller]
642 fn $method(self, rhs: &RecordTensor<'a, T, S, D>) -> Self::Output {
643 rhs.unary(
644 |x| $function::<T>::function(self.clone(), x),
647 |x| $function::<T>::d_function_dy(self.clone(), x),
648 )
649 }
650 }
651 };
652}
653
654macro_rules! record_real_matrix_operator_impl_scalar {
655 (impl $op:tt for RecordMatrix { fn $method:ident } $function:ident) => {
656 impl<'a, T, S> $op<T> for RecordMatrix<'a, T, S>
661 where
662 T: Real + Primitive,
663 for<'t> &'t T: RealRef<T>,
664 S: MatrixRef<(T, Index)> + NoInteriorMutability,
665 {
666 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
667 #[track_caller]
668 fn $method(self, rhs: T) -> Self::Output {
669 self.unary(
670 |x| $function::<T>::function(x, rhs.clone()),
671 |x| $function::<T>::d_function_dx(x, rhs.clone()),
672 )
673 }
674 }
675
676 impl<'a, T, S> $op<T> for &RecordMatrix<'a, T, S>
681 where
682 T: Real + Primitive,
683 for<'t> &'t T: RealRef<T>,
684 S: MatrixRef<(T, Index)> + NoInteriorMutability,
685 {
686 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
687 #[track_caller]
688 fn $method(self, rhs: T) -> Self::Output {
689 self.unary(
690 |x| $function::<T>::function(x, rhs.clone()),
691 |x| $function::<T>::d_function_dx(x, rhs.clone()),
692 )
693 }
694 }
695
696 impl<'a, T, S> $op<&T> for RecordMatrix<'a, T, S>
701 where
702 T: Real + Primitive,
703 for<'t> &'t T: RealRef<T>,
704 S: MatrixRef<(T, Index)> + NoInteriorMutability,
705 {
706 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
707 #[track_caller]
708 fn $method(self, rhs: &T) -> Self::Output {
709 self.unary(
710 |x| $function::<T>::function(x, rhs.clone()),
711 |x| $function::<T>::d_function_dx(x, rhs.clone()),
712 )
713 }
714 }
715
716 impl<'a, T, S> $op<&T> for &RecordMatrix<'a, T, S>
721 where
722 T: Real + Primitive,
723 for<'t> &'t T: RealRef<T>,
724 S: MatrixRef<(T, Index)> + NoInteriorMutability,
725 {
726 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
727 #[track_caller]
728 fn $method(self, rhs: &T) -> Self::Output {
729 self.unary(
730 |x| $function::<T>::function(x, rhs.clone()),
731 |x| $function::<T>::d_function_dx(x, rhs.clone()),
732 )
733 }
734 }
735 };
736}
737
738macro_rules! record_real_matrix_operator_impl_scalar_no_orphan_rule {
739 (impl $op:tt for RecordMatrix { fn $method:ident } $function:ident) => {
740 impl<'a, T, S> $op<RecordMatrix<'a, T, S>> for T
745 where
746 T: Real + Primitive,
747 for<'t> &'t T: RealRef<T>,
748 S: MatrixRef<(T, Index)> + NoInteriorMutability,
749 {
750 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
751 #[track_caller]
752 fn $method(self, rhs: RecordMatrix<'a, T, S>) -> Self::Output {
753 rhs.unary(
754 |x| $function::<T>::function(self.clone(), x),
757 |x| $function::<T>::d_function_dy(self.clone(), x),
758 )
759 }
760 }
761
762 impl<'a, T, S> $op<&RecordMatrix<'a, T, S>> for T
767 where
768 T: Real + Primitive,
769 for<'t> &'t T: RealRef<T>,
770 S: MatrixRef<(T, Index)> + NoInteriorMutability,
771 {
772 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
773 #[track_caller]
774 fn $method(self, rhs: &RecordMatrix<'a, T, S>) -> Self::Output {
775 rhs.unary(
776 |x| $function::<T>::function(self.clone(), x),
779 |x| $function::<T>::d_function_dy(self.clone(), x),
780 )
781 }
782 }
783
784 impl<'a, T, S> $op<RecordMatrix<'a, T, S>> for &T
789 where
790 T: Real + Primitive,
791 for<'t> &'t T: RealRef<T>,
792 S: MatrixRef<(T, Index)> + NoInteriorMutability,
793 {
794 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
795 #[track_caller]
796 fn $method(self, rhs: RecordMatrix<'a, T, S>) -> Self::Output {
797 rhs.unary(
798 |x| $function::<T>::function(self.clone(), x),
801 |x| $function::<T>::d_function_dy(self.clone(), x),
802 )
803 }
804 }
805
806 impl<'a, T, S> $op<&RecordMatrix<'a, T, S>> for &T
812 where
813 T: Real + Primitive,
814 for<'t> &'t T: RealRef<T>,
815 S: MatrixRef<(T, Index)> + NoInteriorMutability,
816 {
817 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
818 #[track_caller]
819 fn $method(self, rhs: &RecordMatrix<'a, T, S>) -> Self::Output {
820 rhs.unary(
821 |x| $function::<T>::function(self.clone(), x),
824 |x| $function::<T>::d_function_dy(self.clone(), x),
825 )
826 }
827 }
828 };
829}
830
831#[track_caller]
832fn record_tensor_add_allocate<'a, T, S1, S2, const D: usize>(
833 lhs: &RecordTensor<'a, T, S1, D>,
834 rhs: &RecordTensor<'a, T, S2, D>,
835) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
836where
837 T: Numeric + Primitive,
838 for<'t> &'t T: NumericRef<T>,
839 S1: TensorRef<(T, Index), D>,
840 S2: TensorRef<(T, Index), D>,
841{
842 assert!(
843 are_same_list(lhs.history, rhs.history),
844 "Record containers must be using the same WengertList"
845 );
846 lhs.binary(
847 rhs,
848 Addition::<T>::function,
849 Addition::<T>::d_function_dx,
850 Addition::<T>::d_function_dy,
851 )
852}
853
854#[track_caller]
855fn record_tensor_add_value_value<'a, T, S1, S2, const D: usize>(
856 lhs: RecordTensor<'a, T, S1, D>,
857 rhs: RecordTensor<'a, T, S2, D>,
858) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
859where
860 T: Numeric + Primitive,
861 for<'t> &'t T: NumericRef<T>,
862 S1: TensorRef<(T, Index), D>,
863 S2: TensorRef<(T, Index), D>,
864{
865 record_tensor_add_allocate::<T, S1, S2, D>(&lhs, &rhs)
866}
867
868#[track_caller]
869fn record_tensor_add_value_reference<'a, T, S1, S2, const D: usize>(
870 lhs: RecordTensor<'a, T, S1, D>,
871 rhs: &RecordTensor<'a, T, S2, D>,
872) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
873where
874 T: Numeric + Primitive,
875 for<'t> &'t T: NumericRef<T>,
876 S1: TensorRef<(T, Index), D>,
877 S2: TensorRef<(T, Index), D>,
878{
879 record_tensor_add_allocate::<T, S1, S2, D>(&lhs, rhs)
880}
881
882#[track_caller]
883fn record_tensor_add_reference_value<'a, T, S1, S2, const D: usize>(
884 lhs: &RecordTensor<'a, T, S1, D>,
885 rhs: RecordTensor<'a, T, S2, D>,
886) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
887where
888 T: Numeric + Primitive,
889 for<'t> &'t T: NumericRef<T>,
890 S1: TensorRef<(T, Index), D>,
891 S2: TensorRef<(T, Index), D>,
892{
893 record_tensor_add_allocate::<T, S1, S2, D>(lhs, &rhs)
894}
895
896record_tensor_operator_impl_value_value!(impl Add for RecordTensor { fn add } record_tensor_add_value_value);
897record_tensor_operator_impl_value_reference!(impl Add for RecordTensor { fn add } record_tensor_add_value_reference);
898record_tensor_operator_impl_reference_value!(impl Add for RecordTensor { fn add } record_tensor_add_reference_value);
899record_tensor_operator_impl_reference_reference!(impl Add for RecordTensor { fn add } record_tensor_add_allocate);
900
901#[track_caller]
902fn record_tensor_add_reference_assign<'a, T, S1, S2, const D: usize>(
903 lhs: &mut RecordTensor<'a, T, S1, D>,
904 rhs: &RecordTensor<'a, T, S2, D>,
905) where
906 T: Numeric + Primitive,
907 for<'t> &'t T: NumericRef<T>,
908 S1: TensorMut<(T, Index), D>,
909 S2: TensorRef<(T, Index), D>,
910{
911 assert!(
912 are_same_list(lhs.history, rhs.history),
913 "Record containers must be using the same WengertList"
914 );
915 lhs.binary_left_assign(
916 rhs,
917 Addition::<T>::function,
918 Addition::<T>::d_function_dx,
919 Addition::<T>::d_function_dy,
920 );
921}
922
923#[track_caller]
924fn record_tensor_add_value_assign<'a, T, S1, S2, const D: usize>(
925 lhs: &mut RecordTensor<'a, T, S1, D>,
926 rhs: RecordTensor<'a, T, S2, D>,
927) where
928 T: Numeric + Primitive,
929 for<'t> &'t T: NumericRef<T>,
930 S1: TensorMut<(T, Index), D>,
931 S2: TensorRef<(T, Index), D>,
932{
933 record_tensor_add_reference_assign::<T, S1, S2, D>(lhs, &rhs);
934}
935
936record_tensor_operator_impl_value_assign!(impl AddAssign for RecordTensor { fn add_assign } record_tensor_add_value_assign);
937record_tensor_operator_impl_reference_assign!(impl AddAssign for RecordTensor { fn add_assign } record_tensor_add_reference_assign);
938
939#[track_caller]
940fn record_matrix_add_allocate<'a, T, S1, S2>(
941 lhs: &RecordMatrix<'a, T, S1>,
942 rhs: &RecordMatrix<'a, T, S2>,
943) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
944where
945 T: Numeric + Primitive,
946 for<'t> &'t T: NumericRef<T>,
947 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
948 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
949{
950 assert!(
951 are_same_list(lhs.history, rhs.history),
952 "Record containers must be using the same WengertList"
953 );
954 lhs.binary(
955 rhs,
956 Addition::<T>::function,
957 Addition::<T>::d_function_dx,
958 Addition::<T>::d_function_dy,
959 )
960}
961
962#[track_caller]
963fn record_matrix_add_value_value<'a, T, S1, S2>(
964 lhs: RecordMatrix<'a, T, S1>,
965 rhs: RecordMatrix<'a, T, S2>,
966) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
967where
968 T: Numeric + Primitive,
969 for<'t> &'t T: NumericRef<T>,
970 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
971 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
972{
973 record_matrix_add_allocate::<T, S1, S2>(&lhs, &rhs)
974}
975
976#[track_caller]
977fn record_matrix_add_value_reference<'a, T, S1, S2>(
978 lhs: RecordMatrix<'a, T, S1>,
979 rhs: &RecordMatrix<'a, T, S2>,
980) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
981where
982 T: Numeric + Primitive,
983 for<'t> &'t T: NumericRef<T>,
984 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
985 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
986{
987 record_matrix_add_allocate::<T, S1, S2>(&lhs, rhs)
988}
989
990#[track_caller]
991fn record_matrix_add_reference_value<'a, T, S1, S2>(
992 lhs: &RecordMatrix<'a, T, S1>,
993 rhs: RecordMatrix<'a, T, S2>,
994) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
995where
996 T: Numeric + Primitive,
997 for<'t> &'t T: NumericRef<T>,
998 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
999 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1000{
1001 record_matrix_add_allocate::<T, S1, S2>(lhs, &rhs)
1002}
1003
1004record_matrix_operator_impl_value_value!(impl Add for RecordMatrix { fn add } record_matrix_add_value_value);
1005record_matrix_operator_impl_value_reference!(impl Add for RecordMatrix { fn add } record_matrix_add_value_reference);
1006record_matrix_operator_impl_reference_value!(impl Add for RecordMatrix { fn add } record_matrix_add_reference_value);
1007record_matrix_operator_impl_reference_reference!(impl Add for RecordMatrix { fn add } record_matrix_add_allocate);
1008
1009#[track_caller]
1010fn record_matrix_add_reference_assign<'a, T, S1, S2>(
1011 lhs: &mut RecordMatrix<'a, T, S1>,
1012 rhs: &RecordMatrix<'a, T, S2>,
1013) where
1014 T: Numeric + Primitive,
1015 for<'t> &'t T: NumericRef<T>,
1016 S1: MatrixMut<(T, Index)> + NoInteriorMutability,
1017 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1018{
1019 assert!(
1020 are_same_list(lhs.history, rhs.history),
1021 "Record containers must be using the same WengertList"
1022 );
1023 lhs.binary_left_assign(
1024 rhs,
1025 Addition::<T>::function,
1026 Addition::<T>::d_function_dx,
1027 Addition::<T>::d_function_dy,
1028 )
1029}
1030
1031#[track_caller]
1032fn record_matrix_add_value_assign<'a, T, S1, S2>(
1033 lhs: &mut RecordMatrix<'a, T, S1>,
1034 rhs: RecordMatrix<'a, T, S2>,
1035) where
1036 T: Numeric + Primitive,
1037 for<'t> &'t T: NumericRef<T>,
1038 S1: MatrixMut<(T, Index)> + NoInteriorMutability,
1039 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1040{
1041 record_matrix_add_reference_assign::<T, S1, S2>(lhs, &rhs);
1042}
1043
1044record_matrix_operator_impl_value_assign!(impl AddAssign for RecordMatrix { fn add_assign } record_matrix_add_value_assign);
1045record_matrix_operator_impl_reference_assign!(impl AddAssign for RecordMatrix { fn add_assign } record_matrix_add_reference_assign);
1046
1047#[track_caller]
1048fn record_tensor_sub_allocate<'a, T, S1, S2, const D: usize>(
1049 lhs: &RecordTensor<'a, T, S1, D>,
1050 rhs: &RecordTensor<'a, T, S2, D>,
1051) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1052where
1053 T: Numeric + Primitive,
1054 for<'t> &'t T: NumericRef<T>,
1055 S1: TensorRef<(T, Index), D>,
1056 S2: TensorRef<(T, Index), D>,
1057{
1058 assert!(
1059 are_same_list(lhs.history, rhs.history),
1060 "Record containers must be using the same WengertList"
1061 );
1062 lhs.binary(
1063 rhs,
1064 Subtraction::<T>::function,
1065 Subtraction::<T>::d_function_dx,
1066 Subtraction::<T>::d_function_dy,
1067 )
1068}
1069
1070#[track_caller]
1071fn record_tensor_sub_value_value<'a, T, S1, S2, const D: usize>(
1072 lhs: RecordTensor<'a, T, S1, D>,
1073 rhs: RecordTensor<'a, T, S2, D>,
1074) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1075where
1076 T: Numeric + Primitive,
1077 for<'t> &'t T: NumericRef<T>,
1078 S1: TensorRef<(T, Index), D>,
1079 S2: TensorRef<(T, Index), D>,
1080{
1081 record_tensor_sub_allocate::<T, S1, S2, D>(&lhs, &rhs)
1082}
1083
1084#[track_caller]
1085fn record_tensor_sub_value_reference<'a, T, S1, S2, const D: usize>(
1086 lhs: RecordTensor<'a, T, S1, D>,
1087 rhs: &RecordTensor<'a, T, S2, D>,
1088) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1089where
1090 T: Numeric + Primitive,
1091 for<'t> &'t T: NumericRef<T>,
1092 S1: TensorRef<(T, Index), D>,
1093 S2: TensorRef<(T, Index), D>,
1094{
1095 record_tensor_sub_allocate::<T, S1, S2, D>(&lhs, rhs)
1096}
1097
1098#[track_caller]
1099fn record_tensor_sub_reference_value<'a, T, S1, S2, const D: usize>(
1100 lhs: &RecordTensor<'a, T, S1, D>,
1101 rhs: RecordTensor<'a, T, S2, D>,
1102) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1103where
1104 T: Numeric + Primitive,
1105 for<'t> &'t T: NumericRef<T>,
1106 S1: TensorRef<(T, Index), D>,
1107 S2: TensorRef<(T, Index), D>,
1108{
1109 record_tensor_sub_allocate::<T, S1, S2, D>(lhs, &rhs)
1110}
1111
1112record_tensor_operator_impl_value_value!(impl Sub for RecordTensor { fn sub } record_tensor_sub_value_value);
1113record_tensor_operator_impl_value_reference!(impl Sub for RecordTensor { fn sub } record_tensor_sub_value_reference);
1114record_tensor_operator_impl_reference_value!(impl Sub for RecordTensor { fn sub } record_tensor_sub_reference_value);
1115record_tensor_operator_impl_reference_reference!(impl Sub for RecordTensor { fn sub } record_tensor_sub_allocate);
1116
1117#[track_caller]
1118fn record_tensor_sub_reference_assign<'a, T, S1, S2, const D: usize>(
1119 lhs: &mut RecordTensor<'a, T, S1, D>,
1120 rhs: &RecordTensor<'a, T, S2, D>,
1121) where
1122 T: Numeric + Primitive,
1123 for<'t> &'t T: NumericRef<T>,
1124 S1: TensorMut<(T, Index), D>,
1125 S2: TensorRef<(T, Index), D>,
1126{
1127 assert!(
1128 are_same_list(lhs.history, rhs.history),
1129 "Record containers must be using the same WengertList"
1130 );
1131 lhs.binary_left_assign(
1132 rhs,
1133 Subtraction::<T>::function,
1134 Subtraction::<T>::d_function_dx,
1135 Subtraction::<T>::d_function_dy,
1136 );
1137}
1138
1139#[track_caller]
1140fn record_tensor_sub_value_assign<'a, T, S1, S2, const D: usize>(
1141 lhs: &mut RecordTensor<'a, T, S1, D>,
1142 rhs: RecordTensor<'a, T, S2, D>,
1143) where
1144 T: Numeric + Primitive,
1145 for<'t> &'t T: NumericRef<T>,
1146 S1: TensorMut<(T, Index), D>,
1147 S2: TensorRef<(T, Index), D>,
1148{
1149 record_tensor_sub_reference_assign::<T, S1, S2, D>(lhs, &rhs);
1150}
1151
1152record_tensor_operator_impl_value_assign!(impl SubAssign for RecordTensor { fn sub_assign } record_tensor_sub_value_assign);
1153record_tensor_operator_impl_reference_assign!(impl SubAssign for RecordTensor { fn sub_assign } record_tensor_sub_reference_assign);
1154
1155#[track_caller]
1156fn record_matrix_sub_allocate<'a, T, S1, S2>(
1157 lhs: &RecordMatrix<'a, T, S1>,
1158 rhs: &RecordMatrix<'a, T, S2>,
1159) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1160where
1161 T: Numeric + Primitive,
1162 for<'t> &'t T: NumericRef<T>,
1163 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1164 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1165{
1166 assert!(
1167 are_same_list(lhs.history, rhs.history),
1168 "Record containers must be using the same WengertList"
1169 );
1170 lhs.binary(
1171 rhs,
1172 Subtraction::<T>::function,
1173 Subtraction::<T>::d_function_dx,
1174 Subtraction::<T>::d_function_dy,
1175 )
1176}
1177
1178#[track_caller]
1179fn record_matrix_sub_value_value<'a, T, S1, S2>(
1180 lhs: RecordMatrix<'a, T, S1>,
1181 rhs: RecordMatrix<'a, T, S2>,
1182) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1183where
1184 T: Numeric + Primitive,
1185 for<'t> &'t T: NumericRef<T>,
1186 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1187 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1188{
1189 record_matrix_sub_allocate::<T, S1, S2>(&lhs, &rhs)
1190}
1191
1192#[track_caller]
1193fn record_matrix_sub_value_reference<'a, T, S1, S2>(
1194 lhs: RecordMatrix<'a, T, S1>,
1195 rhs: &RecordMatrix<'a, T, S2>,
1196) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1197where
1198 T: Numeric + Primitive,
1199 for<'t> &'t T: NumericRef<T>,
1200 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1201 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1202{
1203 record_matrix_sub_allocate::<T, S1, S2>(&lhs, rhs)
1204}
1205
1206#[track_caller]
1207fn record_matrix_sub_reference_value<'a, T, S1, S2>(
1208 lhs: &RecordMatrix<'a, T, S1>,
1209 rhs: RecordMatrix<'a, T, S2>,
1210) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1211where
1212 T: Numeric + Primitive,
1213 for<'t> &'t T: NumericRef<T>,
1214 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1215 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1216{
1217 record_matrix_sub_allocate::<T, S1, S2>(lhs, &rhs)
1218}
1219
1220record_matrix_operator_impl_value_value!(impl Sub for RecordMatrix { fn sub } record_matrix_sub_value_value);
1221record_matrix_operator_impl_value_reference!(impl Sub for RecordMatrix { fn sub } record_matrix_sub_value_reference);
1222record_matrix_operator_impl_reference_value!(impl Sub for RecordMatrix { fn sub } record_matrix_sub_reference_value);
1223record_matrix_operator_impl_reference_reference!(impl Sub for RecordMatrix { fn sub } record_matrix_sub_allocate);
1224
1225#[track_caller]
1226fn record_matrix_sub_reference_assign<'a, T, S1, S2>(
1227 lhs: &mut RecordMatrix<'a, T, S1>,
1228 rhs: &RecordMatrix<'a, T, S2>,
1229) where
1230 T: Numeric + Primitive,
1231 for<'t> &'t T: NumericRef<T>,
1232 S1: MatrixMut<(T, Index)> + NoInteriorMutability,
1233 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1234{
1235 assert!(
1236 are_same_list(lhs.history, rhs.history),
1237 "Record containers must be using the same WengertList"
1238 );
1239 lhs.binary_left_assign(
1240 rhs,
1241 Subtraction::<T>::function,
1242 Subtraction::<T>::d_function_dx,
1243 Subtraction::<T>::d_function_dy,
1244 );
1245}
1246
1247#[track_caller]
1248fn record_matrix_sub_value_assign<'a, T, S1, S2>(
1249 lhs: &mut RecordMatrix<'a, T, S1>,
1250 rhs: RecordMatrix<'a, T, S2>,
1251) where
1252 T: Numeric + Primitive,
1253 for<'t> &'t T: NumericRef<T>,
1254 S1: MatrixMut<(T, Index)> + NoInteriorMutability,
1255 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1256{
1257 record_matrix_sub_reference_assign::<T, S1, S2>(lhs, &rhs);
1258}
1259
1260record_matrix_operator_impl_value_assign!(impl SubAssign for RecordMatrix { fn sub_assign } record_matrix_sub_value_assign);
1261record_matrix_operator_impl_reference_assign!(impl SubAssign for RecordMatrix { fn sub_assign } record_matrix_sub_reference_assign);
1262
1263fn record_scalar_product<'l, 'r, T, S1, S2>(
1264 left_iter: S1,
1265 right_iter: S2,
1266 history: Option<&WengertList<T>>,
1267) -> (T, Index)
1268where
1269 T: Numeric + Primitive,
1270 T: 'l,
1271 T: 'r,
1272 for<'a> &'a T: NumericRef<T>,
1273 S1: Iterator<Item = &'l (T, Index)>,
1274 S2: Iterator<Item = &'r (T, Index)>,
1275{
1276 match history {
1277 None => (
1278 crate::tensors::operations::scalar_product::<T, _, _>(
1279 left_iter.map(|(x, _)| x),
1280 right_iter.map(|(y, _)| y),
1281 ),
1282 0,
1283 ),
1284 Some(history) => {
1285 let products = left_iter
1286 .zip(right_iter)
1287 .map(|((x, x_index), (y, y_index))| {
1288 let z = Multiplication::<T>::function(x.clone(), y.clone());
1289 (
1290 z,
1291 history.append_binary(
1292 *x_index,
1293 Multiplication::<T>::d_function_dx(x.clone(), y.clone()),
1294 *y_index,
1295 Multiplication::<T>::d_function_dy(x.clone(), y.clone()),
1296 ),
1297 )
1298 });
1299 products
1300 .reduce(|(x, x_index), (y, y_index)| {
1301 let z = Addition::<T>::function(x.clone(), y.clone());
1302 (
1303 z,
1304 history.append_binary(
1305 x_index,
1306 Addition::<T>::d_function_dx(x.clone(), y.clone()),
1307 y_index,
1308 Addition::<T>::d_function_dy(x, y),
1309 ),
1310 )
1311 })
1312 .unwrap() }
1314 }
1315}
1316
1317#[track_caller]
1318fn record_tensor_matrix_multiply<'a, T, S1, S2>(
1319 lhs: &RecordTensor<'a, T, S1, 2>,
1320 rhs: &RecordTensor<'a, T, S2, 2>,
1321) -> RecordTensor<'a, T, Tensor<(T, Index), 2>, 2>
1322where
1323 T: Numeric + Primitive,
1324 for<'t> &'t T: NumericRef<T>,
1325 S1: TensorRef<(T, Index), 2>,
1326 S2: TensorRef<(T, Index), 2>,
1327{
1328 use crate::tensors::indexing::TensorReferenceIterator;
1329 use crate::tensors::views::TensorIndex;
1330
1331 assert!(
1332 are_same_list(lhs.history, rhs.history),
1333 "Record containers must be using the same WengertList"
1334 );
1335
1336 let left_shape = lhs.view_shape();
1338 let right_shape = rhs.view_shape();
1339 if left_shape[1].1 != right_shape[0].1 {
1340 panic!(
1341 "Mismatched record tensors, left is {:?}, right is {:?}, * is only defined for MxN * NxL dimension lengths",
1342 lhs.view_shape(),
1343 rhs.view_shape()
1344 );
1345 }
1346 if left_shape[0].0 == right_shape[1].0 {
1347 panic!(
1348 "Matrix multiplication of record tensors with shapes left {:?} and right {:?} would \
1349 create duplicate dimension names as the shape {:?}. Rename one or both of the \
1350 dimension names in the input to prevent this. * is defined as MxN * NxL = MxL",
1351 left_shape,
1352 right_shape,
1353 [left_shape[0], right_shape[1]]
1354 )
1355 }
1356
1357 let history = match (lhs.history, rhs.history) {
1358 (None, None) => None,
1359 (Some(history), _) => Some(history),
1360 (_, Some(history)) => Some(history),
1361 };
1362
1363 let mut tensor = Tensor::empty([lhs.view_shape()[0], rhs.view_shape()[1]], (T::zero(), 0));
1368 for ([i, j], x) in tensor.iter_reference_mut().with_index() {
1369 let left = TensorIndex::from(&lhs, [(lhs.view_shape()[0].0, i)]);
1371 let right = TensorIndex::from(&rhs, [(rhs.view_shape()[1].0, j)]);
1373 *x = record_scalar_product::<T, _, _>(
1375 TensorReferenceIterator::from(&left),
1376 TensorReferenceIterator::from(&right),
1377 history,
1378 )
1379 }
1380 RecordTensor::from_existing(history, TensorView::from(tensor))
1381}
1382
1383#[track_caller]
1384fn record_tensor_matrix_multiply_value_value<'a, T, S1, S2>(
1385 lhs: RecordTensor<'a, T, S1, 2>,
1386 rhs: RecordTensor<'a, T, S2, 2>,
1387) -> RecordTensor<'a, T, Tensor<(T, Index), 2>, 2>
1388where
1389 T: Numeric + Primitive,
1390 for<'t> &'t T: NumericRef<T>,
1391 S1: TensorRef<(T, Index), 2>,
1392 S2: TensorRef<(T, Index), 2>,
1393{
1394 record_tensor_matrix_multiply::<T, S1, S2>(&lhs, &rhs)
1395}
1396
1397#[track_caller]
1398fn record_tensor_matrix_multiply_value_reference<'a, T, S1, S2>(
1399 lhs: RecordTensor<'a, T, S1, 2>,
1400 rhs: &RecordTensor<'a, T, S2, 2>,
1401) -> RecordTensor<'a, T, Tensor<(T, Index), 2>, 2>
1402where
1403 T: Numeric + Primitive,
1404 for<'t> &'t T: NumericRef<T>,
1405 S1: TensorRef<(T, Index), 2>,
1406 S2: TensorRef<(T, Index), 2>,
1407{
1408 record_tensor_matrix_multiply::<T, S1, S2>(&lhs, rhs)
1409}
1410
1411#[track_caller]
1412fn record_tensor_matrix_multiply_reference_value<'a, T, S1, S2>(
1413 lhs: &RecordTensor<'a, T, S1, 2>,
1414 rhs: RecordTensor<'a, T, S2, 2>,
1415) -> RecordTensor<'a, T, Tensor<(T, Index), 2>, 2>
1416where
1417 T: Numeric + Primitive,
1418 for<'t> &'t T: NumericRef<T>,
1419 S1: TensorRef<(T, Index), 2>,
1420 S2: TensorRef<(T, Index), 2>,
1421{
1422 record_tensor_matrix_multiply::<T, S1, S2>(lhs, &rhs)
1423}
1424
1425impl<'a, T, S1, S2> Mul<&RecordTensor<'a, T, S2, 2>> for &RecordTensor<'a, T, S1, 2>
1429where
1430 T: Numeric + Primitive,
1431 for<'t> &'t T: NumericRef<T>,
1432 S1: TensorRef<(T, Index), 2>,
1433 S2: TensorRef<(T, Index), 2>,
1434{
1435 type Output = RecordTensor<'a, T, Tensor<(T, Index), 2>, 2>;
1436 #[track_caller]
1437 fn mul(self, rhs: &RecordTensor<'a, T, S2, 2>) -> Self::Output {
1438 record_tensor_matrix_multiply::<T, S1, S2>(self, rhs)
1439 }
1440}
1441
1442impl<'a, T, S1, S2> Mul<RecordTensor<'a, T, S2, 2>> for RecordTensor<'a, T, S1, 2>
1446where
1447 T: Numeric + Primitive,
1448 for<'t> &'t T: NumericRef<T>,
1449 S1: TensorRef<(T, Index), 2>,
1450 S2: TensorRef<(T, Index), 2>,
1451{
1452 type Output = RecordTensor<'a, T, Tensor<(T, Index), 2>, 2>;
1453 #[track_caller]
1454 fn mul(self, rhs: RecordTensor<'a, T, S2, 2>) -> Self::Output {
1455 record_tensor_matrix_multiply_value_value::<T, S1, S2>(self, rhs)
1456 }
1457}
1458
1459impl<'a, T, S1, S2> Mul<&RecordTensor<'a, T, S2, 2>> for RecordTensor<'a, T, S1, 2>
1463where
1464 T: Numeric + Primitive,
1465 for<'t> &'t T: NumericRef<T>,
1466 S1: TensorRef<(T, Index), 2>,
1467 S2: TensorRef<(T, Index), 2>,
1468{
1469 type Output = RecordTensor<'a, T, Tensor<(T, Index), 2>, 2>;
1470 #[track_caller]
1471 fn mul(self, rhs: &RecordTensor<'a, T, S2, 2>) -> Self::Output {
1472 record_tensor_matrix_multiply_value_reference::<T, S1, S2>(self, rhs)
1473 }
1474}
1475
1476impl<'a, T, S1, S2> Mul<RecordTensor<'a, T, S2, 2>> for &RecordTensor<'a, T, S1, 2>
1480where
1481 T: Numeric + Primitive,
1482 for<'t> &'t T: NumericRef<T>,
1483 S1: TensorRef<(T, Index), 2>,
1484 S2: TensorRef<(T, Index), 2>,
1485{
1486 type Output = RecordTensor<'a, T, Tensor<(T, Index), 2>, 2>;
1487 #[track_caller]
1488 fn mul(self, rhs: RecordTensor<'a, T, S2, 2>) -> Self::Output {
1489 record_tensor_matrix_multiply_reference_value::<T, S1, S2>(self, rhs)
1490 }
1491}
1492
1493#[track_caller]
1494fn record_matrix_matrix_multiply<'a, T, S1, S2>(
1495 lhs: &RecordMatrix<'a, T, S1>,
1496 rhs: &RecordMatrix<'a, T, S2>,
1497) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1498where
1499 T: Numeric + Primitive,
1500 for<'t> &'t T: NumericRef<T>,
1501 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1502 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1503{
1504 use crate::matrices::iterators::{ColumnReferenceIterator, RowReferenceIterator};
1505 assert!(
1507 lhs.view_columns() == rhs.view_rows(),
1508 "Mismatched Matrices, left is {}x{}, right is {}x{}, * is only defined for MxN * NxL",
1509 lhs.view_rows(),
1510 lhs.view_columns(),
1511 rhs.view_rows(),
1512 rhs.view_columns()
1513 );
1514
1515 let history = match (lhs.history, rhs.history) {
1516 (None, None) => None,
1517 (Some(history), _) => Some(history),
1518 (_, Some(history)) => Some(history),
1519 };
1520
1521 let mut result = Matrix::empty((T::zero(), 0), (lhs.view_rows(), rhs.view_columns()));
1522 for ((i, j), x) in result.row_major_reference_mut_iter().with_index() {
1523 let left = RowReferenceIterator::from(lhs, i);
1525 let right = ColumnReferenceIterator::from(rhs, j);
1527 *x = record_scalar_product::<T, _, _>(left, right, history);
1529 }
1530 RecordMatrix::from_existing(history, MatrixView::from(result))
1531}
1532
1533#[track_caller]
1534fn record_matrix_matrix_multiply_value_value<'a, T, S1, S2>(
1535 lhs: RecordMatrix<'a, T, S1>,
1536 rhs: RecordMatrix<'a, T, S2>,
1537) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1538where
1539 T: Numeric + Primitive,
1540 for<'t> &'t T: NumericRef<T>,
1541 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1542 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1543{
1544 record_matrix_matrix_multiply::<T, S1, S2>(&lhs, &rhs)
1545}
1546
1547#[track_caller]
1548fn record_matrix_matrix_multiply_value_reference<'a, T, S1, S2>(
1549 lhs: RecordMatrix<'a, T, S1>,
1550 rhs: &RecordMatrix<'a, T, S2>,
1551) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1552where
1553 T: Numeric + Primitive,
1554 for<'t> &'t T: NumericRef<T>,
1555 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1556 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1557{
1558 record_matrix_matrix_multiply::<T, S1, S2>(&lhs, rhs)
1559}
1560
1561#[track_caller]
1562fn record_matrix_matrix_multiply_reference_value<'a, T, S1, S2>(
1563 lhs: &RecordMatrix<'a, T, S1>,
1564 rhs: RecordMatrix<'a, T, S2>,
1565) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1566where
1567 T: Numeric + Primitive,
1568 for<'t> &'t T: NumericRef<T>,
1569 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1570 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1571{
1572 record_matrix_matrix_multiply::<T, S1, S2>(lhs, &rhs)
1573}
1574
1575impl<'a, T, S1, S2> Mul<&RecordMatrix<'a, T, S2>> for &RecordMatrix<'a, T, S1>
1579where
1580 T: Numeric + Primitive,
1581 for<'t> &'t T: NumericRef<T>,
1582 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1583 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1584{
1585 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
1586 #[track_caller]
1587 fn mul(self, rhs: &RecordMatrix<'a, T, S2>) -> Self::Output {
1588 record_matrix_matrix_multiply::<T, S1, S2>(self, rhs)
1589 }
1590}
1591
1592impl<'a, T, S1, S2> Mul<RecordMatrix<'a, T, S2>> for RecordMatrix<'a, T, S1>
1596where
1597 T: Numeric + Primitive,
1598 for<'t> &'t T: NumericRef<T>,
1599 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1600 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1601{
1602 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
1603 #[track_caller]
1604 fn mul(self, rhs: RecordMatrix<'a, T, S2>) -> Self::Output {
1605 record_matrix_matrix_multiply_value_value::<T, S1, S2>(self, rhs)
1606 }
1607}
1608
1609impl<'a, T, S1, S2> Mul<&RecordMatrix<'a, T, S2>> for RecordMatrix<'a, T, S1>
1613where
1614 T: Numeric + Primitive,
1615 for<'t> &'t T: NumericRef<T>,
1616 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1617 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1618{
1619 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
1620 #[track_caller]
1621 fn mul(self, rhs: &RecordMatrix<'a, T, S2>) -> Self::Output {
1622 record_matrix_matrix_multiply_value_reference::<T, S1, S2>(self, rhs)
1623 }
1624}
1625
1626impl<'a, T, S1, S2> Mul<RecordMatrix<'a, T, S2>> for &RecordMatrix<'a, T, S1>
1630where
1631 T: Numeric + Primitive,
1632 for<'t> &'t T: NumericRef<T>,
1633 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1634 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1635{
1636 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
1637 #[track_caller]
1638 fn mul(self, rhs: RecordMatrix<'a, T, S2>) -> Self::Output {
1639 record_matrix_matrix_multiply_reference_value::<T, S1, S2>(self, rhs)
1640 }
1641}
1642
1643#[track_caller]
1644fn record_tensor_neg_value<'a, T, S, const D: usize>(
1645 lhs: RecordTensor<'a, T, S, D>,
1646) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1647where
1648 T: Numeric + Primitive,
1649 for<'t> &'t T: NumericRef<T>,
1650 S: TensorRef<(T, Index), D>,
1651{
1652 lhs.unary(Negation::<T>::function, Negation::<T>::d_function_dx)
1653}
1654
1655#[track_caller]
1656fn record_tensor_neg_reference<'a, T, S, const D: usize>(
1657 lhs: &RecordTensor<'a, T, S, D>,
1658) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1659where
1660 T: Numeric + Primitive,
1661 for<'t> &'t T: NumericRef<T>,
1662 S: TensorRef<(T, Index), D>,
1663{
1664 lhs.unary(Negation::<T>::function, Negation::<T>::d_function_dx)
1665}
1666
1667record_tensor_operator_impl_value!(impl Neg for RecordTensor { fn neg } record_tensor_neg_value);
1668record_tensor_operator_impl_reference!(impl Neg for RecordTensor { fn neg } record_tensor_neg_reference);
1669
1670#[track_caller]
1671fn record_matrix_neg_reference<'a, T, S>(
1672 lhs: &RecordMatrix<'a, T, S>,
1673) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1674where
1675 T: Numeric + Primitive,
1676 for<'t> &'t T: NumericRef<T>,
1677 S: MatrixRef<(T, Index)> + NoInteriorMutability,
1678{
1679 lhs.unary(Negation::<T>::function, Negation::<T>::d_function_dx)
1680}
1681
1682#[track_caller]
1683fn record_matrix_neg_value<'a, T, S>(
1684 lhs: RecordMatrix<'a, T, S>,
1685) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1686where
1687 T: Numeric + Primitive,
1688 for<'t> &'t T: NumericRef<T>,
1689 S: MatrixRef<(T, Index)> + NoInteriorMutability,
1690{
1691 lhs.unary(Negation::<T>::function, Negation::<T>::d_function_dx)
1692}
1693
1694record_matrix_operator_impl_value!(impl Neg for RecordMatrix { fn neg } record_matrix_neg_value);
1695record_matrix_operator_impl_reference!(impl Neg for RecordMatrix { fn neg } record_matrix_neg_reference);
1696
1697record_real_tensor_operator_impl_unary!(impl Sin for RecordTensor { fn sin } Sine);
1698record_real_matrix_operator_impl_unary!(impl Sin for RecordMatrix { fn sin } Sine);
1699
1700record_real_tensor_operator_impl_unary!(impl Cos for RecordTensor { fn cos } Cosine);
1701record_real_matrix_operator_impl_unary!(impl Cos for RecordMatrix { fn cos } Cosine);
1702
1703record_real_tensor_operator_impl_unary!(impl Exp for RecordTensor { fn exp } Exponential);
1704record_real_matrix_operator_impl_unary!(impl Exp for RecordMatrix { fn exp } Exponential);
1705
1706record_real_tensor_operator_impl_unary!(impl Ln for RecordTensor { fn ln } NaturalLogarithm);
1707record_real_matrix_operator_impl_unary!(impl Ln for RecordMatrix { fn ln } NaturalLogarithm);
1708
1709record_real_tensor_operator_impl_unary!(impl Sqrt for RecordTensor { fn sqrt } SquareRoot);
1710record_real_matrix_operator_impl_unary!(impl Sqrt for RecordMatrix { fn sqrt } SquareRoot);
1711
1712record_real_tensor_operator_impl_scalar!(impl Pow for RecordTensor { fn pow } Power);
1713record_real_matrix_operator_impl_scalar!(impl Pow for RecordMatrix { fn pow } Power);
1714record_real_tensor_operator_impl_scalar_no_orphan_rule!(impl Pow for RecordTensor { fn pow } Power);
1715record_real_matrix_operator_impl_scalar_no_orphan_rule!(impl Pow for RecordMatrix { fn pow } Power);
1716
1717macro_rules! record_tensor_operator_impl_scalar {
1718 (impl $op:tt for RecordTensor { fn $method:ident } $function:ident) => {
1719 impl<'a, T, S1, const D: usize> $op<T> for RecordTensor<'a, T, S1, D>
1724 where
1725 T: Numeric + Primitive,
1726 for<'t> &'t T: NumericRef<T>,
1727 S1: TensorRef<(T, Index), D>,
1728 {
1729 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
1730 #[track_caller]
1731 fn $method(self, rhs: T) -> Self::Output {
1732 self.unary(
1733 |x| $function::<T>::function(x, rhs.clone()),
1734 |x| $function::<T>::d_function_dx(x, rhs.clone()),
1735 )
1736 }
1737 }
1738
1739 impl<'a, T, S1, const D: usize> $op<&T> for RecordTensor<'a, T, S1, D>
1744 where
1745 T: Numeric + Primitive,
1746 for<'t> &'t T: NumericRef<T>,
1747 S1: TensorRef<(T, Index), D>,
1748 {
1749 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
1750 #[track_caller]
1751 fn $method(self, rhs: &T) -> Self::Output {
1752 self.unary(
1753 |x| $function::<T>::function(x, rhs.clone()),
1754 |x| $function::<T>::d_function_dx(x, rhs.clone()),
1755 )
1756 }
1757 }
1758
1759 impl<'a, T, S1, const D: usize> $op<T> for &RecordTensor<'a, T, S1, D>
1764 where
1765 T: Numeric + Primitive,
1766 for<'t> &'t T: NumericRef<T>,
1767 S1: TensorRef<(T, Index), D>,
1768 {
1769 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
1770 #[track_caller]
1771 fn $method(self, rhs: T) -> Self::Output {
1772 self.unary(
1773 |x| $function::<T>::function(x, rhs.clone()),
1774 |x| $function::<T>::d_function_dx(x, rhs.clone()),
1775 )
1776 }
1777 }
1778
1779 impl<'a, T, S1, const D: usize> $op<&T> for &RecordTensor<'a, T, S1, D>
1784 where
1785 T: Numeric + Primitive,
1786 for<'t> &'t T: NumericRef<T>,
1787 S1: TensorRef<(T, Index), D>,
1788 {
1789 type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
1790 #[track_caller]
1791 fn $method(self, rhs: &T) -> Self::Output {
1792 self.unary(
1793 |x| $function::<T>::function(x, rhs.clone()),
1794 |x| $function::<T>::d_function_dx(x, rhs.clone()),
1795 )
1796 }
1797 }
1798 };
1799}
1800
1801record_tensor_operator_impl_scalar!(impl Add for RecordTensor { fn add } Addition);
1802record_tensor_operator_impl_scalar!(impl Sub for RecordTensor { fn sub } Subtraction);
1803record_tensor_operator_impl_scalar!(impl Mul for RecordTensor { fn mul } Multiplication);
1804record_tensor_operator_impl_scalar!(impl Div for RecordTensor { fn div } Division);
1805
1806macro_rules! record_matrix_operator_impl_scalar {
1807 (impl $op:tt for RecordMatrix { fn $method:ident } $function:ident) => {
1808 impl<'a, T, S1> $op<T> for RecordMatrix<'a, T, S1>
1813 where
1814 T: Numeric + Primitive,
1815 for<'t> &'t T: NumericRef<T>,
1816 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1817 {
1818 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
1819 #[track_caller]
1820 fn $method(self, rhs: T) -> Self::Output {
1821 self.unary(
1822 |x| $function::<T>::function(x, rhs.clone()),
1823 |x| $function::<T>::d_function_dx(x, rhs.clone()),
1824 )
1825 }
1826 }
1827
1828 impl<'a, T, S1> $op<&T> for RecordMatrix<'a, T, S1>
1833 where
1834 T: Numeric + Primitive,
1835 for<'t> &'t T: NumericRef<T>,
1836 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1837 {
1838 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
1839 #[track_caller]
1840 fn $method(self, rhs: &T) -> Self::Output {
1841 self.unary(
1842 |x| $function::<T>::function(x, rhs.clone()),
1843 |x| $function::<T>::d_function_dx(x, rhs.clone()),
1844 )
1845 }
1846 }
1847
1848 impl<'a, T, S1> $op<T> for &RecordMatrix<'a, T, S1>
1853 where
1854 T: Numeric + Primitive,
1855 for<'t> &'t T: NumericRef<T>,
1856 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1857 {
1858 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
1859 #[track_caller]
1860 fn $method(self, rhs: T) -> Self::Output {
1861 self.unary(
1862 |x| $function::<T>::function(x, rhs.clone()),
1863 |x| $function::<T>::d_function_dx(x, rhs.clone()),
1864 )
1865 }
1866 }
1867
1868 impl<'a, T, S1> $op<&T> for &RecordMatrix<'a, T, S1>
1873 where
1874 T: Numeric + Primitive,
1875 for<'t> &'t T: NumericRef<T>,
1876 S1: MatrixRef<(T, Index)> + NoInteriorMutability,
1877 {
1878 type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
1879 #[track_caller]
1880 fn $method(self, rhs: &T) -> Self::Output {
1881 self.unary(
1882 |x| $function::<T>::function(x, rhs.clone()),
1883 |x| $function::<T>::d_function_dx(x, rhs.clone()),
1884 )
1885 }
1886 }
1887 };
1888}
1889
1890record_matrix_operator_impl_scalar!(impl Add for RecordMatrix { fn add } Addition);
1891record_matrix_operator_impl_scalar!(impl Sub for RecordMatrix { fn sub } Subtraction);
1892record_matrix_operator_impl_scalar!(impl Mul for RecordMatrix { fn mul } Multiplication);
1893record_matrix_operator_impl_scalar!(impl Div for RecordMatrix { fn div } Division);