use std::marker::PhantomData;
use crate::linear_algebra;
use crate::numeric::{Numeric, NumericRef};
use crate::tensors::dimensions;
use crate::tensors::indexing::{
TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
TensorReferenceMutIterator, TensorTranspose,
};
use crate::tensors::{Dimension, Tensor};
mod indexes;
mod map;
mod ranges;
mod renamed;
mod reshape;
mod reverse;
pub mod traits;
mod zip;
pub use indexes::*;
pub(crate) use map::*;
pub use ranges::*;
pub use renamed::*;
pub use reshape::*;
pub use reverse::*;
pub use zip::*;
pub unsafe trait TensorRef<T, const D: usize> {
fn get_reference(&self, indexes: [usize; D]) -> Option<&T>;
fn view_shape(&self) -> [(Dimension, usize); D];
#[allow(clippy::missing_safety_doc)] unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T;
fn data_layout(&self) -> DataLayout<D>;
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum DataLayout<const D: usize> {
Linear([Dimension; D]),
NonLinear,
Other,
}
pub unsafe trait TensorMut<T, const D: usize>: TensorRef<T, D> {
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T>;
#[allow(clippy::missing_safety_doc)] unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T;
}
#[derive(Clone)]
pub struct TensorView<T, S, const D: usize> {
source: S,
_type: PhantomData<T>,
}
impl<T, S, const D: usize> TensorView<T, S, D>
where
S: TensorRef<T, D>,
{
pub fn from(source: S) -> TensorView<T, S, D> {
TensorView {
source,
_type: PhantomData,
}
}
pub fn source(self) -> S {
self.source
}
pub fn source_ref(&self) -> &S {
&self.source
}
pub fn source_ref_mut(&mut self) -> &mut S {
&mut self.source
}
pub fn shape(&self) -> [(Dimension, usize); D] {
self.source.view_shape()
}
pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
dimensions::length_of(&self.source.view_shape(), dimension)
}
pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
dimensions::last_index_of(&self.source.view_shape(), dimension)
}
#[track_caller]
pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &S, D> {
TensorAccess::from(&self.source, dimensions)
}
#[track_caller]
pub fn index_by_mut(&mut self, dimensions: [Dimension; D]) -> TensorAccess<T, &mut S, D> {
TensorAccess::from(&mut self.source, dimensions)
}
#[track_caller]
pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
TensorAccess::from(self.source, dimensions)
}
pub fn index(&self) -> TensorAccess<T, &S, D> {
TensorAccess::from_source_order(&self.source)
}
pub fn index_mut(&mut self) -> TensorAccess<T, &mut S, D> {
TensorAccess::from_source_order(&mut self.source)
}
pub fn index_owned(self) -> TensorAccess<T, S, D> {
TensorAccess::from_source_order(self.source)
}
pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, S, D> {
TensorReferenceIterator::from(&self.source)
}
#[track_caller]
pub fn rename_view(
&self,
dimensions: [Dimension; D],
) -> TensorView<T, TensorRename<T, &S, D>, D> {
TensorView::from(TensorRename::from(&self.source, dimensions))
}
pub fn reshape<const D2: usize>(
&self,
shape: [(Dimension, usize); D2],
) -> TensorView<T, TensorReshape<T, &S, D, D2>, D2> {
TensorView::from(TensorReshape::from(&self.source, shape))
}
pub fn reshape_mut<const D2: usize>(
&mut self,
shape: [(Dimension, usize); D2],
) -> TensorView<T, TensorReshape<T, &mut S, D, D2>, D2> {
TensorView::from(TensorReshape::from(&mut self.source, shape))
}
pub fn reshape_owned<const D2: usize>(
self,
shape: [(Dimension, usize); D2],
) -> TensorView<T, TensorReshape<T, S, D, D2>, D2> {
TensorView::from(TensorReshape::from(self.source, shape))
}
pub fn flatten(&self, dimension: Dimension) -> TensorView<T, TensorReshape<T, &S, D, 1>, 1> {
self.reshape([(dimension, dimensions::elements(&self.shape()))])
}
pub fn flatten_mut(
&mut self,
dimension: Dimension,
) -> TensorView<T, TensorReshape<T, &mut S, D, 1>, 1> {
self.reshape_mut([(dimension, dimensions::elements(&self.shape()))])
}
pub fn flatten_owned(
self,
dimension: Dimension,
) -> TensorView<T, TensorReshape<T, S, D, 1>, 1> {
let length = dimensions::elements(&self.shape());
self.reshape_owned([(dimension, length)])
}
pub fn flatten_into_tensor(self, dimension: Dimension) -> Tensor<T, 1>
where
T: Clone,
{
let length = dimensions::elements(&self.shape());
Tensor::from([(dimension, length)], self.iter().collect())
}
pub fn range<R, const P: usize>(
&self,
ranges: [(Dimension, R); P],
) -> Result<TensorView<T, TensorRange<T, &S, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorRange::from(&self.source, ranges).map(|range| TensorView::from(range))
}
pub fn range_mut<R, const P: usize>(
&mut self,
ranges: [(Dimension, R); P],
) -> Result<TensorView<T, TensorRange<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorRange::from(&mut self.source, ranges).map(|range| TensorView::from(range))
}
pub fn range_owned<R, const P: usize>(
self,
ranges: [(Dimension, R); P],
) -> Result<TensorView<T, TensorRange<T, S, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorRange::from(self.source, ranges).map(|range| TensorView::from(range))
}
pub fn mask<R, const P: usize>(
&self,
masks: [(Dimension, R); P],
) -> Result<TensorView<T, TensorMask<T, &S, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorMask::from(&self.source, masks).map(|mask| TensorView::from(mask))
}
pub fn mask_mut<R, const P: usize>(
&mut self,
masks: [(Dimension, R); P],
) -> Result<TensorView<T, TensorMask<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorMask::from(&mut self.source, masks).map(|mask| TensorView::from(mask))
}
pub fn mask_owned<R, const P: usize>(
self,
masks: [(Dimension, R); P],
) -> Result<TensorView<T, TensorMask<T, S, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorMask::from(self.source, masks).map(|mask| TensorView::from(mask))
}
#[track_caller]
pub fn start_and_end_of(
&self,
dimension: Dimension,
start_and_end: usize,
) -> TensorView<T, TensorMask<T, &S, D>, D> {
TensorMask::panicking_start_and_end_of(&self.source, dimension, start_and_end)
}
#[track_caller]
pub fn start_and_end_of_mut(
&mut self,
dimension: Dimension,
start_and_end: usize,
) -> TensorView<T, TensorMask<T, &mut S, D>, D> {
TensorMask::panicking_start_and_end_of(&mut self.source, dimension, start_and_end)
}
#[track_caller]
pub fn start_and_end_of_owned(
self,
dimension: Dimension,
start_and_end: usize,
) -> TensorView<T, TensorMask<T, S, D>, D> {
TensorMask::panicking_start_and_end_of(self.source, dimension, start_and_end)
}
#[track_caller]
pub fn reverse(&self, dimensions: &[Dimension]) -> TensorView<T, TensorReverse<T, &S, D>, D> {
TensorView::from(TensorReverse::from(&self.source, dimensions))
}
#[track_caller]
pub fn reverse_mut(
&mut self,
dimensions: &[Dimension],
) -> TensorView<T, TensorReverse<T, &mut S, D>, D> {
TensorView::from(TensorReverse::from(&mut self.source, dimensions))
}
#[track_caller]
pub fn reverse_owned(
self,
dimensions: &[Dimension],
) -> TensorView<T, TensorReverse<T, S, D>, D> {
TensorView::from(TensorReverse::from(self.source, dimensions))
}
#[track_caller]
pub fn elementwise_reference<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
where
I: Into<TensorView<T, S2, D>>,
S2: TensorRef<T, D>,
M: Fn(&T, &T) -> T,
{
self.elementwise_reference_less_generic(rhs.into(), mapping_function)
}
#[track_caller]
pub fn elementwise_reference_with_index<S2, I, M>(
&self,
rhs: I,
mapping_function: M,
) -> Tensor<T, D>
where
I: Into<TensorView<T, S2, D>>,
S2: TensorRef<T, D>,
M: Fn([usize; D], &T, &T) -> T,
{
self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
}
#[track_caller]
fn elementwise_reference_less_generic<S2, M>(
&self,
rhs: TensorView<T, S2, D>,
mapping_function: M,
) -> Tensor<T, D>
where
S2: TensorRef<T, D>,
M: Fn(&T, &T) -> T,
{
let left_shape = self.shape();
let right_shape = rhs.shape();
if left_shape != right_shape {
panic!(
"Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
left_shape, right_shape
);
}
let mapped = self
.iter_reference()
.zip(rhs.iter_reference())
.map(|(x, y)| mapping_function(x, y))
.collect();
Tensor::from(left_shape, mapped)
}
#[track_caller]
fn elementwise_reference_less_generic_with_index<S2, M>(
&self,
rhs: TensorView<T, S2, D>,
mapping_function: M,
) -> Tensor<T, D>
where
S2: TensorRef<T, D>,
M: Fn([usize; D], &T, &T) -> T,
{
let left_shape = self.shape();
let right_shape = rhs.shape();
if left_shape != right_shape {
panic!(
"Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
left_shape, right_shape
);
}
let mapped = self
.iter_reference()
.with_index()
.zip(rhs.iter_reference())
.map(|((i, x), y)| mapping_function(i, x, y))
.collect();
Tensor::from(left_shape, mapped)
}
pub fn transpose_view(
&self,
dimensions: [Dimension; D],
) -> TensorView<T, TensorTranspose<T, &S, D>, D> {
TensorView::from(TensorTranspose::from(&self.source, dimensions))
}
pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> TensorView<T, Tensor<T, D>, D> {
let shape = self.shape();
let strides = crate::tensors::compute_strides(&shape);
TensorView::from(Tensor {
data,
shape,
strides,
})
}
}
impl<T, S, const D: usize> TensorView<T, S, D>
where
S: TensorMut<T, D>,
{
pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<'_, T, S, D> {
TensorReferenceMutIterator::from(&mut self.source)
}
pub fn iter_owned(self) -> TensorOwnedIterator<T, S, D>
where
T: Default,
{
TensorOwnedIterator::from(self.source())
}
}
impl<T, S, const D: usize> TensorView<T, S, D>
where
T: Clone,
S: TensorRef<T, D>,
{
pub fn first(&self) -> T {
self.iter()
.next()
.expect("Tensors always have at least 1 element")
}
#[track_caller]
pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
let shape = self.shape();
let mut reordered = self.reorder(dimensions);
#[allow(clippy::needless_range_loop)]
for d in 0..D {
reordered.shape[d].0 = shape[d].0;
}
reordered
}
#[track_caller]
pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
let reorderd = match TensorAccess::try_from(&self.source, dimensions) {
Ok(reordered) => reordered,
Err(_error) => panic!(
"Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
dimensions,
self.shape(),
),
};
let reorderd_shape = reorderd.shape();
Tensor::from(reorderd_shape, reorderd.iter().collect())
}
pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
let mapped = self.iter().map(mapping_function).collect();
Tensor::from(self.shape(), mapped)
}
pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
let mapped = self
.iter()
.with_index()
.map(|(i, x)| mapping_function(i, x))
.collect();
Tensor::from(self.shape(), mapped)
}
pub fn iter(&self) -> TensorIterator<'_, T, S, D> {
TensorIterator::from(&self.source)
}
#[track_caller]
pub fn elementwise<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
where
I: Into<TensorView<T, S2, D>>,
S2: TensorRef<T, D>,
M: Fn(T, T) -> T,
{
self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
mapping_function(lhs.clone(), rhs.clone())
})
}
#[track_caller]
pub fn elementwise_with_index<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
where
I: Into<TensorView<T, S2, D>>,
S2: TensorRef<T, D>,
M: Fn([usize; D], T, T) -> T,
{
self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
mapping_function(i, lhs.clone(), rhs.clone())
})
}
}
impl<T, S> TensorView<T, S, 0>
where
T: Clone,
S: TensorRef<T, 0>,
{
pub fn scalar(&self) -> T {
self.source.get_reference([]).unwrap().clone()
}
}
impl<T, S> TensorView<T, S, 0>
where
T: Default,
S: TensorMut<T, 0>,
{
pub fn into_scalar(self) -> T {
TensorOwnedIterator::from(self.source).next().unwrap()
}
}
impl<T, S, const D: usize> TensorView<T, S, D>
where
T: Clone,
S: TensorMut<T, D>,
{
pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
self.iter_reference_mut()
.for_each(|x| *x = mapping_function(x.clone()));
}
pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
self.iter_reference_mut()
.with_index()
.for_each(|(i, x)| *x = mapping_function(i, x.clone()));
}
}
impl<T, S> TensorView<T, S, 1>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
S: TensorRef<T, 1>,
{
pub fn scalar_product<S2, I>(&self, rhs: I) -> T
where
I: Into<TensorView<T, S2, 1>>,
S2: TensorRef<T, 1>,
{
self.scalar_product_less_generic(rhs.into())
}
}
impl<T, S> TensorView<T, S, 2>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
S: TensorRef<T, 2>,
{
pub fn determinant(&self) -> Option<T> {
linear_algebra::determinant_tensor::<T, _, _>(self)
}
pub fn inverse(&self) -> Option<Tensor<T, 2>> {
linear_algebra::inverse_tensor::<T, _, _>(self)
}
pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
linear_algebra::covariance::<T, _, _>(self, feature_dimension)
}
}
macro_rules! tensor_view_select_impl {
(impl TensorView $d:literal 1) => {
impl<T, S> TensorView<T, S, $d>
where
S: TensorRef<T, $d>,
{
#[track_caller]
pub fn select(
&self,
provided_indexes: [(Dimension, usize); 1],
) -> TensorView<T, TensorIndex<T, &S, $d, 1>, { $d - 1 }> {
TensorView::from(TensorIndex::from(&self.source, provided_indexes))
}
#[track_caller]
pub fn select_mut(
&mut self,
provided_indexes: [(Dimension, usize); 1],
) -> TensorView<T, TensorIndex<T, &mut S, $d, 1>, { $d - 1 }> {
TensorView::from(TensorIndex::from(&mut self.source, provided_indexes))
}
#[track_caller]
pub fn select_owned(
self,
provided_indexes: [(Dimension, usize); 1],
) -> TensorView<T, TensorIndex<T, S, $d, 1>, { $d - 1 }> {
TensorView::from(TensorIndex::from(self.source, provided_indexes))
}
}
};
}
tensor_view_select_impl!(impl TensorView 6 1);
tensor_view_select_impl!(impl TensorView 5 1);
tensor_view_select_impl!(impl TensorView 4 1);
tensor_view_select_impl!(impl TensorView 3 1);
tensor_view_select_impl!(impl TensorView 2 1);
tensor_view_select_impl!(impl TensorView 1 1);
macro_rules! tensor_view_expand_impl {
(impl Tensor $d:literal 1) => {
impl<T, S> TensorView<T, S, $d>
where
S: TensorRef<T, $d>,
{
#[track_caller]
pub fn expand(
&self,
extra_dimension_names: [(usize, Dimension); 1],
) -> TensorView<T, TensorExpansion<T, &S, $d, 1>, { $d + 1 }> {
TensorView::from(TensorExpansion::from(&self.source, extra_dimension_names))
}
#[track_caller]
pub fn expand_mut(
&mut self,
extra_dimension_names: [(usize, Dimension); 1],
) -> TensorView<T, TensorExpansion<T, &mut S, $d, 1>, { $d + 1 }> {
TensorView::from(TensorExpansion::from(
&mut self.source,
extra_dimension_names,
))
}
#[track_caller]
pub fn expand_owned(
self,
extra_dimension_names: [(usize, Dimension); 1],
) -> TensorView<T, TensorExpansion<T, S, $d, 1>, { $d + 1 }> {
TensorView::from(TensorExpansion::from(self.source, extra_dimension_names))
}
}
};
}
tensor_view_expand_impl!(impl Tensor 0 1);
tensor_view_expand_impl!(impl Tensor 1 1);
tensor_view_expand_impl!(impl Tensor 2 1);
tensor_view_expand_impl!(impl Tensor 3 1);
tensor_view_expand_impl!(impl Tensor 4 1);
tensor_view_expand_impl!(impl Tensor 5 1);
impl<T, S, const D: usize> std::fmt::Debug for TensorView<T, S, D>
where
T: std::fmt::Debug,
S: std::fmt::Debug + TensorRef<T, D>,
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("TensorView")
.field("visible", &DebugSourceVisible::from(&self.source))
.field("shape", &self.source.view_shape())
.field("source", &self.source)
.finish()
}
}
struct DebugSourceVisible<T, S, const D: usize> {
source: S,
_type: PhantomData<T>,
}
impl<T, S, const D: usize> DebugSourceVisible<T, S, D>
where
T: std::fmt::Debug,
S: std::fmt::Debug + TensorRef<T, D>,
{
fn from(source: S) -> DebugSourceVisible<T, S, D> {
DebugSourceVisible {
source,
_type: PhantomData,
}
}
}
impl<T, S, const D: usize> std::fmt::Debug for DebugSourceVisible<T, S, D>
where
T: std::fmt::Debug,
S: std::fmt::Debug + TensorRef<T, D>,
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_list()
.entries(TensorReferenceIterator::from(&self.source))
.finish()
}
}
#[test]
fn test_debug() {
let x = Tensor::from([("rows", 3), ("columns", 4)], (0..12).collect());
let view = TensorView::from(&x);
let debugged = format!("{:?}\n{:?}", x, view);
assert_eq!(
debugged,
r#"Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] }
TensorView { visible: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], source: Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] } }"#
)
}
#[test]
fn test_debug_clipped() {
let x = Tensor::from([("rows", 2), ("columns", 3)], (0..6).collect());
let view = TensorView::from(&x)
.range_owned([("columns", IndexRange::new(1, 2))])
.unwrap();
let debugged = format!("{:#?}\n{:#?}", x, view);
println!("{:#?}\n{:#?}", x, view);
assert_eq!(
debugged,
r#"Tensor {
data: [
0,
1,
2,
3,
4,
5,
],
shape: [
(
"rows",
2,
),
(
"columns",
3,
),
],
strides: [
3,
1,
],
}
TensorView {
visible: [
1,
2,
4,
5,
],
shape: [
(
"rows",
2,
),
(
"columns",
2,
),
],
source: TensorRange {
source: Tensor {
data: [
0,
1,
2,
3,
4,
5,
],
shape: [
(
"rows",
2,
),
(
"columns",
3,
),
],
strides: [
3,
1,
],
},
range: [
IndexRange {
start: 0,
length: 2,
},
IndexRange {
start: 1,
length: 2,
},
],
_type: PhantomData<i32>,
},
}"#
)
}
impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorView<T, S, D>
where
T: std::fmt::Display,
S: TensorRef<T, D>,
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
crate::tensors::display::format_view(&self.source, f)
}
}
impl<T, const D: usize> From<Tensor<T, D>> for TensorView<T, Tensor<T, D>, D> {
fn from(tensor: Tensor<T, D>) -> TensorView<T, Tensor<T, D>, D> {
TensorView::from(tensor)
}
}
impl<'a, T, const D: usize> From<&'a Tensor<T, D>> for TensorView<T, &'a Tensor<T, D>, D> {
fn from(tensor: &Tensor<T, D>) -> TensorView<T, &Tensor<T, D>, D> {
TensorView::from(tensor)
}
}
impl<'a, T, const D: usize> From<&'a mut Tensor<T, D>> for TensorView<T, &'a mut Tensor<T, D>, D> {
fn from(tensor: &mut Tensor<T, D>) -> TensorView<T, &mut Tensor<T, D>, D> {
TensorView::from(tensor)
}
}
impl<'a, T, S, const D: usize> From<&'a TensorView<T, S, D>> for TensorView<T, &'a S, D>
where
S: TensorRef<T, D>,
{
fn from(tensor_view: &TensorView<T, S, D>) -> TensorView<T, &S, D> {
TensorView::from(tensor_view.source_ref())
}
}
impl<'a, T, S, const D: usize> From<&'a mut TensorView<T, S, D>> for TensorView<T, &'a mut S, D>
where
S: TensorRef<T, D>,
{
fn from(tensor_view: &mut TensorView<T, S, D>) -> TensorView<T, &mut S, D> {
TensorView::from(tensor_view.source_ref_mut())
}
}