Skip to main content

easy_ml/differentiation/container_record/
mod.rs

1use crate::differentiation::functions::{Division, FunctionDerivative, Multiplication};
2use crate::differentiation::iterators::{
3    AsRecords, InconsistentHistory, InvalidRecordIteratorError,
4};
5use crate::differentiation::record_operations;
6use crate::differentiation::{Derivatives, Index, Primitive, Record, WengertList};
7use crate::matrices::iterators::{
8    ColumnMajorIterator, RowMajorIterator, RowMajorOwnedIterator, RowMajorReferenceMutIterator,
9};
10use crate::matrices::views::{MatrixMut, MatrixRef, MatrixView, NoInteriorMutability};
11use crate::matrices::{Column, Matrix, Row};
12use crate::numeric::{Numeric, NumericRef};
13use crate::tensors::indexing::{
14    TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceMutIterator,
15};
16use crate::tensors::views::{DataLayout, TensorMut, TensorRef, TensorRename, TensorView};
17use crate::tensors::{Dimension, Tensor};
18
19mod container_operations;
20pub mod iterators;
21
22/**
23 * A pluralisation of [Record](crate::differentiation::Record) that groups together a
24 * **s**ource of numbers of type T and stores the WengertList only once.
25 *
26 * Typically you would refer to one of the type aliases to disambiguate the type of `S` and
27 * use more succinct generics: [RecordMatrix](RecordMatrix), [RecordTensor](RecordTensor).
28 *
29 * For both Matrix and Tensor source types, the containers implement [`+`](std::ops::Add) and
30 * [`-`](std::ops::Sub) and have the methods `elementwise_multiply` and `elementwise_divide`.
31 * In all cases the containers must have the same size for the operation and will panic if
32 * mismatched.
33 * [`*`](std::ops::Mul) is also implemented for 2 dimensional tensors and matrices as matrix
34 * multiplication.
35 *
36 * For convenience, as with Trace and Record, many unary operations including
37 * [Cos](crate::numeric::extra::Cos), [Exp](crate::numeric::extra::Exp),
38 * [Ln](crate::numeric::extra::Ln), [Neg](std::ops::Neg), [Pow](crate::numeric::extra::Pow),
39 * [Sin](crate::numeric::extra::Sin), and [Sqrt](crate::numeric::extra::Sqrt) are implemented as
40 * well, applying the unary function to each element in the tensor.
41 *
42 * `+`, `-`, `*` and `/` operations with a RecordContainer and a scalar are also implemented,
43 * treating the right hand side scalar as a constant. These are also unary functions in terms of
44 * the derivative tracking, for example `X + 5` applies the function `+5` to each element in
45 * `X`. Due to the orphan rule, the standard library scalars cannot be implemented for a left hand
46 * side scalar, see [SwappedOperations](crate::differentiation::record_operations::SwappedOperations).
47 *
48 * Both [RecordMatrix](RecordMatrix) and [RecordTensor](RecordTensor) implement
49 * [MatrixRef](MatrixRef) and [TensorRef](TensorRef) respectively, which provide read and write
50 * access to the underlying numbers and indexes into the WengertList. These APIs along with the
51 * `from_existing` constructors for RecordMatrix, RecordTensor, and Record allow arbitary
52 * manipulation of specific elements in a record container if needed. However, any arithmetic
53 * performed on the raw data won't be tracked on the WengertList and overwriting data within a
54 * record container already tracked on the WengertList could result in nonsense. These APIs exist
55 * for easy read access to check the data in a record container and for read/write access when
56 * manipulating the shape of a record container, and are designed to be used only for moving data
57 * around - you should put it back unchanged in a RecordContainer or Record before doing further
58 * arithmetic that needs to be tracked on the WengertList.
59 *
60 * If you just want to manipulate the data in record containers as if they were Records you can
61 * use the iterator APIs of [AsRecords](AsRecords) instead and collect them back into containers
62 * when you're done, or to manipulate a single container, the [map](RecordContainer::map) and
63 * [map_mut](RecordContainer::map_mut) methods.
64 */
65#[derive(Debug)]
66pub struct RecordContainer<'a, T: Primitive, S, const D: usize> {
67    // Opted to store the indexes alongside each number (T, Index) for a number of reasons, the
68    // main factor being it makes implementing TensorRef feasible so can utilise the range of
69    // existing APIs for Tensor manipulation. It's theoretically possible to only store the first
70    // index and calculate the rest, since most of the time all indexes are ascending entries in
71    // the WengertList but this would also massively complicate the implementation, especially for
72    // handling non row-major operations such as matrix multiplication. It's also not super clear
73    // that this would be more efficient because it turns reads into more arithmetic rather
74    // than avoiding any work. Just lifting the WengertList out of the tensor should have
75    // meaningful improvements to cache line efficiency, and failing that still disallows
76    // very questionable states from existing.
77    numbers: S,
78    history: Option<&'a WengertList<T>>,
79}
80
81/**
82 * Alias for succinctly referring to RecordContainers backed by a matrix.
83 */
84pub type RecordMatrix<'a, T, S> = RecordContainer<'a, T, MatrixView<(T, Index), S>, 2>;
85
86/**
87 * Alias for succinctly referring to RecordContainers backed by a tensor.
88 */
89pub type RecordTensor<'a, T, S, const D: usize> =
90    RecordContainer<'a, T, TensorView<(T, Index), S, D>, D>;
91
92fn calculate_incrementing_indexes(starting_index: usize, total: usize) -> Vec<Index> {
93    let mut indexes = vec![0; total];
94    for (i, x) in indexes.iter_mut().enumerate() {
95        *x = starting_index + i;
96    }
97    indexes
98}
99
100impl<'a, T, const D: usize> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
101where
102    T: Numeric + Primitive,
103{
104    /**
105     * Creates multiple untracked Records which have no backing WengertList.
106     *
107     * This is provided for using constants along with Records in operations.
108     *
109     * For example with `Y = X + 4` the computation graph could be conceived as many
110     * `Y[i,j]` nodes with parent nodes of `X[i,j]` and 4 combined with the operation `+`.
111     * However there is no need to record the derivatives of a constant, so
112     * instead the computation graph can be conceived as `Y[i,j]` nodes each with a single
113     * parent node of `X[i,j]` and the unary operation of `+4`.
114     */
115    pub fn constants<S>(c: S) -> Self
116    where
117        S: TensorMut<T, D>,
118    {
119        RecordContainer {
120            numbers: TensorView::from(Tensor::from(
121                c.view_shape(),
122                TensorOwnedIterator::from_numeric(c)
123                    .map(|x| (x, 0))
124                    .collect(),
125            )),
126            history: None,
127        }
128    }
129
130    /**
131     * Creates multiple records backed by the provided WengertList.
132     *
133     * The records cannot live longer than the WengertList, hence
134     * the following example does not compile
135     *
136     * ```compile_fail
137     * use easy_ml::differentiation::RecordTensor;
138     * use easy_ml::differentiation::WengertList;
139     * use easy_ml::tensors::Tensor;
140     * let record = {
141     *     let list = WengertList::new();
142     *     RecordTensor::variables(
143     *         &list,
144     *         Tensor::from([("r", 2), ("c", 2)], vec![ 1.0, 2.0, 3.0, 4.0 ])
145     *     )
146     * }; // list no longer in scope
147     * ```
148     */
149    pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
150    where
151        S: TensorMut<T, D>,
152    {
153        let total = crate::tensors::dimensions::elements(&x.view_shape());
154        let starting_index = history.append_nullary_repeating(total);
155        RecordContainer {
156            numbers: TensorView::from(Tensor::from(
157                x.view_shape(),
158                TensorOwnedIterator::from_numeric(x)
159                    .zip(calculate_incrementing_indexes(starting_index, total))
160                    .collect(),
161            )),
162            history: Some(history),
163        }
164    }
165}
166
167impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
168where
169    T: Numeric + Primitive,
170    S: TensorRef<(T, Index), D>,
171{
172    /**
173     * Returns the number of elements stored by this container's source.
174     *
175     * For a 2 x 3 Tensor, this would return 6, and for a 2 x 3 x 4 Tensor this would return 24
176     * and so on.
177     *
178     * see also [dimensions::elements](crate::tensors::dimensions::elements)
179     */
180    pub fn elements(&self) -> usize {
181        crate::tensors::dimensions::elements(&self.numbers.shape())
182    }
183
184    /**
185     * The shape of this container's source.
186     */
187    pub fn shape(&self) -> [(Dimension, usize); D] {
188        self.numbers.shape()
189    }
190
191    /**
192     * Creates a container from constants/variables directly, most likely obtained by getting a
193     * tensor view of an existing container. **The inputs are not checked for validity**. It is
194     * possible to pass in the wrong Wengert list here or even numbers with indexes that aren't
195     * tracked on the WengertList.
196     *
197     * It is recommended to use this constructor only in conjunction with
198     * resizing or masking an existing container and not for creating new variables. Any variables
199     * created outside of `RecordContainer::variables` would have to be manually added to the
200     * correct Wengert list, and any arithmetic operations would also need tracking correctly.
201     *
202     * ```
203     * use easy_ml::differentiation::RecordTensor;
204     * use easy_ml::differentiation::WengertList;
205     * use easy_ml::tensors::Tensor;
206     * use easy_ml::tensors::views::{TensorView, TensorRange};
207     *
208     * let list = WengertList::new();
209     * let x = RecordTensor::variables(
210     *     &list,
211     *     Tensor::from_fn([("x", 2), ("y", 2)], |[r, c]| ((r + 3) * (c + 2)) as f64)
212     * );
213     * // oh no wrong shape!
214     * let fixed = TensorView::from(TensorRange::from(x, [("y", 0..1)]).unwrap()); // we can unwrap here because we know the range is valid
215     * let x = RecordTensor::from_existing(Some(&list), fixed);
216     * assert_eq!([("x", 2), ("y", 1)], x.shape());
217     * ```
218     */
219    pub fn from_existing(
220        history: Option<&'a WengertList<T>>,
221        numbers: TensorView<(T, Index), S, D>,
222    ) -> Self {
223        RecordContainer { numbers, history }
224    }
225
226    /**
227     * Returns a record tensor with the dimension names of the shape renamed to the provided
228     * dimensions. The data of this container and the dimension lengths and order remain unchanged.
229     *
230     * This is a shorthand for constructing the RecordTensor via manipulating this TensorView. See
231     * [`RecordTensor::from_existing`](RecordTensor::from_existing).
232     *
233     * # Panics
234     *
235     * If a dimension name is not unique
236     *
237     * ```
238     * use easy_ml::differentiation::RecordTensor;
239     * use easy_ml::differentiation::WengertList;
240     * use easy_ml::tensors::Tensor;
241     * use easy_ml::tensors::views::{TensorView, TensorRename};
242     *
243     * let list = WengertList::new();
244     * let x = RecordTensor::variables(
245     *     &list,
246     *     Tensor::from_fn([("x", 2), ("y", 2)], |[r, c]| ((r + 3) * (c + 2)) as f64)
247     * );
248     * // oh no wrong dimension names!
249     * let x = x.rename_view(["a", "b"]);
250     * assert_eq!([("a", 2), ("b", 2)], x.shape());
251     * ```
252     */
253    #[track_caller]
254    pub fn rename_view(
255        self,
256        dimensions: [Dimension; D],
257    ) -> RecordTensor<'a, T, TensorRename<(T, Index), S, D>, D> {
258        RecordTensor::from_existing(
259            self.history,
260            TensorView::from(TensorRename::from(self.numbers.source(), dimensions)),
261        )
262    }
263
264    /**
265     * Returns a TensorView of this record container, both the `T` for each record element and
266     * also the index for that record's entry in the WengertList. These can be parsed back into
267     * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
268     * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
269     * numerical operations on the data.
270     */
271    pub fn view(&self) -> TensorView<(T, Index), &RecordTensor<'a, T, S, D>, D> {
272        TensorView::from(self)
273    }
274
275    /**
276     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
277     * to read values from this record container, both the `T` for each record element and
278     * also the index for that record's entry in the WengertList. These can be parsed back into
279     * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
280     * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
281     * numerical operations on the data.
282     *
283     * # Panics
284     *
285     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
286     */
287    #[track_caller]
288    pub fn index_by(
289        &self,
290        dimensions: [Dimension; D],
291    ) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
292        TensorAccess::from(self, dimensions)
293    }
294
295    /**
296     * Creates a TensorAccess which will index into the dimensions this record was created with
297     * in the same order as they were provided, both the `T` for each record element and
298     * also the index for that record's entry in the WengertList. These can be parsed back into
299     * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
300     * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
301     * numerical operations on the data.
302     *
303     * See [TensorAccess::from_source_order], [get_as_record](TensorAccess::get_as_record).
304     *
305     * ```
306     * use easy_ml::differentiation::RecordTensor;
307     * use easy_ml::differentiation::WengertList;
308     * use easy_ml::tensors::Tensor;
309     *
310     * let list = WengertList::new();
311     * let X = RecordTensor::variables(
312     *     &list,
313     *     Tensor::from([("a", 3)], vec![ 3.0, 4.0, 5.0 ])
314     * );
315     * let x = X.index().get_as_record([0]);
316     * assert_eq!(x.number, 3.0);
317     * ```
318     */
319    pub fn index(&self) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
320        TensorAccess::from_source_order(self)
321    }
322
323    /**
324     * Returns an iterator over this record tensor as [Record]s instead of the raw `(T, Index)`
325     * data. After manipulating the iterator it can be collected back into a RecordTensor with
326     * [RecordTensor::from_iter](RecordTensor::from_iter).
327     *
328     * This is a shorthand for `AsRecords::from(tensor.history(), TensorIterator::from(&tensor))`
329     */
330    #[allow(clippy::type_complexity)]
331    pub fn iter_as_records<'b>(
332        &'b self,
333    ) -> AsRecords<'a, TensorIterator<'b, (T, Index), RecordTensor<'a, T, S, D>, D>, T> {
334        AsRecords::from_tensor(self)
335    }
336}
337
338impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
339where
340    T: Numeric + Primitive,
341    S: TensorMut<(T, Index), D>,
342{
343    /**
344     * Resets all of the records to place them back on the WengertList, for use
345     * in performing another derivation after clearing the WengertList.
346     *
347     * This is also a preferred shorthand for `map_mut(Record::do_reset)` that can't fail
348     */
349    pub fn reset(&mut self) {
350        match self.history {
351            None => (), // noop
352            Some(history) => {
353                let total = self.elements();
354                let starting_index = history.append_nullary_repeating(total);
355                for (x, i) in self
356                    .numbers
357                    .iter_reference_mut()
358                    .zip(calculate_incrementing_indexes(starting_index, total))
359                {
360                    let (_, old_index) = x;
361                    *old_index = i;
362                }
363            }
364        };
365    }
366
367    /**
368     * A convenience helper function which takes a RecordContainer by value and
369     * calls [reset](RecordTensor::reset()) on it.
370     */
371    pub fn do_reset(mut x: Self) -> Self {
372        x.reset();
373        x
374    }
375}
376
377impl<'a, T> RecordMatrix<'a, T, Matrix<(T, Index)>>
378where
379    T: Numeric + Primitive,
380{
381    /**
382     * Creates multiple untracked Records which have no backing WengertList.
383     *
384     * This is provided for using constants along with Records in operations.
385     *
386     * For example with `Y = X + 4` the computation graph could be conceived as many
387     * `Y[i,j]` nodes with parent nodes of `X[i,j]` and 4 combined with the operation `+`.
388     * However there is no need to record the derivatives of a constant, so
389     * instead the computation graph can be conceived as `Y[i,j]` nodes each with a single
390     * parent node of `X[i,j]` and the unary operation of `+4`.
391     */
392    pub fn constants<S>(c: S) -> Self
393    where
394        S: MatrixMut<T> + NoInteriorMutability,
395    {
396        RecordContainer {
397            numbers: MatrixView::from(Matrix::from_flat_row_major(
398                (c.view_rows(), c.view_columns()),
399                RowMajorOwnedIterator::from_numeric(c)
400                    .map(|x| (x, 0))
401                    .collect(),
402            )),
403            history: None,
404        }
405    }
406
407    /**
408     * Creates multiple records backed by the provided WengertList.
409     *
410     * The records cannot live longer than the WengertList, hence
411     * the following example does not compile
412     *
413     * ```compile_fail
414     * use easy_ml::differentiation::RecordMatrix;
415     * use easy_ml::differentiation::WengertList;
416     * use easy_ml::matrices::Matrix;
417     * let record = {
418     *     let list = WengertList::new();
419     *     RecordMatrix::variables(
420     *         &list,
421     *         Matrix::from(vec![vec![1.0, 2.0], vec![3.0, 4.0]])
422     *     )
423     * }; // list no longer in scope
424     * ```
425     */
426    pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
427    where
428        S: MatrixMut<T> + NoInteriorMutability,
429    {
430        let total = x.view_rows() * x.view_columns();
431        let starting_index = history.append_nullary_repeating(total);
432        RecordContainer {
433            numbers: MatrixView::from(Matrix::from_flat_row_major(
434                (x.view_rows(), x.view_columns()),
435                RowMajorOwnedIterator::from_numeric(x)
436                    .zip(calculate_incrementing_indexes(starting_index, total))
437                    .collect(),
438            )),
439            history: Some(history),
440        }
441    }
442}
443
444impl<'a, T, S> RecordMatrix<'a, T, S>
445where
446    T: Numeric + Primitive,
447    S: MatrixRef<(T, Index)> + NoInteriorMutability,
448{
449    /**
450     * Returns the number of elements stored by this container's source.
451     *
452     * For a 2 x 3 Matrix, this would return 6, and for a 3 x 4 Matrix this would return 12
453     * and so on.
454     */
455    pub fn elements(&self) -> usize {
456        self.numbers.rows() * self.numbers.columns()
457    }
458
459    /**
460     * Returns the dimensionality of this matrix container in Row, Column format
461     */
462    pub fn size(&self) -> (Row, Column) {
463        self.numbers.size()
464    }
465
466    /**
467     * Gets the number of rows visible to this matrix container.
468     */
469    pub fn rows(&self) -> Row {
470        self.numbers.rows()
471    }
472
473    /**
474     * Gets the number of columns visible to this matrix container.
475     */
476    pub fn columns(&self) -> Column {
477        self.numbers.columns()
478    }
479
480    /**
481     * Creates a container from constants/variables directly, most likely obtained by getting a
482     * matrix view of an existing container. **The inputs are not checked for validity**. It is
483     * possible to pass in the wrong Wengert list here or even numbers with indexes that aren't
484     * tracked on the WengertList.
485     *
486     * It is recommended to use this constructor only in conjunction with
487     * resizing or masking an existing container and not for creating new variables. Any variables
488     * created outside of `RecordContainer::variables` would have to be manually added to the
489     * correct Wengert list, and any arithmetic operations would also need tracking correctly.
490     *
491     * ```
492     * use easy_ml::differentiation::RecordMatrix;
493     * use easy_ml::differentiation::WengertList;
494     * use easy_ml::matrices::Matrix;
495     * use easy_ml::matrices::views::{MatrixView, MatrixRange};
496     *
497     * let list = WengertList::new();
498     * let x = RecordMatrix::variables(
499     *     &list,
500     *     Matrix::from_fn((2, 2), |(r, c)| ((r + 3) * (c + 2)) as f64)
501     * );
502     * // oh no wrong shape!
503     * let fixed = MatrixView::from(MatrixRange::from(x, 0..2, 0..1));
504     * let x = RecordMatrix::from_existing(Some(&list), fixed);
505     * assert_eq!((2, 1), x.size());
506     * ```
507     */
508    pub fn from_existing(
509        history: Option<&'a WengertList<T>>,
510        numbers: MatrixView<(T, Index), S>,
511    ) -> Self {
512        RecordContainer { numbers, history }
513    }
514
515    /**
516     * Returns a MatrixView of this record container, both the `T` for each record element and
517     * also the index for that record's entry in the WengertList. These can be parsed back into
518     * a RecordMatrix with [`from_existing`](RecordMatrix::from_existing) or individually into
519     * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
520     * numerical operations on the data.
521     */
522    pub fn view(&self) -> MatrixView<(T, Index), &RecordMatrix<'a, T, S>> {
523        MatrixView::from(self)
524    }
525
526    /**
527     * Returns an iterator over this record matrix as [Record]s instead of the raw `(T, Index)`
528     * data. After manipulating the iterator it can be collected back into a RecordMatrix with
529     * [RecordMatrix::from_iter](RecordMatrix::from_iter).
530     *
531     * This is a shorthand for `AsRecords::from(matrix.history(), RowMajorIterator::from(&matrix))`
532     */
533    #[allow(clippy::type_complexity)]
534    pub fn iter_row_major_as_records<'b>(
535        &'b self,
536    ) -> AsRecords<'a, RowMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
537        AsRecords::from_matrix_row_major(self)
538    }
539
540    /**
541     * Returns an iterator over this record matrix as [Record]s instead of the raw `(T, Index)`
542     * data. After manipulating the iterator it can be collected back into a RecordMatrix with
543     * [RecordMatrix::from_iter](RecordMatrix::from_iter).
544     *
545     * This is a shorthand for
546     * `AsRecords::from(matrix.history(), ColumnMajorIterator::from(&matrix))`
547     */
548    #[allow(clippy::type_complexity)]
549    pub fn iter_column_major_as_records<'b>(
550        &'b self,
551    ) -> AsRecords<'a, ColumnMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
552        AsRecords::from_matrix_column_major(self)
553    }
554
555    /**
556     * Returns a copy of the data at the index as a Record. If you need to access all the data
557     * as records instead of just a specific index you should probably use one of the iterator
558     * APIs instead.
559     *
560     * See also: [iter_row_major_as_records](RecordMatrix::iter_row_major_as_records),
561     * [iter_column_major_as_records](RecordMatrix::iter_column_major_as_records)
562     *
563     * # Panics
564     *
565     * If the index is out of range.
566     *
567     * For a non panicking API see [try_get_as_record](RecordMatrix::try_get_as_record)
568     */
569    #[track_caller]
570    pub fn get_as_record(&self, row: Row, column: Column) -> Record<'a, T> {
571        Record::from_existing(self.numbers.get(row, column), self.history)
572    }
573
574    /**
575     * Returns a copy of the data at the index as a Record, or None if the index is
576     * out of range. If you need to access all the data as records instead of just a specific
577     * index you should probably use one of the iterator APIs instead.
578     *
579     * See also: [iter_row_major_as_records](RecordMatrix::iter_row_major_as_records),
580     * [iter_column_major_as_records](RecordMatrix::iter_column_major_as_records)
581     */
582    pub fn try_get_as_record(&self, row: Row, column: Column) -> Option<Record<'a, T>> {
583        self.numbers
584            .try_get_reference(row, column)
585            .map(|r| Record::from_existing(r.clone(), self.history))
586    }
587}
588
589impl<'a, T, S> RecordMatrix<'a, T, S>
590where
591    T: Numeric + Primitive,
592    S: MatrixMut<(T, Index)> + NoInteriorMutability,
593{
594    /**
595     * Resets all of the records to place them back on the WengertList, for use
596     * in performing another derivation after clearing the WengertList.
597     *
598     * This is also a preferred shorthand for `map_mut(Record::do_reset)` that can't fail
599     */
600    pub fn reset(&mut self) {
601        match self.history {
602            None => (), // noop
603            Some(history) => {
604                let total = self.elements();
605                let starting_index = history.append_nullary_repeating(total);
606                for (x, i) in self
607                    .numbers
608                    .row_major_reference_mut_iter()
609                    .zip(calculate_incrementing_indexes(starting_index, total))
610                {
611                    let (_, old_index) = x;
612                    *old_index = i;
613                }
614            }
615        };
616    }
617
618    /**
619     * A convenience helper function which takes a RecordContainer by value and
620     * calls [reset](RecordMatrix::reset()) on it.
621     */
622    pub fn do_reset(mut x: Self) -> Self {
623        x.reset();
624        x
625    }
626}
627
628impl<'a, T, S, const D: usize> RecordContainer<'a, T, S, D>
629where
630    T: Primitive,
631{
632    /**
633     * Gets the WengertList these records are backed by if variables, and [None](None) if constants.
634     */
635    pub fn history(&self) -> Option<&'a WengertList<T>> {
636        self.history
637    }
638}
639
640/// Returns the vec of indexes and vec of ys for Y = unary(X), not checking but assuming that the
641/// length of the iterator matches the total.
642fn unary<'a, T, I>(
643    total: usize,
644    history: &WengertList<T>,
645    records: I,
646    fx: impl Fn(T) -> T,
647    dfx_dx: impl Fn(T) -> T,
648) -> Vec<(T, usize)>
649where
650    I: Iterator<Item = (T, Index)>,
651    T: Numeric + Primitive,
652    for<'t> &'t T: NumericRef<T>,
653{
654    let mut ys = vec![(T::zero(), 0); total];
655    history.borrow(|history| {
656        // shadow the name so we can't accidentally try to use history while holding
657        // the borrow
658        // use enumerate not with_index because we need the 1D index for indexing
659        // indexes
660        for (i, (x, parent)) in records.enumerate() {
661            let y = fx(x.clone());
662            let derivative = dfx_dx(x);
663            let new_index = history.append_unary(parent, derivative);
664            ys[i] = (y, new_index)
665        }
666    }); // drop borrow on history
667    ys
668}
669
670/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), not checking but assuming that
671/// the length of the iterators match the total. Also assumes both inputs have the same shape
672fn binary_both_history<'a, T, I1, I2>(
673    total: usize,
674    history: &WengertList<T>,
675    x_records: I1,
676    y_records: I2,
677    fxy: impl Fn(T, T) -> T,
678    dfxy_dx: impl Fn(T, T) -> T,
679    dfxy_dy: impl Fn(T, T) -> T,
680) -> Vec<(T, usize)>
681where
682    I1: Iterator<Item = (T, Index)>,
683    I2: Iterator<Item = (T, Index)>,
684    T: Numeric + Primitive,
685    for<'t> &'t T: NumericRef<T>,
686{
687    let mut zs = vec![(T::zero(), 0); total];
688    history.borrow(|history| {
689        // shadow the name so we can't accidentally try to use history while holding
690        // the borrow
691        // use enumerate not with_index because we need the 1D index for indexing
692        // indexes
693        for (i, ((x, parent1), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
694            let z = fxy(x.clone(), y.clone());
695            let derivative1 = dfxy_dx(x.clone(), y.clone());
696            let derivative2 = dfxy_dy(x, y);
697            let new_index = history.append_binary(parent1, derivative1, parent2, derivative2);
698            zs[i] = (z, new_index);
699        }
700    }); // drop borrow on history
701    zs
702}
703
704/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), as with binary_both_history,
705/// but only tracking the derivatives for X, not Y.
706fn binary_x_history<'a, T, I1, I2>(
707    total: usize,
708    history: &WengertList<T>,
709    x_records: I1,
710    y_records: I2,
711    fxy: impl Fn(T, T) -> T,
712    dfxy_dx: impl Fn(T, T) -> T,
713) -> Vec<(T, usize)>
714where
715    I1: Iterator<Item = (T, Index)>,
716    I2: Iterator<Item = (T, Index)>,
717    T: Numeric + Primitive,
718    for<'t> &'t T: NumericRef<T>,
719{
720    let mut zs = vec![(T::zero(), 0); total];
721    history.borrow(|history| {
722        // shadow the name so we can't accidentally try to use history while holding
723        // the borrow
724        // use enumerate not with_index because we need the 1D index for indexing
725        // indexes
726        for (i, ((x, parent1), (y, _))) in (x_records.zip(y_records)).enumerate() {
727            let z = fxy(x.clone(), y.clone());
728            // if rhs didn't have a history, don't track that derivative
729            let derivative1 = dfxy_dx(x, y);
730            let new_index = history.append_unary(parent1, derivative1);
731            zs[i] = (z, new_index);
732        }
733    }); // drop borrow on history
734    zs
735}
736
737/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), as with binary_both_history,
738/// but only tracking the derivatives for Y, not X.
739fn binary_y_history<'a, T, I1, I2>(
740    total: usize,
741    history: &WengertList<T>,
742    x_records: I1,
743    y_records: I2,
744    fxy: impl Fn(T, T) -> T,
745    dfxy_dy: impl Fn(T, T) -> T,
746) -> Vec<(T, usize)>
747where
748    I1: Iterator<Item = (T, Index)>,
749    I2: Iterator<Item = (T, Index)>,
750    T: Numeric + Primitive,
751    for<'t> &'t T: NumericRef<T>,
752{
753    let mut zs = vec![(T::zero(), 0); total];
754    history.borrow(|history| {
755        // shadow the name so we can't accidentally try to use history while holding
756        // the borrow
757        // use enumerate not with_index because we need the 1D index for indexing
758        // indexes
759        for (i, ((x, _), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
760            let z = fxy(x.clone(), y.clone());
761            // if self didn't have a history, don't track that derivative
762            let derivative2 = dfxy_dy(x, y);
763            let new_index = history.append_unary(parent2, derivative2);
764            zs[i] = (z, new_index);
765        }
766    }); // drop borrow on history
767    zs
768}
769
770impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
771where
772    T: Numeric + Primitive,
773    for<'t> &'t T: NumericRef<T>,
774    S: TensorRef<(T, Index), D>,
775{
776    /**
777     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
778     * some unary function from `T` to `T` to every element in the container.
779     *
780     * To compute the new records, the unary function of some input x to some
781     * output y is needed along with its derivative with respect to its input x.
782     *
783     * For example, tanh is a commonly used activation function, but the Real trait
784     * does not include this operation and Record has no operations for it specifically.
785     * However, you can use this function to compute the tanh for a record container like so:
786     *
787     * ```
788     * use easy_ml::differentiation::{RecordTensor, WengertList};
789     * use easy_ml::tensors::Tensor;
790     * let list = WengertList::new();
791     * let X = RecordTensor::variables(
792     *     &list,
793     *     Tensor::from_fn(
794     *         [("rows", 2), ("columns", 2)],
795     *         |[r, c]| 0.15 * ((1 + r + c) as f32)
796     *     )
797     * );
798     * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
799     * // 1 / (cosh(x) * cosh(x))
800     * let Y = X.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
801     *
802     * // we can unwrap here because we know Y contains variables not constants
803     * let derivatives = Y.derivatives().unwrap();
804     * let derivatives_indexing = derivatives.index_by(["rows", "columns"]);
805     * assert_eq!(
806     *     derivatives_indexing.get_ref([0, 0]).at_tensor(&X),
807     *     Tensor::from(
808     *         [("rows", 2), ("columns", 2)],
809     *         // [0, 0] element in Y only had the one input variable [0, 0] in X
810     *         vec![
811     *             0.9778332, 0.0,
812     *             0.0,       0.0
813     *        ]
814     *     ),
815     * );
816     * assert_eq!(
817     *     derivatives_indexing.get_ref([0, 1]).at_tensor(&X),
818     *     Tensor::from(
819     *         [("rows", 2), ("columns", 2)],
820     *         vec![
821     *             0.0, 0.915137,
822     *             0.0, 0.0
823     *        ]
824     *     ),
825     * );
826     * assert_eq!(
827     *     // [0, 1] and [1, 0] elements in X had the same starting value so end up with the same
828     *     // derivative for their corresponding input variable in X
829     *     derivatives_indexing.get_ref([0, 1]).at_tensor(&X).index().get([0, 1]),
830     *     derivatives_indexing.get_ref([1, 0]).at_tensor(&X).index().get([1, 0]),
831     * );
832     * assert_eq!(
833     *     derivatives_indexing.get_ref([1, 1]).at_tensor(&X),
834     *     Tensor::from(
835     *         [("rows", 2), ("columns", 2)],
836     *         vec![
837     *             0.0, 0.0,
838     *             0.0, 0.8220013
839     *        ]
840     *     ),
841     * );
842     * ```
843     */
844    #[track_caller]
845    pub fn unary(
846        &self,
847        fx: impl Fn(T) -> T,
848        dfx_dx: impl Fn(T) -> T,
849    ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D> {
850        let total = self.elements();
851        match self.history {
852            None => RecordTensor::constants(self.numbers.map(|(x, _)| fx(x))),
853            Some(history) => {
854                let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
855                RecordContainer {
856                    numbers: self.numbers.new_with_same_shape(ys),
857                    history: Some(history),
858                }
859            }
860        }
861    }
862
863    /**
864     * Creates a new RecordContainer from two RecordContainers by applying
865     * some binary function from `T` to `T` to every element pair in the containers. Both
866     * containers must have the same shape.
867     *
868     * To compute the new records, the binary function of some inputs x and y to some
869     * output z is needed along with its derivative with respect to its first input x and
870     * its derivative with respect to its second input y.
871     *
872     * For example, atan2 takes two arguments, but the Real trait
873     * does not include this operation and Record has no operations for it specifically.
874     * However, you can use this function to compute the atan2 for two record containers like so:
875     *
876     * ```
877     * use easy_ml::differentiation::{RecordTensor, WengertList};
878     * use easy_ml::tensors::Tensor;
879     * let list = WengertList::new();
880     * let X = RecordTensor::variables(
881     *     &list,
882     *     Tensor::from_fn(
883     *         [("rows", 2), ("columns", 2)],
884     *         |[r, c]| ((1 + r + c) as f32)
885     *     )
886     * );
887     * let Y = RecordTensor::variables(
888     *     &list,
889     *     Tensor::from_fn(
890     *         [("rows", 2), ("columns", 2)],
891     *         |[r, c]| ((1 + r + c) as f32)
892     *     )
893     * );
894     * // the derivative of atan2 with respect to x is y/(x*x + y*y)
895     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
896     * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
897     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
898     * let Z = X.binary(&Y,
899     *     |x, y| x.atan2(y),
900     *     |x, y| y/((x*x) + (y*y)),
901     *     |x, y| -x/((x*x) + (y*y))
902     * );
903     *
904     *
905     * // we can unwrap here because we know Z contains variables not constants
906     * let derivatives = Z.derivatives().unwrap();
907     * // Just as in the unary example, only one pair of the four inputs in X and Y influence the
908     * // outputs in Z, so we have a lot of 0.0 derivatives, and the inputs in [0, 1] and [1, 0]
909     * // are identical so we see the same derivative.
910     * let dZ_dX = derivatives.map(|d| d.at_tensor(&X));
911     * assert_eq!(
912     *     dZ_dX,
913     *     Tensor::from([("rows", 2), ("columns", 2)], vec![
914     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
915     *             0.5, 0.0,
916     *             0.0, 0.0
917     *         ]),
918     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
919     *             0.0, 0.25,
920     *             0.0, 0.0
921     *         ]),
922     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
923     *             0.0, 0.0,
924     *             0.25, 0.0
925     *         ]),
926     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
927     *             0.0, 0.0,
928     *             0.0, 0.16666667
929     *         ])
930     *     ])
931     * );
932     * let dZ_dY = derivatives.map(|d| d.at_tensor(&Y));
933     * assert_eq!(
934     *     dZ_dY,
935     *     Tensor::from([("rows", 2), ("columns", 2)], vec![
936     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
937     *             -0.5, 0.0,
938     *             0.0, 0.0
939     *         ]),
940     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
941     *             0.0, -0.25,
942     *             0.0, 0.0
943     *         ]),
944     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
945     *             0.0, 0.0,
946     *             -0.25, 0.0
947     *         ]),
948     *         Tensor::from([("rows", 2), ("columns", 2)], vec![
949     *             0.0, 0.0,
950     *             0.0, -0.16666667
951     *         ])
952     *     ])
953     * );
954     * ```
955     *
956     * # Panics
957     *
958     * - If both record containers have a WengertList that are different to each other
959     * - If the record containers have different shapes
960     */
961    #[track_caller]
962    pub fn binary<S2>(
963        &self,
964        rhs: &RecordTensor<'a, T, S2, D>,
965        fxy: impl Fn(T, T) -> T,
966        dfxy_dx: impl Fn(T, T) -> T,
967        dfxy_dy: impl Fn(T, T) -> T,
968    ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
969    where
970        S2: TensorRef<(T, Index), D>,
971    {
972        {
973            let left_shape = self.numbers.shape();
974            let right_shape = rhs.numbers.shape();
975            if left_shape != right_shape {
976                panic!(
977                    "Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
978                    left_shape, right_shape
979                );
980            }
981        }
982        let total = self.elements();
983        match (self.history, rhs.history) {
984            (None, None) => RecordTensor::constants(
985                // use direct_from here maybe?
986                Tensor::from(
987                    self.numbers.shape(),
988                    self.numbers
989                        .iter()
990                        .zip(rhs.numbers.iter())
991                        .map(|((x, _), (y, _))| fxy(x, y))
992                        .collect(),
993                ),
994            ),
995            (Some(history), None) => {
996                let zs = binary_x_history::<T, _, _>(
997                    total,
998                    history,
999                    self.numbers.iter(),
1000                    rhs.numbers.iter(),
1001                    fxy,
1002                    dfxy_dx,
1003                );
1004                RecordContainer {
1005                    numbers: self.numbers.new_with_same_shape(zs),
1006                    history: Some(history),
1007                }
1008            }
1009            (None, Some(history)) => {
1010                let zs = binary_y_history::<T, _, _>(
1011                    total,
1012                    history,
1013                    self.numbers.iter(),
1014                    rhs.numbers.iter(),
1015                    fxy,
1016                    dfxy_dy,
1017                );
1018                RecordContainer {
1019                    numbers: self.numbers.new_with_same_shape(zs),
1020                    history: Some(history),
1021                }
1022            }
1023            (Some(history), Some(h)) => {
1024                assert!(
1025                    record_operations::same_lists(history, h),
1026                    "Record containers must be using the same WengertList"
1027                );
1028                let zs = binary_both_history::<T, _, _>(
1029                    total,
1030                    history,
1031                    self.numbers.iter(),
1032                    rhs.numbers.iter(),
1033                    fxy,
1034                    dfxy_dx,
1035                    dfxy_dy,
1036                );
1037                RecordContainer {
1038                    numbers: self.numbers.new_with_same_shape(zs),
1039                    history: Some(history),
1040                }
1041            }
1042        }
1043    }
1044
1045    /**
1046     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1047     * some unary function on `Record<T>` to `Record<T>` to every element in the container. This
1048     * will fail if the function would create records with inconsistent histories.
1049     *
1050     * When used with pure functions that can't return different histories for different inputs
1051     * unwrapping with always succeed.
1052     *
1053     * This API can allow you to call a generic function that operates on
1054     * [Numeric](crate::numeric::Numeric) or [Real](crate::numeric::extra::Real) numbers and
1055     * apply all the correct derivative tracking during the intermediate calculations for you,
1056     * without having to resort to storing the Record types.
1057     *
1058     * ```
1059     * use easy_ml::numeric::extra::Real;
1060     * use easy_ml::tensors::Tensor;
1061     * use easy_ml::differentiation::{RecordTensor, WengertList};
1062     *
1063     * fn sigmoid<T: Real+ Copy>(x: T) -> T {
1064     *     T::one() / (T::one() + (-x).exp())
1065     * }
1066     *
1067     * let history = WengertList::new();
1068     * let layer = RecordTensor::variables(&history, Tensor::from([("x", 2)], vec![ 0.2, 0.6 ]));
1069     * let after = layer.map(sigmoid).unwrap(); // sigmoid can't introduce inconsistent histories
1070     * ```
1071     *
1072     * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1073     * after mapping doesn't have to be the same as before, only must be the same for every
1074     * mapped element.
1075     *
1076     * See also: [AsRecords](AsRecords)
1077     */
1078    #[allow(clippy::type_complexity)]
1079    #[track_caller]
1080    pub fn map(
1081        &self,
1082        fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1083    ) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
1084        let result = RecordTensor::from_iter(self.shape(), self.iter_as_records().map(fx));
1085        RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
1086    }
1087
1088    /**
1089     * Returns a new Tensor of the same shape as this RecordContainer by
1090     * applying a function to every element from the Record to the Record's number
1091     * type. For example, you could lookup the derivatives of each Record in
1092     * the container with respect to some input.
1093     */
1094    pub fn map_to_tensor(&self, fx: impl Fn(Record<'a, T>) -> T) -> Tensor<T, D> {
1095        Tensor::from(self.shape(), self.iter_as_records().map(fx).collect())
1096    }
1097
1098    /**
1099     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1100     * some unary function on `Record<T>` and each index of that position in the Record to
1101     * `Record<T>` to every element in the container. This will fail if the function would create
1102     * records with inconsistent histories.
1103     *
1104     * When used with pure functions that can't return different histories for different inputs
1105     * unwrapping with always succeed.
1106     *
1107     * This API can allow you to call a generic function that operates on
1108     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1109     * during the intermediate calculations for you, without having to resort to storing the
1110     * Record types.
1111     *
1112     * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1113     * after mapping doesn't have to be the same as before, only must be the same for every
1114     * mapped element.
1115     */
1116    #[allow(clippy::type_complexity)]
1117    #[track_caller]
1118    pub fn map_with_index(
1119        &self,
1120        fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
1121    ) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
1122        let result = RecordTensor::from_iter(
1123            self.shape(),
1124            self.iter_as_records().with_index().map(|(i, x)| fx(i, x)),
1125        );
1126        RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
1127    }
1128
1129    #[track_caller]
1130    #[allow(clippy::type_complexity)]
1131    fn map_collection(
1132        result: Result<
1133            RecordTensor<'a, T, Tensor<(T, usize), D>, D>,
1134            InvalidRecordIteratorError<'a, T, D>,
1135        >,
1136        shape: [(Dimension, usize); D],
1137    ) -> Result<RecordTensor<'a, T, Tensor<(T, usize), D>, D>, InconsistentHistory<'a, T>> {
1138        use InvalidRecordIteratorError as Error;
1139        match result {
1140            Ok(tensor) => Ok(tensor),
1141            Err(error) => match error {
1142                // These first two should be 100% impossible but provide a sensible error just
1143                // in case some weird things break our invariants
1144                Error::Empty => panic!("Illegal state, record tensor was empty {:?}", shape),
1145                Error::Shape { requested, length } => panic!(
1146                    "Illegal state, record tensor shape was inconsistent: requested: {:?}, length of data: {:?}",
1147                    requested, length
1148                ),
1149                // This one is theoretically possible but in practise shouldn't happen by accident
1150                // However, it can't implement Debug unless T is debug so to avoid having to
1151                // restrict our function signature we return a Result anyway - this also encourages
1152                // the user to make sure their function isn't going to cause this case, which
1153                // with some of the other variants like with_index might come up more easily
1154                Error::InconsistentHistory(h) => Err(h),
1155            },
1156        }
1157    }
1158
1159    /**
1160     * For each record in the container, peforms a backward pass up its WengertList from it
1161     * as the output, computing all the derivatives for the inputs involving this output.
1162     *
1163     * If this container has no backing WengertList, ie was created as constants, then None is
1164     * returned instead. Otherwise the returned Tensor will have the same shape as this container,
1165     * with the respective derivatives matching each element in this container.
1166     *
1167     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is Y with M outputs,
1168     * then this computes all the derivatives δy<sub>j</sub>/δx<sub>i</sub> for i = 1 to N and
1169     * j = 1 to M.
1170     *
1171     * If you have a lot of outputs this could be very expensive! Reverse auto diff is optimised
1172     * for domains where there are many more inputs than outputs.
1173     *
1174     * If you only need some of the derivatives then
1175     * [derivatives_for](RecordTensor::derivatives_for) can be used instead to avoid
1176     * calculating the rest.
1177     */
1178    pub fn derivatives(&self) -> Option<Tensor<Derivatives<T>, D>> {
1179        self.history.map(|history| {
1180            self.numbers.map(|(x, i)| {
1181                Record {
1182                    number: x,
1183                    history: Some(history),
1184                    index: i,
1185                }
1186                .derivatives()
1187            })
1188        })
1189    }
1190
1191    /**
1192     * For the record at the index, peforms a backward pass up its WengertList from it
1193     * as the output, computing all the derivatives for the inputs involving this output.
1194     *
1195     * If the index is invalid or this container has no backing WengertList, ie was created
1196     * as constants, then None is returned instead.
1197     *
1198     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
1199     * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
1200     */
1201    pub fn derivatives_for(&self, indexes: [usize; D]) -> Option<Derivatives<T>> {
1202        let (number, index) = self.get_reference(indexes).map(|(x, i)| (x.clone(), *i))?;
1203        // The nature of reverse autodiff is that we expect to only have a few outputs from
1204        // which we calculate all the derivatives we care about. Therefore just call Record and
1205        // reuse the implementation instead of trying to do anything clever like calculate all
1206        // derivatives for every number in this container.
1207        Record {
1208            number,
1209            history: self.history,
1210            index,
1211        }
1212        .try_derivatives()
1213    }
1214
1215    /**
1216     * Performs elementwise multiplication for two record tensors of the same shape.
1217     *
1218     * # Panics
1219     *
1220     * - If both record containers have a WengertList that are different to each other
1221     * - If the record containers have different shapes
1222     */
1223    // TODO: Assign variants?
1224    pub fn elementwise_multiply<S2>(
1225        &self,
1226        other: &RecordTensor<'a, T, S2, D>,
1227    ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1228    where
1229        S2: TensorRef<(T, Index), D>,
1230    {
1231        self.binary(
1232            other,
1233            Multiplication::<T>::function,
1234            Multiplication::<T>::d_function_dx,
1235            Multiplication::<T>::d_function_dy,
1236        )
1237    }
1238
1239    /**
1240     * Performs elementwise division for two record tensors of the same shape.
1241     *
1242     * # Panics
1243     *
1244     * - If both record containers have a WengertList that are different to each other
1245     * - If the record containers have different shapes
1246     */
1247    pub fn elementwise_divide<S2>(
1248        &self,
1249        other: &RecordTensor<'a, T, S2, D>,
1250    ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1251    where
1252        S2: TensorRef<(T, Index), D>,
1253    {
1254        self.binary(
1255            other,
1256            Division::<T>::function,
1257            Division::<T>::d_function_dx,
1258            Division::<T>::d_function_dy,
1259        )
1260    }
1261}
1262
1263impl<T: Clone + Primitive> Derivatives<T> {
1264    /**
1265     * Queries the derivative at the provided index into the record tensor as input.
1266     *
1267     * If you construct a Derivatives object for some output y,
1268     * and call .at_tensor_index(i, &xs) on it for some input container xs and index i, this
1269     * returns dy/dx where x = xs\[i\].
1270     *
1271     * If the index into the tensor is invalid, returns None instead.
1272     */
1273    pub fn at_tensor_index<S, const D: usize>(
1274        &self,
1275        indexes: [usize; D],
1276        input: &RecordTensor<T, S, D>,
1277    ) -> Option<T>
1278    where
1279        S: TensorRef<(T, Index), D>,
1280    {
1281        let index = input.get_reference(indexes).map(|(_, i)| *i)?;
1282        Some(self.derivatives[index].clone())
1283    }
1284
1285    /**
1286     * Queries the derivatives at every element in the record tensor input.
1287     *
1288     * If you construct a Derivatives object for some output y,
1289     * and call .at_tensor(&xs) on it for some input container xs this
1290     * returns dy/dx for every x in xs.
1291     */
1292    pub fn at_tensor<S, const D: usize>(&self, input: &RecordTensor<T, S, D>) -> Tensor<T, D>
1293    where
1294        S: TensorRef<(T, Index), D>,
1295    {
1296        input.numbers.map(|(_, i)| self.derivatives[i].clone())
1297    }
1298
1299    /**
1300     * Queries the derivative at the provided index into the record matrix as input.
1301     *
1302     * If you construct a Derivatives object for some output y,
1303     * and call .at_matrix_index(i, j, &xs) on it for some input container xs and indexes i and j,
1304     * this returns dy/dx where x = xs\[i, j\].
1305     *
1306     * If the index into the tensor is invalid, returns None instead.
1307     */
1308    pub fn at_matrix_index<S>(
1309        &self,
1310        row: Row,
1311        column: Column,
1312        input: &RecordMatrix<T, S>,
1313    ) -> Option<T>
1314    where
1315        S: MatrixRef<(T, Index)> + NoInteriorMutability,
1316    {
1317        let index = input.try_get_reference(row, column).map(|(_, i)| *i)?;
1318        Some(self.derivatives[index].clone())
1319    }
1320
1321    /**
1322     * Queries the derivatives at every element in the record matrix input.
1323     *
1324     * If you construct a Derivatives object for some output y,
1325     * and call .at_matrix(&xs) on it for some input container xs this
1326     * returns dy/dx for every x in xs.
1327     */
1328    pub fn at_matrix<S>(&self, input: &RecordMatrix<T, S>) -> Matrix<T>
1329    where
1330        S: MatrixRef<(T, Index)> + NoInteriorMutability,
1331    {
1332        input.numbers.map(|(_, i)| self.derivatives[i].clone())
1333    }
1334}
1335
1336impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
1337where
1338    T: Numeric + Primitive,
1339    for<'t> &'t T: NumericRef<T>,
1340    S: TensorMut<(T, Index), D>,
1341{
1342    /**
1343     * Overwrites a RecordContainer by applying
1344     * some unary function from `T` to `T` to every element in the container.
1345     *
1346     * To compute the new records, the unary function of some input x to some
1347     * output y is needed along with its derivative with respect to its input x.
1348     */
1349    #[track_caller]
1350    pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
1351        let total = self.elements();
1352        match self.history {
1353            None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
1354            Some(history) => {
1355                let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
1356                for (element, result) in self.numbers.iter_reference_mut().zip(ys) {
1357                    *element = result;
1358                }
1359                self.history = Some(history);
1360            }
1361        }
1362    }
1363
1364    /**
1365     * Overwrites the left hand side of a RecordContainer with the result of applying
1366     * some binary function from `T` to `T` to every element pair in the containers. Both
1367     * containers must have the same shape.
1368     * To compute the new records, the binary function of some inputs x and y to some
1369     * output z is needed along with its derivative with respect to its first input x and
1370     * its derivative with respect to its second input y.
1371     *
1372     * # Panics
1373     *
1374     * - If both record containers have a WengertList that are different to each other
1375     * - If the record containers have different shapes
1376     */
1377    #[track_caller]
1378    pub fn binary_left_assign<S2>(
1379        &mut self,
1380        rhs: &RecordTensor<'a, T, S2, D>,
1381        fxy: impl Fn(T, T) -> T,
1382        dfxy_dx: impl Fn(T, T) -> T,
1383        dfxy_dy: impl Fn(T, T) -> T,
1384    ) where
1385        S2: TensorRef<(T, Index), D>,
1386    {
1387        {
1388            let left_shape = self.numbers.shape();
1389            let right_shape = rhs.numbers.shape();
1390            if left_shape != right_shape {
1391                panic!(
1392                    "Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
1393                    left_shape, right_shape
1394                );
1395            }
1396        }
1397        let total = self.elements();
1398        match (self.history, rhs.history) {
1399            (None, None) => {
1400                for (x, y) in self.numbers.iter_reference_mut().zip(rhs.numbers.iter()) {
1401                    let (left, _) = x;
1402                    let (right, _) = y;
1403                    *x = (fxy(left.clone(), right), 0);
1404                }
1405            }
1406            (Some(history), None) => {
1407                let zs = binary_x_history::<T, _, _>(
1408                    total,
1409                    history,
1410                    self.numbers.iter(),
1411                    rhs.numbers.iter(),
1412                    fxy,
1413                    dfxy_dx,
1414                );
1415                for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1416                    *element = result;
1417                }
1418                self.history = Some(history);
1419            }
1420            (None, Some(history)) => {
1421                let zs = binary_y_history::<T, _, _>(
1422                    total,
1423                    history,
1424                    self.numbers.iter(),
1425                    rhs.numbers.iter(),
1426                    fxy,
1427                    dfxy_dy,
1428                );
1429                for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1430                    *element = result;
1431                }
1432                self.history = Some(history);
1433            }
1434            (Some(history), Some(h)) => {
1435                assert!(
1436                    record_operations::same_lists(history, h),
1437                    "Record containers must be using the same WengertList"
1438                );
1439                let zs = binary_both_history::<T, _, _>(
1440                    total,
1441                    history,
1442                    self.numbers.iter(),
1443                    rhs.numbers.iter(),
1444                    fxy,
1445                    dfxy_dx,
1446                    dfxy_dy,
1447                );
1448                for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1449                    *element = result;
1450                }
1451                self.history = Some(history);
1452            }
1453        }
1454    }
1455
1456    /**
1457     * A convenience helper function which takes the RecordContainer value and
1458     * calls [unary_assign](RecordTensor::unary_assign()) on it, returning
1459     * the record container which now contains the result of the operation.
1460     */
1461    #[track_caller]
1462    pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
1463        self.unary_assign(fx, dfx_dx);
1464        self
1465    }
1466
1467    /**
1468     * A convenience helper function which takes the left hand side by value and
1469     * calls [binary_left_assign](RecordTensor::binary_left_assign()) on it, returning
1470     * the left hand side which now contains the result of the operation.
1471     */
1472    #[track_caller]
1473    pub fn do_binary_left_assign<S2>(
1474        mut self,
1475        rhs: &RecordTensor<'a, T, S2, D>,
1476        fxy: impl Fn(T, T) -> T,
1477        dfxy_dx: impl Fn(T, T) -> T,
1478        dfxy_dy: impl Fn(T, T) -> T,
1479    ) -> Self
1480    where
1481        S2: TensorRef<(T, Index), D>,
1482    {
1483        self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
1484        self
1485    }
1486
1487    /**
1488     * Updates this RecordContainer in place by applying some unary function on `Record<T>` to
1489     * `Record<T>` to every element in the container. This will fail if the function would create
1490     * records with inconsistent histories.
1491     *
1492     * When used with pure functions that can't return different histories for different inputs
1493     * unwrapping with always succeed.
1494     *
1495     * Since this updates the container in place, if Err is returned then the data in this
1496     * RecordContainer is still available but it has been corrupted - at least one of the elements
1497     * should have a different history than what it will have because the mapping function created
1498     * inconsistent histories that couldn't be represented by the container as it only stores
1499     * one.
1500     *
1501     * This API can allow you to call a generic function that operates on
1502     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1503     * during the intermediate calculations for you, without having to resort to storing the
1504     * Record types.
1505     *
1506     * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1507     * after mapping doesn't have to be the same as before, only must be the same for every
1508     * mapped element.
1509     *
1510     * You might also use this function at the end of a training loop to update all the weights
1511     * to reduce their loss.
1512     *
1513     * ```
1514     * use easy_ml::numeric::Numeric;
1515     * use easy_ml::tensors::Tensor;
1516     * use easy_ml::differentiation::{Record, RecordTensor, WengertList};
1517     *
1518     * let history = WengertList::new();
1519     * let mut weights = RecordTensor::variables(
1520     *     &history,
1521     *     Tensor::from([("w1", 4)], vec![ 0.3, 0.2, -1.2, -0.4 ])
1522     * );
1523     * let error = {
1524     *     // this is over-simplified for brevity, obviously in a real scenario we wouldn't have a
1525     *     // function that calculates the error like this or we wouldn't be doing machine learning
1526     *     // to fit it in the first place
1527     *     let mut loss = Record::variable(0.0, &history);
1528     *     for r in weights.iter_as_records() {
1529     *         loss = loss + r;
1530     *     }
1531     *     loss
1532     * };
1533     * let derivatives = error.derivatives();
1534     * let learning_rate = 0.1;
1535     * // update the weights to contain less error than they did before
1536     * let result = weights.map_mut(|x| x - (derivatives[&x] * learning_rate));
1537     * assert!(result.is_ok()); // we know we didn't introduce an inconsistent history just updating the weights
1538     * ```
1539     */
1540    #[track_caller]
1541    pub fn map_mut(
1542        &mut self,
1543        fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1544    ) -> Result<(), InconsistentHistory<'a, T>> {
1545        let history = self.history;
1546        let new_history =
1547            map_mut_base::<'a, T, _, _>(TensorReferenceMutIterator::from(self), |x| {
1548                let record = Record::from_existing(x.clone(), history);
1549                let result = fx(record);
1550                *x = (result.number, result.index);
1551                result.history
1552            })?;
1553        self.history = new_history;
1554        Ok(())
1555    }
1556
1557    /**
1558     * Updates this RecordContainer in place by applying some unary function on `Record<T>` and
1559     * each index of that position in the Record to `Record<T>` to every element in the container.
1560     * This will fail if the function would create records with inconsistent histories.
1561     *
1562     * When used with pure functions that can't return different histories for different inputs
1563     * unwrapping with always succeed.
1564     *
1565     * Since this updates the container in place, if Err is returned then the data in this
1566     * RecordContainer is still available but it has been corrupted - at least one of the elements
1567     * should have a different history than what it will have because the mapping function created
1568     * inconsistent histories that couldn't be represented by the container as it only stores
1569     * one.
1570     *
1571     * This API can allow you to call a generic function that operates on
1572     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1573     * during the intermediate calculations for you, without having to resort to storing the
1574     * Record types.
1575     *
1576     * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1577     * after mapping doesn't have to be the same as before, only must be the same for every
1578     * mapped element.
1579     */
1580    #[track_caller]
1581    pub fn map_mut_with_index(
1582        &mut self,
1583        fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
1584    ) -> Result<(), InconsistentHistory<'a, T>> {
1585        let history = self.history;
1586        let new_history = map_mut_base::<'a, T, _, _>(
1587            TensorReferenceMutIterator::from(self).with_index(),
1588            |(i, x)| {
1589                let record = Record::from_existing(x.clone(), history);
1590                let result = fx(i, record);
1591                *x = (result.number, result.index);
1592                result.history
1593            },
1594        )?;
1595        self.history = new_history;
1596        Ok(())
1597    }
1598}
1599
1600#[track_caller]
1601fn map_mut_base<'a, T, I, X>(
1602    mut iter: I,
1603    fx: impl Fn(X) -> Option<&'a WengertList<T>>,
1604) -> Result<Option<&'a WengertList<T>>, InconsistentHistory<'a, T>>
1605where
1606    I: Iterator<Item = X>,
1607    T: Primitive,
1608{
1609    use crate::differentiation::record_operations::are_exact_same_list;
1610    #[rustfmt::skip]
1611    let first_history = fx(iter.next().expect("Illegal state, record container was empty"));
1612    let mut different_history: Option<Option<&WengertList<T>>> = None;
1613    for x in iter {
1614        let history = fx(x);
1615        if !are_exact_same_list(history, first_history) {
1616            different_history = Some(history);
1617        }
1618    }
1619    match different_history {
1620        None => Ok(first_history),
1621        Some(h) => Err(InconsistentHistory {
1622            first: first_history,
1623            later: h,
1624        }),
1625    }
1626}
1627
1628impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
1629where
1630    T: Numeric + Primitive,
1631    for<'t> &'t T: NumericRef<T>,
1632    S: TensorRef<(T, Index), D>,
1633{
1634    /**
1635     * Overwrites the right hand side of a RecordContainer with the result of applying
1636     * some binary function from `T` to `T` to every element pair in the containers. Both
1637     * containers must have the same shape.
1638     * To compute the new records, the binary function of some inputs x and y to some
1639     * output z is needed along with its derivative with respect to its first input x and
1640     * its derivative with respect to its second input y.
1641     *
1642     * # Panics
1643     *
1644     * - If both record containers have a WengertList that are different to each other
1645     * - If the record containers have different shapes
1646     */
1647    #[track_caller]
1648    pub fn binary_right_assign<S2>(
1649        &self,
1650        rhs: &mut RecordTensor<'a, T, S2, D>,
1651        fxy: impl Fn(T, T) -> T,
1652        dfxy_dx: impl Fn(T, T) -> T,
1653        dfxy_dy: impl Fn(T, T) -> T,
1654    ) where
1655        S2: TensorMut<(T, Index), D>,
1656    {
1657        // x is lhs, y is rhs, so calling binary_left_assign on the rhs container
1658        // means we need to swap all the arguments
1659        rhs.binary_left_assign(
1660            self,
1661            |y, x| fxy(x, y),
1662            |y, x| dfxy_dy(x, y),
1663            |y, x| dfxy_dx(x, y),
1664        )
1665    }
1666
1667    /**
1668     * A convenience helper function which takes the right hand side by value and
1669     * calls [binary_right_assign](RecordTensor::binary_right_assign()) on it, returning
1670     * the right hand side which now contains the result of the operation.
1671     */
1672    #[track_caller]
1673    pub fn do_binary_right_assign<S2>(
1674        &self,
1675        mut rhs: RecordTensor<'a, T, S2, D>,
1676        fxy: impl Fn(T, T) -> T,
1677        dfxy_dx: impl Fn(T, T) -> T,
1678        dfxy_dy: impl Fn(T, T) -> T,
1679    ) -> RecordTensor<'a, T, S2, D>
1680    where
1681        S2: TensorMut<(T, Index), D>,
1682    {
1683        self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
1684        rhs
1685    }
1686}
1687
1688impl<'a, T, S> RecordMatrix<'a, T, S>
1689where
1690    T: Numeric + Primitive,
1691    for<'t> &'t T: NumericRef<T>,
1692    S: MatrixRef<(T, Index)> + NoInteriorMutability,
1693{
1694    /**
1695     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1696     * some unary function from `T` to `T` to every element in the container.
1697     *
1698     * To compute the new records, the unary function of some input x to some
1699     * output y is needed along with its derivative with respect to its input x.
1700     *
1701     * For example, tanh is a commonly used activation function, but the Real trait
1702     * does not include this operation and Record has no operations for it specifically.
1703     * However, you can use this function to compute the tanh for a record container like so:
1704     *
1705     * ```
1706     * use easy_ml::differentiation::{RecordMatrix, WengertList};
1707     * use easy_ml::matrices::Matrix;
1708     * let list = WengertList::new();
1709     * let X = RecordMatrix::variables(
1710     *     &list,
1711     *     Matrix::from_fn((2, 2), |(r, c)| 0.15 * ((1 + r + c) as f32))
1712     * );
1713     * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
1714     * // 1 / (cosh(x) * cosh(x))
1715     * let Y = X.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
1716     *
1717     * // we can unwrap here because we know Y contains variables not constants
1718     * let derivatives = Y.derivatives().unwrap();
1719     * assert_eq!(
1720     *     derivatives.get_reference(0, 0).at_matrix(&X),
1721     *     Matrix::from(vec![
1722     *         // (0, 0) element in Y only had the one input variable (0, 0) in X
1723     *         vec![0.9778332, 0.0],
1724     *         vec![0.0,       0.0]
1725     *     ]),
1726     * );
1727     * assert_eq!(
1728     *     derivatives.get_reference(0, 1).at_matrix(&X),
1729     *     Matrix::from(vec![
1730     *         vec![0.0, 0.915137],
1731     *         vec![0.0,      0.0]
1732     *     ]),
1733     * );
1734     * assert_eq!(
1735     *     // (0, 1) and (1, 0) elements in X had the same starting value so end up with the same
1736     *     // derivative for their corresponding input variable in X
1737     *     derivatives.get_reference(0, 1).at_matrix(&X).get(0, 1),
1738     *     derivatives.get_reference(1, 0).at_matrix(&X).get(1, 0),
1739     * );
1740     * assert_eq!(
1741     *     derivatives.get_reference(1, 1).at_matrix(&X),
1742     *     Matrix::from(vec![
1743     *         vec![0.0, 0.0      ],
1744     *         vec![0.0, 0.8220013]
1745     *     ]),
1746     * );
1747     * ```
1748     */
1749    #[track_caller]
1750    pub fn unary(
1751        &self,
1752        fx: impl Fn(T) -> T,
1753        dfx_dx: impl Fn(T) -> T,
1754    ) -> RecordMatrix<'a, T, Matrix<(T, Index)>> {
1755        let total = self.elements();
1756        match self.history {
1757            None => RecordMatrix::constants(self.numbers.map(|(x, _)| fx(x))),
1758            Some(history) => {
1759                let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
1760                RecordContainer {
1761                    numbers: MatrixView::from(Matrix::from_flat_row_major(self.numbers.size(), ys)),
1762                    history: Some(history),
1763                }
1764            }
1765        }
1766    }
1767
1768    /**
1769     * Creates a new RecordContainer from two RecordContainers by applying
1770     * some binary function from `T` to `T` to every element pair in the containers. Both
1771     * containers must have the same shape.
1772     *
1773     * To compute the new records, the binary function of some inputs x and y to some
1774     * output z is needed along with its derivative with respect to its first input x and
1775     * its derivative with respect to its second input y.
1776     *
1777     * For example, atan2 takes two arguments, but the Real trait
1778     * does not include this operation and Record has no operations for it specifically.
1779     * However, you can use this function to compute the atan2 for two record containers like so:
1780     *
1781     * ```
1782     * use easy_ml::differentiation::{RecordMatrix, WengertList};
1783     * use easy_ml::matrices::Matrix;
1784     * let list = WengertList::new();
1785     * let X = RecordMatrix::variables(
1786     *     &list,
1787     *     Matrix::from_fn((2, 2), |(r, c)| ((1 + r + c) as f32))
1788     * );
1789     * let Y = RecordMatrix::variables(
1790     *     &list,
1791     *     Matrix::from_fn((2, 2), |(r, c)| ((1 + r + c) as f32))
1792     * );
1793     * // the derivative of atan2 with respect to x is y/(x*x + y*y)
1794     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
1795     * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
1796     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
1797     * let Z = X.binary(&Y,
1798     *     |x, y| x.atan2(y),
1799     *     |x, y| y/((x*x) + (y*y)),
1800     *     |x, y| -x/((x*x) + (y*y))
1801     * );
1802     *
1803     * // we can unwrap here because we know Z contains variables not constants
1804     * let derivatives = Z.derivatives().unwrap();
1805     * // Just as in the unary example, only one pair of the four inputs in X and Y influence the
1806     * // outputs in Z, so we have a lot of 0.0 derivatives, and the inputs in [0, 1] and [1, 0]
1807     * // are identical so we see the same derivative.
1808     * let dZ_dX = derivatives.map(|d| d.at_matrix(&X));
1809     * assert_eq!(
1810     *     dZ_dX,
1811     *     Matrix::from(vec![
1812     *          vec![
1813     *              Matrix::from(vec![
1814     *                  vec![ 0.5, 0.0 ],
1815     *                  vec![ 0.0, 0.0 ]
1816     *              ]),
1817     *              Matrix::from(vec![
1818     *                  vec![ 0.0, 0.25 ],
1819     *                  vec![ 0.0, 0.0 ]
1820     *              ])
1821     *          ],
1822     *          vec![
1823     *              Matrix::from(vec![
1824     *                  vec![ 0.0, 0.0 ],
1825     *                  vec![ 0.25, 0.0 ]
1826     *              ]),
1827     *              Matrix::from(vec![
1828     *                  vec![ 0.0, 0.0 ],
1829     *                  vec![ 0.0, 0.16666667 ]
1830     *              ])
1831     *          ]
1832     *     ])
1833     * );
1834     * let dZ_dY = derivatives.map(|d| d.at_matrix(&Y));
1835     * assert_eq!(
1836     *     dZ_dY,
1837     *     Matrix::from(vec![
1838     *          vec![
1839     *              Matrix::from(vec![
1840     *                  vec![ -0.5, 0.0 ],
1841     *                  vec![ 0.0, 0.0 ]
1842     *              ]),
1843     *              Matrix::from(vec![
1844     *                  vec![ 0.0, -0.25 ],
1845     *                  vec![ 0.0, 0.0 ]
1846     *              ])
1847     *          ],
1848     *          vec![
1849     *              Matrix::from(vec![
1850     *                  vec![ 0.0, 0.0 ],
1851     *                  vec![ -0.25, 0.0 ]
1852     *              ]),
1853     *              Matrix::from(vec![
1854     *                  vec![ 0.0, 0.0 ],
1855     *                  vec![ 0.0, -0.16666667 ]
1856     *              ])
1857     *          ]
1858     *     ])
1859     * );
1860     * ```
1861     *
1862     * # Panics
1863     *
1864     * - If both record containers have a WengertList that are different to each other
1865     * - If the record containers have different shapes
1866     */
1867    #[track_caller]
1868    pub fn binary<S2>(
1869        &self,
1870        rhs: &RecordMatrix<'a, T, S2>,
1871        fxy: impl Fn(T, T) -> T,
1872        dfxy_dx: impl Fn(T, T) -> T,
1873        dfxy_dy: impl Fn(T, T) -> T,
1874    ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1875    where
1876        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1877    {
1878        let shape = {
1879            let left_shape = self.numbers.size();
1880            let right_shape = rhs.numbers.size();
1881            if left_shape != right_shape {
1882                panic!(
1883                    "Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
1884                    left_shape, right_shape
1885                );
1886            }
1887            left_shape
1888        };
1889        let total = self.elements();
1890        match (self.history, rhs.history) {
1891            (None, None) => RecordMatrix::constants(Matrix::from_flat_row_major(
1892                shape,
1893                self.numbers
1894                    .row_major_iter()
1895                    .zip(rhs.numbers.row_major_iter())
1896                    .map(|((x, _), (y, _))| fxy(x, y))
1897                    .collect(),
1898            )),
1899            (Some(history), None) => {
1900                let zs = binary_x_history::<T, _, _>(
1901                    total,
1902                    history,
1903                    self.numbers.row_major_iter(),
1904                    rhs.numbers.row_major_iter(),
1905                    fxy,
1906                    dfxy_dx,
1907                );
1908                RecordContainer {
1909                    numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1910                    history: Some(history),
1911                }
1912            }
1913            (None, Some(history)) => {
1914                let zs = binary_y_history::<T, _, _>(
1915                    total,
1916                    history,
1917                    self.numbers.row_major_iter(),
1918                    rhs.numbers.row_major_iter(),
1919                    fxy,
1920                    dfxy_dy,
1921                );
1922                RecordContainer {
1923                    numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1924                    history: Some(history),
1925                }
1926            }
1927            (Some(history), Some(h)) => {
1928                assert!(
1929                    record_operations::same_lists(history, h),
1930                    "Record containers must be using the same WengertList"
1931                );
1932                let zs = binary_both_history::<T, _, _>(
1933                    total,
1934                    history,
1935                    self.numbers.row_major_iter(),
1936                    rhs.numbers.row_major_iter(),
1937                    fxy,
1938                    dfxy_dx,
1939                    dfxy_dy,
1940                );
1941                RecordContainer {
1942                    numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1943                    history: Some(history),
1944                }
1945            }
1946        }
1947    }
1948
1949    /**
1950     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1951     * some unary function on `Record<T>` to `Record<T>` to every element in the container. This
1952     * will fail if the function would create records with inconsistent histories.
1953     *
1954     * When used with pure functions that can't return different histories for different inputs
1955     * unwrapping with always succeed.
1956     *
1957     * This API can allow you to call a generic function that operates on
1958     * [Numeric](crate::numeric::Numeric) or [Real](crate::numeric::extra::Real) numbers and
1959     * apply all the correct derivative tracking during the intermediate calculations for you,
1960     * without having to resort to storing the Record types.
1961     *
1962     * ```
1963     * use easy_ml::numeric::extra::Real;
1964     * use easy_ml::matrices::Matrix;
1965     * use easy_ml::differentiation::{RecordMatrix, WengertList};
1966     *
1967     * fn sigmoid<T: Real+ Copy>(x: T) -> T {
1968     *     T::one() / (T::one() + (-x).exp())
1969     * }
1970     *
1971     * let history = WengertList::new();
1972     * let layer = RecordMatrix::variables(&history, Matrix::from(vec![vec![ 0.2, 0.6 ]]));
1973     * let after = layer.map(sigmoid).unwrap(); // sigmoid can't introduce inconsistent histories
1974     * ```
1975     *
1976     * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
1977     * after mapping doesn't have to be the same as before, only must be the same for every
1978     * mapped element.
1979     *
1980     * See also: [AsRecords](AsRecords)
1981     */
1982    #[allow(clippy::type_complexity)]
1983    #[track_caller]
1984    pub fn map(
1985        &self,
1986        fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1987    ) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
1988        let result = RecordMatrix::from_iter(self.size(), self.iter_row_major_as_records().map(fx));
1989        RecordMatrix::<'a, T, S>::map_collection(result, self.size())
1990    }
1991
1992    /**
1993     * Returns a new Matrix of the same shape as this RecordContainer by
1994     * applying a function to every element from the Record to the Record's number
1995     * type. For example, you could lookup the derivatives of each Record in
1996     * the container with respect to some input.
1997     */
1998    pub fn map_to_matrix(&self, fx: impl Fn(Record<'a, T>) -> T) -> Matrix<T> {
1999        Matrix::from_flat_row_major(
2000            self.size(),
2001            self.iter_row_major_as_records().map(fx).collect(),
2002        )
2003    }
2004
2005    /**
2006     * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
2007     * some unary function on `Record<T>` and each index of that position in the Record to
2008     * `Record<T>` to every element in the container. This will fail if the function would
2009     * create records with inconsistent histories.
2010     *
2011     * When used with pure functions that can't return different histories for different inputs
2012     * unwrapping with always succeed.
2013     *
2014     * This API can allow you to call a generic function that operates on
2015     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2016     * during the intermediate calculations for you, without having to resort to storing the
2017     * Record types.
2018     *
2019     * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2020     * after mapping doesn't have to be the same as before, only must be the same for every
2021     * mapped element.
2022     */
2023    #[allow(clippy::type_complexity)]
2024    #[track_caller]
2025    pub fn map_with_index(
2026        &self,
2027        fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
2028    ) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
2029        let result = RecordMatrix::from_iter(
2030            self.size(),
2031            self.iter_row_major_as_records()
2032                .with_index()
2033                .map(|((r, c), x)| fx(x, r, c)),
2034        );
2035        RecordMatrix::<'a, T, S>::map_collection(result, self.size())
2036    }
2037
2038    #[allow(clippy::type_complexity)]
2039    #[track_caller]
2040    fn map_collection(
2041        result: Result<
2042            RecordMatrix<'a, T, Matrix<(T, usize)>>,
2043            InvalidRecordIteratorError<'a, T, 2>,
2044        >,
2045        size: (Row, Column),
2046    ) -> Result<RecordMatrix<'a, T, Matrix<(T, usize)>>, InconsistentHistory<'a, T>> {
2047        use InvalidRecordIteratorError as Error;
2048        match result {
2049            Ok(matrix) => Ok(matrix),
2050            Err(error) => match error {
2051                // These first two should be 100% impossible but provide a sensible error just
2052                // in case some weird things break our invariants
2053                Error::Empty => panic!("Illegal state, record matrix was empty {:?}", size),
2054                Error::Shape { requested, length } => panic!(
2055                    "Illegal state, record matrix shape was inconsistent: requested: {:?}, length of data: {:?}",
2056                    requested, length
2057                ),
2058                // This one is theoretically possible but in practise shouldn't happen by accident
2059                // However, it can't implement Debug unless T is debug so to avoid having to
2060                // restrict our function signature we return a Result anyway - this also encourages
2061                // the user to make sure their function isn't going to cause this case, which
2062                // with some of the other variants like with_index might come up more easily
2063                Error::InconsistentHistory(h) => Err(h),
2064            },
2065        }
2066    }
2067
2068    /**
2069     * For each record in the container, peforms a backward pass up its WengertList from it
2070     * as the output, computing all the derivatives for the inputs involving this output.
2071     *
2072     * If this container has no backing WengertList, ie was created as constants, then None is
2073     * returned instead. Otherwise the returned Matrix will have the same size as this container,
2074     * with the respective derivatives matching each element in this container.
2075     *
2076     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is Y with M outputs,
2077     * then this computes all the derivatives δy<sub>j</sub>/δx<sub>i</sub> for i = 1 to N and
2078     * j = 1 to M.
2079     *
2080     * If you have a lot of outputs this could be very expensive! Reverse auto diff is optimised
2081     * for domains where there are many more inputs than outputs.
2082     *
2083     * If you only need some of the derivatives then
2084     * [derivatives_for](RecordMatrix::derivatives_for) can be used instead to avoid
2085     * calculating the rest.
2086     */
2087    pub fn derivatives(&self) -> Option<Matrix<Derivatives<T>>> {
2088        self.history.map(|history| {
2089            self.numbers.map(|(x, i)| {
2090                Record {
2091                    number: x,
2092                    history: Some(history),
2093                    index: i,
2094                }
2095                .derivatives()
2096            })
2097        })
2098    }
2099
2100    /**
2101     * For the record at the index, peforms a backward pass up its WengertList from it
2102     * as the output, computing all the derivatives for the inputs involving this output.
2103     *
2104     * If the index is invalid or this container has no backing WengertList, ie was created
2105     * as constants, then None is returned instead.
2106     *
2107     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
2108     * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
2109     */
2110    pub fn derivatives_for(&self, row: Row, column: Column) -> Option<Derivatives<T>> {
2111        let (number, index) = self
2112            .try_get_reference(row, column)
2113            .map(|(x, i)| (x.clone(), *i))?;
2114        // The nature of reverse autodiff is that we expect to only have a few outputs from
2115        // which we calculate all the derivatives we care about. Therefore just call Record and
2116        // reuse the implementation instead of trying to do anything clever like calculate all
2117        // derivatives for every number in this container.
2118        Record {
2119            number,
2120            history: self.history,
2121            index,
2122        }
2123        .try_derivatives()
2124    }
2125
2126    /**
2127     * Performs elementwise multiplication for two record matrices of the same size.
2128     *
2129     * # Panics
2130     *
2131     * - If both record containers have a WengertList that are different to each other
2132     * - If the record containers have different shapes
2133     */
2134    // TODO: Assign variants?
2135    pub fn elementwise_multiply<S2>(
2136        &self,
2137        other: &RecordMatrix<'a, T, S2>,
2138    ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
2139    where
2140        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2141    {
2142        self.binary(
2143            other,
2144            Multiplication::<T>::function,
2145            Multiplication::<T>::d_function_dx,
2146            Multiplication::<T>::d_function_dy,
2147        )
2148    }
2149
2150    /**
2151     * Performs elementwise division for two record matrices of the same size.
2152     *
2153     * # Panics
2154     *
2155     * - If both record containers have a WengertList that are different to each other
2156     * - If the record containers have different shapes
2157     */
2158    pub fn elementwise_divide<S2>(
2159        &self,
2160        other: &RecordMatrix<'a, T, S2>,
2161    ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
2162    where
2163        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2164    {
2165        self.binary(
2166            other,
2167            Division::<T>::function,
2168            Division::<T>::d_function_dx,
2169            Division::<T>::d_function_dy,
2170        )
2171    }
2172}
2173
2174impl<'a, T, S> RecordMatrix<'a, T, S>
2175where
2176    T: Numeric + Primitive,
2177    for<'t> &'t T: NumericRef<T>,
2178    S: MatrixMut<(T, Index)> + NoInteriorMutability,
2179{
2180    /**
2181     * Overwrites a RecordContainer by applying
2182     * some unary function from `T` to `T` to every element in the container.
2183     *
2184     * To compute the new records, the unary function of some input x to some
2185     * output y is needed along with its derivative with respect to its input x.
2186     */
2187    #[track_caller]
2188    pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
2189        let total = self.elements();
2190        match self.history {
2191            None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
2192            Some(history) => {
2193                let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
2194                for (element, result) in self.numbers.row_major_reference_mut_iter().zip(ys) {
2195                    *element = result;
2196                }
2197                self.history = Some(history);
2198            }
2199        }
2200    }
2201
2202    /**
2203     * Overwrites the left hand side of a RecordContainer with the result of applying
2204     * some binary function from `T` to `T` to every element pair in the containers. Both
2205     * containers must have the same shape.
2206     * To compute the new records, the binary function of some inputs x and y to some
2207     * output z is needed along with its derivative with respect to its first input x and
2208     * its derivative with respect to its second input y.
2209     *
2210     * # Panics
2211     *
2212     * - If both record containers have a WengertList that are different to each other
2213     * - If the record containers have different shapes
2214     */
2215    #[track_caller]
2216    pub fn binary_left_assign<S2>(
2217        &mut self,
2218        rhs: &RecordMatrix<'a, T, S2>,
2219        fxy: impl Fn(T, T) -> T,
2220        dfxy_dx: impl Fn(T, T) -> T,
2221        dfxy_dy: impl Fn(T, T) -> T,
2222    ) where
2223        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2224    {
2225        {
2226            let left_shape = self.numbers.size();
2227            let right_shape = rhs.numbers.size();
2228            if left_shape != right_shape {
2229                panic!(
2230                    "Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
2231                    left_shape, right_shape
2232                );
2233            }
2234        }
2235        let total = self.elements();
2236        match (self.history, rhs.history) {
2237            (None, None) => {
2238                for (x, y) in self
2239                    .numbers
2240                    .row_major_reference_mut_iter()
2241                    .zip(rhs.numbers.row_major_iter())
2242                {
2243                    let (left, _) = x;
2244                    let (right, _) = y;
2245                    *x = (fxy(left.clone(), right), 0);
2246                }
2247            }
2248            (Some(history), None) => {
2249                let zs = binary_x_history::<T, _, _>(
2250                    total,
2251                    history,
2252                    self.numbers.row_major_iter(),
2253                    rhs.numbers.row_major_iter(),
2254                    fxy,
2255                    dfxy_dx,
2256                );
2257                for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2258                    *element = result;
2259                }
2260                self.history = Some(history);
2261            }
2262            (None, Some(history)) => {
2263                let zs = binary_y_history::<T, _, _>(
2264                    total,
2265                    history,
2266                    self.numbers.row_major_iter(),
2267                    rhs.numbers.row_major_iter(),
2268                    fxy,
2269                    dfxy_dy,
2270                );
2271                for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2272                    *element = result;
2273                }
2274                self.history = Some(history);
2275            }
2276            (Some(history), Some(h)) => {
2277                assert!(
2278                    record_operations::same_lists(history, h),
2279                    "Record containers must be using the same WengertList"
2280                );
2281                let zs = binary_both_history::<T, _, _>(
2282                    total,
2283                    history,
2284                    self.numbers.row_major_iter(),
2285                    rhs.numbers.row_major_iter(),
2286                    fxy,
2287                    dfxy_dx,
2288                    dfxy_dy,
2289                );
2290                for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2291                    *element = result;
2292                }
2293                self.history = Some(history);
2294            }
2295        }
2296    }
2297
2298    /**
2299     * A convenience helper function which takes the RecordContainer value and
2300     * calls [unary_assign](RecordMatrix::unary_assign()) on it, returning
2301     * the record container which now contains the result of the operation.
2302     */
2303    #[track_caller]
2304    pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
2305        self.unary_assign(fx, dfx_dx);
2306        self
2307    }
2308
2309    /**
2310     * A convenience helper function which takes the left hand side by value and
2311     * calls [binary_left_assign](RecordMatrix::binary_left_assign()) on it, returning
2312     * the left hand side which now contains the result of the operation.
2313     */
2314    #[track_caller]
2315    pub fn do_binary_left_assign<S2>(
2316        mut self,
2317        rhs: &RecordMatrix<'a, T, S2>,
2318        fxy: impl Fn(T, T) -> T,
2319        dfxy_dx: impl Fn(T, T) -> T,
2320        dfxy_dy: impl Fn(T, T) -> T,
2321    ) -> Self
2322    where
2323        S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2324    {
2325        self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
2326        self
2327    }
2328
2329    /**
2330     * Updates this RecordContainer in place by applying some unary function on `Record<T>` to
2331     * `Record<T>` to every element in the container. This will fail if the function would create
2332     * records with inconsistent histories.
2333     *
2334     * When used with pure functions that can't return different histories for different inputs
2335     * unwrapping with always succeed.
2336     *
2337     * Since this updates the container in place, if Err is returned then the data in this
2338     * RecordContainer is still available but it has been corrupted - at least one of the elements
2339     * should have a different history than what it will have because the mapping function created
2340     * inconsistent histories that couldn't be represented by the container as it only stores
2341     * one.
2342     *
2343     * This API can allow you to call a generic function that operates on
2344     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2345     * during the intermediate calculations for you, without having to resort to storing the
2346     * Record types.
2347     *
2348     * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2349     * after mapping doesn't have to be the same as before, only must be the same for every
2350     * mapped element.
2351     *
2352     * You might also use this function at the end of a training loop to update all the weights
2353     * to reduce their loss.
2354     *
2355     * ```
2356     * use easy_ml::numeric::Numeric;
2357     * use easy_ml::matrices::Matrix;
2358     * use easy_ml::differentiation::{Record, RecordMatrix, WengertList};
2359     *
2360     * let history = WengertList::new();
2361     * let mut weights = RecordMatrix::variables(
2362     *     &history,
2363     *     Matrix::from(vec![vec![ 0.3, 0.2, -1.2, -0.4 ]])
2364     * );
2365     * let error = {
2366     *     // this is over-simplified for brevity, obviously in a real scenario we wouldn't have a
2367     *     // function that calculates the error like this or we wouldn't be doing machine learning
2368     *     // to fit it in the first place
2369     *     let mut loss = Record::variable(0.0, &history);
2370     *     for r in weights.iter_row_major_as_records() {
2371     *         loss = loss + r;
2372     *     }
2373     *     loss
2374     * };
2375     * let derivatives = error.derivatives();
2376     * let learning_rate = 0.1;
2377     * // update the weights to contain less error than they did before
2378     * let result = weights.map_mut(|x| x - (derivatives[&x] * learning_rate));
2379     * assert!(result.is_ok()); // we know we didn't introduce an inconsistent history just updating the weights
2380     * ```
2381     */
2382    #[track_caller]
2383    pub fn map_mut(
2384        &mut self,
2385        fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
2386    ) -> Result<(), InconsistentHistory<'a, T>> {
2387        let history = self.history;
2388        let new_history =
2389            map_mut_base::<'a, T, _, _>(RowMajorReferenceMutIterator::from(self), |x| {
2390                let record = Record::from_existing(x.clone(), history);
2391                let result = fx(record);
2392                *x = (result.number, result.index);
2393                result.history
2394            })?;
2395        self.history = new_history;
2396        Ok(())
2397    }
2398
2399    /**
2400     * Updates this RecordContainer in place by applying some unary function on `Record<T>` and
2401     * each index of that position in the Record to `Record<T>` to every element in the container.
2402     * This will fail if the function would create records with inconsistent histories.
2403     *
2404     * When used with pure functions that can't return different histories for different inputs
2405     * unwrapping with always succeed.
2406     *
2407     * Since this updates the container in place, if Err is returned then the data in this
2408     * RecordContainer is still available but it has been corrupted - at least one of the elements
2409     * should have a different history than what it will have because the mapping function created
2410     * inconsistent histories that couldn't be represented by the container as it only stores
2411     * one.
2412     *
2413     * This API can allow you to call a generic function that operates on
2414     * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2415     * during the intermediate calculations for you, without having to resort to storing the
2416     * Record types.
2417     *
2418     * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2419     * after mapping doesn't have to be the same as before, only must be the same for every
2420     * mapped element.
2421     */
2422    #[track_caller]
2423    pub fn map_mut_with_index(
2424        &mut self,
2425        fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
2426    ) -> Result<(), InconsistentHistory<'a, T>> {
2427        let history = self.history;
2428        let new_history = map_mut_base::<'a, T, _, _>(
2429            RowMajorReferenceMutIterator::from(self).with_index(),
2430            |((r, c), x)| {
2431                let record = Record::from_existing(x.clone(), history);
2432                let result = fx(record, r, c);
2433                *x = (result.number, result.index);
2434                result.history
2435            },
2436        )?;
2437        self.history = new_history;
2438        Ok(())
2439    }
2440}
2441
2442impl<'a, T, S> RecordMatrix<'a, T, S>
2443where
2444    T: Numeric + Primitive,
2445    for<'t> &'t T: NumericRef<T>,
2446    S: MatrixRef<(T, Index)> + NoInteriorMutability,
2447{
2448    /**
2449     * Overwrites the right hand side of a RecordContainer with the result of applying
2450     * some binary function from `T` to `T` to every element pair in the containers. Both
2451     * containers must have the same shape.
2452     * To compute the new records, the binary function of some inputs x and y to some
2453     * output z is needed along with its derivative with respect to its first input x and
2454     * its derivative with respect to its second input y.
2455     *
2456     * # Panics
2457     *
2458     * - If both record containers have a WengertList that are different to each other
2459     * - If the record containers have different shapes
2460     */
2461    #[track_caller]
2462    pub fn binary_right_assign<S2>(
2463        &self,
2464        rhs: &mut RecordMatrix<'a, T, S2>,
2465        fxy: impl Fn(T, T) -> T,
2466        dfxy_dx: impl Fn(T, T) -> T,
2467        dfxy_dy: impl Fn(T, T) -> T,
2468    ) where
2469        S2: MatrixMut<(T, Index)> + NoInteriorMutability,
2470    {
2471        // x is lhs, y is rhs, so calling binary_left_assign on the rhs container
2472        // means we need to swap all the arguments
2473        rhs.binary_left_assign(
2474            self,
2475            |y, x| fxy(x, y),
2476            |y, x| dfxy_dy(x, y),
2477            |y, x| dfxy_dx(x, y),
2478        )
2479    }
2480
2481    /**
2482     * A convenience helper function which takes the right hand side by value and
2483     * calls [binary_right_assign](RecordMatrix::binary_right_assign()) on it, returning
2484     * the right hand side which now contains the result of the operation.
2485     */
2486    #[track_caller]
2487    pub fn do_binary_right_assign<S2>(
2488        &self,
2489        mut rhs: RecordMatrix<'a, T, S2>,
2490        fxy: impl Fn(T, T) -> T,
2491        dfxy_dx: impl Fn(T, T) -> T,
2492        dfxy_dy: impl Fn(T, T) -> T,
2493    ) -> RecordMatrix<'a, T, S2>
2494    where
2495        S2: MatrixMut<(T, Index)> + NoInteriorMutability,
2496    {
2497        self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
2498        rhs
2499    }
2500}
2501
2502// # Safety
2503//
2504// Our inner `numbers` tensor has to implement TensorRef correctly so by delegating to it
2505// without changing any indexes or introducing interior mutability, we implement TensorRef
2506// correctly as well.
2507/**
2508 * RecordTensor implements TensorRef when the source does, returning references to the tuples
2509 * of `T` and [`Index`](Index).
2510 */
2511unsafe impl<'a, T, S, const D: usize> TensorRef<(T, Index), D> for RecordTensor<'a, T, S, D>
2512where
2513    T: Primitive,
2514    S: TensorRef<(T, Index), D>,
2515{
2516    fn get_reference(&self, indexes: [usize; D]) -> Option<&(T, Index)> {
2517        self.numbers.source_ref().get_reference(indexes)
2518    }
2519
2520    fn view_shape(&self) -> [(Dimension, usize); D] {
2521        self.numbers.source_ref().view_shape()
2522    }
2523
2524    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &(T, Index) {
2525        unsafe { self.numbers.source_ref().get_reference_unchecked(indexes) }
2526    }
2527
2528    fn data_layout(&self) -> DataLayout<D> {
2529        self.numbers.source_ref().data_layout()
2530    }
2531}
2532
2533// # Safety
2534//
2535// Our inner `numbers` tensor has to implement TensorMut correctly so by delegating to it
2536// without changing any indexes or introducing interior mutability, we implement TensorMut
2537// correctly as well.
2538/**
2539 * RecordTensor implements TensorMut when the source does, returning mutable references to the
2540 * tuples of `T` and [`Index`](Index).
2541 */
2542unsafe impl<'a, T, S, const D: usize> TensorMut<(T, Index), D> for RecordTensor<'a, T, S, D>
2543where
2544    T: Primitive,
2545    S: TensorMut<(T, Index), D>,
2546{
2547    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut (T, Index)> {
2548        self.numbers.source_ref_mut().get_reference_mut(indexes)
2549    }
2550
2551    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut (T, Index) {
2552        unsafe {
2553            self.numbers
2554                .source_ref_mut()
2555                .get_reference_unchecked_mut(indexes)
2556        }
2557    }
2558}
2559
2560// # Safety
2561//
2562// Our inner `numbers` matrix has to implement MatrixRef correctly so by delegating to it
2563// without changing any indexes or introducing interior mutability, we implement MatrixRef
2564// correctly as well.
2565/**
2566 * RecordMatrix implements MatrixRef when the source does, returning references to the tuples
2567 * of `T` and [`Index`](Index).
2568 */
2569unsafe impl<'a, T, S> MatrixRef<(T, Index)> for RecordMatrix<'a, T, S>
2570where
2571    T: Primitive,
2572    S: MatrixRef<(T, Index)>,
2573{
2574    fn try_get_reference(&self, row: Row, column: Column) -> Option<&(T, Index)> {
2575        self.numbers.source_ref().try_get_reference(row, column)
2576    }
2577
2578    fn view_rows(&self) -> Row {
2579        self.numbers.source_ref().view_rows()
2580    }
2581
2582    fn view_columns(&self) -> Column {
2583        self.numbers.source_ref().view_columns()
2584    }
2585
2586    unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &(T, Index) {
2587        unsafe {
2588            self.numbers
2589                .source_ref()
2590                .get_reference_unchecked(row, column)
2591        }
2592    }
2593
2594    fn data_layout(&self) -> crate::matrices::views::DataLayout {
2595        self.numbers.source_ref().data_layout()
2596    }
2597}
2598
2599// # Safety
2600//
2601// Our inner `numbers` matrix has to implement NoInteriorMutability correctly so by delegating to
2602// it without introducing interior mutability, we implement NoInteriorMutability
2603// correctly as well.
2604/**
2605 * RecordMatrix implements NoInteriorMutability when the source does.
2606 */
2607unsafe impl<'a, T, S> NoInteriorMutability for RecordMatrix<'a, T, S>
2608where
2609    T: Primitive,
2610    S: NoInteriorMutability,
2611{
2612}
2613
2614// # Safety
2615//
2616// Our inner `numbers` matrix has to implement MatrixMut correctly so by delegating to it
2617// without changing any indexes or introducing interior mutability, we implement MatrixMut
2618// correctly as well.
2619/**
2620 * RecordMatrix implements MatrixMut when the source does, returning mutable references to the
2621 * tuples of `T` and [`Index`].
2622 */
2623unsafe impl<'a, T, S> MatrixMut<(T, Index)> for RecordMatrix<'a, T, S>
2624where
2625    T: Primitive,
2626    S: MatrixMut<(T, Index)>,
2627{
2628    fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut (T, Index)> {
2629        self.numbers
2630            .source_ref_mut()
2631            .try_get_reference_mut(row, column)
2632    }
2633
2634    unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut (T, Index) {
2635        unsafe {
2636            self.numbers
2637                .source_ref_mut()
2638                .get_reference_unchecked_mut(row, column)
2639        }
2640    }
2641}
2642
2643/**
2644 * A zero dimensional record tensor can be converted losslessly into a record.
2645 */
2646impl<'a, T, S> From<RecordTensor<'a, T, S, 0>> for Record<'a, T>
2647where
2648    T: Numeric + Primitive,
2649    S: TensorRef<(T, Index), 0>,
2650{
2651    /**
2652     * Converts the sole element in the zero dimensional record tensor into a record.
2653     */
2654    fn from(scalar: RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
2655        // Not a good way to make this zero copy and just move the data out of the scalar because
2656        // TensorRef API doesn't have by value moves of the data and using TensorOwnedIterator
2657        // requires T: Default and a dummy value (at which point a clone is probably cheaper or
2658        // basically the same?)
2659        Record::from(&scalar)
2660    }
2661}
2662
2663/**
2664 * A zero dimensional record tensor can be converted losslessly into a record.
2665 */
2666impl<'a, T, S> From<&RecordTensor<'a, T, S, 0>> for Record<'a, T>
2667where
2668    T: Numeric + Primitive,
2669    S: TensorRef<(T, Index), 0>,
2670{
2671    /**
2672     * Converts the sole element in the zero dimensional record tensor into a record.
2673     */
2674    fn from(scalar: &RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
2675        Record::from_existing(scalar.view().scalar(), scalar.history)
2676    }
2677}
2678
2679/**
2680 * A record can be converted losslessly into a zero dimensional record tensor.
2681 */
2682impl<'a, T> From<Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
2683where
2684    T: Numeric + Primitive,
2685{
2686    /**
2687     * Converts a record into a zero dimensional record tensor with the single element.
2688     */
2689    fn from(record: Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
2690        RecordTensor::from_existing(
2691            record.history,
2692            TensorView::from(Tensor::from([], vec![(record.number, record.index)])),
2693        )
2694    }
2695}
2696
2697/**
2698 * A record can be converted losslessly into a zero dimensional record tensor.
2699 */
2700impl<'a, T> From<&Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
2701where
2702    T: Numeric + Primitive,
2703{
2704    /**
2705     * Converts a record into a zero dimensional record tensor with the single element.
2706     */
2707    fn from(record: &Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
2708        RecordTensor::from_existing(
2709            record.history,
2710            TensorView::from(Tensor::from(
2711                [],
2712                vec![(record.number.clone(), record.index)],
2713            )),
2714        )
2715    }
2716}
2717
2718#[test]
2719fn matrix_multiplication_derivatives_are_the_same() {
2720    #[rustfmt::skip]
2721    let a = Tensor::from(
2722        [("r", 4), ("c", 3)],
2723        vec![
2724            1.0, 2.0, 3.0,
2725            4.0, 5.0, 6.0,
2726            7.0, 8.0, 9.0,
2727            0.0, 5.0, 2.0
2728        ]
2729    );
2730    let b = a.transpose(["c", "r"]);
2731    let history = WengertList::new();
2732    let also_history: WengertList<f64> = WengertList::new();
2733    let tensor_of_records_a = a.map(|x| Record::variable(x, &history));
2734    let tensor_of_records_b = b.map(|x| Record::variable(x, &history));
2735    let tensor_of_records_c = &tensor_of_records_a * &tensor_of_records_b;
2736    let record_tensor_a = RecordTensor::variables(&also_history, a);
2737    let record_tensor_b = RecordTensor::variables(&also_history, b);
2738    let record_tensor_c = &record_tensor_a * &record_tensor_b;
2739
2740    // C should be calculated the same in terms of the actual number
2741    assert_eq!(
2742        tensor_of_records_c.map(|r| r.number),
2743        TensorView::from(&record_tensor_c).map(|(n, _)| n)
2744    );
2745
2746    let tensor_of_records_derivatives = tensor_of_records_c.map(|r| r.derivatives());
2747    let tensor_of_records_a_derivatives =
2748        tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_a));
2749    let tensor_of_records_b_derivatives =
2750        tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_b));
2751
2752    let record_tensor_derivatives = record_tensor_c.derivatives().unwrap();
2753    let record_tensor_a_derivatives =
2754        record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_a));
2755    let record_tensor_b_derivatives =
2756        record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_b));
2757
2758    // Every calculated derivative should match exactly
2759    assert_eq!(tensor_of_records_a_derivatives, record_tensor_a_derivatives);
2760    assert_eq!(tensor_of_records_b_derivatives, record_tensor_b_derivatives);
2761
2762    // Verify C is actually calculated correctly
2763    #[rustfmt::skip]
2764    assert_eq!(
2765        tensor_of_records_c.map(|r| r.number),
2766        Tensor::from(
2767            [("r", 4), ("c", 4)],
2768            vec![
2769                14.0, 32.0, 50.0, 16.0,
2770                32.0, 77.0, 122.0, 37.0,
2771                50.0, 122.0, 194.0, 58.0,
2772                16.0, 37.0, 58.0, 29.0
2773            ]
2774        )
2775    );
2776    #[rustfmt::skip]
2777    assert_eq!(
2778        tensor_of_records_c.map(|r| r.number),
2779        Tensor::from(
2780            [("r", 4), ("c", 4)],
2781            vec![
2782                (1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
2783                (1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
2784                (1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
2785                (1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
2786
2787                (4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
2788                (4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
2789                (4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
2790                (4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
2791
2792                (7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
2793                (7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
2794                (7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
2795                (7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
2796
2797                (0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
2798                (0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
2799                (0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
2800                (0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
2801            ]
2802        )
2803    );
2804
2805    let tensor_of_records_derivatives = history.operations.borrow().clone();
2806    let record_tensor_derivatives = also_history.operations.borrow().clone();
2807    assert_eq!(
2808        tensor_of_records_derivatives.len(),
2809        record_tensor_derivatives.len()
2810    );
2811}
2812
2813#[test]
2814fn matrix_view_matrix_multiplication_derivatives_are_the_same() {
2815    #[rustfmt::skip]
2816    let a = Matrix::from(vec![
2817        vec![ 1.0, 2.0, 3.0 ],
2818        vec![ 4.0, 5.0, 6.0 ],
2819        vec![ 7.0, 8.0, 9.0 ],
2820        vec![ 0.0, 5.0, 2.0 ]
2821    ]);
2822    let b = a.transpose();
2823    let history = WengertList::new();
2824    let also_history: WengertList<f64> = WengertList::new();
2825    let matrix_of_records_a = a.map(|x| Record::variable(x, &history));
2826    let matrix_of_records_b = b.map(|x| Record::variable(x, &history));
2827    let matrix_of_records_c = &matrix_of_records_a * &matrix_of_records_b;
2828    let record_matrix_a = RecordMatrix::variables(&also_history, a);
2829    let record_matrix_b = RecordMatrix::variables(&also_history, b);
2830    let record_matrix_c = &record_matrix_a * &record_matrix_b;
2831
2832    // C should be calculated the same in terms of the actual number
2833    assert_eq!(
2834        matrix_of_records_c.map(|r| r.number),
2835        MatrixView::from(&record_matrix_c).map(|(n, _)| n)
2836    );
2837
2838    let matrix_of_records_derivatives = matrix_of_records_c.map(|r| r.derivatives());
2839    let matrix_of_records_a_derivatives =
2840        matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_a));
2841    let matrix_of_records_b_derivatives =
2842        matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_b));
2843
2844    let record_matrix_derivatives = record_matrix_c.derivatives().unwrap();
2845    let record_matrix_a_derivatives =
2846        record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_a));
2847    let record_matrix_b_derivatives =
2848        record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_b));
2849
2850    // Every calculated derivative should match exactly
2851    assert_eq!(matrix_of_records_a_derivatives, record_matrix_a_derivatives);
2852    assert_eq!(matrix_of_records_b_derivatives, record_matrix_b_derivatives);
2853
2854    // Verify C is actually calculated correctly
2855    #[rustfmt::skip]
2856    assert_eq!(
2857        matrix_of_records_c.map(|r| r.number),
2858        Matrix::from(vec![
2859                vec![ 14.0, 32.0, 50.0, 16.0 ],
2860                vec![ 32.0, 77.0, 122.0, 37.0 ],
2861                vec![ 50.0, 122.0, 194.0, 58.0 ],
2862                vec![ 16.0, 37.0, 58.0, 29.0 ]
2863            ]
2864        )
2865    );
2866    #[rustfmt::skip]
2867    assert_eq!(
2868        matrix_of_records_c.map(|r| r.number),
2869        Matrix::from(vec![
2870                vec![
2871                    (1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
2872                    (1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
2873                    (1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
2874                    (1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
2875                ],
2876                vec![
2877                    (4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
2878                    (4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
2879                    (4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
2880                    (4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
2881                ],
2882                vec![
2883                    (7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
2884                    (7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
2885                    (7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
2886                    (7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
2887                ],
2888                vec![
2889                    (0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
2890                    (0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
2891                    (0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
2892                    (0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
2893                ]
2894        ])
2895    );
2896
2897    let matrix_of_records_derivatives = history.operations.borrow().clone();
2898    let record_matrix_derivatives = also_history.operations.borrow().clone();
2899    assert_eq!(
2900        matrix_of_records_derivatives.len(),
2901        record_matrix_derivatives.len()
2902    );
2903}