easy_ml/differentiation/container_record/
mod.rs

1use crate::differentiation::functions::{Division, FunctionDerivative, Multiplication};
2use crate::differentiation::iterators::{
3    AsRecords, InconsistentHistory, InvalidRecordIteratorError,
4};
5use crate::differentiation::record_operations;
6use crate::differentiation::{Derivatives, Index, Primitive, Record, WengertList};
7use crate::matrices::iterators::{
8    ColumnMajorIterator, RowMajorIterator, RowMajorOwnedIterator, RowMajorReferenceMutIterator,
9};
10use crate::matrices::views::{MatrixMut, MatrixRef, MatrixView, NoInteriorMutability};
11use crate::matrices::{Column, Matrix, Row};
12use crate::numeric::{Numeric, NumericRef};
13use crate::tensors::indexing::{
14    TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceMutIterator,
15};
16use crate::tensors::views::{DataLayout, TensorMut, TensorRef, TensorRename, TensorView};
17use crate::tensors::{Dimension, Tensor};
18
19mod container_operations;
20pub mod iterators;
21
22/**
23 * A pluralisation of [Record](crate::differentiation::Record) that groups together a
24 * **s**ource of numbers of type T and stores the WengertList only once.
25 *
26 * Typically you would refer to one of the type aliases to disambiguate the type of `S` and
27 * use more succinct generics: [RecordMatrix](RecordMatrix), [RecordTensor](RecordTensor).
28 *
29 * For both Matrix and Tensor source types, the containers implement [`+`](std::ops::Add) and
30 * [`-`](std::ops::Sub) and have the methods `elementwise_multiply` and `elementwise_divide`.
31 * In all cases the containers must have the same size for the operation and will panic if
32 * mismatched.
33 * [`*`](std::ops::Mul) is also implemented for 2 dimensional tensors and matrices as matrix
34 * multiplication.
35 *
36 * For convenience, as with Trace and Record, many unary operations including
37 * [Cos](crate::numeric::extra::Cos), [Exp](crate::numeric::extra::Exp),
38 * [Ln](crate::numeric::extra::Ln), [Neg](std::ops::Neg), [Pow](crate::numeric::extra::Pow),
39 * [Sin](crate::numeric::extra::Sin), and [Sqrt](crate::numeric::extra::Sqrt) are implemented as
40 * well, applying the unary function to each element in the tensor.
41 *
42 * `+`, `-`, `*` and `/` operations with a RecordContainer and a scalar are also implemented,
43 * treating the right hand side scalar as a constant. These are also unary functions in terms of
44 * the derivative tracking, for example `X + 5` applies the function `+5` to each element in
45 * `X`. Due to the orphan rule, the standard library scalars cannot be implemented for a left hand
46 * side scalar, see [SwappedOperations](crate::differentiation::record_operations::SwappedOperations).
47 *
48 * Both [RecordMatrix](RecordMatrix) and [RecordTensor](RecordTensor) implement
49 * [MatrixRef](MatrixRef) and [TensorRef](TensorRef) respectively, which provide read and write
50 * access to the underlying numbers and indexes into the WengertList. These APIs along with the
51 * `from_existing` constructors for RecordMatrix, RecordTensor, and Record allow arbitary
52 * manipulation of specific elements in a record container if needed. However, any arithmetic
53 * performed on the raw data won't be tracked on the WengertList and overwriting data within a
54 * record container already tracked on the WengertList could result in nonsense. These APIs exist
55 * for easy read access to check the data in a record container and for read/write access when
56 * manipulating the shape of a record container, and are designed to be used only for moving data
57 * around - you should put it back unchanged in a RecordContainer or Record before doing further
58 * arithmetic that needs to be tracked on the WengertList.
59 *
60 * If you just want to manipulate the data in record containers as if they were Records you can
61 * use the iterator APIs of [AsRecords](AsRecords) instead and collect them back into containers
62 * when you're done, or to manipulate a single container, the [map](RecordContainer::map) and
63 * [map_mut](RecordContainer::map_mut) methods.
64 */
65#[derive(Debug)]
66pub struct RecordContainer<'a, T: Primitive, S, const D: usize> {
67    // Opted to store the indexes alongside each number (T, Index) for a number of reasons, the
68    // main factor being it makes implementing TensorRef feasible so can utilise the range of
69    // existing APIs for Tensor manipulation. It's theoretically possible to only store the first
70    // index and calculate the rest, since most of the time all indexes are ascending entries in
71    // the WengertList but this would also massively complicate the implementation, especially for
72    // handling non row-major operations such as matrix multiplication. It's also not super clear
73    // that this would be more efficient because it turns reads into more arithmetic rather
74    // than avoiding any work. Just lifting the WengertList out of the tensor should have
75    // meaningful improvements to cache line efficiency, and failing that still disallows
76    // very questionable states from existing.
77    numbers: S,
78    history: Option<&'a WengertList<T>>,
79}
80
81/**
82 * Alias for succinctly referring to RecordContainers backed by a matrix.
83 */
84pub type RecordMatrix<'a, T, S> = RecordContainer<'a, T, MatrixView<(T, Index), S>, 2>;
85
86/**
87 * Alias for succinctly referring to RecordContainers backed by a tensor.
88 */
89pub type RecordTensor<'a, T, S, const D: usize> =
90    RecordContainer<'a, T, TensorView<(T, Index), S, D>, D>;
91
92fn calculate_incrementing_indexes(starting_index: usize, total: usize) -> Vec<Index> {
93    let mut indexes = vec![0; total];
94    for (i, x) in indexes.iter_mut().enumerate() {
95        *x = starting_index + i;
96    }
97    indexes
98}
99
100impl<'a, T, const D: usize> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
101where
102    T: Numeric + Primitive,
103{
104    /**
105     * Creates multiple untracked Records which have no backing WengertList.
106     *
107     * This is provided for using constants along with Records in operations.
108     *
109     * For example with `Y = X + 4` the computation graph could be conceived as many
110     * `Y[i,j]` nodes with parent nodes of `X[i,j]` and 4 combined with the operation `+`.
111     * However there is no need to record the derivatives of a constant, so
112     * instead the computation graph can be conceived as `Y[i,j]` nodes each with a single
113     * parent node of `X[i,j]` and the unary operation of `+4`.
114     */
115    pub fn constants<S>(c: S) -> Self
116    where
117        S: TensorMut<T, D>,
118    {
119        RecordContainer {
120            numbers: TensorView::from(Tensor::from(
121                c.view_shape(),
122                TensorOwnedIterator::from_numeric(c)
123                    .map(|x| (x, 0))
124                    .collect(),
125            )),
126            history: None,
127        }
128    }
129
130    /**
131     * Creates multiple records backed by the provided WengertList.
132     *
133     * The records cannot live longer than the WengertList, hence
134     * the following example does not compile
135     *
136     * ```compile_fail
137     * use easy_ml::differentiation::RecordTensor;
138     * use easy_ml::differentiation::WengertList;
139     * use easy_ml::tensors::Tensor;
140     * let record = {
141     *     let list = WengertList::new();
142     *     RecordTensor::variables(
143     *         &list,
144     *         Tensor::from([("r", 2), ("c", 2)], vec![ 1.0, 2.0, 3.0, 4.0 ])
145     *     )
146     * }; // list no longer in scope
147     * ```
148     */
149    pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
150    where
151        S: TensorMut<T, D>,
152    {
153        let total = crate::tensors::dimensions::elements(&x.view_shape());
154        let starting_index = history.append_nullary_repeating(total);
155        RecordContainer {
156            numbers: TensorView::from(Tensor::from(
157                x.view_shape(),
158                TensorOwnedIterator::from_numeric(x)
159                    .zip(calculate_incrementing_indexes(starting_index, total))
160                    .collect(),
161            )),
162            history: Some(history),
163        }
164    }
165}
166
167impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
168where
169    T: Numeric + Primitive,
170    S: TensorRef<(T, Index), D>,
171{
172    /**
173     * Returns the number of elements stored by this container's source.
174     *
175     * For a 2 x 3 Tensor, this would return 6, and for a 2 x 3 x 4 Tensor this would return 24
176     * and so on.
177     *
178     * see also [dimensions::elements](crate::tensors::dimensions::elements)
179     */
180    pub fn elements(&self) -> usize {
181        crate::tensors::dimensions::elements(&self.numbers.shape())
182    }
183
184    /**
185     * The shape of this container's source.
186     */
187    pub fn shape(&self) -> [(Dimension, usize); D] {
188        self.numbers.shape()
189    }
190
191    /**
192     * Creates a container from constants/variables directly, most likely obtained by getting a
193     * tensor view of an existing container. **The inputs are not checked for validity**. It is
194     * possible to pass in the wrong Wengert list here or even numbers with indexes that aren't
195     * tracked on the WengertList.
196     *
197     * It is recommended to use this constructor only in conjunction with
198     * resizing or masking an existing container and not for creating new variables. Any variables
199     * created outside of `RecordContainer::variables` would have to be manually added to the
200     * correct Wengert list, and any arithmetic operations would also need tracking correctly.
201     *
202     * ```
203     * use easy_ml::differentiation::RecordTensor;
204     * use easy_ml::differentiation::WengertList;
205     * use easy_ml::tensors::Tensor;
206     * use easy_ml::tensors::views::{TensorView, TensorRange};
207     *
208     * let list = WengertList::new();
209     * let x = RecordTensor::variables(
210     *     &list,
211     *     Tensor::from_fn([("x", 2), ("y", 2)], |[r, c]| ((r + 3) * (c + 2)) as f64)
212     * );
213     * // oh no wrong shape!
214     * let fixed = TensorView::from(TensorRange::from(x, [("y", 0..1)]).unwrap()); // we can unwrap here because we know the range is valid
215     * let x = RecordTensor::from_existing(Some(&list), fixed);
216     * assert_eq!([("x", 2), ("y", 1)], x.shape());
217     * ```
218     */
219    pub fn from_existing(
220        history: Option<&'a WengertList<T>>,
221        numbers: TensorView<(T, Index), S, D>,
222    ) -> Self {
223        RecordContainer { numbers, history }
224    }
225
226    /**
227     * Returns a record tensor with the dimension names of the shape renamed to the provided
228     * dimensions. The data of this container and the dimension lengths and order remain unchanged.
229     *
230     * This is a shorthand for constructing the RecordTensor via manipulating this TensorView. See
231     * [`RecordTensor::from_existing`](RecordTensor::from_existing).
232     *
233     * # Panics
234     *
235     * If a dimension name is not unique
236     *
237     * ```
238     * use easy_ml::differentiation::RecordTensor;
239     * use easy_ml::differentiation::WengertList;
240     * use easy_ml::tensors::Tensor;
241     * use easy_ml::tensors::views::{TensorView, TensorRename};
242     *
243     * let list = WengertList::new();
244     * let x = RecordTensor::variables(
245     *     &list,
246     *     Tensor::from_fn([("x", 2), ("y", 2)], |[r, c]| ((r + 3) * (c + 2)) as f64)
247     * );
248     * // oh no wrong dimension names!
249     * let x = x.rename_view(["a", "b"]);
250     * assert_eq!([("a", 2), ("b", 2)], x.shape());
251     * ```
252     */
253    #[track_caller]
254    pub fn rename_view(
255        self,
256        dimensions: [Dimension; D],
257    ) -> RecordTensor<'a, T, TensorRename<(T, Index), S, D>, D> {
258        RecordTensor::from_existing(
259            self.history,
260            TensorView::from(TensorRename::from(self.numbers.source(), dimensions)),
261        )
262    }
263
264    /**
265     * Returns a TensorView of this record container, both the `T` for each record element and
266     * also the index for that record's entry in the WengertList. These can be parsed back into
267     * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
268     * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
269     * numerical operations on the data.
270     */
271    pub fn view(&self) -> TensorView<(T, Index), &RecordTensor<'a, T, S, D>, D> {
272        TensorView::from(self)
273    }
274
275    /**
276     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
277     * to read values from this record container, both the `T` for each record element and
278     * also the index for that record's entry in the WengertList. These can be parsed back into
279     * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
280     * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
281     * numerical operations on the data.
282     *
283     * # Panics
284     *
285     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
286     */
287    #[track_caller]
288    pub fn index_by(
289        &self,
290        dimensions: [Dimension; D],
291    ) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
292        TensorAccess::from(self, dimensions)
293    }
294
295    /**
296     * Creates a TensorAccess which will index into the dimensions this record was created with
297     * in the same order as they were provided, both the `T` for each record element and
298     * also the index for that record's entry in the WengertList. These can be parsed back into
299     * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
300     * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
301     * numerical operations on the data.
302     *
303     * See [TensorAccess::from_source_order], [get_as_record](TensorAccess::get_as_record).
304     *
305     * ```
306     * use easy_ml::differentiation::RecordTensor;
307     * use easy_ml::differentiation::WengertList;
308     * use easy_ml::tensors::Tensor;
309     *
310     * let list = WengertList::new();
311     * let X = RecordTensor::variables(
312     *     &list,
313     *     Tensor::from([("a", 3)], vec![ 3.0, 4.0, 5.0 ])
314     * );
315     * let x = X.index().get_as_record([0]);
316     * assert_eq!(x.number, 3.0);
317     * ```
318     */
319    pub fn index(&self) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
320        TensorAccess::from_source_order(self)
321    }
322
323    /**
324     * Returns an iterator over this record tensor as [Record]s instead of the raw `(T, Index)`
325     * data. After manipulating the iterator it can be collected back into a RecordTensor with
326     * [RecordTensor::from_iter](RecordTensor::from_iter).
327     *
328     * This is a shorthand for `AsRecords::from(tensor.history(), TensorIterator::from(&tensor))`
329     */
330    #[allow(clippy::type_complexity)]
331    pub fn iter_as_records<'b>(
332        &'b self,
333    ) -> AsRecords<'a, TensorIterator<'b, (T, Index), RecordTensor<'a, T, S, D>, D>, T> {
334        AsRecords::from_tensor(self)
335    }
336}
337
338impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
339where
340    T: Numeric + Primitive,
341    S: TensorMut<(T, Index), D>,
342{
343    /**
344     * Resets all of the records to place them back on the WengertList, for use
345     * in performing another derivation after clearing the WengertList.
346     *
347     * This is also a preferred shorthand for `map_mut(Record::do_reset)` that can't fail
348     */
349    pub fn reset(&mut self) {
350        match self.history {
351            None => (), // noop
352            Some(history) => {
353                let total = self.elements();
354                let starting_index = history.append_nullary_repeating(total);
355                for (x, i) in self
356                    .numbers
357                    .iter_reference_mut()
358                    .zip(calculate_incrementing_indexes(starting_index, total))
359                {
360                    let (_, old_index) = x;
361                    *old_index = i;
362                }
363            }
364        };
365    }
366
367    /**
368     * A convenience helper function which takes a RecordContainer by value and
369     * calls [reset](RecordTensor::reset()) on it.
370     */
371    pub fn do_reset(mut x: Self) -> Self {
372        x.reset();
373        x
374    }
375}
376
377impl<'a, T> RecordMatrix<'a, T, Matrix<(T, Index)>>
378where
379    T: Numeric + Primitive,
380{
381    /**
382     * Creates multiple untracked Records which have no backing WengertList.
383     *
384     * This is provided for using constants along with Records in operations.
385     *
386     * For example with `Y = X + 4` the computation graph could be conceived as many
387     * `Y[i,j]` nodes with parent nodes of `X[i,j]` and 4 combined with the operation `+`.
388     * However there is no need to record the derivatives of a constant, so
389     * instead the computation graph can be conceived as `Y[i,j]` nodes each with a single
390     * parent node of `X[i,j]` and the unary operation of `+4`.
391     */
392    pub fn constants<S>(c: S) -> Self
393    where
394        S: MatrixMut<T> + NoInteriorMutability,
395    {
396        RecordContainer {
397            numbers: MatrixView::from(Matrix::from_flat_row_major(
398                (c.view_rows(), c.view_columns()),
399                RowMajorOwnedIterator::from_numeric(c)
400                    .map(|x| (x, 0))
401                    .collect(),
402            )),
403            history: None,
404        }
405    }
406
407    /**
408     * Creates multiple records backed by the provided WengertList.
409     *
410     * The records cannot live longer than the WengertList, hence
411     * the following example does not compile
412     *
413     * ```compile_fail
414     * use easy_ml::differentiation::RecordMatrix;
415     * use easy_ml::differentiation::WengertList;
416     * use easy_ml::matrices::Matrix;
417     * let record = {
418     *     let list = WengertList::new();
419     *     RecordMatrix::variables(
420     *         &list,
421     *         Matrix::from(vec![vec![1.0, 2.0], vec![3.0, 4.0]])
422     *     )
423     * }; // list no longer in scope
424     * ```
425     */
426    pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
427    where
428        S: MatrixMut<T> + NoInteriorMutability,
429    {
430        let total = x.view_rows() * x.view_columns();
431        let starting_index = history.append_nullary_repeating(total);
432        RecordContainer {
433            numbers: MatrixView::from(Matrix::from_flat_row_major(
434                (x.view_rows(), x.view_columns()),
435                RowMajorOwnedIterator::from_numeric(x)
436                    .zip(calculate_incrementing_indexes(starting_index, total))
437                    .collect(),
438            )),
439            history: Some(history),
440        }
441    }
442}
443
444impl<'a, T, S> RecordMatrix<'a, T, S>
445where
446    T: Numeric + Primitive,
447    S: MatrixRef<(T, Index)> + NoInteriorMutability,
448{
449    /**
450     * Returns the number of elements stored by this container's source.
451     *
452     * For a 2 x 3 Matrix, this would return 6, and for a 3 x 4 Matrix this would return 12
453     * and so on.
454     */
455    pub fn elements(&self) -> usize {
456        self.numbers.rows() * self.numbers.columns()
457    }
458
459    /**
460     * Returns the dimensionality of this matrix container in Row, Column format
461     */
462    pub fn size(&self) -> (Row, Column) {
463        self.numbers.size()
464    }
465
466    /**
467     * Gets the number of rows visible to this matrix container.
468     */
469    pub fn rows(&self) -> Row {
470        self.numbers.rows()
471    }
472
473    /**
474     * Gets the number of columns visible to this matrix container.
475     */
476    pub fn columns(&self) -> Column {
477        self.numbers.columns()
478    }
479
480    /**
481     * Creates a container from constants/variables directly, most likely obtained by getting a
482     * matrix view of an existing container. **The inputs are not checked for validity**. It is
483     * possible to pass in the wrong Wengert list here or even numbers with indexes that aren't
484     * tracked on the WengertList.
485     *
486     * It is recommended to use this constructor only in conjunction with
487     * resizing or masking an existing container and not for creating new variables. Any variables
488     * created outside of `RecordContainer::variables` would have to be manually added to the
489     * correct Wengert list, and any arithmetic operations would also need tracking correctly.
490     *
491     * ```
492     * use easy_ml::differentiation::RecordMatrix;
493     * use easy_ml::differentiation::WengertList;
494     * use easy_ml::matrices::Matrix;
495     * use easy_ml::matrices::views::{MatrixView, MatrixRange};
496     *
497     * let list = WengertList::new();
498     * let x = RecordMatrix::variables(
499     *     &list,
500     *     Matrix::from_fn((2, 2), |(r, c)| ((r + 3) * (c + 2)) as f64)
501     * );
502     * // oh no wrong shape!
503     * let fixed = MatrixView::from(MatrixRange::from(x, 0..2, 0..1));
504     * let x = RecordMatrix::from_existing(Some(&list), fixed);
505     * assert_eq!((2, 1), x.size());
506     * ```
507     */
508    pub fn from_existing(
509        history: Option<&'a WengertList<T>>,
510        numbers: MatrixView<(T, Index), S>,
511    ) -> Self {
512        RecordContainer { numbers, history }
513    }
514
515    /**
516     * Returns a MatrixView of this record container, both the `T` for each record element and
517     * also the index for that record's entry in the WengertList. These can be parsed back into
518     * a RecordMatrix with [`from_existing`](RecordMatrix::from_existing) or individually into
519     * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
520     * numerical operations on the data.
521     */
522    pub fn view(&self) -> MatrixView<(T, Index), &RecordMatrix<'a, T, S>> {
523        MatrixView::from(self)
524    }
525
526    /**
527     * Returns an iterator over this record matrix as [Record]s instead of the raw `(T, Index)`
528     * data. After manipulating the iterator it can be collected back into a RecordMatrix with
529     * [RecordMatrix::from_iter](RecordMatrix::from_iter).
530     *
531     * This is a shorthand for `AsRecords::from(matrix.history(), RowMajorIterator::from(&matrix))`
532     */
533    #[allow(clippy::type_complexity)]
534    pub fn iter_row_major_as_records<'b>(
535        &'b self,
536    ) -> AsRecords<'a, RowMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
537        AsRecords::from_matrix_row_major(self)
538    }
539
540    /**
541     * Returns an iterator over this record matrix as [Record]s instead of the raw `(T, Index)`
542     * data. After manipulating the iterator it can be collected back into a RecordMatrix with
543     * [RecordMatrix::from_iter](RecordMatrix::from_iter).
544     *
545     * This is a shorthand for
546     * `AsRecords::from(matrix.history(), ColumnMajorIterator::from(&matrix))`
547     */
548    #[allow(clippy::type_complexity)]
549    pub fn iter_column_major_as_records<'b>(
550        &'b self,
551    ) -> AsRecords<'a, ColumnMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
552        AsRecords::from_matrix_column_major(self)
553    }
554
555    /**
556     * Returns a copy of the data at the index as a Record. If you need to access all the data
557     * as records instead of just a specific index you should probably use one of the iterator
558     * APIs instead.
559     *
560     * See also: [iter_row_major_as_records](RecordMatrix::iter_row_major_as_records),
561     * [iter_column_major_as_records](RecordMatrix::iter_column_major_as_records)
562     *
563     * # Panics
564     *
565     * If the index is out of range.
566     *
567     * For a non panicking API see [try_get_as_record](RecordMatrix::try_get_as_record)
568     */
569    #[track_caller]
570    pub fn get_as_record(&self, row: Row, column: Column) -> Record<'a, T> {
571        Record::from_existing(self.numbers.get(row, column), self.history)
572    }
573
574    /**
575     * Returns a copy of the data at the index as a Record, or None if the index is
576     * out of range. If you need to access all the data as records instead of just a specific
577     * index you should probably use one of the iterator APIs instead.
578     *
579     * See also: [iter_row_major_as_records](RecordMatrix::iter_row_major_as_records),
580     * [iter_column_major_as_records](RecordMatrix::iter_column_major_as_records)
581     */
582    pub fn try_get_as_record(&self, row: Row, column: Column) -> Option<Record<'a, T>> {
583        self.numbers
584            .try_get_reference(row, column)
585            .map(|r| Record::from_existing(r.clone(), self.history))
586    }
587}
588
589impl<'a, T, S> RecordMatrix<'a, T, S>
590where
591    T: Numeric + Primitive,
592    S: MatrixMut<(T, Index)> + NoInteriorMutability,
593{
594    /**
595     * Resets all of the records to place them back on the WengertList, for use
596     * in performing another derivation after clearing the WengertList.
597     *
598     * This is also a preferred shorthand for `map_mut(Record::do_reset)` that can't fail
599     */
600    pub fn reset(&mut self) {
601        match self.history {
602            None => (), // noop
603            Some(history) => {
604                let total = self.elements();
605                let starting_index = history.append_nullary_repeating(total);
606                for (x, i) in self
607                    .numbers
608                    .row_major_reference_mut_iter()
609                    .zip(calculate_incrementing_indexes(starting_index, total))
610                {
611                    let (_, old_index) = x;
612                    *old_index = i;
613                }
614            }
615        };
616    }
617
618    /**
619     * A convenience helper function which takes a RecordContainer by value and
620     * calls [reset](RecordMatrix::reset()) on it.
621     */
622    pub fn do_reset(mut x: Self) -> Self {
623        x.reset();
624        x
625    }
626}
627
628impl<'a, T, S, const D: usize> RecordContainer<'a, T, S, D>
629where
630    T: Primitive,
631{
632    /**
633     * Gets the WengertList these records are backed by if variables, and [None](None) if constants.
634     */
635    pub fn history(&self) -> Option<&'a WengertList<T>> {
636        self.history
637    }
638}
639
640/// Returns the vec of indexes and vec of ys for Y = unary(X), not checking but assuming that the
641/// length of the iterator matches the total.
642fn unary<'a, T, I>(
643    total: usize,
644    history: &WengertList<T>,
645    records: I,
646    fx: impl Fn(T) -> T,
647    dfx_dx: impl Fn(T) -> T,
648) -> Vec<(T, usize)>
649where
650    I: Iterator<Item = (T, Index)>,
651    T: Numeric + Primitive,
652    for<'t> &'t T: NumericRef<T>,
653{
654    let mut ys = vec![(T::zero(), 0); total];
655    history.borrow(|history| {
656        // shadow the name so we can't accidentally try to use history while holding
657        // the borrow
658        // use enumerate not with_index because we need the 1D index for indexing
659        // indexes
660        for (i, (x, parent)) in records.enumerate() {
661            let y = fx(x.clone());
662            let derivative = dfx_dx(x);
663            let new_index = history.append_unary(parent, derivative);
664            ys[i] = (y, new_index)
665        }
666    }); // drop borrow on history
667    ys
668}
669
670/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), not checking but assuming that
671/// the length of the iterators match the total. Also assumes both inputs have the same shape
672fn binary_both_history<'a, T, I1, I2>(
673    total: usize,
674    history: &WengertList<T>,
675    x_records: I1,
676    y_records: I2,
677    fxy: impl Fn(T, T) -> T,
678    dfxy_dx: impl Fn(T, T) -> T,
679    dfxy_dy: impl Fn(T, T) -> T,
680) -> Vec<(T, usize)>
681where
682    I1: Iterator<Item = (T, Index)>,
683    I2: Iterator<Item = (T, Index)>,
684    T: Numeric + Primitive,
685    for<'t> &'t T: NumericRef<T>,
686{
687    let mut zs = vec![(T::zero(), 0); total];
688    history.borrow(|history| {
689        // shadow the name so we can't accidentally try to use history while holding
690        // the borrow
691        // use enumerate not with_index because we need the 1D index for indexing
692        // indexes
693        for (i, ((x, parent1), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
694            let z = fxy(x.clone(), y.clone());
695            let derivative1 = dfxy_dx(x.clone(), y.clone());
696            let derivative2 = dfxy_dy(x, y);
697            let new_index = history.append_binary(parent1, derivative1, parent2, derivative2);
698            zs[i] = (z, new_index);
699        }
700    }); // drop borrow on history
701    zs
702}
703
704/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), as with binary_both_history,
705/// but only tracking the derivatives for X, not Y.
706fn binary_x_history<'a, T, I1, I2>(
707    total: usize,
708    history: &WengertList<T>,
709    x_records: I1,
710    y_records: I2,
711    fxy: impl Fn(T, T) -> T,
712    dfxy_dx: impl Fn(T, T) -> T,
713) -> Vec<(T, usize)>
714where
715    I1: Iterator<Item = (T, Index)>,
716    I2: Iterator<Item = (T, Index)>,
717    T: Numeric + Primitive,
718    for<'t> &'t T: NumericRef<T>,
719{
720    let mut zs = vec![(T::zero(), 0); total];
721    history.borrow(|history| {
722        // shadow the name so we can't accidentally try to use history while holding
723        // the borrow
724        // use enumerate not with_index because we need the 1D index for indexing
725        // indexes
726        for (i, ((x, parent1), (y, _))) in (x_records.zip(y_records)).enumerate() {
727            let z = fxy(x.clone(), y.clone());
728            // if rhs didn't have a history, don't track that derivative
729            let derivative1 = dfxy_dx(x, y);
730            let new_index = history.append_unary(parent1, derivative1);
731            zs[i] = (z, new_index);
732        }
733    }); // drop borrow on history
734    zs
735}
736
737/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), as with binary_both_history,
738/// but only tracking the derivatives for Y, not X.
739fn binary_y_history<'a, T, I1, I2>(
740    total: usize,
741    history: &WengertList<T>,
742    x_records: I1,
743    y_records: I2,
744    fxy: impl Fn(T, T) -> T,
745    dfxy_dy: impl Fn(T, T) -> T,
746) -> Vec<(T, usize)>
747where
748    I1: Iterator<Item = (T, Index)>,
749    I2: Iterator<Item = (T, Index)>,
750    T: Numeric + Primitive,
751    for<'t> &'t T: NumericRef<T>,
752{
753    let mut zs = vec![(T::zero(), 0); total];
754    history.borrow(|history| {
755        // shadow the name so we can't accidentally try to use history while holding
756        // the borrow
757        // use enumerate not with_index because we need the 1D index for indexing
758        // indexes
759        for (i, ((x, _), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
760            let z = fxy(x.clone(), y.clone());
761            // if self didn't have a history, don't track that derivative
762            let derivative2 = dfxy_dy(x, y);
763            let new_index = history.append_unary(parent2, derivative2);
764            zs[i] = (z, new_index);
765        }
766    }); // drop borrow on history
767    zs
768}
769
770impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
771where
772    T: Numeric + Primitive,
773    for<'t> &'t T: NumericRef<T>,
774    S: TensorRef<(T, Index), D>,
775{
776    /**
777     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
778     * some unary function from `T` to `T` to every element in the container.
779     *
780     * To compute the new records, the unary function of some input x to some
781     * output y is needed along with its derivative with respect to its input x.
782     *
783     * For example, tanh is a commonly used activation function, but the Real trait
784     * does not include this operation and Record has no operations for it specifically.
785     * However, you can use this function to compute the tanh for a record container like so:
786     *
787     * ```
788     * use easy_ml::differentiation::{RecordTensor, WengertList};
789     * use easy_ml::tensors::Tensor;
790     * let list = WengertList::new();
791     * let X = RecordTensor::variables(
792     *     &list,
793     *     Tensor::from_fn(
794     *         [("rows", 2), ("columns", 2)],
795     *         |[r, c]| 0.15 * ((1 + r + c) as f32)
796     *     )
797     * );
798     * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
799     * // 1 / (cosh(x) * cosh(x))
800     * let Y = X.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
801     *
802     * // we can unwrap here because we know Y contains variables not constants
803     * let derivatives = Y.derivatives().unwrap();
804     * let derivatives_indexing = derivatives.index_by(["rows", "columns"]);
805     * assert_eq!(
806     *     derivatives_indexing.get_ref([0, 0]).at_tensor(&X),
807     *     Tensor::from(
808     *         [("rows", 2), ("columns", 2)],
809     *         // [0, 0] element in Y only had the one input variable [0, 0] in X
810     *         vec![
811     *             0.9778332, 0.0,
812     *             0.0,       0.0
813     *        ]
814     *     ),
815     * );
816     * assert_eq!(
817     *     derivatives_indexing.get_ref([0, 1]).at_tensor(&X),
818     *     Tensor::from(
819     *         [("rows", 2), ("columns", 2)],
820     *         vec![
821     *             0.0, 0.915137,
822     *             0.0, 0.0
823     *        ]
824     *     ),
825     * );
826     * assert_eq!(
827     *     // [0, 1] and [1, 0] elements in X had the same starting value so end up with the same
828     *     // derivative for their corresponding input variable in X
829     *     derivatives_indexing.get_ref([0, 1]).at_tensor(&X).index().get([0, 1]),
830     *     derivatives_indexing.get_ref([1, 0]).at_tensor(&X).index().get([1, 0]),
831     * );
832     * assert_eq!(
833     *     derivatives_indexing.get_ref([1, 1]).at_tensor(&X),
834     *     Tensor::from(
835     *         [("rows", 2), ("columns", 2)],
836     *         vec![
837     *             0.0, 0.0,
838     *             0.0, 0.8220013
839     *        ]
840     *     ),
841     * );
842     * ```
843     */
844    #[track_caller]
845    pub fn unary(
846        &self,
847        fx: impl Fn(T) -> T,
848        dfx_dx: impl Fn(T) -> T,
849    ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D> {
850        let total = self.elements();
851        match self.history {
852            None => RecordTensor::constants(self.numbers.map(|(x, _)| fx(x))),
853            Some(history) => {
854                let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
855                RecordContainer {
856                    numbers: self.numbers.new_with_same_shape(ys),
857                    history: Some(history),
858                }
859            }
860        }
861    }
862
863    /**
864     * Creates a new RecordContainer from two RecordContainers by applying
865     * some binary function from `T` to `T` to every element pair in the containers. Both
866     * containers must have the same shape.
867     *
868     * To compute the new records, the binary function of some inputs x and y to some
869     * output z is needed along with its derivative with respect to its first input x and
870     * its derivative with respect to its second input y.
871     *
872     * For example, atan2 takes two arguments, but the Real trait
873     * does not include this operation and Record has no operations for it specifically.
874     * However, you can use this function to compute the atan2 for two record containers like so:
875     *
876     * ```
877     * use easy_ml::differentiation::{RecordTensor, WengertList};
878     * use easy_ml::tensors::Tensor;
879     * let list = WengertList::new();
880     * let X = RecordTensor::variables(
881     *     &list,
882     *     Tensor::from_fn(
883     *         [("rows", 2), ("columns", 2)],
884     *         |[r, c]| ((1 + r + c) as f32)
885     *     )
886     * );
887     * let Y = RecordTensor::variables(
888     *     &list,
889     *     Tensor::from_fn(
890     *         [("rows", 2), ("columns", 2)],
891     *         |[r, c]| ((1 + r + c) as f32)
892     *     )
893     * );
894     * // the derivative of atan2 with respect to x is y/(x*x + y*y)
895     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
896     * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
897     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
898     * let Z = X.binary(&Y,
899     *     |x, y| x.atan2(y),
900     *     |x, y| y/((x*x) + (y*y)),
901     *     |x, y| -x/((x*x) + (y*y))
902     * );
903     *
904     *
905     * // we can unwrap here because we know Z contains variables not constants
906     * let derivatives = Z.derivatives().unwrap();
907     * // Just as in the unary example, only one pair of the four inputs in X and Y influence the
908     * // outputs in Z, so we have a lot of 0.0 derivatives, and the inputs in [0, 1] and [1, 0]
909     * // are identical so we see the same derivative.
910     * let dZ_dX = derivatives.map(|d| d.at_tensor(&X));
911     * assert_eq!(
912     *     dZ_dX,
913     *     Tensor::from([("rows", 2), ("columns", 2)], vec![
914     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
915     *             0.5, 0.0,
916     *             0.0, 0.0
917     *         ]),
918     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
919     *             0.0, 0.25,
920     *             0.0, 0.0
921     *         ]),
922     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
923     *             0.0, 0.0,
924     *             0.25, 0.0
925     *         ]),
926     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
927     *             0.0, 0.0,
928     *             0.0, 0.16666667
929     *         ])
930     *     ])
931     * );
932     * let dZ_dY = derivatives.map(|d| d.at_tensor(&Y));
933     * assert_eq!(
934     *     dZ_dY,
935     *     Tensor::from([("rows", 2), ("columns", 2)], vec![
936     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
937     *             -0.5, 0.0,
938     *             0.0, 0.0
939     *         ]),
940     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
941     *             0.0, -0.25,
942     *             0.0, 0.0
943     *         ]),
944     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
945     *             0.0, 0.0,
946     *             -0.25, 0.0
947     *         ]),
948     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
949     *             0.0, 0.0,
950     *             0.0, -0.16666667
951     *         ])
952     *     ])
953     * );
954     * ```
955     *
956     * # Panics
957     *
958     * - If both record containers have a WengertList that are different to each other
959     * - If the record containers have different shapes
960     */
961    #[track_caller]
962    pub fn binary<S2>(
963        &self,
964        rhs: &RecordTensor<'a, T, S2, D>,
965        fxy: impl Fn(T, T) -> T,
966        dfxy_dx: impl Fn(T, T) -> T,
967        dfxy_dy: impl Fn(T, T) -> T,
968    ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
969    where
970        S2: TensorRef<(T, Index), D>,
971    {
972        {
973            let left_shape = self.numbers.shape();
974            let right_shape = rhs.numbers.shape();
975            if left_shape != right_shape {
976                panic!(
977                    "Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
978                    left_shape, right_shape
979                );
980            }
981        }
982        let total = self.elements();
983        match (self.history, rhs.history) {
984            (None, None) => RecordTensor::constants(
985                // use direct_from here maybe?
986                Tensor::from(
987                    self.numbers.shape(),
988                    self.numbers
989                        .iter()
990                        .zip(rhs.numbers.iter())
991                        .map(|((x, _), (y, _))| fxy(x, y))
992                        .collect(),
993                ),
994            ),
995            (Some(history), None) => {
996                let zs = binary_x_history::<T, _, _>(
997                    total,
998                    history,
999                    self.numbers.iter(),
1000                    rhs.numbers.iter(),
1001                    fxy,
1002                    dfxy_dx,
1003                );
1004                RecordContainer {
1005                    numbers: self.numbers.new_with_same_shape(zs),
1006                    history: Some(history),
1007                }
1008            }
1009            (None, Some(history)) => {
1010                let zs = binary_y_history::<T, _, _>(
1011                    total,
1012                    history,
1013                    self.numbers.iter(),
1014                    rhs.numbers.iter(),
1015                    fxy,
1016                    dfxy_dy,
1017                );
1018                RecordContainer {
1019                    numbers: self.numbers.new_with_same_shape(zs),
1020                    history: Some(history),
1021                }
1022            }
1023            (Some(history), Some(h)) => {
1024                assert!(
1025                    record_operations::same_lists(history, h),
1026                    "Record containers must be using the same WengertList"
1027                );
1028                let zs = binary_both_history::<T, _, _>(
1029                    total,
1030                    history,
1031                    self.numbers.iter(),
1032                    rhs.numbers.iter(),
1033                    fxy,
1034                    dfxy_dx,
1035                    dfxy_dy,
1036                );
1037                RecordContainer {
1038                    numbers: self.numbers.new_with_same_shape(zs),
1039                    history: Some(history),
1040                }
1041            }
1042        }
1043    }
1044
1045    /**
1046     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1047     * some unary function on `Record<T>` to `Record<T>` to every element in the container. This
1048     * will fail if the function would create records with inconsistent histories.
1049     *
1050     * When used with pure functions that can't return different histories for different inputs
1051     * unwrapping with always succeed.
1052     *
1053     * This API can allow you to call a generic function that operates on
1054     * [Numeric](crate::numeric::Numeric) or [Real](crate::numeric::extra::Real) numbers and
1055     * apply all the correct derivative tracking during the intermediate calculations for you,
1056     * without having to resort to storing the Record types.
1057     *
1058     * ```
1059     * use easy_ml::numeric::extra::Real;
1060     * use easy_ml::tensors::Tensor;
1061     * use easy_ml::differentiation::{RecordTensor, WengertList};
1062     *
1063     * fn sigmoid<T: Real+ Copy>(x: T) -> T {
1064     *     T::one() / (T::one() + (-x).exp())
1065     * }
1066     *
1067     * let history = WengertList::new();
1068     * let layer = RecordTensor::variables(&history, Tensor::from([("x", 2)], vec![ 0.2, 0.6 ]));
1069     * let after = layer.map(sigmoid).unwrap(); // sigmoid can't introduce inconsistent histories
1070     * ```
1071     *
1072     * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1073     * after mapping doesn't have to be the same as before, only must be the same for every
1074     * mapped element.
1075     *
1076     * See also: [AsRecords](AsRecords)
1077     */
1078    #[allow(clippy::type_complexity)]
1079    #[track_caller]
1080    pub fn map(
1081        &self,
1082        fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1083    ) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
1084        let result = RecordTensor::from_iter(self.shape(), self.iter_as_records().map(fx));
1085        RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
1086    }
1087
1088    /**
1089     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1090     * some unary function on `Record<T>` and each index of that position in the Record to
1091     * `Record<T>` to every element in the container. This will fail if the function would create
1092     * records with inconsistent histories.
1093     *
1094     * When used with pure functions that can't return different histories for different inputs
1095     * unwrapping with always succeed.
1096     *
1097     * This API can allow you to call a generic function that operates on
1098     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1099     * during the intermediate calculations for you, without having to resort to storing the
1100     * Record types.
1101     *
1102     * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1103     * after mapping doesn't have to be the same as before, only must be the same for every
1104     * mapped element.
1105     */
1106    #[allow(clippy::type_complexity)]
1107    #[track_caller]
1108    pub fn map_with_index(
1109        &self,
1110        fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
1111    ) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
1112        let result = RecordTensor::from_iter(
1113            self.shape(),
1114            self.iter_as_records().with_index().map(|(i, x)| fx(i, x)),
1115        );
1116        RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
1117    }
1118
1119    #[track_caller]
1120    #[allow(clippy::type_complexity)]
1121    fn map_collection(
1122        result: Result<
1123            RecordTensor<'a, T, Tensor<(T, usize), D>, D>,
1124            InvalidRecordIteratorError<'a, T, D>,
1125        >,
1126        shape: [(Dimension, usize); D],
1127    ) -> Result<RecordTensor<'a, T, Tensor<(T, usize), D>, D>, InconsistentHistory<'a, T>> {
1128        use InvalidRecordIteratorError as Error;
1129        match result {
1130            Ok(tensor) => Ok(tensor),
1131            Err(error) => match error {
1132                // These first two should be 100% impossible but provide a sensible error just
1133                // in case some weird things break our invariants
1134                Error::Empty => panic!("Illegal state, record tensor was empty {:?}", shape),
1135                Error::Shape { requested, length } => panic!(
1136                    "Illegal state, record tensor shape was inconsistent: requested: {:?}, length of data: {:?}",
1137                    requested, length
1138                ),
1139                // This one is theoretically possible but in practise shouldn't happen by accident
1140                // However, it can't implement Debug unless T is debug so to avoid having to
1141                // restrict our function signature we return a Result anyway - this also encourages
1142                // the user to make sure their function isn't going to cause this case, which
1143                // with some of the other variants like with_index might come up more easily
1144                Error::InconsistentHistory(h) => Err(h),
1145            },
1146        }
1147    }
1148
1149    /**
1150     * For each record in the container, peforms a backward pass up its WengertList from it
1151     * as the output, computing all the derivatives for the inputs involving this output.
1152     *
1153     * If this container has no backing WengertList, ie was created as constants, then None is
1154     * returned instead. Otherwise the returned Tensor will have the same shape as this container,
1155     * with the respective derivatives matching each element in this container.
1156     *
1157     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is Y with M outputs,
1158     * then this computes all the derivatives δy<sub>j</sub>/δx<sub>i</sub> for i = 1 to N and
1159     * j = 1 to M.
1160     *
1161     * If you have a lot of outputs this could be very expensive! Reverse auto diff is optimised
1162     * for domains where there are many more inputs than outputs.
1163     *
1164     * If you only need some of the derivatives then
1165     * [derivatives_for](RecordTensor::derivatives_for) can be used instead to avoid
1166     * calculating the rest.
1167     */
1168    pub fn derivatives(&self) -> Option<Tensor<Derivatives<T>, D>> {
1169        self.history.map(|history| {
1170            self.numbers.map(|(x, i)| {
1171                Record {
1172                    number: x,
1173                    history: Some(history),
1174                    index: i,
1175                }
1176                .derivatives()
1177            })
1178        })
1179    }
1180
1181    /**
1182     * For the record at the index, peforms a backward pass up its WengertList from it
1183     * as the output, computing all the derivatives for the inputs involving this output.
1184     *
1185     * If the index is invalid or this container has no backing WengertList, ie was created
1186     * as constants, then None is returned instead.
1187     *
1188     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
1189     * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
1190     */
1191    pub fn derivatives_for(&self, indexes: [usize; D]) -> Option<Derivatives<T>> {
1192        let (number, index) = self.get_reference(indexes).map(|(x, i)| (x.clone(), *i))?;
1193        // The nature of reverse autodiff is that we expect to only have a few outputs from
1194        // which we calculate all the derivatives we care about. Therefore just call Record and
1195        // reuse the implementation instead of trying to do anything clever like calculate all
1196        // derivatives for every number in this container.
1197        Record {
1198            number,
1199            history: self.history,
1200            index,
1201        }
1202        .try_derivatives()
1203    }
1204
1205    /**
1206     * Performs elementwise multiplication for two record tensors of the same shape.
1207     *
1208     * # Panics
1209     *
1210     * - If both record containers have a WengertList that are different to each other
1211     * - If the record containers have different shapes
1212     */
1213    // TODO: Assign variants?
1214    pub fn elementwise_multiply<S2>(
1215        &self,
1216        other: &RecordTensor<'a, T, S2, D>,
1217    ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1218    where
1219        S2: TensorRef<(T, Index), D>,
1220    {
1221        self.binary(
1222            other,
1223            Multiplication::<T>::function,
1224            Multiplication::<T>::d_function_dx,
1225            Multiplication::<T>::d_function_dy,
1226        )
1227    }
1228
1229    /**
1230     * Performs elementwise division for two record tensors of the same shape.
1231     *
1232     * # Panics
1233     *
1234     * - If both record containers have a WengertList that are different to each other
1235     * - If the record containers have different shapes
1236     */
1237    pub fn elementwise_divide<S2>(
1238        &self,
1239        other: &RecordTensor<'a, T, S2, D>,
1240    ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1241    where
1242        S2: TensorRef<(T, Index), D>,
1243    {
1244        self.binary(
1245            other,
1246            Division::<T>::function,
1247            Division::<T>::d_function_dx,
1248            Division::<T>::d_function_dy,
1249        )
1250    }
1251}
1252
1253impl<T: Clone + Primitive> Derivatives<T> {
1254    /**
1255     * Queries the derivative at the provided index into the record tensor as input.
1256     *
1257     * If you construct a Derivatives object for some output y,
1258     * and call .at_tensor_index(i, &xs) on it for some input container xs and index i, this
1259     * returns dy/dx where x = xs\[i\].
1260     *
1261     * If the index into the tensor is invalid, returns None instead.
1262     */
1263    pub fn at_tensor_index<S, const D: usize>(
1264        &self,
1265        indexes: [usize; D],
1266        input: &RecordTensor<T, S, D>,
1267    ) -> Option<T>
1268    where
1269        S: TensorRef<(T, Index), D>,
1270    {
1271        let index = input.get_reference(indexes).map(|(_, i)| *i)?;
1272        Some(self.derivatives[index].clone())
1273    }
1274
1275    /**
1276     * Queries the derivatives at every element in the record tensor input.
1277     *
1278     * If you construct a Derivatives object for some output y,
1279     * and call .at_tensor(&xs) on it for some input container xs this
1280     * returns dy/dx for every x in xs.
1281     */
1282    pub fn at_tensor<S, const D: usize>(&self, input: &RecordTensor<T, S, D>) -> Tensor<T, D>
1283    where
1284        S: TensorRef<(T, Index), D>,
1285    {
1286        input.numbers.map(|(_, i)| self.derivatives[i].clone())
1287    }
1288
1289    /**
1290     * Queries the derivative at the provided index into the record matrix as input.
1291     *
1292     * If you construct a Derivatives object for some output y,
1293     * and call .at_matrix_index(i, j, &xs) on it for some input container xs and indexes i and j,
1294     * this returns dy/dx where x = xs\[i, j\].
1295     *
1296     * If the index into the tensor is invalid, returns None instead.
1297     */
1298    pub fn at_matrix_index<S>(
1299        &self,
1300        row: Row,
1301        column: Column,
1302        input: &RecordMatrix<T, S>,
1303    ) -> Option<T>
1304    where
1305        S: MatrixRef<(T, Index)> + NoInteriorMutability,
1306    {
1307        let index = input.try_get_reference(row, column).map(|(_, i)| *i)?;
1308        Some(self.derivatives[index].clone())
1309    }
1310
1311    /**
1312     * Queries the derivatives at every element in the record matrix input.
1313     *
1314     * If you construct a Derivatives object for some output y,
1315     * and call .at_matrix(&xs) on it for some input container xs this
1316     * returns dy/dx for every x in xs.
1317     */
1318    pub fn at_matrix<S>(&self, input: &RecordMatrix<T, S>) -> Matrix<T>
1319    where
1320        S: MatrixRef<(T, Index)> + NoInteriorMutability,
1321    {
1322        input.numbers.map(|(_, i)| self.derivatives[i].clone())
1323    }
1324}
1325
1326impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
1327where
1328    T: Numeric + Primitive,
1329    for<'t> &'t T: NumericRef<T>,
1330    S: TensorMut<(T, Index), D>,
1331{
1332    /**
1333     * Overwrites a RecordContainer by applying
1334     * some unary function from `T` to `T` to every element in the container.
1335     *
1336     * To compute the new records, the unary function of some input x to some
1337     * output y is needed along with its derivative with respect to its input x.
1338     */
1339    #[track_caller]
1340    pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
1341        let total = self.elements();
1342        match self.history {
1343            None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
1344            Some(history) => {
1345                let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
1346                for (element, result) in self.numbers.iter_reference_mut().zip(ys) {
1347                    *element = result;
1348                }
1349                self.history = Some(history);
1350            }
1351        }
1352    }
1353
1354    /**
1355     * Overwrites the left hand side of a RecordContainer with the result of applying
1356     * some binary function from `T` to `T` to every element pair in the containers. Both
1357     * containers must have the same shape.
1358     * To compute the new records, the binary function of some inputs x and y to some
1359     * output z is needed along with its derivative with respect to its first input x and
1360     * its derivative with respect to its second input y.
1361     *
1362     * # Panics
1363     *
1364     * - If both record containers have a WengertList that are different to each other
1365     * - If the record containers have different shapes
1366     */
1367    #[track_caller]
1368    pub fn binary_left_assign<S2>(
1369        &mut self,
1370        rhs: &RecordTensor<'a, T, S2, D>,
1371        fxy: impl Fn(T, T) -> T,
1372        dfxy_dx: impl Fn(T, T) -> T,
1373        dfxy_dy: impl Fn(T, T) -> T,
1374    ) where
1375        S2: TensorRef<(T, Index), D>,
1376    {
1377        {
1378            let left_shape = self.numbers.shape();
1379            let right_shape = rhs.numbers.shape();
1380            if left_shape != right_shape {
1381                panic!(
1382                    "Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
1383                    left_shape, right_shape
1384                );
1385            }
1386        }
1387        let total = self.elements();
1388        match (self.history, rhs.history) {
1389            (None, None) => {
1390                for (x, y) in self.numbers.iter_reference_mut().zip(rhs.numbers.iter()) {
1391                    let (left, _) = x;
1392                    let (right, _) = y;
1393                    *x = (fxy(left.clone(), right), 0);
1394                }
1395            }
1396            (Some(history), None) => {
1397                let zs = binary_x_history::<T, _, _>(
1398                    total,
1399                    history,
1400                    self.numbers.iter(),
1401                    rhs.numbers.iter(),
1402                    fxy,
1403                    dfxy_dx,
1404                );
1405                for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1406                    *element = result;
1407                }
1408                self.history = Some(history);
1409            }
1410            (None, Some(history)) => {
1411                let zs = binary_y_history::<T, _, _>(
1412                    total,
1413                    history,
1414                    self.numbers.iter(),
1415                    rhs.numbers.iter(),
1416                    fxy,
1417                    dfxy_dy,
1418                );
1419                for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1420                    *element = result;
1421                }
1422                self.history = Some(history);
1423            }
1424            (Some(history), Some(h)) => {
1425                assert!(
1426                    record_operations::same_lists(history, h),
1427                    "Record containers must be using the same WengertList"
1428                );
1429                let zs = binary_both_history::<T, _, _>(
1430                    total,
1431                    history,
1432                    self.numbers.iter(),
1433                    rhs.numbers.iter(),
1434                    fxy,
1435                    dfxy_dx,
1436                    dfxy_dy,
1437                );
1438                for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1439                    *element = result;
1440                }
1441                self.history = Some(history);
1442            }
1443        }
1444    }
1445
1446    /**
1447     * A convenience helper function which takes the RecordContainer value and
1448     * calls [unary_assign](RecordTensor::unary_assign()) on it, returning
1449     * the record container which now contains the result of the operation.
1450     */
1451    #[track_caller]
1452    pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
1453        self.unary_assign(fx, dfx_dx);
1454        self
1455    }
1456
1457    /**
1458     * A convenience helper function which takes the left hand side by value and
1459     * calls [binary_left_assign](RecordTensor::binary_left_assign()) on it, returning
1460     * the left hand side which now contains the result of the operation.
1461     */
1462    #[track_caller]
1463    pub fn do_binary_left_assign<S2>(
1464        mut self,
1465        rhs: &RecordTensor<'a, T, S2, D>,
1466        fxy: impl Fn(T, T) -> T,
1467        dfxy_dx: impl Fn(T, T) -> T,
1468        dfxy_dy: impl Fn(T, T) -> T,
1469    ) -> Self
1470    where
1471        S2: TensorRef<(T, Index), D>,
1472    {
1473        self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
1474        self
1475    }
1476
1477    /**
1478     * Updates this RecordContainer in place by applying some unary function on `Record<T>` to
1479     * `Record<T>` to every element in the container. This will fail if the function would create
1480     * records with inconsistent histories.
1481     *
1482     * When used with pure functions that can't return different histories for different inputs
1483     * unwrapping with always succeed.
1484     *
1485     * Since this updates the container in place, if Err is returned then the data in this
1486     * RecordContainer is still available but it has been corrupted - at least one of the elements
1487     * should have a different history than what it will have because the mapping function created
1488     * inconsistent histories that couldn't be represented by the container as it only stores
1489     * one.
1490     *
1491     * This API can allow you to call a generic function that operates on
1492     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1493     * during the intermediate calculations for you, without having to resort to storing the
1494     * Record types.
1495     *
1496     * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1497     * after mapping doesn't have to be the same as before, only must be the same for every
1498     * mapped element.
1499     *
1500     * You might also use this function at the end of a training loop to update all the weights
1501     * to reduce their loss.
1502     *
1503     * ```
1504     * use easy_ml::numeric::Numeric;
1505     * use easy_ml::tensors::Tensor;
1506     * use easy_ml::differentiation::{Record, RecordTensor, WengertList};
1507     *
1508     * let history = WengertList::new();
1509     * let mut weights = RecordTensor::variables(
1510     *     &history,
1511     *     Tensor::from([("w1", 4)], vec![ 0.3, 0.2, -1.2, -0.4 ])
1512     * );
1513     * let error = {
1514     *     // this is over-simplified for brevity, obviously in a real scenario we wouldn't have a
1515     *     // function that calculates the error like this or we wouldn't be doing machine learning
1516     *     // to fit it in the first place
1517     *     let mut loss = Record::variable(0.0, &history);
1518     *     for r in weights.iter_as_records() {
1519     *         loss = loss + r;
1520     *     }
1521     *     loss
1522     * };
1523     * let derivatives = error.derivatives();
1524     * let learning_rate = 0.1;
1525     * // update the weights to contain less error than they did before
1526     * let result = weights.map_mut(|x| x - (derivatives[&x] * learning_rate));
1527     * assert!(result.is_ok()); // we know we didn't introduce an inconsistent history just updating the weights
1528     * ```
1529     */
1530    #[track_caller]
1531    pub fn map_mut(
1532        &mut self,
1533        fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1534    ) -> Result<(), InconsistentHistory<'a, T>> {
1535        let history = self.history;
1536        let new_history =
1537            map_mut_base::<'a, T, _, _>(TensorReferenceMutIterator::from(self), |x| {
1538                let record = Record::from_existing(x.clone(), history);
1539                let result = fx(record);
1540                *x = (result.number, result.index);
1541                result.history
1542            })?;
1543        self.history = new_history;
1544        Ok(())
1545    }
1546
1547    /**
1548     * Updates this RecordContainer in place by applying some unary function on `Record<T>` and
1549     * each index of that position in the Record to `Record<T>` to every element in the container.
1550     * This will fail if the function would create records with inconsistent histories.
1551     *
1552     * When used with pure functions that can't return different histories for different inputs
1553     * unwrapping with always succeed.
1554     *
1555     * Since this updates the container in place, if Err is returned then the data in this
1556     * RecordContainer is still available but it has been corrupted - at least one of the elements
1557     * should have a different history than what it will have because the mapping function created
1558     * inconsistent histories that couldn't be represented by the container as it only stores
1559     * one.
1560     *
1561     * This API can allow you to call a generic function that operates on
1562     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1563     * during the intermediate calculations for you, without having to resort to storing the
1564     * Record types.
1565     *
1566     * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1567     * after mapping doesn't have to be the same as before, only must be the same for every
1568     * mapped element.
1569     */
1570    #[track_caller]
1571    pub fn map_mut_with_index(
1572        &mut self,
1573        fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
1574    ) -> Result<(), InconsistentHistory<'a, T>> {
1575        let history = self.history;
1576        let new_history = map_mut_base::<'a, T, _, _>(
1577            TensorReferenceMutIterator::from(self).with_index(),
1578            |(i, x)| {
1579                let record = Record::from_existing(x.clone(), history);
1580                let result = fx(i, record);
1581                *x = (result.number, result.index);
1582                result.history
1583            },
1584        )?;
1585        self.history = new_history;
1586        Ok(())
1587    }
1588}
1589
1590#[track_caller]
1591fn map_mut_base<'a, T, I, X>(
1592    mut iter: I,
1593    fx: impl Fn(X) -> Option<&'a WengertList<T>>,
1594) -> Result<Option<&'a WengertList<T>>, InconsistentHistory<'a, T>>
1595where
1596    I: Iterator<Item = X>,
1597    T: Primitive,
1598{
1599    use crate::differentiation::record_operations::are_exact_same_list;
1600    #[rustfmt::skip]
1601    let first_history = fx(iter.next().expect("Illegal state, record container was empty"));
1602    let mut different_history: Option<Option<&WengertList<T>>> = None;
1603    for x in iter {
1604        let history = fx(x);
1605        if !are_exact_same_list(history, first_history) {
1606            different_history = Some(history);
1607        }
1608    }
1609    match different_history {
1610        None => Ok(first_history),
1611        Some(h) => Err(InconsistentHistory {
1612            first: first_history,
1613            later: h,
1614        }),
1615    }
1616}
1617
1618impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
1619where
1620    T: Numeric + Primitive,
1621    for<'t> &'t T: NumericRef<T>,
1622    S: TensorRef<(T, Index), D>,
1623{
1624    /**
1625     * Overwrites the right hand side of a RecordContainer with the result of applying
1626     * some binary function from `T` to `T` to every element pair in the containers. Both
1627     * containers must have the same shape.
1628     * To compute the new records, the binary function of some inputs x and y to some
1629     * output z is needed along with its derivative with respect to its first input x and
1630     * its derivative with respect to its second input y.
1631     *
1632     * # Panics
1633     *
1634     * - If both record containers have a WengertList that are different to each other
1635     * - If the record containers have different shapes
1636     */
1637    #[track_caller]
1638    pub fn binary_right_assign<S2>(
1639        &self,
1640        rhs: &mut RecordTensor<'a, T, S2, D>,
1641        fxy: impl Fn(T, T) -> T,
1642        dfxy_dx: impl Fn(T, T) -> T,
1643        dfxy_dy: impl Fn(T, T) -> T,
1644    ) where
1645        S2: TensorMut<(T, Index), D>,
1646    {
1647        // x is lhs, y is rhs, so calling binary_left_assign on the rhs container
1648        // means we need to swap all the arguments
1649        rhs.binary_left_assign(
1650            self,
1651            |y, x| fxy(x, y),
1652            |y, x| dfxy_dy(x, y),
1653            |y, x| dfxy_dx(x, y),
1654        )
1655    }
1656
1657    /**
1658     * A convenience helper function which takes the right hand side by value and
1659     * calls [binary_right_assign](RecordTensor::binary_right_assign()) on it, returning
1660     * the right hand side which now contains the result of the operation.
1661     */
1662    #[track_caller]
1663    pub fn do_binary_right_assign<S2>(
1664        &self,
1665        mut rhs: RecordTensor<'a, T, S2, D>,
1666        fxy: impl Fn(T, T) -> T,
1667        dfxy_dx: impl Fn(T, T) -> T,
1668        dfxy_dy: impl Fn(T, T) -> T,
1669    ) -> RecordTensor<'a, T, S2, D>
1670    where
1671        S2: TensorMut<(T, Index), D>,
1672    {
1673        self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
1674        rhs
1675    }
1676}
1677
1678impl<'a, T, S> RecordMatrix<'a, T, S>
1679where
1680    T: Numeric + Primitive,
1681    for<'t> &'t T: NumericRef<T>,
1682    S: MatrixRef<(T, Index)> + NoInteriorMutability,
1683{
1684    /**
1685     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1686     * some unary function from `T` to `T` to every element in the container.
1687     *
1688     * To compute the new records, the unary function of some input x to some
1689     * output y is needed along with its derivative with respect to its input x.
1690     *
1691     * For example, tanh is a commonly used activation function, but the Real trait
1692     * does not include this operation and Record has no operations for it specifically.
1693     * However, you can use this function to compute the tanh for a record container like so:
1694     *
1695     * ```
1696     * use easy_ml::differentiation::{RecordMatrix, WengertList};
1697     * use easy_ml::matrices::Matrix;
1698     * let list = WengertList::new();
1699     * let X = RecordMatrix::variables(
1700     *     &list,
1701     *     Matrix::from_fn((2, 2), |(r, c)| 0.15 * ((1 + r + c) as f32))
1702     * );
1703     * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
1704     * // 1 / (cosh(x) * cosh(x))
1705     * let Y = X.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
1706     *
1707     * // we can unwrap here because we know Y contains variables not constants
1708     * let derivatives = Y.derivatives().unwrap();
1709     * assert_eq!(
1710     *     derivatives.get_reference(0, 0).at_matrix(&X),
1711     *     Matrix::from(vec![
1712     *         // (0, 0) element in Y only had the one input variable (0, 0) in X
1713     *         vec![0.9778332, 0.0],
1714     *         vec![0.0,       0.0]
1715     *     ]),
1716     * );
1717     * assert_eq!(
1718     *     derivatives.get_reference(0, 1).at_matrix(&X),
1719     *     Matrix::from(vec![
1720     *         vec![0.0, 0.915137],
1721     *         vec![0.0,      0.0]
1722     *     ]),
1723     * );
1724     * assert_eq!(
1725     *     // (0, 1) and (1, 0) elements in X had the same starting value so end up with the same
1726     *     // derivative for their corresponding input variable in X
1727     *     derivatives.get_reference(0, 1).at_matrix(&X).get(0, 1),
1728     *     derivatives.get_reference(1, 0).at_matrix(&X).get(1, 0),
1729     * );
1730     * assert_eq!(
1731     *     derivatives.get_reference(1, 1).at_matrix(&X),
1732     *     Matrix::from(vec![
1733     *         vec![0.0, 0.0      ],
1734     *         vec![0.0, 0.8220013]
1735     *     ]),
1736     * );
1737     * ```
1738     */
1739    #[track_caller]
1740    pub fn unary(
1741        &self,
1742        fx: impl Fn(T) -> T,
1743        dfx_dx: impl Fn(T) -> T,
1744    ) -> RecordMatrix<'a, T, Matrix<(T, Index)>> {
1745        let total = self.elements();
1746        match self.history {
1747            None => RecordMatrix::constants(self.numbers.map(|(x, _)| fx(x))),
1748            Some(history) => {
1749                let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
1750                RecordContainer {
1751                    numbers: MatrixView::from(Matrix::from_flat_row_major(self.numbers.size(), ys)),
1752                    history: Some(history),
1753                }
1754            }
1755        }
1756    }
1757
1758    /**
1759     * Creates a new RecordContainer from two RecordContainers by applying
1760     * some binary function from `T` to `T` to every element pair in the containers. Both
1761     * containers must have the same shape.
1762     *
1763     * To compute the new records, the binary function of some inputs x and y to some
1764     * output z is needed along with its derivative with respect to its first input x and
1765     * its derivative with respect to its second input y.
1766     *
1767     * For example, atan2 takes two arguments, but the Real trait
1768     * does not include this operation and Record has no operations for it specifically.
1769     * However, you can use this function to compute the atan2 for two record containers like so:
1770     *
1771     * ```
1772     * use easy_ml::differentiation::{RecordMatrix, WengertList};
1773     * use easy_ml::matrices::Matrix;
1774     * let list = WengertList::new();
1775     * let X = RecordMatrix::variables(
1776     *     &list,
1777     *     Matrix::from_fn((2, 2), |(r, c)| ((1 + r + c) as f32))
1778     * );
1779     * let Y = RecordMatrix::variables(
1780     *     &list,
1781     *     Matrix::from_fn((2, 2), |(r, c)| ((1 + r + c) as f32))
1782     * );
1783     * // the derivative of atan2 with respect to x is y/(x*x + y*y)
1784     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
1785     * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
1786     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
1787     * let Z = X.binary(&Y,
1788     *     |x, y| x.atan2(y),
1789     *     |x, y| y/((x*x) + (y*y)),
1790     *     |x, y| -x/((x*x) + (y*y))
1791     * );
1792     *
1793     * // we can unwrap here because we know Z contains variables not constants
1794     * let derivatives = Z.derivatives().unwrap();
1795     * // Just as in the unary example, only one pair of the four inputs in X and Y influence the
1796     * // outputs in Z, so we have a lot of 0.0 derivatives, and the inputs in [0, 1] and [1, 0]
1797     * // are identical so we see the same derivative.
1798     * let dZ_dX = derivatives.map(|d| d.at_matrix(&X));
1799     * assert_eq!(
1800     *     dZ_dX,
1801     *     Matrix::from(vec![
1802     *          vec![
1803     *              Matrix::from(vec![
1804     *                  vec![ 0.5, 0.0 ],
1805     *                  vec![ 0.0, 0.0 ]
1806     *              ]),
1807     *              Matrix::from(vec![
1808     *                  vec![ 0.0, 0.25 ],
1809     *                  vec![ 0.0, 0.0 ]
1810     *              ])
1811     *          ],
1812     *          vec![
1813     *              Matrix::from(vec![
1814     *                  vec![ 0.0, 0.0 ],
1815     *                  vec![ 0.25, 0.0 ]
1816     *              ]),
1817     *              Matrix::from(vec![
1818     *                  vec![ 0.0, 0.0 ],
1819     *                  vec![ 0.0, 0.16666667 ]
1820     *              ])
1821     *          ]
1822     *     ])
1823     * );
1824     * let dZ_dY = derivatives.map(|d| d.at_matrix(&Y));
1825     * assert_eq!(
1826     *     dZ_dY,
1827     *     Matrix::from(vec![
1828     *          vec![
1829     *              Matrix::from(vec![
1830     *                  vec![ -0.5, 0.0 ],
1831     *                  vec![ 0.0, 0.0 ]
1832     *              ]),
1833     *              Matrix::from(vec![
1834     *                  vec![ 0.0, -0.25 ],
1835     *                  vec![ 0.0, 0.0 ]
1836     *              ])
1837     *          ],
1838     *          vec![
1839     *              Matrix::from(vec![
1840     *                  vec![ 0.0, 0.0 ],
1841     *                  vec![ -0.25, 0.0 ]
1842     *              ]),
1843     *              Matrix::from(vec![
1844     *                  vec![ 0.0, 0.0 ],
1845     *                  vec![ 0.0, -0.16666667 ]
1846     *              ])
1847     *          ]
1848     *     ])
1849     * );
1850     * ```
1851     *
1852     * # Panics
1853     *
1854     * - If both record containers have a WengertList that are different to each other
1855     * - If the record containers have different shapes
1856     */
1857    #[track_caller]
1858    pub fn binary<S2>(
1859        &self,
1860        rhs: &RecordMatrix<'a, T, S2>,
1861        fxy: impl Fn(T, T) -> T,
1862        dfxy_dx: impl Fn(T, T) -> T,
1863        dfxy_dy: impl Fn(T, T) -> T,
1864    ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1865    where
1866        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1867    {
1868        let shape = {
1869            let left_shape = self.numbers.size();
1870            let right_shape = rhs.numbers.size();
1871            if left_shape != right_shape {
1872                panic!(
1873                    "Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
1874                    left_shape, right_shape
1875                );
1876            }
1877            left_shape
1878        };
1879        let total = self.elements();
1880        match (self.history, rhs.history) {
1881            (None, None) => RecordMatrix::constants(Matrix::from_flat_row_major(
1882                shape,
1883                self.numbers
1884                    .row_major_iter()
1885                    .zip(rhs.numbers.row_major_iter())
1886                    .map(|((x, _), (y, _))| fxy(x, y))
1887                    .collect(),
1888            )),
1889            (Some(history), None) => {
1890                let zs = binary_x_history::<T, _, _>(
1891                    total,
1892                    history,
1893                    self.numbers.row_major_iter(),
1894                    rhs.numbers.row_major_iter(),
1895                    fxy,
1896                    dfxy_dx,
1897                );
1898                RecordContainer {
1899                    numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1900                    history: Some(history),
1901                }
1902            }
1903            (None, Some(history)) => {
1904                let zs = binary_y_history::<T, _, _>(
1905                    total,
1906                    history,
1907                    self.numbers.row_major_iter(),
1908                    rhs.numbers.row_major_iter(),
1909                    fxy,
1910                    dfxy_dy,
1911                );
1912                RecordContainer {
1913                    numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1914                    history: Some(history),
1915                }
1916            }
1917            (Some(history), Some(h)) => {
1918                assert!(
1919                    record_operations::same_lists(history, h),
1920                    "Record containers must be using the same WengertList"
1921                );
1922                let zs = binary_both_history::<T, _, _>(
1923                    total,
1924                    history,
1925                    self.numbers.row_major_iter(),
1926                    rhs.numbers.row_major_iter(),
1927                    fxy,
1928                    dfxy_dx,
1929                    dfxy_dy,
1930                );
1931                RecordContainer {
1932                    numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1933                    history: Some(history),
1934                }
1935            }
1936        }
1937    }
1938
1939    /**
1940     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1941     * some unary function on `Record<T>` to `Record<T>` to every element in the container. This
1942     * will fail if the function would create records with inconsistent histories.
1943     *
1944     * When used with pure functions that can't return different histories for different inputs
1945     * unwrapping with always succeed.
1946     *
1947     * This API can allow you to call a generic function that operates on
1948     * [Numeric](crate::numeric::Numeric) or [Real](crate::numeric::extra::Real) numbers and
1949     * apply all the correct derivative tracking during the intermediate calculations for you,
1950     * without having to resort to storing the Record types.
1951     *
1952     * ```
1953     * use easy_ml::numeric::extra::Real;
1954     * use easy_ml::matrices::Matrix;
1955     * use easy_ml::differentiation::{RecordMatrix, WengertList};
1956     *
1957     * fn sigmoid<T: Real+ Copy>(x: T) -> T {
1958     *     T::one() / (T::one() + (-x).exp())
1959     * }
1960     *
1961     * let history = WengertList::new();
1962     * let layer = RecordMatrix::variables(&history, Matrix::from(vec![vec![ 0.2, 0.6 ]]));
1963     * let after = layer.map(sigmoid).unwrap(); // sigmoid can't introduce inconsistent histories
1964     * ```
1965     *
1966     * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
1967     * after mapping doesn't have to be the same as before, only must be the same for every
1968     * mapped element.
1969     *
1970     * See also: [AsRecords](AsRecords)
1971     */
1972    #[allow(clippy::type_complexity)]
1973    #[track_caller]
1974    pub fn map(
1975        &self,
1976        fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1977    ) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
1978        let result = RecordMatrix::from_iter(self.size(), self.iter_row_major_as_records().map(fx));
1979        RecordMatrix::<'a, T, S>::map_collection(result, self.size())
1980    }
1981
1982    /**
1983     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1984     * some unary function on `Record<T>` and each index of that position in the Record to
1985     * `Record<T>` to every element in the container. This will fail if the function would
1986     * create records with inconsistent histories.
1987     *
1988     * When used with pure functions that can't return different histories for different inputs
1989     * unwrapping with always succeed.
1990     *
1991     * This API can allow you to call a generic function that operates on
1992     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1993     * during the intermediate calculations for you, without having to resort to storing the
1994     * Record types.
1995     *
1996     * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
1997     * after mapping doesn't have to be the same as before, only must be the same for every
1998     * mapped element.
1999     */
2000    #[allow(clippy::type_complexity)]
2001    #[track_caller]
2002    pub fn map_with_index(
2003        &self,
2004        fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
2005    ) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
2006        let result = RecordMatrix::from_iter(
2007            self.size(),
2008            self.iter_row_major_as_records()
2009                .with_index()
2010                .map(|((r, c), x)| fx(x, r, c)),
2011        );
2012        RecordMatrix::<'a, T, S>::map_collection(result, self.size())
2013    }
2014
2015    #[allow(clippy::type_complexity)]
2016    #[track_caller]
2017    fn map_collection(
2018        result: Result<
2019            RecordMatrix<'a, T, Matrix<(T, usize)>>,
2020            InvalidRecordIteratorError<'a, T, 2>,
2021        >,
2022        size: (Row, Column),
2023    ) -> Result<RecordMatrix<'a, T, Matrix<(T, usize)>>, InconsistentHistory<'a, T>> {
2024        use InvalidRecordIteratorError as Error;
2025        match result {
2026            Ok(matrix) => Ok(matrix),
2027            Err(error) => match error {
2028                // These first two should be 100% impossible but provide a sensible error just
2029                // in case some weird things break our invariants
2030                Error::Empty => panic!("Illegal state, record matrix was empty {:?}", size),
2031                Error::Shape { requested, length } => panic!(
2032                    "Illegal state, record matrix shape was inconsistent: requested: {:?}, length of data: {:?}",
2033                    requested, length
2034                ),
2035                // This one is theoretically possible but in practise shouldn't happen by accident
2036                // However, it can't implement Debug unless T is debug so to avoid having to
2037                // restrict our function signature we return a Result anyway - this also encourages
2038                // the user to make sure their function isn't going to cause this case, which
2039                // with some of the other variants like with_index might come up more easily
2040                Error::InconsistentHistory(h) => Err(h),
2041            },
2042        }
2043    }
2044
2045    /**
2046     * For each record in the container, peforms a backward pass up its WengertList from it
2047     * as the output, computing all the derivatives for the inputs involving this output.
2048     *
2049     * If this container has no backing WengertList, ie was created as constants, then None is
2050     * returned instead. Otherwise the returned Matrix will have the same size as this container,
2051     * with the respective derivatives matching each element in this container.
2052     *
2053     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is Y with M outputs,
2054     * then this computes all the derivatives δy<sub>j</sub>/δx<sub>i</sub> for i = 1 to N and
2055     * j = 1 to M.
2056     *
2057     * If you have a lot of outputs this could be very expensive! Reverse auto diff is optimised
2058     * for domains where there are many more inputs than outputs.
2059     *
2060     * If you only need some of the derivatives then
2061     * [derivatives_for](RecordMatrix::derivatives_for) can be used instead to avoid
2062     * calculating the rest.
2063     */
2064    pub fn derivatives(&self) -> Option<Matrix<Derivatives<T>>> {
2065        self.history.map(|history| {
2066            self.numbers.map(|(x, i)| {
2067                Record {
2068                    number: x,
2069                    history: Some(history),
2070                    index: i,
2071                }
2072                .derivatives()
2073            })
2074        })
2075    }
2076
2077    /**
2078     * For the record at the index, peforms a backward pass up its WengertList from it
2079     * as the output, computing all the derivatives for the inputs involving this output.
2080     *
2081     * If the index is invalid or this container has no backing WengertList, ie was created
2082     * as constants, then None is returned instead.
2083     *
2084     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
2085     * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
2086     */
2087    pub fn derivatives_for(&self, row: Row, column: Column) -> Option<Derivatives<T>> {
2088        let (number, index) = self
2089            .try_get_reference(row, column)
2090            .map(|(x, i)| (x.clone(), *i))?;
2091        // The nature of reverse autodiff is that we expect to only have a few outputs from
2092        // which we calculate all the derivatives we care about. Therefore just call Record and
2093        // reuse the implementation instead of trying to do anything clever like calculate all
2094        // derivatives for every number in this container.
2095        Record {
2096            number,
2097            history: self.history,
2098            index,
2099        }
2100        .try_derivatives()
2101    }
2102
2103    /**
2104     * Performs elementwise multiplication for two record matrices of the same size.
2105     *
2106     * # Panics
2107     *
2108     * - If both record containers have a WengertList that are different to each other
2109     * - If the record containers have different shapes
2110     */
2111    // TODO: Assign variants?
2112    pub fn elementwise_multiply<S2>(
2113        &self,
2114        other: &RecordMatrix<'a, T, S2>,
2115    ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
2116    where
2117        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2118    {
2119        self.binary(
2120            other,
2121            Multiplication::<T>::function,
2122            Multiplication::<T>::d_function_dx,
2123            Multiplication::<T>::d_function_dy,
2124        )
2125    }
2126
2127    /**
2128     * Performs elementwise division for two record matrices of the same size.
2129     *
2130     * # Panics
2131     *
2132     * - If both record containers have a WengertList that are different to each other
2133     * - If the record containers have different shapes
2134     */
2135    pub fn elementwise_divide<S2>(
2136        &self,
2137        other: &RecordMatrix<'a, T, S2>,
2138    ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
2139    where
2140        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2141    {
2142        self.binary(
2143            other,
2144            Division::<T>::function,
2145            Division::<T>::d_function_dx,
2146            Division::<T>::d_function_dy,
2147        )
2148    }
2149}
2150
2151impl<'a, T, S> RecordMatrix<'a, T, S>
2152where
2153    T: Numeric + Primitive,
2154    for<'t> &'t T: NumericRef<T>,
2155    S: MatrixMut<(T, Index)> + NoInteriorMutability,
2156{
2157    /**
2158     * Overwrites a RecordContainer by applying
2159     * some unary function from `T` to `T` to every element in the container.
2160     *
2161     * To compute the new records, the unary function of some input x to some
2162     * output y is needed along with its derivative with respect to its input x.
2163     */
2164    #[track_caller]
2165    pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
2166        let total = self.elements();
2167        match self.history {
2168            None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
2169            Some(history) => {
2170                let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
2171                for (element, result) in self.numbers.row_major_reference_mut_iter().zip(ys) {
2172                    *element = result;
2173                }
2174                self.history = Some(history);
2175            }
2176        }
2177    }
2178
2179    /**
2180     * Overwrites the left hand side of a RecordContainer with the result of applying
2181     * some binary function from `T` to `T` to every element pair in the containers. Both
2182     * containers must have the same shape.
2183     * To compute the new records, the binary function of some inputs x and y to some
2184     * output z is needed along with its derivative with respect to its first input x and
2185     * its derivative with respect to its second input y.
2186     *
2187     * # Panics
2188     *
2189     * - If both record containers have a WengertList that are different to each other
2190     * - If the record containers have different shapes
2191     */
2192    #[track_caller]
2193    pub fn binary_left_assign<S2>(
2194        &mut self,
2195        rhs: &RecordMatrix<'a, T, S2>,
2196        fxy: impl Fn(T, T) -> T,
2197        dfxy_dx: impl Fn(T, T) -> T,
2198        dfxy_dy: impl Fn(T, T) -> T,
2199    ) where
2200        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2201    {
2202        {
2203            let left_shape = self.numbers.size();
2204            let right_shape = rhs.numbers.size();
2205            if left_shape != right_shape {
2206                panic!(
2207                    "Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
2208                    left_shape, right_shape
2209                );
2210            }
2211        }
2212        let total = self.elements();
2213        match (self.history, rhs.history) {
2214            (None, None) => {
2215                for (x, y) in self
2216                    .numbers
2217                    .row_major_reference_mut_iter()
2218                    .zip(rhs.numbers.row_major_iter())
2219                {
2220                    let (left, _) = x;
2221                    let (right, _) = y;
2222                    *x = (fxy(left.clone(), right), 0);
2223                }
2224            }
2225            (Some(history), None) => {
2226                let zs = binary_x_history::<T, _, _>(
2227                    total,
2228                    history,
2229                    self.numbers.row_major_iter(),
2230                    rhs.numbers.row_major_iter(),
2231                    fxy,
2232                    dfxy_dx,
2233                );
2234                for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2235                    *element = result;
2236                }
2237                self.history = Some(history);
2238            }
2239            (None, Some(history)) => {
2240                let zs = binary_y_history::<T, _, _>(
2241                    total,
2242                    history,
2243                    self.numbers.row_major_iter(),
2244                    rhs.numbers.row_major_iter(),
2245                    fxy,
2246                    dfxy_dy,
2247                );
2248                for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2249                    *element = result;
2250                }
2251                self.history = Some(history);
2252            }
2253            (Some(history), Some(h)) => {
2254                assert!(
2255                    record_operations::same_lists(history, h),
2256                    "Record containers must be using the same WengertList"
2257                );
2258                let zs = binary_both_history::<T, _, _>(
2259                    total,
2260                    history,
2261                    self.numbers.row_major_iter(),
2262                    rhs.numbers.row_major_iter(),
2263                    fxy,
2264                    dfxy_dx,
2265                    dfxy_dy,
2266                );
2267                for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2268                    *element = result;
2269                }
2270                self.history = Some(history);
2271            }
2272        }
2273    }
2274
2275    /**
2276     * A convenience helper function which takes the RecordContainer value and
2277     * calls [unary_assign](RecordMatrix::unary_assign()) on it, returning
2278     * the record container which now contains the result of the operation.
2279     */
2280    #[track_caller]
2281    pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
2282        self.unary_assign(fx, dfx_dx);
2283        self
2284    }
2285
2286    /**
2287     * A convenience helper function which takes the left hand side by value and
2288     * calls [binary_left_assign](RecordMatrix::binary_left_assign()) on it, returning
2289     * the left hand side which now contains the result of the operation.
2290     */
2291    #[track_caller]
2292    pub fn do_binary_left_assign<S2>(
2293        mut self,
2294        rhs: &RecordMatrix<'a, T, S2>,
2295        fxy: impl Fn(T, T) -> T,
2296        dfxy_dx: impl Fn(T, T) -> T,
2297        dfxy_dy: impl Fn(T, T) -> T,
2298    ) -> Self
2299    where
2300        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2301    {
2302        self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
2303        self
2304    }
2305
2306    /**
2307     * Updates this RecordContainer in place by applying some unary function on `Record<T>` to
2308     * `Record<T>` to every element in the container. This will fail if the function would create
2309     * records with inconsistent histories.
2310     *
2311     * When used with pure functions that can't return different histories for different inputs
2312     * unwrapping with always succeed.
2313     *
2314     * Since this updates the container in place, if Err is returned then the data in this
2315     * RecordContainer is still available but it has been corrupted - at least one of the elements
2316     * should have a different history than what it will have because the mapping function created
2317     * inconsistent histories that couldn't be represented by the container as it only stores
2318     * one.
2319     *
2320     * This API can allow you to call a generic function that operates on
2321     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2322     * during the intermediate calculations for you, without having to resort to storing the
2323     * Record types.
2324     *
2325     * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2326     * after mapping doesn't have to be the same as before, only must be the same for every
2327     * mapped element.
2328     *
2329     * You might also use this function at the end of a training loop to update all the weights
2330     * to reduce their loss.
2331     *
2332     * ```
2333     * use easy_ml::numeric::Numeric;
2334     * use easy_ml::matrices::Matrix;
2335     * use easy_ml::differentiation::{Record, RecordMatrix, WengertList};
2336     *
2337     * let history = WengertList::new();
2338     * let mut weights = RecordMatrix::variables(
2339     *     &history,
2340     *     Matrix::from(vec![vec![ 0.3, 0.2, -1.2, -0.4 ]])
2341     * );
2342     * let error = {
2343     *     // this is over-simplified for brevity, obviously in a real scenario we wouldn't have a
2344     *     // function that calculates the error like this or we wouldn't be doing machine learning
2345     *     // to fit it in the first place
2346     *     let mut loss = Record::variable(0.0, &history);
2347     *     for r in weights.iter_row_major_as_records() {
2348     *         loss = loss + r;
2349     *     }
2350     *     loss
2351     * };
2352     * let derivatives = error.derivatives();
2353     * let learning_rate = 0.1;
2354     * // update the weights to contain less error than they did before
2355     * let result = weights.map_mut(|x| x - (derivatives[&x] * learning_rate));
2356     * assert!(result.is_ok()); // we know we didn't introduce an inconsistent history just updating the weights
2357     * ```
2358     */
2359    #[track_caller]
2360    pub fn map_mut(
2361        &mut self,
2362        fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
2363    ) -> Result<(), InconsistentHistory<'a, T>> {
2364        let history = self.history;
2365        let new_history =
2366            map_mut_base::<'a, T, _, _>(RowMajorReferenceMutIterator::from(self), |x| {
2367                let record = Record::from_existing(x.clone(), history);
2368                let result = fx(record);
2369                *x = (result.number, result.index);
2370                result.history
2371            })?;
2372        self.history = new_history;
2373        Ok(())
2374    }
2375
2376    /**
2377     * Updates this RecordContainer in place by applying some unary function on `Record<T>` and
2378     * each index of that position in the Record to `Record<T>` to every element in the container.
2379     * This will fail if the function would create records with inconsistent histories.
2380     *
2381     * When used with pure functions that can't return different histories for different inputs
2382     * unwrapping with always succeed.
2383     *
2384     * Since this updates the container in place, if Err is returned then the data in this
2385     * RecordContainer is still available but it has been corrupted - at least one of the elements
2386     * should have a different history than what it will have because the mapping function created
2387     * inconsistent histories that couldn't be represented by the container as it only stores
2388     * one.
2389     *
2390     * This API can allow you to call a generic function that operates on
2391     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2392     * during the intermediate calculations for you, without having to resort to storing the
2393     * Record types.
2394     *
2395     * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2396     * after mapping doesn't have to be the same as before, only must be the same for every
2397     * mapped element.
2398     */
2399    #[track_caller]
2400    pub fn map_mut_with_index(
2401        &mut self,
2402        fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
2403    ) -> Result<(), InconsistentHistory<'a, T>> {
2404        let history = self.history;
2405        let new_history = map_mut_base::<'a, T, _, _>(
2406            RowMajorReferenceMutIterator::from(self).with_index(),
2407            |((r, c), x)| {
2408                let record = Record::from_existing(x.clone(), history);
2409                let result = fx(record, r, c);
2410                *x = (result.number, result.index);
2411                result.history
2412            },
2413        )?;
2414        self.history = new_history;
2415        Ok(())
2416    }
2417}
2418
2419impl<'a, T, S> RecordMatrix<'a, T, S>
2420where
2421    T: Numeric + Primitive,
2422    for<'t> &'t T: NumericRef<T>,
2423    S: MatrixRef<(T, Index)> + NoInteriorMutability,
2424{
2425    /**
2426     * Overwrites the right hand side of a RecordContainer with the result of applying
2427     * some binary function from `T` to `T` to every element pair in the containers. Both
2428     * containers must have the same shape.
2429     * To compute the new records, the binary function of some inputs x and y to some
2430     * output z is needed along with its derivative with respect to its first input x and
2431     * its derivative with respect to its second input y.
2432     *
2433     * # Panics
2434     *
2435     * - If both record containers have a WengertList that are different to each other
2436     * - If the record containers have different shapes
2437     */
2438    #[track_caller]
2439    pub fn binary_right_assign<S2>(
2440        &self,
2441        rhs: &mut RecordMatrix<'a, T, S2>,
2442        fxy: impl Fn(T, T) -> T,
2443        dfxy_dx: impl Fn(T, T) -> T,
2444        dfxy_dy: impl Fn(T, T) -> T,
2445    ) where
2446        S2: MatrixMut<(T, Index)> + NoInteriorMutability,
2447    {
2448        // x is lhs, y is rhs, so calling binary_left_assign on the rhs container
2449        // means we need to swap all the arguments
2450        rhs.binary_left_assign(
2451            self,
2452            |y, x| fxy(x, y),
2453            |y, x| dfxy_dy(x, y),
2454            |y, x| dfxy_dx(x, y),
2455        )
2456    }
2457
2458    /**
2459     * A convenience helper function which takes the right hand side by value and
2460     * calls [binary_right_assign](RecordMatrix::binary_right_assign()) on it, returning
2461     * the right hand side which now contains the result of the operation.
2462     */
2463    #[track_caller]
2464    pub fn do_binary_right_assign<S2>(
2465        &self,
2466        mut rhs: RecordMatrix<'a, T, S2>,
2467        fxy: impl Fn(T, T) -> T,
2468        dfxy_dx: impl Fn(T, T) -> T,
2469        dfxy_dy: impl Fn(T, T) -> T,
2470    ) -> RecordMatrix<'a, T, S2>
2471    where
2472        S2: MatrixMut<(T, Index)> + NoInteriorMutability,
2473    {
2474        self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
2475        rhs
2476    }
2477}
2478
2479// # Safety
2480//
2481// Our inner `numbers` tensor has to implement TensorRef correctly so by delegating to it
2482// without changing any indexes or introducing interior mutability, we implement TensorRef
2483// correctly as well.
2484/**
2485 * RecordTensor implements TensorRef when the source does, returning references to the tuples
2486 * of `T` and [`Index`](Index).
2487 */
2488unsafe impl<'a, T, S, const D: usize> TensorRef<(T, Index), D> for RecordTensor<'a, T, S, D>
2489where
2490    T: Primitive,
2491    S: TensorRef<(T, Index), D>,
2492{
2493    fn get_reference(&self, indexes: [usize; D]) -> Option<&(T, Index)> {
2494        self.numbers.source_ref().get_reference(indexes)
2495    }
2496
2497    fn view_shape(&self) -> [(Dimension, usize); D] {
2498        self.numbers.source_ref().view_shape()
2499    }
2500
2501    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &(T, Index) {
2502        unsafe { self.numbers.source_ref().get_reference_unchecked(indexes) }
2503    }
2504
2505    fn data_layout(&self) -> DataLayout<D> {
2506        self.numbers.source_ref().data_layout()
2507    }
2508}
2509
2510// # Safety
2511//
2512// Our inner `numbers` tensor has to implement TensorMut correctly so by delegating to it
2513// without changing any indexes or introducing interior mutability, we implement TensorMut
2514// correctly as well.
2515/**
2516 * RecordTensor implements TensorMut when the source does, returning mutable references to the
2517 * tuples of `T` and [`Index`](Index).
2518 */
2519unsafe impl<'a, T, S, const D: usize> TensorMut<(T, Index), D> for RecordTensor<'a, T, S, D>
2520where
2521    T: Primitive,
2522    S: TensorMut<(T, Index), D>,
2523{
2524    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut (T, Index)> {
2525        self.numbers.source_ref_mut().get_reference_mut(indexes)
2526    }
2527
2528    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut (T, Index) {
2529        unsafe {
2530            self.numbers
2531                .source_ref_mut()
2532                .get_reference_unchecked_mut(indexes)
2533        }
2534    }
2535}
2536
2537// # Safety
2538//
2539// Our inner `numbers` matrix has to implement MatrixRef correctly so by delegating to it
2540// without changing any indexes or introducing interior mutability, we implement MatrixRef
2541// correctly as well.
2542/**
2543 * RecordMatrix implements MatrixRef when the source does, returning references to the tuples
2544 * of `T` and [`Index`](Index).
2545 */
2546unsafe impl<'a, T, S> MatrixRef<(T, Index)> for RecordMatrix<'a, T, S>
2547where
2548    T: Primitive,
2549    S: MatrixRef<(T, Index)>,
2550{
2551    fn try_get_reference(&self, row: Row, column: Column) -> Option<&(T, Index)> {
2552        self.numbers.source_ref().try_get_reference(row, column)
2553    }
2554
2555    fn view_rows(&self) -> Row {
2556        self.numbers.source_ref().view_rows()
2557    }
2558
2559    fn view_columns(&self) -> Column {
2560        self.numbers.source_ref().view_columns()
2561    }
2562
2563    unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &(T, Index) {
2564        unsafe {
2565            self.numbers
2566                .source_ref()
2567                .get_reference_unchecked(row, column)
2568        }
2569    }
2570
2571    fn data_layout(&self) -> crate::matrices::views::DataLayout {
2572        self.numbers.source_ref().data_layout()
2573    }
2574}
2575
2576// # Safety
2577//
2578// Our inner `numbers` matrix has to implement NoInteriorMutability correctly so by delegating to
2579// it without introducing interior mutability, we implement NoInteriorMutability
2580// correctly as well.
2581/**
2582 * RecordMatrix implements NoInteriorMutability when the source does.
2583 */
2584unsafe impl<'a, T, S> NoInteriorMutability for RecordMatrix<'a, T, S>
2585where
2586    T: Primitive,
2587    S: NoInteriorMutability,
2588{
2589}
2590
2591// # Safety
2592//
2593// Our inner `numbers` matrix has to implement MatrixMut correctly so by delegating to it
2594// without changing any indexes or introducing interior mutability, we implement MatrixMut
2595// correctly as well.
2596/**
2597 * RecordMatrix implements MatrixMut when the source does, returning mutable references to the
2598 * tuples of `T` and [`Index`].
2599 */
2600unsafe impl<'a, T, S> MatrixMut<(T, Index)> for RecordMatrix<'a, T, S>
2601where
2602    T: Primitive,
2603    S: MatrixMut<(T, Index)>,
2604{
2605    fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut (T, Index)> {
2606        self.numbers
2607            .source_ref_mut()
2608            .try_get_reference_mut(row, column)
2609    }
2610
2611    unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut (T, Index) {
2612        unsafe {
2613            self.numbers
2614                .source_ref_mut()
2615                .get_reference_unchecked_mut(row, column)
2616        }
2617    }
2618}
2619
2620/**
2621 * A zero dimensional record tensor can be converted losslessly into a record.
2622 */
2623impl<'a, T, S> From<RecordTensor<'a, T, S, 0>> for Record<'a, T>
2624where
2625    T: Numeric + Primitive,
2626    S: TensorRef<(T, Index), 0>,
2627{
2628    /**
2629     * Converts the sole element in the zero dimensional record tensor into a record.
2630     */
2631    fn from(scalar: RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
2632        // Not a good way to make this zero copy and just move the data out of the scalar because
2633        // TensorRef API doesn't have by value moves of the data and using TensorOwnedIterator
2634        // requires T: Default and a dummy value (at which point a clone is probably cheaper or
2635        // basically the same?)
2636        Record::from(&scalar)
2637    }
2638}
2639
2640/**
2641 * A zero dimensional record tensor can be converted losslessly into a record.
2642 */
2643impl<'a, T, S> From<&RecordTensor<'a, T, S, 0>> for Record<'a, T>
2644where
2645    T: Numeric + Primitive,
2646    S: TensorRef<(T, Index), 0>,
2647{
2648    /**
2649     * Converts the sole element in the zero dimensional record tensor into a record.
2650     */
2651    fn from(scalar: &RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
2652        Record::from_existing(scalar.view().scalar(), scalar.history)
2653    }
2654}
2655
2656/**
2657 * A record can be converted losslessly into a zero dimensional record tensor.
2658 */
2659impl<'a, T> From<Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
2660where
2661    T: Numeric + Primitive,
2662{
2663    /**
2664     * Converts a record into a zero dimensional record tensor with the single element.
2665     */
2666    fn from(record: Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
2667        RecordTensor::from_existing(
2668            record.history,
2669            TensorView::from(Tensor::from([], vec![(record.number, record.index)])),
2670        )
2671    }
2672}
2673
2674/**
2675 * A record can be converted losslessly into a zero dimensional record tensor.
2676 */
2677impl<'a, T> From<&Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
2678where
2679    T: Numeric + Primitive,
2680{
2681    /**
2682     * Converts a record into a zero dimensional record tensor with the single element.
2683     */
2684    fn from(record: &Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
2685        RecordTensor::from_existing(
2686            record.history,
2687            TensorView::from(Tensor::from(
2688                [],
2689                vec![(record.number.clone(), record.index)],
2690            )),
2691        )
2692    }
2693}
2694
2695#[test]
2696fn matrix_multiplication_derivatives_are_the_same() {
2697    #[rustfmt::skip]
2698    let a = Tensor::from(
2699        [("r", 4), ("c", 3)],
2700        vec![
2701            1.0, 2.0, 3.0,
2702            4.0, 5.0, 6.0,
2703            7.0, 8.0, 9.0,
2704            0.0, 5.0, 2.0
2705        ]
2706    );
2707    let b = a.transpose(["c", "r"]);
2708    let history = WengertList::new();
2709    let also_history: WengertList<f64> = WengertList::new();
2710    let tensor_of_records_a = a.map(|x| Record::variable(x, &history));
2711    let tensor_of_records_b = b.map(|x| Record::variable(x, &history));
2712    let tensor_of_records_c = &tensor_of_records_a * &tensor_of_records_b;
2713    let record_tensor_a = RecordTensor::variables(&also_history, a);
2714    let record_tensor_b = RecordTensor::variables(&also_history, b);
2715    let record_tensor_c = &record_tensor_a * &record_tensor_b;
2716
2717    // C should be calculated the same in terms of the actual number
2718    assert_eq!(
2719        tensor_of_records_c.map(|r| r.number),
2720        TensorView::from(&record_tensor_c).map(|(n, _)| n)
2721    );
2722
2723    let tensor_of_records_derivatives = tensor_of_records_c.map(|r| r.derivatives());
2724    let tensor_of_records_a_derivatives =
2725        tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_a));
2726    let tensor_of_records_b_derivatives =
2727        tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_b));
2728
2729    let record_tensor_derivatives = record_tensor_c.derivatives().unwrap();
2730    let record_tensor_a_derivatives =
2731        record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_a));
2732    let record_tensor_b_derivatives =
2733        record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_b));
2734
2735    // Every calculated derivative should match exactly
2736    assert_eq!(tensor_of_records_a_derivatives, record_tensor_a_derivatives);
2737    assert_eq!(tensor_of_records_b_derivatives, record_tensor_b_derivatives);
2738
2739    // Verify C is actually calculated correctly
2740    #[rustfmt::skip]
2741    assert_eq!(
2742        tensor_of_records_c.map(|r| r.number),
2743        Tensor::from(
2744            [("r", 4), ("c", 4)],
2745            vec![
2746                14.0, 32.0, 50.0, 16.0,
2747                32.0, 77.0, 122.0, 37.0,
2748                50.0, 122.0, 194.0, 58.0,
2749                16.0, 37.0, 58.0, 29.0
2750            ]
2751        )
2752    );
2753    #[rustfmt::skip]
2754    assert_eq!(
2755        tensor_of_records_c.map(|r| r.number),
2756        Tensor::from(
2757            [("r", 4), ("c", 4)],
2758            vec![
2759                (1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
2760                (1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
2761                (1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
2762                (1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
2763
2764                (4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
2765                (4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
2766                (4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
2767                (4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
2768
2769                (7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
2770                (7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
2771                (7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
2772                (7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
2773
2774                (0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
2775                (0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
2776                (0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
2777                (0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
2778            ]
2779        )
2780    );
2781
2782    let tensor_of_records_derivatives = history.operations.borrow().clone();
2783    let record_tensor_derivatives = also_history.operations.borrow().clone();
2784    assert_eq!(
2785        tensor_of_records_derivatives.len(),
2786        record_tensor_derivatives.len()
2787    );
2788}
2789
2790#[test]
2791fn matrix_view_matrix_multiplication_derivatives_are_the_same() {
2792    #[rustfmt::skip]
2793    let a = Matrix::from(vec![
2794        vec![ 1.0, 2.0, 3.0 ],
2795        vec![ 4.0, 5.0, 6.0 ],
2796        vec![ 7.0, 8.0, 9.0 ],
2797        vec![ 0.0, 5.0, 2.0 ]
2798    ]);
2799    let b = a.transpose();
2800    let history = WengertList::new();
2801    let also_history: WengertList<f64> = WengertList::new();
2802    let matrix_of_records_a = a.map(|x| Record::variable(x, &history));
2803    let matrix_of_records_b = b.map(|x| Record::variable(x, &history));
2804    let matrix_of_records_c = &matrix_of_records_a * &matrix_of_records_b;
2805    let record_matrix_a = RecordMatrix::variables(&also_history, a);
2806    let record_matrix_b = RecordMatrix::variables(&also_history, b);
2807    let record_matrix_c = &record_matrix_a * &record_matrix_b;
2808
2809    // C should be calculated the same in terms of the actual number
2810    assert_eq!(
2811        matrix_of_records_c.map(|r| r.number),
2812        MatrixView::from(&record_matrix_c).map(|(n, _)| n)
2813    );
2814
2815    let matrix_of_records_derivatives = matrix_of_records_c.map(|r| r.derivatives());
2816    let matrix_of_records_a_derivatives =
2817        matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_a));
2818    let matrix_of_records_b_derivatives =
2819        matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_b));
2820
2821    let record_matrix_derivatives = record_matrix_c.derivatives().unwrap();
2822    let record_matrix_a_derivatives =
2823        record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_a));
2824    let record_matrix_b_derivatives =
2825        record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_b));
2826
2827    // Every calculated derivative should match exactly
2828    assert_eq!(matrix_of_records_a_derivatives, record_matrix_a_derivatives);
2829    assert_eq!(matrix_of_records_b_derivatives, record_matrix_b_derivatives);
2830
2831    // Verify C is actually calculated correctly
2832    #[rustfmt::skip]
2833    assert_eq!(
2834        matrix_of_records_c.map(|r| r.number),
2835        Matrix::from(vec![
2836                vec![ 14.0, 32.0, 50.0, 16.0 ],
2837                vec![ 32.0, 77.0, 122.0, 37.0 ],
2838                vec![ 50.0, 122.0, 194.0, 58.0 ],
2839                vec![ 16.0, 37.0, 58.0, 29.0 ]
2840            ]
2841        )
2842    );
2843    #[rustfmt::skip]
2844    assert_eq!(
2845        matrix_of_records_c.map(|r| r.number),
2846        Matrix::from(vec![
2847                vec![
2848                    (1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
2849                    (1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
2850                    (1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
2851                    (1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
2852                ],
2853                vec![
2854                    (4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
2855                    (4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
2856                    (4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
2857                    (4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
2858                ],
2859                vec![
2860                    (7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
2861                    (7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
2862                    (7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
2863                    (7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
2864                ],
2865                vec![
2866                    (0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
2867                    (0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
2868                    (0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
2869                    (0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
2870                ]
2871        ])
2872    );
2873
2874    let matrix_of_records_derivatives = history.operations.borrow().clone();
2875    let record_matrix_derivatives = also_history.operations.borrow().clone();
2876    assert_eq!(
2877        matrix_of_records_derivatives.len(),
2878        record_matrix_derivatives.len()
2879    );
2880}