use crate::linear_algebra;
use crate::numeric::extra::{Real, RealRef};
use crate::numeric::{Numeric, NumericRef};
use crate::tensors::indexing::{
ShapeIterator, TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
TensorReferenceMutIterator, TensorTranspose,
};
use crate::tensors::views::{
DataLayout, IndexRange, IndexRangeValidationError, TensorExpansion, TensorIndex, TensorMask,
TensorMut, TensorRange, TensorRef, TensorRename, TensorReshape, TensorReverse, TensorView,
};
use std::error::Error;
use std::fmt;
#[cfg(feature = "serde")]
use serde::Serialize;
pub mod dimensions;
mod display;
pub mod indexing;
pub mod operations;
pub mod views;
#[cfg(feature = "serde")]
pub use serde_impls::TensorDeserialize;
pub type Dimension = &'static str;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct InvalidShapeError<const D: usize> {
shape: [(Dimension, usize); D],
}
impl<const D: usize> InvalidShapeError<D> {
pub fn is_valid(&self) -> bool {
!crate::tensors::dimensions::has_duplicates(&self.shape)
&& !self.shape.iter().any(|d| d.1 == 0)
}
pub fn new(shape: [(Dimension, usize); D]) -> InvalidShapeError<D> {
InvalidShapeError { shape }
}
pub fn shape(&self) -> [(Dimension, usize); D] {
self.shape
}
pub fn shape_ref(&self) -> &[(Dimension, usize); D] {
&self.shape
}
#[track_caller]
#[inline]
fn validate_dimensions_or_panic(shape: &[(Dimension, usize); D], data_len: usize) {
let elements = crate::tensors::dimensions::elements(shape);
if data_len != elements {
panic!(
"Product of dimension lengths must match size of data. {} != {}",
elements, data_len
);
}
if crate::tensors::dimensions::has_duplicates(shape) {
panic!("Dimension names must all be unique: {:?}", &shape);
}
if shape.iter().any(|d| d.1 == 0) {
panic!("No dimension can have 0 elements: {:?}", &shape);
}
}
fn validate_dimensions(shape: &[(Dimension, usize); D], data_len: usize) -> bool {
let elements = crate::tensors::dimensions::elements(shape);
data_len == elements
&& !crate::tensors::dimensions::has_duplicates(shape)
&& !shape.iter().any(|d| d.1 == 0)
}
}
impl<const D: usize> fmt::Display for InvalidShapeError<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Dimensions must all be at least length 1 with unique names: {:?}",
self.shape
)
}
}
impl<const D: usize> Error for InvalidShapeError<D> {}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct InvalidDimensionsError<const D: usize, const P: usize> {
valid: [Dimension; D],
provided: [Dimension; P],
}
impl<const D: usize, const P: usize> InvalidDimensionsError<D, P> {
pub fn has_duplicates(&self) -> bool {
crate::tensors::dimensions::has_duplicates_names(&self.provided)
}
pub fn new(provided: [Dimension; P], valid: [Dimension; D]) -> InvalidDimensionsError<D, P> {
InvalidDimensionsError { valid, provided }
}
pub fn provided_names(&self) -> [Dimension; P] {
self.provided
}
pub fn provided_names_ref(&self) -> &[Dimension; P] {
&self.provided
}
pub fn valid_names(&self) -> [Dimension; D] {
self.valid
}
pub fn valid_names_ref(&self) -> &[Dimension; D] {
&self.valid
}
}
impl<const D: usize, const P: usize> fmt::Display for InvalidDimensionsError<D, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if P > 0 {
write!(
f,
"Dimensions names {:?} were incorrect, valid dimensions in this context are: {:?}",
self.provided, self.valid
)
} else {
write!(f, "Dimensions names {:?} were incorrect", self.provided)
}
}
}
impl<const D: usize, const P: usize> Error for InvalidDimensionsError<D, P> {}
#[test]
fn test_sync() {
fn assert_sync<T: Sync>() {}
assert_sync::<InvalidShapeError<2>>();
assert_sync::<InvalidDimensionsError<2, 2>>();
}
#[test]
fn test_send() {
fn assert_send<T: Send>() {}
assert_send::<InvalidShapeError<2>>();
assert_send::<InvalidDimensionsError<2, 2>>();
}
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct Tensor<T, const D: usize> {
data: Vec<T>,
#[cfg_attr(feature = "serde", serde(with = "serde_arrays"))]
shape: [(Dimension, usize); D],
#[cfg_attr(feature = "serde", serde(skip))]
strides: [usize; D],
}
impl<T, const D: usize> Tensor<T, D> {
#[track_caller]
pub fn from(shape: [(Dimension, usize); D], data: Vec<T>) -> Self {
InvalidShapeError::validate_dimensions_or_panic(&shape, data.len());
let strides = compute_strides(&shape);
Tensor {
data,
shape,
strides,
}
}
#[track_caller]
pub fn from_fn<F>(shape: [(Dimension, usize); D], mut producer: F) -> Self
where
F: FnMut([usize; D]) -> T,
{
let length = dimensions::elements(&shape);
let mut data = Vec::with_capacity(length);
let iterator = ShapeIterator::from(shape);
for index in iterator {
data.push(producer(index));
}
Tensor::from(shape, data)
}
pub fn shape(&self) -> [(Dimension, usize); D] {
self.shape
}
pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
dimensions::length_of(&self.shape, dimension)
}
pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
dimensions::last_index_of(&self.shape, dimension)
}
pub fn try_from(
shape: [(Dimension, usize); D],
data: Vec<T>,
) -> Result<Self, InvalidShapeError<D>> {
let valid = InvalidShapeError::validate_dimensions(&shape, data.len());
if !valid {
return Err(InvalidShapeError::new(shape));
}
let strides = compute_strides(&shape);
Ok(Tensor {
data,
shape,
strides,
})
}
pub(crate) fn direct_from(
data: Vec<T>,
shape: [(Dimension, usize); D],
strides: [usize; D],
) -> Self {
Tensor {
data,
shape,
strides,
}
}
#[allow(dead_code)] pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> Self {
Tensor {
data,
shape: self.shape,
strides: self.strides,
}
}
}
impl<T> Tensor<T, 0> {
pub fn from_scalar(value: T) -> Tensor<T, 0> {
Tensor {
data: vec![value],
shape: [],
strides: [],
}
}
pub fn into_scalar(self) -> T {
self.data
.into_iter()
.next()
.expect("Tensors always have at least 1 element")
}
}
impl<T> Tensor<T, 0>
where
T: Clone,
{
pub fn scalar(&self) -> T {
self.data
.first()
.expect("Tensors always have at least 1 element")
.clone()
}
}
impl<T> From<T> for Tensor<T, 0> {
fn from(scalar: T) -> Tensor<T, 0> {
Tensor::from_scalar(scalar)
}
}
unsafe impl<T, const D: usize> TensorRef<T, D> for Tensor<T, D> {
fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
self.data.get(i)
}
fn view_shape(&self) -> [(Dimension, usize); D] {
Tensor::shape(self)
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
unsafe {
let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
self.data.get_unchecked(i)
}
}
fn data_layout(&self) -> DataLayout<D> {
DataLayout::Linear(std::array::from_fn(|i| self.shape[i].0))
}
}
unsafe impl<T, const D: usize> TensorMut<T, D> for Tensor<T, D> {
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
self.data.get_mut(i)
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
unsafe {
let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
self.data.get_unchecked_mut(i)
}
}
}
impl<T: Clone, const D: usize> Clone for Tensor<T, D> {
fn clone(&self) -> Self {
self.map(|element| element)
}
}
impl<T: std::fmt::Display, const D: usize> std::fmt::Display for Tensor<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
crate::tensors::display::format_view(self, f)
}
}
impl<T> From<Tensor<T, 2>> for crate::matrices::Matrix<T> {
fn from(tensor: Tensor<T, 2>) -> Self {
crate::matrices::Matrix::from_flat_row_major(
(tensor.shape[0].1, tensor.shape[1].1),
tensor.data,
)
}
}
pub(crate) fn compute_strides<const D: usize>(shape: &[(Dimension, usize); D]) -> [usize; D] {
std::array::from_fn(|d| shape.iter().skip(d + 1).map(|d| d.1).product())
}
#[inline]
pub(crate) fn get_index_direct<const D: usize>(
indexes: &[usize; D],
strides: &[usize; D],
shape: &[(Dimension, usize); D],
) -> Option<usize> {
let mut index = 0;
for d in 0..D {
let n = indexes[d];
if n >= shape[d].1 {
return None;
}
index += n * strides[d];
}
Some(index)
}
#[inline]
fn get_index_direct_unchecked<const D: usize>(
indexes: &[usize; D],
strides: &[usize; D],
) -> usize {
let mut index = 0;
for d in 0..D {
let n = indexes[d];
index += n * strides[d];
}
index
}
impl<T, const D: usize> Tensor<T, D> {
pub fn view(&self) -> TensorView<T, &Tensor<T, D>, D> {
TensorView::from(self)
}
pub fn view_mut(&mut self) -> TensorView<T, &mut Tensor<T, D>, D> {
TensorView::from(self)
}
pub fn view_owned(self) -> TensorView<T, Tensor<T, D>, D> {
TensorView::from(self)
}
#[track_caller]
pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &Tensor<T, D>, D> {
TensorAccess::from(self, dimensions)
}
#[track_caller]
pub fn index_by_mut(
&mut self,
dimensions: [Dimension; D],
) -> TensorAccess<T, &mut Tensor<T, D>, D> {
TensorAccess::from(self, dimensions)
}
#[track_caller]
pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, Tensor<T, D>, D> {
TensorAccess::from(self, dimensions)
}
pub fn index(&self) -> TensorAccess<T, &Tensor<T, D>, D> {
TensorAccess::from_source_order(self)
}
pub fn index_mut(&mut self) -> TensorAccess<T, &mut Tensor<T, D>, D> {
TensorAccess::from_source_order(self)
}
pub fn index_owned(self) -> TensorAccess<T, Tensor<T, D>, D> {
TensorAccess::from_source_order(self)
}
pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, Tensor<T, D>, D> {
TensorReferenceIterator::from(self)
}
pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<'_, T, Tensor<T, D>, D> {
TensorReferenceMutIterator::from(self)
}
pub fn iter_owned(self) -> TensorOwnedIterator<T, Tensor<T, D>, D>
where
T: Default,
{
TensorOwnedIterator::from(self)
}
pub(crate) fn direct_iter_reference(&self) -> std::slice::Iter<'_, T> {
self.data.iter()
}
pub(crate) fn direct_iter_reference_mut(&mut self) -> std::slice::IterMut<'_, T> {
self.data.iter_mut()
}
#[track_caller]
pub fn rename(&mut self, dimensions: [Dimension; D]) {
if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
panic!("Dimension names must all be unique: {:?}", &dimensions);
}
#[allow(clippy::needless_range_loop)]
for d in 0..D {
self.shape[d].0 = dimensions[d];
}
}
#[track_caller]
pub fn rename_owned(mut self, dimensions: [Dimension; D]) -> Tensor<T, D> {
self.rename(dimensions);
self
}
#[track_caller]
pub fn rename_view(
&self,
dimensions: [Dimension; D],
) -> TensorView<T, TensorRename<T, &Tensor<T, D>, D>, D> {
TensorView::from(TensorRename::from(self, dimensions))
}
#[track_caller]
pub fn reshape_mut(&mut self, shape: [(Dimension, usize); D]) {
InvalidShapeError::validate_dimensions_or_panic(&shape, self.data.len());
let strides = compute_strides(&shape);
self.shape = shape;
self.strides = strides;
}
#[track_caller]
pub fn reshape_owned<const D2: usize>(self, shape: [(Dimension, usize); D2]) -> Tensor<T, D2> {
Tensor::from(shape, self.data)
}
pub fn reshape_view<const D2: usize>(
&self,
shape: [(Dimension, usize); D2],
) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, D2>, D2> {
TensorView::from(TensorReshape::from(self, shape))
}
pub fn reshape_view_mut<const D2: usize>(
&mut self,
shape: [(Dimension, usize); D2],
) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, D2>, D2> {
TensorView::from(TensorReshape::from(self, shape))
}
pub fn reshape_view_owned<const D2: usize>(
self,
shape: [(Dimension, usize); D2],
) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, D2>, D2> {
TensorView::from(TensorReshape::from(self, shape))
}
pub fn flatten_view(
&self,
dimension: Dimension,
) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, 1>, 1> {
self.reshape_view([(dimension, dimensions::elements(&self.shape))])
}
pub fn flatten_view_mut(
&mut self,
dimension: Dimension,
) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, 1>, 1> {
self.reshape_view_mut([(dimension, dimensions::elements(&self.shape))])
}
pub fn flatten_view_owned(
self,
dimension: Dimension,
) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, 1>, 1> {
let length = dimensions::elements(&self.shape);
self.reshape_view_owned([(dimension, length)])
}
pub fn flatten(self, dimension: Dimension) -> Tensor<T, 1> {
let length = dimensions::elements(&self.shape);
self.reshape_owned([(dimension, length)])
}
pub fn range<R, const P: usize>(
&self,
ranges: [(Dimension, R); P],
) -> Result<TensorView<T, TensorRange<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorRange::from(self, ranges).map(|range| TensorView::from(range))
}
pub fn range_mut<R, const P: usize>(
&mut self,
ranges: [(Dimension, R); P],
) -> Result<
TensorView<T, TensorRange<T, &mut Tensor<T, D>, D>, D>,
IndexRangeValidationError<D, P>,
>
where
R: Into<IndexRange>,
{
TensorRange::from(self, ranges).map(|range| TensorView::from(range))
}
pub fn range_owned<R, const P: usize>(
self,
ranges: [(Dimension, R); P],
) -> Result<TensorView<T, TensorRange<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorRange::from(self, ranges).map(|range| TensorView::from(range))
}
pub fn mask<R, const P: usize>(
&self,
masks: [(Dimension, R); P],
) -> Result<TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
}
pub fn mask_mut<R, const P: usize>(
&mut self,
masks: [(Dimension, R); P],
) -> Result<
TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D>,
IndexRangeValidationError<D, P>,
>
where
R: Into<IndexRange>,
{
TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
}
pub fn mask_owned<R, const P: usize>(
self,
masks: [(Dimension, R); P],
) -> Result<TensorView<T, TensorMask<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
}
#[track_caller]
pub fn start_and_end_of(
&self,
dimension: Dimension,
start_and_end: usize,
) -> TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D> {
TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
}
#[track_caller]
pub fn start_and_end_of_mut(
&mut self,
dimension: Dimension,
start_and_end: usize,
) -> TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D> {
TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
}
#[track_caller]
pub fn start_and_end_of_owned(
self,
dimension: Dimension,
start_and_end: usize,
) -> TensorView<T, TensorMask<T, Tensor<T, D>, D>, D> {
TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
}
#[track_caller]
pub fn reverse(
&self,
dimensions: &[Dimension],
) -> TensorView<T, TensorReverse<T, &Tensor<T, D>, D>, D> {
TensorView::from(TensorReverse::from(self, dimensions))
}
#[track_caller]
pub fn reverse_mut(
&mut self,
dimensions: &[Dimension],
) -> TensorView<T, TensorReverse<T, &mut Tensor<T, D>, D>, D> {
TensorView::from(TensorReverse::from(self, dimensions))
}
#[track_caller]
pub fn reverse_owned(
self,
dimensions: &[Dimension],
) -> TensorView<T, TensorReverse<T, Tensor<T, D>, D>, D> {
TensorView::from(TensorReverse::from(self, dimensions))
}
#[track_caller]
pub fn elementwise_reference<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
where
I: Into<TensorView<T, S, D>>,
S: TensorRef<T, D>,
M: Fn(&T, &T) -> T,
{
self.elementwise_reference_less_generic(rhs.into(), mapping_function)
}
#[track_caller]
pub fn elementwise_reference_with_index<S, I, M>(
&self,
rhs: I,
mapping_function: M,
) -> Tensor<T, D>
where
I: Into<TensorView<T, S, D>>,
S: TensorRef<T, D>,
M: Fn([usize; D], &T, &T) -> T,
{
self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
}
#[track_caller]
fn elementwise_reference_less_generic<S, M>(
&self,
rhs: TensorView<T, S, D>,
mapping_function: M,
) -> Tensor<T, D>
where
S: TensorRef<T, D>,
M: Fn(&T, &T) -> T,
{
let left_shape = self.shape();
let right_shape = rhs.shape();
if left_shape != right_shape {
panic!(
"Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
left_shape, right_shape
);
}
let mapped = self
.direct_iter_reference()
.zip(rhs.iter_reference())
.map(|(x, y)| mapping_function(x, y))
.collect();
Tensor::direct_from(mapped, self.shape, self.strides)
}
#[track_caller]
fn elementwise_reference_less_generic_with_index<S, M>(
&self,
rhs: TensorView<T, S, D>,
mapping_function: M,
) -> Tensor<T, D>
where
S: TensorRef<T, D>,
M: Fn([usize; D], &T, &T) -> T,
{
let left_shape = self.shape();
let right_shape = rhs.shape();
if left_shape != right_shape {
panic!(
"Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
left_shape, right_shape
);
}
let mapped = self
.direct_iter_reference()
.zip(rhs.iter_reference().with_index())
.map(|(x, (i, y))| mapping_function(i, x, y))
.collect();
Tensor::direct_from(mapped, self.shape, self.strides)
}
pub fn transpose_view(
&self,
dimensions: [Dimension; D],
) -> TensorView<T, TensorTranspose<T, &Tensor<T, D>, D>, D> {
TensorView::from(TensorTranspose::from(self, dimensions))
}
}
impl<T, const D: usize> Tensor<T, D>
where
T: Clone,
{
#[track_caller]
pub fn empty(shape: [(Dimension, usize); D], value: T) -> Self {
let elements = crate::tensors::dimensions::elements(&shape);
Tensor::from(shape, vec![value; elements])
}
pub fn first(&self) -> T {
self.data
.first()
.expect("Tensors always have at least 1 element")
.clone()
}
#[track_caller]
pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
let shape = self.shape;
let mut reordered = self.reorder(dimensions);
#[allow(clippy::needless_range_loop)]
for d in 0..D {
reordered.shape[d].0 = shape[d].0;
}
reordered
}
#[track_caller]
pub fn transpose_mut(&mut self, dimensions: [Dimension; D]) {
let shape = self.shape;
self.reorder_mut(dimensions);
#[allow(clippy::needless_range_loop)]
for d in 0..D {
self.shape[d].0 = shape[d].0;
}
}
#[track_caller]
pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
let reorderd = match TensorAccess::try_from(&self, dimensions) {
Ok(reordered) => reordered,
Err(_error) => panic!(
"Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
dimensions, &self.shape,
),
};
let reorderd_shape = reorderd.shape();
Tensor::from(reorderd_shape, reorderd.iter().collect())
}
#[track_caller]
pub fn reorder_mut(&mut self, dimensions: [Dimension; D]) {
use crate::tensors::dimensions::DimensionMappings;
if D == 2 && crate::tensors::dimensions::is_square(&self.shape) {
let dimension_mapping = match DimensionMappings::new(&self.shape, &dimensions) {
Some(dimension_mapping) => dimension_mapping,
None => panic!(
"Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
dimensions, &self.shape,
),
};
let shape = dimension_mapping.map_shape_to_requested(&self.shape);
let shape_iterator = ShapeIterator::from(shape);
for index in shape_iterator {
let i = index[0];
let j = index[1];
if j >= i {
let mapped_index = dimension_mapping.map_dimensions_to_source(&index);
let temp = self.get_reference(index).unwrap().clone();
*self.get_reference_mut(index).unwrap() =
self.get_reference(mapped_index).unwrap().clone();
*self.get_reference_mut(mapped_index).unwrap() = temp;
}
}
self.shape = shape;
self.strides = compute_strides(&shape);
} else {
let reordered = self.reorder(dimensions);
self.data = reordered.data;
self.shape = reordered.shape;
self.strides = reordered.strides;
}
}
pub fn iter(&self) -> TensorIterator<'_, T, Tensor<T, D>, D> {
TensorIterator::from(self)
}
pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
let mapped = self
.data
.iter()
.map(|x| mapping_function(x.clone()))
.collect();
Tensor::direct_from(mapped, self.shape, self.strides)
}
pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
let mapped = self
.iter()
.with_index()
.map(|(i, x)| mapping_function(i, x))
.collect();
Tensor::direct_from(mapped, self.shape, self.strides)
}
pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
for value in self.data.iter_mut() {
*value = mapping_function(value.clone());
}
}
pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
self.iter_reference_mut()
.with_index()
.for_each(|(i, x)| *x = mapping_function(i, x.clone()));
}
#[track_caller]
pub fn elementwise<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
where
I: Into<TensorView<T, S, D>>,
S: TensorRef<T, D>,
M: Fn(T, T) -> T,
{
self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
mapping_function(lhs.clone(), rhs.clone())
})
}
#[track_caller]
pub fn elementwise_with_index<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
where
I: Into<TensorView<T, S, D>>,
S: TensorRef<T, D>,
M: Fn([usize; D], T, T) -> T,
{
self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
mapping_function(i, lhs.clone(), rhs.clone())
})
}
}
impl<T> Tensor<T, 1>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
{
pub fn scalar_product<S, I>(&self, rhs: I) -> T
where
I: Into<TensorView<T, S, 1>>,
S: TensorRef<T, 1>,
{
self.scalar_product_less_generic(rhs.into())
}
}
impl<T> Tensor<T, 2>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
{
pub fn determinant(&self) -> Option<T> {
linear_algebra::determinant_tensor::<T, _, _>(self)
}
pub fn inverse(&self) -> Option<Tensor<T, 2>> {
linear_algebra::inverse_tensor::<T, _, _>(self)
}
pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
linear_algebra::covariance::<T, _, _>(self, feature_dimension)
}
}
impl<T> Tensor<T, 2>
where
T: Numeric,
{
#[track_caller]
pub fn diagonal(shape: [(Dimension, usize); 2], value: T) -> Tensor<T, 2> {
if !crate::tensors::dimensions::is_square(&shape) {
panic!("Shape must be square: {:?}", shape);
}
let mut tensor = Tensor::empty(shape, T::zero());
for ([r, c], x) in tensor.iter_reference_mut().with_index() {
if r == c {
*x = value.clone();
}
}
tensor
}
}
impl<T> Tensor<T, 2> {
pub fn into_matrix(self) -> crate::matrices::Matrix<T> {
self.into()
}
}
impl<T: Real> Tensor<T, 1>
where
for<'a> &'a T: RealRef<T>,
{
pub fn euclidean_length(&self) -> T {
self.direct_iter_reference()
.map(|x| x * x)
.sum::<T>()
.sqrt()
}
}
#[cfg(feature = "serde")]
mod serde_impls {
use crate::tensors::{Dimension, InvalidShapeError, Tensor};
use serde::Deserialize;
use std::convert::TryFrom;
#[derive(Deserialize)]
#[serde(rename = "Tensor")]
pub struct TensorDeserialize<'a, T, const D: usize> {
data: Vec<T>,
#[serde(with = "serde_arrays")]
#[serde(borrow)]
shape: [(&'a str, usize); D],
}
impl<'a, T, const D: usize> TensorDeserialize<'a, T, D> {
pub fn into_tensor(
self,
dimensions: [Dimension; D],
) -> Result<Tensor<T, D>, InvalidShapeError<D>> {
let shape = std::array::from_fn(|d| (dimensions[d], self.shape[d].1));
Tensor::try_from(shape, self.data)
}
}
impl<T, const D: usize> TryFrom<TensorDeserialize<'static, T, D>> for Tensor<T, D> {
type Error = InvalidShapeError<D>;
fn try_from(value: TensorDeserialize<'static, T, D>) -> Result<Self, Self::Error> {
Tensor::try_from(value.shape, value.data)
}
}
}
#[cfg(feature = "serde")]
#[test]
fn test_serialize() {
fn assert_serialize<T: Serialize>() {}
assert_serialize::<Tensor<f64, 3>>();
assert_serialize::<Tensor<f64, 2>>();
assert_serialize::<Tensor<f64, 1>>();
assert_serialize::<Tensor<f64, 0>>();
}
#[cfg(feature = "serde")]
#[test]
fn test_deserialize() {
use serde::Deserialize;
fn assert_deserialize<'de, T: Deserialize<'de>>() {}
assert_deserialize::<TensorDeserialize<f64, 3>>();
assert_deserialize::<TensorDeserialize<f64, 2>>();
assert_deserialize::<TensorDeserialize<f64, 1>>();
assert_deserialize::<TensorDeserialize<f64, 0>>();
}
#[cfg(feature = "serde")]
#[test]
fn test_serialization_deserialization_loop() {
#[rustfmt::skip]
let tensor = Tensor::from(
[("rows", 3), ("columns", 4)],
vec![
1, 2, 3, 4,
5, 6, 7, 8,
9, 10, 11, 12
],
);
let encoded = toml::to_string(&tensor).unwrap();
assert_eq!(
encoded,
r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
shape = [["rows", 3], ["columns", 4]]
"#,
);
let parsed: Result<TensorDeserialize<i32, 2>, _> = toml::from_str(&encoded);
assert!(parsed.is_ok());
let result = parsed.unwrap().into_tensor(["rows", "columns"]);
assert!(result.is_ok());
assert_eq!(result.unwrap(), tensor);
}
#[cfg(feature = "serde")]
#[test]
fn test_deserialization_validation() {
let parsed: Result<TensorDeserialize<i32, 2>, _> = toml::from_str(
r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
shape = [["rows", 4], ["columns", 4]]
"#,
);
assert!(parsed.is_ok());
let result = parsed.unwrap().into_tensor(["rows", "columns"]);
assert!(result.is_err());
}
macro_rules! tensor_select_impl {
(impl Tensor $d:literal 1) => {
impl<T> Tensor<T, $d> {
#[track_caller]
pub fn select(
&self,
provided_indexes: [(Dimension, usize); 1],
) -> TensorView<T, TensorIndex<T, &Tensor<T, $d>, $d, 1>, { $d - 1 }> {
TensorView::from(TensorIndex::from(self, provided_indexes))
}
#[track_caller]
pub fn select_mut(
&mut self,
provided_indexes: [(Dimension, usize); 1],
) -> TensorView<T, TensorIndex<T, &mut Tensor<T, $d>, $d, 1>, { $d - 1 }> {
TensorView::from(TensorIndex::from(self, provided_indexes))
}
#[track_caller]
pub fn select_owned(
self,
provided_indexes: [(Dimension, usize); 1],
) -> TensorView<T, TensorIndex<T, Tensor<T, $d>, $d, 1>, { $d - 1 }> {
TensorView::from(TensorIndex::from(self, provided_indexes))
}
}
};
}
tensor_select_impl!(impl Tensor 6 1);
tensor_select_impl!(impl Tensor 5 1);
tensor_select_impl!(impl Tensor 4 1);
tensor_select_impl!(impl Tensor 3 1);
tensor_select_impl!(impl Tensor 2 1);
tensor_select_impl!(impl Tensor 1 1);
macro_rules! tensor_expand_impl {
(impl Tensor $d:literal 1) => {
impl<T> Tensor<T, $d> {
#[track_caller]
pub fn expand(
&self,
extra_dimension_names: [(usize, Dimension); 1],
) -> TensorView<T, TensorExpansion<T, &Tensor<T, $d>, $d, 1>, { $d + 1 }> {
TensorView::from(TensorExpansion::from(self, extra_dimension_names))
}
#[track_caller]
pub fn expand_mut(
&mut self,
extra_dimension_names: [(usize, Dimension); 1],
) -> TensorView<T, TensorExpansion<T, &mut Tensor<T, $d>, $d, 1>, { $d + 1 }> {
TensorView::from(TensorExpansion::from(self, extra_dimension_names))
}
#[track_caller]
pub fn expand_owned(
self,
extra_dimension_names: [(usize, Dimension); 1],
) -> TensorView<T, TensorExpansion<T, Tensor<T, $d>, $d, 1>, { $d + 1 }> {
TensorView::from(TensorExpansion::from(self, extra_dimension_names))
}
}
};
}
tensor_expand_impl!(impl Tensor 0 1);
tensor_expand_impl!(impl Tensor 1 1);
tensor_expand_impl!(impl Tensor 2 1);
tensor_expand_impl!(impl Tensor 3 1);
tensor_expand_impl!(impl Tensor 4 1);
tensor_expand_impl!(impl Tensor 5 1);