use crate::differentiation::{Index, Primitive, Record, RecordTensor};
use crate::numeric::Numeric;
use crate::tensors::dimensions;
use crate::tensors::dimensions::DimensionMappings;
use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
use crate::tensors::{Dimension, Tensor};
use std::error::Error;
use std::fmt;
use std::iter::{ExactSizeIterator, FusedIterator};
use std::marker::PhantomData;
pub use crate::matrices::iterators::WithIndex;
#[derive(Clone, Debug)]
pub struct TensorAccess<T, S, const D: usize> {
source: S,
dimension_mapping: DimensionMappings<D>,
_type: PhantomData<T>,
}
impl<T, S, const D: usize> TensorAccess<T, S, D>
where
S: TensorRef<T, D>,
{
#[track_caller]
pub fn from(source: S, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
match TensorAccess::try_from(source, dimensions) {
Err(error) => panic!("{}", error),
Ok(success) => success,
}
}
pub fn try_from(
source: S,
dimensions: [Dimension; D],
) -> Result<TensorAccess<T, S, D>, InvalidDimensionsError<D>> {
Ok(TensorAccess {
dimension_mapping: DimensionMappings::new(&source.view_shape(), &dimensions)
.ok_or_else(|| InvalidDimensionsError {
actual: source.view_shape(),
requested: dimensions,
})?,
source,
_type: PhantomData,
})
}
pub fn from_source_order(source: S) -> TensorAccess<T, S, D> {
TensorAccess {
dimension_mapping: DimensionMappings::no_op_mapping(),
source,
_type: PhantomData,
}
}
pub fn from_memory_order(source: S) -> Option<TensorAccess<T, S, D>> {
let data_layout = match source.data_layout() {
DataLayout::Linear(order) => order,
_ => return None,
};
let shape = source.view_shape();
Some(TensorAccess::try_from(source, data_layout).unwrap_or_else(|_| panic!(
"Source implementation contained dimensions {:?} in data_layout that were not the same set as in the view_shape {:?} which breaks the contract of TensorRef",
data_layout, shape
)))
}
pub fn shape(&self) -> [(Dimension, usize); D] {
self.dimension_mapping
.map_shape_to_requested(&self.source.view_shape())
}
pub fn source(self) -> S {
self.source
}
pub fn source_ref(&self) -> &S {
&self.source
}
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub struct InvalidDimensionsError<const D: usize> {
pub actual: [(Dimension, usize); D],
pub requested: [Dimension; D],
}
impl<const D: usize> Error for InvalidDimensionsError<D> {}
impl<const D: usize> fmt::Display for InvalidDimensionsError<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Requested dimension order: {:?} does not match the shape in the source: {:?}",
&self.actual, &self.requested
)
}
}
#[test]
fn test_sync() {
fn assert_sync<T: Sync>() {}
assert_sync::<InvalidDimensionsError<3>>();
}
#[test]
fn test_send() {
fn assert_send<T: Send>() {}
assert_send::<InvalidDimensionsError<3>>();
}
impl<T, S, const D: usize> TensorAccess<T, S, D>
where
S: TensorRef<T, D>,
{
pub fn try_get_reference(&self, indexes: [usize; D]) -> Option<&T> {
self.source
.get_reference(self.dimension_mapping.map_dimensions_to_source(&indexes))
}
#[track_caller]
pub fn get_ref(&self, indexes: [usize; D]) -> &T {
match self.try_get_reference(indexes) {
Some(reference) => reference,
None => panic!(
"Unable to index with {:?}, Tensor dimensions are {:?}.",
indexes,
self.shape()
),
}
}
#[allow(clippy::missing_safety_doc)] pub unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
unsafe {
self.source
.get_reference_unchecked(self.dimension_mapping.map_dimensions_to_source(&indexes))
}
}
pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, TensorAccess<T, S, D>, D> {
TensorReferenceIterator::from(self)
}
}
impl<T, S, const D: usize> TensorAccess<T, S, D>
where
S: TensorRef<T, D>,
T: Clone,
{
#[track_caller]
pub fn get(&self, indexes: [usize; D]) -> T {
match self.try_get_reference(indexes) {
Some(reference) => reference.clone(),
None => panic!(
"Unable to index with {:?}, Tensor dimensions are {:?}.",
indexes,
self.shape()
),
}
}
pub fn first(&self) -> T {
self.iter()
.next()
.expect("Tensors always have at least 1 element")
}
pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
let mapped = self.iter().map(mapping_function).collect();
Tensor::from(self.shape(), mapped)
}
pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
let mapped = self
.iter()
.with_index()
.map(|(i, x)| mapping_function(i, x))
.collect();
Tensor::from(self.shape(), mapped)
}
pub fn iter(&self) -> TensorIterator<'_, T, TensorAccess<T, S, D>, D> {
TensorIterator::from(self)
}
}
impl<T, S, const D: usize> TensorAccess<T, S, D>
where
S: TensorMut<T, D>,
{
pub fn try_get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
self.source
.get_reference_mut(self.dimension_mapping.map_dimensions_to_source(&indexes))
}
#[track_caller]
pub fn get_ref_mut(&mut self, indexes: [usize; D]) -> &mut T {
match self.try_get_reference_mut(indexes) {
Some(reference) => reference,
None => panic!("Unable to index with {:?}", indexes),
}
}
#[allow(clippy::missing_safety_doc)] pub unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
unsafe {
self.source.get_reference_unchecked_mut(
self.dimension_mapping.map_dimensions_to_source(&indexes),
)
}
}
pub fn iter_reference_mut(
&mut self,
) -> TensorReferenceMutIterator<'_, T, TensorAccess<T, S, D>, D> {
TensorReferenceMutIterator::from(self)
}
}
impl<T, S, const D: usize> TensorAccess<T, S, D>
where
S: TensorMut<T, D>,
T: Clone,
{
pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
self.iter_reference_mut()
.for_each(|x| *x = mapping_function(x.clone()));
}
pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
self.iter_reference_mut()
.with_index()
.for_each(|(i, x)| *x = mapping_function(i, x.clone()));
}
}
impl<'a, T, S, const D: usize> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D>
where
T: Numeric + Primitive,
S: TensorRef<(T, Index), D>,
{
#[track_caller]
pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
Record::from_existing(self.get(indexes), self.source.history())
}
pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
self.try_get_reference(indexes)
.map(|r| Record::from_existing(r.clone(), self.source.history()))
}
}
impl<'a, T, S, const D: usize> TensorAccess<(T, Index), RecordTensor<'a, T, S, D>, D>
where
T: Numeric + Primitive,
S: TensorRef<(T, Index), D>,
{
#[track_caller]
pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
Record::from_existing(self.get(indexes), self.source.history())
}
pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
self.try_get_reference(indexes)
.map(|r| Record::from_existing(r.clone(), self.source.history()))
}
}
impl<'a, T, S, const D: usize> TensorAccess<(T, Index), &mut RecordTensor<'a, T, S, D>, D>
where
T: Numeric + Primitive,
S: TensorRef<(T, Index), D>,
{
#[track_caller]
pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
Record::from_existing(self.get(indexes), self.source.history())
}
pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
self.try_get_reference(indexes)
.map(|r| Record::from_existing(r.clone(), self.source.history()))
}
}
unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorAccess<T, S, D>
where
S: TensorRef<T, D>,
{
fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
self.try_get_reference(indexes)
}
fn view_shape(&self) -> [(Dimension, usize); D] {
self.shape()
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
unsafe { self.get_reference_unchecked(indexes) }
}
fn data_layout(&self) -> DataLayout<D> {
match self.source.data_layout() {
DataLayout::Linear(order) => DataLayout::Linear(order),
DataLayout::NonLinear => DataLayout::NonLinear,
DataLayout::Other => DataLayout::Other,
}
}
}
unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorAccess<T, S, D>
where
S: TensorMut<T, D>,
{
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
self.try_get_reference_mut(indexes)
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
unsafe { self.get_reference_unchecked_mut(indexes) }
}
}
impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorAccess<T, S, D>
where
T: std::fmt::Display,
S: TensorRef<T, D>,
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
crate::tensors::display::format_view(&self, f)?;
writeln!(f)?;
write!(f, "Data Layout = {:?}", self.data_layout())
}
}
#[derive(Clone, Debug)]
pub struct ShapeIterator<const D: usize> {
shape: [(Dimension, usize); D],
indexes: [usize; D],
finished: bool,
}
impl<const D: usize> ShapeIterator<D> {
pub fn from(shape: [(Dimension, usize); D]) -> ShapeIterator<D> {
let starting_index_valid = shape.iter().all(|(_, l)| *l > 0);
ShapeIterator {
shape,
indexes: [0; D],
finished: !starting_index_valid,
}
}
}
impl<const D: usize> Iterator for ShapeIterator<D> {
type Item = [usize; D];
fn next(&mut self) -> Option<Self::Item> {
iter(&mut self.finished, &mut self.indexes, &self.shape)
}
fn size_hint(&self) -> (usize, Option<usize>) {
size_hint(self.finished, &self.indexes, &self.shape)
}
}
impl<const D: usize> FusedIterator for ShapeIterator<D> {}
impl<const D: usize> ExactSizeIterator for ShapeIterator<D> {}
fn iter<const D: usize>(
finished: &mut bool,
indexes: &mut [usize; D],
shape: &[(Dimension, usize); D],
) -> Option<[usize; D]> {
if *finished {
return None;
}
let value = Some(*indexes);
if D > 0 {
indexes[D - 1] += 1;
for d in (1..D).rev() {
if indexes[d] == shape[d].1 {
indexes[d] = 0;
indexes[d - 1] += 1;
}
}
if indexes[0] == shape[0].1 {
*finished = true;
}
} else {
*finished = true;
}
value
}
fn iter_back<const D: usize>(
finished: &mut bool,
indexes: &mut [usize; D],
shape: &[(Dimension, usize); D],
) -> Option<[usize; D]> {
if *finished {
return None;
}
let value = Some(*indexes);
if D > 0 {
let mut bounds = [false; D];
if indexes[D - 1] == 0 {
bounds[D - 1] = true;
} else {
indexes[D - 1] -= 1;
}
for d in (1..D).rev() {
if bounds[d] {
indexes[d] = shape[d].1 - 1;
if indexes[d - 1] == 0 {
bounds[d - 1] = true;
} else {
indexes[d - 1] -= 1;
}
}
}
if bounds[0] {
*finished = true;
}
} else {
*finished = true;
}
value
}
fn size_hint<const D: usize>(
finished: bool,
indexes: &[usize; D],
shape: &[(Dimension, usize); D],
) -> (usize, Option<usize>) {
if finished {
return (0, Some(0));
}
let remaining = if D > 0 {
let total = dimensions::elements(shape);
let strides = crate::tensors::compute_strides(shape);
let seen = crate::tensors::get_index_direct_unchecked(indexes, &strides);
total - seen
} else {
1
};
(remaining, Some(remaining))
}
fn double_ended_size_hint<const D: usize>(
finished: bool,
forward_indexes: &[usize; D],
back_indexes: &[usize; D],
shape: &[(Dimension, usize); D],
) -> (usize, Option<usize>) {
if finished {
return (0, Some(0));
}
let remaining = if D > 0 {
let strides = crate::tensors::compute_strides(shape);
let progress_forward =
crate::tensors::get_index_direct_unchecked(forward_indexes, &strides);
let progress_backward = crate::tensors::get_index_direct_unchecked(back_indexes, &strides);
1 + progress_backward - progress_forward
} else {
1
};
(remaining, Some(remaining))
}
#[derive(Clone, Debug)]
pub struct DoubleEndedShapeIterator<const D: usize> {
shape: [(Dimension, usize); D],
forward_indexes: [usize; D],
back_indexes: [usize; D],
finished: bool,
}
impl<const D: usize> DoubleEndedShapeIterator<D> {
pub fn from(shape: [(Dimension, usize); D]) -> DoubleEndedShapeIterator<D> {
let starting_index_valid = shape.iter().all(|(_, l)| *l > 0);
DoubleEndedShapeIterator {
shape,
forward_indexes: [0; D],
back_indexes: shape.map(|(_, l)| l - 1),
finished: !starting_index_valid,
}
}
}
fn overlapping_iterators<const D: usize>(
forward_indexes: &[usize; D],
back_indexes: &[usize; D],
) -> bool {
forward_indexes == back_indexes
}
impl<const D: usize> Iterator for DoubleEndedShapeIterator<D> {
type Item = [usize; D];
fn next(&mut self) -> Option<Self::Item> {
let will_finish = overlapping_iterators(&self.forward_indexes, &self.back_indexes);
let item = iter(&mut self.finished, &mut self.forward_indexes, &self.shape);
if will_finish {
self.finished = true;
}
item
}
fn size_hint(&self) -> (usize, Option<usize>) {
double_ended_size_hint(
self.finished,
&self.forward_indexes,
&self.back_indexes,
&self.shape,
)
}
}
impl<const D: usize> DoubleEndedIterator for DoubleEndedShapeIterator<D> {
fn next_back(&mut self) -> Option<Self::Item> {
let will_finish = overlapping_iterators(&self.forward_indexes, &self.back_indexes);
let item = iter_back(&mut self.finished, &mut self.back_indexes, &self.shape);
if will_finish {
self.finished = true;
}
item
}
}
impl<const D: usize> FusedIterator for DoubleEndedShapeIterator<D> {}
impl<const D: usize> ExactSizeIterator for DoubleEndedShapeIterator<D> {}
#[derive(Debug)]
pub struct TensorIterator<'a, T, S, const D: usize> {
shape_iterator: DoubleEndedShapeIterator<D>,
source: &'a S,
_type: PhantomData<T>,
}
impl<'a, T, S, const D: usize> TensorIterator<'a, T, S, D>
where
T: Clone,
S: TensorRef<T, D>,
{
pub fn from(source: &S) -> TensorIterator<'_, T, S, D> {
TensorIterator {
shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
source,
_type: PhantomData,
}
}
pub fn with_index(self) -> WithIndex<Self> {
WithIndex { iterator: self }
}
}
impl<'a, T, S, const D: usize> From<TensorIterator<'a, T, S, D>>
for WithIndex<TensorIterator<'a, T, S, D>>
where
T: Clone,
S: TensorRef<T, D>,
{
fn from(iterator: TensorIterator<'a, T, S, D>) -> Self {
iterator.with_index()
}
}
impl<'a, T, S, const D: usize> Iterator for TensorIterator<'a, T, S, D>
where
T: Clone,
S: TensorRef<T, D>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.shape_iterator
.next()
.map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) }.clone())
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.shape_iterator.size_hint()
}
}
impl<'a, T, S, const D: usize> DoubleEndedIterator for TensorIterator<'a, T, S, D>
where
T: Clone,
S: TensorRef<T, D>,
{
fn next_back(&mut self) -> Option<Self::Item> {
self.shape_iterator
.next_back()
.map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) }.clone())
}
}
impl<'a, T, S, const D: usize> FusedIterator for TensorIterator<'a, T, S, D>
where
T: Clone,
S: TensorRef<T, D>,
{
}
impl<'a, T, S, const D: usize> ExactSizeIterator for TensorIterator<'a, T, S, D>
where
T: Clone,
S: TensorRef<T, D>,
{
}
impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorIterator<'a, T, S, D>>
where
T: Clone,
S: TensorRef<T, D>,
{
type Item = ([usize; D], T);
fn next(&mut self) -> Option<Self::Item> {
let index = self.iterator.shape_iterator.forward_indexes;
self.iterator.next().map(|x| (index, x))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iterator.size_hint()
}
}
impl<'a, T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorIterator<'a, T, S, D>>
where
T: Clone,
S: TensorRef<T, D>,
{
fn next_back(&mut self) -> Option<Self::Item> {
let index = self.iterator.shape_iterator.back_indexes;
self.iterator.next_back().map(|x| (index, x))
}
}
impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorIterator<'a, T, S, D>>
where
T: Clone,
S: TensorRef<T, D>,
{
}
impl<'a, T, S, const D: usize> ExactSizeIterator for WithIndex<TensorIterator<'a, T, S, D>>
where
T: Clone,
S: TensorRef<T, D>,
{
}
#[derive(Debug)]
pub struct TensorReferenceIterator<'a, T, S, const D: usize> {
shape_iterator: DoubleEndedShapeIterator<D>,
source: &'a S,
_type: PhantomData<&'a T>,
}
impl<'a, T, S, const D: usize> TensorReferenceIterator<'a, T, S, D>
where
S: TensorRef<T, D>,
{
pub fn from(source: &S) -> TensorReferenceIterator<'_, T, S, D> {
TensorReferenceIterator {
shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
source,
_type: PhantomData,
}
}
pub fn with_index(self) -> WithIndex<Self> {
WithIndex { iterator: self }
}
}
impl<'a, T, S, const D: usize> From<TensorReferenceIterator<'a, T, S, D>>
for WithIndex<TensorReferenceIterator<'a, T, S, D>>
where
S: TensorRef<T, D>,
{
fn from(iterator: TensorReferenceIterator<'a, T, S, D>) -> Self {
iterator.with_index()
}
}
impl<'a, T, S, const D: usize> Iterator for TensorReferenceIterator<'a, T, S, D>
where
S: TensorRef<T, D>,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
self.shape_iterator
.next()
.map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) })
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.shape_iterator.size_hint()
}
}
impl<'a, T, S, const D: usize> DoubleEndedIterator for TensorReferenceIterator<'a, T, S, D>
where
S: TensorRef<T, D>,
{
fn next_back(&mut self) -> Option<Self::Item> {
self.shape_iterator
.next_back()
.map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) })
}
}
impl<'a, T, S, const D: usize> FusedIterator for TensorReferenceIterator<'a, T, S, D> where
S: TensorRef<T, D>
{
}
impl<'a, T, S, const D: usize> ExactSizeIterator for TensorReferenceIterator<'a, T, S, D> where
S: TensorRef<T, D>
{
}
impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorReferenceIterator<'a, T, S, D>>
where
S: TensorRef<T, D>,
{
type Item = ([usize; D], &'a T);
fn next(&mut self) -> Option<Self::Item> {
let index = self.iterator.shape_iterator.forward_indexes;
self.iterator.next().map(|x| (index, x))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iterator.size_hint()
}
}
impl<'a, T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>>
where
S: TensorRef<T, D>,
{
fn next_back(&mut self) -> Option<Self::Item> {
let index = self.iterator.shape_iterator.back_indexes;
self.iterator.next_back().map(|x| (index, x))
}
}
impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>> where
S: TensorRef<T, D>
{
}
impl<'a, T, S, const D: usize> ExactSizeIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>> where
S: TensorRef<T, D>
{
}
#[derive(Debug)]
pub struct TensorReferenceMutIterator<'a, T, S, const D: usize> {
shape_iterator: DoubleEndedShapeIterator<D>,
source: &'a mut S,
_type: PhantomData<&'a mut T>,
}
impl<'a, T, S, const D: usize> TensorReferenceMutIterator<'a, T, S, D>
where
S: TensorMut<T, D>,
{
pub fn from(source: &mut S) -> TensorReferenceMutIterator<'_, T, S, D> {
TensorReferenceMutIterator {
shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
source,
_type: PhantomData,
}
}
pub fn with_index(self) -> WithIndex<Self> {
WithIndex { iterator: self }
}
}
impl<'a, T, S, const D: usize> From<TensorReferenceMutIterator<'a, T, S, D>>
for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
where
S: TensorMut<T, D>,
{
fn from(iterator: TensorReferenceMutIterator<'a, T, S, D>) -> Self {
iterator.with_index()
}
}
impl<'a, T, S, const D: usize> Iterator for TensorReferenceMutIterator<'a, T, S, D>
where
S: TensorMut<T, D>,
{
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
self.shape_iterator.next().map(|indexes| {
unsafe {
std::mem::transmute::<&mut T, &mut T>(
self.source.get_reference_unchecked_mut(indexes)
)
}
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.shape_iterator.size_hint()
}
}
impl<'a, T, S, const D: usize> DoubleEndedIterator for TensorReferenceMutIterator<'a, T, S, D>
where
S: TensorMut<T, D>,
{
fn next_back(&mut self) -> Option<Self::Item> {
self.shape_iterator.next_back().map(|indexes| {
unsafe {
std::mem::transmute::<&mut T, &mut T>(
self.source.get_reference_unchecked_mut(indexes)
)
}
})
}
}
impl<'a, T, S, const D: usize> FusedIterator for TensorReferenceMutIterator<'a, T, S, D> where
S: TensorMut<T, D>
{
}
impl<'a, T, S, const D: usize> ExactSizeIterator for TensorReferenceMutIterator<'a, T, S, D> where
S: TensorMut<T, D>
{
}
impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
where
S: TensorMut<T, D>,
{
type Item = ([usize; D], &'a mut T);
fn next(&mut self) -> Option<Self::Item> {
let index = self.iterator.shape_iterator.forward_indexes;
self.iterator.next().map(|x| (index, x))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iterator.size_hint()
}
}
impl<'a, T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
where
S: TensorMut<T, D>,
{
fn next_back(&mut self) -> Option<Self::Item> {
let index = self.iterator.shape_iterator.back_indexes;
self.iterator.next_back().map(|x| (index, x))
}
}
impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>> where
S: TensorMut<T, D>
{
}
impl<'a, T, S, const D: usize> ExactSizeIterator
for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
where
S: TensorMut<T, D>,
{
}
#[derive(Debug)]
pub struct TensorOwnedIterator<T, S, const D: usize> {
shape_iterator: DoubleEndedShapeIterator<D>,
source: S,
producer: fn() -> T,
}
impl<T, S, const D: usize> TensorOwnedIterator<T, S, D>
where
S: TensorMut<T, D>,
{
pub fn from(source: S) -> TensorOwnedIterator<T, S, D>
where
T: Default,
{
TensorOwnedIterator {
shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
source,
producer: || T::default(),
}
}
pub fn from_numeric(source: S) -> TensorOwnedIterator<T, S, D>
where
T: crate::numeric::ZeroOne,
{
TensorOwnedIterator {
shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
source,
producer: || T::zero(),
}
}
pub fn with_index(self) -> WithIndex<Self> {
WithIndex { iterator: self }
}
}
impl<T, S, const D: usize> From<TensorOwnedIterator<T, S, D>>
for WithIndex<TensorOwnedIterator<T, S, D>>
where
S: TensorMut<T, D>,
{
fn from(iterator: TensorOwnedIterator<T, S, D>) -> Self {
iterator.with_index()
}
}
impl<T, S, const D: usize> Iterator for TensorOwnedIterator<T, S, D>
where
S: TensorMut<T, D>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.shape_iterator.next().map(|indexes| {
let producer = self.producer;
let dummy = producer();
std::mem::replace(
unsafe { self.source.get_reference_unchecked_mut(indexes) },
dummy,
)
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.shape_iterator.size_hint()
}
}
impl<T, S, const D: usize> DoubleEndedIterator for TensorOwnedIterator<T, S, D>
where
S: TensorMut<T, D>,
{
fn next_back(&mut self) -> Option<Self::Item> {
self.shape_iterator.next_back().map(|indexes| {
let producer = self.producer;
let dummy = producer();
std::mem::replace(
unsafe { self.source.get_reference_unchecked_mut(indexes) },
dummy,
)
})
}
}
impl<T, S, const D: usize> FusedIterator for TensorOwnedIterator<T, S, D> where S: TensorMut<T, D> {}
impl<T, S, const D: usize> ExactSizeIterator for TensorOwnedIterator<T, S, D> where
S: TensorMut<T, D>
{
}
impl<T, S, const D: usize> Iterator for WithIndex<TensorOwnedIterator<T, S, D>>
where
S: TensorMut<T, D>,
{
type Item = ([usize; D], T);
fn next(&mut self) -> Option<Self::Item> {
let index = self.iterator.shape_iterator.forward_indexes;
self.iterator.next().map(|x| (index, x))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iterator.size_hint()
}
}
impl<T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorOwnedIterator<T, S, D>>
where
S: TensorMut<T, D>,
{
fn next_back(&mut self) -> Option<Self::Item> {
let index = self.iterator.shape_iterator.back_indexes;
self.iterator.next_back().map(|x| (index, x))
}
}
impl<T, S, const D: usize> FusedIterator for WithIndex<TensorOwnedIterator<T, S, D>> where
S: TensorMut<T, D>
{
}
impl<T, S, const D: usize> ExactSizeIterator for WithIndex<TensorOwnedIterator<T, S, D>> where
S: TensorMut<T, D>
{
}
#[derive(Clone)]
pub struct TensorTranspose<T, S, const D: usize> {
access: TensorAccess<T, S, D>,
}
impl<T: fmt::Debug, S: fmt::Debug, const D: usize> fmt::Debug for TensorTranspose<T, S, D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TensorTranspose")
.field("source", &self.access.source)
.field("dimension_mapping", &self.access.dimension_mapping)
.field("_type", &self.access._type)
.finish()
}
}
impl<T, S, const D: usize> TensorTranspose<T, S, D>
where
S: TensorRef<T, D>,
{
#[track_caller]
pub fn from(source: S, dimensions: [Dimension; D]) -> TensorTranspose<T, S, D> {
TensorTranspose {
access: match TensorAccess::try_from(source, dimensions) {
Err(error) => panic!("{}", error),
Ok(success) => success,
},
}
}
pub fn try_from(
source: S,
dimensions: [Dimension; D],
) -> Result<TensorTranspose<T, S, D>, InvalidDimensionsError<D>> {
TensorAccess::try_from(source, dimensions).map(|access| TensorTranspose { access })
}
pub fn shape(&self) -> [(Dimension, usize); D] {
let names = self.access.source.view_shape();
let order = self.access.shape();
std::array::from_fn(|d| (names[d].0, order[d].1))
}
pub fn source(self) -> S {
self.access.source
}
pub fn source_ref(&self) -> &S {
&self.access.source
}
}
unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorTranspose<T, S, D>
where
S: TensorRef<T, D>,
{
fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
self.access.try_get_reference(indexes)
}
fn view_shape(&self) -> [(Dimension, usize); D] {
self.shape()
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
unsafe { self.access.get_reference_unchecked(indexes) }
}
fn data_layout(&self) -> DataLayout<D> {
let data_layout = self.access.data_layout();
match data_layout {
DataLayout::Linear(order) => DataLayout::Linear(
self.access
.dimension_mapping
.map_linear_data_layout_to_transposed(&order),
),
_ => data_layout,
}
}
}
unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorTranspose<T, S, D>
where
S: TensorMut<T, D>,
{
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
self.access.try_get_reference_mut(indexes)
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
unsafe { self.access.get_reference_unchecked_mut(indexes) }
}
}
impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorTranspose<T, S, D>
where
T: std::fmt::Display,
S: TensorRef<T, D>,
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
crate::tensors::display::format_view(&self, f)?;
writeln!(f)?;
write!(f, "Data Layout = {:?}", self.data_layout())
}
}