use std::borrow::Cow;
use std::marker::PhantomData;
use std::ops::{Index, IndexMut, Range};
use crate::errors::{DimensionError, FromDataError, SliceError};
use crate::iterators::{
AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, BroadcastIter, InnerIter, InnerIterMut, Iter,
IterMut, Lanes, LanesMut, MutViewRef, ViewRef,
};
use crate::layout::{
AsIndex, BroadcastLayout, DynLayout, IntoLayout, Layout, MatrixLayout, MutLayout, NdLayout,
OverlapPolicy, ResizeLayout,
};
use crate::transpose::contiguous_data;
use crate::{IntoSliceItems, RandomSource, SliceItem};
#[derive(Debug)]
pub struct TensorBase<T, S: AsRef<[T]>, L: MutLayout> {
data: S,
layout: L,
element_type: PhantomData<T>,
}
pub trait AsView: Layout {
type Elem;
type Layout: for<'a> MutLayout<Index<'a> = Self::Index<'a>>;
fn view(&self) -> TensorBase<Self::Elem, &[Self::Elem], Self::Layout>;
fn layout(&self) -> &Self::Layout;
fn as_dyn(&self) -> TensorBase<Self::Elem, &[Self::Elem], DynLayout> {
self.view().as_dyn()
}
fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<Self::Elem, Self::Layout> {
self.view().axis_chunks(dim, chunk_size)
}
fn axis_iter(&self, dim: usize) -> AxisIter<Self::Elem, Self::Layout> {
self.view().axis_iter(dim)
}
fn broadcast<S: IntoLayout>(&self, shape: S) -> TensorBase<Self::Elem, &[Self::Elem], S::Layout>
where
Self::Layout: BroadcastLayout<S::Layout>,
{
self.view().broadcast(shape)
}
fn broadcast_iter(&self, shape: &[usize]) -> BroadcastIter<Self::Elem> {
self.view().broadcast_iter(shape)
}
fn data(&self) -> Option<&[Self::Elem]>;
fn get<I: AsIndex<Self::Layout>>(&self, index: I) -> Option<&Self::Elem> {
self.view().get(index)
}
unsafe fn get_unchecked<I: AsIndex<Self::Layout>>(&self, index: I) -> &Self::Elem {
self.view().get_unchecked(index)
}
fn inner_iter<const N: usize>(&self) -> InnerIter<Self::Elem, Self::Layout, N> {
self.view().inner_iter()
}
fn insert_axis(&mut self, index: usize)
where
Self::Layout: ResizeLayout;
fn item(&self) -> Option<&Self::Elem> {
self.view().item()
}
fn iter(&self) -> Iter<Self::Elem>;
fn lanes(&self, dim: usize) -> Lanes<Self::Elem> {
self.view().lanes(dim)
}
fn map<F, U>(&self, f: F) -> TensorBase<U, Vec<U>, Self::Layout>
where
F: Fn(&Self::Elem) -> U,
{
self.view().map(f)
}
fn move_axis(&mut self, from: usize, to: usize);
fn nd_view<const N: usize>(&self) -> TensorBase<Self::Elem, &[Self::Elem], NdLayout<N>> {
self.view().nd_view()
}
fn permute(&mut self, order: Self::Index<'_>);
fn permuted(
&self,
order: Self::Index<'_>,
) -> TensorBase<Self::Elem, &[Self::Elem], Self::Layout> {
self.view().permuted(order)
}
fn reshaped<S: IntoLayout>(
&self,
shape: S,
) -> TensorBase<Self::Elem, &[Self::Elem], S::Layout> {
self.view().reshaped(shape)
}
fn transpose(&mut self);
fn transposed(&self) -> TensorBase<Self::Elem, &[Self::Elem], Self::Layout> {
self.view().transposed()
}
fn try_slice_dyn<R: IntoSliceItems>(
&self,
range: R,
) -> Result<TensorView<Self::Elem>, SliceError> {
self.view().try_slice_dyn(range)
}
fn slice<const M: usize, R: IntoSliceItems>(&self, range: R) -> NdTensorView<Self::Elem, M> {
self.view().slice(range)
}
fn slice_dyn<R: IntoSliceItems>(&self, range: R) -> TensorView<Self::Elem> {
self.view().slice_dyn(range)
}
fn slice_iter(&self, range: &[SliceItem]) -> Iter<Self::Elem> {
self.view().slice_iter(range)
}
fn squeezed(&self) -> TensorView<Self::Elem> {
self.view().squeezed()
}
fn to_vec(&self) -> Vec<Self::Elem>
where
Self::Elem: Clone;
fn to_contiguous(&self) -> TensorBase<Self::Elem, Cow<[Self::Elem]>, Self::Layout>
where
Self::Elem: Clone,
{
self.view().to_contiguous()
}
fn to_shape<S: IntoLayout>(
&self,
shape: S,
) -> TensorBase<Self::Elem, Vec<Self::Elem>, S::Layout>
where
Self::Elem: Clone;
fn to_tensor(&self) -> TensorBase<Self::Elem, Vec<Self::Elem>, Self::Layout>
where
Self::Elem: Clone,
{
let data = self.to_vec();
TensorBase::from_data(self.layout().shape(), data)
}
fn weakly_checked_view(&self) -> WeaklyCheckedView<Self::Elem, &[Self::Elem], Self::Layout> {
self.view().weakly_checked_view()
}
}
impl<T, S: AsRef<[T]>, L: MutLayout> TensorBase<T, S, L> {
pub fn from_data(shape: L::Index<'_>, data: S) -> TensorBase<T, S, L>
where
for<'a> L::Index<'a>: Clone,
{
let len = data.as_ref().len();
Self::try_from_data(shape.clone(), data).unwrap_or_else(|_| {
panic!(
"data length {} does not match shape {:?}",
len,
shape.as_ref(),
);
})
}
pub fn try_from_data(
shape: L::Index<'_>,
data: S,
) -> Result<TensorBase<T, S, L>, FromDataError> {
let layout = L::from_shape(shape);
if layout.min_data_len() != data.as_ref().len() {
return Err(FromDataError::StorageLengthMismatch);
}
Ok(TensorBase {
data,
layout,
element_type: PhantomData,
})
}
pub fn from_data_with_strides(
shape: L::Index<'_>,
data: S,
strides: L::Index<'_>,
) -> Result<TensorBase<T, S, L>, FromDataError> {
let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::DisallowOverlap)?;
if layout.min_data_len() > data.as_ref().len() {
return Err(FromDataError::StorageTooShort);
}
Ok(TensorBase {
data,
layout,
element_type: PhantomData,
})
}
pub fn into_dyn(self) -> TensorBase<T, S, DynLayout>
where
L: Into<DynLayout>,
{
TensorBase {
data: self.data,
layout: self.layout.into(),
element_type: PhantomData,
}
}
fn nd_layout<const N: usize>(&self) -> Option<NdLayout<N>> {
if self.ndim() != N {
return None;
}
let shape: [usize; N] = std::array::from_fn(|i| self.size(i));
let strides: [usize; N] = std::array::from_fn(|i| self.stride(i));
let layout =
NdLayout::try_from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)
.expect("invalid layout");
Some(layout)
}
}
impl<T, S: AsRef<[T]> + AsMut<[T]>, L: MutLayout> TensorBase<T, S, L> {
pub fn axis_iter_mut(&mut self, dim: usize) -> AxisIterMut<T, L> {
AxisIterMut::new(self.view_mut(), dim)
}
pub fn axis_chunks_mut(&mut self, dim: usize, chunk_size: usize) -> AxisChunksMut<T, L> {
AxisChunksMut::new(self.view_mut(), dim, chunk_size)
}
pub fn apply<F: Fn(&T) -> T>(&mut self, f: F) {
if let Some(data) = self.data_mut() {
data.iter_mut().for_each(|x| *x = f(x));
} else {
self.iter_mut().for_each(|x| *x = f(x));
}
}
pub fn as_dyn_mut(&mut self) -> TensorBase<T, &mut [T], DynLayout> {
TensorBase {
layout: DynLayout::from_layout(&self.layout),
data: self.data.as_mut(),
element_type: PhantomData,
}
}
pub fn copy_from<S2: AsRef<[T]>>(&mut self, other: &TensorBase<T, S2, L>)
where
T: Clone,
L: Clone,
{
assert!(self.shape() == other.shape());
for (out, x) in self.iter_mut().zip(other.iter()) {
*out = x.clone();
}
}
pub fn data_mut(&mut self) -> Option<&mut [T]> {
self.layout.is_contiguous().then_some(self.data.as_mut())
}
pub fn fill(&mut self, value: T)
where
T: Clone,
{
self.apply(|_| value.clone())
}
pub fn get_mut<I: AsIndex<L>>(&mut self, index: I) -> Option<&mut T> {
self.try_offset(index.as_index())
.map(|offset| &mut self.data.as_mut()[offset])
}
pub unsafe fn get_unchecked_mut<I: AsIndex<L>>(&mut self, index: I) -> &mut T {
self.data
.as_mut()
.get_unchecked_mut(self.layout.offset_unchecked(index.as_index()))
}
pub(crate) fn mut_view_ref(&mut self) -> MutViewRef<T, L> {
MutViewRef::new(self.data.as_mut(), &self.layout)
}
pub fn inner_iter_mut<const N: usize>(&mut self) -> InnerIterMut<T, L, N> {
InnerIterMut::new(self.view_mut())
}
pub fn iter_mut(&mut self) -> IterMut<T> {
IterMut::new(self.mut_view_ref())
}
pub fn lanes_mut(&mut self, dim: usize) -> LanesMut<T> {
LanesMut::new(self.mut_view_ref(), dim)
}
pub fn nd_view_mut<const N: usize>(&mut self) -> TensorBase<T, &mut [T], NdLayout<N>> {
assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N);
TensorBase {
layout: self.nd_layout().unwrap(),
data: self.data.as_mut(),
element_type: PhantomData,
}
}
pub fn permuted_mut(&mut self, order: L::Index<'_>) -> TensorBase<T, &mut [T], L> {
TensorBase {
layout: self.layout.permuted(order),
data: self.data.as_mut(),
element_type: PhantomData,
}
}
pub fn reshaped_mut<SH: IntoLayout>(
&mut self,
shape: SH,
) -> TensorBase<T, &mut [T], SH::Layout> {
TensorBase {
layout: self.layout.reshaped(shape),
data: self.data.as_mut(),
element_type: PhantomData,
}
}
pub fn slice_mut<const M: usize, R: IntoSliceItems>(
&mut self,
range: R,
) -> NdTensorViewMut<T, M> {
let range = range.into_slice_items();
let (offset_range, sliced_layout) = self.layout.slice(range.as_ref());
NdTensorViewMut {
data: &mut self.data.as_mut()[offset_range],
layout: sliced_layout,
element_type: PhantomData,
}
}
pub fn slice_mut_dyn<R: IntoSliceItems>(&mut self, range: R) -> TensorViewMut<T> {
let range = range.into_slice_items();
let (offset_range, sliced_layout) = self.layout.slice_dyn(range.as_ref());
TensorViewMut {
data: &mut self.data.as_mut()[offset_range],
layout: sliced_layout,
element_type: PhantomData,
}
}
pub fn try_slice_mut<R: IntoSliceItems>(
&mut self,
range: R,
) -> Result<TensorViewMut<T>, SliceError> {
let (offset_range, layout) = self.layout.try_slice(range)?;
Ok(TensorBase {
data: &mut self.data.as_mut()[offset_range],
layout,
element_type: PhantomData,
})
}
pub fn view_mut(&mut self) -> TensorBase<T, &mut [T], L>
where
L: Clone,
{
TensorBase {
data: self.data.as_mut(),
layout: self.layout.clone(),
element_type: PhantomData,
}
}
pub fn weakly_checked_view_mut(&mut self) -> WeaklyCheckedView<T, &mut [T], L> {
WeaklyCheckedView {
base: self.view_mut(),
}
}
}
impl<T, L: Clone + MutLayout> TensorBase<T, Vec<T>, L> {
pub fn arange(start: T, end: T, step: Option<T>) -> TensorBase<T, Vec<T>, L>
where
T: Copy + PartialOrd + From<bool> + std::ops::Add<Output = T>,
[usize; 1]: AsIndex<L>,
{
let step = step.unwrap_or((true).into());
let mut data = Vec::new();
let mut curr = start;
while curr < end {
data.push(curr);
curr = curr + step;
}
TensorBase::from_data([data.len()].as_index(), data)
}
pub fn from_vec(vec: Vec<T>) -> TensorBase<T, Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
{
TensorBase::from_data([vec.len()].as_index(), vec)
}
pub fn clip_dim(&mut self, dim: usize, range: Range<usize>)
where
T: Copy,
{
let (start, end) = (range.start, range.end);
assert!(start <= end, "start must be <= end");
assert!(end <= self.size(dim), "end must be <= dim size");
let start_offset = self.layout.stride(dim) * start;
self.layout.resize_dim(dim, end - start);
let range = start_offset..start_offset + self.layout.min_data_len();
self.data.copy_within(range.clone(), 0);
self.data.truncate(range.end - range.start);
}
pub fn into_data(self) -> Vec<T>
where
T: Clone,
{
if self.is_contiguous() {
self.data
} else {
self.to_vec()
}
}
pub fn into_shape<S: IntoLayout>(self, shape: S) -> TensorBase<T, Vec<T>, S::Layout>
where
T: Clone,
{
TensorBase {
data: self.into_data(),
layout: shape.into_layout(),
element_type: PhantomData,
}
}
pub fn from_fn<F: FnMut(L::Index<'_>) -> T, Idx>(
shape: L::Index<'_>,
mut f: F,
) -> TensorBase<T, Vec<T>, L>
where
L::Indices: Iterator<Item = Idx>,
Idx: AsIndex<L>,
{
let layout = L::from_shape(shape);
let data: Vec<T> = layout.indices().map(|idx| f(idx.as_index())).collect();
TensorBase {
data,
layout,
element_type: PhantomData,
}
}
pub fn from_simple_fn<F: FnMut() -> T>(
shape: L::Index<'_>,
mut f: F,
) -> TensorBase<T, Vec<T>, L> {
let len = shape.as_ref().iter().product();
let data: Vec<T> = std::iter::from_fn(|| Some(f())).take(len).collect();
TensorBase::from_data(shape, data)
}
pub fn from_scalar(value: T) -> TensorBase<T, Vec<T>, L>
where
[usize; 0]: AsIndex<L>,
{
TensorBase::from_data([].as_index(), vec![value])
}
pub fn full(shape: L::Index<'_>, value: T) -> TensorBase<T, Vec<T>, L>
where
T: Clone,
{
let n_elts = shape.as_ref().iter().product();
let data = vec![value; n_elts];
TensorBase::from_data(shape, data)
}
pub fn make_contiguous(&mut self)
where
T: Clone,
{
if self.is_contiguous() {
return;
}
self.data = self.to_vec();
self.layout = L::from_shape(self.layout.shape());
}
pub fn rand<R: RandomSource<T>>(
shape: L::Index<'_>,
rand_src: &mut R,
) -> TensorBase<T, Vec<T>, L> {
Self::from_simple_fn(shape, || rand_src.next())
}
pub fn zeros(shape: L::Index<'_>) -> TensorBase<T, Vec<T>, L>
where
T: Clone + Default,
{
Self::full(shape, T::default())
}
}
impl<'a, T, L: Clone + MutLayout> TensorBase<T, &'a [T], L> {
pub fn axis_iter(&self, dim: usize) -> AxisIter<'a, T, L> {
AxisIter::new(self, dim)
}
pub fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<'a, T, L> {
AxisChunks::new(self, dim, chunk_size)
}
pub fn as_dyn(&self) -> TensorBase<T, &'a [T], DynLayout> {
TensorBase {
data: self.data,
layout: DynLayout::from_layout(&self.layout),
element_type: PhantomData,
}
}
pub fn broadcast<S: IntoLayout>(&self, shape: S) -> TensorBase<T, &'a [T], S::Layout>
where
L: BroadcastLayout<S::Layout>,
{
TensorBase {
layout: self.layout.broadcast(shape),
data: self.data,
element_type: PhantomData,
}
}
pub fn broadcast_iter(&self, shape: &[usize]) -> BroadcastIter<'a, T> {
BroadcastIter::new(self.view_ref(), shape)
}
pub fn data(&self) -> Option<&'a [T]> {
self.layout.is_contiguous().then_some(self.data)
}
pub fn get<I: AsIndex<L>>(&self, index: I) -> Option<&'a T> {
self.try_offset(index.as_index())
.map(|offset| &self.data[offset])
}
pub fn non_contiguous_data(&self) -> &'a [T] {
self.data
}
pub fn from_slice_with_strides(
shape: L::Index<'_>,
data: &'a [T],
strides: L::Index<'_>,
) -> Result<TensorBase<T, &'a [T], L>, FromDataError> {
let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)?;
if layout.min_data_len() > data.as_ref().len() {
return Err(FromDataError::StorageTooShort);
}
Ok(TensorBase {
data,
layout,
element_type: PhantomData,
})
}
pub unsafe fn get_unchecked<I: AsIndex<L>>(&self, index: I) -> &'a T {
self.data
.get_unchecked(self.layout.offset_unchecked(index.as_index()))
}
pub fn inner_iter<const N: usize>(&self) -> InnerIter<'a, T, L, N> {
InnerIter::new(self.view())
}
pub fn item(&self) -> Option<&'a T> {
match self.ndim() {
0 => Some(&self.data[0]),
_ if self.len() == 1 => self.iter().next(),
_ => None,
}
}
pub fn iter(&self) -> Iter<'a, T> {
Iter::new(self.view_ref())
}
pub fn lanes(&self, dim: usize) -> Lanes<'a, T> {
Lanes::new(self.view_ref(), dim)
}
pub fn nd_view<const N: usize>(&self) -> TensorBase<T, &'a [T], NdLayout<N>> {
assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N);
TensorBase {
data: self.data,
layout: self.nd_layout().unwrap(),
element_type: PhantomData,
}
}
pub fn permuted(&self, order: L::Index<'_>) -> TensorBase<T, &'a [T], L> {
TensorBase {
data: self.data,
layout: self.layout.permuted(order),
element_type: PhantomData,
}
}
pub fn reshaped<S: IntoLayout>(&self, shape: S) -> TensorBase<T, &'a [T], S::Layout> {
TensorBase {
data: self.data,
layout: self.layout.reshaped(shape),
element_type: PhantomData,
}
}
pub fn slice<const M: usize, R: IntoSliceItems>(&self, range: R) -> NdTensorView<'a, T, M> {
let range = range.into_slice_items();
let (offset_range, sliced_layout) = self.layout.slice(range.as_ref());
NdTensorView {
data: &self.data[offset_range],
layout: sliced_layout,
element_type: PhantomData,
}
}
pub fn slice_dyn<R: IntoSliceItems>(&self, range: R) -> TensorView<'a, T> {
let range = range.into_slice_items();
let (offset_range, sliced_layout) = self.layout.slice_dyn(range.as_ref());
TensorView {
data: &self.data[offset_range],
layout: sliced_layout,
element_type: PhantomData,
}
}
pub fn slice_iter(&self, range: &[SliceItem]) -> Iter<'a, T> {
Iter::slice(self.view_ref(), range)
}
pub fn squeezed(&self) -> TensorView<'a, T> {
TensorBase {
data: self.data,
layout: self.layout.squeezed(),
element_type: PhantomData,
}
}
pub fn to_contiguous(&self) -> TensorBase<T, Cow<'a, [T]>, L>
where
T: Clone,
{
if self.is_contiguous() {
TensorBase {
data: Cow::Borrowed(self.data),
layout: self.layout.clone(),
element_type: PhantomData,
}
} else {
let data = self.to_vec();
TensorBase {
data: Cow::Owned(data),
layout: L::from_shape(self.layout.shape()),
element_type: PhantomData,
}
}
}
pub fn transposed(&self) -> TensorBase<T, &'a [T], L> {
TensorBase {
data: self.data,
layout: self.layout.transposed(),
element_type: PhantomData,
}
}
pub fn try_slice_dyn<R: IntoSliceItems>(
&self,
range: R,
) -> Result<TensorView<'a, T>, SliceError> {
let (offset_range, layout) = self.layout.try_slice(range)?;
Ok(TensorBase {
data: &self.data[offset_range],
layout,
element_type: PhantomData,
})
}
pub fn view(&self) -> TensorBase<T, &'a [T], L> {
TensorBase {
data: self.data,
layout: self.layout.clone(),
element_type: PhantomData,
}
}
pub(crate) fn view_ref(&self) -> ViewRef<'a, '_, T, L> {
ViewRef::new(self.data, &self.layout)
}
pub fn weakly_checked_view(&self) -> WeaklyCheckedView<T, &'a [T], L> {
WeaklyCheckedView { base: self.view() }
}
}
impl<T, S: AsRef<[T]>, L: MutLayout> Layout for TensorBase<T, S, L> {
type Index<'a> = L::Index<'a>;
type Indices = L::Indices;
fn ndim(&self) -> usize {
self.layout.ndim()
}
fn len(&self) -> usize {
self.layout.len()
}
fn is_empty(&self) -> bool {
self.layout.is_empty()
}
fn shape(&self) -> Self::Index<'_> {
self.layout.shape()
}
fn size(&self, dim: usize) -> usize {
self.layout.size(dim)
}
fn strides(&self) -> Self::Index<'_> {
self.layout.strides()
}
fn stride(&self, dim: usize) -> usize {
self.layout.stride(dim)
}
fn indices(&self) -> Self::Indices {
self.layout.indices()
}
fn try_offset(&self, index: Self::Index<'_>) -> Option<usize> {
self.layout.try_offset(index)
}
}
impl<T, S: AsRef<[T]>, L: MutLayout + MatrixLayout> MatrixLayout for TensorBase<T, S, L> {
fn rows(&self) -> usize {
self.layout.rows()
}
fn cols(&self) -> usize {
self.layout.cols()
}
fn row_stride(&self) -> usize {
self.layout.row_stride()
}
fn col_stride(&self) -> usize {
self.layout.col_stride()
}
}
impl<T, S: AsRef<[T]>, L: MutLayout + Clone> AsView for TensorBase<T, S, L> {
type Elem = T;
type Layout = L;
fn iter(&self) -> Iter<T> {
self.view().iter()
}
fn data(&self) -> Option<&[Self::Elem]> {
self.view().data()
}
fn insert_axis(&mut self, index: usize)
where
L: ResizeLayout,
{
self.layout.insert_axis(index)
}
fn layout(&self) -> &L {
&self.layout
}
fn map<F, U>(&self, f: F) -> TensorBase<U, Vec<U>, L>
where
F: Fn(&Self::Elem) -> U,
{
let data: Vec<U> = if let Some(data) = self.data() {
data.iter().map(f).collect()
} else {
self.iter().map(f).collect()
};
TensorBase::from_data(self.shape(), data)
}
fn move_axis(&mut self, from: usize, to: usize) {
self.layout.move_axis(from, to);
}
fn view(&self) -> TensorBase<T, &[T], L> {
TensorBase {
data: self.data.as_ref(),
layout: self.layout.clone(),
element_type: PhantomData,
}
}
fn get<I: AsIndex<L>>(&self, index: I) -> Option<&Self::Elem> {
self.try_offset(index.as_index())
.map(|offset| &self.data.as_ref()[offset])
}
unsafe fn get_unchecked<I: AsIndex<L>>(&self, index: I) -> &T {
let offset = self.layout.offset_unchecked(index.as_index());
self.data.as_ref().get_unchecked(offset)
}
fn permute(&mut self, order: Self::Index<'_>) {
self.layout = self.layout.permuted(order);
}
fn to_vec(&self) -> Vec<T>
where
T: Clone,
{
if let Some(data) = self.data() {
data.to_vec()
} else {
contiguous_data(self.as_dyn())
}
}
fn to_shape<SH: IntoLayout>(
&self,
shape: SH,
) -> TensorBase<Self::Elem, Vec<Self::Elem>, SH::Layout>
where
T: Clone,
{
TensorBase {
data: self.to_vec(),
layout: shape.into_layout(),
element_type: PhantomData,
}
}
fn transpose(&mut self) {
self.layout = self.layout.transposed();
}
}
impl<T, S: AsRef<[T]>, const N: usize> TensorBase<T, S, NdLayout<N>> {
#[inline]
pub fn get_array<const M: usize>(&self, base: [usize; N], dim: usize) -> [T; M]
where
T: Copy + Default,
{
let offsets: [usize; M] = array_offsets(&self.layout, base, dim);
let data = self.data.as_ref();
let mut result = [T::default(); M];
for i in 0..M {
result[i] = unsafe { *data.get_unchecked(offsets[i]) };
}
result
}
}
impl<T> TensorBase<T, Vec<T>, DynLayout> {
pub fn reshape(&mut self, shape: &[usize])
where
T: Clone,
{
if !self.is_contiguous() {
self.data = self.to_vec();
}
self.layout = DynLayout::from_shape(shape);
}
}
impl<'a, T> TensorBase<T, &'a [T], DynLayout> {
pub fn reshape(&mut self, shape: &[usize])
where
T: Clone,
{
assert!(self.is_contiguous(), "can only reshape contiguous views");
self.layout = DynLayout::from_shape(shape);
}
}
impl<'a, T> TensorBase<T, &'a mut [T], DynLayout> {
pub fn reshape(&mut self, shape: &[usize])
where
T: Clone,
{
assert!(self.is_contiguous(), "can only reshape contiguous views");
self.layout = DynLayout::from_shape(shape);
}
}
impl<T, L: Clone + MutLayout> FromIterator<T> for TensorBase<T, Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
{
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> TensorBase<T, Vec<T>, L> {
let data: Vec<T> = iter.into_iter().collect();
TensorBase::from_data([data.len()].as_index(), data)
}
}
impl<T, L: Clone + MutLayout> From<Vec<T>> for TensorBase<T, Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
{
fn from(vec: Vec<T>) -> Self {
Self::from_data([vec.len()].as_index(), vec)
}
}
impl<'a, T, L: Clone + MutLayout> From<&'a [T]> for TensorBase<T, &'a [T], L>
where
[usize; 1]: AsIndex<L>,
{
fn from(slice: &'a [T]) -> Self {
Self::from_data([slice.len()].as_index(), slice)
}
}
impl<'a, T, L: Clone + MutLayout, const N: usize> From<&'a [T; N]> for TensorBase<T, &'a [T], L>
where
[usize; 1]: AsIndex<L>,
{
fn from(slice: &'a [T; N]) -> Self {
Self::from_data([slice.len()].as_index(), slice.as_slice())
}
}
fn array_offsets<const N: usize, const M: usize>(
layout: &NdLayout<N>,
base: [usize; N],
dim: usize,
) -> [usize; M] {
assert!(
base[dim] < usize::MAX - M && layout.size(dim) >= base[dim] + M,
"array indices invalid"
);
let offset = layout.offset(base);
let stride = layout.stride(dim);
let mut offsets = [0; M];
for i in 0..M {
offsets[i] = offset + i * stride;
}
offsets
}
impl<T, S: AsRef<[T]> + AsMut<[T]>, const N: usize> TensorBase<T, S, NdLayout<N>> {
#[inline]
pub fn set_array<const M: usize>(&mut self, base: [usize; N], dim: usize, values: [T; M])
where
T: Copy,
{
let offsets: [usize; M] = array_offsets(&self.layout, base, dim);
let data = self.data.as_mut();
for i in 0..M {
unsafe { *data.get_unchecked_mut(offsets[i]) = values[i] };
}
}
}
impl<T, S: AsRef<[T]>> TensorBase<T, S, NdLayout<1>> {
#[inline]
pub fn to_array<const M: usize>(&self) -> [T; M]
where
T: Copy + Default,
{
self.get_array([0], 0)
}
}
impl<T, S: AsRef<[T]> + AsMut<[T]>> TensorBase<T, S, NdLayout<1>> {
#[inline]
pub fn assign_array<const M: usize>(&mut self, values: [T; M])
where
T: Copy + Default,
{
self.set_array([0], 0, values)
}
}
pub type NdTensorView<'a, T, const N: usize> = TensorBase<T, &'a [T], NdLayout<N>>;
pub type NdTensor<T, const N: usize> = TensorBase<T, Vec<T>, NdLayout<N>>;
pub type NdTensorViewMut<'a, T, const N: usize> = TensorBase<T, &'a mut [T], NdLayout<N>>;
pub type Matrix<'a, T = f32> = NdTensorView<'a, T, 2>;
pub type MatrixMut<'a, T = f32> = NdTensorViewMut<'a, T, 2>;
pub type Tensor<T = f32> = TensorBase<T, Vec<T>, DynLayout>;
pub type TensorView<'a, T = f32> = TensorBase<T, &'a [T], DynLayout>;
pub type TensorViewMut<'a, T = f32> = TensorBase<T, &'a mut [T], DynLayout>;
impl<T, S: AsRef<[T]>, L: MutLayout, I: AsIndex<L>> Index<I> for TensorBase<T, S, L> {
type Output = T;
fn index(&self, index: I) -> &Self::Output {
let offset = self.layout.offset(index.as_index());
&self.data.as_ref()[offset]
}
}
impl<T, S: AsRef<[T]> + AsMut<[T]>, L: MutLayout, I: AsIndex<L>> IndexMut<I>
for TensorBase<T, S, L>
{
fn index_mut(&mut self, index: I) -> &mut Self::Output {
let offset = self.layout.offset(index.as_index());
&mut self.data.as_mut()[offset]
}
}
impl<T, S: AsRef<[T]> + Clone, L: MutLayout + Clone> Clone for TensorBase<T, S, L> {
fn clone(&self) -> TensorBase<T, S, L> {
let data = self.data.clone();
TensorBase {
data,
layout: self.layout.clone(),
element_type: PhantomData,
}
}
}
impl<T, S: AsRef<[T]> + Copy, L: MutLayout + Copy> Copy for TensorBase<T, S, L> {}
impl<T: PartialEq, S: AsRef<[T]>, L: MutLayout, V: AsView<Elem = T>> PartialEq<V>
for TensorBase<T, S, L>
{
fn eq(&self, other: &V) -> bool {
self.shape().as_ref() == other.shape().as_ref() && self.iter().eq(other.iter())
}
}
impl<T, S: AsRef<[T]>, const N: usize> From<TensorBase<T, S, NdLayout<N>>>
for TensorBase<T, S, DynLayout>
{
fn from(tensor: TensorBase<T, S, NdLayout<N>>) -> Self {
Self {
data: tensor.data,
layout: tensor.layout.into(),
element_type: PhantomData,
}
}
}
impl<T, S1: AsRef<[T]>, S2: AsRef<[T]>, const N: usize> TryFrom<TensorBase<T, S1, DynLayout>>
for TensorBase<T, S2, NdLayout<N>>
where
S1: Into<S2>,
{
type Error = DimensionError;
fn try_from(value: TensorBase<T, S1, DynLayout>) -> Result<Self, Self::Error> {
let layout: NdLayout<N> = value.layout().try_into()?;
Ok(TensorBase {
data: value.data.into(),
layout,
element_type: PhantomData,
})
}
}
pub trait Scalar {}
impl Scalar for i32 {}
impl Scalar for f32 {}
impl<T: Clone + Scalar, L: MutLayout, const D0: usize> From<[T; D0]> for TensorBase<T, Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
{
fn from(value: [T; D0]) -> Self {
Self::from_data([D0].as_index(), value.iter().cloned().collect())
}
}
impl<T: Clone + Scalar, L: MutLayout, const D0: usize, const D1: usize> From<[[T; D1]; D0]>
for TensorBase<T, Vec<T>, L>
where
[usize; 2]: AsIndex<L>,
{
fn from(value: [[T; D1]; D0]) -> Self {
let data: Vec<_> = value.iter().flat_map(|y| y.iter()).cloned().collect();
Self::from_data([D0, D1].as_index(), data)
}
}
impl<T: Clone + Scalar, L: MutLayout, const D0: usize, const D1: usize, const D2: usize>
From<[[[T; D2]; D1]; D0]> for TensorBase<T, Vec<T>, L>
where
[usize; 3]: AsIndex<L>,
{
fn from(value: [[[T; D2]; D1]; D0]) -> Self {
let data: Vec<_> = value
.iter()
.flat_map(|y| y.iter().flat_map(|z| z.iter()))
.cloned()
.collect();
Self::from_data([D0, D1, D2].as_index(), data)
}
}
pub struct WeaklyCheckedView<T, S: AsRef<[T]>, L: MutLayout> {
base: TensorBase<T, S, L>,
}
impl<T, S: AsRef<[T]>, L: MutLayout> Layout for WeaklyCheckedView<T, S, L> {
type Index<'a> = L::Index<'a>;
type Indices = L::Indices;
fn ndim(&self) -> usize {
self.base.ndim()
}
fn try_offset(&self, index: Self::Index<'_>) -> Option<usize> {
self.base.try_offset(index)
}
fn len(&self) -> usize {
self.base.len()
}
fn shape(&self) -> Self::Index<'_> {
self.base.shape()
}
fn strides(&self) -> Self::Index<'_> {
self.base.strides()
}
fn indices(&self) -> Self::Indices {
self.base.indices()
}
}
impl<T, S: AsRef<[T]>, L: MutLayout, I: AsIndex<L>> Index<I> for WeaklyCheckedView<T, S, L> {
type Output = T;
fn index(&self, index: I) -> &Self::Output {
&self.base.data.as_ref()[self.base.layout.offset_unchecked(index.as_index())]
}
}
impl<T, S: AsRef<[T]> + AsMut<[T]>, L: MutLayout, I: AsIndex<L>> IndexMut<I>
for WeaklyCheckedView<T, S, L>
{
fn index_mut(&mut self, index: I) -> &mut Self::Output {
let offset = self.base.layout.offset_unchecked(index.as_index());
&mut self.base.data.as_mut()[offset]
}
}
#[cfg(test)]
mod tests {
use super::{AsView, NdTensor, NdTensorView, Tensor};
use crate::errors::FromDataError;
use crate::layout::MatrixLayout;
use crate::prelude::*;
use crate::rng::XorShiftRng;
use crate::SliceItem;
#[test]
fn test_apply() {
let data = vec![1., 2., 3., 4.];
let mut tensor = NdTensor::from_data([2, 2], data);
tensor.apply(|x| *x * 2.);
assert_eq!(tensor.to_vec(), &[2., 4., 6., 8.]);
tensor.transpose();
tensor.apply(|x| *x / 2.);
assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]);
}
#[test]
fn test_arange() {
let x = Tensor::arange(2, 6, None);
let y = NdTensor::arange(2, 6, None);
assert_eq!(x.data(), Some([2, 3, 4, 5].as_slice()));
assert_eq!(y.data(), Some([2, 3, 4, 5].as_slice()));
}
#[test]
fn test_as_dyn() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);
let dyn_view = tensor.as_dyn();
assert_eq!(dyn_view.shape(), tensor.shape().as_ref());
assert_eq!(dyn_view.to_vec(), tensor.to_vec());
}
#[test]
fn test_as_dyn_mut() {
let data = vec![1., 2., 3., 4.];
let mut tensor = NdTensor::from_data([2, 2], data);
let mut dyn_view = tensor.as_dyn_mut();
dyn_view[[0, 0]] = 9.;
assert_eq!(tensor[[0, 0]], 9.);
}
#[test]
fn test_assign_array() {
let mut tensor = NdTensor::zeros([2, 2]);
let mut transposed = tensor.view_mut();
transposed.permute([1, 0]);
transposed.slice_mut(0).assign_array([1, 2]);
transposed.slice_mut(1).assign_array([3, 4]);
assert_eq!(tensor.iter().copied().collect::<Vec<_>>(), [1, 3, 2, 4]);
}
#[test]
fn test_axis_chunks() {
let tensor = NdTensor::arange(0, 8, None).into_shape([4, 2]);
let mut row_chunks = tensor.axis_chunks(0, 2);
let chunk = row_chunks.next().unwrap();
assert_eq!(chunk.shape(), &[2, 2]);
assert_eq!(chunk.to_vec(), &[0, 1, 2, 3]);
let chunk = row_chunks.next().unwrap();
assert_eq!(chunk.shape(), &[2, 2]);
assert_eq!(chunk.to_vec(), &[4, 5, 6, 7]);
assert!(row_chunks.next().is_none());
}
#[test]
fn test_axis_chunks_mut() {
let mut tensor = NdTensor::arange(1, 9, None).into_shape([4, 2]);
let mut row_chunks = tensor.axis_chunks_mut(0, 2);
let mut chunk = row_chunks.next().unwrap();
chunk.apply(|x| x * 2);
let mut chunk = row_chunks.next().unwrap();
chunk.apply(|x| x * -2);
assert!(row_chunks.next().is_none());
assert_eq!(tensor.to_vec(), [2, 4, 6, 8, -10, -12, -14, -16]);
}
#[test]
fn test_axis_iter() {
let tensor = NdTensor::arange(0, 4, None).into_shape([2, 2]);
let mut rows = tensor.axis_iter(0);
let row = rows.next().unwrap();
assert_eq!(row.shape(), &[2]);
assert_eq!(row.to_vec(), &[0, 1]);
let row = rows.next().unwrap();
assert_eq!(row.shape(), &[2]);
assert_eq!(row.to_vec(), &[2, 3]);
assert!(rows.next().is_none());
}
#[test]
fn test_axis_iter_mut() {
let mut tensor = NdTensor::arange(1, 5, None).into_shape([2, 2]);
let mut rows = tensor.axis_iter_mut(0);
let mut row = rows.next().unwrap();
row.apply(|x| x * 2);
let mut row = rows.next().unwrap();
row.apply(|x| x * -2);
assert!(rows.next().is_none());
assert_eq!(tensor.to_vec(), [2, 4, -6, -8]);
}
#[test]
fn test_broadcast() {
let data = vec![1., 2., 3., 4.];
let dest_shape = [3, 1, 2, 2];
let expected_data: Vec<_> = data.iter().copied().cycle().take(data.len() * 3).collect();
let ndtensor = NdTensor::from_data([2, 2], data);
let view = ndtensor.broadcast(dest_shape);
assert_eq!(view.shape(), dest_shape);
assert_eq!(view.to_vec(), expected_data);
let view = ndtensor.broadcast(dest_shape.as_slice());
assert_eq!(view.shape(), dest_shape);
assert_eq!(view.to_vec(), expected_data);
let tensor = ndtensor.as_dyn();
let view = tensor.broadcast(dest_shape);
assert_eq!(view.shape(), dest_shape);
assert_eq!(view.to_vec(), expected_data);
let view = tensor.broadcast(dest_shape.as_slice());
assert_eq!(view.shape(), dest_shape);
assert_eq!(view.to_vec(), expected_data);
}
#[test]
fn test_broadcast_iter() {
let tensor = NdTensor::from_data([1], vec![3]);
let elems: Vec<_> = tensor.broadcast_iter(&[2, 2]).copied().collect();
assert_eq!(elems, &[3, 3, 3, 3]);
}
#[test]
fn test_clip_dim() {
let mut tensor = NdTensor::arange(0, 10, None).into_shape([3, 3]);
tensor.clip_dim(0, 0..3); assert_eq!(tensor.shape(), [3, 3]);
tensor.clip_dim(0, 1..2); assert_eq!(tensor.shape(), [1, 3]);
assert_eq!(tensor.data(), Some([3, 4, 5].as_slice()));
}
#[test]
fn test_clone() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);
let cloned = tensor.clone();
assert_eq!(tensor.shape(), cloned.shape());
assert_eq!(tensor.to_vec(), cloned.to_vec());
}
#[test]
fn test_copy_view() {
let data = vec![1., 2., 3., 4.];
let view = NdTensorView::from_data([2, 2], &data);
let view2 = view;
assert_eq!(view.shape(), view2.shape());
}
#[test]
fn test_copy_from() {
let mut dest = Tensor::zeros(&[2, 2]);
let src = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
dest.copy_from(&src);
assert_eq!(dest.to_vec(), &[1., 2., 3., 4.]);
}
#[test]
fn test_data() {
let data = vec![1., 2., 3., 4., 5., 6.];
let tensor = NdTensorView::from_data([2, 3], &data);
assert_eq!(tensor.data(), Some(data.as_slice()));
let permuted = tensor.permuted([1, 0]);
assert_eq!(permuted.shape(), [3, 2]);
assert_eq!(permuted.data(), None);
}
#[test]
fn test_data_mut() {
let mut data = vec![1., 2., 3., 4., 5., 6.];
let mut tensor = NdTensor::from_data([2, 3], data.clone());
assert_eq!(tensor.data_mut(), Some(data.as_mut_slice()));
let mut permuted = tensor.permuted_mut([1, 0]);
assert_eq!(permuted.shape(), [3, 2]);
assert_eq!(permuted.data_mut(), None);
}
#[test]
fn test_fill() {
let data = vec![1., 2., 3., 4.];
let mut tensor = NdTensor::from_data([2, 2], data);
tensor.fill(9.);
assert_eq!(tensor.to_vec(), &[9., 9., 9., 9.]);
}
#[test]
fn test_from_fn() {
let x = NdTensor::from_fn([], |_| 5);
assert_eq!(x.data(), Some([5].as_slice()));
let x = NdTensor::from_fn([5], |i| i[0]);
assert_eq!(x.data(), Some([0, 1, 2, 3, 4].as_slice()));
let x = NdTensor::from_fn([2, 2], |[y, x]| y * 10 + x);
assert_eq!(x.data(), Some([0, 1, 10, 11].as_slice()));
let x = Tensor::from_fn(&[], |_| 6);
assert_eq!(x.data(), Some([6].as_slice()));
let x = Tensor::from_fn(&[2, 2], |index| index[0] * 10 + index[1]);
assert_eq!(x.data(), Some([0, 1, 10, 11].as_slice()));
}
#[test]
fn test_from_nested_array() {
let x = NdTensor::from([1, 2, 3]);
assert_eq!(x.shape(), [3]);
assert_eq!(x.data(), Some([1, 2, 3].as_slice()));
let x = NdTensor::from([[1, 2], [3, 4]]);
assert_eq!(x.shape(), [2, 2]);
assert_eq!(x.data(), Some([1, 2, 3, 4].as_slice()));
let x = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
assert_eq!(x.shape(), [2, 2, 2]);
assert_eq!(x.data(), Some([1, 2, 3, 4, 5, 6, 7, 8].as_slice()));
}
#[test]
fn test_from_simple_fn() {
let mut next_val = 0;
let mut gen_int = || {
let curr = next_val;
next_val += 1;
curr
};
let x = NdTensor::from_simple_fn([2, 2], &mut gen_int);
assert_eq!(x.data(), Some([0, 1, 2, 3].as_slice()));
let x = NdTensor::from_simple_fn([], &mut gen_int);
assert_eq!(x.data(), Some([4].as_slice()));
let x = Tensor::from_simple_fn(&[2, 2], gen_int);
assert_eq!(x.data(), Some([5, 6, 7, 8].as_slice()));
}
#[test]
fn test_from_vec_or_slice() {
let x = NdTensor::from(vec![1, 2, 3, 4]);
assert_eq!(x.shape(), [4]);
assert_eq!(x.data(), Some([1, 2, 3, 4].as_slice()));
let x = NdTensorView::from(&[1, 2, 3]);
assert_eq!(x.shape(), [3]);
assert_eq!(x.data(), Some([1, 2, 3].as_slice()));
}
#[test]
fn test_dyn_tensor_from_nd_tensor() {
let x = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]);
let y: Tensor<i32> = x.into();
assert_eq!(y.data(), Some([1, 2, 3, 4].as_slice()));
assert_eq!(y.shape(), &[2, 2]);
}
#[test]
fn test_nd_tensor_from_dyn_tensor() {
let x = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let y: NdTensor<i32, 2> = x.try_into().unwrap();
assert_eq!(y.data(), Some([1, 2, 3, 4].as_slice()));
assert_eq!(y.shape(), [2, 2]);
let x = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let y: Result<NdTensor<i32, 3>, _> = x.try_into();
assert!(y.is_err());
}
#[test]
fn test_from_data() {
let x = NdTensor::from_data([1, 2, 2], vec![1, 2, 3, 4]);
assert_eq!(x.shape(), [1, 2, 2]);
assert_eq!(x.strides(), [4, 2, 1]);
assert_eq!(x.to_vec(), [1, 2, 3, 4]);
}
#[test]
#[should_panic(expected = "data length 4 does not match shape [2, 2, 2]")]
fn test_from_data_shape_mismatch() {
NdTensor::from_data([2, 2, 2], vec![1, 2, 3, 4]);
}
#[test]
fn test_from_data_with_strides() {
let x = NdTensor::from_data_with_strides([2, 2, 1], vec![1, 2, 3, 4], [1, 2, 4]).unwrap();
assert_eq!(x.shape(), [2, 2, 1]);
assert_eq!(x.strides(), [1, 2, 4]);
assert_eq!(x.to_vec(), [1, 3, 2, 4]);
let x = NdTensor::from_data_with_strides([2, 2, 2], vec![1, 2, 3, 4], [1, 2, 4]);
assert_eq!(x, Err(FromDataError::StorageTooShort));
let x = NdTensor::from_data_with_strides([2, 2], vec![1, 2], [0, 1]);
assert_eq!(x, Err(FromDataError::MayOverlap));
}
#[test]
fn test_from_slice_with_strides() {
let data = [1, 2];
let x = NdTensorView::from_slice_with_strides([2, 2], &data, [0, 1]).unwrap();
assert_eq!(x.to_vec(), [1, 2, 1, 2]);
}
#[test]
fn test_from_iter() {
let x: Tensor = [1., 2., 3., 4.].into_iter().collect();
assert_eq!(x.shape(), &[4]);
assert_eq!(x.data(), Some([1., 2., 3., 4.].as_slice()));
let y: NdTensor<_, 1> = [1., 2., 3., 4.].into_iter().collect();
assert_eq!(y.shape(), [4]);
assert_eq!(y.data(), Some([1., 2., 3., 4.].as_slice()));
}
#[test]
fn test_from_scalar() {
let x = Tensor::from_scalar(5.);
let y = NdTensor::from_scalar(6.);
assert_eq!(x.item(), Some(&5.));
assert_eq!(y.item(), Some(&6.));
}
#[test]
fn test_from_vec() {
let x = NdTensor::from_vec(vec![1, 2, 3, 4]);
assert_eq!(x.shape(), [4]);
assert_eq!(x.data(), Some([1, 2, 3, 4].as_slice()));
}
#[test]
fn test_full() {
let tensor = NdTensor::full([2, 2], 2.);
assert_eq!(tensor.shape(), [2, 2]);
assert_eq!(tensor.data(), Some([2., 2., 2., 2.].as_slice()));
}
#[test]
fn test_get() {
let data = vec![1., 2., 3., 4.];
let tensor: NdTensor<f32, 2> = NdTensor::from_data([2, 2], data);
assert_eq!(tensor.get([1, 1]), Some(&4.));
assert_eq!(tensor.get([2, 1]), None);
assert_eq!(tensor.view().get([1, 1]), Some(&4.));
assert_eq!(tensor.view().get([2, 1]), None);
let data = vec![1., 2., 3., 4.];
let tensor: Tensor<f32> = Tensor::from_data(&[2, 2], data);
assert_eq!(tensor.get([1, 1]), Some(&4.));
assert_eq!(tensor.get([2, 1]), None); assert_eq!(tensor.get([1, 2, 3]), None); assert_eq!(tensor.view().get([1, 1]), Some(&4.));
assert_eq!(tensor.view().get([2, 1]), None); assert_eq!(tensor.view().get([1, 2, 3]), None); }
#[test]
fn test_get_array() {
let tensor = NdTensor::arange(1, 17, None).into_shape([4, 2, 2]);
let values: [i32; 4] = tensor.get_array([0, 0, 0], 0);
assert_eq!(values, [1, 5, 9, 13]);
let values: [i32; 4] = tensor.get_array([0, 1, 1], 0);
assert_eq!(values, [4, 8, 12, 16]);
let values: [i32; 2] = tensor.get_array([0, 0, 0], 2);
assert_eq!(values, [1, 2]);
}
#[test]
fn test_get_mut() {
let data = vec![1., 2., 3., 4.];
let mut tensor: NdTensor<f32, 2> = NdTensor::from_data([2, 2], data);
if let Some(elem) = tensor.get_mut([1, 1]) {
*elem = 9.;
}
assert_eq!(tensor[[1, 1]], 9.);
assert_eq!(tensor.get_mut([2, 1]), None);
}
#[test]
fn test_get_unchecked() {
let ndtensor = NdTensor::arange(1, 5, None);
for i in 0..ndtensor.size(0) {
assert_eq!(unsafe { ndtensor.get_unchecked([i]) }, &ndtensor[[i]]);
assert_eq!(
unsafe { ndtensor.view().get_unchecked([i]) },
&ndtensor[[i]]
);
}
let tensor = Tensor::arange(1, 5, None);
for i in 0..tensor.size(0) {
assert_eq!(unsafe { tensor.get_unchecked([i]) }, &ndtensor[[i]]);
assert_eq!(unsafe { tensor.view().get_unchecked([i]) }, &ndtensor[[i]]);
}
}
#[test]
fn test_get_unchecked_mut() {
let mut ndtensor = NdTensor::arange(1, 5, None);
for i in 0..ndtensor.size(0) {
unsafe { *ndtensor.get_unchecked_mut([i]) += 1 }
}
assert_eq!(ndtensor.to_vec(), &[2, 3, 4, 5]);
let mut tensor = Tensor::arange(1, 5, None);
for i in 0..tensor.size(0) {
unsafe { *tensor.get_unchecked_mut([i]) += 1 }
}
assert_eq!(tensor.to_vec(), &[2, 3, 4, 5]);
}
#[test]
fn test_index_and_index_mut() {
let data = vec![1., 2., 3., 4.];
let mut tensor: NdTensor<f32, 2> = NdTensor::from_data([2, 2], data);
assert_eq!(tensor[[1, 1]], 4.);
tensor[[1, 1]] = 9.;
assert_eq!(tensor[[1, 1]], 9.);
let data = vec![1., 2., 3., 4.];
let mut tensor: Tensor<f32> = Tensor::from_data(&[2, 2], data);
assert_eq!(tensor[[1, 1]], 4.);
tensor[&[1, 1]] = 9.;
assert_eq!(tensor[[1, 1]], 9.);
}
#[test]
fn test_into_data() {
let tensor = NdTensor::from_data([2], vec![2., 3.]);
assert_eq!(tensor.into_data(), vec![2., 3.]);
}
#[test]
fn test_into_dyn() {
let tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]);
let dyn_tensor = tensor.into_dyn();
assert_eq!(dyn_tensor.shape(), &[2, 2]);
assert_eq!(dyn_tensor.data(), Some([1., 2., 3., 4.].as_slice()));
}
#[test]
fn test_into_shape() {
let tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]);
let reshaped = tensor.into_shape([4]);
assert_eq!(reshaped.shape(), [4]);
assert_eq!(reshaped.data(), Some([1., 2., 3., 4.].as_slice()));
let mut tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]);
tensor.transpose();
let reshaped = tensor.into_shape([4]);
assert_eq!(reshaped.shape(), [4]);
assert_eq!(reshaped.data(), Some([1., 3., 2., 4.].as_slice()));
}
#[test]
fn test_inner_iter() {
let tensor = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let mut rows = tensor.inner_iter::<1>();
let row = rows.next().unwrap();
assert_eq!(row.shape(), [2]);
assert_eq!(row.to_vec(), &[1, 2]);
let row = rows.next().unwrap();
assert_eq!(row.shape(), [2]);
assert_eq!(row.to_vec(), &[3, 4]);
assert_eq!(rows.next(), None);
}
#[test]
fn test_inner_iter_mut() {
let mut tensor = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let mut rows = tensor.inner_iter_mut::<1>();
let mut row = rows.next().unwrap();
assert_eq!(row.shape(), [2]);
row.apply(|x| x * 2);
let mut row = rows.next().unwrap();
assert_eq!(row.shape(), [2]);
row.apply(|x| x * 2);
assert_eq!(rows.next(), None);
assert_eq!(tensor.to_vec(), &[2, 4, 6, 8]);
}
#[test]
fn test_insert_axis() {
let mut tensor = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
tensor.insert_axis(0);
assert_eq!(tensor.shape(), &[1, 2, 2]);
tensor.insert_axis(3);
assert_eq!(tensor.shape(), &[1, 2, 2, 1]);
}
#[test]
fn test_item() {
let tensor = NdTensor::from_data([], vec![5.]);
assert_eq!(tensor.item(), Some(&5.));
let tensor = NdTensor::from_data([1], vec![6.]);
assert_eq!(tensor.item(), Some(&6.));
let tensor = NdTensor::from_data([2], vec![2., 3.]);
assert_eq!(tensor.item(), None);
let tensor = Tensor::from_data(&[], vec![5.]);
assert_eq!(tensor.item(), Some(&5.));
let tensor = Tensor::from_data(&[1], vec![6.]);
assert_eq!(tensor.item(), Some(&6.));
let tensor = Tensor::from_data(&[2], vec![2., 3.]);
assert_eq!(tensor.item(), None);
}
#[test]
fn test_iter() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);
assert_eq!(
tensor.iter().copied().collect::<Vec<_>>(),
&[1., 2., 3., 4.]
);
let transposed = tensor.transposed();
assert_eq!(
transposed.iter().copied().collect::<Vec<_>>(),
&[1., 3., 2., 4.]
);
let data = vec![1., 2., 3., 4.];
let tensor = Tensor::from_data(&[2, 2], data);
assert_eq!(
tensor.iter().copied().collect::<Vec<_>>(),
&[1., 2., 3., 4.]
);
let transposed = tensor.transposed();
assert_eq!(
transposed.iter().copied().collect::<Vec<_>>(),
&[1., 3., 2., 4.]
);
}
#[test]
fn test_iter_mut() {
let data = vec![1., 2., 3., 4.];
let mut tensor = NdTensor::from_data([2, 2], data);
tensor.iter_mut().for_each(|x| *x *= 2.);
assert_eq!(tensor.to_vec(), &[2., 4., 6., 8.]);
}
#[test]
fn test_lanes() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);
let mut lanes = tensor.lanes(1);
assert_eq!(
lanes.next().unwrap().copied().collect::<Vec<_>>(),
&[1., 2.]
);
assert_eq!(
lanes.next().unwrap().copied().collect::<Vec<_>>(),
&[3., 4.]
);
}
#[test]
fn test_lanes_mut() {
let data = vec![1., 2., 3., 4.];
let mut tensor = NdTensor::from_data([2, 2], data);
let mut lanes = tensor.lanes_mut(1);
assert_eq!(lanes.next().unwrap().collect::<Vec<_>>(), &[&1., &2.]);
assert_eq!(lanes.next().unwrap().collect::<Vec<_>>(), &[&3., &4.]);
}
#[test]
fn test_make_contiguous() {
let mut tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]);
assert!(tensor.is_contiguous());
tensor.make_contiguous();
assert!(tensor.is_contiguous());
tensor.transpose();
assert!(!tensor.is_contiguous());
tensor.make_contiguous();
assert!(tensor.is_contiguous());
assert_eq!(tensor.data(), Some([1., 3., 2., 4.].as_slice()));
}
#[test]
fn test_map() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);
let doubled = tensor.map(|x| x * 2.);
assert_eq!(doubled.to_vec(), &[2., 4., 6., 8.]);
let halved = doubled.transposed().map(|x| x / 2.);
assert_eq!(halved.to_vec(), &[1., 3., 2., 4.]);
}
#[test]
fn test_matrix_layout() {
let data = vec![1., 2., 3., 4., 5., 6.];
let tensor = NdTensorView::from_data([2, 3], &data);
assert_eq!(tensor.rows(), 2);
assert_eq!(tensor.row_stride(), 3);
assert_eq!(tensor.cols(), 3);
assert_eq!(tensor.col_stride(), 1);
}
#[test]
fn test_move_axis() {
let data = vec![1., 2., 3., 4., 5., 6.];
let mut tensor = NdTensorView::from_data([2, 3], &data);
tensor.move_axis(1, 0);
assert_eq!(tensor.shape(), [3, 2]);
assert_eq!(tensor.to_vec(), &[1., 4., 2., 5., 3., 6.]);
tensor.move_axis(0, 1);
assert_eq!(tensor.shape(), [2, 3]);
assert_eq!(tensor.to_vec(), &[1., 2., 3., 4., 5., 6.]);
}
#[test]
fn test_nd_view() {
let tensor: Tensor<f32> = Tensor::zeros(&[1, 4, 5]);
let nd_view = tensor.nd_view::<3>();
assert_eq!(nd_view.shape(), [1, 4, 5]);
assert_eq!(nd_view.strides().as_ref(), tensor.strides());
let nd_view_2 = nd_view.nd_view::<3>();
assert_eq!(nd_view_2.shape(), nd_view.shape());
}
#[test]
fn test_nd_view_mut() {
let mut tensor: Tensor<f32> = Tensor::zeros(&[1, 4, 5]);
let mut nd_view = tensor.nd_view_mut::<3>();
assert_eq!(nd_view.shape(), [1, 4, 5]);
nd_view[[0, 0, 0]] = 9.;
assert_eq!(tensor[[0, 0, 0]], 9.);
}
#[test]
fn test_non_contiguous_data() {
let mut tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]);
assert_eq!(tensor.data(), Some(tensor.view().non_contiguous_data()));
tensor.transpose();
assert!(tensor.data().is_none());
assert_eq!(tensor.view().non_contiguous_data(), [1, 2, 3, 4]);
}
#[test]
fn test_rand() {
let mut rng = XorShiftRng::new(1234);
let tensor = NdTensor::rand([2, 2], &mut rng);
assert_eq!(tensor.shape(), [2, 2]);
for &x in tensor.iter() {
assert!(x >= 0. && x <= 1.);
}
}
#[test]
fn test_permute() {
let data = vec![1., 2., 3., 4., 5., 6.];
let mut tensor = NdTensorView::from_data([2, 3], &data);
tensor.permute([1, 0]);
assert_eq!(tensor.shape(), [3, 2]);
assert_eq!(tensor.to_vec(), &[1., 4., 2., 5., 3., 6.]);
}
#[test]
fn test_permuted() {
let data = vec![1., 2., 3., 4., 5., 6.];
let tensor = NdTensorView::from_data([2, 3], &data);
let permuted = tensor.permuted([1, 0]);
assert_eq!(permuted.shape(), [3, 2]);
assert_eq!(permuted.to_vec(), &[1., 4., 2., 5., 3., 6.]);
}
#[test]
fn test_permuted_mut() {
let data = vec![1., 2., 3., 4., 5., 6.];
let mut tensor = NdTensor::from_data([2, 3], data);
let mut permuted = tensor.permuted_mut([1, 0]);
permuted[[2, 1]] = 8.;
assert_eq!(permuted.shape(), [3, 2]);
assert_eq!(permuted.to_vec(), &[1., 4., 2., 5., 3., 8.]);
}
#[test]
fn test_reshape() {
let mut tensor = Tensor::<f32>::from_data(&[2, 2], vec![1., 2., 3., 4.]);
tensor.transpose();
tensor.reshape(&[4]);
assert_eq!(tensor.shape(), &[4]);
assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]);
let mut view = tensor.view();
view.reshape(&[2, 2]);
assert_eq!(view.shape(), &[2, 2]);
let mut view_mut = tensor.view_mut();
view_mut.reshape(&[2, 2]);
assert_eq!(view_mut.shape(), &[2, 2]);
}
#[test]
fn test_reshaped() {
let data = vec![1., 2., 3., 4., 5., 6.];
let tensor = NdTensorView::from_data([1, 1, 2, 1, 3], &data);
let reshaped = tensor.reshaped([6]);
assert_eq!(reshaped.shape(), [6]);
let reshaped = tensor.reshaped([6].as_slice());
assert_eq!(reshaped.shape(), &[6]);
}
#[test]
fn test_reshaped_mut() {
let data = vec![1., 2., 3., 4., 5., 6.];
let mut tensor = NdTensor::from_data([1, 1, 2, 1, 3], data);
let mut reshaped = tensor.reshaped_mut([6]);
reshaped[[0]] = 0.;
reshaped[[5]] = 0.;
assert_eq!(tensor.data(), Some([0., 2., 3., 4., 5., 0.].as_slice()));
}
#[test]
fn test_set_array() {
let mut tensor = NdTensor::arange(1, 17, None).into_shape([4, 2, 2]);
tensor.set_array([0, 0, 0], 0, [-1, -2, -3, -4]);
assert_eq!(
tensor.iter().copied().collect::<Vec<_>>(),
&[-1, 2, 3, 4, -2, 6, 7, 8, -3, 10, 11, 12, -4, 14, 15, 16]
);
}
#[test]
fn test_slice_with_ndlayout() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);
let row_one = tensor.slice(0);
assert_eq!(row_one[[0]], 1.);
assert_eq!(row_one[[1]], 2.);
let row_two = tensor.slice(1);
assert_eq!(row_two[[0]], 3.);
assert_eq!(row_two[[1]], 4.);
}
#[test]
fn test_slice_dyn_with_ndlayout() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);
let row_one = tensor.slice_dyn(0);
assert_eq!(row_one[[0]], 1.);
assert_eq!(row_one[[1]], 2.);
let row_two = tensor.slice_dyn(1);
assert_eq!(row_two[[0]], 3.);
assert_eq!(row_two[[1]], 4.);
}
#[test]
fn test_slice_with_dynlayout() {
let data = vec![1., 2., 3., 4.];
let tensor = Tensor::from_data(&[2, 2], data);
let row_one = tensor.slice(0);
assert_eq!(row_one[[0]], 1.);
assert_eq!(row_one[[1]], 2.);
let row_two = tensor.slice(1);
assert_eq!(row_two[[0]], 3.);
assert_eq!(row_two[[1]], 4.);
}
#[test]
fn test_slice_dyn_with_dynlayout() {
let data = vec![1., 2., 3., 4.];
let tensor = Tensor::from_data(&[2, 2], data);
let row_one = tensor.slice_dyn(0);
assert_eq!(row_one[[0]], 1.);
assert_eq!(row_one[[1]], 2.);
let row_two = tensor.slice_dyn(1);
assert_eq!(row_two[[0]], 3.);
assert_eq!(row_two[[1]], 4.);
}
#[test]
fn test_slice_iter() {
let data = vec![1., 2., 3., 4.];
let tensor = Tensor::from_data(&[2, 2], data);
let row_one: Vec<_> = tensor
.slice_iter(&[SliceItem::Index(0), SliceItem::full_range()])
.copied()
.collect();
assert_eq!(row_one, &[1., 2.]);
}
#[test]
fn test_slice_mut() {
let data = vec![1., 2., 3., 4.];
let mut tensor = NdTensor::from_data([2, 2], data);
let mut row = tensor.slice_mut(1);
row[[0]] = 8.;
row[[1]] = 9.;
assert_eq!(tensor.to_vec(), &[1., 2., 8., 9.]);
}
#[test]
fn test_slice_mut_dyn() {
let data = vec![1., 2., 3., 4.];
let mut tensor = NdTensor::from_data([2, 2], data);
let mut row = tensor.slice_mut_dyn(1);
row[[0]] = 8.;
row[[1]] = 9.;
assert_eq!(tensor.to_vec(), &[1., 2., 8., 9.]);
}
#[test]
fn test_squeezed() {
let data = vec![1., 2., 3., 4., 5., 6.];
let tensor = NdTensorView::from_data([1, 1, 2, 1, 3], &data);
let squeezed = tensor.squeezed();
assert_eq!(squeezed.shape(), &[2, 3]);
}
#[test]
fn test_to_array() {
let tensor = NdTensor::arange(1., 5., None).into_shape([2, 2]);
let col0: [f32; 2] = tensor.view().transposed().slice::<1, _>(0).to_array();
let col1: [f32; 2] = tensor.view().transposed().slice::<1, _>(1).to_array();
assert_eq!(col0, [1., 3.]);
assert_eq!(col1, [2., 4.]);
}
#[test]
fn test_to_contiguous() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);
let mut tensor = tensor.to_contiguous();
assert_eq!(tensor.to_vec(), &[1., 2., 3., 4.]);
tensor.transpose();
assert!(!tensor.is_contiguous());
assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]);
let tensor = tensor.to_contiguous();
assert!(tensor.is_contiguous());
assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]);
}
#[test]
fn test_to_shape() {
let tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]);
let flat = tensor.to_shape([4]);
assert_eq!(flat.shape(), [4]);
assert_eq!(flat.data(), Some([1, 2, 3, 4].as_slice()));
}
#[test]
fn test_to_vec() {
let tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]);
assert_eq!(tensor.to_vec(), &[1, 2, 3, 4]);
let mut tensor = tensor.clone();
tensor.transpose();
assert_eq!(tensor.to_vec(), &[1, 3, 2, 4]);
}
#[test]
fn test_to_tensor() {
let data = vec![1., 2., 3., 4.];
let view = NdTensorView::from_data([2, 2], &data);
let tensor = view.to_tensor();
assert_eq!(tensor.shape(), view.shape());
assert_eq!(tensor.to_vec(), view.to_vec());
}
#[test]
fn test_transpose() {
let data = vec![1., 2., 3., 4., 5., 6.];
let mut tensor = NdTensorView::from_data([2, 3], &data);
tensor.transpose();
assert_eq!(tensor.shape(), [3, 2]);
assert_eq!(tensor.to_vec(), &[1., 4., 2., 5., 3., 6.]);
}
#[test]
fn test_transposed() {
let data = vec![1., 2., 3., 4., 5., 6.];
let tensor = NdTensorView::from_data([2, 3], &data);
let permuted = tensor.transposed();
assert_eq!(permuted.shape(), [3, 2]);
assert_eq!(permuted.to_vec(), &[1., 4., 2., 5., 3., 6.]);
}
#[test]
fn test_try_from_data() {
let x = NdTensor::try_from_data([1, 2, 2], vec![1, 2, 3, 4]);
assert!(x.is_ok());
if let Ok(x) = x {
assert_eq!(x.shape(), [1, 2, 2]);
assert_eq!(x.strides(), [4, 2, 1]);
assert_eq!(x.to_vec(), [1, 2, 3, 4]);
}
let x = NdTensor::try_from_data([1, 2, 2], vec![1]);
assert_eq!(x, Err(FromDataError::StorageLengthMismatch));
}
#[test]
fn test_try_slice() {
let data = vec![1., 2., 3., 4.];
let tensor = Tensor::from_data(&[2, 2], data);
let row = tensor.try_slice_dyn(0);
assert!(row.is_ok());
assert_eq!(row.unwrap().data(), Some([1., 2.].as_slice()));
let row = tensor.try_slice_dyn(1);
assert!(row.is_ok());
let row = tensor.try_slice_dyn(2);
assert!(row.is_err());
}
#[test]
fn test_try_slice_mut() {
let data = vec![1., 2., 3., 4.];
let mut tensor = Tensor::from_data(&[2, 2], data);
let mut row = tensor.try_slice_mut(0).unwrap();
row[[0]] += 1.;
row[[1]] += 1.;
assert_eq!(row.data(), Some([2., 3.].as_slice()));
let row = tensor.try_slice_mut(1);
assert!(row.is_ok());
let row = tensor.try_slice_dyn(2);
assert!(row.is_err());
}
#[test]
fn test_view() {
let tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]);
let view = tensor.view();
assert_eq!(view.data(), Some([1, 2, 3, 4].as_slice()));
}
#[test]
fn test_view_mut() {
let mut tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]);
let mut view = tensor.view_mut();
view[[0, 0]] = 0;
view[[1, 1]] = 0;
assert_eq!(tensor.data(), Some([0, 2, 3, 0].as_slice()));
}
#[test]
fn test_weakly_checked_view() {
let tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]);
let view = tensor.weakly_checked_view();
for y in 0..tensor.size(0) {
for x in 0..tensor.size(1) {
assert_eq!(view[[y, x]], tensor[[y, x]]);
}
}
assert_eq!(view[[0, 2]], 3);
}
#[test]
fn test_weakly_checked_view_mut() {
let mut tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]);
let mut view = tensor.weakly_checked_view_mut();
view[[0, 0]] = 5;
view[[1, 1]] = 6;
view[[0, 2]] = 7;
assert_eq!(tensor.data(), Some([5, 2, 7, 6].as_slice()));
}
#[test]
fn test_zeros() {
let tensor = NdTensor::zeros([2, 2]);
assert_eq!(tensor.shape(), [2, 2]);
assert_eq!(tensor.data(), Some([0, 0, 0, 0].as_slice()));
}
}