use crate::differentiation::{Index, Primitive, Record, WengertList};
use crate::differentiation::{RecordMatrix, RecordTensor};
use crate::matrices::iterators::{ColumnMajorIterator, RowMajorIterator, WithIndex};
use crate::matrices::views::{MatrixRef, MatrixView, NoInteriorMutability};
use crate::matrices::{Column, Matrix, Row};
use crate::numeric::Numeric;
use crate::tensors::indexing::TensorIterator;
use crate::tensors::views::{TensorRef, TensorView};
use crate::tensors::{Dimension, InvalidShapeError, Tensor};
use std::error::Error;
use std::fmt;
use std::fmt::Debug;
use std::iter::{ExactSizeIterator, FusedIterator};
pub struct AsRecords<'a, I, T> {
numbers: I,
history: Option<&'a WengertList<T>>,
}
impl<'a, 'b, T, S, const D: usize>
AsRecords<'a, TensorIterator<'b, (T, Index), RecordTensor<'a, T, S, D>, D>, T>
where
T: Numeric + Primitive,
S: TensorRef<(T, Index), D>,
{
pub fn from_tensor(tensor: &'b RecordTensor<'a, T, S, D>) -> Self {
AsRecords::from(tensor.history, TensorIterator::from(tensor))
}
}
impl<'a, 'b, T, S> AsRecords<'a, RowMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T>
where
T: Numeric + Primitive,
S: MatrixRef<(T, Index)> + NoInteriorMutability,
{
pub fn from_matrix_row_major(matrix: &'b RecordMatrix<'a, T, S>) -> Self {
AsRecords::from(matrix.history, RowMajorIterator::from(matrix))
}
}
impl<'a, 'b, T, S> AsRecords<'a, ColumnMajorIterator<'b, (T, Index), RecordMatrix<'a, T, S>>, T>
where
T: Numeric + Primitive,
S: MatrixRef<(T, Index)> + NoInteriorMutability,
{
pub fn from_matrix_column_major(matrix: &'b RecordMatrix<'a, T, S>) -> Self {
AsRecords::from(matrix.history, ColumnMajorIterator::from(matrix))
}
}
impl<'a, I, T> AsRecords<'a, I, T>
where
T: Numeric + Primitive,
I: Iterator<Item = (T, Index)>,
{
pub fn from(history: Option<&'a WengertList<T>>, numbers: I) -> Self {
AsRecords { numbers, history }
}
}
impl<'a, I, T> AsRecords<'a, I, T>
where
T: Numeric + Primitive,
I: Iterator<Item = (T, Index)> + Into<WithIndex<I>>,
{
pub fn with_index(self) -> WithIndex<AsRecords<'a, WithIndex<I>, T>> {
WithIndex {
iterator: AsRecords {
numbers: self.numbers.into(),
history: self.history,
},
}
}
}
impl<'a, I, O, T> AsRecords<'a, I, T>
where
T: Numeric + Primitive,
I: Iterator<Item = (O, (T, Index))>,
{
pub fn from_with_index(history: Option<&'a WengertList<T>>, numbers: I) -> Self {
AsRecords { numbers, history }
}
}
impl<'a, I, T> From<AsRecords<'a, I, T>> for WithIndex<AsRecords<'a, WithIndex<I>, T>>
where
T: Numeric + Primitive,
I: Iterator<Item = (T, Index)> + Into<WithIndex<I>>,
{
fn from(iterator: AsRecords<'a, I, T>) -> Self {
iterator.with_index()
}
}
impl<'a, I, T> Iterator for AsRecords<'a, I, T>
where
T: Numeric + Primitive,
I: Iterator<Item = (T, Index)>,
{
type Item = Record<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.numbers
.next()
.map(|number| Record::from_existing(number, self.history))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.numbers.size_hint()
}
}
impl<'a, I, T> FusedIterator for AsRecords<'a, I, T>
where
T: Numeric + Primitive,
I: Iterator<Item = (T, Index)> + FusedIterator,
{
}
impl<'a, I, T> ExactSizeIterator for AsRecords<'a, I, T>
where
T: Numeric + Primitive,
I: Iterator<Item = (T, Index)> + ExactSizeIterator,
{
}
impl<'a, I, O, T> Iterator for WithIndex<AsRecords<'a, I, T>>
where
T: Numeric + Primitive,
I: Iterator<Item = (O, (T, Index))>,
{
type Item = (O, Record<'a, T>);
fn next(&mut self) -> Option<Self::Item> {
self.iterator
.numbers
.next()
.map(|(i, number)| (i, Record::from_existing(number, self.iterator.history)))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iterator.numbers.size_hint()
}
}
impl<'a, I, O, T> FusedIterator for WithIndex<AsRecords<'a, I, T>>
where
T: Numeric + Primitive,
I: Iterator<Item = (O, (T, Index))> + FusedIterator,
{
}
impl<'a, I, O, T> ExactSizeIterator for WithIndex<AsRecords<'a, I, T>>
where
T: Numeric + Primitive,
I: Iterator<Item = (O, (T, Index))> + ExactSizeIterator,
{
}
#[derive(Clone, Debug)]
pub enum InvalidRecordIteratorError<'a, T, const D: usize> {
Shape {
requested: InvalidShapeError<D>,
length: usize,
},
Empty,
InconsistentHistory(InconsistentHistory<'a, T>),
}
#[derive(Clone, Debug)]
pub struct InconsistentHistory<'a, T> {
pub first: Option<&'a WengertList<T>>,
pub later: Option<&'a WengertList<T>>,
}
impl<'a, T, const D: usize> fmt::Display for InvalidRecordIteratorError<'a, T, D>
where
T: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Shape { requested, length } => write!(
f,
"Shape {:?} does not match size of data {}",
requested.shape(),
length
),
Self::Empty => write!(
f,
"Iterator was empty but all tensors and matrices must contain at least one element"
),
Self::InconsistentHistory(h) => write!(
f,
"First history in iterator of records was {:?} but a later history in iterator was {:?}, record container cannot support different histories for a single tensor or matrix.",
h.first, h.later,
),
}
}
}
impl<'a, T> fmt::Display for InconsistentHistory<'a, T>
where
T: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"First history was {:?} but a later history in iterator was {:?}, record container cannot support different histories for a single tensor or matrix.",
self.first, self.later,
)
}
}
impl<'a, T, const D: usize> Error for InvalidRecordIteratorError<'a, T, D> where T: Debug {}
impl<'a, T> Error for InconsistentHistory<'a, T> where T: Debug {}
struct RecordContainerComponents<'a, T> {
history: Option<&'a WengertList<T>>,
numbers: Vec<(T, Index)>,
}
fn collect_into_components<'a, T, I, const D: usize>(
iter: I,
) -> Result<RecordContainerComponents<'a, T>, InvalidRecordIteratorError<'a, T, D>>
where
T: Numeric + Primitive,
I: IntoIterator<Item = Record<'a, T>>,
{
use crate::differentiation::record_operations::are_exact_same_list;
let mut history: Option<Option<&WengertList<T>>> = None;
let mut error: Option<InvalidRecordIteratorError<'a, T, D>> = None;
let numbers: Vec<(T, Index)> = iter
.into_iter()
.map(|record| {
match history {
None => history = Some(record.history),
Some(h) => {
if !are_exact_same_list(h, record.history) {
error = Some(InvalidRecordIteratorError::InconsistentHistory(
InconsistentHistory {
first: h,
later: record.history,
},
));
}
}
}
(record.number, record.index)
})
.collect();
if let Some(error) = error {
return Err(error);
}
let data_length = numbers.len();
if data_length == 0 {
Err(InvalidRecordIteratorError::Empty)
} else {
Ok(RecordContainerComponents {
history: history.unwrap(),
numbers,
})
}
}
fn collect_into_n_components<'a, T, I, const D: usize, const N: usize>(
iter: I,
) -> [Result<RecordContainerComponents<'a, T>, InvalidRecordIteratorError<'a, T, D>>; N]
where
T: Numeric + Primitive,
I: IntoIterator<Item = [Record<'a, T>; N]>,
{
use crate::differentiation::record_operations::are_exact_same_list;
let iter = iter.into_iter();
let mut histories: [Option<Option<&WengertList<T>>>; N] = [None; N];
let mut errors: [Option<InvalidRecordIteratorError<'a, T, D>>; N] =
std::array::from_fn(|_| None);
let mut numbers: [Vec<(T, usize)>; N] =
std::array::from_fn(|_| Vec::with_capacity(iter.size_hint().0));
for records in iter {
for (n, record) in records.into_iter().enumerate() {
let history = &mut histories[n];
let error = &mut errors[n];
match history {
None => *history = Some(record.history),
Some(h) => {
if !are_exact_same_list(*h, record.history) {
*error = Some(InvalidRecordIteratorError::InconsistentHistory(
InconsistentHistory {
first: *h,
later: record.history,
},
));
}
}
}
numbers[n].push((record.number, record.index));
}
}
let mut histories = histories.into_iter();
let mut errors = errors.into_iter();
let mut numbers = numbers.into_iter();
std::array::from_fn(|_| {
let history = histories.next().unwrap();
let error = errors.next().unwrap();
let numbers = numbers.next().unwrap();
let data_length = numbers.len();
match error {
Some(error) => Err(error),
None => {
if data_length == 0 {
Err(InvalidRecordIteratorError::Empty)
} else {
Ok(RecordContainerComponents {
history: history.unwrap(),
numbers,
})
}
}
}
})
}
impl<'a, T, const D: usize> RecordTensor<'a, T, Tensor<(T, Index), D>, D>
where
T: Numeric + Primitive,
{
pub fn from_iter<I>(
shape: [(Dimension, usize); D],
iter: I,
) -> Result<Self, InvalidRecordIteratorError<'a, T, D>>
where
I: IntoIterator<Item = Record<'a, T>>,
{
let RecordContainerComponents { history, numbers } = collect_into_components(iter)?;
let data_length = numbers.len();
match Tensor::try_from(shape, numbers) {
Ok(numbers) => Ok(RecordTensor::from_existing(
history,
TensorView::from(numbers),
)),
Err(invalid_shape) => Err(InvalidRecordIteratorError::Shape {
requested: invalid_shape,
length: data_length,
}),
}
}
pub fn from_iters<I, const N: usize>(
shape: [(Dimension, usize); D],
iter: I,
) -> [Result<Self, InvalidRecordIteratorError<'a, T, D>>; N]
where
I: IntoIterator<Item = [Record<'a, T>; N]>,
{
let mut components = collect_into_n_components(iter).into_iter();
std::array::from_fn(|_| match components.next().unwrap() {
Err(error) => Err(error),
Ok(RecordContainerComponents { history, numbers }) => {
let data_length = numbers.len();
match Tensor::try_from(shape, numbers) {
Ok(numbers) => Ok(RecordTensor::from_existing(
history,
TensorView::from(numbers),
)),
Err(invalid_shape) => Err(InvalidRecordIteratorError::Shape {
requested: invalid_shape,
length: data_length,
}),
}
}
})
}
}
impl<'a, T> RecordMatrix<'a, T, Matrix<(T, Index)>>
where
T: Numeric + Primitive,
{
pub fn from_iter<I>(
size: (Row, Column),
iter: I,
) -> Result<Self, InvalidRecordIteratorError<'a, T, 2>>
where
I: IntoIterator<Item = Record<'a, T>>,
{
let RecordContainerComponents { history, numbers } = collect_into_components(iter)?;
let data_length = numbers.len();
if data_length == size.0 * size.1 {
Ok(RecordMatrix::from_existing(
history,
MatrixView::from(Matrix::from_flat_row_major(size, numbers)),
))
} else {
Err(InvalidRecordIteratorError::Shape {
requested: InvalidShapeError::new([("rows", size.0), ("columns", size.1)]),
length: data_length,
})
}
}
pub fn from_iters<I, const N: usize>(
size: (Row, Column),
iter: I,
) -> [Result<Self, InvalidRecordIteratorError<'a, T, 2>>; N]
where
I: IntoIterator<Item = [Record<'a, T>; N]>,
{
let mut components = collect_into_n_components(iter).into_iter();
std::array::from_fn(|_| match components.next().unwrap() {
Err(error) => Err(error),
Ok(RecordContainerComponents { history, numbers }) => {
let data_length = numbers.len();
if data_length == size.0 * size.1 {
Ok(RecordMatrix::from_existing(
history,
MatrixView::from(Matrix::from_flat_row_major(size, numbers)),
))
} else {
Err(InvalidRecordIteratorError::Shape {
requested: InvalidShapeError::new([("rows", size.0), ("columns", size.1)]),
length: data_length,
})
}
}
})
}
}