easy_ml/differentiation/container_record/mod.rs
1use crate::differentiation::functions::{Division, FunctionDerivative, Multiplication};
2use crate::differentiation::iterators::{
3 AsRecords, InconsistentHistory, InvalidRecordIteratorError,
4};
5use crate::differentiation::record_operations;
6use crate::differentiation::{Derivatives, Index, Primitive, Record, WengertList};
7use crate::matrices::iterators::{
8 ColumnMajorIterator, RowMajorIterator, RowMajorOwnedIterator, RowMajorReferenceMutIterator,
9};
10use crate::matrices::views::{MatrixMut, MatrixRef, MatrixView, NoInteriorMutability};
11use crate::matrices::{Column, Matrix, Row};
12use crate::numeric::{Numeric, NumericRef};
13use crate::tensors::indexing::{
14 TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceMutIterator,
15};
16use crate::tensors::views::{DataLayout, TensorMut, TensorRef, TensorRename, TensorView};
17use crate::tensors::{Dimension, Tensor};
18
19mod container_operations;
20pub mod iterators;
21
22/**
23 * A pluralisation of [Record](crate::differentiation::Record) that groups together a
24 * **s**ource of numbers of type T and stores the WengertList only once.
25 *
26 * Typically you would refer to one of the type aliases to disambiguate the type of `S` and
27 * use more succinct generics: [RecordMatrix](RecordMatrix), [RecordTensor](RecordTensor).
28 *
29 * For both Matrix and Tensor source types, the containers implement [`+`](std::ops::Add) and
30 * [`-`](std::ops::Sub) and have the methods `elementwise_multiply` and `elementwise_divide`.
31 * In all cases the containers must have the same size for the operation and will panic if
32 * mismatched.
33 * [`*`](std::ops::Mul) is also implemented for 2 dimensional tensors and matrices as matrix
34 * multiplication.
35 *
36 * For convenience, as with Trace and Record, many unary operations including
37 * [Cos](crate::numeric::extra::Cos), [Exp](crate::numeric::extra::Exp),
38 * [Ln](crate::numeric::extra::Ln), [Neg](std::ops::Neg), [Pow](crate::numeric::extra::Pow),
39 * [Sin](crate::numeric::extra::Sin), and [Sqrt](crate::numeric::extra::Sqrt) are implemented as
40 * well, applying the unary function to each element in the tensor.
41 *
42 * `+`, `-`, `*` and `/` operations with a RecordContainer and a scalar are also implemented,
43 * treating the right hand side scalar as a constant. These are also unary functions in terms of
44 * the derivative tracking, for example `X + 5` applies the function `+5` to each element in
45 * `X`. Due to the orphan rule, the standard library scalars cannot be implemented for a left hand
46 * side scalar, see [SwappedOperations](crate::differentiation::record_operations::SwappedOperations).
47 *
48 * Both [RecordMatrix](RecordMatrix) and [RecordTensor](RecordTensor) implement
49 * [MatrixRef](MatrixRef) and [TensorRef](TensorRef) respectively, which provide read and write
50 * access to the underlying numbers and indexes into the WengertList. These APIs along with the
51 * `from_existing` constructors for RecordMatrix, RecordTensor, and Record allow arbitary
52 * manipulation of specific elements in a record container if needed. However, any arithmetic
53 * performed on the raw data won't be tracked on the WengertList and overwriting data within a
54 * record container already tracked on the WengertList could result in nonsense. These APIs exist
55 * for easy read access to check the data in a record container and for read/write access when
56 * manipulating the shape of a record container, and are designed to be used only for moving data
57 * around - you should put it back unchanged in a RecordContainer or Record before doing further
58 * arithmetic that needs to be tracked on the WengertList.
59 *
60 * If you just want to manipulate the data in record containers as if they were Records you can
61 * use the iterator APIs of [AsRecords](AsRecords) instead and collect them back into containers
62 * when you're done, or to manipulate a single container, the [map](RecordContainer::map) and
63 * [map_mut](RecordContainer::map_mut) methods.
64 */
65#[derive(Debug)]
66pub struct RecordContainer<'a, T: Primitive, S, const D: usize> {
67 // Opted to store the indexes alongside each number (T, Index) for a number of reasons, the
68 // main factor being it makes implementing TensorRef feasible so can utilise the range of
69 // existing APIs for Tensor manipulation. It's theoretically possible to only store the first
70 // index and calculate the rest, since most of the time all indexes are ascending entries in
71 // the WengertList but this would also massively complicate the implementation, especially for
72 // handling non row-major operations such as matrix multiplication. It's also not super clear
73 // that this would be more efficient because it turns reads into more arithmetic rather
74 // than avoiding any work. Just lifting the WengertList out of the tensor should have
75 // meaningful improvements to cache line efficiency, and failing that still disallows
76 // very questionable states from existing.
77 numbers: S,
78 history: Option<&'a WengertList<T>>,
79}
80
81/**
82 * Alias for succinctly referring to RecordContainers backed by a matrix.
83 */
84pub type RecordMatrix<'a, T, S> = RecordContainer<'a, T, MatrixView<(T, Index), S>, 2>;
85
86/**
87 * Alias for succinctly referring to RecordContainers backed by a tensor.
88 */
89pub type RecordTensor<'a, T, S, const D: usize> =
90 RecordContainer<'a, T, TensorView<(T, Index), S, D>, D>;
91
92fn calculate_incrementing_indexes(starting_index: usize, total: usize) -> Vec<Index> {
93 let mut indexes = vec![0; total];
94 for (i, x) in indexes.iter_mut().enumerate() {
95 *x = starting_index + i;
96 }
97 indexes
98}
99
100impl<'a, T, const D: usize> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
101where
102 T: Numeric + Primitive,
103{
104 /**
105 * Creates multiple untracked Records which have no backing WengertList.
106 *
107 * This is provided for using constants along with Records in operations.
108 *
109 * For example with `Y = X + 4` the computation graph could be conceived as many
110 * `Y[i,j]` nodes with parent nodes of `X[i,j]` and 4 combined with the operation `+`.
111 * However there is no need to record the derivatives of a constant, so
112 * instead the computation graph can be conceived as `Y[i,j]` nodes each with a single
113 * parent node of `X[i,j]` and the unary operation of `+4`.
114 */
115 pub fn constants<S>(c: S) -> Self
116 where
117 S: TensorMut<T, D>,
118 {
119 RecordContainer {
120 numbers: TensorView::from(Tensor::from(
121 c.view_shape(),
122 TensorOwnedIterator::from_numeric(c)
123 .map(|x| (x, 0))
124 .collect(),
125 )),
126 history: None,
127 }
128 }
129
130 /**
131 * Creates multiple records backed by the provided WengertList.
132 *
133 * The records cannot live longer than the WengertList, hence
134 * the following example does not compile
135 *
136 * ```compile_fail
137 * use easy_ml::differentiation::RecordTensor;
138 * use easy_ml::differentiation::WengertList;
139 * use easy_ml::tensors::Tensor;
140 * let record = {
141 * let list = WengertList::new();
142 * RecordTensor::variables(
143 * &list,
144 * Tensor::from([("r", 2), ("c", 2)], vec![ 1.0, 2.0, 3.0, 4.0 ])
145 * )
146 * }; // list no longer in scope
147 * ```
148 */
149 pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
150 where
151 S: TensorMut<T, D>,
152 {
153 let total = crate::tensors::dimensions::elements(&x.view_shape());
154 let starting_index = history.append_nullary_repeating(total);
155 RecordContainer {
156 numbers: TensorView::from(Tensor::from(
157 x.view_shape(),
158 TensorOwnedIterator::from_numeric(x)
159 .zip(calculate_incrementing_indexes(starting_index, total))
160 .collect(),
161 )),
162 history: Some(history),
163 }
164 }
165}
166
167impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
168where
169 T: Numeric + Primitive,
170 S: TensorRef<(T, Index), D>,
171{
172 /**
173 * Returns the number of elements stored by this container's source.
174 *
175 * For a 2 x 3 Tensor, this would return 6, and for a 2 x 3 x 4 Tensor this would return 24
176 * and so on.
177 *
178 * see also [dimensions::elements](crate::tensors::dimensions::elements)
179 */
180 pub fn elements(&self) -> usize {
181 crate::tensors::dimensions::elements(&self.numbers.shape())
182 }
183
184 /**
185 * The shape of this container's source.
186 */
187 pub fn shape(&self) -> [(Dimension, usize); D] {
188 self.numbers.shape()
189 }
190
191 /**
192 * Creates a container from constants/variables directly, most likely obtained by getting a
193 * tensor view of an existing container. **The inputs are not checked for validity**. It is
194 * possible to pass in the wrong Wengert list here or even numbers with indexes that aren't
195 * tracked on the WengertList.
196 *
197 * It is recommended to use this constructor only in conjunction with
198 * resizing or masking an existing container and not for creating new variables. Any variables
199 * created outside of `RecordContainer::variables` would have to be manually added to the
200 * correct Wengert list, and any arithmetic operations would also need tracking correctly.
201 *
202 * ```
203 * use easy_ml::differentiation::RecordTensor;
204 * use easy_ml::differentiation::WengertList;
205 * use easy_ml::tensors::Tensor;
206 * use easy_ml::tensors::views::{TensorView, TensorRange};
207 *
208 * let list = WengertList::new();
209 * let x = RecordTensor::variables(
210 * &list,
211 * Tensor::from_fn([("x", 2), ("y", 2)], |[r, c]| ((r + 3) * (c + 2)) as f64)
212 * );
213 * // oh no wrong shape!
214 * let fixed = TensorView::from(TensorRange::from(x, [("y", 0..1)]).unwrap()); // we can unwrap here because we know the range is valid
215 * let x = RecordTensor::from_existing(Some(&list), fixed);
216 * assert_eq!([("x", 2), ("y", 1)], x.shape());
217 * ```
218 */
219 pub fn from_existing(
220 history: Option<&'a WengertList<T>>,
221 numbers: TensorView<(T, Index), S, D>,
222 ) -> Self {
223 RecordContainer { numbers, history }
224 }
225
226 /**
227 * Returns a record tensor with the dimension names of the shape renamed to the provided
228 * dimensions. The data of this container and the dimension lengths and order remain unchanged.
229 *
230 * This is a shorthand for constructing the RecordTensor via manipulating this TensorView. See
231 * [`RecordTensor::from_existing`](RecordTensor::from_existing).
232 *
233 * # Panics
234 *
235 * If a dimension name is not unique
236 *
237 * ```
238 * use easy_ml::differentiation::RecordTensor;
239 * use easy_ml::differentiation::WengertList;
240 * use easy_ml::tensors::Tensor;
241 * use easy_ml::tensors::views::{TensorView, TensorRename};
242 *
243 * let list = WengertList::new();
244 * let x = RecordTensor::variables(
245 * &list,
246 * Tensor::from_fn([("x", 2), ("y", 2)], |[r, c]| ((r + 3) * (c + 2)) as f64)
247 * );
248 * // oh no wrong dimension names!
249 * let x = x.rename_view(["a", "b"]);
250 * assert_eq!([("a", 2), ("b", 2)], x.shape());
251 * ```
252 */
253 #[track_caller]
254 pub fn rename_view(
255 self,
256 dimensions: [Dimension; D],
257 ) -> RecordTensor<'a, T, TensorRename<(T, Index), S, D>, D> {
258 RecordTensor::from_existing(
259 self.history,
260 TensorView::from(TensorRename::from(self.numbers.source(), dimensions)),
261 )
262 }
263
264 /**
265 * Returns a TensorView of this record container, both the `T` for each record element and
266 * also the index for that record's entry in the WengertList. These can be parsed back into
267 * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
268 * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
269 * numerical operations on the data.
270 */
271 pub fn view(&self) -> TensorView<(T, Index), &RecordTensor<'a, T, S, D>, D> {
272 TensorView::from(self)
273 }
274
275 /**
276 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
277 * to read values from this record container, both the `T` for each record element and
278 * also the index for that record's entry in the WengertList. These can be parsed back into
279 * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
280 * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
281 * numerical operations on the data.
282 *
283 * # Panics
284 *
285 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
286 */
287 #[track_caller]
288 pub fn index_by(
289 &self,
290 dimensions: [Dimension; D],
291 ) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
292 TensorAccess::from(self, dimensions)
293 }
294
295 /**
296 * Creates a TensorAccess which will index into the dimensions this record was created with
297 * in the same order as they were provided, both the `T` for each record element and
298 * also the index for that record's entry in the WengertList. These can be parsed back into
299 * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
300 * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
301 * numerical operations on the data.
302 *
303 * See [TensorAccess::from_source_order], [get_as_record](TensorAccess::get_as_record).
304 *
305 * ```
306 * use easy_ml::differentiation::RecordTensor;
307 * use easy_ml::differentiation::WengertList;
308 * use easy_ml::tensors::Tensor;
309 *
310 * let list = WengertList::new();
311 * let X = RecordTensor::variables(
312 * &list,
313 * Tensor::from([("a", 3)], vec![ 3.0, 4.0, 5.0 ])
314 * );
315 * let x = X.index().get_as_record([0]);
316 * assert_eq!(x.number, 3.0);
317 * ```
318 */
319 pub fn index(&self) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
320 TensorAccess::from_source_order(self)
321 }
322
323 /**
324 * Returns an iterator over this record tensor as [Record]s instead of the raw `(T, Index)`
325 * data. After manipulating the iterator it can be collected back into a RecordTensor with
326 * [RecordTensor::from_iter](RecordTensor::from_iter).
327 *
328 * This is a shorthand for `AsRecords::from(tensor.history(), TensorIterator::from(&tensor))`
329 */
330 #[allow(clippy::type_complexity)]
331 pub fn iter_as_records<'b>(
332 &'b self,
333 ) -> AsRecords<'a, TensorIterator<'b, (T, Index), RecordTensor<'a, T, S, D>, D>, T> {
334 AsRecords::from_tensor(self)
335 }
336}
337
338impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
339where
340 T: Numeric + Primitive,
341 S: TensorMut<(T, Index), D>,
342{
343 /**
344 * Resets all of the records to place them back on the WengertList, for use
345 * in performing another derivation after clearing the WengertList.
346 *
347 * This is also a preferred shorthand for `map_mut(Record::do_reset)` that can't fail
348 */
349 pub fn reset(&mut self) {
350 match self.history {
351 None => (), // noop
352 Some(history) => {
353 let total = self.elements();
354 let starting_index = history.append_nullary_repeating(total);
355 for (x, i) in self
356 .numbers
357 .iter_reference_mut()
358 .zip(calculate_incrementing_indexes(starting_index, total))
359 {
360 let (_, old_index) = x;
361 *old_index = i;
362 }
363 }
364 };
365 }
366
367 /**
368 * A convenience helper function which takes a RecordContainer by value and
369 * calls [reset](RecordTensor::reset()) on it.
370 */
371 pub fn do_reset(mut x: Self) -> Self {
372 x.reset();
373 x
374 }
375}
376
377impl<'a, T> RecordMatrix<'a, T, Matrix<(T, Index)>>
378where
379 T: Numeric + Primitive,
380{
381 /**
382 * Creates multiple untracked Records which have no backing WengertList.
383 *
384 * This is provided for using constants along with Records in operations.
385 *
386 * For example with `Y = X + 4` the computation graph could be conceived as many
387 * `Y[i,j]` nodes with parent nodes of `X[i,j]` and 4 combined with the operation `+`.
388 * However there is no need to record the derivatives of a constant, so
389 * instead the computation graph can be conceived as `Y[i,j]` nodes each with a single
390 * parent node of `X[i,j]` and the unary operation of `+4`.
391 */
392 pub fn constants<S>(c: S) -> Self
393 where
394 S: MatrixMut<T> + NoInteriorMutability,
395 {
396 RecordContainer {
397 numbers: MatrixView::from(Matrix::from_flat_row_major(
398 (c.view_rows(), c.view_columns()),
399 RowMajorOwnedIterator::from_numeric(c)
400 .map(|x| (x, 0))
401 .collect(),
402 )),
403 history: None,
404 }
405 }
406
407 /**
408 * Creates multiple records backed by the provided WengertList.
409 *
410 * The records cannot live longer than the WengertList, hence
411 * the following example does not compile
412 *
413 * ```compile_fail
414 * use easy_ml::differentiation::RecordMatrix;
415 * use easy_ml::differentiation::WengertList;
416 * use easy_ml::matrices::Matrix;
417 * let record = {
418 * let list = WengertList::new();
419 * RecordMatrix::variables(
420 * &list,
421 * Matrix::from(vec![vec![1.0, 2.0], vec![3.0, 4.0]])
422 * )
423 * }; // list no longer in scope
424 * ```
425 */
426 pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
427 where
428 S: MatrixMut<T> + NoInteriorMutability,
429 {
430 let total = x.view_rows() * x.view_columns();
431 let starting_index = history.append_nullary_repeating(total);
432 RecordContainer {
433 numbers: MatrixView::from(Matrix::from_flat_row_major(
434 (x.view_rows(), x.view_columns()),
435 RowMajorOwnedIterator::from_numeric(x)
436 .zip(calculate_incrementing_indexes(starting_index, total))
437 .collect(),
438 )),
439 history: Some(history),
440 }
441 }
442}
443
444impl<'a, T, S> RecordMatrix<'a, T, S>
445where
446 T: Numeric + Primitive,
447 S: MatrixRef<(T, Index)> + NoInteriorMutability,
448{
449 /**
450 * Returns the number of elements stored by this container's source.
451 *
452 * For a 2 x 3 Matrix, this would return 6, and for a 3 x 4 Matrix this would return 12
453 * and so on.
454 */
455 pub fn elements(&self) -> usize {
456 self.numbers.rows() * self.numbers.columns()
457 }
458
459 /**
460 * Returns the dimensionality of this matrix container in Row, Column format
461 */
462 pub fn size(&self) -> (Row, Column) {
463 self.numbers.size()
464 }
465
466 /**
467 * Gets the number of rows visible to this matrix container.
468 */
469 pub fn rows(&self) -> Row {
470 self.numbers.rows()
471 }
472
473 /**
474 * Gets the number of columns visible to this matrix container.
475 */
476 pub fn columns(&self) -> Column {
477 self.numbers.columns()
478 }
479
480 /**
481 * Creates a container from constants/variables directly, most likely obtained by getting a
482 * matrix view of an existing container. **The inputs are not checked for validity**. It is
483 * possible to pass in the wrong Wengert list here or even numbers with indexes that aren't
484 * tracked on the WengertList.
485 *
486 * It is recommended to use this constructor only in conjunction with
487 * resizing or masking an existing container and not for creating new variables. Any variables
488 * created outside of `RecordContainer::variables` would have to be manually added to the
489 * correct Wengert list, and any arithmetic operations would also need tracking correctly.
490 *
491 * ```
492 * use easy_ml::differentiation::RecordMatrix;
493 * use easy_ml::differentiation::WengertList;
494 * use easy_ml::matrices::Matrix;
495 * use easy_ml::matrices::views::{MatrixView, MatrixRange};
496 *
497 * let list = WengertList::new();
498 * let x = RecordMatrix::variables(
499 * &list,
500 * Matrix::from_fn((2, 2), |(r, c)| ((r + 3) * (c + 2)) as f64)
501 * );
502 * // oh no wrong shape!
503 * let fixed = MatrixView::from(MatrixRange::from(x, 0..2, 0..1));
504 * let x = RecordMatrix::from_existing(Some(&list), fixed);
505 * assert_eq!((2, 1), x.size());
506 * ```
507 */
508 pub fn from_existing(
509 history: Option<&'a WengertList<T>>,
510 numbers: MatrixView<(T, Index), S>,
511 ) -> Self {
512 RecordContainer { numbers, history }
513 }
514
515 /**
516 * Returns a MatrixView of this record container, both the `T` for each record element and
517 * also the index for that record's entry in the WengertList. These can be parsed back into
518 * a RecordMatrix with [`from_existing`](RecordMatrix::from_existing) or individually into
519 * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
520 * numerical operations on the data.
521 */
522 pub fn view(&self) -> MatrixView<(T, Index), &RecordMatrix<'a, T, S>> {
523 MatrixView::from(self)
524 }
525
526 /**
527 * Returns an iterator over this record matrix as [Record]s instead of the raw `(T, Index)`
528 * data. After manipulating the iterator it can be collected back into a RecordMatrix with
529 * [RecordMatrix::from_iter](RecordMatrix::from_iter).
530 *
531 * This is a shorthand for `AsRecords::from(matrix.history(), RowMajorIterator::from(&matrix))`
532 */
533 #[allow(clippy::type_complexity)]
534 pub fn iter_row_major_as_records<'b>(
535 &'b self,
536 ) -> AsRecords<'a, RowMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
537 AsRecords::from_matrix_row_major(self)
538 }
539
540 /**
541 * Returns an iterator over this record matrix as [Record]s instead of the raw `(T, Index)`
542 * data. After manipulating the iterator it can be collected back into a RecordMatrix with
543 * [RecordMatrix::from_iter](RecordMatrix::from_iter).
544 *
545 * This is a shorthand for
546 * `AsRecords::from(matrix.history(), ColumnMajorIterator::from(&matrix))`
547 */
548 #[allow(clippy::type_complexity)]
549 pub fn iter_column_major_as_records<'b>(
550 &'b self,
551 ) -> AsRecords<'a, ColumnMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
552 AsRecords::from_matrix_column_major(self)
553 }
554
555 /**
556 * Returns a copy of the data at the index as a Record. If you need to access all the data
557 * as records instead of just a specific index you should probably use one of the iterator
558 * APIs instead.
559 *
560 * See also: [iter_row_major_as_records](RecordMatrix::iter_row_major_as_records),
561 * [iter_column_major_as_records](RecordMatrix::iter_column_major_as_records)
562 *
563 * # Panics
564 *
565 * If the index is out of range.
566 *
567 * For a non panicking API see [try_get_as_record](RecordMatrix::try_get_as_record)
568 */
569 #[track_caller]
570 pub fn get_as_record(&self, row: Row, column: Column) -> Record<'a, T> {
571 Record::from_existing(self.numbers.get(row, column), self.history)
572 }
573
574 /**
575 * Returns a copy of the data at the index as a Record, or None if the index is
576 * out of range. If you need to access all the data as records instead of just a specific
577 * index you should probably use one of the iterator APIs instead.
578 *
579 * See also: [iter_row_major_as_records](RecordMatrix::iter_row_major_as_records),
580 * [iter_column_major_as_records](RecordMatrix::iter_column_major_as_records)
581 */
582 pub fn try_get_as_record(&self, row: Row, column: Column) -> Option<Record<'a, T>> {
583 self.numbers
584 .try_get_reference(row, column)
585 .map(|r| Record::from_existing(r.clone(), self.history))
586 }
587}
588
589impl<'a, T, S> RecordMatrix<'a, T, S>
590where
591 T: Numeric + Primitive,
592 S: MatrixMut<(T, Index)> + NoInteriorMutability,
593{
594 /**
595 * Resets all of the records to place them back on the WengertList, for use
596 * in performing another derivation after clearing the WengertList.
597 *
598 * This is also a preferred shorthand for `map_mut(Record::do_reset)` that can't fail
599 */
600 pub fn reset(&mut self) {
601 match self.history {
602 None => (), // noop
603 Some(history) => {
604 let total = self.elements();
605 let starting_index = history.append_nullary_repeating(total);
606 for (x, i) in self
607 .numbers
608 .row_major_reference_mut_iter()
609 .zip(calculate_incrementing_indexes(starting_index, total))
610 {
611 let (_, old_index) = x;
612 *old_index = i;
613 }
614 }
615 };
616 }
617
618 /**
619 * A convenience helper function which takes a RecordContainer by value and
620 * calls [reset](RecordMatrix::reset()) on it.
621 */
622 pub fn do_reset(mut x: Self) -> Self {
623 x.reset();
624 x
625 }
626}
627
628impl<'a, T, S, const D: usize> RecordContainer<'a, T, S, D>
629where
630 T: Primitive,
631{
632 /**
633 * Gets the WengertList these records are backed by if variables, and [None](None) if constants.
634 */
635 pub fn history(&self) -> Option<&'a WengertList<T>> {
636 self.history
637 }
638}
639
640/// Returns the vec of indexes and vec of ys for Y = unary(X), not checking but assuming that the
641/// length of the iterator matches the total.
642fn unary<'a, T, I>(
643 total: usize,
644 history: &WengertList<T>,
645 records: I,
646 fx: impl Fn(T) -> T,
647 dfx_dx: impl Fn(T) -> T,
648) -> Vec<(T, usize)>
649where
650 I: Iterator<Item = (T, Index)>,
651 T: Numeric + Primitive,
652 for<'t> &'t T: NumericRef<T>,
653{
654 let mut ys = vec![(T::zero(), 0); total];
655 history.borrow(|history| {
656 // shadow the name so we can't accidentally try to use history while holding
657 // the borrow
658 // use enumerate not with_index because we need the 1D index for indexing
659 // indexes
660 for (i, (x, parent)) in records.enumerate() {
661 let y = fx(x.clone());
662 let derivative = dfx_dx(x);
663 let new_index = history.append_unary(parent, derivative);
664 ys[i] = (y, new_index)
665 }
666 }); // drop borrow on history
667 ys
668}
669
670/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), not checking but assuming that
671/// the length of the iterators match the total. Also assumes both inputs have the same shape
672fn binary_both_history<'a, T, I1, I2>(
673 total: usize,
674 history: &WengertList<T>,
675 x_records: I1,
676 y_records: I2,
677 fxy: impl Fn(T, T) -> T,
678 dfxy_dx: impl Fn(T, T) -> T,
679 dfxy_dy: impl Fn(T, T) -> T,
680) -> Vec<(T, usize)>
681where
682 I1: Iterator<Item = (T, Index)>,
683 I2: Iterator<Item = (T, Index)>,
684 T: Numeric + Primitive,
685 for<'t> &'t T: NumericRef<T>,
686{
687 let mut zs = vec![(T::zero(), 0); total];
688 history.borrow(|history| {
689 // shadow the name so we can't accidentally try to use history while holding
690 // the borrow
691 // use enumerate not with_index because we need the 1D index for indexing
692 // indexes
693 for (i, ((x, parent1), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
694 let z = fxy(x.clone(), y.clone());
695 let derivative1 = dfxy_dx(x.clone(), y.clone());
696 let derivative2 = dfxy_dy(x, y);
697 let new_index = history.append_binary(parent1, derivative1, parent2, derivative2);
698 zs[i] = (z, new_index);
699 }
700 }); // drop borrow on history
701 zs
702}
703
704/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), as with binary_both_history,
705/// but only tracking the derivatives for X, not Y.
706fn binary_x_history<'a, T, I1, I2>(
707 total: usize,
708 history: &WengertList<T>,
709 x_records: I1,
710 y_records: I2,
711 fxy: impl Fn(T, T) -> T,
712 dfxy_dx: impl Fn(T, T) -> T,
713) -> Vec<(T, usize)>
714where
715 I1: Iterator<Item = (T, Index)>,
716 I2: Iterator<Item = (T, Index)>,
717 T: Numeric + Primitive,
718 for<'t> &'t T: NumericRef<T>,
719{
720 let mut zs = vec![(T::zero(), 0); total];
721 history.borrow(|history| {
722 // shadow the name so we can't accidentally try to use history while holding
723 // the borrow
724 // use enumerate not with_index because we need the 1D index for indexing
725 // indexes
726 for (i, ((x, parent1), (y, _))) in (x_records.zip(y_records)).enumerate() {
727 let z = fxy(x.clone(), y.clone());
728 // if rhs didn't have a history, don't track that derivative
729 let derivative1 = dfxy_dx(x, y);
730 let new_index = history.append_unary(parent1, derivative1);
731 zs[i] = (z, new_index);
732 }
733 }); // drop borrow on history
734 zs
735}
736
737/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), as with binary_both_history,
738/// but only tracking the derivatives for Y, not X.
739fn binary_y_history<'a, T, I1, I2>(
740 total: usize,
741 history: &WengertList<T>,
742 x_records: I1,
743 y_records: I2,
744 fxy: impl Fn(T, T) -> T,
745 dfxy_dy: impl Fn(T, T) -> T,
746) -> Vec<(T, usize)>
747where
748 I1: Iterator<Item = (T, Index)>,
749 I2: Iterator<Item = (T, Index)>,
750 T: Numeric + Primitive,
751 for<'t> &'t T: NumericRef<T>,
752{
753 let mut zs = vec![(T::zero(), 0); total];
754 history.borrow(|history| {
755 // shadow the name so we can't accidentally try to use history while holding
756 // the borrow
757 // use enumerate not with_index because we need the 1D index for indexing
758 // indexes
759 for (i, ((x, _), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
760 let z = fxy(x.clone(), y.clone());
761 // if self didn't have a history, don't track that derivative
762 let derivative2 = dfxy_dy(x, y);
763 let new_index = history.append_unary(parent2, derivative2);
764 zs[i] = (z, new_index);
765 }
766 }); // drop borrow on history
767 zs
768}
769
770impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
771where
772 T: Numeric + Primitive,
773 for<'t> &'t T: NumericRef<T>,
774 S: TensorRef<(T, Index), D>,
775{
776 /**
777 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
778 * some unary function from `T` to `T` to every element in the container.
779 *
780 * To compute the new records, the unary function of some input x to some
781 * output y is needed along with its derivative with respect to its input x.
782 *
783 * For example, tanh is a commonly used activation function, but the Real trait
784 * does not include this operation and Record has no operations for it specifically.
785 * However, you can use this function to compute the tanh for a record container like so:
786 *
787 * ```
788 * use easy_ml::differentiation::{RecordTensor, WengertList};
789 * use easy_ml::tensors::Tensor;
790 * let list = WengertList::new();
791 * let X = RecordTensor::variables(
792 * &list,
793 * Tensor::from_fn(
794 * [("rows", 2), ("columns", 2)],
795 * |[r, c]| 0.15 * ((1 + r + c) as f32)
796 * )
797 * );
798 * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
799 * // 1 / (cosh(x) * cosh(x))
800 * let Y = X.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
801 *
802 * // we can unwrap here because we know Y contains variables not constants
803 * let derivatives = Y.derivatives().unwrap();
804 * let derivatives_indexing = derivatives.index_by(["rows", "columns"]);
805 * assert_eq!(
806 * derivatives_indexing.get_ref([0, 0]).at_tensor(&X),
807 * Tensor::from(
808 * [("rows", 2), ("columns", 2)],
809 * // [0, 0] element in Y only had the one input variable [0, 0] in X
810 * vec![
811 * 0.9778332, 0.0,
812 * 0.0, 0.0
813 * ]
814 * ),
815 * );
816 * assert_eq!(
817 * derivatives_indexing.get_ref([0, 1]).at_tensor(&X),
818 * Tensor::from(
819 * [("rows", 2), ("columns", 2)],
820 * vec![
821 * 0.0, 0.915137,
822 * 0.0, 0.0
823 * ]
824 * ),
825 * );
826 * assert_eq!(
827 * // [0, 1] and [1, 0] elements in X had the same starting value so end up with the same
828 * // derivative for their corresponding input variable in X
829 * derivatives_indexing.get_ref([0, 1]).at_tensor(&X).index().get([0, 1]),
830 * derivatives_indexing.get_ref([1, 0]).at_tensor(&X).index().get([1, 0]),
831 * );
832 * assert_eq!(
833 * derivatives_indexing.get_ref([1, 1]).at_tensor(&X),
834 * Tensor::from(
835 * [("rows", 2), ("columns", 2)],
836 * vec![
837 * 0.0, 0.0,
838 * 0.0, 0.8220013
839 * ]
840 * ),
841 * );
842 * ```
843 */
844 #[track_caller]
845 pub fn unary(
846 &self,
847 fx: impl Fn(T) -> T,
848 dfx_dx: impl Fn(T) -> T,
849 ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D> {
850 let total = self.elements();
851 match self.history {
852 None => RecordTensor::constants(self.numbers.map(|(x, _)| fx(x))),
853 Some(history) => {
854 let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
855 RecordContainer {
856 numbers: self.numbers.new_with_same_shape(ys),
857 history: Some(history),
858 }
859 }
860 }
861 }
862
863 /**
864 * Creates a new RecordContainer from two RecordContainers by applying
865 * some binary function from `T` to `T` to every element pair in the containers. Both
866 * containers must have the same shape.
867 *
868 * To compute the new records, the binary function of some inputs x and y to some
869 * output z is needed along with its derivative with respect to its first input x and
870 * its derivative with respect to its second input y.
871 *
872 * For example, atan2 takes two arguments, but the Real trait
873 * does not include this operation and Record has no operations for it specifically.
874 * However, you can use this function to compute the atan2 for two record containers like so:
875 *
876 * ```
877 * use easy_ml::differentiation::{RecordTensor, WengertList};
878 * use easy_ml::tensors::Tensor;
879 * let list = WengertList::new();
880 * let X = RecordTensor::variables(
881 * &list,
882 * Tensor::from_fn(
883 * [("rows", 2), ("columns", 2)],
884 * |[r, c]| ((1 + r + c) as f32)
885 * )
886 * );
887 * let Y = RecordTensor::variables(
888 * &list,
889 * Tensor::from_fn(
890 * [("rows", 2), ("columns", 2)],
891 * |[r, c]| ((1 + r + c) as f32)
892 * )
893 * );
894 * // the derivative of atan2 with respect to x is y/(x*x + y*y)
895 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
896 * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
897 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
898 * let Z = X.binary(&Y,
899 * |x, y| x.atan2(y),
900 * |x, y| y/((x*x) + (y*y)),
901 * |x, y| -x/((x*x) + (y*y))
902 * );
903 *
904 *
905 * // we can unwrap here because we know Z contains variables not constants
906 * let derivatives = Z.derivatives().unwrap();
907 * // Just as in the unary example, only one pair of the four inputs in X and Y influence the
908 * // outputs in Z, so we have a lot of 0.0 derivatives, and the inputs in [0, 1] and [1, 0]
909 * // are identical so we see the same derivative.
910 * let dZ_dX = derivatives.map(|d| d.at_tensor(&X));
911 * assert_eq!(
912 * dZ_dX,
913 * Tensor::from([("rows", 2), ("columns", 2)], vec![
914 * Tensor::from([("rows", 2), ("columns", 2)], vec![
915 * 0.5, 0.0,
916 * 0.0, 0.0
917 * ]),
918 * Tensor::from([("rows", 2), ("columns", 2)], vec![
919 * 0.0, 0.25,
920 * 0.0, 0.0
921 * ]),
922 * Tensor::from([("rows", 2), ("columns", 2)], vec![
923 * 0.0, 0.0,
924 * 0.25, 0.0
925 * ]),
926 * Tensor::from([("rows", 2), ("columns", 2)], vec![
927 * 0.0, 0.0,
928 * 0.0, 0.16666667
929 * ])
930 * ])
931 * );
932 * let dZ_dY = derivatives.map(|d| d.at_tensor(&Y));
933 * assert_eq!(
934 * dZ_dY,
935 * Tensor::from([("rows", 2), ("columns", 2)], vec![
936 * Tensor::from([("rows", 2), ("columns", 2)], vec![
937 * -0.5, 0.0,
938 * 0.0, 0.0
939 * ]),
940 * Tensor::from([("rows", 2), ("columns", 2)], vec![
941 * 0.0, -0.25,
942 * 0.0, 0.0
943 * ]),
944 * Tensor::from([("rows", 2), ("columns", 2)], vec![
945 * 0.0, 0.0,
946 * -0.25, 0.0
947 * ]),
948 * Tensor::from([("rows", 2), ("columns", 2)], vec![
949 * 0.0, 0.0,
950 * 0.0, -0.16666667
951 * ])
952 * ])
953 * );
954 * ```
955 *
956 * # Panics
957 *
958 * - If both record containers have a WengertList that are different to each other
959 * - If the record containers have different shapes
960 */
961 #[track_caller]
962 pub fn binary<S2>(
963 &self,
964 rhs: &RecordTensor<'a, T, S2, D>,
965 fxy: impl Fn(T, T) -> T,
966 dfxy_dx: impl Fn(T, T) -> T,
967 dfxy_dy: impl Fn(T, T) -> T,
968 ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
969 where
970 S2: TensorRef<(T, Index), D>,
971 {
972 {
973 let left_shape = self.numbers.shape();
974 let right_shape = rhs.numbers.shape();
975 if left_shape != right_shape {
976 panic!(
977 "Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
978 left_shape, right_shape
979 );
980 }
981 }
982 let total = self.elements();
983 match (self.history, rhs.history) {
984 (None, None) => RecordTensor::constants(
985 // use direct_from here maybe?
986 Tensor::from(
987 self.numbers.shape(),
988 self.numbers
989 .iter()
990 .zip(rhs.numbers.iter())
991 .map(|((x, _), (y, _))| fxy(x, y))
992 .collect(),
993 ),
994 ),
995 (Some(history), None) => {
996 let zs = binary_x_history::<T, _, _>(
997 total,
998 history,
999 self.numbers.iter(),
1000 rhs.numbers.iter(),
1001 fxy,
1002 dfxy_dx,
1003 );
1004 RecordContainer {
1005 numbers: self.numbers.new_with_same_shape(zs),
1006 history: Some(history),
1007 }
1008 }
1009 (None, Some(history)) => {
1010 let zs = binary_y_history::<T, _, _>(
1011 total,
1012 history,
1013 self.numbers.iter(),
1014 rhs.numbers.iter(),
1015 fxy,
1016 dfxy_dy,
1017 );
1018 RecordContainer {
1019 numbers: self.numbers.new_with_same_shape(zs),
1020 history: Some(history),
1021 }
1022 }
1023 (Some(history), Some(h)) => {
1024 assert!(
1025 record_operations::same_lists(history, h),
1026 "Record containers must be using the same WengertList"
1027 );
1028 let zs = binary_both_history::<T, _, _>(
1029 total,
1030 history,
1031 self.numbers.iter(),
1032 rhs.numbers.iter(),
1033 fxy,
1034 dfxy_dx,
1035 dfxy_dy,
1036 );
1037 RecordContainer {
1038 numbers: self.numbers.new_with_same_shape(zs),
1039 history: Some(history),
1040 }
1041 }
1042 }
1043 }
1044
1045 /**
1046 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1047 * some unary function on `Record<T>` to `Record<T>` to every element in the container. This
1048 * will fail if the function would create records with inconsistent histories.
1049 *
1050 * When used with pure functions that can't return different histories for different inputs
1051 * unwrapping with always succeed.
1052 *
1053 * This API can allow you to call a generic function that operates on
1054 * [Numeric](crate::numeric::Numeric) or [Real](crate::numeric::extra::Real) numbers and
1055 * apply all the correct derivative tracking during the intermediate calculations for you,
1056 * without having to resort to storing the Record types.
1057 *
1058 * ```
1059 * use easy_ml::numeric::extra::Real;
1060 * use easy_ml::tensors::Tensor;
1061 * use easy_ml::differentiation::{RecordTensor, WengertList};
1062 *
1063 * fn sigmoid<T: Real+ Copy>(x: T) -> T {
1064 * T::one() / (T::one() + (-x).exp())
1065 * }
1066 *
1067 * let history = WengertList::new();
1068 * let layer = RecordTensor::variables(&history, Tensor::from([("x", 2)], vec![ 0.2, 0.6 ]));
1069 * let after = layer.map(sigmoid).unwrap(); // sigmoid can't introduce inconsistent histories
1070 * ```
1071 *
1072 * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1073 * after mapping doesn't have to be the same as before, only must be the same for every
1074 * mapped element.
1075 *
1076 * See also: [AsRecords](AsRecords)
1077 */
1078 #[allow(clippy::type_complexity)]
1079 #[track_caller]
1080 pub fn map(
1081 &self,
1082 fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1083 ) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
1084 let result = RecordTensor::from_iter(self.shape(), self.iter_as_records().map(fx));
1085 RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
1086 }
1087
1088 /**
1089 * Returns a new Tensor of the same shape as this RecordContainer by
1090 * applying a function to every element from the Record to the Record's number
1091 * type. For example, you could lookup the derivatives of each Record in
1092 * the container with respect to some input.
1093 */
1094 pub fn map_to_tensor(&self, fx: impl Fn(Record<'a, T>) -> T) -> Tensor<T, D> {
1095 Tensor::from(self.shape(), self.iter_as_records().map(fx).collect())
1096 }
1097
1098 /**
1099 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1100 * some unary function on `Record<T>` and each index of that position in the Record to
1101 * `Record<T>` to every element in the container. This will fail if the function would create
1102 * records with inconsistent histories.
1103 *
1104 * When used with pure functions that can't return different histories for different inputs
1105 * unwrapping with always succeed.
1106 *
1107 * This API can allow you to call a generic function that operates on
1108 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1109 * during the intermediate calculations for you, without having to resort to storing the
1110 * Record types.
1111 *
1112 * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1113 * after mapping doesn't have to be the same as before, only must be the same for every
1114 * mapped element.
1115 */
1116 #[allow(clippy::type_complexity)]
1117 #[track_caller]
1118 pub fn map_with_index(
1119 &self,
1120 fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
1121 ) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
1122 let result = RecordTensor::from_iter(
1123 self.shape(),
1124 self.iter_as_records().with_index().map(|(i, x)| fx(i, x)),
1125 );
1126 RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
1127 }
1128
1129 #[track_caller]
1130 #[allow(clippy::type_complexity)]
1131 fn map_collection(
1132 result: Result<
1133 RecordTensor<'a, T, Tensor<(T, usize), D>, D>,
1134 InvalidRecordIteratorError<'a, T, D>,
1135 >,
1136 shape: [(Dimension, usize); D],
1137 ) -> Result<RecordTensor<'a, T, Tensor<(T, usize), D>, D>, InconsistentHistory<'a, T>> {
1138 use InvalidRecordIteratorError as Error;
1139 match result {
1140 Ok(tensor) => Ok(tensor),
1141 Err(error) => match error {
1142 // These first two should be 100% impossible but provide a sensible error just
1143 // in case some weird things break our invariants
1144 Error::Empty => panic!("Illegal state, record tensor was empty {:?}", shape),
1145 Error::Shape { requested, length } => panic!(
1146 "Illegal state, record tensor shape was inconsistent: requested: {:?}, length of data: {:?}",
1147 requested, length
1148 ),
1149 // This one is theoretically possible but in practise shouldn't happen by accident
1150 // However, it can't implement Debug unless T is debug so to avoid having to
1151 // restrict our function signature we return a Result anyway - this also encourages
1152 // the user to make sure their function isn't going to cause this case, which
1153 // with some of the other variants like with_index might come up more easily
1154 Error::InconsistentHistory(h) => Err(h),
1155 },
1156 }
1157 }
1158
1159 /**
1160 * For each record in the container, peforms a backward pass up its WengertList from it
1161 * as the output, computing all the derivatives for the inputs involving this output.
1162 *
1163 * If this container has no backing WengertList, ie was created as constants, then None is
1164 * returned instead. Otherwise the returned Tensor will have the same shape as this container,
1165 * with the respective derivatives matching each element in this container.
1166 *
1167 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is Y with M outputs,
1168 * then this computes all the derivatives δy<sub>j</sub>/δx<sub>i</sub> for i = 1 to N and
1169 * j = 1 to M.
1170 *
1171 * If you have a lot of outputs this could be very expensive! Reverse auto diff is optimised
1172 * for domains where there are many more inputs than outputs.
1173 *
1174 * If you only need some of the derivatives then
1175 * [derivatives_for](RecordTensor::derivatives_for) can be used instead to avoid
1176 * calculating the rest.
1177 */
1178 pub fn derivatives(&self) -> Option<Tensor<Derivatives<T>, D>> {
1179 self.history.map(|history| {
1180 self.numbers.map(|(x, i)| {
1181 Record {
1182 number: x,
1183 history: Some(history),
1184 index: i,
1185 }
1186 .derivatives()
1187 })
1188 })
1189 }
1190
1191 /**
1192 * For the record at the index, peforms a backward pass up its WengertList from it
1193 * as the output, computing all the derivatives for the inputs involving this output.
1194 *
1195 * If the index is invalid or this container has no backing WengertList, ie was created
1196 * as constants, then None is returned instead.
1197 *
1198 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
1199 * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
1200 */
1201 pub fn derivatives_for(&self, indexes: [usize; D]) -> Option<Derivatives<T>> {
1202 let (number, index) = self.get_reference(indexes).map(|(x, i)| (x.clone(), *i))?;
1203 // The nature of reverse autodiff is that we expect to only have a few outputs from
1204 // which we calculate all the derivatives we care about. Therefore just call Record and
1205 // reuse the implementation instead of trying to do anything clever like calculate all
1206 // derivatives for every number in this container.
1207 Record {
1208 number,
1209 history: self.history,
1210 index,
1211 }
1212 .try_derivatives()
1213 }
1214
1215 /**
1216 * Performs elementwise multiplication for two record tensors of the same shape.
1217 *
1218 * # Panics
1219 *
1220 * - If both record containers have a WengertList that are different to each other
1221 * - If the record containers have different shapes
1222 */
1223 // TODO: Assign variants?
1224 pub fn elementwise_multiply<S2>(
1225 &self,
1226 other: &RecordTensor<'a, T, S2, D>,
1227 ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1228 where
1229 S2: TensorRef<(T, Index), D>,
1230 {
1231 self.binary(
1232 other,
1233 Multiplication::<T>::function,
1234 Multiplication::<T>::d_function_dx,
1235 Multiplication::<T>::d_function_dy,
1236 )
1237 }
1238
1239 /**
1240 * Performs elementwise division for two record tensors of the same shape.
1241 *
1242 * # Panics
1243 *
1244 * - If both record containers have a WengertList that are different to each other
1245 * - If the record containers have different shapes
1246 */
1247 pub fn elementwise_divide<S2>(
1248 &self,
1249 other: &RecordTensor<'a, T, S2, D>,
1250 ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1251 where
1252 S2: TensorRef<(T, Index), D>,
1253 {
1254 self.binary(
1255 other,
1256 Division::<T>::function,
1257 Division::<T>::d_function_dx,
1258 Division::<T>::d_function_dy,
1259 )
1260 }
1261}
1262
1263impl<T: Clone + Primitive> Derivatives<T> {
1264 /**
1265 * Queries the derivative at the provided index into the record tensor as input.
1266 *
1267 * If you construct a Derivatives object for some output y,
1268 * and call .at_tensor_index(i, &xs) on it for some input container xs and index i, this
1269 * returns dy/dx where x = xs\[i\].
1270 *
1271 * If the index into the tensor is invalid, returns None instead.
1272 */
1273 pub fn at_tensor_index<S, const D: usize>(
1274 &self,
1275 indexes: [usize; D],
1276 input: &RecordTensor<T, S, D>,
1277 ) -> Option<T>
1278 where
1279 S: TensorRef<(T, Index), D>,
1280 {
1281 let index = input.get_reference(indexes).map(|(_, i)| *i)?;
1282 Some(self.derivatives[index].clone())
1283 }
1284
1285 /**
1286 * Queries the derivatives at every element in the record tensor input.
1287 *
1288 * If you construct a Derivatives object for some output y,
1289 * and call .at_tensor(&xs) on it for some input container xs this
1290 * returns dy/dx for every x in xs.
1291 */
1292 pub fn at_tensor<S, const D: usize>(&self, input: &RecordTensor<T, S, D>) -> Tensor<T, D>
1293 where
1294 S: TensorRef<(T, Index), D>,
1295 {
1296 input.numbers.map(|(_, i)| self.derivatives[i].clone())
1297 }
1298
1299 /**
1300 * Queries the derivative at the provided index into the record matrix as input.
1301 *
1302 * If you construct a Derivatives object for some output y,
1303 * and call .at_matrix_index(i, j, &xs) on it for some input container xs and indexes i and j,
1304 * this returns dy/dx where x = xs\[i, j\].
1305 *
1306 * If the index into the tensor is invalid, returns None instead.
1307 */
1308 pub fn at_matrix_index<S>(
1309 &self,
1310 row: Row,
1311 column: Column,
1312 input: &RecordMatrix<T, S>,
1313 ) -> Option<T>
1314 where
1315 S: MatrixRef<(T, Index)> + NoInteriorMutability,
1316 {
1317 let index = input.try_get_reference(row, column).map(|(_, i)| *i)?;
1318 Some(self.derivatives[index].clone())
1319 }
1320
1321 /**
1322 * Queries the derivatives at every element in the record matrix input.
1323 *
1324 * If you construct a Derivatives object for some output y,
1325 * and call .at_matrix(&xs) on it for some input container xs this
1326 * returns dy/dx for every x in xs.
1327 */
1328 pub fn at_matrix<S>(&self, input: &RecordMatrix<T, S>) -> Matrix<T>
1329 where
1330 S: MatrixRef<(T, Index)> + NoInteriorMutability,
1331 {
1332 input.numbers.map(|(_, i)| self.derivatives[i].clone())
1333 }
1334}
1335
1336impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
1337where
1338 T: Numeric + Primitive,
1339 for<'t> &'t T: NumericRef<T>,
1340 S: TensorMut<(T, Index), D>,
1341{
1342 /**
1343 * Overwrites a RecordContainer by applying
1344 * some unary function from `T` to `T` to every element in the container.
1345 *
1346 * To compute the new records, the unary function of some input x to some
1347 * output y is needed along with its derivative with respect to its input x.
1348 */
1349 #[track_caller]
1350 pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
1351 let total = self.elements();
1352 match self.history {
1353 None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
1354 Some(history) => {
1355 let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
1356 for (element, result) in self.numbers.iter_reference_mut().zip(ys) {
1357 *element = result;
1358 }
1359 self.history = Some(history);
1360 }
1361 }
1362 }
1363
1364 /**
1365 * Overwrites the left hand side of a RecordContainer with the result of applying
1366 * some binary function from `T` to `T` to every element pair in the containers. Both
1367 * containers must have the same shape.
1368 * To compute the new records, the binary function of some inputs x and y to some
1369 * output z is needed along with its derivative with respect to its first input x and
1370 * its derivative with respect to its second input y.
1371 *
1372 * # Panics
1373 *
1374 * - If both record containers have a WengertList that are different to each other
1375 * - If the record containers have different shapes
1376 */
1377 #[track_caller]
1378 pub fn binary_left_assign<S2>(
1379 &mut self,
1380 rhs: &RecordTensor<'a, T, S2, D>,
1381 fxy: impl Fn(T, T) -> T,
1382 dfxy_dx: impl Fn(T, T) -> T,
1383 dfxy_dy: impl Fn(T, T) -> T,
1384 ) where
1385 S2: TensorRef<(T, Index), D>,
1386 {
1387 {
1388 let left_shape = self.numbers.shape();
1389 let right_shape = rhs.numbers.shape();
1390 if left_shape != right_shape {
1391 panic!(
1392 "Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
1393 left_shape, right_shape
1394 );
1395 }
1396 }
1397 let total = self.elements();
1398 match (self.history, rhs.history) {
1399 (None, None) => {
1400 for (x, y) in self.numbers.iter_reference_mut().zip(rhs.numbers.iter()) {
1401 let (left, _) = x;
1402 let (right, _) = y;
1403 *x = (fxy(left.clone(), right), 0);
1404 }
1405 }
1406 (Some(history), None) => {
1407 let zs = binary_x_history::<T, _, _>(
1408 total,
1409 history,
1410 self.numbers.iter(),
1411 rhs.numbers.iter(),
1412 fxy,
1413 dfxy_dx,
1414 );
1415 for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1416 *element = result;
1417 }
1418 self.history = Some(history);
1419 }
1420 (None, Some(history)) => {
1421 let zs = binary_y_history::<T, _, _>(
1422 total,
1423 history,
1424 self.numbers.iter(),
1425 rhs.numbers.iter(),
1426 fxy,
1427 dfxy_dy,
1428 );
1429 for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1430 *element = result;
1431 }
1432 self.history = Some(history);
1433 }
1434 (Some(history), Some(h)) => {
1435 assert!(
1436 record_operations::same_lists(history, h),
1437 "Record containers must be using the same WengertList"
1438 );
1439 let zs = binary_both_history::<T, _, _>(
1440 total,
1441 history,
1442 self.numbers.iter(),
1443 rhs.numbers.iter(),
1444 fxy,
1445 dfxy_dx,
1446 dfxy_dy,
1447 );
1448 for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1449 *element = result;
1450 }
1451 self.history = Some(history);
1452 }
1453 }
1454 }
1455
1456 /**
1457 * A convenience helper function which takes the RecordContainer value and
1458 * calls [unary_assign](RecordTensor::unary_assign()) on it, returning
1459 * the record container which now contains the result of the operation.
1460 */
1461 #[track_caller]
1462 pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
1463 self.unary_assign(fx, dfx_dx);
1464 self
1465 }
1466
1467 /**
1468 * A convenience helper function which takes the left hand side by value and
1469 * calls [binary_left_assign](RecordTensor::binary_left_assign()) on it, returning
1470 * the left hand side which now contains the result of the operation.
1471 */
1472 #[track_caller]
1473 pub fn do_binary_left_assign<S2>(
1474 mut self,
1475 rhs: &RecordTensor<'a, T, S2, D>,
1476 fxy: impl Fn(T, T) -> T,
1477 dfxy_dx: impl Fn(T, T) -> T,
1478 dfxy_dy: impl Fn(T, T) -> T,
1479 ) -> Self
1480 where
1481 S2: TensorRef<(T, Index), D>,
1482 {
1483 self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
1484 self
1485 }
1486
1487 /**
1488 * Updates this RecordContainer in place by applying some unary function on `Record<T>` to
1489 * `Record<T>` to every element in the container. This will fail if the function would create
1490 * records with inconsistent histories.
1491 *
1492 * When used with pure functions that can't return different histories for different inputs
1493 * unwrapping with always succeed.
1494 *
1495 * Since this updates the container in place, if Err is returned then the data in this
1496 * RecordContainer is still available but it has been corrupted - at least one of the elements
1497 * should have a different history than what it will have because the mapping function created
1498 * inconsistent histories that couldn't be represented by the container as it only stores
1499 * one.
1500 *
1501 * This API can allow you to call a generic function that operates on
1502 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1503 * during the intermediate calculations for you, without having to resort to storing the
1504 * Record types.
1505 *
1506 * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1507 * after mapping doesn't have to be the same as before, only must be the same for every
1508 * mapped element.
1509 *
1510 * You might also use this function at the end of a training loop to update all the weights
1511 * to reduce their loss.
1512 *
1513 * ```
1514 * use easy_ml::numeric::Numeric;
1515 * use easy_ml::tensors::Tensor;
1516 * use easy_ml::differentiation::{Record, RecordTensor, WengertList};
1517 *
1518 * let history = WengertList::new();
1519 * let mut weights = RecordTensor::variables(
1520 * &history,
1521 * Tensor::from([("w1", 4)], vec![ 0.3, 0.2, -1.2, -0.4 ])
1522 * );
1523 * let error = {
1524 * // this is over-simplified for brevity, obviously in a real scenario we wouldn't have a
1525 * // function that calculates the error like this or we wouldn't be doing machine learning
1526 * // to fit it in the first place
1527 * let mut loss = Record::variable(0.0, &history);
1528 * for r in weights.iter_as_records() {
1529 * loss = loss + r;
1530 * }
1531 * loss
1532 * };
1533 * let derivatives = error.derivatives();
1534 * let learning_rate = 0.1;
1535 * // update the weights to contain less error than they did before
1536 * let result = weights.map_mut(|x| x - (derivatives[&x] * learning_rate));
1537 * assert!(result.is_ok()); // we know we didn't introduce an inconsistent history just updating the weights
1538 * ```
1539 */
1540 #[track_caller]
1541 pub fn map_mut(
1542 &mut self,
1543 fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1544 ) -> Result<(), InconsistentHistory<'a, T>> {
1545 let history = self.history;
1546 let new_history =
1547 map_mut_base::<'a, T, _, _>(TensorReferenceMutIterator::from(self), |x| {
1548 let record = Record::from_existing(x.clone(), history);
1549 let result = fx(record);
1550 *x = (result.number, result.index);
1551 result.history
1552 })?;
1553 self.history = new_history;
1554 Ok(())
1555 }
1556
1557 /**
1558 * Updates this RecordContainer in place by applying some unary function on `Record<T>` and
1559 * each index of that position in the Record to `Record<T>` to every element in the container.
1560 * This will fail if the function would create records with inconsistent histories.
1561 *
1562 * When used with pure functions that can't return different histories for different inputs
1563 * unwrapping with always succeed.
1564 *
1565 * Since this updates the container in place, if Err is returned then the data in this
1566 * RecordContainer is still available but it has been corrupted - at least one of the elements
1567 * should have a different history than what it will have because the mapping function created
1568 * inconsistent histories that couldn't be represented by the container as it only stores
1569 * one.
1570 *
1571 * This API can allow you to call a generic function that operates on
1572 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1573 * during the intermediate calculations for you, without having to resort to storing the
1574 * Record types.
1575 *
1576 * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1577 * after mapping doesn't have to be the same as before, only must be the same for every
1578 * mapped element.
1579 */
1580 #[track_caller]
1581 pub fn map_mut_with_index(
1582 &mut self,
1583 fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
1584 ) -> Result<(), InconsistentHistory<'a, T>> {
1585 let history = self.history;
1586 let new_history = map_mut_base::<'a, T, _, _>(
1587 TensorReferenceMutIterator::from(self).with_index(),
1588 |(i, x)| {
1589 let record = Record::from_existing(x.clone(), history);
1590 let result = fx(i, record);
1591 *x = (result.number, result.index);
1592 result.history
1593 },
1594 )?;
1595 self.history = new_history;
1596 Ok(())
1597 }
1598}
1599
1600#[track_caller]
1601fn map_mut_base<'a, T, I, X>(
1602 mut iter: I,
1603 fx: impl Fn(X) -> Option<&'a WengertList<T>>,
1604) -> Result<Option<&'a WengertList<T>>, InconsistentHistory<'a, T>>
1605where
1606 I: Iterator<Item = X>,
1607 T: Primitive,
1608{
1609 use crate::differentiation::record_operations::are_exact_same_list;
1610 #[rustfmt::skip]
1611 let first_history = fx(iter.next().expect("Illegal state, record container was empty"));
1612 let mut different_history: Option<Option<&WengertList<T>>> = None;
1613 for x in iter {
1614 let history = fx(x);
1615 if !are_exact_same_list(history, first_history) {
1616 different_history = Some(history);
1617 }
1618 }
1619 match different_history {
1620 None => Ok(first_history),
1621 Some(h) => Err(InconsistentHistory {
1622 first: first_history,
1623 later: h,
1624 }),
1625 }
1626}
1627
1628impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
1629where
1630 T: Numeric + Primitive,
1631 for<'t> &'t T: NumericRef<T>,
1632 S: TensorRef<(T, Index), D>,
1633{
1634 /**
1635 * Overwrites the right hand side of a RecordContainer with the result of applying
1636 * some binary function from `T` to `T` to every element pair in the containers. Both
1637 * containers must have the same shape.
1638 * To compute the new records, the binary function of some inputs x and y to some
1639 * output z is needed along with its derivative with respect to its first input x and
1640 * its derivative with respect to its second input y.
1641 *
1642 * # Panics
1643 *
1644 * - If both record containers have a WengertList that are different to each other
1645 * - If the record containers have different shapes
1646 */
1647 #[track_caller]
1648 pub fn binary_right_assign<S2>(
1649 &self,
1650 rhs: &mut RecordTensor<'a, T, S2, D>,
1651 fxy: impl Fn(T, T) -> T,
1652 dfxy_dx: impl Fn(T, T) -> T,
1653 dfxy_dy: impl Fn(T, T) -> T,
1654 ) where
1655 S2: TensorMut<(T, Index), D>,
1656 {
1657 // x is lhs, y is rhs, so calling binary_left_assign on the rhs container
1658 // means we need to swap all the arguments
1659 rhs.binary_left_assign(
1660 self,
1661 |y, x| fxy(x, y),
1662 |y, x| dfxy_dy(x, y),
1663 |y, x| dfxy_dx(x, y),
1664 )
1665 }
1666
1667 /**
1668 * A convenience helper function which takes the right hand side by value and
1669 * calls [binary_right_assign](RecordTensor::binary_right_assign()) on it, returning
1670 * the right hand side which now contains the result of the operation.
1671 */
1672 #[track_caller]
1673 pub fn do_binary_right_assign<S2>(
1674 &self,
1675 mut rhs: RecordTensor<'a, T, S2, D>,
1676 fxy: impl Fn(T, T) -> T,
1677 dfxy_dx: impl Fn(T, T) -> T,
1678 dfxy_dy: impl Fn(T, T) -> T,
1679 ) -> RecordTensor<'a, T, S2, D>
1680 where
1681 S2: TensorMut<(T, Index), D>,
1682 {
1683 self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
1684 rhs
1685 }
1686}
1687
1688impl<'a, T, S> RecordMatrix<'a, T, S>
1689where
1690 T: Numeric + Primitive,
1691 for<'t> &'t T: NumericRef<T>,
1692 S: MatrixRef<(T, Index)> + NoInteriorMutability,
1693{
1694 /**
1695 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1696 * some unary function from `T` to `T` to every element in the container.
1697 *
1698 * To compute the new records, the unary function of some input x to some
1699 * output y is needed along with its derivative with respect to its input x.
1700 *
1701 * For example, tanh is a commonly used activation function, but the Real trait
1702 * does not include this operation and Record has no operations for it specifically.
1703 * However, you can use this function to compute the tanh for a record container like so:
1704 *
1705 * ```
1706 * use easy_ml::differentiation::{RecordMatrix, WengertList};
1707 * use easy_ml::matrices::Matrix;
1708 * let list = WengertList::new();
1709 * let X = RecordMatrix::variables(
1710 * &list,
1711 * Matrix::from_fn((2, 2), |(r, c)| 0.15 * ((1 + r + c) as f32))
1712 * );
1713 * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
1714 * // 1 / (cosh(x) * cosh(x))
1715 * let Y = X.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
1716 *
1717 * // we can unwrap here because we know Y contains variables not constants
1718 * let derivatives = Y.derivatives().unwrap();
1719 * assert_eq!(
1720 * derivatives.get_reference(0, 0).at_matrix(&X),
1721 * Matrix::from(vec![
1722 * // (0, 0) element in Y only had the one input variable (0, 0) in X
1723 * vec![0.9778332, 0.0],
1724 * vec![0.0, 0.0]
1725 * ]),
1726 * );
1727 * assert_eq!(
1728 * derivatives.get_reference(0, 1).at_matrix(&X),
1729 * Matrix::from(vec![
1730 * vec![0.0, 0.915137],
1731 * vec![0.0, 0.0]
1732 * ]),
1733 * );
1734 * assert_eq!(
1735 * // (0, 1) and (1, 0) elements in X had the same starting value so end up with the same
1736 * // derivative for their corresponding input variable in X
1737 * derivatives.get_reference(0, 1).at_matrix(&X).get(0, 1),
1738 * derivatives.get_reference(1, 0).at_matrix(&X).get(1, 0),
1739 * );
1740 * assert_eq!(
1741 * derivatives.get_reference(1, 1).at_matrix(&X),
1742 * Matrix::from(vec![
1743 * vec![0.0, 0.0 ],
1744 * vec![0.0, 0.8220013]
1745 * ]),
1746 * );
1747 * ```
1748 */
1749 #[track_caller]
1750 pub fn unary(
1751 &self,
1752 fx: impl Fn(T) -> T,
1753 dfx_dx: impl Fn(T) -> T,
1754 ) -> RecordMatrix<'a, T, Matrix<(T, Index)>> {
1755 let total = self.elements();
1756 match self.history {
1757 None => RecordMatrix::constants(self.numbers.map(|(x, _)| fx(x))),
1758 Some(history) => {
1759 let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
1760 RecordContainer {
1761 numbers: MatrixView::from(Matrix::from_flat_row_major(self.numbers.size(), ys)),
1762 history: Some(history),
1763 }
1764 }
1765 }
1766 }
1767
1768 /**
1769 * Creates a new RecordContainer from two RecordContainers by applying
1770 * some binary function from `T` to `T` to every element pair in the containers. Both
1771 * containers must have the same shape.
1772 *
1773 * To compute the new records, the binary function of some inputs x and y to some
1774 * output z is needed along with its derivative with respect to its first input x and
1775 * its derivative with respect to its second input y.
1776 *
1777 * For example, atan2 takes two arguments, but the Real trait
1778 * does not include this operation and Record has no operations for it specifically.
1779 * However, you can use this function to compute the atan2 for two record containers like so:
1780 *
1781 * ```
1782 * use easy_ml::differentiation::{RecordMatrix, WengertList};
1783 * use easy_ml::matrices::Matrix;
1784 * let list = WengertList::new();
1785 * let X = RecordMatrix::variables(
1786 * &list,
1787 * Matrix::from_fn((2, 2), |(r, c)| ((1 + r + c) as f32))
1788 * );
1789 * let Y = RecordMatrix::variables(
1790 * &list,
1791 * Matrix::from_fn((2, 2), |(r, c)| ((1 + r + c) as f32))
1792 * );
1793 * // the derivative of atan2 with respect to x is y/(x*x + y*y)
1794 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
1795 * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
1796 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
1797 * let Z = X.binary(&Y,
1798 * |x, y| x.atan2(y),
1799 * |x, y| y/((x*x) + (y*y)),
1800 * |x, y| -x/((x*x) + (y*y))
1801 * );
1802 *
1803 * // we can unwrap here because we know Z contains variables not constants
1804 * let derivatives = Z.derivatives().unwrap();
1805 * // Just as in the unary example, only one pair of the four inputs in X and Y influence the
1806 * // outputs in Z, so we have a lot of 0.0 derivatives, and the inputs in [0, 1] and [1, 0]
1807 * // are identical so we see the same derivative.
1808 * let dZ_dX = derivatives.map(|d| d.at_matrix(&X));
1809 * assert_eq!(
1810 * dZ_dX,
1811 * Matrix::from(vec![
1812 * vec![
1813 * Matrix::from(vec![
1814 * vec![ 0.5, 0.0 ],
1815 * vec![ 0.0, 0.0 ]
1816 * ]),
1817 * Matrix::from(vec![
1818 * vec![ 0.0, 0.25 ],
1819 * vec![ 0.0, 0.0 ]
1820 * ])
1821 * ],
1822 * vec![
1823 * Matrix::from(vec![
1824 * vec![ 0.0, 0.0 ],
1825 * vec![ 0.25, 0.0 ]
1826 * ]),
1827 * Matrix::from(vec![
1828 * vec![ 0.0, 0.0 ],
1829 * vec![ 0.0, 0.16666667 ]
1830 * ])
1831 * ]
1832 * ])
1833 * );
1834 * let dZ_dY = derivatives.map(|d| d.at_matrix(&Y));
1835 * assert_eq!(
1836 * dZ_dY,
1837 * Matrix::from(vec![
1838 * vec![
1839 * Matrix::from(vec![
1840 * vec![ -0.5, 0.0 ],
1841 * vec![ 0.0, 0.0 ]
1842 * ]),
1843 * Matrix::from(vec![
1844 * vec![ 0.0, -0.25 ],
1845 * vec![ 0.0, 0.0 ]
1846 * ])
1847 * ],
1848 * vec![
1849 * Matrix::from(vec![
1850 * vec![ 0.0, 0.0 ],
1851 * vec![ -0.25, 0.0 ]
1852 * ]),
1853 * Matrix::from(vec![
1854 * vec![ 0.0, 0.0 ],
1855 * vec![ 0.0, -0.16666667 ]
1856 * ])
1857 * ]
1858 * ])
1859 * );
1860 * ```
1861 *
1862 * # Panics
1863 *
1864 * - If both record containers have a WengertList that are different to each other
1865 * - If the record containers have different shapes
1866 */
1867 #[track_caller]
1868 pub fn binary<S2>(
1869 &self,
1870 rhs: &RecordMatrix<'a, T, S2>,
1871 fxy: impl Fn(T, T) -> T,
1872 dfxy_dx: impl Fn(T, T) -> T,
1873 dfxy_dy: impl Fn(T, T) -> T,
1874 ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1875 where
1876 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1877 {
1878 let shape = {
1879 let left_shape = self.numbers.size();
1880 let right_shape = rhs.numbers.size();
1881 if left_shape != right_shape {
1882 panic!(
1883 "Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
1884 left_shape, right_shape
1885 );
1886 }
1887 left_shape
1888 };
1889 let total = self.elements();
1890 match (self.history, rhs.history) {
1891 (None, None) => RecordMatrix::constants(Matrix::from_flat_row_major(
1892 shape,
1893 self.numbers
1894 .row_major_iter()
1895 .zip(rhs.numbers.row_major_iter())
1896 .map(|((x, _), (y, _))| fxy(x, y))
1897 .collect(),
1898 )),
1899 (Some(history), None) => {
1900 let zs = binary_x_history::<T, _, _>(
1901 total,
1902 history,
1903 self.numbers.row_major_iter(),
1904 rhs.numbers.row_major_iter(),
1905 fxy,
1906 dfxy_dx,
1907 );
1908 RecordContainer {
1909 numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1910 history: Some(history),
1911 }
1912 }
1913 (None, Some(history)) => {
1914 let zs = binary_y_history::<T, _, _>(
1915 total,
1916 history,
1917 self.numbers.row_major_iter(),
1918 rhs.numbers.row_major_iter(),
1919 fxy,
1920 dfxy_dy,
1921 );
1922 RecordContainer {
1923 numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1924 history: Some(history),
1925 }
1926 }
1927 (Some(history), Some(h)) => {
1928 assert!(
1929 record_operations::same_lists(history, h),
1930 "Record containers must be using the same WengertList"
1931 );
1932 let zs = binary_both_history::<T, _, _>(
1933 total,
1934 history,
1935 self.numbers.row_major_iter(),
1936 rhs.numbers.row_major_iter(),
1937 fxy,
1938 dfxy_dx,
1939 dfxy_dy,
1940 );
1941 RecordContainer {
1942 numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1943 history: Some(history),
1944 }
1945 }
1946 }
1947 }
1948
1949 /**
1950 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1951 * some unary function on `Record<T>` to `Record<T>` to every element in the container. This
1952 * will fail if the function would create records with inconsistent histories.
1953 *
1954 * When used with pure functions that can't return different histories for different inputs
1955 * unwrapping with always succeed.
1956 *
1957 * This API can allow you to call a generic function that operates on
1958 * [Numeric](crate::numeric::Numeric) or [Real](crate::numeric::extra::Real) numbers and
1959 * apply all the correct derivative tracking during the intermediate calculations for you,
1960 * without having to resort to storing the Record types.
1961 *
1962 * ```
1963 * use easy_ml::numeric::extra::Real;
1964 * use easy_ml::matrices::Matrix;
1965 * use easy_ml::differentiation::{RecordMatrix, WengertList};
1966 *
1967 * fn sigmoid<T: Real+ Copy>(x: T) -> T {
1968 * T::one() / (T::one() + (-x).exp())
1969 * }
1970 *
1971 * let history = WengertList::new();
1972 * let layer = RecordMatrix::variables(&history, Matrix::from(vec![vec![ 0.2, 0.6 ]]));
1973 * let after = layer.map(sigmoid).unwrap(); // sigmoid can't introduce inconsistent histories
1974 * ```
1975 *
1976 * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
1977 * after mapping doesn't have to be the same as before, only must be the same for every
1978 * mapped element.
1979 *
1980 * See also: [AsRecords](AsRecords)
1981 */
1982 #[allow(clippy::type_complexity)]
1983 #[track_caller]
1984 pub fn map(
1985 &self,
1986 fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1987 ) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
1988 let result = RecordMatrix::from_iter(self.size(), self.iter_row_major_as_records().map(fx));
1989 RecordMatrix::<'a, T, S>::map_collection(result, self.size())
1990 }
1991
1992 /**
1993 * Returns a new Matrix of the same shape as this RecordContainer by
1994 * applying a function to every element from the Record to the Record's number
1995 * type. For example, you could lookup the derivatives of each Record in
1996 * the container with respect to some input.
1997 */
1998 pub fn map_to_matrix(&self, fx: impl Fn(Record<'a, T>) -> T) -> Matrix<T> {
1999 Matrix::from_flat_row_major(
2000 self.size(),
2001 self.iter_row_major_as_records().map(fx).collect(),
2002 )
2003 }
2004
2005 /**
2006 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
2007 * some unary function on `Record<T>` and each index of that position in the Record to
2008 * `Record<T>` to every element in the container. This will fail if the function would
2009 * create records with inconsistent histories.
2010 *
2011 * When used with pure functions that can't return different histories for different inputs
2012 * unwrapping with always succeed.
2013 *
2014 * This API can allow you to call a generic function that operates on
2015 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2016 * during the intermediate calculations for you, without having to resort to storing the
2017 * Record types.
2018 *
2019 * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2020 * after mapping doesn't have to be the same as before, only must be the same for every
2021 * mapped element.
2022 */
2023 #[allow(clippy::type_complexity)]
2024 #[track_caller]
2025 pub fn map_with_index(
2026 &self,
2027 fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
2028 ) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
2029 let result = RecordMatrix::from_iter(
2030 self.size(),
2031 self.iter_row_major_as_records()
2032 .with_index()
2033 .map(|((r, c), x)| fx(x, r, c)),
2034 );
2035 RecordMatrix::<'a, T, S>::map_collection(result, self.size())
2036 }
2037
2038 #[allow(clippy::type_complexity)]
2039 #[track_caller]
2040 fn map_collection(
2041 result: Result<
2042 RecordMatrix<'a, T, Matrix<(T, usize)>>,
2043 InvalidRecordIteratorError<'a, T, 2>,
2044 >,
2045 size: (Row, Column),
2046 ) -> Result<RecordMatrix<'a, T, Matrix<(T, usize)>>, InconsistentHistory<'a, T>> {
2047 use InvalidRecordIteratorError as Error;
2048 match result {
2049 Ok(matrix) => Ok(matrix),
2050 Err(error) => match error {
2051 // These first two should be 100% impossible but provide a sensible error just
2052 // in case some weird things break our invariants
2053 Error::Empty => panic!("Illegal state, record matrix was empty {:?}", size),
2054 Error::Shape { requested, length } => panic!(
2055 "Illegal state, record matrix shape was inconsistent: requested: {:?}, length of data: {:?}",
2056 requested, length
2057 ),
2058 // This one is theoretically possible but in practise shouldn't happen by accident
2059 // However, it can't implement Debug unless T is debug so to avoid having to
2060 // restrict our function signature we return a Result anyway - this also encourages
2061 // the user to make sure their function isn't going to cause this case, which
2062 // with some of the other variants like with_index might come up more easily
2063 Error::InconsistentHistory(h) => Err(h),
2064 },
2065 }
2066 }
2067
2068 /**
2069 * For each record in the container, peforms a backward pass up its WengertList from it
2070 * as the output, computing all the derivatives for the inputs involving this output.
2071 *
2072 * If this container has no backing WengertList, ie was created as constants, then None is
2073 * returned instead. Otherwise the returned Matrix will have the same size as this container,
2074 * with the respective derivatives matching each element in this container.
2075 *
2076 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is Y with M outputs,
2077 * then this computes all the derivatives δy<sub>j</sub>/δx<sub>i</sub> for i = 1 to N and
2078 * j = 1 to M.
2079 *
2080 * If you have a lot of outputs this could be very expensive! Reverse auto diff is optimised
2081 * for domains where there are many more inputs than outputs.
2082 *
2083 * If you only need some of the derivatives then
2084 * [derivatives_for](RecordMatrix::derivatives_for) can be used instead to avoid
2085 * calculating the rest.
2086 */
2087 pub fn derivatives(&self) -> Option<Matrix<Derivatives<T>>> {
2088 self.history.map(|history| {
2089 self.numbers.map(|(x, i)| {
2090 Record {
2091 number: x,
2092 history: Some(history),
2093 index: i,
2094 }
2095 .derivatives()
2096 })
2097 })
2098 }
2099
2100 /**
2101 * For the record at the index, peforms a backward pass up its WengertList from it
2102 * as the output, computing all the derivatives for the inputs involving this output.
2103 *
2104 * If the index is invalid or this container has no backing WengertList, ie was created
2105 * as constants, then None is returned instead.
2106 *
2107 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
2108 * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
2109 */
2110 pub fn derivatives_for(&self, row: Row, column: Column) -> Option<Derivatives<T>> {
2111 let (number, index) = self
2112 .try_get_reference(row, column)
2113 .map(|(x, i)| (x.clone(), *i))?;
2114 // The nature of reverse autodiff is that we expect to only have a few outputs from
2115 // which we calculate all the derivatives we care about. Therefore just call Record and
2116 // reuse the implementation instead of trying to do anything clever like calculate all
2117 // derivatives for every number in this container.
2118 Record {
2119 number,
2120 history: self.history,
2121 index,
2122 }
2123 .try_derivatives()
2124 }
2125
2126 /**
2127 * Performs elementwise multiplication for two record matrices of the same size.
2128 *
2129 * # Panics
2130 *
2131 * - If both record containers have a WengertList that are different to each other
2132 * - If the record containers have different shapes
2133 */
2134 // TODO: Assign variants?
2135 pub fn elementwise_multiply<S2>(
2136 &self,
2137 other: &RecordMatrix<'a, T, S2>,
2138 ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
2139 where
2140 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2141 {
2142 self.binary(
2143 other,
2144 Multiplication::<T>::function,
2145 Multiplication::<T>::d_function_dx,
2146 Multiplication::<T>::d_function_dy,
2147 )
2148 }
2149
2150 /**
2151 * Performs elementwise division for two record matrices of the same size.
2152 *
2153 * # Panics
2154 *
2155 * - If both record containers have a WengertList that are different to each other
2156 * - If the record containers have different shapes
2157 */
2158 pub fn elementwise_divide<S2>(
2159 &self,
2160 other: &RecordMatrix<'a, T, S2>,
2161 ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
2162 where
2163 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2164 {
2165 self.binary(
2166 other,
2167 Division::<T>::function,
2168 Division::<T>::d_function_dx,
2169 Division::<T>::d_function_dy,
2170 )
2171 }
2172}
2173
2174impl<'a, T, S> RecordMatrix<'a, T, S>
2175where
2176 T: Numeric + Primitive,
2177 for<'t> &'t T: NumericRef<T>,
2178 S: MatrixMut<(T, Index)> + NoInteriorMutability,
2179{
2180 /**
2181 * Overwrites a RecordContainer by applying
2182 * some unary function from `T` to `T` to every element in the container.
2183 *
2184 * To compute the new records, the unary function of some input x to some
2185 * output y is needed along with its derivative with respect to its input x.
2186 */
2187 #[track_caller]
2188 pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
2189 let total = self.elements();
2190 match self.history {
2191 None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
2192 Some(history) => {
2193 let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
2194 for (element, result) in self.numbers.row_major_reference_mut_iter().zip(ys) {
2195 *element = result;
2196 }
2197 self.history = Some(history);
2198 }
2199 }
2200 }
2201
2202 /**
2203 * Overwrites the left hand side of a RecordContainer with the result of applying
2204 * some binary function from `T` to `T` to every element pair in the containers. Both
2205 * containers must have the same shape.
2206 * To compute the new records, the binary function of some inputs x and y to some
2207 * output z is needed along with its derivative with respect to its first input x and
2208 * its derivative with respect to its second input y.
2209 *
2210 * # Panics
2211 *
2212 * - If both record containers have a WengertList that are different to each other
2213 * - If the record containers have different shapes
2214 */
2215 #[track_caller]
2216 pub fn binary_left_assign<S2>(
2217 &mut self,
2218 rhs: &RecordMatrix<'a, T, S2>,
2219 fxy: impl Fn(T, T) -> T,
2220 dfxy_dx: impl Fn(T, T) -> T,
2221 dfxy_dy: impl Fn(T, T) -> T,
2222 ) where
2223 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2224 {
2225 {
2226 let left_shape = self.numbers.size();
2227 let right_shape = rhs.numbers.size();
2228 if left_shape != right_shape {
2229 panic!(
2230 "Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
2231 left_shape, right_shape
2232 );
2233 }
2234 }
2235 let total = self.elements();
2236 match (self.history, rhs.history) {
2237 (None, None) => {
2238 for (x, y) in self
2239 .numbers
2240 .row_major_reference_mut_iter()
2241 .zip(rhs.numbers.row_major_iter())
2242 {
2243 let (left, _) = x;
2244 let (right, _) = y;
2245 *x = (fxy(left.clone(), right), 0);
2246 }
2247 }
2248 (Some(history), None) => {
2249 let zs = binary_x_history::<T, _, _>(
2250 total,
2251 history,
2252 self.numbers.row_major_iter(),
2253 rhs.numbers.row_major_iter(),
2254 fxy,
2255 dfxy_dx,
2256 );
2257 for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2258 *element = result;
2259 }
2260 self.history = Some(history);
2261 }
2262 (None, Some(history)) => {
2263 let zs = binary_y_history::<T, _, _>(
2264 total,
2265 history,
2266 self.numbers.row_major_iter(),
2267 rhs.numbers.row_major_iter(),
2268 fxy,
2269 dfxy_dy,
2270 );
2271 for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2272 *element = result;
2273 }
2274 self.history = Some(history);
2275 }
2276 (Some(history), Some(h)) => {
2277 assert!(
2278 record_operations::same_lists(history, h),
2279 "Record containers must be using the same WengertList"
2280 );
2281 let zs = binary_both_history::<T, _, _>(
2282 total,
2283 history,
2284 self.numbers.row_major_iter(),
2285 rhs.numbers.row_major_iter(),
2286 fxy,
2287 dfxy_dx,
2288 dfxy_dy,
2289 );
2290 for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2291 *element = result;
2292 }
2293 self.history = Some(history);
2294 }
2295 }
2296 }
2297
2298 /**
2299 * A convenience helper function which takes the RecordContainer value and
2300 * calls [unary_assign](RecordMatrix::unary_assign()) on it, returning
2301 * the record container which now contains the result of the operation.
2302 */
2303 #[track_caller]
2304 pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
2305 self.unary_assign(fx, dfx_dx);
2306 self
2307 }
2308
2309 /**
2310 * A convenience helper function which takes the left hand side by value and
2311 * calls [binary_left_assign](RecordMatrix::binary_left_assign()) on it, returning
2312 * the left hand side which now contains the result of the operation.
2313 */
2314 #[track_caller]
2315 pub fn do_binary_left_assign<S2>(
2316 mut self,
2317 rhs: &RecordMatrix<'a, T, S2>,
2318 fxy: impl Fn(T, T) -> T,
2319 dfxy_dx: impl Fn(T, T) -> T,
2320 dfxy_dy: impl Fn(T, T) -> T,
2321 ) -> Self
2322 where
2323 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2324 {
2325 self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
2326 self
2327 }
2328
2329 /**
2330 * Updates this RecordContainer in place by applying some unary function on `Record<T>` to
2331 * `Record<T>` to every element in the container. This will fail if the function would create
2332 * records with inconsistent histories.
2333 *
2334 * When used with pure functions that can't return different histories for different inputs
2335 * unwrapping with always succeed.
2336 *
2337 * Since this updates the container in place, if Err is returned then the data in this
2338 * RecordContainer is still available but it has been corrupted - at least one of the elements
2339 * should have a different history than what it will have because the mapping function created
2340 * inconsistent histories that couldn't be represented by the container as it only stores
2341 * one.
2342 *
2343 * This API can allow you to call a generic function that operates on
2344 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2345 * during the intermediate calculations for you, without having to resort to storing the
2346 * Record types.
2347 *
2348 * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2349 * after mapping doesn't have to be the same as before, only must be the same for every
2350 * mapped element.
2351 *
2352 * You might also use this function at the end of a training loop to update all the weights
2353 * to reduce their loss.
2354 *
2355 * ```
2356 * use easy_ml::numeric::Numeric;
2357 * use easy_ml::matrices::Matrix;
2358 * use easy_ml::differentiation::{Record, RecordMatrix, WengertList};
2359 *
2360 * let history = WengertList::new();
2361 * let mut weights = RecordMatrix::variables(
2362 * &history,
2363 * Matrix::from(vec![vec![ 0.3, 0.2, -1.2, -0.4 ]])
2364 * );
2365 * let error = {
2366 * // this is over-simplified for brevity, obviously in a real scenario we wouldn't have a
2367 * // function that calculates the error like this or we wouldn't be doing machine learning
2368 * // to fit it in the first place
2369 * let mut loss = Record::variable(0.0, &history);
2370 * for r in weights.iter_row_major_as_records() {
2371 * loss = loss + r;
2372 * }
2373 * loss
2374 * };
2375 * let derivatives = error.derivatives();
2376 * let learning_rate = 0.1;
2377 * // update the weights to contain less error than they did before
2378 * let result = weights.map_mut(|x| x - (derivatives[&x] * learning_rate));
2379 * assert!(result.is_ok()); // we know we didn't introduce an inconsistent history just updating the weights
2380 * ```
2381 */
2382 #[track_caller]
2383 pub fn map_mut(
2384 &mut self,
2385 fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
2386 ) -> Result<(), InconsistentHistory<'a, T>> {
2387 let history = self.history;
2388 let new_history =
2389 map_mut_base::<'a, T, _, _>(RowMajorReferenceMutIterator::from(self), |x| {
2390 let record = Record::from_existing(x.clone(), history);
2391 let result = fx(record);
2392 *x = (result.number, result.index);
2393 result.history
2394 })?;
2395 self.history = new_history;
2396 Ok(())
2397 }
2398
2399 /**
2400 * Updates this RecordContainer in place by applying some unary function on `Record<T>` and
2401 * each index of that position in the Record to `Record<T>` to every element in the container.
2402 * This will fail if the function would create records with inconsistent histories.
2403 *
2404 * When used with pure functions that can't return different histories for different inputs
2405 * unwrapping with always succeed.
2406 *
2407 * Since this updates the container in place, if Err is returned then the data in this
2408 * RecordContainer is still available but it has been corrupted - at least one of the elements
2409 * should have a different history than what it will have because the mapping function created
2410 * inconsistent histories that couldn't be represented by the container as it only stores
2411 * one.
2412 *
2413 * This API can allow you to call a generic function that operates on
2414 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2415 * during the intermediate calculations for you, without having to resort to storing the
2416 * Record types.
2417 *
2418 * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2419 * after mapping doesn't have to be the same as before, only must be the same for every
2420 * mapped element.
2421 */
2422 #[track_caller]
2423 pub fn map_mut_with_index(
2424 &mut self,
2425 fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
2426 ) -> Result<(), InconsistentHistory<'a, T>> {
2427 let history = self.history;
2428 let new_history = map_mut_base::<'a, T, _, _>(
2429 RowMajorReferenceMutIterator::from(self).with_index(),
2430 |((r, c), x)| {
2431 let record = Record::from_existing(x.clone(), history);
2432 let result = fx(record, r, c);
2433 *x = (result.number, result.index);
2434 result.history
2435 },
2436 )?;
2437 self.history = new_history;
2438 Ok(())
2439 }
2440}
2441
2442impl<'a, T, S> RecordMatrix<'a, T, S>
2443where
2444 T: Numeric + Primitive,
2445 for<'t> &'t T: NumericRef<T>,
2446 S: MatrixRef<(T, Index)> + NoInteriorMutability,
2447{
2448 /**
2449 * Overwrites the right hand side of a RecordContainer with the result of applying
2450 * some binary function from `T` to `T` to every element pair in the containers. Both
2451 * containers must have the same shape.
2452 * To compute the new records, the binary function of some inputs x and y to some
2453 * output z is needed along with its derivative with respect to its first input x and
2454 * its derivative with respect to its second input y.
2455 *
2456 * # Panics
2457 *
2458 * - If both record containers have a WengertList that are different to each other
2459 * - If the record containers have different shapes
2460 */
2461 #[track_caller]
2462 pub fn binary_right_assign<S2>(
2463 &self,
2464 rhs: &mut RecordMatrix<'a, T, S2>,
2465 fxy: impl Fn(T, T) -> T,
2466 dfxy_dx: impl Fn(T, T) -> T,
2467 dfxy_dy: impl Fn(T, T) -> T,
2468 ) where
2469 S2: MatrixMut<(T, Index)> + NoInteriorMutability,
2470 {
2471 // x is lhs, y is rhs, so calling binary_left_assign on the rhs container
2472 // means we need to swap all the arguments
2473 rhs.binary_left_assign(
2474 self,
2475 |y, x| fxy(x, y),
2476 |y, x| dfxy_dy(x, y),
2477 |y, x| dfxy_dx(x, y),
2478 )
2479 }
2480
2481 /**
2482 * A convenience helper function which takes the right hand side by value and
2483 * calls [binary_right_assign](RecordMatrix::binary_right_assign()) on it, returning
2484 * the right hand side which now contains the result of the operation.
2485 */
2486 #[track_caller]
2487 pub fn do_binary_right_assign<S2>(
2488 &self,
2489 mut rhs: RecordMatrix<'a, T, S2>,
2490 fxy: impl Fn(T, T) -> T,
2491 dfxy_dx: impl Fn(T, T) -> T,
2492 dfxy_dy: impl Fn(T, T) -> T,
2493 ) -> RecordMatrix<'a, T, S2>
2494 where
2495 S2: MatrixMut<(T, Index)> + NoInteriorMutability,
2496 {
2497 self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
2498 rhs
2499 }
2500}
2501
2502// # Safety
2503//
2504// Our inner `numbers` tensor has to implement TensorRef correctly so by delegating to it
2505// without changing any indexes or introducing interior mutability, we implement TensorRef
2506// correctly as well.
2507/**
2508 * RecordTensor implements TensorRef when the source does, returning references to the tuples
2509 * of `T` and [`Index`](Index).
2510 */
2511unsafe impl<'a, T, S, const D: usize> TensorRef<(T, Index), D> for RecordTensor<'a, T, S, D>
2512where
2513 T: Primitive,
2514 S: TensorRef<(T, Index), D>,
2515{
2516 fn get_reference(&self, indexes: [usize; D]) -> Option<&(T, Index)> {
2517 self.numbers.source_ref().get_reference(indexes)
2518 }
2519
2520 fn view_shape(&self) -> [(Dimension, usize); D] {
2521 self.numbers.source_ref().view_shape()
2522 }
2523
2524 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &(T, Index) {
2525 unsafe { self.numbers.source_ref().get_reference_unchecked(indexes) }
2526 }
2527
2528 fn data_layout(&self) -> DataLayout<D> {
2529 self.numbers.source_ref().data_layout()
2530 }
2531}
2532
2533// # Safety
2534//
2535// Our inner `numbers` tensor has to implement TensorMut correctly so by delegating to it
2536// without changing any indexes or introducing interior mutability, we implement TensorMut
2537// correctly as well.
2538/**
2539 * RecordTensor implements TensorMut when the source does, returning mutable references to the
2540 * tuples of `T` and [`Index`](Index).
2541 */
2542unsafe impl<'a, T, S, const D: usize> TensorMut<(T, Index), D> for RecordTensor<'a, T, S, D>
2543where
2544 T: Primitive,
2545 S: TensorMut<(T, Index), D>,
2546{
2547 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut (T, Index)> {
2548 self.numbers.source_ref_mut().get_reference_mut(indexes)
2549 }
2550
2551 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut (T, Index) {
2552 unsafe {
2553 self.numbers
2554 .source_ref_mut()
2555 .get_reference_unchecked_mut(indexes)
2556 }
2557 }
2558}
2559
2560// # Safety
2561//
2562// Our inner `numbers` matrix has to implement MatrixRef correctly so by delegating to it
2563// without changing any indexes or introducing interior mutability, we implement MatrixRef
2564// correctly as well.
2565/**
2566 * RecordMatrix implements MatrixRef when the source does, returning references to the tuples
2567 * of `T` and [`Index`](Index).
2568 */
2569unsafe impl<'a, T, S> MatrixRef<(T, Index)> for RecordMatrix<'a, T, S>
2570where
2571 T: Primitive,
2572 S: MatrixRef<(T, Index)>,
2573{
2574 fn try_get_reference(&self, row: Row, column: Column) -> Option<&(T, Index)> {
2575 self.numbers.source_ref().try_get_reference(row, column)
2576 }
2577
2578 fn view_rows(&self) -> Row {
2579 self.numbers.source_ref().view_rows()
2580 }
2581
2582 fn view_columns(&self) -> Column {
2583 self.numbers.source_ref().view_columns()
2584 }
2585
2586 unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &(T, Index) {
2587 unsafe {
2588 self.numbers
2589 .source_ref()
2590 .get_reference_unchecked(row, column)
2591 }
2592 }
2593
2594 fn data_layout(&self) -> crate::matrices::views::DataLayout {
2595 self.numbers.source_ref().data_layout()
2596 }
2597}
2598
2599// # Safety
2600//
2601// Our inner `numbers` matrix has to implement NoInteriorMutability correctly so by delegating to
2602// it without introducing interior mutability, we implement NoInteriorMutability
2603// correctly as well.
2604/**
2605 * RecordMatrix implements NoInteriorMutability when the source does.
2606 */
2607unsafe impl<'a, T, S> NoInteriorMutability for RecordMatrix<'a, T, S>
2608where
2609 T: Primitive,
2610 S: NoInteriorMutability,
2611{
2612}
2613
2614// # Safety
2615//
2616// Our inner `numbers` matrix has to implement MatrixMut correctly so by delegating to it
2617// without changing any indexes or introducing interior mutability, we implement MatrixMut
2618// correctly as well.
2619/**
2620 * RecordMatrix implements MatrixMut when the source does, returning mutable references to the
2621 * tuples of `T` and [`Index`].
2622 */
2623unsafe impl<'a, T, S> MatrixMut<(T, Index)> for RecordMatrix<'a, T, S>
2624where
2625 T: Primitive,
2626 S: MatrixMut<(T, Index)>,
2627{
2628 fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut (T, Index)> {
2629 self.numbers
2630 .source_ref_mut()
2631 .try_get_reference_mut(row, column)
2632 }
2633
2634 unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut (T, Index) {
2635 unsafe {
2636 self.numbers
2637 .source_ref_mut()
2638 .get_reference_unchecked_mut(row, column)
2639 }
2640 }
2641}
2642
2643/**
2644 * A zero dimensional record tensor can be converted losslessly into a record.
2645 */
2646impl<'a, T, S> From<RecordTensor<'a, T, S, 0>> for Record<'a, T>
2647where
2648 T: Numeric + Primitive,
2649 S: TensorRef<(T, Index), 0>,
2650{
2651 /**
2652 * Converts the sole element in the zero dimensional record tensor into a record.
2653 */
2654 fn from(scalar: RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
2655 // Not a good way to make this zero copy and just move the data out of the scalar because
2656 // TensorRef API doesn't have by value moves of the data and using TensorOwnedIterator
2657 // requires T: Default and a dummy value (at which point a clone is probably cheaper or
2658 // basically the same?)
2659 Record::from(&scalar)
2660 }
2661}
2662
2663/**
2664 * A zero dimensional record tensor can be converted losslessly into a record.
2665 */
2666impl<'a, T, S> From<&RecordTensor<'a, T, S, 0>> for Record<'a, T>
2667where
2668 T: Numeric + Primitive,
2669 S: TensorRef<(T, Index), 0>,
2670{
2671 /**
2672 * Converts the sole element in the zero dimensional record tensor into a record.
2673 */
2674 fn from(scalar: &RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
2675 Record::from_existing(scalar.view().scalar(), scalar.history)
2676 }
2677}
2678
2679/**
2680 * A record can be converted losslessly into a zero dimensional record tensor.
2681 */
2682impl<'a, T> From<Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
2683where
2684 T: Numeric + Primitive,
2685{
2686 /**
2687 * Converts a record into a zero dimensional record tensor with the single element.
2688 */
2689 fn from(record: Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
2690 RecordTensor::from_existing(
2691 record.history,
2692 TensorView::from(Tensor::from([], vec![(record.number, record.index)])),
2693 )
2694 }
2695}
2696
2697/**
2698 * A record can be converted losslessly into a zero dimensional record tensor.
2699 */
2700impl<'a, T> From<&Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
2701where
2702 T: Numeric + Primitive,
2703{
2704 /**
2705 * Converts a record into a zero dimensional record tensor with the single element.
2706 */
2707 fn from(record: &Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
2708 RecordTensor::from_existing(
2709 record.history,
2710 TensorView::from(Tensor::from(
2711 [],
2712 vec![(record.number.clone(), record.index)],
2713 )),
2714 )
2715 }
2716}
2717
2718#[test]
2719fn matrix_multiplication_derivatives_are_the_same() {
2720 #[rustfmt::skip]
2721 let a = Tensor::from(
2722 [("r", 4), ("c", 3)],
2723 vec![
2724 1.0, 2.0, 3.0,
2725 4.0, 5.0, 6.0,
2726 7.0, 8.0, 9.0,
2727 0.0, 5.0, 2.0
2728 ]
2729 );
2730 let b = a.transpose(["c", "r"]);
2731 let history = WengertList::new();
2732 let also_history: WengertList<f64> = WengertList::new();
2733 let tensor_of_records_a = a.map(|x| Record::variable(x, &history));
2734 let tensor_of_records_b = b.map(|x| Record::variable(x, &history));
2735 let tensor_of_records_c = &tensor_of_records_a * &tensor_of_records_b;
2736 let record_tensor_a = RecordTensor::variables(&also_history, a);
2737 let record_tensor_b = RecordTensor::variables(&also_history, b);
2738 let record_tensor_c = &record_tensor_a * &record_tensor_b;
2739
2740 // C should be calculated the same in terms of the actual number
2741 assert_eq!(
2742 tensor_of_records_c.map(|r| r.number),
2743 TensorView::from(&record_tensor_c).map(|(n, _)| n)
2744 );
2745
2746 let tensor_of_records_derivatives = tensor_of_records_c.map(|r| r.derivatives());
2747 let tensor_of_records_a_derivatives =
2748 tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_a));
2749 let tensor_of_records_b_derivatives =
2750 tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_b));
2751
2752 let record_tensor_derivatives = record_tensor_c.derivatives().unwrap();
2753 let record_tensor_a_derivatives =
2754 record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_a));
2755 let record_tensor_b_derivatives =
2756 record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_b));
2757
2758 // Every calculated derivative should match exactly
2759 assert_eq!(tensor_of_records_a_derivatives, record_tensor_a_derivatives);
2760 assert_eq!(tensor_of_records_b_derivatives, record_tensor_b_derivatives);
2761
2762 // Verify C is actually calculated correctly
2763 #[rustfmt::skip]
2764 assert_eq!(
2765 tensor_of_records_c.map(|r| r.number),
2766 Tensor::from(
2767 [("r", 4), ("c", 4)],
2768 vec![
2769 14.0, 32.0, 50.0, 16.0,
2770 32.0, 77.0, 122.0, 37.0,
2771 50.0, 122.0, 194.0, 58.0,
2772 16.0, 37.0, 58.0, 29.0
2773 ]
2774 )
2775 );
2776 #[rustfmt::skip]
2777 assert_eq!(
2778 tensor_of_records_c.map(|r| r.number),
2779 Tensor::from(
2780 [("r", 4), ("c", 4)],
2781 vec![
2782 (1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
2783 (1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
2784 (1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
2785 (1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
2786
2787 (4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
2788 (4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
2789 (4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
2790 (4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
2791
2792 (7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
2793 (7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
2794 (7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
2795 (7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
2796
2797 (0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
2798 (0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
2799 (0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
2800 (0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
2801 ]
2802 )
2803 );
2804
2805 let tensor_of_records_derivatives = history.operations.borrow().clone();
2806 let record_tensor_derivatives = also_history.operations.borrow().clone();
2807 assert_eq!(
2808 tensor_of_records_derivatives.len(),
2809 record_tensor_derivatives.len()
2810 );
2811}
2812
2813#[test]
2814fn matrix_view_matrix_multiplication_derivatives_are_the_same() {
2815 #[rustfmt::skip]
2816 let a = Matrix::from(vec![
2817 vec![ 1.0, 2.0, 3.0 ],
2818 vec![ 4.0, 5.0, 6.0 ],
2819 vec![ 7.0, 8.0, 9.0 ],
2820 vec![ 0.0, 5.0, 2.0 ]
2821 ]);
2822 let b = a.transpose();
2823 let history = WengertList::new();
2824 let also_history: WengertList<f64> = WengertList::new();
2825 let matrix_of_records_a = a.map(|x| Record::variable(x, &history));
2826 let matrix_of_records_b = b.map(|x| Record::variable(x, &history));
2827 let matrix_of_records_c = &matrix_of_records_a * &matrix_of_records_b;
2828 let record_matrix_a = RecordMatrix::variables(&also_history, a);
2829 let record_matrix_b = RecordMatrix::variables(&also_history, b);
2830 let record_matrix_c = &record_matrix_a * &record_matrix_b;
2831
2832 // C should be calculated the same in terms of the actual number
2833 assert_eq!(
2834 matrix_of_records_c.map(|r| r.number),
2835 MatrixView::from(&record_matrix_c).map(|(n, _)| n)
2836 );
2837
2838 let matrix_of_records_derivatives = matrix_of_records_c.map(|r| r.derivatives());
2839 let matrix_of_records_a_derivatives =
2840 matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_a));
2841 let matrix_of_records_b_derivatives =
2842 matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_b));
2843
2844 let record_matrix_derivatives = record_matrix_c.derivatives().unwrap();
2845 let record_matrix_a_derivatives =
2846 record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_a));
2847 let record_matrix_b_derivatives =
2848 record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_b));
2849
2850 // Every calculated derivative should match exactly
2851 assert_eq!(matrix_of_records_a_derivatives, record_matrix_a_derivatives);
2852 assert_eq!(matrix_of_records_b_derivatives, record_matrix_b_derivatives);
2853
2854 // Verify C is actually calculated correctly
2855 #[rustfmt::skip]
2856 assert_eq!(
2857 matrix_of_records_c.map(|r| r.number),
2858 Matrix::from(vec![
2859 vec![ 14.0, 32.0, 50.0, 16.0 ],
2860 vec![ 32.0, 77.0, 122.0, 37.0 ],
2861 vec![ 50.0, 122.0, 194.0, 58.0 ],
2862 vec![ 16.0, 37.0, 58.0, 29.0 ]
2863 ]
2864 )
2865 );
2866 #[rustfmt::skip]
2867 assert_eq!(
2868 matrix_of_records_c.map(|r| r.number),
2869 Matrix::from(vec![
2870 vec![
2871 (1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
2872 (1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
2873 (1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
2874 (1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
2875 ],
2876 vec![
2877 (4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
2878 (4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
2879 (4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
2880 (4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
2881 ],
2882 vec![
2883 (7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
2884 (7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
2885 (7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
2886 (7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
2887 ],
2888 vec![
2889 (0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
2890 (0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
2891 (0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
2892 (0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
2893 ]
2894 ])
2895 );
2896
2897 let matrix_of_records_derivatives = history.operations.borrow().clone();
2898 let record_matrix_derivatives = also_history.operations.borrow().clone();
2899 assert_eq!(
2900 matrix_of_records_derivatives.len(),
2901 record_matrix_derivatives.len()
2902 );
2903}