easy_ml/differentiation/container_record/container_operations/
swapped.rs

1use crate::differentiation::functions::{Division, FunctionDerivative, Subtraction};
2use crate::differentiation::record_operations::SwappedOperations;
3use crate::differentiation::{Index, Primitive};
4use crate::differentiation::{RecordMatrix, RecordTensor};
5use crate::matrices::Matrix;
6use crate::matrices::views::{MatrixRef, NoInteriorMutability};
7use crate::numeric::{Numeric, NumericRef};
8use crate::tensors::Tensor;
9use crate::tensors::views::TensorRef;
10
11impl<'a, T, S, const D: usize> SwappedOperations<T> for RecordTensor<'a, T, S, D>
12where
13    T: Numeric + Primitive,
14    for<'t> &'t T: NumericRef<T>,
15    S: TensorRef<(T, Index), D>,
16{
17    type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
18
19    /**
20     * Subtraction for a record tensor and a constant, where the constant
21     * is the left hand side, ie C - record.
22     */
23    #[track_caller]
24    fn sub_swapped(self, lhs: T) -> Self::Output {
25        self.unary(
26            // We want with respect to y because it is the right hand side here that we
27            // need the derivative for (since left is a constant).
28            |x| Subtraction::<T>::function(lhs.clone(), x),
29            |x| Subtraction::<T>::d_function_dy(lhs.clone(), x),
30        )
31    }
32
33    /**
34     * Division for a record tensor and a constant, where the constant
35     * is the left hand side, ie C / record.
36     */
37    #[track_caller]
38    fn div_swapped(self, lhs: T) -> Self::Output {
39        self.unary(
40            // We want with respect to y because it is the right hand side here that we
41            // need the derivative for (since left is a constant).
42            |x| Division::<T>::function(lhs.clone(), x),
43            |x| Division::<T>::d_function_dy(lhs.clone(), x),
44        )
45    }
46}
47
48impl<'a, T, S, const D: usize> SwappedOperations<&T> for RecordTensor<'a, T, S, D>
49where
50    T: Numeric + Primitive,
51    for<'t> &'t T: NumericRef<T>,
52    S: TensorRef<(T, Index), D>,
53{
54    type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
55
56    /**
57     * Subtraction for a record tensor and a constant, where the constant
58     * is the left hand side, ie C - record.
59     */
60    #[track_caller]
61    fn sub_swapped(self, lhs: &T) -> Self::Output {
62        self.unary(
63            // We want with respect to y because it is the right hand side here that we
64            // need the derivative for (since left is a constant).
65            |x| Subtraction::<T>::function(lhs.clone(), x),
66            |x| Subtraction::<T>::d_function_dy(lhs.clone(), x),
67        )
68    }
69
70    /**
71     * Division for a record tensor and a constant, where the constant
72     * is the left hand side, ie C / record.
73     */
74    #[track_caller]
75    fn div_swapped(self, lhs: &T) -> Self::Output {
76        self.unary(
77            // We want with respect to y because it is the right hand side here that we
78            // need the derivative for (since left is a constant).
79            |x| Division::<T>::function(lhs.clone(), x),
80            |x| Division::<T>::d_function_dy(lhs.clone(), x),
81        )
82    }
83}
84
85impl<'a, T, S, const D: usize> SwappedOperations<T> for &RecordTensor<'a, T, S, D>
86where
87    T: Numeric + Primitive,
88    for<'t> &'t T: NumericRef<T>,
89    S: TensorRef<(T, Index), D>,
90{
91    type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
92
93    /**
94     * Subtraction for a record tensor and a constant, where the constant
95     * is the left hand side, ie C - record.
96     */
97    #[track_caller]
98    fn sub_swapped(self, lhs: T) -> Self::Output {
99        self.unary(
100            // We want with respect to y because it is the right hand side here that we
101            // need the derivative for (since left is a constant).
102            |x| Subtraction::<T>::function(lhs.clone(), x),
103            |x| Subtraction::<T>::d_function_dy(lhs.clone(), x),
104        )
105    }
106
107    /**
108     * Division for a record tensor and a constant, where the constant
109     * is the left hand side, ie C / record.
110     */
111    #[track_caller]
112    fn div_swapped(self, lhs: T) -> Self::Output {
113        self.unary(
114            // We want with respect to y because it is the right hand side here that we
115            // need the derivative for (since left is a constant).
116            |x| Division::<T>::function(lhs.clone(), x),
117            |x| Division::<T>::d_function_dy(lhs.clone(), x),
118        )
119    }
120}
121
122impl<'a, T, S, const D: usize> SwappedOperations<&T> for &RecordTensor<'a, T, S, D>
123where
124    T: Numeric + Primitive,
125    for<'t> &'t T: NumericRef<T>,
126    S: TensorRef<(T, Index), D>,
127{
128    type Output = RecordTensor<'a, T, Tensor<(T, Index), D>, D>;
129
130    /**
131     * Subtraction for a record tensor and a constant, where the constant
132     * is the left hand side, ie C - record.
133     */
134    #[track_caller]
135    fn sub_swapped(self, lhs: &T) -> Self::Output {
136        self.unary(
137            // We want with respect to y because it is the right hand side here that we
138            // need the derivative for (since left is a constant).
139            |x| Subtraction::<T>::function(lhs.clone(), x),
140            |x| Subtraction::<T>::d_function_dy(lhs.clone(), x),
141        )
142    }
143
144    /**
145     * Division for a record tensor and a constant, where the constant
146     * is the left hand side, ie C / record.
147     */
148    #[track_caller]
149    fn div_swapped(self, lhs: &T) -> Self::Output {
150        self.unary(
151            // We want with respect to y because it is the right hand side here that we
152            // need the derivative for (since left is a constant).
153            |x| Division::<T>::function(lhs.clone(), x),
154            |x| Division::<T>::d_function_dy(lhs.clone(), x),
155        )
156    }
157}
158
159impl<'a, T, S> SwappedOperations<T> for RecordMatrix<'a, T, S>
160where
161    T: Numeric + Primitive,
162    for<'t> &'t T: NumericRef<T>,
163    S: MatrixRef<(T, Index)> + NoInteriorMutability,
164{
165    type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
166
167    /**
168     * Subtraction for a record matrix and a constant, where the constant
169     * is the left hand side, ie C - record.
170     */
171    #[track_caller]
172    fn sub_swapped(self, lhs: T) -> Self::Output {
173        self.unary(
174            // We want with respect to y because it is the right hand side here that we
175            // need the derivative for (since left is a constant).
176            |x| Subtraction::<T>::function(lhs.clone(), x),
177            |x| Subtraction::<T>::d_function_dy(lhs.clone(), x),
178        )
179    }
180
181    /**
182     * Division for a record matrix and a constant, where the constant
183     * is the left hand side, ie C / record.
184     */
185    #[track_caller]
186    fn div_swapped(self, lhs: T) -> Self::Output {
187        self.unary(
188            // We want with respect to y because it is the right hand side here that we
189            // need the derivative for (since left is a constant).
190            |x| Division::<T>::function(lhs.clone(), x),
191            |x| Division::<T>::d_function_dy(lhs.clone(), x),
192        )
193    }
194}
195
196impl<'a, T, S> SwappedOperations<&T> for RecordMatrix<'a, T, S>
197where
198    T: Numeric + Primitive,
199    for<'t> &'t T: NumericRef<T>,
200    S: MatrixRef<(T, Index)> + NoInteriorMutability,
201{
202    type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
203
204    /**
205     * Subtraction for a record matrix and a constant, where the constant
206     * is the left hand side, ie C - record.
207     */
208    #[track_caller]
209    fn sub_swapped(self, lhs: &T) -> Self::Output {
210        self.unary(
211            // We want with respect to y because it is the right hand side here that we
212            // need the derivative for (since left is a constant).
213            |x| Subtraction::<T>::function(lhs.clone(), x),
214            |x| Subtraction::<T>::d_function_dy(lhs.clone(), x),
215        )
216    }
217
218    /**
219     * Division for a record matrix and a constant, where the constant
220     * is the left hand side, ie C / record.
221     */
222    #[track_caller]
223    fn div_swapped(self, lhs: &T) -> Self::Output {
224        self.unary(
225            // We want with respect to y because it is the right hand side here that we
226            // need the derivative for (since left is a constant).
227            |x| Division::<T>::function(lhs.clone(), x),
228            |x| Division::<T>::d_function_dy(lhs.clone(), x),
229        )
230    }
231}
232
233impl<'a, T, S> SwappedOperations<T> for &RecordMatrix<'a, T, S>
234where
235    T: Numeric + Primitive,
236    for<'t> &'t T: NumericRef<T>,
237    S: MatrixRef<(T, Index)> + NoInteriorMutability,
238{
239    type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
240
241    /**
242     * Subtraction for a record matrix and a constant, where the constant
243     * is the left hand side, ie C - record.
244     */
245    #[track_caller]
246    fn sub_swapped(self, lhs: T) -> Self::Output {
247        self.unary(
248            // We want with respect to y because it is the right hand side here that we
249            // need the derivative for (since left is a constant).
250            |x| Subtraction::<T>::function(lhs.clone(), x),
251            |x| Subtraction::<T>::d_function_dy(lhs.clone(), x),
252        )
253    }
254
255    /**
256     * Division for a record matrix and a constant, where the constant
257     * is the left hand side, ie C / record.
258     */
259    #[track_caller]
260    fn div_swapped(self, lhs: T) -> Self::Output {
261        self.unary(
262            // We want with respect to y because it is the right hand side here that we
263            // need the derivative for (since left is a constant).
264            |x| Division::<T>::function(lhs.clone(), x),
265            |x| Division::<T>::d_function_dy(lhs.clone(), x),
266        )
267    }
268}
269
270impl<'a, T, S> SwappedOperations<&T> for &RecordMatrix<'a, T, S>
271where
272    T: Numeric + Primitive,
273    for<'t> &'t T: NumericRef<T>,
274    S: MatrixRef<(T, Index)> + NoInteriorMutability,
275{
276    type Output = RecordMatrix<'a, T, Matrix<(T, Index)>>;
277
278    /**
279     * Subtraction for a record matrix and a constant, where the constant
280     * is the left hand side, ie C - record.
281     */
282    #[track_caller]
283    fn sub_swapped(self, lhs: &T) -> Self::Output {
284        self.unary(
285            // We want with respect to y because it is the right hand side here that we
286            // need the derivative for (since left is a constant).
287            |x| Subtraction::<T>::function(lhs.clone(), x),
288            |x| Subtraction::<T>::d_function_dy(lhs.clone(), x),
289        )
290    }
291
292    /**
293     * Division for a record matrix and a constant, where the constant
294     * is the left hand side, ie C / record.
295     */
296    #[track_caller]
297    fn div_swapped(self, lhs: &T) -> Self::Output {
298        self.unary(
299            // We want with respect to y because it is the right hand side here that we
300            // need the derivative for (since left is a constant).
301            |x| Division::<T>::function(lhs.clone(), x),
302            |x| Division::<T>::d_function_dy(lhs.clone(), x),
303        )
304    }
305}