easy_ml/differentiation/container_record/mod.rs
1use crate::differentiation::functions::{Division, FunctionDerivative, Multiplication};
2use crate::differentiation::iterators::{
3 AsRecords, InconsistentHistory, InvalidRecordIteratorError,
4};
5use crate::differentiation::record_operations;
6use crate::differentiation::{Derivatives, Index, Primitive, Record, WengertList};
7use crate::matrices::iterators::{
8 ColumnMajorIterator, RowMajorIterator, RowMajorOwnedIterator, RowMajorReferenceMutIterator,
9};
10use crate::matrices::views::{MatrixMut, MatrixRef, MatrixView, NoInteriorMutability};
11use crate::matrices::{Column, Matrix, Row};
12use crate::numeric::{Numeric, NumericRef};
13use crate::tensors::indexing::{
14 TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceMutIterator,
15};
16use crate::tensors::views::{DataLayout, TensorMut, TensorRef, TensorRename, TensorView};
17use crate::tensors::{Dimension, Tensor};
18
19mod container_operations;
20pub mod iterators;
21
22/**
23 * A pluralisation of [Record](crate::differentiation::Record) that groups together a
24 * **s**ource of numbers of type T and stores the WengertList only once.
25 *
26 * Typically you would refer to one of the type aliases to disambiguate the type of `S` and
27 * use more succinct generics: [RecordMatrix](RecordMatrix), [RecordTensor](RecordTensor).
28 *
29 * For both Matrix and Tensor source types, the containers implement [`+`](std::ops::Add) and
30 * [`-`](std::ops::Sub) and have the methods `elementwise_multiply` and `elementwise_divide`.
31 * In all cases the containers must have the same size for the operation and will panic if
32 * mismatched.
33 * [`*`](std::ops::Mul) is also implemented for 2 dimensional tensors and matrices as matrix
34 * multiplication.
35 *
36 * For convenience, as with Trace and Record, many unary operations including
37 * [Cos](crate::numeric::extra::Cos), [Exp](crate::numeric::extra::Exp),
38 * [Ln](crate::numeric::extra::Ln), [Neg](std::ops::Neg), [Pow](crate::numeric::extra::Pow),
39 * [Sin](crate::numeric::extra::Sin), and [Sqrt](crate::numeric::extra::Sqrt) are implemented as
40 * well, applying the unary function to each element in the tensor.
41 *
42 * `+`, `-`, `*` and `/` operations with a RecordContainer and a scalar are also implemented,
43 * treating the right hand side scalar as a constant. These are also unary functions in terms of
44 * the derivative tracking, for example `X + 5` applies the function `+5` to each element in
45 * `X`. Due to the orphan rule, the standard library scalars cannot be implemented for a left hand
46 * side scalar, see [SwappedOperations](crate::differentiation::record_operations::SwappedOperations).
47 *
48 * Both [RecordMatrix](RecordMatrix) and [RecordTensor](RecordTensor) implement
49 * [MatrixRef](MatrixRef) and [TensorRef](TensorRef) respectively, which provide read and write
50 * access to the underlying numbers and indexes into the WengertList. These APIs along with the
51 * `from_existing` constructors for RecordMatrix, RecordTensor, and Record allow arbitary
52 * manipulation of specific elements in a record container if needed. However, any arithmetic
53 * performed on the raw data won't be tracked on the WengertList and overwriting data within a
54 * record container already tracked on the WengertList could result in nonsense. These APIs exist
55 * for easy read access to check the data in a record container and for read/write access when
56 * manipulating the shape of a record container, and are designed to be used only for moving data
57 * around - you should put it back unchanged in a RecordContainer or Record before doing further
58 * arithmetic that needs to be tracked on the WengertList.
59 *
60 * If you just want to manipulate the data in record containers as if they were Records you can
61 * use the iterator APIs of [AsRecords](AsRecords) instead and collect them back into containers
62 * when you're done, or to manipulate a single container, the [map](RecordContainer::map) and
63 * [map_mut](RecordContainer::map_mut) methods.
64 */
65#[derive(Debug)]
66pub struct RecordContainer<'a, T: Primitive, S, const D: usize> {
67 // Opted to store the indexes alongside each number (T, Index) for a number of reasons, the
68 // main factor being it makes implementing TensorRef feasible so can utilise the range of
69 // existing APIs for Tensor manipulation. It's theoretically possible to only store the first
70 // index and calculate the rest, since most of the time all indexes are ascending entries in
71 // the WengertList but this would also massively complicate the implementation, especially for
72 // handling non row-major operations such as matrix multiplication. It's also not super clear
73 // that this would be more efficient because it turns reads into more arithmetic rather
74 // than avoiding any work. Just lifting the WengertList out of the tensor should have
75 // meaningful improvements to cache line efficiency, and failing that still disallows
76 // very questionable states from existing.
77 numbers: S,
78 history: Option<&'a WengertList<T>>,
79}
80
81/**
82 * Alias for succinctly referring to RecordContainers backed by a matrix.
83 */
84pub type RecordMatrix<'a, T, S> = RecordContainer<'a, T, MatrixView<(T, Index), S>, 2>;
85
86/**
87 * Alias for succinctly referring to RecordContainers backed by a tensor.
88 */
89pub type RecordTensor<'a, T, S, const D: usize> =
90 RecordContainer<'a, T, TensorView<(T, Index), S, D>, D>;
91
92fn calculate_incrementing_indexes(starting_index: usize, total: usize) -> Vec<Index> {
93 let mut indexes = vec![0; total];
94 for (i, x) in indexes.iter_mut().enumerate() {
95 *x = starting_index + i;
96 }
97 indexes
98}
99
100impl<'a, T, const D: usize> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
101where
102 T: Numeric + Primitive,
103{
104 /**
105 * Creates multiple untracked Records which have no backing WengertList.
106 *
107 * This is provided for using constants along with Records in operations.
108 *
109 * For example with `Y = X + 4` the computation graph could be conceived as many
110 * `Y[i,j]` nodes with parent nodes of `X[i,j]` and 4 combined with the operation `+`.
111 * However there is no need to record the derivatives of a constant, so
112 * instead the computation graph can be conceived as `Y[i,j]` nodes each with a single
113 * parent node of `X[i,j]` and the unary operation of `+4`.
114 */
115 pub fn constants<S>(c: S) -> Self
116 where
117 S: TensorMut<T, D>,
118 {
119 RecordContainer {
120 numbers: TensorView::from(Tensor::from(
121 c.view_shape(),
122 TensorOwnedIterator::from_numeric(c)
123 .map(|x| (x, 0))
124 .collect(),
125 )),
126 history: None,
127 }
128 }
129
130 /**
131 * Creates multiple records backed by the provided WengertList.
132 *
133 * The records cannot live longer than the WengertList, hence
134 * the following example does not compile
135 *
136 * ```compile_fail
137 * use easy_ml::differentiation::RecordTensor;
138 * use easy_ml::differentiation::WengertList;
139 * use easy_ml::tensors::Tensor;
140 * let record = {
141 * let list = WengertList::new();
142 * RecordTensor::variables(
143 * &list,
144 * Tensor::from([("r", 2), ("c", 2)], vec![ 1.0, 2.0, 3.0, 4.0 ])
145 * )
146 * }; // list no longer in scope
147 * ```
148 */
149 pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
150 where
151 S: TensorMut<T, D>,
152 {
153 let total = crate::tensors::dimensions::elements(&x.view_shape());
154 let starting_index = history.append_nullary_repeating(total);
155 RecordContainer {
156 numbers: TensorView::from(Tensor::from(
157 x.view_shape(),
158 TensorOwnedIterator::from_numeric(x)
159 .zip(calculate_incrementing_indexes(starting_index, total))
160 .collect(),
161 )),
162 history: Some(history),
163 }
164 }
165}
166
167impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
168where
169 T: Numeric + Primitive,
170 S: TensorRef<(T, Index), D>,
171{
172 /**
173 * Returns the number of elements stored by this container's source.
174 *
175 * For a 2 x 3 Tensor, this would return 6, and for a 2 x 3 x 4 Tensor this would return 24
176 * and so on.
177 *
178 * see also [dimensions::elements](crate::tensors::dimensions::elements)
179 */
180 pub fn elements(&self) -> usize {
181 crate::tensors::dimensions::elements(&self.numbers.shape())
182 }
183
184 /**
185 * The shape of this container's source.
186 */
187 pub fn shape(&self) -> [(Dimension, usize); D] {
188 self.numbers.shape()
189 }
190
191 /**
192 * Creates a container from constants/variables directly, most likely obtained by getting a
193 * tensor view of an existing container. **The inputs are not checked for validity**. It is
194 * possible to pass in the wrong Wengert list here or even numbers with indexes that aren't
195 * tracked on the WengertList.
196 *
197 * It is recommended to use this constructor only in conjunction with
198 * resizing or masking an existing container and not for creating new variables. Any variables
199 * created outside of `RecordContainer::variables` would have to be manually added to the
200 * correct Wengert list, and any arithmetic operations would also need tracking correctly.
201 *
202 * ```
203 * use easy_ml::differentiation::RecordTensor;
204 * use easy_ml::differentiation::WengertList;
205 * use easy_ml::tensors::Tensor;
206 * use easy_ml::tensors::views::{TensorView, TensorRange};
207 *
208 * let list = WengertList::new();
209 * let x = RecordTensor::variables(
210 * &list,
211 * Tensor::from_fn([("x", 2), ("y", 2)], |[r, c]| ((r + 3) * (c + 2)) as f64)
212 * );
213 * // oh no wrong shape!
214 * let fixed = TensorView::from(TensorRange::from(x, [("y", 0..1)]).unwrap()); // we can unwrap here because we know the range is valid
215 * let x = RecordTensor::from_existing(Some(&list), fixed);
216 * assert_eq!([("x", 2), ("y", 1)], x.shape());
217 * ```
218 */
219 pub fn from_existing(
220 history: Option<&'a WengertList<T>>,
221 numbers: TensorView<(T, Index), S, D>,
222 ) -> Self {
223 RecordContainer { numbers, history }
224 }
225
226 /**
227 * Returns a record tensor with the dimension names of the shape renamed to the provided
228 * dimensions. The data of this container and the dimension lengths and order remain unchanged.
229 *
230 * This is a shorthand for constructing the RecordTensor via manipulating this TensorView. See
231 * [`RecordTensor::from_existing`](RecordTensor::from_existing).
232 *
233 * # Panics
234 *
235 * If a dimension name is not unique
236 *
237 * ```
238 * use easy_ml::differentiation::RecordTensor;
239 * use easy_ml::differentiation::WengertList;
240 * use easy_ml::tensors::Tensor;
241 * use easy_ml::tensors::views::{TensorView, TensorRename};
242 *
243 * let list = WengertList::new();
244 * let x = RecordTensor::variables(
245 * &list,
246 * Tensor::from_fn([("x", 2), ("y", 2)], |[r, c]| ((r + 3) * (c + 2)) as f64)
247 * );
248 * // oh no wrong dimension names!
249 * let x = x.rename_view(["a", "b"]);
250 * assert_eq!([("a", 2), ("b", 2)], x.shape());
251 * ```
252 */
253 #[track_caller]
254 pub fn rename_view(
255 self,
256 dimensions: [Dimension; D],
257 ) -> RecordTensor<'a, T, TensorRename<(T, Index), S, D>, D> {
258 RecordTensor::from_existing(
259 self.history,
260 TensorView::from(TensorRename::from(self.numbers.source(), dimensions)),
261 )
262 }
263
264 /**
265 * Returns a TensorView of this record container, both the `T` for each record element and
266 * also the index for that record's entry in the WengertList. These can be parsed back into
267 * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
268 * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
269 * numerical operations on the data.
270 */
271 pub fn view(&self) -> TensorView<(T, Index), &RecordTensor<'a, T, S, D>, D> {
272 TensorView::from(self)
273 }
274
275 /**
276 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
277 * to read values from this record container, both the `T` for each record element and
278 * also the index for that record's entry in the WengertList. These can be parsed back into
279 * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
280 * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
281 * numerical operations on the data.
282 *
283 * # Panics
284 *
285 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
286 */
287 #[track_caller]
288 pub fn index_by(
289 &self,
290 dimensions: [Dimension; D],
291 ) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
292 TensorAccess::from(self, dimensions)
293 }
294
295 /**
296 * Creates a TensorAccess which will index into the dimensions this record was created with
297 * in the same order as they were provided, both the `T` for each record element and
298 * also the index for that record's entry in the WengertList. These can be parsed back into
299 * a RecordTensor with [`from_existing`](RecordTensor::from_existing) or individually into
300 * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
301 * numerical operations on the data.
302 *
303 * See [TensorAccess::from_source_order], [get_as_record](TensorAccess::get_as_record).
304 *
305 * ```
306 * use easy_ml::differentiation::RecordTensor;
307 * use easy_ml::differentiation::WengertList;
308 * use easy_ml::tensors::Tensor;
309 *
310 * let list = WengertList::new();
311 * let X = RecordTensor::variables(
312 * &list,
313 * Tensor::from([("a", 3)], vec![ 3.0, 4.0, 5.0 ])
314 * );
315 * let x = X.index().get_as_record([0]);
316 * assert_eq!(x.number, 3.0);
317 * ```
318 */
319 pub fn index(&self) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
320 TensorAccess::from_source_order(self)
321 }
322
323 /**
324 * Returns an iterator over this record tensor as [Record]s instead of the raw `(T, Index)`
325 * data. After manipulating the iterator it can be collected back into a RecordTensor with
326 * [RecordTensor::from_iter](RecordTensor::from_iter).
327 *
328 * This is a shorthand for `AsRecords::from(tensor.history(), TensorIterator::from(&tensor))`
329 */
330 #[allow(clippy::type_complexity)]
331 pub fn iter_as_records<'b>(
332 &'b self,
333 ) -> AsRecords<'a, TensorIterator<'b, (T, Index), RecordTensor<'a, T, S, D>, D>, T> {
334 AsRecords::from_tensor(self)
335 }
336}
337
338impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
339where
340 T: Numeric + Primitive,
341 S: TensorMut<(T, Index), D>,
342{
343 /**
344 * Resets all of the records to place them back on the WengertList, for use
345 * in performing another derivation after clearing the WengertList.
346 *
347 * This is also a preferred shorthand for `map_mut(Record::do_reset)` that can't fail
348 */
349 pub fn reset(&mut self) {
350 match self.history {
351 None => (), // noop
352 Some(history) => {
353 let total = self.elements();
354 let starting_index = history.append_nullary_repeating(total);
355 for (x, i) in self
356 .numbers
357 .iter_reference_mut()
358 .zip(calculate_incrementing_indexes(starting_index, total))
359 {
360 let (_, old_index) = x;
361 *old_index = i;
362 }
363 }
364 };
365 }
366
367 /**
368 * A convenience helper function which takes a RecordContainer by value and
369 * calls [reset](RecordTensor::reset()) on it.
370 */
371 pub fn do_reset(mut x: Self) -> Self {
372 x.reset();
373 x
374 }
375}
376
377impl<'a, T> RecordMatrix<'a, T, Matrix<(T, Index)>>
378where
379 T: Numeric + Primitive,
380{
381 /**
382 * Creates multiple untracked Records which have no backing WengertList.
383 *
384 * This is provided for using constants along with Records in operations.
385 *
386 * For example with `Y = X + 4` the computation graph could be conceived as many
387 * `Y[i,j]` nodes with parent nodes of `X[i,j]` and 4 combined with the operation `+`.
388 * However there is no need to record the derivatives of a constant, so
389 * instead the computation graph can be conceived as `Y[i,j]` nodes each with a single
390 * parent node of `X[i,j]` and the unary operation of `+4`.
391 */
392 pub fn constants<S>(c: S) -> Self
393 where
394 S: MatrixMut<T> + NoInteriorMutability,
395 {
396 RecordContainer {
397 numbers: MatrixView::from(Matrix::from_flat_row_major(
398 (c.view_rows(), c.view_columns()),
399 RowMajorOwnedIterator::from_numeric(c)
400 .map(|x| (x, 0))
401 .collect(),
402 )),
403 history: None,
404 }
405 }
406
407 /**
408 * Creates multiple records backed by the provided WengertList.
409 *
410 * The records cannot live longer than the WengertList, hence
411 * the following example does not compile
412 *
413 * ```compile_fail
414 * use easy_ml::differentiation::RecordMatrix;
415 * use easy_ml::differentiation::WengertList;
416 * use easy_ml::matrices::Matrix;
417 * let record = {
418 * let list = WengertList::new();
419 * RecordMatrix::variables(
420 * &list,
421 * Matrix::from(vec![vec![1.0, 2.0], vec![3.0, 4.0]])
422 * )
423 * }; // list no longer in scope
424 * ```
425 */
426 pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
427 where
428 S: MatrixMut<T> + NoInteriorMutability,
429 {
430 let total = x.view_rows() * x.view_columns();
431 let starting_index = history.append_nullary_repeating(total);
432 RecordContainer {
433 numbers: MatrixView::from(Matrix::from_flat_row_major(
434 (x.view_rows(), x.view_columns()),
435 RowMajorOwnedIterator::from_numeric(x)
436 .zip(calculate_incrementing_indexes(starting_index, total))
437 .collect(),
438 )),
439 history: Some(history),
440 }
441 }
442}
443
444impl<'a, T, S> RecordMatrix<'a, T, S>
445where
446 T: Numeric + Primitive,
447 S: MatrixRef<(T, Index)> + NoInteriorMutability,
448{
449 /**
450 * Returns the number of elements stored by this container's source.
451 *
452 * For a 2 x 3 Matrix, this would return 6, and for a 3 x 4 Matrix this would return 12
453 * and so on.
454 */
455 pub fn elements(&self) -> usize {
456 self.numbers.rows() * self.numbers.columns()
457 }
458
459 /**
460 * Returns the dimensionality of this matrix container in Row, Column format
461 */
462 pub fn size(&self) -> (Row, Column) {
463 self.numbers.size()
464 }
465
466 /**
467 * Gets the number of rows visible to this matrix container.
468 */
469 pub fn rows(&self) -> Row {
470 self.numbers.rows()
471 }
472
473 /**
474 * Gets the number of columns visible to this matrix container.
475 */
476 pub fn columns(&self) -> Column {
477 self.numbers.columns()
478 }
479
480 /**
481 * Creates a container from constants/variables directly, most likely obtained by getting a
482 * matrix view of an existing container. **The inputs are not checked for validity**. It is
483 * possible to pass in the wrong Wengert list here or even numbers with indexes that aren't
484 * tracked on the WengertList.
485 *
486 * It is recommended to use this constructor only in conjunction with
487 * resizing or masking an existing container and not for creating new variables. Any variables
488 * created outside of `RecordContainer::variables` would have to be manually added to the
489 * correct Wengert list, and any arithmetic operations would also need tracking correctly.
490 *
491 * ```
492 * use easy_ml::differentiation::RecordMatrix;
493 * use easy_ml::differentiation::WengertList;
494 * use easy_ml::matrices::Matrix;
495 * use easy_ml::matrices::views::{MatrixView, MatrixRange};
496 *
497 * let list = WengertList::new();
498 * let x = RecordMatrix::variables(
499 * &list,
500 * Matrix::from_fn((2, 2), |(r, c)| ((r + 3) * (c + 2)) as f64)
501 * );
502 * // oh no wrong shape!
503 * let fixed = MatrixView::from(MatrixRange::from(x, 0..2, 0..1));
504 * let x = RecordMatrix::from_existing(Some(&list), fixed);
505 * assert_eq!((2, 1), x.size());
506 * ```
507 */
508 pub fn from_existing(
509 history: Option<&'a WengertList<T>>,
510 numbers: MatrixView<(T, Index), S>,
511 ) -> Self {
512 RecordContainer { numbers, history }
513 }
514
515 /**
516 * Returns a MatrixView of this record container, both the `T` for each record element and
517 * also the index for that record's entry in the WengertList. These can be parsed back into
518 * a RecordMatrix with [`from_existing`](RecordMatrix::from_existing) or individually into
519 * [Record](Record)s with [`Record::from_existing`](Record::from_existing) to continue tracking
520 * numerical operations on the data.
521 */
522 pub fn view(&self) -> MatrixView<(T, Index), &RecordMatrix<'a, T, S>> {
523 MatrixView::from(self)
524 }
525
526 /**
527 * Returns an iterator over this record matrix as [Record]s instead of the raw `(T, Index)`
528 * data. After manipulating the iterator it can be collected back into a RecordMatrix with
529 * [RecordMatrix::from_iter](RecordMatrix::from_iter).
530 *
531 * This is a shorthand for `AsRecords::from(matrix.history(), RowMajorIterator::from(&matrix))`
532 */
533 #[allow(clippy::type_complexity)]
534 pub fn iter_row_major_as_records<'b>(
535 &'b self,
536 ) -> AsRecords<'a, RowMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
537 AsRecords::from_matrix_row_major(self)
538 }
539
540 /**
541 * Returns an iterator over this record matrix as [Record]s instead of the raw `(T, Index)`
542 * data. After manipulating the iterator it can be collected back into a RecordMatrix with
543 * [RecordMatrix::from_iter](RecordMatrix::from_iter).
544 *
545 * This is a shorthand for
546 * `AsRecords::from(matrix.history(), ColumnMajorIterator::from(&matrix))`
547 */
548 #[allow(clippy::type_complexity)]
549 pub fn iter_column_major_as_records<'b>(
550 &'b self,
551 ) -> AsRecords<'a, ColumnMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
552 AsRecords::from_matrix_column_major(self)
553 }
554
555 /**
556 * Returns a copy of the data at the index as a Record. If you need to access all the data
557 * as records instead of just a specific index you should probably use one of the iterator
558 * APIs instead.
559 *
560 * See also: [iter_row_major_as_records](RecordMatrix::iter_row_major_as_records),
561 * [iter_column_major_as_records](RecordMatrix::iter_column_major_as_records)
562 *
563 * # Panics
564 *
565 * If the index is out of range.
566 *
567 * For a non panicking API see [try_get_as_record](RecordMatrix::try_get_as_record)
568 */
569 #[track_caller]
570 pub fn get_as_record(&self, row: Row, column: Column) -> Record<'a, T> {
571 Record::from_existing(self.numbers.get(row, column), self.history)
572 }
573
574 /**
575 * Returns a copy of the data at the index as a Record, or None if the index is
576 * out of range. If you need to access all the data as records instead of just a specific
577 * index you should probably use one of the iterator APIs instead.
578 *
579 * See also: [iter_row_major_as_records](RecordMatrix::iter_row_major_as_records),
580 * [iter_column_major_as_records](RecordMatrix::iter_column_major_as_records)
581 */
582 pub fn try_get_as_record(&self, row: Row, column: Column) -> Option<Record<'a, T>> {
583 self.numbers
584 .try_get_reference(row, column)
585 .map(|r| Record::from_existing(r.clone(), self.history))
586 }
587}
588
589impl<'a, T, S> RecordMatrix<'a, T, S>
590where
591 T: Numeric + Primitive,
592 S: MatrixMut<(T, Index)> + NoInteriorMutability,
593{
594 /**
595 * Resets all of the records to place them back on the WengertList, for use
596 * in performing another derivation after clearing the WengertList.
597 *
598 * This is also a preferred shorthand for `map_mut(Record::do_reset)` that can't fail
599 */
600 pub fn reset(&mut self) {
601 match self.history {
602 None => (), // noop
603 Some(history) => {
604 let total = self.elements();
605 let starting_index = history.append_nullary_repeating(total);
606 for (x, i) in self
607 .numbers
608 .row_major_reference_mut_iter()
609 .zip(calculate_incrementing_indexes(starting_index, total))
610 {
611 let (_, old_index) = x;
612 *old_index = i;
613 }
614 }
615 };
616 }
617
618 /**
619 * A convenience helper function which takes a RecordContainer by value and
620 * calls [reset](RecordMatrix::reset()) on it.
621 */
622 pub fn do_reset(mut x: Self) -> Self {
623 x.reset();
624 x
625 }
626}
627
628impl<'a, T, S, const D: usize> RecordContainer<'a, T, S, D>
629where
630 T: Primitive,
631{
632 /**
633 * Gets the WengertList these records are backed by if variables, and [None](None) if constants.
634 */
635 pub fn history(&self) -> Option<&'a WengertList<T>> {
636 self.history
637 }
638}
639
640/// Returns the vec of indexes and vec of ys for Y = unary(X), not checking but assuming that the
641/// length of the iterator matches the total.
642fn unary<'a, T, I>(
643 total: usize,
644 history: &WengertList<T>,
645 records: I,
646 fx: impl Fn(T) -> T,
647 dfx_dx: impl Fn(T) -> T,
648) -> Vec<(T, usize)>
649where
650 I: Iterator<Item = (T, Index)>,
651 T: Numeric + Primitive,
652 for<'t> &'t T: NumericRef<T>,
653{
654 let mut ys = vec![(T::zero(), 0); total];
655 history.borrow(|history| {
656 // shadow the name so we can't accidentally try to use history while holding
657 // the borrow
658 // use enumerate not with_index because we need the 1D index for indexing
659 // indexes
660 for (i, (x, parent)) in records.enumerate() {
661 let y = fx(x.clone());
662 let derivative = dfx_dx(x);
663 let new_index = history.append_unary(parent, derivative);
664 ys[i] = (y, new_index)
665 }
666 }); // drop borrow on history
667 ys
668}
669
670/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), not checking but assuming that
671/// the length of the iterators match the total. Also assumes both inputs have the same shape
672fn binary_both_history<'a, T, I1, I2>(
673 total: usize,
674 history: &WengertList<T>,
675 x_records: I1,
676 y_records: I2,
677 fxy: impl Fn(T, T) -> T,
678 dfxy_dx: impl Fn(T, T) -> T,
679 dfxy_dy: impl Fn(T, T) -> T,
680) -> Vec<(T, usize)>
681where
682 I1: Iterator<Item = (T, Index)>,
683 I2: Iterator<Item = (T, Index)>,
684 T: Numeric + Primitive,
685 for<'t> &'t T: NumericRef<T>,
686{
687 let mut zs = vec![(T::zero(), 0); total];
688 history.borrow(|history| {
689 // shadow the name so we can't accidentally try to use history while holding
690 // the borrow
691 // use enumerate not with_index because we need the 1D index for indexing
692 // indexes
693 for (i, ((x, parent1), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
694 let z = fxy(x.clone(), y.clone());
695 let derivative1 = dfxy_dx(x.clone(), y.clone());
696 let derivative2 = dfxy_dy(x, y);
697 let new_index = history.append_binary(parent1, derivative1, parent2, derivative2);
698 zs[i] = (z, new_index);
699 }
700 }); // drop borrow on history
701 zs
702}
703
704/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), as with binary_both_history,
705/// but only tracking the derivatives for X, not Y.
706fn binary_x_history<'a, T, I1, I2>(
707 total: usize,
708 history: &WengertList<T>,
709 x_records: I1,
710 y_records: I2,
711 fxy: impl Fn(T, T) -> T,
712 dfxy_dx: impl Fn(T, T) -> T,
713) -> Vec<(T, usize)>
714where
715 I1: Iterator<Item = (T, Index)>,
716 I2: Iterator<Item = (T, Index)>,
717 T: Numeric + Primitive,
718 for<'t> &'t T: NumericRef<T>,
719{
720 let mut zs = vec![(T::zero(), 0); total];
721 history.borrow(|history| {
722 // shadow the name so we can't accidentally try to use history while holding
723 // the borrow
724 // use enumerate not with_index because we need the 1D index for indexing
725 // indexes
726 for (i, ((x, parent1), (y, _))) in (x_records.zip(y_records)).enumerate() {
727 let z = fxy(x.clone(), y.clone());
728 // if rhs didn't have a history, don't track that derivative
729 let derivative1 = dfxy_dx(x, y);
730 let new_index = history.append_unary(parent1, derivative1);
731 zs[i] = (z, new_index);
732 }
733 }); // drop borrow on history
734 zs
735}
736
737/// Returns the vec of indexes and vec of zs for Z = binary(X, Y), as with binary_both_history,
738/// but only tracking the derivatives for Y, not X.
739fn binary_y_history<'a, T, I1, I2>(
740 total: usize,
741 history: &WengertList<T>,
742 x_records: I1,
743 y_records: I2,
744 fxy: impl Fn(T, T) -> T,
745 dfxy_dy: impl Fn(T, T) -> T,
746) -> Vec<(T, usize)>
747where
748 I1: Iterator<Item = (T, Index)>,
749 I2: Iterator<Item = (T, Index)>,
750 T: Numeric + Primitive,
751 for<'t> &'t T: NumericRef<T>,
752{
753 let mut zs = vec![(T::zero(), 0); total];
754 history.borrow(|history| {
755 // shadow the name so we can't accidentally try to use history while holding
756 // the borrow
757 // use enumerate not with_index because we need the 1D index for indexing
758 // indexes
759 for (i, ((x, _), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
760 let z = fxy(x.clone(), y.clone());
761 // if self didn't have a history, don't track that derivative
762 let derivative2 = dfxy_dy(x, y);
763 let new_index = history.append_unary(parent2, derivative2);
764 zs[i] = (z, new_index);
765 }
766 }); // drop borrow on history
767 zs
768}
769
770impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
771where
772 T: Numeric + Primitive,
773 for<'t> &'t T: NumericRef<T>,
774 S: TensorRef<(T, Index), D>,
775{
776 /**
777 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
778 * some unary function from `T` to `T` to every element in the container.
779 *
780 * To compute the new records, the unary function of some input x to some
781 * output y is needed along with its derivative with respect to its input x.
782 *
783 * For example, tanh is a commonly used activation function, but the Real trait
784 * does not include this operation and Record has no operations for it specifically.
785 * However, you can use this function to compute the tanh for a record container like so:
786 *
787 * ```
788 * use easy_ml::differentiation::{RecordTensor, WengertList};
789 * use easy_ml::tensors::Tensor;
790 * let list = WengertList::new();
791 * let X = RecordTensor::variables(
792 * &list,
793 * Tensor::from_fn(
794 * [("rows", 2), ("columns", 2)],
795 * |[r, c]| 0.15 * ((1 + r + c) as f32)
796 * )
797 * );
798 * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
799 * // 1 / (cosh(x) * cosh(x))
800 * let Y = X.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
801 *
802 * // we can unwrap here because we know Y contains variables not constants
803 * let derivatives = Y.derivatives().unwrap();
804 * let derivatives_indexing = derivatives.index_by(["rows", "columns"]);
805 * assert_eq!(
806 * derivatives_indexing.get_ref([0, 0]).at_tensor(&X),
807 * Tensor::from(
808 * [("rows", 2), ("columns", 2)],
809 * // [0, 0] element in Y only had the one input variable [0, 0] in X
810 * vec![
811 * 0.9778332, 0.0,
812 * 0.0, 0.0
813 * ]
814 * ),
815 * );
816 * assert_eq!(
817 * derivatives_indexing.get_ref([0, 1]).at_tensor(&X),
818 * Tensor::from(
819 * [("rows", 2), ("columns", 2)],
820 * vec![
821 * 0.0, 0.915137,
822 * 0.0, 0.0
823 * ]
824 * ),
825 * );
826 * assert_eq!(
827 * // [0, 1] and [1, 0] elements in X had the same starting value so end up with the same
828 * // derivative for their corresponding input variable in X
829 * derivatives_indexing.get_ref([0, 1]).at_tensor(&X).index().get([0, 1]),
830 * derivatives_indexing.get_ref([1, 0]).at_tensor(&X).index().get([1, 0]),
831 * );
832 * assert_eq!(
833 * derivatives_indexing.get_ref([1, 1]).at_tensor(&X),
834 * Tensor::from(
835 * [("rows", 2), ("columns", 2)],
836 * vec![
837 * 0.0, 0.0,
838 * 0.0, 0.8220013
839 * ]
840 * ),
841 * );
842 * ```
843 */
844 #[track_caller]
845 pub fn unary(
846 &self,
847 fx: impl Fn(T) -> T,
848 dfx_dx: impl Fn(T) -> T,
849 ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D> {
850 let total = self.elements();
851 match self.history {
852 None => RecordTensor::constants(self.numbers.map(|(x, _)| fx(x))),
853 Some(history) => {
854 let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
855 RecordContainer {
856 numbers: self.numbers.new_with_same_shape(ys),
857 history: Some(history),
858 }
859 }
860 }
861 }
862
863 /**
864 * Creates a new RecordContainer from two RecordContainers by applying
865 * some binary function from `T` to `T` to every element pair in the containers. Both
866 * containers must have the same shape.
867 *
868 * To compute the new records, the binary function of some inputs x and y to some
869 * output z is needed along with its derivative with respect to its first input x and
870 * its derivative with respect to its second input y.
871 *
872 * For example, atan2 takes two arguments, but the Real trait
873 * does not include this operation and Record has no operations for it specifically.
874 * However, you can use this function to compute the atan2 for two record containers like so:
875 *
876 * ```
877 * use easy_ml::differentiation::{RecordTensor, WengertList};
878 * use easy_ml::tensors::Tensor;
879 * let list = WengertList::new();
880 * let X = RecordTensor::variables(
881 * &list,
882 * Tensor::from_fn(
883 * [("rows", 2), ("columns", 2)],
884 * |[r, c]| ((1 + r + c) as f32)
885 * )
886 * );
887 * let Y = RecordTensor::variables(
888 * &list,
889 * Tensor::from_fn(
890 * [("rows", 2), ("columns", 2)],
891 * |[r, c]| ((1 + r + c) as f32)
892 * )
893 * );
894 * // the derivative of atan2 with respect to x is y/(x*x + y*y)
895 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
896 * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
897 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
898 * let Z = X.binary(&Y,
899 * |x, y| x.atan2(y),
900 * |x, y| y/((x*x) + (y*y)),
901 * |x, y| -x/((x*x) + (y*y))
902 * );
903 *
904 *
905 * // we can unwrap here because we know Z contains variables not constants
906 * let derivatives = Z.derivatives().unwrap();
907 * // Just as in the unary example, only one pair of the four inputs in X and Y influence the
908 * // outputs in Z, so we have a lot of 0.0 derivatives, and the inputs in [0, 1] and [1, 0]
909 * // are identical so we see the same derivative.
910 * let dZ_dX = derivatives.map(|d| d.at_tensor(&X));
911 * assert_eq!(
912 * dZ_dX,
913 * Tensor::from([("rows", 2), ("columns", 2)], vec![
914 * Tensor::from([("rows", 2), ("columns", 2)], vec![
915 * 0.5, 0.0,
916 * 0.0, 0.0
917 * ]),
918 * Tensor::from([("rows", 2), ("columns", 2)], vec![
919 * 0.0, 0.25,
920 * 0.0, 0.0
921 * ]),
922 * Tensor::from([("rows", 2), ("columns", 2)], vec![
923 * 0.0, 0.0,
924 * 0.25, 0.0
925 * ]),
926 * Tensor::from([("rows", 2), ("columns", 2)], vec![
927 * 0.0, 0.0,
928 * 0.0, 0.16666667
929 * ])
930 * ])
931 * );
932 * let dZ_dY = derivatives.map(|d| d.at_tensor(&Y));
933 * assert_eq!(
934 * dZ_dY,
935 * Tensor::from([("rows", 2), ("columns", 2)], vec![
936 * Tensor::from([("rows", 2), ("columns", 2)], vec![
937 * -0.5, 0.0,
938 * 0.0, 0.0
939 * ]),
940 * Tensor::from([("rows", 2), ("columns", 2)], vec![
941 * 0.0, -0.25,
942 * 0.0, 0.0
943 * ]),
944 * Tensor::from([("rows", 2), ("columns", 2)], vec![
945 * 0.0, 0.0,
946 * -0.25, 0.0
947 * ]),
948 * Tensor::from([("rows", 2), ("columns", 2)], vec![
949 * 0.0, 0.0,
950 * 0.0, -0.16666667
951 * ])
952 * ])
953 * );
954 * ```
955 *
956 * # Panics
957 *
958 * - If both record containers have a WengertList that are different to each other
959 * - If the record containers have different shapes
960 */
961 #[track_caller]
962 pub fn binary<S2>(
963 &self,
964 rhs: &RecordTensor<'a, T, S2, D>,
965 fxy: impl Fn(T, T) -> T,
966 dfxy_dx: impl Fn(T, T) -> T,
967 dfxy_dy: impl Fn(T, T) -> T,
968 ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
969 where
970 S2: TensorRef<(T, Index), D>,
971 {
972 {
973 let left_shape = self.numbers.shape();
974 let right_shape = rhs.numbers.shape();
975 if left_shape != right_shape {
976 panic!(
977 "Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
978 left_shape, right_shape
979 );
980 }
981 }
982 let total = self.elements();
983 match (self.history, rhs.history) {
984 (None, None) => RecordTensor::constants(
985 // use direct_from here maybe?
986 Tensor::from(
987 self.numbers.shape(),
988 self.numbers
989 .iter()
990 .zip(rhs.numbers.iter())
991 .map(|((x, _), (y, _))| fxy(x, y))
992 .collect(),
993 ),
994 ),
995 (Some(history), None) => {
996 let zs = binary_x_history::<T, _, _>(
997 total,
998 history,
999 self.numbers.iter(),
1000 rhs.numbers.iter(),
1001 fxy,
1002 dfxy_dx,
1003 );
1004 RecordContainer {
1005 numbers: self.numbers.new_with_same_shape(zs),
1006 history: Some(history),
1007 }
1008 }
1009 (None, Some(history)) => {
1010 let zs = binary_y_history::<T, _, _>(
1011 total,
1012 history,
1013 self.numbers.iter(),
1014 rhs.numbers.iter(),
1015 fxy,
1016 dfxy_dy,
1017 );
1018 RecordContainer {
1019 numbers: self.numbers.new_with_same_shape(zs),
1020 history: Some(history),
1021 }
1022 }
1023 (Some(history), Some(h)) => {
1024 assert!(
1025 record_operations::same_lists(history, h),
1026 "Record containers must be using the same WengertList"
1027 );
1028 let zs = binary_both_history::<T, _, _>(
1029 total,
1030 history,
1031 self.numbers.iter(),
1032 rhs.numbers.iter(),
1033 fxy,
1034 dfxy_dx,
1035 dfxy_dy,
1036 );
1037 RecordContainer {
1038 numbers: self.numbers.new_with_same_shape(zs),
1039 history: Some(history),
1040 }
1041 }
1042 }
1043 }
1044
1045 /**
1046 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1047 * some unary function on `Record<T>` to `Record<T>` to every element in the container. This
1048 * will fail if the function would create records with inconsistent histories.
1049 *
1050 * When used with pure functions that can't return different histories for different inputs
1051 * unwrapping with always succeed.
1052 *
1053 * This API can allow you to call a generic function that operates on
1054 * [Numeric](crate::numeric::Numeric) or [Real](crate::numeric::extra::Real) numbers and
1055 * apply all the correct derivative tracking during the intermediate calculations for you,
1056 * without having to resort to storing the Record types.
1057 *
1058 * ```
1059 * use easy_ml::numeric::extra::Real;
1060 * use easy_ml::tensors::Tensor;
1061 * use easy_ml::differentiation::{RecordTensor, WengertList};
1062 *
1063 * fn sigmoid<T: Real+ Copy>(x: T) -> T {
1064 * T::one() / (T::one() + (-x).exp())
1065 * }
1066 *
1067 * let history = WengertList::new();
1068 * let layer = RecordTensor::variables(&history, Tensor::from([("x", 2)], vec![ 0.2, 0.6 ]));
1069 * let after = layer.map(sigmoid).unwrap(); // sigmoid can't introduce inconsistent histories
1070 * ```
1071 *
1072 * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1073 * after mapping doesn't have to be the same as before, only must be the same for every
1074 * mapped element.
1075 *
1076 * See also: [AsRecords](AsRecords)
1077 */
1078 #[allow(clippy::type_complexity)]
1079 #[track_caller]
1080 pub fn map(
1081 &self,
1082 fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1083 ) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
1084 let result = RecordTensor::from_iter(self.shape(), self.iter_as_records().map(fx));
1085 RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
1086 }
1087
1088 /**
1089 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1090 * some unary function on `Record<T>` and each index of that position in the Record to
1091 * `Record<T>` to every element in the container. This will fail if the function would create
1092 * records with inconsistent histories.
1093 *
1094 * When used with pure functions that can't return different histories for different inputs
1095 * unwrapping with always succeed.
1096 *
1097 * This API can allow you to call a generic function that operates on
1098 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1099 * during the intermediate calculations for you, without having to resort to storing the
1100 * Record types.
1101 *
1102 * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1103 * after mapping doesn't have to be the same as before, only must be the same for every
1104 * mapped element.
1105 */
1106 #[allow(clippy::type_complexity)]
1107 #[track_caller]
1108 pub fn map_with_index(
1109 &self,
1110 fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
1111 ) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
1112 let result = RecordTensor::from_iter(
1113 self.shape(),
1114 self.iter_as_records().with_index().map(|(i, x)| fx(i, x)),
1115 );
1116 RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
1117 }
1118
1119 #[track_caller]
1120 #[allow(clippy::type_complexity)]
1121 fn map_collection(
1122 result: Result<
1123 RecordTensor<'a, T, Tensor<(T, usize), D>, D>,
1124 InvalidRecordIteratorError<'a, T, D>,
1125 >,
1126 shape: [(Dimension, usize); D],
1127 ) -> Result<RecordTensor<'a, T, Tensor<(T, usize), D>, D>, InconsistentHistory<'a, T>> {
1128 use InvalidRecordIteratorError as Error;
1129 match result {
1130 Ok(tensor) => Ok(tensor),
1131 Err(error) => match error {
1132 // These first two should be 100% impossible but provide a sensible error just
1133 // in case some weird things break our invariants
1134 Error::Empty => panic!("Illegal state, record tensor was empty {:?}", shape),
1135 Error::Shape { requested, length } => panic!(
1136 "Illegal state, record tensor shape was inconsistent: requested: {:?}, length of data: {:?}",
1137 requested, length
1138 ),
1139 // This one is theoretically possible but in practise shouldn't happen by accident
1140 // However, it can't implement Debug unless T is debug so to avoid having to
1141 // restrict our function signature we return a Result anyway - this also encourages
1142 // the user to make sure their function isn't going to cause this case, which
1143 // with some of the other variants like with_index might come up more easily
1144 Error::InconsistentHistory(h) => Err(h),
1145 },
1146 }
1147 }
1148
1149 /**
1150 * For each record in the container, peforms a backward pass up its WengertList from it
1151 * as the output, computing all the derivatives for the inputs involving this output.
1152 *
1153 * If this container has no backing WengertList, ie was created as constants, then None is
1154 * returned instead. Otherwise the returned Tensor will have the same shape as this container,
1155 * with the respective derivatives matching each element in this container.
1156 *
1157 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is Y with M outputs,
1158 * then this computes all the derivatives δy<sub>j</sub>/δx<sub>i</sub> for i = 1 to N and
1159 * j = 1 to M.
1160 *
1161 * If you have a lot of outputs this could be very expensive! Reverse auto diff is optimised
1162 * for domains where there are many more inputs than outputs.
1163 *
1164 * If you only need some of the derivatives then
1165 * [derivatives_for](RecordTensor::derivatives_for) can be used instead to avoid
1166 * calculating the rest.
1167 */
1168 pub fn derivatives(&self) -> Option<Tensor<Derivatives<T>, D>> {
1169 self.history.map(|history| {
1170 self.numbers.map(|(x, i)| {
1171 Record {
1172 number: x,
1173 history: Some(history),
1174 index: i,
1175 }
1176 .derivatives()
1177 })
1178 })
1179 }
1180
1181 /**
1182 * For the record at the index, peforms a backward pass up its WengertList from it
1183 * as the output, computing all the derivatives for the inputs involving this output.
1184 *
1185 * If the index is invalid or this container has no backing WengertList, ie was created
1186 * as constants, then None is returned instead.
1187 *
1188 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
1189 * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
1190 */
1191 pub fn derivatives_for(&self, indexes: [usize; D]) -> Option<Derivatives<T>> {
1192 let (number, index) = self.get_reference(indexes).map(|(x, i)| (x.clone(), *i))?;
1193 // The nature of reverse autodiff is that we expect to only have a few outputs from
1194 // which we calculate all the derivatives we care about. Therefore just call Record and
1195 // reuse the implementation instead of trying to do anything clever like calculate all
1196 // derivatives for every number in this container.
1197 Record {
1198 number,
1199 history: self.history,
1200 index,
1201 }
1202 .try_derivatives()
1203 }
1204
1205 /**
1206 * Performs elementwise multiplication for two record tensors of the same shape.
1207 *
1208 * # Panics
1209 *
1210 * - If both record containers have a WengertList that are different to each other
1211 * - If the record containers have different shapes
1212 */
1213 // TODO: Assign variants?
1214 pub fn elementwise_multiply<S2>(
1215 &self,
1216 other: &RecordTensor<'a, T, S2, D>,
1217 ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1218 where
1219 S2: TensorRef<(T, Index), D>,
1220 {
1221 self.binary(
1222 other,
1223 Multiplication::<T>::function,
1224 Multiplication::<T>::d_function_dx,
1225 Multiplication::<T>::d_function_dy,
1226 )
1227 }
1228
1229 /**
1230 * Performs elementwise division for two record tensors of the same shape.
1231 *
1232 * # Panics
1233 *
1234 * - If both record containers have a WengertList that are different to each other
1235 * - If the record containers have different shapes
1236 */
1237 pub fn elementwise_divide<S2>(
1238 &self,
1239 other: &RecordTensor<'a, T, S2, D>,
1240 ) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
1241 where
1242 S2: TensorRef<(T, Index), D>,
1243 {
1244 self.binary(
1245 other,
1246 Division::<T>::function,
1247 Division::<T>::d_function_dx,
1248 Division::<T>::d_function_dy,
1249 )
1250 }
1251}
1252
1253impl<T: Clone + Primitive> Derivatives<T> {
1254 /**
1255 * Queries the derivative at the provided index into the record tensor as input.
1256 *
1257 * If you construct a Derivatives object for some output y,
1258 * and call .at_tensor_index(i, &xs) on it for some input container xs and index i, this
1259 * returns dy/dx where x = xs\[i\].
1260 *
1261 * If the index into the tensor is invalid, returns None instead.
1262 */
1263 pub fn at_tensor_index<S, const D: usize>(
1264 &self,
1265 indexes: [usize; D],
1266 input: &RecordTensor<T, S, D>,
1267 ) -> Option<T>
1268 where
1269 S: TensorRef<(T, Index), D>,
1270 {
1271 let index = input.get_reference(indexes).map(|(_, i)| *i)?;
1272 Some(self.derivatives[index].clone())
1273 }
1274
1275 /**
1276 * Queries the derivatives at every element in the record tensor input.
1277 *
1278 * If you construct a Derivatives object for some output y,
1279 * and call .at_tensor(&xs) on it for some input container xs this
1280 * returns dy/dx for every x in xs.
1281 */
1282 pub fn at_tensor<S, const D: usize>(&self, input: &RecordTensor<T, S, D>) -> Tensor<T, D>
1283 where
1284 S: TensorRef<(T, Index), D>,
1285 {
1286 input.numbers.map(|(_, i)| self.derivatives[i].clone())
1287 }
1288
1289 /**
1290 * Queries the derivative at the provided index into the record matrix as input.
1291 *
1292 * If you construct a Derivatives object for some output y,
1293 * and call .at_matrix_index(i, j, &xs) on it for some input container xs and indexes i and j,
1294 * this returns dy/dx where x = xs\[i, j\].
1295 *
1296 * If the index into the tensor is invalid, returns None instead.
1297 */
1298 pub fn at_matrix_index<S>(
1299 &self,
1300 row: Row,
1301 column: Column,
1302 input: &RecordMatrix<T, S>,
1303 ) -> Option<T>
1304 where
1305 S: MatrixRef<(T, Index)> + NoInteriorMutability,
1306 {
1307 let index = input.try_get_reference(row, column).map(|(_, i)| *i)?;
1308 Some(self.derivatives[index].clone())
1309 }
1310
1311 /**
1312 * Queries the derivatives at every element in the record matrix input.
1313 *
1314 * If you construct a Derivatives object for some output y,
1315 * and call .at_matrix(&xs) on it for some input container xs this
1316 * returns dy/dx for every x in xs.
1317 */
1318 pub fn at_matrix<S>(&self, input: &RecordMatrix<T, S>) -> Matrix<T>
1319 where
1320 S: MatrixRef<(T, Index)> + NoInteriorMutability,
1321 {
1322 input.numbers.map(|(_, i)| self.derivatives[i].clone())
1323 }
1324}
1325
1326impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
1327where
1328 T: Numeric + Primitive,
1329 for<'t> &'t T: NumericRef<T>,
1330 S: TensorMut<(T, Index), D>,
1331{
1332 /**
1333 * Overwrites a RecordContainer by applying
1334 * some unary function from `T` to `T` to every element in the container.
1335 *
1336 * To compute the new records, the unary function of some input x to some
1337 * output y is needed along with its derivative with respect to its input x.
1338 */
1339 #[track_caller]
1340 pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
1341 let total = self.elements();
1342 match self.history {
1343 None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
1344 Some(history) => {
1345 let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
1346 for (element, result) in self.numbers.iter_reference_mut().zip(ys) {
1347 *element = result;
1348 }
1349 self.history = Some(history);
1350 }
1351 }
1352 }
1353
1354 /**
1355 * Overwrites the left hand side of a RecordContainer with the result of applying
1356 * some binary function from `T` to `T` to every element pair in the containers. Both
1357 * containers must have the same shape.
1358 * To compute the new records, the binary function of some inputs x and y to some
1359 * output z is needed along with its derivative with respect to its first input x and
1360 * its derivative with respect to its second input y.
1361 *
1362 * # Panics
1363 *
1364 * - If both record containers have a WengertList that are different to each other
1365 * - If the record containers have different shapes
1366 */
1367 #[track_caller]
1368 pub fn binary_left_assign<S2>(
1369 &mut self,
1370 rhs: &RecordTensor<'a, T, S2, D>,
1371 fxy: impl Fn(T, T) -> T,
1372 dfxy_dx: impl Fn(T, T) -> T,
1373 dfxy_dy: impl Fn(T, T) -> T,
1374 ) where
1375 S2: TensorRef<(T, Index), D>,
1376 {
1377 {
1378 let left_shape = self.numbers.shape();
1379 let right_shape = rhs.numbers.shape();
1380 if left_shape != right_shape {
1381 panic!(
1382 "Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
1383 left_shape, right_shape
1384 );
1385 }
1386 }
1387 let total = self.elements();
1388 match (self.history, rhs.history) {
1389 (None, None) => {
1390 for (x, y) in self.numbers.iter_reference_mut().zip(rhs.numbers.iter()) {
1391 let (left, _) = x;
1392 let (right, _) = y;
1393 *x = (fxy(left.clone(), right), 0);
1394 }
1395 }
1396 (Some(history), None) => {
1397 let zs = binary_x_history::<T, _, _>(
1398 total,
1399 history,
1400 self.numbers.iter(),
1401 rhs.numbers.iter(),
1402 fxy,
1403 dfxy_dx,
1404 );
1405 for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1406 *element = result;
1407 }
1408 self.history = Some(history);
1409 }
1410 (None, Some(history)) => {
1411 let zs = binary_y_history::<T, _, _>(
1412 total,
1413 history,
1414 self.numbers.iter(),
1415 rhs.numbers.iter(),
1416 fxy,
1417 dfxy_dy,
1418 );
1419 for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1420 *element = result;
1421 }
1422 self.history = Some(history);
1423 }
1424 (Some(history), Some(h)) => {
1425 assert!(
1426 record_operations::same_lists(history, h),
1427 "Record containers must be using the same WengertList"
1428 );
1429 let zs = binary_both_history::<T, _, _>(
1430 total,
1431 history,
1432 self.numbers.iter(),
1433 rhs.numbers.iter(),
1434 fxy,
1435 dfxy_dx,
1436 dfxy_dy,
1437 );
1438 for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
1439 *element = result;
1440 }
1441 self.history = Some(history);
1442 }
1443 }
1444 }
1445
1446 /**
1447 * A convenience helper function which takes the RecordContainer value and
1448 * calls [unary_assign](RecordTensor::unary_assign()) on it, returning
1449 * the record container which now contains the result of the operation.
1450 */
1451 #[track_caller]
1452 pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
1453 self.unary_assign(fx, dfx_dx);
1454 self
1455 }
1456
1457 /**
1458 * A convenience helper function which takes the left hand side by value and
1459 * calls [binary_left_assign](RecordTensor::binary_left_assign()) on it, returning
1460 * the left hand side which now contains the result of the operation.
1461 */
1462 #[track_caller]
1463 pub fn do_binary_left_assign<S2>(
1464 mut self,
1465 rhs: &RecordTensor<'a, T, S2, D>,
1466 fxy: impl Fn(T, T) -> T,
1467 dfxy_dx: impl Fn(T, T) -> T,
1468 dfxy_dy: impl Fn(T, T) -> T,
1469 ) -> Self
1470 where
1471 S2: TensorRef<(T, Index), D>,
1472 {
1473 self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
1474 self
1475 }
1476
1477 /**
1478 * Updates this RecordContainer in place by applying some unary function on `Record<T>` to
1479 * `Record<T>` to every element in the container. This will fail if the function would create
1480 * records with inconsistent histories.
1481 *
1482 * When used with pure functions that can't return different histories for different inputs
1483 * unwrapping with always succeed.
1484 *
1485 * Since this updates the container in place, if Err is returned then the data in this
1486 * RecordContainer is still available but it has been corrupted - at least one of the elements
1487 * should have a different history than what it will have because the mapping function created
1488 * inconsistent histories that couldn't be represented by the container as it only stores
1489 * one.
1490 *
1491 * This API can allow you to call a generic function that operates on
1492 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1493 * during the intermediate calculations for you, without having to resort to storing the
1494 * Record types.
1495 *
1496 * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1497 * after mapping doesn't have to be the same as before, only must be the same for every
1498 * mapped element.
1499 *
1500 * You might also use this function at the end of a training loop to update all the weights
1501 * to reduce their loss.
1502 *
1503 * ```
1504 * use easy_ml::numeric::Numeric;
1505 * use easy_ml::tensors::Tensor;
1506 * use easy_ml::differentiation::{Record, RecordTensor, WengertList};
1507 *
1508 * let history = WengertList::new();
1509 * let mut weights = RecordTensor::variables(
1510 * &history,
1511 * Tensor::from([("w1", 4)], vec![ 0.3, 0.2, -1.2, -0.4 ])
1512 * );
1513 * let error = {
1514 * // this is over-simplified for brevity, obviously in a real scenario we wouldn't have a
1515 * // function that calculates the error like this or we wouldn't be doing machine learning
1516 * // to fit it in the first place
1517 * let mut loss = Record::variable(0.0, &history);
1518 * for r in weights.iter_as_records() {
1519 * loss = loss + r;
1520 * }
1521 * loss
1522 * };
1523 * let derivatives = error.derivatives();
1524 * let learning_rate = 0.1;
1525 * // update the weights to contain less error than they did before
1526 * let result = weights.map_mut(|x| x - (derivatives[&x] * learning_rate));
1527 * assert!(result.is_ok()); // we know we didn't introduce an inconsistent history just updating the weights
1528 * ```
1529 */
1530 #[track_caller]
1531 pub fn map_mut(
1532 &mut self,
1533 fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1534 ) -> Result<(), InconsistentHistory<'a, T>> {
1535 let history = self.history;
1536 let new_history =
1537 map_mut_base::<'a, T, _, _>(TensorReferenceMutIterator::from(self), |x| {
1538 let record = Record::from_existing(x.clone(), history);
1539 let result = fx(record);
1540 *x = (result.number, result.index);
1541 result.history
1542 })?;
1543 self.history = new_history;
1544 Ok(())
1545 }
1546
1547 /**
1548 * Updates this RecordContainer in place by applying some unary function on `Record<T>` and
1549 * each index of that position in the Record to `Record<T>` to every element in the container.
1550 * This will fail if the function would create records with inconsistent histories.
1551 *
1552 * When used with pure functions that can't return different histories for different inputs
1553 * unwrapping with always succeed.
1554 *
1555 * Since this updates the container in place, if Err is returned then the data in this
1556 * RecordContainer is still available but it has been corrupted - at least one of the elements
1557 * should have a different history than what it will have because the mapping function created
1558 * inconsistent histories that couldn't be represented by the container as it only stores
1559 * one.
1560 *
1561 * This API can allow you to call a generic function that operates on
1562 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1563 * during the intermediate calculations for you, without having to resort to storing the
1564 * Record types.
1565 *
1566 * NB: Mapping a RecordTensor of constants to variables is not inconsistent, the history
1567 * after mapping doesn't have to be the same as before, only must be the same for every
1568 * mapped element.
1569 */
1570 #[track_caller]
1571 pub fn map_mut_with_index(
1572 &mut self,
1573 fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
1574 ) -> Result<(), InconsistentHistory<'a, T>> {
1575 let history = self.history;
1576 let new_history = map_mut_base::<'a, T, _, _>(
1577 TensorReferenceMutIterator::from(self).with_index(),
1578 |(i, x)| {
1579 let record = Record::from_existing(x.clone(), history);
1580 let result = fx(i, record);
1581 *x = (result.number, result.index);
1582 result.history
1583 },
1584 )?;
1585 self.history = new_history;
1586 Ok(())
1587 }
1588}
1589
1590#[track_caller]
1591fn map_mut_base<'a, T, I, X>(
1592 mut iter: I,
1593 fx: impl Fn(X) -> Option<&'a WengertList<T>>,
1594) -> Result<Option<&'a WengertList<T>>, InconsistentHistory<'a, T>>
1595where
1596 I: Iterator<Item = X>,
1597 T: Primitive,
1598{
1599 use crate::differentiation::record_operations::are_exact_same_list;
1600 #[rustfmt::skip]
1601 let first_history = fx(iter.next().expect("Illegal state, record container was empty"));
1602 let mut different_history: Option<Option<&WengertList<T>>> = None;
1603 for x in iter {
1604 let history = fx(x);
1605 if !are_exact_same_list(history, first_history) {
1606 different_history = Some(history);
1607 }
1608 }
1609 match different_history {
1610 None => Ok(first_history),
1611 Some(h) => Err(InconsistentHistory {
1612 first: first_history,
1613 later: h,
1614 }),
1615 }
1616}
1617
1618impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
1619where
1620 T: Numeric + Primitive,
1621 for<'t> &'t T: NumericRef<T>,
1622 S: TensorRef<(T, Index), D>,
1623{
1624 /**
1625 * Overwrites the right hand side of a RecordContainer with the result of applying
1626 * some binary function from `T` to `T` to every element pair in the containers. Both
1627 * containers must have the same shape.
1628 * To compute the new records, the binary function of some inputs x and y to some
1629 * output z is needed along with its derivative with respect to its first input x and
1630 * its derivative with respect to its second input y.
1631 *
1632 * # Panics
1633 *
1634 * - If both record containers have a WengertList that are different to each other
1635 * - If the record containers have different shapes
1636 */
1637 #[track_caller]
1638 pub fn binary_right_assign<S2>(
1639 &self,
1640 rhs: &mut RecordTensor<'a, T, S2, D>,
1641 fxy: impl Fn(T, T) -> T,
1642 dfxy_dx: impl Fn(T, T) -> T,
1643 dfxy_dy: impl Fn(T, T) -> T,
1644 ) where
1645 S2: TensorMut<(T, Index), D>,
1646 {
1647 // x is lhs, y is rhs, so calling binary_left_assign on the rhs container
1648 // means we need to swap all the arguments
1649 rhs.binary_left_assign(
1650 self,
1651 |y, x| fxy(x, y),
1652 |y, x| dfxy_dy(x, y),
1653 |y, x| dfxy_dx(x, y),
1654 )
1655 }
1656
1657 /**
1658 * A convenience helper function which takes the right hand side by value and
1659 * calls [binary_right_assign](RecordTensor::binary_right_assign()) on it, returning
1660 * the right hand side which now contains the result of the operation.
1661 */
1662 #[track_caller]
1663 pub fn do_binary_right_assign<S2>(
1664 &self,
1665 mut rhs: RecordTensor<'a, T, S2, D>,
1666 fxy: impl Fn(T, T) -> T,
1667 dfxy_dx: impl Fn(T, T) -> T,
1668 dfxy_dy: impl Fn(T, T) -> T,
1669 ) -> RecordTensor<'a, T, S2, D>
1670 where
1671 S2: TensorMut<(T, Index), D>,
1672 {
1673 self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
1674 rhs
1675 }
1676}
1677
1678impl<'a, T, S> RecordMatrix<'a, T, S>
1679where
1680 T: Numeric + Primitive,
1681 for<'t> &'t T: NumericRef<T>,
1682 S: MatrixRef<(T, Index)> + NoInteriorMutability,
1683{
1684 /**
1685 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1686 * some unary function from `T` to `T` to every element in the container.
1687 *
1688 * To compute the new records, the unary function of some input x to some
1689 * output y is needed along with its derivative with respect to its input x.
1690 *
1691 * For example, tanh is a commonly used activation function, but the Real trait
1692 * does not include this operation and Record has no operations for it specifically.
1693 * However, you can use this function to compute the tanh for a record container like so:
1694 *
1695 * ```
1696 * use easy_ml::differentiation::{RecordMatrix, WengertList};
1697 * use easy_ml::matrices::Matrix;
1698 * let list = WengertList::new();
1699 * let X = RecordMatrix::variables(
1700 * &list,
1701 * Matrix::from_fn((2, 2), |(r, c)| 0.15 * ((1 + r + c) as f32))
1702 * );
1703 * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
1704 * // 1 / (cosh(x) * cosh(x))
1705 * let Y = X.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
1706 *
1707 * // we can unwrap here because we know Y contains variables not constants
1708 * let derivatives = Y.derivatives().unwrap();
1709 * assert_eq!(
1710 * derivatives.get_reference(0, 0).at_matrix(&X),
1711 * Matrix::from(vec![
1712 * // (0, 0) element in Y only had the one input variable (0, 0) in X
1713 * vec![0.9778332, 0.0],
1714 * vec![0.0, 0.0]
1715 * ]),
1716 * );
1717 * assert_eq!(
1718 * derivatives.get_reference(0, 1).at_matrix(&X),
1719 * Matrix::from(vec![
1720 * vec![0.0, 0.915137],
1721 * vec![0.0, 0.0]
1722 * ]),
1723 * );
1724 * assert_eq!(
1725 * // (0, 1) and (1, 0) elements in X had the same starting value so end up with the same
1726 * // derivative for their corresponding input variable in X
1727 * derivatives.get_reference(0, 1).at_matrix(&X).get(0, 1),
1728 * derivatives.get_reference(1, 0).at_matrix(&X).get(1, 0),
1729 * );
1730 * assert_eq!(
1731 * derivatives.get_reference(1, 1).at_matrix(&X),
1732 * Matrix::from(vec![
1733 * vec![0.0, 0.0 ],
1734 * vec![0.0, 0.8220013]
1735 * ]),
1736 * );
1737 * ```
1738 */
1739 #[track_caller]
1740 pub fn unary(
1741 &self,
1742 fx: impl Fn(T) -> T,
1743 dfx_dx: impl Fn(T) -> T,
1744 ) -> RecordMatrix<'a, T, Matrix<(T, Index)>> {
1745 let total = self.elements();
1746 match self.history {
1747 None => RecordMatrix::constants(self.numbers.map(|(x, _)| fx(x))),
1748 Some(history) => {
1749 let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
1750 RecordContainer {
1751 numbers: MatrixView::from(Matrix::from_flat_row_major(self.numbers.size(), ys)),
1752 history: Some(history),
1753 }
1754 }
1755 }
1756 }
1757
1758 /**
1759 * Creates a new RecordContainer from two RecordContainers by applying
1760 * some binary function from `T` to `T` to every element pair in the containers. Both
1761 * containers must have the same shape.
1762 *
1763 * To compute the new records, the binary function of some inputs x and y to some
1764 * output z is needed along with its derivative with respect to its first input x and
1765 * its derivative with respect to its second input y.
1766 *
1767 * For example, atan2 takes two arguments, but the Real trait
1768 * does not include this operation and Record has no operations for it specifically.
1769 * However, you can use this function to compute the atan2 for two record containers like so:
1770 *
1771 * ```
1772 * use easy_ml::differentiation::{RecordMatrix, WengertList};
1773 * use easy_ml::matrices::Matrix;
1774 * let list = WengertList::new();
1775 * let X = RecordMatrix::variables(
1776 * &list,
1777 * Matrix::from_fn((2, 2), |(r, c)| ((1 + r + c) as f32))
1778 * );
1779 * let Y = RecordMatrix::variables(
1780 * &list,
1781 * Matrix::from_fn((2, 2), |(r, c)| ((1 + r + c) as f32))
1782 * );
1783 * // the derivative of atan2 with respect to x is y/(x*x + y*y)
1784 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
1785 * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
1786 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
1787 * let Z = X.binary(&Y,
1788 * |x, y| x.atan2(y),
1789 * |x, y| y/((x*x) + (y*y)),
1790 * |x, y| -x/((x*x) + (y*y))
1791 * );
1792 *
1793 * // we can unwrap here because we know Z contains variables not constants
1794 * let derivatives = Z.derivatives().unwrap();
1795 * // Just as in the unary example, only one pair of the four inputs in X and Y influence the
1796 * // outputs in Z, so we have a lot of 0.0 derivatives, and the inputs in [0, 1] and [1, 0]
1797 * // are identical so we see the same derivative.
1798 * let dZ_dX = derivatives.map(|d| d.at_matrix(&X));
1799 * assert_eq!(
1800 * dZ_dX,
1801 * Matrix::from(vec![
1802 * vec![
1803 * Matrix::from(vec![
1804 * vec![ 0.5, 0.0 ],
1805 * vec![ 0.0, 0.0 ]
1806 * ]),
1807 * Matrix::from(vec![
1808 * vec![ 0.0, 0.25 ],
1809 * vec![ 0.0, 0.0 ]
1810 * ])
1811 * ],
1812 * vec![
1813 * Matrix::from(vec![
1814 * vec![ 0.0, 0.0 ],
1815 * vec![ 0.25, 0.0 ]
1816 * ]),
1817 * Matrix::from(vec![
1818 * vec![ 0.0, 0.0 ],
1819 * vec![ 0.0, 0.16666667 ]
1820 * ])
1821 * ]
1822 * ])
1823 * );
1824 * let dZ_dY = derivatives.map(|d| d.at_matrix(&Y));
1825 * assert_eq!(
1826 * dZ_dY,
1827 * Matrix::from(vec![
1828 * vec![
1829 * Matrix::from(vec![
1830 * vec![ -0.5, 0.0 ],
1831 * vec![ 0.0, 0.0 ]
1832 * ]),
1833 * Matrix::from(vec![
1834 * vec![ 0.0, -0.25 ],
1835 * vec![ 0.0, 0.0 ]
1836 * ])
1837 * ],
1838 * vec![
1839 * Matrix::from(vec![
1840 * vec![ 0.0, 0.0 ],
1841 * vec![ -0.25, 0.0 ]
1842 * ]),
1843 * Matrix::from(vec![
1844 * vec![ 0.0, 0.0 ],
1845 * vec![ 0.0, -0.16666667 ]
1846 * ])
1847 * ]
1848 * ])
1849 * );
1850 * ```
1851 *
1852 * # Panics
1853 *
1854 * - If both record containers have a WengertList that are different to each other
1855 * - If the record containers have different shapes
1856 */
1857 #[track_caller]
1858 pub fn binary<S2>(
1859 &self,
1860 rhs: &RecordMatrix<'a, T, S2>,
1861 fxy: impl Fn(T, T) -> T,
1862 dfxy_dx: impl Fn(T, T) -> T,
1863 dfxy_dy: impl Fn(T, T) -> T,
1864 ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
1865 where
1866 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
1867 {
1868 let shape = {
1869 let left_shape = self.numbers.size();
1870 let right_shape = rhs.numbers.size();
1871 if left_shape != right_shape {
1872 panic!(
1873 "Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
1874 left_shape, right_shape
1875 );
1876 }
1877 left_shape
1878 };
1879 let total = self.elements();
1880 match (self.history, rhs.history) {
1881 (None, None) => RecordMatrix::constants(Matrix::from_flat_row_major(
1882 shape,
1883 self.numbers
1884 .row_major_iter()
1885 .zip(rhs.numbers.row_major_iter())
1886 .map(|((x, _), (y, _))| fxy(x, y))
1887 .collect(),
1888 )),
1889 (Some(history), None) => {
1890 let zs = binary_x_history::<T, _, _>(
1891 total,
1892 history,
1893 self.numbers.row_major_iter(),
1894 rhs.numbers.row_major_iter(),
1895 fxy,
1896 dfxy_dx,
1897 );
1898 RecordContainer {
1899 numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1900 history: Some(history),
1901 }
1902 }
1903 (None, Some(history)) => {
1904 let zs = binary_y_history::<T, _, _>(
1905 total,
1906 history,
1907 self.numbers.row_major_iter(),
1908 rhs.numbers.row_major_iter(),
1909 fxy,
1910 dfxy_dy,
1911 );
1912 RecordContainer {
1913 numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1914 history: Some(history),
1915 }
1916 }
1917 (Some(history), Some(h)) => {
1918 assert!(
1919 record_operations::same_lists(history, h),
1920 "Record containers must be using the same WengertList"
1921 );
1922 let zs = binary_both_history::<T, _, _>(
1923 total,
1924 history,
1925 self.numbers.row_major_iter(),
1926 rhs.numbers.row_major_iter(),
1927 fxy,
1928 dfxy_dx,
1929 dfxy_dy,
1930 );
1931 RecordContainer {
1932 numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
1933 history: Some(history),
1934 }
1935 }
1936 }
1937 }
1938
1939 /**
1940 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1941 * some unary function on `Record<T>` to `Record<T>` to every element in the container. This
1942 * will fail if the function would create records with inconsistent histories.
1943 *
1944 * When used with pure functions that can't return different histories for different inputs
1945 * unwrapping with always succeed.
1946 *
1947 * This API can allow you to call a generic function that operates on
1948 * [Numeric](crate::numeric::Numeric) or [Real](crate::numeric::extra::Real) numbers and
1949 * apply all the correct derivative tracking during the intermediate calculations for you,
1950 * without having to resort to storing the Record types.
1951 *
1952 * ```
1953 * use easy_ml::numeric::extra::Real;
1954 * use easy_ml::matrices::Matrix;
1955 * use easy_ml::differentiation::{RecordMatrix, WengertList};
1956 *
1957 * fn sigmoid<T: Real+ Copy>(x: T) -> T {
1958 * T::one() / (T::one() + (-x).exp())
1959 * }
1960 *
1961 * let history = WengertList::new();
1962 * let layer = RecordMatrix::variables(&history, Matrix::from(vec![vec![ 0.2, 0.6 ]]));
1963 * let after = layer.map(sigmoid).unwrap(); // sigmoid can't introduce inconsistent histories
1964 * ```
1965 *
1966 * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
1967 * after mapping doesn't have to be the same as before, only must be the same for every
1968 * mapped element.
1969 *
1970 * See also: [AsRecords](AsRecords)
1971 */
1972 #[allow(clippy::type_complexity)]
1973 #[track_caller]
1974 pub fn map(
1975 &self,
1976 fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
1977 ) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
1978 let result = RecordMatrix::from_iter(self.size(), self.iter_row_major_as_records().map(fx));
1979 RecordMatrix::<'a, T, S>::map_collection(result, self.size())
1980 }
1981
1982 /**
1983 * Creates a new RecordContainer from a reference to an existing RecordContainer by applying
1984 * some unary function on `Record<T>` and each index of that position in the Record to
1985 * `Record<T>` to every element in the container. This will fail if the function would
1986 * create records with inconsistent histories.
1987 *
1988 * When used with pure functions that can't return different histories for different inputs
1989 * unwrapping with always succeed.
1990 *
1991 * This API can allow you to call a generic function that operates on
1992 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
1993 * during the intermediate calculations for you, without having to resort to storing the
1994 * Record types.
1995 *
1996 * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
1997 * after mapping doesn't have to be the same as before, only must be the same for every
1998 * mapped element.
1999 */
2000 #[allow(clippy::type_complexity)]
2001 #[track_caller]
2002 pub fn map_with_index(
2003 &self,
2004 fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
2005 ) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
2006 let result = RecordMatrix::from_iter(
2007 self.size(),
2008 self.iter_row_major_as_records()
2009 .with_index()
2010 .map(|((r, c), x)| fx(x, r, c)),
2011 );
2012 RecordMatrix::<'a, T, S>::map_collection(result, self.size())
2013 }
2014
2015 #[allow(clippy::type_complexity)]
2016 #[track_caller]
2017 fn map_collection(
2018 result: Result<
2019 RecordMatrix<'a, T, Matrix<(T, usize)>>,
2020 InvalidRecordIteratorError<'a, T, 2>,
2021 >,
2022 size: (Row, Column),
2023 ) -> Result<RecordMatrix<'a, T, Matrix<(T, usize)>>, InconsistentHistory<'a, T>> {
2024 use InvalidRecordIteratorError as Error;
2025 match result {
2026 Ok(matrix) => Ok(matrix),
2027 Err(error) => match error {
2028 // These first two should be 100% impossible but provide a sensible error just
2029 // in case some weird things break our invariants
2030 Error::Empty => panic!("Illegal state, record matrix was empty {:?}", size),
2031 Error::Shape { requested, length } => panic!(
2032 "Illegal state, record matrix shape was inconsistent: requested: {:?}, length of data: {:?}",
2033 requested, length
2034 ),
2035 // This one is theoretically possible but in practise shouldn't happen by accident
2036 // However, it can't implement Debug unless T is debug so to avoid having to
2037 // restrict our function signature we return a Result anyway - this also encourages
2038 // the user to make sure their function isn't going to cause this case, which
2039 // with some of the other variants like with_index might come up more easily
2040 Error::InconsistentHistory(h) => Err(h),
2041 },
2042 }
2043 }
2044
2045 /**
2046 * For each record in the container, peforms a backward pass up its WengertList from it
2047 * as the output, computing all the derivatives for the inputs involving this output.
2048 *
2049 * If this container has no backing WengertList, ie was created as constants, then None is
2050 * returned instead. Otherwise the returned Matrix will have the same size as this container,
2051 * with the respective derivatives matching each element in this container.
2052 *
2053 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is Y with M outputs,
2054 * then this computes all the derivatives δy<sub>j</sub>/δx<sub>i</sub> for i = 1 to N and
2055 * j = 1 to M.
2056 *
2057 * If you have a lot of outputs this could be very expensive! Reverse auto diff is optimised
2058 * for domains where there are many more inputs than outputs.
2059 *
2060 * If you only need some of the derivatives then
2061 * [derivatives_for](RecordMatrix::derivatives_for) can be used instead to avoid
2062 * calculating the rest.
2063 */
2064 pub fn derivatives(&self) -> Option<Matrix<Derivatives<T>>> {
2065 self.history.map(|history| {
2066 self.numbers.map(|(x, i)| {
2067 Record {
2068 number: x,
2069 history: Some(history),
2070 index: i,
2071 }
2072 .derivatives()
2073 })
2074 })
2075 }
2076
2077 /**
2078 * For the record at the index, peforms a backward pass up its WengertList from it
2079 * as the output, computing all the derivatives for the inputs involving this output.
2080 *
2081 * If the index is invalid or this container has no backing WengertList, ie was created
2082 * as constants, then None is returned instead.
2083 *
2084 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
2085 * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
2086 */
2087 pub fn derivatives_for(&self, row: Row, column: Column) -> Option<Derivatives<T>> {
2088 let (number, index) = self
2089 .try_get_reference(row, column)
2090 .map(|(x, i)| (x.clone(), *i))?;
2091 // The nature of reverse autodiff is that we expect to only have a few outputs from
2092 // which we calculate all the derivatives we care about. Therefore just call Record and
2093 // reuse the implementation instead of trying to do anything clever like calculate all
2094 // derivatives for every number in this container.
2095 Record {
2096 number,
2097 history: self.history,
2098 index,
2099 }
2100 .try_derivatives()
2101 }
2102
2103 /**
2104 * Performs elementwise multiplication for two record matrices of the same size.
2105 *
2106 * # Panics
2107 *
2108 * - If both record containers have a WengertList that are different to each other
2109 * - If the record containers have different shapes
2110 */
2111 // TODO: Assign variants?
2112 pub fn elementwise_multiply<S2>(
2113 &self,
2114 other: &RecordMatrix<'a, T, S2>,
2115 ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
2116 where
2117 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2118 {
2119 self.binary(
2120 other,
2121 Multiplication::<T>::function,
2122 Multiplication::<T>::d_function_dx,
2123 Multiplication::<T>::d_function_dy,
2124 )
2125 }
2126
2127 /**
2128 * Performs elementwise division for two record matrices of the same size.
2129 *
2130 * # Panics
2131 *
2132 * - If both record containers have a WengertList that are different to each other
2133 * - If the record containers have different shapes
2134 */
2135 pub fn elementwise_divide<S2>(
2136 &self,
2137 other: &RecordMatrix<'a, T, S2>,
2138 ) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
2139 where
2140 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2141 {
2142 self.binary(
2143 other,
2144 Division::<T>::function,
2145 Division::<T>::d_function_dx,
2146 Division::<T>::d_function_dy,
2147 )
2148 }
2149}
2150
2151impl<'a, T, S> RecordMatrix<'a, T, S>
2152where
2153 T: Numeric + Primitive,
2154 for<'t> &'t T: NumericRef<T>,
2155 S: MatrixMut<(T, Index)> + NoInteriorMutability,
2156{
2157 /**
2158 * Overwrites a RecordContainer by applying
2159 * some unary function from `T` to `T` to every element in the container.
2160 *
2161 * To compute the new records, the unary function of some input x to some
2162 * output y is needed along with its derivative with respect to its input x.
2163 */
2164 #[track_caller]
2165 pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
2166 let total = self.elements();
2167 match self.history {
2168 None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
2169 Some(history) => {
2170 let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
2171 for (element, result) in self.numbers.row_major_reference_mut_iter().zip(ys) {
2172 *element = result;
2173 }
2174 self.history = Some(history);
2175 }
2176 }
2177 }
2178
2179 /**
2180 * Overwrites the left hand side of a RecordContainer with the result of applying
2181 * some binary function from `T` to `T` to every element pair in the containers. Both
2182 * containers must have the same shape.
2183 * To compute the new records, the binary function of some inputs x and y to some
2184 * output z is needed along with its derivative with respect to its first input x and
2185 * its derivative with respect to its second input y.
2186 *
2187 * # Panics
2188 *
2189 * - If both record containers have a WengertList that are different to each other
2190 * - If the record containers have different shapes
2191 */
2192 #[track_caller]
2193 pub fn binary_left_assign<S2>(
2194 &mut self,
2195 rhs: &RecordMatrix<'a, T, S2>,
2196 fxy: impl Fn(T, T) -> T,
2197 dfxy_dx: impl Fn(T, T) -> T,
2198 dfxy_dy: impl Fn(T, T) -> T,
2199 ) where
2200 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2201 {
2202 {
2203 let left_shape = self.numbers.size();
2204 let right_shape = rhs.numbers.size();
2205 if left_shape != right_shape {
2206 panic!(
2207 "Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
2208 left_shape, right_shape
2209 );
2210 }
2211 }
2212 let total = self.elements();
2213 match (self.history, rhs.history) {
2214 (None, None) => {
2215 for (x, y) in self
2216 .numbers
2217 .row_major_reference_mut_iter()
2218 .zip(rhs.numbers.row_major_iter())
2219 {
2220 let (left, _) = x;
2221 let (right, _) = y;
2222 *x = (fxy(left.clone(), right), 0);
2223 }
2224 }
2225 (Some(history), None) => {
2226 let zs = binary_x_history::<T, _, _>(
2227 total,
2228 history,
2229 self.numbers.row_major_iter(),
2230 rhs.numbers.row_major_iter(),
2231 fxy,
2232 dfxy_dx,
2233 );
2234 for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2235 *element = result;
2236 }
2237 self.history = Some(history);
2238 }
2239 (None, Some(history)) => {
2240 let zs = binary_y_history::<T, _, _>(
2241 total,
2242 history,
2243 self.numbers.row_major_iter(),
2244 rhs.numbers.row_major_iter(),
2245 fxy,
2246 dfxy_dy,
2247 );
2248 for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2249 *element = result;
2250 }
2251 self.history = Some(history);
2252 }
2253 (Some(history), Some(h)) => {
2254 assert!(
2255 record_operations::same_lists(history, h),
2256 "Record containers must be using the same WengertList"
2257 );
2258 let zs = binary_both_history::<T, _, _>(
2259 total,
2260 history,
2261 self.numbers.row_major_iter(),
2262 rhs.numbers.row_major_iter(),
2263 fxy,
2264 dfxy_dx,
2265 dfxy_dy,
2266 );
2267 for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
2268 *element = result;
2269 }
2270 self.history = Some(history);
2271 }
2272 }
2273 }
2274
2275 /**
2276 * A convenience helper function which takes the RecordContainer value and
2277 * calls [unary_assign](RecordMatrix::unary_assign()) on it, returning
2278 * the record container which now contains the result of the operation.
2279 */
2280 #[track_caller]
2281 pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
2282 self.unary_assign(fx, dfx_dx);
2283 self
2284 }
2285
2286 /**
2287 * A convenience helper function which takes the left hand side by value and
2288 * calls [binary_left_assign](RecordMatrix::binary_left_assign()) on it, returning
2289 * the left hand side which now contains the result of the operation.
2290 */
2291 #[track_caller]
2292 pub fn do_binary_left_assign<S2>(
2293 mut self,
2294 rhs: &RecordMatrix<'a, T, S2>,
2295 fxy: impl Fn(T, T) -> T,
2296 dfxy_dx: impl Fn(T, T) -> T,
2297 dfxy_dy: impl Fn(T, T) -> T,
2298 ) -> Self
2299 where
2300 S2: MatrixRef<(T, Index)> + NoInteriorMutability,
2301 {
2302 self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
2303 self
2304 }
2305
2306 /**
2307 * Updates this RecordContainer in place by applying some unary function on `Record<T>` to
2308 * `Record<T>` to every element in the container. This will fail if the function would create
2309 * records with inconsistent histories.
2310 *
2311 * When used with pure functions that can't return different histories for different inputs
2312 * unwrapping with always succeed.
2313 *
2314 * Since this updates the container in place, if Err is returned then the data in this
2315 * RecordContainer is still available but it has been corrupted - at least one of the elements
2316 * should have a different history than what it will have because the mapping function created
2317 * inconsistent histories that couldn't be represented by the container as it only stores
2318 * one.
2319 *
2320 * This API can allow you to call a generic function that operates on
2321 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2322 * during the intermediate calculations for you, without having to resort to storing the
2323 * Record types.
2324 *
2325 * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2326 * after mapping doesn't have to be the same as before, only must be the same for every
2327 * mapped element.
2328 *
2329 * You might also use this function at the end of a training loop to update all the weights
2330 * to reduce their loss.
2331 *
2332 * ```
2333 * use easy_ml::numeric::Numeric;
2334 * use easy_ml::matrices::Matrix;
2335 * use easy_ml::differentiation::{Record, RecordMatrix, WengertList};
2336 *
2337 * let history = WengertList::new();
2338 * let mut weights = RecordMatrix::variables(
2339 * &history,
2340 * Matrix::from(vec![vec![ 0.3, 0.2, -1.2, -0.4 ]])
2341 * );
2342 * let error = {
2343 * // this is over-simplified for brevity, obviously in a real scenario we wouldn't have a
2344 * // function that calculates the error like this or we wouldn't be doing machine learning
2345 * // to fit it in the first place
2346 * let mut loss = Record::variable(0.0, &history);
2347 * for r in weights.iter_row_major_as_records() {
2348 * loss = loss + r;
2349 * }
2350 * loss
2351 * };
2352 * let derivatives = error.derivatives();
2353 * let learning_rate = 0.1;
2354 * // update the weights to contain less error than they did before
2355 * let result = weights.map_mut(|x| x - (derivatives[&x] * learning_rate));
2356 * assert!(result.is_ok()); // we know we didn't introduce an inconsistent history just updating the weights
2357 * ```
2358 */
2359 #[track_caller]
2360 pub fn map_mut(
2361 &mut self,
2362 fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
2363 ) -> Result<(), InconsistentHistory<'a, T>> {
2364 let history = self.history;
2365 let new_history =
2366 map_mut_base::<'a, T, _, _>(RowMajorReferenceMutIterator::from(self), |x| {
2367 let record = Record::from_existing(x.clone(), history);
2368 let result = fx(record);
2369 *x = (result.number, result.index);
2370 result.history
2371 })?;
2372 self.history = new_history;
2373 Ok(())
2374 }
2375
2376 /**
2377 * Updates this RecordContainer in place by applying some unary function on `Record<T>` and
2378 * each index of that position in the Record to `Record<T>` to every element in the container.
2379 * This will fail if the function would create records with inconsistent histories.
2380 *
2381 * When used with pure functions that can't return different histories for different inputs
2382 * unwrapping with always succeed.
2383 *
2384 * Since this updates the container in place, if Err is returned then the data in this
2385 * RecordContainer is still available but it has been corrupted - at least one of the elements
2386 * should have a different history than what it will have because the mapping function created
2387 * inconsistent histories that couldn't be represented by the container as it only stores
2388 * one.
2389 *
2390 * This API can allow you to call a generic function that operates on
2391 * [Numeric](crate::numeric::Numeric) numbers and apply all the correct derivative tracking
2392 * during the intermediate calculations for you, without having to resort to storing the
2393 * Record types.
2394 *
2395 * NB: Mapping a RecordMatrix of constants to variables is not inconsistent, the history
2396 * after mapping doesn't have to be the same as before, only must be the same for every
2397 * mapped element.
2398 */
2399 #[track_caller]
2400 pub fn map_mut_with_index(
2401 &mut self,
2402 fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
2403 ) -> Result<(), InconsistentHistory<'a, T>> {
2404 let history = self.history;
2405 let new_history = map_mut_base::<'a, T, _, _>(
2406 RowMajorReferenceMutIterator::from(self).with_index(),
2407 |((r, c), x)| {
2408 let record = Record::from_existing(x.clone(), history);
2409 let result = fx(record, r, c);
2410 *x = (result.number, result.index);
2411 result.history
2412 },
2413 )?;
2414 self.history = new_history;
2415 Ok(())
2416 }
2417}
2418
2419impl<'a, T, S> RecordMatrix<'a, T, S>
2420where
2421 T: Numeric + Primitive,
2422 for<'t> &'t T: NumericRef<T>,
2423 S: MatrixRef<(T, Index)> + NoInteriorMutability,
2424{
2425 /**
2426 * Overwrites the right hand side of a RecordContainer with the result of applying
2427 * some binary function from `T` to `T` to every element pair in the containers. Both
2428 * containers must have the same shape.
2429 * To compute the new records, the binary function of some inputs x and y to some
2430 * output z is needed along with its derivative with respect to its first input x and
2431 * its derivative with respect to its second input y.
2432 *
2433 * # Panics
2434 *
2435 * - If both record containers have a WengertList that are different to each other
2436 * - If the record containers have different shapes
2437 */
2438 #[track_caller]
2439 pub fn binary_right_assign<S2>(
2440 &self,
2441 rhs: &mut RecordMatrix<'a, T, S2>,
2442 fxy: impl Fn(T, T) -> T,
2443 dfxy_dx: impl Fn(T, T) -> T,
2444 dfxy_dy: impl Fn(T, T) -> T,
2445 ) where
2446 S2: MatrixMut<(T, Index)> + NoInteriorMutability,
2447 {
2448 // x is lhs, y is rhs, so calling binary_left_assign on the rhs container
2449 // means we need to swap all the arguments
2450 rhs.binary_left_assign(
2451 self,
2452 |y, x| fxy(x, y),
2453 |y, x| dfxy_dy(x, y),
2454 |y, x| dfxy_dx(x, y),
2455 )
2456 }
2457
2458 /**
2459 * A convenience helper function which takes the right hand side by value and
2460 * calls [binary_right_assign](RecordMatrix::binary_right_assign()) on it, returning
2461 * the right hand side which now contains the result of the operation.
2462 */
2463 #[track_caller]
2464 pub fn do_binary_right_assign<S2>(
2465 &self,
2466 mut rhs: RecordMatrix<'a, T, S2>,
2467 fxy: impl Fn(T, T) -> T,
2468 dfxy_dx: impl Fn(T, T) -> T,
2469 dfxy_dy: impl Fn(T, T) -> T,
2470 ) -> RecordMatrix<'a, T, S2>
2471 where
2472 S2: MatrixMut<(T, Index)> + NoInteriorMutability,
2473 {
2474 self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
2475 rhs
2476 }
2477}
2478
2479// # Safety
2480//
2481// Our inner `numbers` tensor has to implement TensorRef correctly so by delegating to it
2482// without changing any indexes or introducing interior mutability, we implement TensorRef
2483// correctly as well.
2484/**
2485 * RecordTensor implements TensorRef when the source does, returning references to the tuples
2486 * of `T` and [`Index`](Index).
2487 */
2488unsafe impl<'a, T, S, const D: usize> TensorRef<(T, Index), D> for RecordTensor<'a, T, S, D>
2489where
2490 T: Primitive,
2491 S: TensorRef<(T, Index), D>,
2492{
2493 fn get_reference(&self, indexes: [usize; D]) -> Option<&(T, Index)> {
2494 self.numbers.source_ref().get_reference(indexes)
2495 }
2496
2497 fn view_shape(&self) -> [(Dimension, usize); D] {
2498 self.numbers.source_ref().view_shape()
2499 }
2500
2501 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &(T, Index) {
2502 unsafe { self.numbers.source_ref().get_reference_unchecked(indexes) }
2503 }
2504
2505 fn data_layout(&self) -> DataLayout<D> {
2506 self.numbers.source_ref().data_layout()
2507 }
2508}
2509
2510// # Safety
2511//
2512// Our inner `numbers` tensor has to implement TensorMut correctly so by delegating to it
2513// without changing any indexes or introducing interior mutability, we implement TensorMut
2514// correctly as well.
2515/**
2516 * RecordTensor implements TensorMut when the source does, returning mutable references to the
2517 * tuples of `T` and [`Index`](Index).
2518 */
2519unsafe impl<'a, T, S, const D: usize> TensorMut<(T, Index), D> for RecordTensor<'a, T, S, D>
2520where
2521 T: Primitive,
2522 S: TensorMut<(T, Index), D>,
2523{
2524 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut (T, Index)> {
2525 self.numbers.source_ref_mut().get_reference_mut(indexes)
2526 }
2527
2528 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut (T, Index) {
2529 unsafe {
2530 self.numbers
2531 .source_ref_mut()
2532 .get_reference_unchecked_mut(indexes)
2533 }
2534 }
2535}
2536
2537// # Safety
2538//
2539// Our inner `numbers` matrix has to implement MatrixRef correctly so by delegating to it
2540// without changing any indexes or introducing interior mutability, we implement MatrixRef
2541// correctly as well.
2542/**
2543 * RecordMatrix implements MatrixRef when the source does, returning references to the tuples
2544 * of `T` and [`Index`](Index).
2545 */
2546unsafe impl<'a, T, S> MatrixRef<(T, Index)> for RecordMatrix<'a, T, S>
2547where
2548 T: Primitive,
2549 S: MatrixRef<(T, Index)>,
2550{
2551 fn try_get_reference(&self, row: Row, column: Column) -> Option<&(T, Index)> {
2552 self.numbers.source_ref().try_get_reference(row, column)
2553 }
2554
2555 fn view_rows(&self) -> Row {
2556 self.numbers.source_ref().view_rows()
2557 }
2558
2559 fn view_columns(&self) -> Column {
2560 self.numbers.source_ref().view_columns()
2561 }
2562
2563 unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &(T, Index) {
2564 unsafe {
2565 self.numbers
2566 .source_ref()
2567 .get_reference_unchecked(row, column)
2568 }
2569 }
2570
2571 fn data_layout(&self) -> crate::matrices::views::DataLayout {
2572 self.numbers.source_ref().data_layout()
2573 }
2574}
2575
2576// # Safety
2577//
2578// Our inner `numbers` matrix has to implement NoInteriorMutability correctly so by delegating to
2579// it without introducing interior mutability, we implement NoInteriorMutability
2580// correctly as well.
2581/**
2582 * RecordMatrix implements NoInteriorMutability when the source does.
2583 */
2584unsafe impl<'a, T, S> NoInteriorMutability for RecordMatrix<'a, T, S>
2585where
2586 T: Primitive,
2587 S: NoInteriorMutability,
2588{
2589}
2590
2591// # Safety
2592//
2593// Our inner `numbers` matrix has to implement MatrixMut correctly so by delegating to it
2594// without changing any indexes or introducing interior mutability, we implement MatrixMut
2595// correctly as well.
2596/**
2597 * RecordMatrix implements MatrixMut when the source does, returning mutable references to the
2598 * tuples of `T` and [`Index`].
2599 */
2600unsafe impl<'a, T, S> MatrixMut<(T, Index)> for RecordMatrix<'a, T, S>
2601where
2602 T: Primitive,
2603 S: MatrixMut<(T, Index)>,
2604{
2605 fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut (T, Index)> {
2606 self.numbers
2607 .source_ref_mut()
2608 .try_get_reference_mut(row, column)
2609 }
2610
2611 unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut (T, Index) {
2612 unsafe {
2613 self.numbers
2614 .source_ref_mut()
2615 .get_reference_unchecked_mut(row, column)
2616 }
2617 }
2618}
2619
2620/**
2621 * A zero dimensional record tensor can be converted losslessly into a record.
2622 */
2623impl<'a, T, S> From<RecordTensor<'a, T, S, 0>> for Record<'a, T>
2624where
2625 T: Numeric + Primitive,
2626 S: TensorRef<(T, Index), 0>,
2627{
2628 /**
2629 * Converts the sole element in the zero dimensional record tensor into a record.
2630 */
2631 fn from(scalar: RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
2632 // Not a good way to make this zero copy and just move the data out of the scalar because
2633 // TensorRef API doesn't have by value moves of the data and using TensorOwnedIterator
2634 // requires T: Default and a dummy value (at which point a clone is probably cheaper or
2635 // basically the same?)
2636 Record::from(&scalar)
2637 }
2638}
2639
2640/**
2641 * A zero dimensional record tensor can be converted losslessly into a record.
2642 */
2643impl<'a, T, S> From<&RecordTensor<'a, T, S, 0>> for Record<'a, T>
2644where
2645 T: Numeric + Primitive,
2646 S: TensorRef<(T, Index), 0>,
2647{
2648 /**
2649 * Converts the sole element in the zero dimensional record tensor into a record.
2650 */
2651 fn from(scalar: &RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
2652 Record::from_existing(scalar.view().scalar(), scalar.history)
2653 }
2654}
2655
2656/**
2657 * A record can be converted losslessly into a zero dimensional record tensor.
2658 */
2659impl<'a, T> From<Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
2660where
2661 T: Numeric + Primitive,
2662{
2663 /**
2664 * Converts a record into a zero dimensional record tensor with the single element.
2665 */
2666 fn from(record: Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
2667 RecordTensor::from_existing(
2668 record.history,
2669 TensorView::from(Tensor::from([], vec![(record.number, record.index)])),
2670 )
2671 }
2672}
2673
2674/**
2675 * A record can be converted losslessly into a zero dimensional record tensor.
2676 */
2677impl<'a, T> From<&Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
2678where
2679 T: Numeric + Primitive,
2680{
2681 /**
2682 * Converts a record into a zero dimensional record tensor with the single element.
2683 */
2684 fn from(record: &Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
2685 RecordTensor::from_existing(
2686 record.history,
2687 TensorView::from(Tensor::from(
2688 [],
2689 vec![(record.number.clone(), record.index)],
2690 )),
2691 )
2692 }
2693}
2694
2695#[test]
2696fn matrix_multiplication_derivatives_are_the_same() {
2697 #[rustfmt::skip]
2698 let a = Tensor::from(
2699 [("r", 4), ("c", 3)],
2700 vec![
2701 1.0, 2.0, 3.0,
2702 4.0, 5.0, 6.0,
2703 7.0, 8.0, 9.0,
2704 0.0, 5.0, 2.0
2705 ]
2706 );
2707 let b = a.transpose(["c", "r"]);
2708 let history = WengertList::new();
2709 let also_history: WengertList<f64> = WengertList::new();
2710 let tensor_of_records_a = a.map(|x| Record::variable(x, &history));
2711 let tensor_of_records_b = b.map(|x| Record::variable(x, &history));
2712 let tensor_of_records_c = &tensor_of_records_a * &tensor_of_records_b;
2713 let record_tensor_a = RecordTensor::variables(&also_history, a);
2714 let record_tensor_b = RecordTensor::variables(&also_history, b);
2715 let record_tensor_c = &record_tensor_a * &record_tensor_b;
2716
2717 // C should be calculated the same in terms of the actual number
2718 assert_eq!(
2719 tensor_of_records_c.map(|r| r.number),
2720 TensorView::from(&record_tensor_c).map(|(n, _)| n)
2721 );
2722
2723 let tensor_of_records_derivatives = tensor_of_records_c.map(|r| r.derivatives());
2724 let tensor_of_records_a_derivatives =
2725 tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_a));
2726 let tensor_of_records_b_derivatives =
2727 tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_b));
2728
2729 let record_tensor_derivatives = record_tensor_c.derivatives().unwrap();
2730 let record_tensor_a_derivatives =
2731 record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_a));
2732 let record_tensor_b_derivatives =
2733 record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_b));
2734
2735 // Every calculated derivative should match exactly
2736 assert_eq!(tensor_of_records_a_derivatives, record_tensor_a_derivatives);
2737 assert_eq!(tensor_of_records_b_derivatives, record_tensor_b_derivatives);
2738
2739 // Verify C is actually calculated correctly
2740 #[rustfmt::skip]
2741 assert_eq!(
2742 tensor_of_records_c.map(|r| r.number),
2743 Tensor::from(
2744 [("r", 4), ("c", 4)],
2745 vec![
2746 14.0, 32.0, 50.0, 16.0,
2747 32.0, 77.0, 122.0, 37.0,
2748 50.0, 122.0, 194.0, 58.0,
2749 16.0, 37.0, 58.0, 29.0
2750 ]
2751 )
2752 );
2753 #[rustfmt::skip]
2754 assert_eq!(
2755 tensor_of_records_c.map(|r| r.number),
2756 Tensor::from(
2757 [("r", 4), ("c", 4)],
2758 vec![
2759 (1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
2760 (1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
2761 (1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
2762 (1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
2763
2764 (4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
2765 (4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
2766 (4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
2767 (4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
2768
2769 (7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
2770 (7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
2771 (7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
2772 (7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
2773
2774 (0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
2775 (0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
2776 (0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
2777 (0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
2778 ]
2779 )
2780 );
2781
2782 let tensor_of_records_derivatives = history.operations.borrow().clone();
2783 let record_tensor_derivatives = also_history.operations.borrow().clone();
2784 assert_eq!(
2785 tensor_of_records_derivatives.len(),
2786 record_tensor_derivatives.len()
2787 );
2788}
2789
2790#[test]
2791fn matrix_view_matrix_multiplication_derivatives_are_the_same() {
2792 #[rustfmt::skip]
2793 let a = Matrix::from(vec![
2794 vec![ 1.0, 2.0, 3.0 ],
2795 vec![ 4.0, 5.0, 6.0 ],
2796 vec![ 7.0, 8.0, 9.0 ],
2797 vec![ 0.0, 5.0, 2.0 ]
2798 ]);
2799 let b = a.transpose();
2800 let history = WengertList::new();
2801 let also_history: WengertList<f64> = WengertList::new();
2802 let matrix_of_records_a = a.map(|x| Record::variable(x, &history));
2803 let matrix_of_records_b = b.map(|x| Record::variable(x, &history));
2804 let matrix_of_records_c = &matrix_of_records_a * &matrix_of_records_b;
2805 let record_matrix_a = RecordMatrix::variables(&also_history, a);
2806 let record_matrix_b = RecordMatrix::variables(&also_history, b);
2807 let record_matrix_c = &record_matrix_a * &record_matrix_b;
2808
2809 // C should be calculated the same in terms of the actual number
2810 assert_eq!(
2811 matrix_of_records_c.map(|r| r.number),
2812 MatrixView::from(&record_matrix_c).map(|(n, _)| n)
2813 );
2814
2815 let matrix_of_records_derivatives = matrix_of_records_c.map(|r| r.derivatives());
2816 let matrix_of_records_a_derivatives =
2817 matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_a));
2818 let matrix_of_records_b_derivatives =
2819 matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_b));
2820
2821 let record_matrix_derivatives = record_matrix_c.derivatives().unwrap();
2822 let record_matrix_a_derivatives =
2823 record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_a));
2824 let record_matrix_b_derivatives =
2825 record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_b));
2826
2827 // Every calculated derivative should match exactly
2828 assert_eq!(matrix_of_records_a_derivatives, record_matrix_a_derivatives);
2829 assert_eq!(matrix_of_records_b_derivatives, record_matrix_b_derivatives);
2830
2831 // Verify C is actually calculated correctly
2832 #[rustfmt::skip]
2833 assert_eq!(
2834 matrix_of_records_c.map(|r| r.number),
2835 Matrix::from(vec![
2836 vec![ 14.0, 32.0, 50.0, 16.0 ],
2837 vec![ 32.0, 77.0, 122.0, 37.0 ],
2838 vec![ 50.0, 122.0, 194.0, 58.0 ],
2839 vec![ 16.0, 37.0, 58.0, 29.0 ]
2840 ]
2841 )
2842 );
2843 #[rustfmt::skip]
2844 assert_eq!(
2845 matrix_of_records_c.map(|r| r.number),
2846 Matrix::from(vec![
2847 vec![
2848 (1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
2849 (1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
2850 (1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
2851 (1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
2852 ],
2853 vec![
2854 (4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
2855 (4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
2856 (4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
2857 (4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
2858 ],
2859 vec![
2860 (7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
2861 (7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
2862 (7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
2863 (7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
2864 ],
2865 vec![
2866 (0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
2867 (0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
2868 (0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
2869 (0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
2870 ]
2871 ])
2872 );
2873
2874 let matrix_of_records_derivatives = history.operations.borrow().clone();
2875 let record_matrix_derivatives = also_history.operations.borrow().clone();
2876 assert_eq!(
2877 matrix_of_records_derivatives.len(),
2878 record_matrix_derivatives.len()
2879 );
2880}