use crate::differentiation::functions::{Division, FunctionDerivative, Multiplication};
use crate::differentiation::iterators::{
AsRecords, InconsistentHistory, InvalidRecordIteratorError,
};
use crate::differentiation::record_operations;
use crate::differentiation::{Derivatives, Index, Primitive, Record, WengertList};
use crate::matrices::iterators::{
ColumnMajorIterator, RowMajorIterator, RowMajorOwnedIterator, RowMajorReferenceMutIterator,
};
use crate::matrices::views::{MatrixMut, MatrixRef, MatrixView, NoInteriorMutability};
use crate::matrices::{Column, Matrix, Row};
use crate::numeric::{Numeric, NumericRef};
use crate::tensors::indexing::{
TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceMutIterator,
};
use crate::tensors::views::{DataLayout, TensorMut, TensorRef, TensorRename, TensorView};
use crate::tensors::{Dimension, Tensor};
mod container_operations;
pub mod iterators;
#[derive(Debug)]
pub struct RecordContainer<'a, T: Primitive, S, const D: usize> {
numbers: S,
history: Option<&'a WengertList<T>>,
}
pub type RecordMatrix<'a, T, S> = RecordContainer<'a, T, MatrixView<(T, Index), S>, 2>;
pub type RecordTensor<'a, T, S, const D: usize> =
RecordContainer<'a, T, TensorView<(T, Index), S, D>, D>;
fn calculate_incrementing_indexes(starting_index: usize, total: usize) -> Vec<Index> {
let mut indexes = vec![0; total];
for (i, x) in indexes.iter_mut().enumerate() {
*x = starting_index + i;
}
indexes
}
impl<'a, T, const D: usize> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
where
T: Numeric + Primitive,
{
pub fn constants<S>(c: S) -> Self
where
S: TensorMut<T, D>,
{
RecordContainer {
numbers: TensorView::from(Tensor::from(
c.view_shape(),
TensorOwnedIterator::from_numeric(c)
.map(|x| (x, 0))
.collect(),
)),
history: None,
}
}
pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
where
S: TensorMut<T, D>,
{
let total = crate::tensors::dimensions::elements(&x.view_shape());
let starting_index = history.append_nullary_repeating(total);
RecordContainer {
numbers: TensorView::from(Tensor::from(
x.view_shape(),
TensorOwnedIterator::from_numeric(x)
.zip(calculate_incrementing_indexes(starting_index, total))
.collect(),
)),
history: Some(history),
}
}
}
impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
where
T: Numeric + Primitive,
S: TensorRef<(T, Index), D>,
{
pub fn elements(&self) -> usize {
crate::tensors::dimensions::elements(&self.numbers.shape())
}
pub fn shape(&self) -> [(Dimension, usize); D] {
self.numbers.shape()
}
pub fn from_existing(
history: Option<&'a WengertList<T>>,
numbers: TensorView<(T, Index), S, D>,
) -> Self {
RecordContainer { numbers, history }
}
#[track_caller]
pub fn rename_view(
self,
dimensions: [Dimension; D],
) -> RecordTensor<'a, T, TensorRename<(T, Index), S, D>, D> {
RecordTensor::from_existing(
self.history,
TensorView::from(TensorRename::from(self.numbers.source(), dimensions)),
)
}
pub fn view(&self) -> TensorView<(T, Index), &RecordTensor<'a, T, S, D>, D> {
TensorView::from(self)
}
#[track_caller]
pub fn index_by(
&self,
dimensions: [Dimension; D],
) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
TensorAccess::from(self, dimensions)
}
pub fn index(&self) -> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D> {
TensorAccess::from_source_order(self)
}
#[allow(clippy::type_complexity)]
pub fn iter_as_records<'b>(
&'b self,
) -> AsRecords<'a, TensorIterator<'b, (T, Index), RecordTensor<'a, T, S, D>, D>, T> {
AsRecords::from_tensor(self)
}
}
impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
where
T: Numeric + Primitive,
S: TensorMut<(T, Index), D>,
{
pub fn reset(&mut self) {
match self.history {
None => (), Some(history) => {
let total = self.elements();
let starting_index = history.append_nullary_repeating(total);
for (x, i) in self
.numbers
.iter_reference_mut()
.zip(calculate_incrementing_indexes(starting_index, total))
{
let (_, old_index) = x;
*old_index = i;
}
}
};
}
pub fn do_reset(mut x: Self) -> Self {
x.reset();
x
}
}
impl<'a, T> RecordMatrix<'a, T, Matrix<(T, Index)>>
where
T: Numeric + Primitive,
{
pub fn constants<S>(c: S) -> Self
where
S: MatrixMut<T> + NoInteriorMutability,
{
RecordContainer {
numbers: MatrixView::from(Matrix::from_flat_row_major(
(c.view_rows(), c.view_columns()),
RowMajorOwnedIterator::from_numeric(c)
.map(|x| (x, 0))
.collect(),
)),
history: None,
}
}
pub fn variables<S>(history: &'a WengertList<T>, x: S) -> Self
where
S: MatrixMut<T> + NoInteriorMutability,
{
let total = x.view_rows() * x.view_columns();
let starting_index = history.append_nullary_repeating(total);
RecordContainer {
numbers: MatrixView::from(Matrix::from_flat_row_major(
(x.view_rows(), x.view_columns()),
RowMajorOwnedIterator::from_numeric(x)
.zip(calculate_incrementing_indexes(starting_index, total))
.collect(),
)),
history: Some(history),
}
}
}
impl<'a, T, S> RecordMatrix<'a, T, S>
where
T: Numeric + Primitive,
S: MatrixRef<(T, Index)> + NoInteriorMutability,
{
pub fn elements(&self) -> usize {
self.numbers.rows() * self.numbers.columns()
}
pub fn size(&self) -> (Row, Column) {
self.numbers.size()
}
pub fn rows(&self) -> Row {
self.numbers.rows()
}
pub fn columns(&self) -> Column {
self.numbers.columns()
}
pub fn from_existing(
history: Option<&'a WengertList<T>>,
numbers: MatrixView<(T, Index), S>,
) -> Self {
RecordContainer { numbers, history }
}
pub fn view(&self) -> MatrixView<(T, Index), &RecordMatrix<'a, T, S>> {
MatrixView::from(self)
}
#[allow(clippy::type_complexity)]
pub fn iter_row_major_as_records<'b>(
&'b self,
) -> AsRecords<'a, RowMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
AsRecords::from_matrix_row_major(self)
}
#[allow(clippy::type_complexity)]
pub fn iter_column_major_as_records<'b>(
&'b self,
) -> AsRecords<'a, ColumnMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T> {
AsRecords::from_matrix_column_major(self)
}
#[track_caller]
pub fn get_as_record(&self, row: Row, column: Column) -> Record<'a, T> {
Record::from_existing(self.numbers.get(row, column), self.history)
}
pub fn try_get_as_record(&self, row: Row, column: Column) -> Option<Record<'a, T>> {
self.numbers
.try_get_reference(row, column)
.map(|r| Record::from_existing(r.clone(), self.history))
}
}
impl<'a, T, S> RecordMatrix<'a, T, S>
where
T: Numeric + Primitive,
S: MatrixMut<(T, Index)> + NoInteriorMutability,
{
pub fn reset(&mut self) {
match self.history {
None => (), Some(history) => {
let total = self.elements();
let starting_index = history.append_nullary_repeating(total);
for (x, i) in self
.numbers
.row_major_reference_mut_iter()
.zip(calculate_incrementing_indexes(starting_index, total))
{
let (_, old_index) = x;
*old_index = i;
}
}
};
}
pub fn do_reset(mut x: Self) -> Self {
x.reset();
x
}
}
impl<'a, T, S, const D: usize> RecordContainer<'a, T, S, D>
where
T: Primitive,
{
pub fn history(&self) -> Option<&'a WengertList<T>> {
self.history
}
}
fn unary<'a, T, I>(
total: usize,
history: &WengertList<T>,
records: I,
fx: impl Fn(T) -> T,
dfx_dx: impl Fn(T) -> T,
) -> Vec<(T, usize)>
where
I: Iterator<Item = (T, Index)>,
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
{
let mut ys = vec![(T::zero(), 0); total];
history.borrow(|history| {
for (i, (x, parent)) in records.enumerate() {
let y = fx(x.clone());
let derivative = dfx_dx(x);
let new_index = history.append_unary(parent, derivative);
ys[i] = (y, new_index)
}
}); ys
}
fn binary_both_history<'a, T, I1, I2>(
total: usize,
history: &WengertList<T>,
x_records: I1,
y_records: I2,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> Vec<(T, usize)>
where
I1: Iterator<Item = (T, Index)>,
I2: Iterator<Item = (T, Index)>,
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
{
let mut zs = vec![(T::zero(), 0); total];
history.borrow(|history| {
for (i, ((x, parent1), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
let z = fxy(x.clone(), y.clone());
let derivative1 = dfxy_dx(x.clone(), y.clone());
let derivative2 = dfxy_dy(x, y);
let new_index = history.append_binary(parent1, derivative1, parent2, derivative2);
zs[i] = (z, new_index);
}
}); zs
}
fn binary_x_history<'a, T, I1, I2>(
total: usize,
history: &WengertList<T>,
x_records: I1,
y_records: I2,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
) -> Vec<(T, usize)>
where
I1: Iterator<Item = (T, Index)>,
I2: Iterator<Item = (T, Index)>,
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
{
let mut zs = vec![(T::zero(), 0); total];
history.borrow(|history| {
for (i, ((x, parent1), (y, _))) in (x_records.zip(y_records)).enumerate() {
let z = fxy(x.clone(), y.clone());
let derivative1 = dfxy_dx(x, y);
let new_index = history.append_unary(parent1, derivative1);
zs[i] = (z, new_index);
}
}); zs
}
fn binary_y_history<'a, T, I1, I2>(
total: usize,
history: &WengertList<T>,
x_records: I1,
y_records: I2,
fxy: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> Vec<(T, usize)>
where
I1: Iterator<Item = (T, Index)>,
I2: Iterator<Item = (T, Index)>,
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
{
let mut zs = vec![(T::zero(), 0); total];
history.borrow(|history| {
for (i, ((x, _), (y, parent2))) in (x_records.zip(y_records)).enumerate() {
let z = fxy(x.clone(), y.clone());
let derivative2 = dfxy_dy(x, y);
let new_index = history.append_unary(parent2, derivative2);
zs[i] = (z, new_index);
}
}); zs
}
impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
where
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
S: TensorRef<(T, Index), D>,
{
#[track_caller]
pub fn unary(
&self,
fx: impl Fn(T) -> T,
dfx_dx: impl Fn(T) -> T,
) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D> {
let total = self.elements();
match self.history {
None => RecordTensor::constants(self.numbers.map(|(x, _)| fx(x))),
Some(history) => {
let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
RecordContainer {
numbers: self.numbers.new_with_same_shape(ys),
history: Some(history),
}
}
}
}
#[track_caller]
pub fn binary<S2>(
&self,
rhs: &RecordTensor<'a, T, S2, D>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
where
S2: TensorRef<(T, Index), D>,
{
{
let left_shape = self.numbers.shape();
let right_shape = rhs.numbers.shape();
if left_shape != right_shape {
panic!(
"Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
left_shape, right_shape
);
}
}
let total = self.elements();
match (self.history, rhs.history) {
(None, None) => RecordTensor::constants(
Tensor::from(
self.numbers.shape(),
self.numbers
.iter()
.zip(rhs.numbers.iter())
.map(|((x, _), (y, _))| fxy(x, y))
.collect(),
),
),
(Some(history), None) => {
let zs = binary_x_history::<T, _, _>(
total,
history,
self.numbers.iter(),
rhs.numbers.iter(),
fxy,
dfxy_dx,
);
RecordContainer {
numbers: self.numbers.new_with_same_shape(zs),
history: Some(history),
}
}
(None, Some(history)) => {
let zs = binary_y_history::<T, _, _>(
total,
history,
self.numbers.iter(),
rhs.numbers.iter(),
fxy,
dfxy_dy,
);
RecordContainer {
numbers: self.numbers.new_with_same_shape(zs),
history: Some(history),
}
}
(Some(history), Some(h)) => {
assert!(
record_operations::same_lists(history, h),
"Record containers must be using the same WengertList"
);
let zs = binary_both_history::<T, _, _>(
total,
history,
self.numbers.iter(),
rhs.numbers.iter(),
fxy,
dfxy_dx,
dfxy_dy,
);
RecordContainer {
numbers: self.numbers.new_with_same_shape(zs),
history: Some(history),
}
}
}
}
#[allow(clippy::type_complexity)]
#[track_caller]
pub fn map(
&self,
fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
let result = RecordTensor::from_iter(self.shape(), self.iter_as_records().map(fx));
RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
}
#[allow(clippy::type_complexity)]
#[track_caller]
pub fn map_with_index(
&self,
fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
) -> Result<RecordTensor<'a, T, Tensor<(T, Index), D>, D>, InconsistentHistory<'a, T>> {
let result = RecordTensor::from_iter(
self.shape(),
self.iter_as_records().with_index().map(|(i, x)| fx(i, x)),
);
RecordTensor::<'a, T, S, D>::map_collection(result, self.shape())
}
#[track_caller]
#[allow(clippy::type_complexity)]
fn map_collection(
result: Result<
RecordTensor<'a, T, Tensor<(T, usize), D>, D>,
InvalidRecordIteratorError<'a, T, D>,
>,
shape: [(Dimension, usize); D],
) -> Result<RecordTensor<'a, T, Tensor<(T, usize), D>, D>, InconsistentHistory<'a, T>> {
use InvalidRecordIteratorError as Error;
match result {
Ok(tensor) => Ok(tensor),
Err(error) => match error {
Error::Empty => panic!("Illegal state, record tensor was empty {:?}", shape),
Error::Shape { requested, length } => panic!(
"Illegal state, record tensor shape was inconsistent: requested: {:?}, length of data: {:?}",
requested, length
),
Error::InconsistentHistory(h) => Err(h),
},
}
}
pub fn derivatives(&self) -> Option<Tensor<Derivatives<T>, D>> {
self.history.map(|history| {
self.numbers.map(|(x, i)| {
Record {
number: x,
history: Some(history),
index: i,
}
.derivatives()
})
})
}
pub fn derivatives_for(&self, indexes: [usize; D]) -> Option<Derivatives<T>> {
let (number, index) = self.get_reference(indexes).map(|(x, i)| (x.clone(), *i))?;
Record {
number,
history: self.history,
index,
}
.try_derivatives()
}
pub fn elementwise_multiply<S2>(
&self,
other: &RecordTensor<'a, T, S2, D>,
) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
where
S2: TensorRef<(T, Index), D>,
{
self.binary(
other,
Multiplication::<T>::function,
Multiplication::<T>::d_function_dx,
Multiplication::<T>::d_function_dy,
)
}
pub fn elementwise_divide<S2>(
&self,
other: &RecordTensor<'a, T, S2, D>,
) -> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
where
S2: TensorRef<(T, Index), D>,
{
self.binary(
other,
Division::<T>::function,
Division::<T>::d_function_dx,
Division::<T>::d_function_dy,
)
}
}
impl<T: Clone + Primitive> Derivatives<T> {
pub fn at_tensor_index<S, const D: usize>(
&self,
indexes: [usize; D],
input: &RecordTensor<T, S, D>,
) -> Option<T>
where
S: TensorRef<(T, Index), D>,
{
let index = input.get_reference(indexes).map(|(_, i)| *i)?;
Some(self.derivatives[index].clone())
}
pub fn at_tensor<S, const D: usize>(&self, input: &RecordTensor<T, S, D>) -> Tensor<T, D>
where
S: TensorRef<(T, Index), D>,
{
input.numbers.map(|(_, i)| self.derivatives[i].clone())
}
pub fn at_matrix_index<S>(
&self,
row: Row,
column: Column,
input: &RecordMatrix<T, S>,
) -> Option<T>
where
S: MatrixRef<(T, Index)> + NoInteriorMutability,
{
let index = input.try_get_reference(row, column).map(|(_, i)| *i)?;
Some(self.derivatives[index].clone())
}
pub fn at_matrix<S>(&self, input: &RecordMatrix<T, S>) -> Matrix<T>
where
S: MatrixRef<(T, Index)> + NoInteriorMutability,
{
input.numbers.map(|(_, i)| self.derivatives[i].clone())
}
}
impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
where
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
S: TensorMut<(T, Index), D>,
{
#[track_caller]
pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
let total = self.elements();
match self.history {
None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
Some(history) => {
let ys = unary::<T, _>(total, history, self.numbers.iter(), fx, dfx_dx);
for (element, result) in self.numbers.iter_reference_mut().zip(ys) {
*element = result;
}
self.history = Some(history);
}
}
}
#[track_caller]
pub fn binary_left_assign<S2>(
&mut self,
rhs: &RecordTensor<'a, T, S2, D>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) where
S2: TensorRef<(T, Index), D>,
{
{
let left_shape = self.numbers.shape();
let right_shape = rhs.numbers.shape();
if left_shape != right_shape {
panic!(
"Record containers must have the same shape for a binary operation: (left: {:?}, right: {:?})",
left_shape, right_shape
);
}
}
let total = self.elements();
match (self.history, rhs.history) {
(None, None) => {
for (x, y) in self.numbers.iter_reference_mut().zip(rhs.numbers.iter()) {
let (left, _) = x;
let (right, _) = y;
*x = (fxy(left.clone(), right), 0);
}
}
(Some(history), None) => {
let zs = binary_x_history::<T, _, _>(
total,
history,
self.numbers.iter(),
rhs.numbers.iter(),
fxy,
dfxy_dx,
);
for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
*element = result;
}
self.history = Some(history);
}
(None, Some(history)) => {
let zs = binary_y_history::<T, _, _>(
total,
history,
self.numbers.iter(),
rhs.numbers.iter(),
fxy,
dfxy_dy,
);
for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
*element = result;
}
self.history = Some(history);
}
(Some(history), Some(h)) => {
assert!(
record_operations::same_lists(history, h),
"Record containers must be using the same WengertList"
);
let zs = binary_both_history::<T, _, _>(
total,
history,
self.numbers.iter(),
rhs.numbers.iter(),
fxy,
dfxy_dx,
dfxy_dy,
);
for (element, result) in self.numbers.iter_reference_mut().zip(zs) {
*element = result;
}
self.history = Some(history);
}
}
}
#[track_caller]
pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
self.unary_assign(fx, dfx_dx);
self
}
#[track_caller]
pub fn do_binary_left_assign<S2>(
mut self,
rhs: &RecordTensor<'a, T, S2, D>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> Self
where
S2: TensorRef<(T, Index), D>,
{
self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
self
}
#[track_caller]
pub fn map_mut(
&mut self,
fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
) -> Result<(), InconsistentHistory<'a, T>> {
let history = self.history;
let new_history =
map_mut_base::<'a, T, _, _>(TensorReferenceMutIterator::from(self), |x| {
let record = Record::from_existing(x.clone(), history);
let result = fx(record);
*x = (result.number, result.index);
result.history
})?;
self.history = new_history;
Ok(())
}
#[track_caller]
pub fn map_mut_with_index(
&mut self,
fx: impl Fn([usize; D], Record<'a, T>) -> Record<'a, T>,
) -> Result<(), InconsistentHistory<'a, T>> {
let history = self.history;
let new_history = map_mut_base::<'a, T, _, _>(
TensorReferenceMutIterator::from(self).with_index(),
|(i, x)| {
let record = Record::from_existing(x.clone(), history);
let result = fx(i, record);
*x = (result.number, result.index);
result.history
},
)?;
self.history = new_history;
Ok(())
}
}
#[track_caller]
fn map_mut_base<'a, T, I, X>(
mut iter: I,
fx: impl Fn(X) -> Option<&'a WengertList<T>>,
) -> Result<Option<&'a WengertList<T>>, InconsistentHistory<'a, T>>
where
I: Iterator<Item = X>,
T: Primitive,
{
use crate::differentiation::record_operations::are_exact_same_list;
#[rustfmt::skip]
let first_history = fx(iter.next().expect("Illegal state, record container was empty"));
let mut different_history: Option<Option<&WengertList<T>>> = None;
for x in iter {
let history = fx(x);
if !are_exact_same_list(history, first_history) {
different_history = Some(history);
}
}
match different_history {
None => Ok(first_history),
Some(h) => Err(InconsistentHistory {
first: first_history,
later: h,
}),
}
}
impl<'a, T, S, const D: usize> RecordTensor<'a, T, S, D>
where
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
S: TensorRef<(T, Index), D>,
{
#[track_caller]
pub fn binary_right_assign<S2>(
&self,
rhs: &mut RecordTensor<'a, T, S2, D>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) where
S2: TensorMut<(T, Index), D>,
{
rhs.binary_left_assign(
self,
|y, x| fxy(x, y),
|y, x| dfxy_dy(x, y),
|y, x| dfxy_dx(x, y),
)
}
#[track_caller]
pub fn do_binary_right_assign<S2>(
&self,
mut rhs: RecordTensor<'a, T, S2, D>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> RecordTensor<'a, T, S2, D>
where
S2: TensorMut<(T, Index), D>,
{
self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
rhs
}
}
impl<'a, T, S> RecordMatrix<'a, T, S>
where
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
S: MatrixRef<(T, Index)> + NoInteriorMutability,
{
#[track_caller]
pub fn unary(
&self,
fx: impl Fn(T) -> T,
dfx_dx: impl Fn(T) -> T,
) -> RecordMatrix<'a, T, Matrix<(T, Index)>> {
let total = self.elements();
match self.history {
None => RecordMatrix::constants(self.numbers.map(|(x, _)| fx(x))),
Some(history) => {
let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
RecordContainer {
numbers: MatrixView::from(Matrix::from_flat_row_major(self.numbers.size(), ys)),
history: Some(history),
}
}
}
}
#[track_caller]
pub fn binary<S2>(
&self,
rhs: &RecordMatrix<'a, T, S2>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
where
S2: MatrixRef<(T, Index)> + NoInteriorMutability,
{
let shape = {
let left_shape = self.numbers.size();
let right_shape = rhs.numbers.size();
if left_shape != right_shape {
panic!(
"Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
left_shape, right_shape
);
}
left_shape
};
let total = self.elements();
match (self.history, rhs.history) {
(None, None) => RecordMatrix::constants(Matrix::from_flat_row_major(
shape,
self.numbers
.row_major_iter()
.zip(rhs.numbers.row_major_iter())
.map(|((x, _), (y, _))| fxy(x, y))
.collect(),
)),
(Some(history), None) => {
let zs = binary_x_history::<T, _, _>(
total,
history,
self.numbers.row_major_iter(),
rhs.numbers.row_major_iter(),
fxy,
dfxy_dx,
);
RecordContainer {
numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
history: Some(history),
}
}
(None, Some(history)) => {
let zs = binary_y_history::<T, _, _>(
total,
history,
self.numbers.row_major_iter(),
rhs.numbers.row_major_iter(),
fxy,
dfxy_dy,
);
RecordContainer {
numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
history: Some(history),
}
}
(Some(history), Some(h)) => {
assert!(
record_operations::same_lists(history, h),
"Record containers must be using the same WengertList"
);
let zs = binary_both_history::<T, _, _>(
total,
history,
self.numbers.row_major_iter(),
rhs.numbers.row_major_iter(),
fxy,
dfxy_dx,
dfxy_dy,
);
RecordContainer {
numbers: MatrixView::from(Matrix::from_flat_row_major(shape, zs)),
history: Some(history),
}
}
}
}
#[allow(clippy::type_complexity)]
#[track_caller]
pub fn map(
&self,
fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
let result = RecordMatrix::from_iter(self.size(), self.iter_row_major_as_records().map(fx));
RecordMatrix::<'a, T, S>::map_collection(result, self.size())
}
#[allow(clippy::type_complexity)]
#[track_caller]
pub fn map_with_index(
&self,
fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
) -> Result<RecordMatrix<'a, T, Matrix<(T, Index)>>, InconsistentHistory<'a, T>> {
let result = RecordMatrix::from_iter(
self.size(),
self.iter_row_major_as_records()
.with_index()
.map(|((r, c), x)| fx(x, r, c)),
);
RecordMatrix::<'a, T, S>::map_collection(result, self.size())
}
#[allow(clippy::type_complexity)]
#[track_caller]
fn map_collection(
result: Result<
RecordMatrix<'a, T, Matrix<(T, usize)>>,
InvalidRecordIteratorError<'a, T, 2>,
>,
size: (Row, Column),
) -> Result<RecordMatrix<'a, T, Matrix<(T, usize)>>, InconsistentHistory<'a, T>> {
use InvalidRecordIteratorError as Error;
match result {
Ok(matrix) => Ok(matrix),
Err(error) => match error {
Error::Empty => panic!("Illegal state, record matrix was empty {:?}", size),
Error::Shape { requested, length } => panic!(
"Illegal state, record matrix shape was inconsistent: requested: {:?}, length of data: {:?}",
requested, length
),
Error::InconsistentHistory(h) => Err(h),
},
}
}
pub fn derivatives(&self) -> Option<Matrix<Derivatives<T>>> {
self.history.map(|history| {
self.numbers.map(|(x, i)| {
Record {
number: x,
history: Some(history),
index: i,
}
.derivatives()
})
})
}
pub fn derivatives_for(&self, row: Row, column: Column) -> Option<Derivatives<T>> {
let (number, index) = self
.try_get_reference(row, column)
.map(|(x, i)| (x.clone(), *i))?;
Record {
number,
history: self.history,
index,
}
.try_derivatives()
}
pub fn elementwise_multiply<S2>(
&self,
other: &RecordMatrix<'a, T, S2>,
) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
where
S2: MatrixRef<(T, Index)> + NoInteriorMutability,
{
self.binary(
other,
Multiplication::<T>::function,
Multiplication::<T>::d_function_dx,
Multiplication::<T>::d_function_dy,
)
}
pub fn elementwise_divide<S2>(
&self,
other: &RecordMatrix<'a, T, S2>,
) -> RecordMatrix<'a, T, Matrix<(T, Index)>>
where
S2: MatrixRef<(T, Index)> + NoInteriorMutability,
{
self.binary(
other,
Division::<T>::function,
Division::<T>::d_function_dx,
Division::<T>::d_function_dy,
)
}
}
impl<'a, T, S> RecordMatrix<'a, T, S>
where
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
S: MatrixMut<(T, Index)> + NoInteriorMutability,
{
#[track_caller]
pub fn unary_assign(&mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) {
let total = self.elements();
match self.history {
None => self.numbers.map_mut(|(x, i)| (fx(x), i)),
Some(history) => {
let ys = unary::<T, _>(total, history, self.numbers.row_major_iter(), fx, dfx_dx);
for (element, result) in self.numbers.row_major_reference_mut_iter().zip(ys) {
*element = result;
}
self.history = Some(history);
}
}
}
#[track_caller]
pub fn binary_left_assign<S2>(
&mut self,
rhs: &RecordMatrix<'a, T, S2>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) where
S2: MatrixRef<(T, Index)> + NoInteriorMutability,
{
{
let left_shape = self.numbers.size();
let right_shape = rhs.numbers.size();
if left_shape != right_shape {
panic!(
"Record containers must have the same size for a binary operation: (left: {:?}, right: {:?})",
left_shape, right_shape
);
}
}
let total = self.elements();
match (self.history, rhs.history) {
(None, None) => {
for (x, y) in self
.numbers
.row_major_reference_mut_iter()
.zip(rhs.numbers.row_major_iter())
{
let (left, _) = x;
let (right, _) = y;
*x = (fxy(left.clone(), right), 0);
}
}
(Some(history), None) => {
let zs = binary_x_history::<T, _, _>(
total,
history,
self.numbers.row_major_iter(),
rhs.numbers.row_major_iter(),
fxy,
dfxy_dx,
);
for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
*element = result;
}
self.history = Some(history);
}
(None, Some(history)) => {
let zs = binary_y_history::<T, _, _>(
total,
history,
self.numbers.row_major_iter(),
rhs.numbers.row_major_iter(),
fxy,
dfxy_dy,
);
for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
*element = result;
}
self.history = Some(history);
}
(Some(history), Some(h)) => {
assert!(
record_operations::same_lists(history, h),
"Record containers must be using the same WengertList"
);
let zs = binary_both_history::<T, _, _>(
total,
history,
self.numbers.row_major_iter(),
rhs.numbers.row_major_iter(),
fxy,
dfxy_dx,
dfxy_dy,
);
for (element, result) in self.numbers.row_major_reference_mut_iter().zip(zs) {
*element = result;
}
self.history = Some(history);
}
}
}
#[track_caller]
pub fn do_unary_assign(mut self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Self {
self.unary_assign(fx, dfx_dx);
self
}
#[track_caller]
pub fn do_binary_left_assign<S2>(
mut self,
rhs: &RecordMatrix<'a, T, S2>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> Self
where
S2: MatrixRef<(T, Index)> + NoInteriorMutability,
{
self.binary_left_assign(rhs, fxy, dfxy_dx, dfxy_dy);
self
}
#[track_caller]
pub fn map_mut(
&mut self,
fx: impl Fn(Record<'a, T>) -> Record<'a, T>,
) -> Result<(), InconsistentHistory<'a, T>> {
let history = self.history;
let new_history =
map_mut_base::<'a, T, _, _>(RowMajorReferenceMutIterator::from(self), |x| {
let record = Record::from_existing(x.clone(), history);
let result = fx(record);
*x = (result.number, result.index);
result.history
})?;
self.history = new_history;
Ok(())
}
#[track_caller]
pub fn map_mut_with_index(
&mut self,
fx: impl Fn(Record<'a, T>, Row, Column) -> Record<'a, T>,
) -> Result<(), InconsistentHistory<'a, T>> {
let history = self.history;
let new_history = map_mut_base::<'a, T, _, _>(
RowMajorReferenceMutIterator::from(self).with_index(),
|((r, c), x)| {
let record = Record::from_existing(x.clone(), history);
let result = fx(record, r, c);
*x = (result.number, result.index);
result.history
},
)?;
self.history = new_history;
Ok(())
}
}
impl<'a, T, S> RecordMatrix<'a, T, S>
where
T: Numeric + Primitive,
for<'t> &'t T: NumericRef<T>,
S: MatrixRef<(T, Index)> + NoInteriorMutability,
{
#[track_caller]
pub fn binary_right_assign<S2>(
&self,
rhs: &mut RecordMatrix<'a, T, S2>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) where
S2: MatrixMut<(T, Index)> + NoInteriorMutability,
{
rhs.binary_left_assign(
self,
|y, x| fxy(x, y),
|y, x| dfxy_dy(x, y),
|y, x| dfxy_dx(x, y),
)
}
#[track_caller]
pub fn do_binary_right_assign<S2>(
&self,
mut rhs: RecordMatrix<'a, T, S2>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> RecordMatrix<'a, T, S2>
where
S2: MatrixMut<(T, Index)> + NoInteriorMutability,
{
self.binary_right_assign(&mut rhs, fxy, dfxy_dx, dfxy_dy);
rhs
}
}
unsafe impl<'a, T, S, const D: usize> TensorRef<(T, Index), D> for RecordTensor<'a, T, S, D>
where
T: Primitive,
S: TensorRef<(T, Index), D>,
{
fn get_reference(&self, indexes: [usize; D]) -> Option<&(T, Index)> {
self.numbers.source_ref().get_reference(indexes)
}
fn view_shape(&self) -> [(Dimension, usize); D] {
self.numbers.source_ref().view_shape()
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &(T, Index) {
unsafe { self.numbers.source_ref().get_reference_unchecked(indexes) }
}
fn data_layout(&self) -> DataLayout<D> {
self.numbers.source_ref().data_layout()
}
}
unsafe impl<'a, T, S, const D: usize> TensorMut<(T, Index), D> for RecordTensor<'a, T, S, D>
where
T: Primitive,
S: TensorMut<(T, Index), D>,
{
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut (T, Index)> {
self.numbers.source_ref_mut().get_reference_mut(indexes)
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut (T, Index) {
unsafe {
self.numbers
.source_ref_mut()
.get_reference_unchecked_mut(indexes)
}
}
}
unsafe impl<'a, T, S> MatrixRef<(T, Index)> for RecordMatrix<'a, T, S>
where
T: Primitive,
S: MatrixRef<(T, Index)>,
{
fn try_get_reference(&self, row: Row, column: Column) -> Option<&(T, Index)> {
self.numbers.source_ref().try_get_reference(row, column)
}
fn view_rows(&self) -> Row {
self.numbers.source_ref().view_rows()
}
fn view_columns(&self) -> Column {
self.numbers.source_ref().view_columns()
}
unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &(T, Index) {
unsafe {
self.numbers
.source_ref()
.get_reference_unchecked(row, column)
}
}
fn data_layout(&self) -> crate::matrices::views::DataLayout {
self.numbers.source_ref().data_layout()
}
}
unsafe impl<'a, T, S> NoInteriorMutability for RecordMatrix<'a, T, S>
where
T: Primitive,
S: NoInteriorMutability,
{
}
unsafe impl<'a, T, S> MatrixMut<(T, Index)> for RecordMatrix<'a, T, S>
where
T: Primitive,
S: MatrixMut<(T, Index)>,
{
fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut (T, Index)> {
self.numbers
.source_ref_mut()
.try_get_reference_mut(row, column)
}
unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut (T, Index) {
unsafe {
self.numbers
.source_ref_mut()
.get_reference_unchecked_mut(row, column)
}
}
}
impl<'a, T, S> From<RecordTensor<'a, T, S, 0>> for Record<'a, T>
where
T: Numeric + Primitive,
S: TensorRef<(T, Index), 0>,
{
fn from(scalar: RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
Record::from(&scalar)
}
}
impl<'a, T, S> From<&RecordTensor<'a, T, S, 0>> for Record<'a, T>
where
T: Numeric + Primitive,
S: TensorRef<(T, Index), 0>,
{
fn from(scalar: &RecordTensor<'a, T, S, 0>) -> Record<'a, T> {
Record::from_existing(scalar.view().scalar(), scalar.history)
}
}
impl<'a, T> From<Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
where
T: Numeric + Primitive,
{
fn from(record: Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
RecordTensor::from_existing(
record.history,
TensorView::from(Tensor::from([], vec![(record.number, record.index)])),
)
}
}
impl<'a, T> From<&Record<'a, T>> for RecordTensor<'a, T, Tensor<(T, Index), 0>, 0>
where
T: Numeric + Primitive,
{
fn from(record: &Record<'a, T>) -> RecordTensor<'a, T, Tensor<(T, Index), 0>, 0> {
RecordTensor::from_existing(
record.history,
TensorView::from(Tensor::from(
[],
vec![(record.number.clone(), record.index)],
)),
)
}
}
#[test]
fn matrix_multiplication_derivatives_are_the_same() {
#[rustfmt::skip]
let a = Tensor::from(
[("r", 4), ("c", 3)],
vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
0.0, 5.0, 2.0
]
);
let b = a.transpose(["c", "r"]);
let history = WengertList::new();
let also_history: WengertList<f64> = WengertList::new();
let tensor_of_records_a = a.map(|x| Record::variable(x, &history));
let tensor_of_records_b = b.map(|x| Record::variable(x, &history));
let tensor_of_records_c = &tensor_of_records_a * &tensor_of_records_b;
let record_tensor_a = RecordTensor::variables(&also_history, a);
let record_tensor_b = RecordTensor::variables(&also_history, b);
let record_tensor_c = &record_tensor_a * &record_tensor_b;
assert_eq!(
tensor_of_records_c.map(|r| r.number),
TensorView::from(&record_tensor_c).map(|(n, _)| n)
);
let tensor_of_records_derivatives = tensor_of_records_c.map(|r| r.derivatives());
let tensor_of_records_a_derivatives =
tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_a));
let tensor_of_records_b_derivatives =
tensor_of_records_derivatives.map(|d| d.at_tensor(&record_tensor_b));
let record_tensor_derivatives = record_tensor_c.derivatives().unwrap();
let record_tensor_a_derivatives =
record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_a));
let record_tensor_b_derivatives =
record_tensor_derivatives.map(|d| d.at_tensor(&record_tensor_b));
assert_eq!(tensor_of_records_a_derivatives, record_tensor_a_derivatives);
assert_eq!(tensor_of_records_b_derivatives, record_tensor_b_derivatives);
#[rustfmt::skip]
assert_eq!(
tensor_of_records_c.map(|r| r.number),
Tensor::from(
[("r", 4), ("c", 4)],
vec![
14.0, 32.0, 50.0, 16.0,
32.0, 77.0, 122.0, 37.0,
50.0, 122.0, 194.0, 58.0,
16.0, 37.0, 58.0, 29.0
]
)
);
#[rustfmt::skip]
assert_eq!(
tensor_of_records_c.map(|r| r.number),
Tensor::from(
[("r", 4), ("c", 4)],
vec![
(1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
(1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
(1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
(1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
(4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
(4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
(4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
(4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
(7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
(7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
(7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
(7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
(0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
(0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
(0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
(0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
]
)
);
let tensor_of_records_derivatives = history.operations.borrow().clone();
let record_tensor_derivatives = also_history.operations.borrow().clone();
assert_eq!(
tensor_of_records_derivatives.len(),
record_tensor_derivatives.len()
);
}
#[test]
fn matrix_view_matrix_multiplication_derivatives_are_the_same() {
#[rustfmt::skip]
let a = Matrix::from(vec![
vec![ 1.0, 2.0, 3.0 ],
vec![ 4.0, 5.0, 6.0 ],
vec![ 7.0, 8.0, 9.0 ],
vec![ 0.0, 5.0, 2.0 ]
]);
let b = a.transpose();
let history = WengertList::new();
let also_history: WengertList<f64> = WengertList::new();
let matrix_of_records_a = a.map(|x| Record::variable(x, &history));
let matrix_of_records_b = b.map(|x| Record::variable(x, &history));
let matrix_of_records_c = &matrix_of_records_a * &matrix_of_records_b;
let record_matrix_a = RecordMatrix::variables(&also_history, a);
let record_matrix_b = RecordMatrix::variables(&also_history, b);
let record_matrix_c = &record_matrix_a * &record_matrix_b;
assert_eq!(
matrix_of_records_c.map(|r| r.number),
MatrixView::from(&record_matrix_c).map(|(n, _)| n)
);
let matrix_of_records_derivatives = matrix_of_records_c.map(|r| r.derivatives());
let matrix_of_records_a_derivatives =
matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_a));
let matrix_of_records_b_derivatives =
matrix_of_records_derivatives.map(|d| d.at_matrix(&record_matrix_b));
let record_matrix_derivatives = record_matrix_c.derivatives().unwrap();
let record_matrix_a_derivatives =
record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_a));
let record_matrix_b_derivatives =
record_matrix_derivatives.map(|d| d.at_matrix(&record_matrix_b));
assert_eq!(matrix_of_records_a_derivatives, record_matrix_a_derivatives);
assert_eq!(matrix_of_records_b_derivatives, record_matrix_b_derivatives);
#[rustfmt::skip]
assert_eq!(
matrix_of_records_c.map(|r| r.number),
Matrix::from(vec![
vec![ 14.0, 32.0, 50.0, 16.0 ],
vec![ 32.0, 77.0, 122.0, 37.0 ],
vec![ 50.0, 122.0, 194.0, 58.0 ],
vec![ 16.0, 37.0, 58.0, 29.0 ]
]
)
);
#[rustfmt::skip]
assert_eq!(
matrix_of_records_c.map(|r| r.number),
Matrix::from(vec![
vec![
(1.0 * 1.0) + (2.0 * 2.0) + (3.0 * 3.0),
(1.0 * 4.0) + (2.0 * 5.0) + (3.0 * 6.0),
(1.0 * 7.0) + (2.0 * 8.0) + (3.0 * 9.0),
(1.0 * 0.0) + (2.0 * 5.0) + (3.0 * 2.0),
],
vec![
(4.0 * 1.0) + (5.0 * 2.0) + (6.0 * 3.0),
(4.0 * 4.0) + (5.0 * 5.0) + (6.0 * 6.0),
(4.0 * 7.0) + (5.0 * 8.0) + (6.0 * 9.0),
(4.0 * 0.0) + (5.0 * 5.0) + (6.0 * 2.0),
],
vec![
(7.0 * 1.0) + (8.0 * 2.0) + (9.0 * 3.0),
(7.0 * 4.0) + (8.0 * 5.0) + (9.0 * 6.0),
(7.0 * 7.0) + (8.0 * 8.0) + (9.0 * 9.0),
(7.0 * 0.0) + (8.0 * 5.0) + (9.0 * 2.0),
],
vec![
(0.0 * 1.0) + (5.0 * 2.0) + (2.0 * 3.0),
(0.0 * 4.0) + (5.0 * 5.0) + (2.0 * 6.0),
(0.0 * 7.0) + (5.0 * 8.0) + (2.0 * 9.0),
(0.0 * 0.0) + (5.0 * 5.0) + (2.0 * 2.0)
]
])
);
let matrix_of_records_derivatives = history.operations.borrow().clone();
let record_matrix_derivatives = also_history.operations.borrow().clone();
assert_eq!(
matrix_of_records_derivatives.len(),
record_matrix_derivatives.len()
);
}