use std::borrow::Cow;
use std::fmt::Debug;
use std::mem::MaybeUninit;
use std::ops::{Index, IndexMut, Range};
use std::sync::Arc;
use crate::assume_init::AssumeInit;
use crate::copy::{
copy_into, copy_into_slice, copy_into_uninit, copy_range_into_slice, map_into_slice,
};
use crate::errors::{DimensionError, ExpandError, FromDataError, ReshapeError, SliceError};
use crate::iterators::{
AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterMut, Iter, IterMut,
Lanes, LanesMut, for_each_mut,
};
use crate::layout::{
AsIndex, BroadcastLayout, DynLayout, IntoLayout, Layout, LayoutExt, MatrixLayout, MutLayout,
NdLayout, OverlapPolicy, RemoveDim, ResizeLayout, SliceWith, TrustedLayout,
};
use crate::overlap::may_have_internal_overlap;
use crate::slice_range::{IntoSliceItems, SliceItem};
use crate::storage::{
Alloc, CowData, GlobalAlloc, IntoStorage, Storage, StorageMut, ViewData, ViewMutData,
};
use crate::type_num::IndexCount;
use crate::{Contiguous, RandomSource};
pub struct TensorBase<S: Storage, L: Layout> {
data: S,
layout: L,
}
pub trait AsView: Layout {
type Elem;
type Layout: Clone + for<'a> Layout<Index<'a> = Self::Index<'a>>;
fn view(&self) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>;
fn layout(&self) -> &Self::Layout;
fn as_cow(&self) -> TensorBase<CowData<'_, Self::Elem>, Self::Layout>
where
[Self::Elem]: ToOwned,
{
self.view().as_cow()
}
fn as_dyn(&self) -> TensorBase<ViewData<'_, Self::Elem>, DynLayout> {
self.view().as_dyn()
}
fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<'_, Self::Elem, Self::Layout>
where
Self::Layout: MutLayout,
{
self.view().axis_chunks(dim, chunk_size)
}
fn axis_iter(&self, dim: usize) -> AxisIter<'_, Self::Elem, Self::Layout>
where
Self::Layout: MutLayout + RemoveDim,
{
self.view().axis_iter(dim)
}
fn broadcast<S: IntoLayout>(&self, shape: S) -> TensorBase<ViewData<'_, Self::Elem>, S::Layout>
where
Self::Layout: BroadcastLayout<S::Layout>,
{
self.view().broadcast(shape)
}
fn try_broadcast<S: IntoLayout>(
&self,
shape: S,
) -> Result<TensorBase<ViewData<'_, Self::Elem>, S::Layout>, ExpandError>
where
Self::Layout: BroadcastLayout<S::Layout>,
{
self.view().try_broadcast(shape)
}
fn copy_into_slice<'a>(&self, dest: &'a mut [MaybeUninit<Self::Elem>]) -> &'a [Self::Elem]
where
Self::Elem: Copy;
fn data(&self) -> Option<&[Self::Elem]>;
fn get<I: AsIndex<Self::Layout>>(&self, index: I) -> Option<&Self::Elem>
where
Self::Layout: TrustedLayout,
{
self.view().get(index)
}
unsafe fn get_unchecked<I: AsIndex<Self::Layout>>(&self, index: I) -> &Self::Elem {
let view = self.view();
unsafe { view.get_unchecked(index) }
}
fn index_axis(
&self,
axis: usize,
index: usize,
) -> TensorBase<ViewData<'_, Self::Elem>, <Self::Layout as RemoveDim>::Output>
where
Self::Layout: MutLayout + RemoveDim,
{
self.view().index_axis(axis, index)
}
fn inner_iter<const N: usize>(&self) -> InnerIter<'_, Self::Elem, NdLayout<N>> {
self.view().inner_iter()
}
fn inner_iter_dyn(&self, n: usize) -> InnerIter<'_, Self::Elem, DynLayout> {
self.view().inner_iter_dyn(n)
}
fn insert_axis(&mut self, index: usize)
where
Self::Layout: ResizeLayout;
fn remove_axis(&mut self, index: usize)
where
Self::Layout: ResizeLayout;
fn item(&self) -> Option<&Self::Elem> {
self.view().item()
}
fn iter(&self) -> Iter<'_, Self::Elem>;
fn lanes(&self, dim: usize) -> Lanes<'_, Self::Elem>
where
Self::Layout: RemoveDim,
{
self.view().lanes(dim)
}
fn map<F, U>(&self, f: F) -> TensorBase<Vec<U>, Self::Layout>
where
F: Fn(&Self::Elem) -> U,
Self::Layout: MutLayout,
{
self.view().map(f)
}
fn map_in<A: Alloc, F, U>(&self, alloc: A, f: F) -> TensorBase<Vec<U>, Self::Layout>
where
F: Fn(&Self::Elem) -> U,
Self::Layout: MutLayout,
{
self.view().map_in(alloc, f)
}
fn merge_axes(&mut self)
where
Self::Layout: ResizeLayout;
fn move_axis(&mut self, from: usize, to: usize)
where
Self::Layout: MutLayout;
fn nd_view<const N: usize>(&self) -> TensorBase<ViewData<'_, Self::Elem>, NdLayout<N>> {
self.view().nd_view()
}
fn permute(&mut self, order: Self::Index<'_>)
where
Self::Layout: MutLayout;
fn permuted(&self, order: Self::Index<'_>) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>
where
Self::Layout: MutLayout,
{
self.view().permuted(order)
}
fn reshaped<S: Copy + IntoLayout>(
&self,
shape: S,
) -> TensorBase<CowData<'_, Self::Elem>, S::Layout>
where
Self::Elem: Clone,
Self::Layout: MutLayout,
{
self.view().reshaped(shape)
}
fn reshaped_in<A: Alloc, S: Copy + IntoLayout>(
&self,
alloc: A,
shape: S,
) -> TensorBase<CowData<'_, Self::Elem>, S::Layout>
where
Self::Elem: Clone,
Self::Layout: MutLayout,
{
self.view().reshaped_in(alloc, shape)
}
fn transpose(&mut self)
where
Self::Layout: MutLayout;
fn transposed(&self) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>
where
Self::Layout: MutLayout,
{
self.view().transposed()
}
#[allow(clippy::type_complexity)]
fn slice<R: IntoSliceItems + IndexCount>(
&self,
range: R,
) -> TensorBase<ViewData<'_, Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
where
Self::Layout: SliceWith<R, R::Count>,
{
self.view().slice(range)
}
fn slice_axis(
&self,
axis: usize,
range: Range<usize>,
) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>
where
Self::Layout: MutLayout,
{
self.view().slice_axis(axis, range)
}
#[allow(clippy::type_complexity)]
fn try_slice<R: IntoSliceItems + IndexCount>(
&self,
range: R,
) -> Result<
TensorBase<ViewData<'_, Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>,
SliceError,
>
where
Self::Layout: SliceWith<R, R::Count>,
{
self.view().try_slice(range)
}
#[allow(clippy::type_complexity)]
fn slice_copy<R: Clone + IntoSliceItems + IndexCount>(
&self,
range: R,
) -> TensorBase<Vec<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
where
Self::Elem: Clone,
Self::Layout: SliceWith<
R,
R::Count,
Layout: for<'a> Layout<Index<'a>: TryFrom<&'a [usize], Error: Debug>>,
>,
{
self.slice_copy_in(GlobalAlloc::new(), range)
}
#[allow(clippy::type_complexity)]
fn slice_copy_in<A: Alloc, R: Clone + IntoSliceItems + IndexCount>(
&self,
pool: A,
range: R,
) -> TensorBase<Vec<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
where
Self::Elem: Clone,
Self::Layout: SliceWith<
R,
R::Count,
Layout: for<'a> Layout<Index<'a>: TryFrom<&'a [usize], Error: Debug>>,
>,
{
if let Ok(slice_view) = self.try_slice(range.clone()) {
return slice_view.to_tensor_in(pool);
}
let items = range.into_slice_items();
let sliced_shape: Vec<_> = items
.as_ref()
.iter()
.copied()
.enumerate()
.filter_map(|(dim, item)| match item {
SliceItem::Index(_) => None,
SliceItem::Range(range) => Some(range.index_range(self.size(dim)).steps()),
})
.collect();
let sliced_len = sliced_shape.iter().product();
let mut sliced_data = pool.alloc(sliced_len);
copy_range_into_slice(
self.as_dyn(),
&mut sliced_data.spare_capacity_mut()[..sliced_len],
items.as_ref(),
);
unsafe {
sliced_data.set_len(sliced_len);
}
let sliced_shape = sliced_shape.as_slice().try_into().expect("slice failed");
TensorBase::from_data(sliced_shape, sliced_data)
}
fn squeezed(&self) -> TensorView<'_, Self::Elem>
where
Self::Layout: MutLayout,
{
self.view().squeezed()
}
fn to_vec(&self) -> Vec<Self::Elem>
where
Self::Elem: Clone;
fn to_vec_in<A: Alloc>(&self, alloc: A) -> Vec<Self::Elem>
where
Self::Elem: Clone;
fn to_contiguous(&self) -> Contiguous<TensorBase<CowData<'_, Self::Elem>, Self::Layout>>
where
Self::Elem: Clone,
Self::Layout: MutLayout,
{
self.view().to_contiguous()
}
fn to_contiguous_in<A: Alloc>(
&self,
alloc: A,
) -> Contiguous<TensorBase<CowData<'_, Self::Elem>, Self::Layout>>
where
Self::Elem: Clone,
Self::Layout: MutLayout,
{
self.view().to_contiguous_in(alloc)
}
fn to_shape<S: IntoLayout>(&self, shape: S) -> TensorBase<Vec<Self::Elem>, S::Layout>
where
Self::Elem: Clone,
Self::Layout: MutLayout;
fn to_slice(&self) -> Cow<'_, [Self::Elem]>
where
Self::Elem: Clone,
{
self.view().to_slice()
}
fn to_tensor(&self) -> TensorBase<Vec<Self::Elem>, Self::Layout>
where
Self::Elem: Clone,
Self::Layout: MutLayout,
{
self.to_tensor_in(GlobalAlloc::new())
}
fn to_tensor_in<A: Alloc>(&self, alloc: A) -> TensorBase<Vec<Self::Elem>, Self::Layout>
where
Self::Elem: Clone,
Self::Layout: MutLayout,
{
TensorBase::from_data(self.layout().shape(), self.to_vec_in(alloc))
}
fn weakly_checked_view(&self) -> WeaklyCheckedView<ViewData<'_, Self::Elem>, Self::Layout> {
self.view().weakly_checked_view()
}
}
impl<S: Storage, L: Layout> TensorBase<S, L> {
#[track_caller]
pub fn from_data<D: IntoStorage<Output = S>>(shape: L::Index<'_>, data: D) -> TensorBase<S, L>
where
for<'a> L::Index<'a>: Clone,
L: MutLayout,
{
let data = data.into_storage();
let len = data.len();
match Self::try_from_data(shape.clone(), data) {
Ok(data) => data,
Err(_) => panic!(
"data length {} does not match shape {:?}",
len,
shape.as_ref()
),
}
}
pub fn try_from_data<D: IntoStorage<Output = S>>(
shape: L::Index<'_>,
data: D,
) -> Result<TensorBase<S, L>, FromDataError>
where
L: MutLayout,
{
let data = data.into_storage();
let layout = L::from_shape(shape);
if layout.min_data_len() != data.len() {
return Err(FromDataError::StorageLengthMismatch);
}
Ok(TensorBase { data, layout })
}
pub fn from_storage_and_layout(data: S, layout: L) -> TensorBase<S, L> {
assert!(data.len() >= layout.min_data_len());
assert!(
!S::MUTABLE
|| !may_have_internal_overlap(layout.shape().as_ref(), layout.strides().as_ref())
);
TensorBase { data, layout }
}
pub(crate) unsafe fn from_storage_and_layout_unchecked(data: S, layout: L) -> TensorBase<S, L> {
debug_assert!(data.len() >= layout.min_data_len());
debug_assert!(
!S::MUTABLE
|| !may_have_internal_overlap(layout.shape().as_ref(), layout.strides().as_ref())
);
TensorBase { data, layout }
}
pub fn from_data_with_strides<D: IntoStorage<Output = S>>(
shape: L::Index<'_>,
data: D,
strides: L::Index<'_>,
) -> Result<TensorBase<S, L>, FromDataError>
where
L: MutLayout,
{
let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::DisallowOverlap)?;
let data = data.into_storage();
if layout.min_data_len() > data.len() {
return Err(FromDataError::StorageTooShort);
}
Ok(TensorBase { data, layout })
}
pub fn into_dyn(self) -> TensorBase<S, DynLayout>
where
L: Into<DynLayout>,
{
TensorBase {
data: self.data,
layout: self.layout.into(),
}
}
pub(crate) fn into_storage(self) -> S {
self.data
}
fn nd_layout<const N: usize>(&self) -> Option<NdLayout<N>> {
if self.ndim() != N {
return None;
}
let shape: [usize; N] = std::array::from_fn(|i| self.size(i));
let strides: [usize; N] = std::array::from_fn(|i| self.stride(i));
let layout = NdLayout::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)
.expect("invalid layout");
Some(layout)
}
pub fn data_ptr(&self) -> *const S::Elem {
self.data.as_ptr()
}
}
impl<S: StorageMut, L: Clone + Layout> TensorBase<S, L> {
pub fn axis_iter_mut(&mut self, dim: usize) -> AxisIterMut<'_, S::Elem, L>
where
L: RemoveDim,
{
AxisIterMut::new(self.view_mut(), dim)
}
pub fn axis_chunks_mut(
&mut self,
dim: usize,
chunk_size: usize,
) -> AxisChunksMut<'_, S::Elem, L>
where
L: MutLayout,
{
AxisChunksMut::new(self.view_mut(), dim, chunk_size)
}
pub fn apply<F: Fn(&S::Elem) -> S::Elem>(&mut self, f: F) {
if let Some(data) = self.data_mut() {
data.iter_mut().for_each(|x| *x = f(x));
} else {
for_each_mut(self.as_dyn_mut(), |x| *x = f(x));
}
}
pub fn as_dyn_mut(&mut self) -> TensorBase<ViewMutData<'_, S::Elem>, DynLayout> {
TensorBase {
layout: DynLayout::from(&self.layout),
data: self.data.view_mut(),
}
}
pub fn copy_from<S2: Storage<Elem = S::Elem>>(&mut self, other: &TensorBase<S2, L>)
where
S::Elem: Clone,
L: Clone,
{
assert!(
self.shape() == other.shape(),
"copy dest shape {:?} != src shape {:?}",
self.shape(),
other.shape()
);
if let Some(dest) = self.data_mut() {
if let Some(src) = other.data() {
dest.clone_from_slice(src);
} else {
let uninit_dest: &mut [MaybeUninit<S::Elem>] = unsafe { std::mem::transmute(dest) };
for x in &mut *uninit_dest {
unsafe { x.assume_init_drop() }
}
copy_into_slice(other.as_dyn(), uninit_dest);
}
} else {
copy_into(other.as_dyn(), self.as_dyn_mut());
}
}
pub fn data_mut(&mut self) -> Option<&mut [S::Elem]> {
let len = self.layout.min_data_len();
let data = self.data.slice_mut(0..len);
self.layout.is_contiguous().then(|| unsafe {
data.to_slice_mut()
})
}
pub fn index_axis_mut(
&mut self,
axis: usize,
index: usize,
) -> TensorBase<ViewMutData<'_, S::Elem>, <L as RemoveDim>::Output>
where
L: MutLayout + RemoveDim,
{
let (offsets, layout) = self.layout.index_axis(axis, index);
TensorBase {
data: self.data.slice_mut(offsets),
layout,
}
}
pub fn storage_mut(&mut self) -> ViewMutData<'_, S::Elem> {
self.data.view_mut()
}
pub fn fill(&mut self, value: S::Elem)
where
S::Elem: Clone,
{
self.apply(|_| value.clone())
}
pub fn get_mut<I: AsIndex<L>>(&mut self, index: I) -> Option<&mut S::Elem>
where
L: TrustedLayout,
{
self.offset(index.as_index()).map(|offset| unsafe {
self.data.get_unchecked_mut(offset)
})
}
pub unsafe fn get_unchecked_mut<I: AsIndex<L>>(&mut self, index: I) -> &mut S::Elem {
let offset = self.layout.offset_unchecked(index.as_index());
unsafe { self.data.get_unchecked_mut(offset) }
}
pub(crate) fn mut_view_ref(&mut self) -> TensorBase<ViewMutData<'_, S::Elem>, &L> {
TensorBase {
data: self.data.view_mut(),
layout: &self.layout,
}
}
pub fn inner_iter_mut<const N: usize>(&mut self) -> InnerIterMut<'_, S::Elem, NdLayout<N>>
where
L: MutLayout,
{
InnerIterMut::new(self.view_mut())
}
pub fn inner_iter_dyn_mut(&mut self, n: usize) -> InnerIterMut<'_, S::Elem, DynLayout>
where
L: MutLayout,
{
InnerIterMut::new_dyn(self.view_mut(), n)
}
pub fn iter_mut(&mut self) -> IterMut<'_, S::Elem> {
IterMut::new(self.mut_view_ref())
}
pub fn lanes_mut(&mut self, dim: usize) -> LanesMut<'_, S::Elem>
where
L: RemoveDim,
{
LanesMut::new(self.mut_view_ref(), dim)
}
pub fn nd_view_mut<const N: usize>(
&mut self,
) -> TensorBase<ViewMutData<'_, S::Elem>, NdLayout<N>> {
assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N);
TensorBase {
layout: self.nd_layout().unwrap(),
data: self.data.view_mut(),
}
}
pub fn permuted_mut(&mut self, order: L::Index<'_>) -> TensorBase<ViewMutData<'_, S::Elem>, L>
where
L: MutLayout,
{
TensorBase {
layout: self.layout.permuted(order),
data: self.data.view_mut(),
}
}
pub fn reshaped_mut<SH: IntoLayout>(
&mut self,
shape: SH,
) -> Result<TensorBase<ViewMutData<'_, S::Elem>, SH::Layout>, ReshapeError>
where
L: MutLayout,
{
let layout = self.layout.reshaped_for_view(shape)?;
Ok(TensorBase {
layout,
data: self.data.view_mut(),
})
}
pub fn slice_axis_mut(
&mut self,
axis: usize,
range: Range<usize>,
) -> TensorBase<ViewMutData<'_, S::Elem>, L>
where
L: MutLayout,
{
let (offset_range, sliced_layout) = self.layout.slice_axis(axis, range.clone()).unwrap();
debug_assert_eq!(sliced_layout.size(axis), range.len());
TensorBase {
data: self.data.slice_mut(offset_range),
layout: sliced_layout,
}
}
pub fn slice_mut<R: IntoSliceItems + IndexCount>(
&mut self,
range: R,
) -> TensorBase<ViewMutData<'_, S::Elem>, <L as SliceWith<R, R::Count>>::Layout>
where
L: SliceWith<R, R::Count>,
{
self.try_slice_mut(range).expect("slice failed")
}
#[allow(clippy::type_complexity)]
pub fn try_slice_mut<R: IntoSliceItems + IndexCount>(
&mut self,
range: R,
) -> Result<
TensorBase<ViewMutData<'_, S::Elem>, <L as SliceWith<R, R::Count>>::Layout>,
SliceError,
>
where
L: SliceWith<R, R::Count>,
{
let (offset_range, sliced_layout) = self.layout.slice_with(range)?;
Ok(TensorBase {
data: self.data.slice_mut(offset_range),
layout: sliced_layout,
})
}
pub fn view_mut(&mut self) -> TensorBase<ViewMutData<'_, S::Elem>, L>
where
L: Clone,
{
TensorBase {
data: self.data.view_mut(),
layout: self.layout.clone(),
}
}
pub fn weakly_checked_view_mut(&mut self) -> WeaklyCheckedView<ViewMutData<'_, S::Elem>, L> {
WeaklyCheckedView {
base: self.view_mut(),
}
}
}
impl<T, L: Clone + Layout> TensorBase<Vec<T>, L> {
pub fn arange(start: T, end: T, step: Option<T>) -> TensorBase<Vec<T>, L>
where
T: Copy + PartialOrd + From<bool> + std::ops::Add<Output = T>,
[usize; 1]: AsIndex<L>,
L: MutLayout,
{
let step = step.unwrap_or((true).into());
let mut data = Vec::new();
let mut curr = start;
while curr < end {
data.push(curr);
curr = curr + step;
}
TensorBase::from_data([data.len()].as_index(), data)
}
pub fn append<S2: Storage<Elem = T>>(
&mut self,
axis: usize,
other: &TensorBase<S2, L>,
) -> Result<(), ExpandError>
where
T: Copy,
L: MutLayout,
{
let shape_match = self.ndim() == other.ndim()
&& (0..self.ndim()).all(|d| d == axis || self.size(d) == other.size(d));
if !shape_match {
return Err(ExpandError::ShapeMismatch);
}
let old_size = self.size(axis);
let new_size = self.size(axis) + other.size(axis);
let Some(new_layout) = self.expanded_layout(axis, new_size) else {
return Err(ExpandError::InsufficientCapacity);
};
let new_data_len = new_layout.min_data_len();
self.layout = new_layout;
assert!(self.data.capacity() >= new_data_len);
unsafe {
self.data.set_len(new_data_len);
}
self.slice_axis_mut(axis, old_size..new_size)
.copy_from(other);
Ok(())
}
pub fn from_vec(vec: Vec<T>) -> TensorBase<Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
L: MutLayout,
{
TensorBase::from_data([vec.len()].as_index(), vec)
}
pub fn clip_dim(&mut self, dim: usize, range: Range<usize>)
where
T: Copy,
L: MutLayout,
{
let (start, end) = (range.start, range.end);
assert!(start <= end, "start must be <= end");
assert!(end <= self.size(dim), "end must be <= dim size");
self.layout.resize_dim(dim, end - start);
let range = if self.is_empty() {
0..0
} else {
let start_offset = start * self.layout.stride(dim);
let end_offset = start_offset + self.layout.min_data_len();
start_offset..end_offset
};
self.data.copy_within(range.clone(), 0);
self.data.truncate(range.end - range.start);
}
pub fn has_capacity(&self, axis: usize, new_size: usize) -> bool
where
L: MutLayout,
{
self.expanded_layout(axis, new_size).is_some()
}
fn expanded_layout(&self, axis: usize, new_size: usize) -> Option<L>
where
L: MutLayout,
{
let mut new_layout = self.layout.clone();
new_layout.resize_dim(axis, new_size);
let new_data_len = new_layout.min_data_len();
let has_capacity = new_data_len <= self.data.capacity()
&& !may_have_internal_overlap(
new_layout.shape().as_ref(),
new_layout.strides().as_ref(),
);
has_capacity.then_some(new_layout)
}
pub fn into_cow(self) -> TensorBase<CowData<'static, T>, L> {
let TensorBase { data, layout } = self;
TensorBase {
layout,
data: CowData::Owned(data),
}
}
pub fn into_arc(self) -> TensorBase<Arc<Vec<T>>, L> {
let TensorBase { data, layout } = self;
TensorBase {
layout,
data: Arc::new(data),
}
}
pub fn into_data(self) -> Vec<T>
where
T: Clone,
{
if self.is_contiguous() {
self.into_non_contiguous_data()
} else {
self.to_vec()
}
}
pub fn into_non_contiguous_data(mut self) -> Vec<T> {
self.data.truncate(self.layout.min_data_len());
self.data
}
#[track_caller]
pub fn into_shape<S: Copy + IntoLayout>(self, shape: S) -> TensorBase<Vec<T>, S::Layout>
where
T: Clone,
L: MutLayout,
{
let Ok(layout) = self.layout.reshaped_for_copy(shape) else {
panic!(
"element count mismatch reshaping {:?} to {:?}",
self.shape(),
shape
);
};
TensorBase {
layout,
data: self.into_data(),
}
}
pub fn from_fn<F: FnMut(L::Index<'_>) -> T, Idx>(
shape: L::Index<'_>,
mut f: F,
) -> TensorBase<Vec<T>, L>
where
L::Indices: Iterator<Item = Idx>,
Idx: AsIndex<L>,
L: MutLayout,
{
let layout = L::from_shape(shape);
let data: Vec<T> = layout.indices().map(|idx| f(idx.as_index())).collect();
TensorBase { data, layout }
}
pub fn from_simple_fn<F: FnMut() -> T>(shape: L::Index<'_>, f: F) -> TensorBase<Vec<T>, L>
where
L: MutLayout,
{
Self::from_simple_fn_in(GlobalAlloc::new(), shape, f)
}
pub fn from_simple_fn_in<A: Alloc, F: FnMut() -> T>(
alloc: A,
shape: L::Index<'_>,
mut f: F,
) -> TensorBase<Vec<T>, L>
where
L: MutLayout,
{
let len = shape.as_ref().iter().product();
let mut data = alloc.alloc(len);
data.extend(std::iter::from_fn(|| Some(f())).take(len));
TensorBase::from_data(shape, data)
}
pub fn from_scalar(value: T) -> TensorBase<Vec<T>, L>
where
[usize; 0]: AsIndex<L>,
L: MutLayout,
{
TensorBase::from_data([].as_index(), vec![value])
}
pub fn full(shape: L::Index<'_>, value: T) -> TensorBase<Vec<T>, L>
where
T: Clone,
L: MutLayout,
{
Self::full_in(GlobalAlloc::new(), shape, value)
}
pub fn full_in<A: Alloc>(alloc: A, shape: L::Index<'_>, value: T) -> TensorBase<Vec<T>, L>
where
T: Clone,
L: MutLayout,
{
let len = shape.as_ref().iter().product();
let mut data = alloc.alloc(len);
data.resize(len, value);
TensorBase::from_data(shape, data)
}
pub fn make_contiguous(&mut self)
where
T: Clone,
L: MutLayout,
{
if self.is_contiguous() {
return;
}
self.data = self.to_vec();
self.layout = L::from_shape(self.layout.shape());
}
pub fn rand<R: RandomSource<T>>(shape: L::Index<'_>, rand_src: &mut R) -> TensorBase<Vec<T>, L>
where
L: MutLayout,
{
Self::from_simple_fn(shape, || rand_src.next())
}
pub fn zeros(shape: L::Index<'_>) -> TensorBase<Vec<T>, L>
where
T: Clone + Default,
L: MutLayout,
{
Self::zeros_in(GlobalAlloc::new(), shape)
}
pub fn zeros_in<A: Alloc>(alloc: A, shape: L::Index<'_>) -> TensorBase<Vec<T>, L>
where
T: Clone + Default,
L: MutLayout,
{
Self::full_in(alloc, shape, T::default())
}
pub fn uninit(shape: L::Index<'_>) -> TensorBase<Vec<MaybeUninit<T>>, L>
where
MaybeUninit<T>: Clone,
L: MutLayout,
{
Self::uninit_in(GlobalAlloc::new(), shape)
}
pub fn uninit_in<A: Alloc>(alloc: A, shape: L::Index<'_>) -> TensorBase<Vec<MaybeUninit<T>>, L>
where
L: MutLayout,
{
let len = shape.as_ref().iter().product();
let mut data = alloc.alloc(len);
unsafe { data.set_len(len) }
TensorBase::from_data(shape, data)
}
pub fn with_capacity(shape: L::Index<'_>, expand_dim: usize) -> TensorBase<Vec<T>, L>
where
T: Copy,
L: MutLayout,
{
Self::with_capacity_in(GlobalAlloc::new(), shape, expand_dim)
}
pub fn with_capacity_in<A: Alloc>(
alloc: A,
shape: L::Index<'_>,
expand_dim: usize,
) -> TensorBase<Vec<T>, L>
where
T: Copy,
L: MutLayout,
{
let mut tensor = Self::uninit_in(alloc, shape);
tensor.clip_dim(expand_dim, 0..0);
unsafe { tensor.assume_init() }
}
}
impl<T, L: Layout> TensorBase<CowData<'_, T>, L> {
pub fn into_non_contiguous_data(self) -> Option<Vec<T>> {
match self.data {
CowData::Owned(mut vec) => {
vec.truncate(self.layout.min_data_len());
Some(vec)
}
CowData::Borrowed(_) => None,
}
}
}
impl<T, S: Storage<Elem = MaybeUninit<T>> + AssumeInit, L: Layout + Clone> TensorBase<S, L>
where
<S as AssumeInit>::Output: Storage<Elem = T>,
{
pub unsafe fn assume_init(self) -> TensorBase<<S as AssumeInit>::Output, L> {
TensorBase {
layout: self.layout,
data: unsafe { self.data.assume_init() },
}
}
pub fn init_from<S2: Storage<Elem = T>>(
mut self,
other: &TensorBase<S2, L>,
) -> TensorBase<<S as AssumeInit>::Output, L>
where
T: Copy,
S: StorageMut<Elem = MaybeUninit<T>>,
{
assert_eq!(self.shape(), other.shape(), "shape mismatch");
match (self.data_mut(), other.data()) {
(Some(self_data), Some(other_data)) => {
let other_data: &[MaybeUninit<T>] = unsafe { std::mem::transmute(other_data) };
self_data.clone_from_slice(other_data);
}
(Some(self_data), _) => {
copy_into_slice(other.as_dyn(), self_data);
}
_ => {
copy_into_uninit(other.as_dyn(), self.as_dyn_mut());
}
}
unsafe { self.assume_init() }
}
}
impl<'a, T, L: Clone + Layout> TensorBase<ViewData<'a, T>, L> {
pub fn axis_iter(&self, dim: usize) -> AxisIter<'a, T, L>
where
L: MutLayout + RemoveDim,
{
AxisIter::new(self, dim)
}
pub fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<'a, T, L>
where
L: MutLayout,
{
AxisChunks::new(self, dim, chunk_size)
}
pub fn as_dyn(&self) -> TensorBase<ViewData<'a, T>, DynLayout> {
TensorBase {
data: self.data,
layout: DynLayout::from(&self.layout),
}
}
pub fn as_cow(&self) -> TensorBase<CowData<'a, T>, L> {
TensorBase {
layout: self.layout.clone(),
data: CowData::Borrowed(self.data),
}
}
pub fn broadcast<S: IntoLayout>(&self, shape: S) -> TensorBase<ViewData<'a, T>, S::Layout>
where
L: BroadcastLayout<S::Layout>,
{
self.try_broadcast(shape).unwrap()
}
pub fn try_broadcast<S: IntoLayout>(
&self,
shape: S,
) -> Result<TensorBase<ViewData<'a, T>, S::Layout>, ExpandError>
where
L: BroadcastLayout<S::Layout>,
{
Ok(TensorBase {
layout: self.layout.broadcast(shape)?,
data: self.data,
})
}
pub fn data(&self) -> Option<&'a [T]> {
let len = self.layout.min_data_len();
let data = self.data.slice(0..len);
self.layout.is_contiguous().then(|| unsafe {
data.as_slice()
})
}
pub fn storage(&self) -> ViewData<'a, T> {
self.data.view()
}
pub fn get<I: AsIndex<L>>(&self, index: I) -> Option<&'a T>
where
L: TrustedLayout,
{
self.offset(index.as_index()).map(|offset|
unsafe {
self.data.get_unchecked(offset)
})
}
pub fn from_slice_with_strides(
shape: L::Index<'_>,
data: &'a [T],
strides: L::Index<'_>,
) -> Result<TensorBase<ViewData<'a, T>, L>, FromDataError>
where
L: MutLayout,
{
let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)?;
if layout.min_data_len() > data.as_ref().len() {
return Err(FromDataError::StorageTooShort);
}
Ok(TensorBase {
data: data.into_storage(),
layout,
})
}
pub unsafe fn get_unchecked<I: AsIndex<L>>(&self, index: I) -> &'a T {
let offset = self.layout.offset_unchecked(index.as_index());
unsafe { self.data.get_unchecked(offset) }
}
pub fn index_axis(
&self,
axis: usize,
index: usize,
) -> TensorBase<ViewData<'a, T>, <L as RemoveDim>::Output>
where
L: MutLayout + RemoveDim,
{
let (offsets, layout) = self.layout.index_axis(axis, index);
TensorBase {
data: self.data.slice(offsets),
layout,
}
}
pub fn inner_iter<const N: usize>(&self) -> InnerIter<'a, T, NdLayout<N>> {
InnerIter::new(self.view())
}
pub fn inner_iter_dyn(&self, n: usize) -> InnerIter<'a, T, DynLayout> {
InnerIter::new_dyn(self.view(), n)
}
pub fn item(&self) -> Option<&'a T> {
match self.ndim() {
0 => unsafe {
self.data.get(0)
},
_ if self.len() == 1 => self.iter().next(),
_ => None,
}
}
pub fn iter(&self) -> Iter<'a, T> {
Iter::new(self.view_ref())
}
pub fn lanes(&self, dim: usize) -> Lanes<'a, T>
where
L: RemoveDim,
{
assert!(dim < self.ndim());
Lanes::new(self.view_ref(), dim)
}
pub fn nd_view<const N: usize>(&self) -> TensorBase<ViewData<'a, T>, NdLayout<N>> {
assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N);
TensorBase {
data: self.data,
layout: self.nd_layout().unwrap(),
}
}
pub fn permuted(&self, order: L::Index<'_>) -> TensorBase<ViewData<'a, T>, L>
where
L: MutLayout,
{
TensorBase {
data: self.data,
layout: self.layout.permuted(order),
}
}
pub fn reshaped<S: Copy + IntoLayout>(&self, shape: S) -> TensorBase<CowData<'a, T>, S::Layout>
where
T: Clone,
L: MutLayout,
{
self.reshaped_in(GlobalAlloc::new(), shape)
}
pub fn reshaped_in<A: Alloc, S: Copy + IntoLayout>(
&self,
alloc: A,
shape: S,
) -> TensorBase<CowData<'a, T>, S::Layout>
where
T: Clone,
L: MutLayout,
{
if let Ok(layout) = self.layout.reshaped_for_view(shape) {
TensorBase {
data: CowData::Borrowed(self.data),
layout,
}
} else {
let Ok(layout) = self.layout.reshaped_for_copy(shape) else {
panic!(
"element count mismatch reshaping {:?} to {:?}",
self.shape(),
shape
);
};
TensorBase {
data: CowData::Owned(self.to_vec_in(alloc)),
layout,
}
}
}
pub fn slice<R: IntoSliceItems + IndexCount>(
&self,
range: R,
) -> TensorBase<ViewData<'a, T>, <L as SliceWith<R, R::Count>>::Layout>
where
L: SliceWith<R, R::Count>,
{
self.try_slice(range).expect("slice failed")
}
pub fn slice_axis(&self, axis: usize, range: Range<usize>) -> TensorBase<ViewData<'a, T>, L>
where
L: MutLayout,
{
let (offset_range, sliced_layout) = self.layout.slice_axis(axis, range.clone()).unwrap();
debug_assert_eq!(sliced_layout.size(axis), range.len());
TensorBase {
data: self.data.slice(offset_range),
layout: sliced_layout,
}
}
#[allow(clippy::type_complexity)]
pub fn try_slice<R: IntoSliceItems + IndexCount>(
&self,
range: R,
) -> Result<TensorBase<ViewData<'a, T>, <L as SliceWith<R, R::Count>>::Layout>, SliceError>
where
L: SliceWith<R, R::Count>,
{
let (offset_range, sliced_layout) = self.layout.slice_with(range)?;
Ok(TensorBase {
data: self.data.slice(offset_range),
layout: sliced_layout,
})
}
pub fn squeezed(&self) -> TensorView<'a, T>
where
L: MutLayout,
{
TensorBase {
data: self.data.view(),
layout: self.layout.squeezed(),
}
}
#[allow(clippy::type_complexity)]
pub fn split_at(
&self,
axis: usize,
mid: usize,
) -> (
TensorBase<ViewData<'a, T>, L>,
TensorBase<ViewData<'a, T>, L>,
)
where
L: MutLayout,
{
let (left, right) = self.layout.split(axis, mid);
let (left_offset_range, left_layout) = left;
let (right_offset_range, right_layout) = right;
let left_data = self.data.slice(left_offset_range.clone());
let right_data = self.data.slice(right_offset_range.clone());
debug_assert_eq!(left_data.len(), left_layout.min_data_len());
let left_view = TensorBase {
data: left_data,
layout: left_layout,
};
debug_assert_eq!(right_data.len(), right_layout.min_data_len());
let right_view = TensorBase {
data: right_data,
layout: right_layout,
};
(left_view, right_view)
}
pub fn to_contiguous(&self) -> Contiguous<TensorBase<CowData<'a, T>, L>>
where
T: Clone,
L: MutLayout,
{
self.to_contiguous_in(GlobalAlloc::new())
}
pub fn to_contiguous_in<A: Alloc>(&self, alloc: A) -> Contiguous<TensorBase<CowData<'a, T>, L>>
where
T: Clone,
L: MutLayout,
{
let tensor = if let Some(data) = self.data() {
TensorBase {
data: CowData::Borrowed(data.into_storage()),
layout: self.layout.clone(),
}
} else {
let data = self.to_vec_in(alloc);
TensorBase {
data: CowData::Owned(data),
layout: L::from_shape(self.layout.shape()),
}
};
Contiguous::new(tensor).unwrap()
}
pub fn to_slice(&self) -> Cow<'a, [T]>
where
T: Clone,
{
self.data()
.map(Cow::Borrowed)
.unwrap_or_else(|| Cow::Owned(self.to_vec()))
}
pub fn transposed(&self) -> TensorBase<ViewData<'a, T>, L>
where
L: MutLayout,
{
TensorBase {
data: self.data,
layout: self.layout.transposed(),
}
}
pub fn try_slice_dyn<R: IntoSliceItems>(
&self,
range: R,
) -> Result<TensorView<'a, T>, SliceError>
where
L: MutLayout,
{
let (offset_range, layout) = self.layout.slice_dyn(range.into_slice_items().as_ref())?;
Ok(TensorBase {
data: self.data.slice(offset_range),
layout,
})
}
pub fn view(&self) -> TensorBase<ViewData<'a, T>, L> {
TensorBase {
data: self.data,
layout: self.layout.clone(),
}
}
pub(crate) fn view_ref(&self) -> TensorBase<ViewData<'a, T>, &L> {
TensorBase {
data: self.data,
layout: &self.layout,
}
}
pub fn weakly_checked_view(&self) -> WeaklyCheckedView<ViewData<'a, T>, L> {
WeaklyCheckedView { base: self.view() }
}
}
impl<S: Storage, L: Layout> Layout for TensorBase<S, L> {
type Index<'a> = L::Index<'a>;
type Indices = L::Indices;
fn ndim(&self) -> usize {
self.layout.ndim()
}
fn len(&self) -> usize {
self.layout.len()
}
fn is_empty(&self) -> bool {
self.layout.is_empty()
}
fn shape(&self) -> Self::Index<'_> {
self.layout.shape()
}
fn size(&self, dim: usize) -> usize {
self.layout.size(dim)
}
fn strides(&self) -> Self::Index<'_> {
self.layout.strides()
}
fn stride(&self, dim: usize) -> usize {
self.layout.stride(dim)
}
fn indices(&self) -> Self::Indices {
self.layout.indices()
}
fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
self.layout.offset(index)
}
}
impl<S: Storage, L: Layout + MatrixLayout> MatrixLayout for TensorBase<S, L> {
fn rows(&self) -> usize {
self.layout.rows()
}
fn cols(&self) -> usize {
self.layout.cols()
}
fn row_stride(&self) -> usize {
self.layout.row_stride()
}
fn col_stride(&self) -> usize {
self.layout.col_stride()
}
}
impl<T, S: Storage<Elem = T>, L: Layout + Clone> AsView for TensorBase<S, L> {
type Elem = T;
type Layout = L;
fn iter(&self) -> Iter<'_, T> {
self.view().iter()
}
fn copy_into_slice<'a>(&self, dest: &'a mut [MaybeUninit<T>]) -> &'a [T]
where
T: Copy,
{
if let Some(data) = self.data() {
let src_uninit = unsafe { std::mem::transmute::<&[T], &[MaybeUninit<T>]>(data) };
dest.copy_from_slice(src_uninit);
unsafe { dest.assume_init() }
} else {
copy_into_slice(self.as_dyn(), dest)
}
}
fn data(&self) -> Option<&[Self::Elem]> {
self.view().data()
}
fn insert_axis(&mut self, index: usize)
where
L: ResizeLayout,
{
self.layout.insert_axis(index)
}
#[track_caller]
fn remove_axis(&mut self, index: usize)
where
L: ResizeLayout,
{
self.layout.remove_axis(index)
}
fn merge_axes(&mut self)
where
L: ResizeLayout,
{
self.layout.merge_axes()
}
fn layout(&self) -> &L {
&self.layout
}
fn map<F, U>(&self, f: F) -> TensorBase<Vec<U>, L>
where
F: Fn(&Self::Elem) -> U,
L: MutLayout,
{
self.map_in(GlobalAlloc::new(), f)
}
fn map_in<A: Alloc, F, U>(&self, alloc: A, f: F) -> TensorBase<Vec<U>, L>
where
F: Fn(&Self::Elem) -> U,
L: MutLayout,
{
let len = self.len();
let mut buf = alloc.alloc(len);
if let Some(data) = self.data() {
buf.extend(data.iter().map(f));
} else {
let dest = &mut buf.spare_capacity_mut()[..len];
map_into_slice(self.as_dyn(), dest, f);
unsafe {
buf.set_len(len);
}
};
TensorBase::from_data(self.shape(), buf)
}
fn move_axis(&mut self, from: usize, to: usize)
where
L: MutLayout,
{
self.layout.move_axis(from, to);
}
fn view(&self) -> TensorBase<ViewData<'_, T>, L> {
TensorBase {
data: self.data.view(),
layout: self.layout.clone(),
}
}
fn get<I: AsIndex<L>>(&self, index: I) -> Option<&Self::Elem> {
self.offset(index.as_index()).map(|offset| unsafe {
self.data.get_unchecked(offset)
})
}
unsafe fn get_unchecked<I: AsIndex<L>>(&self, index: I) -> &T {
let offset = self.layout.offset_unchecked(index.as_index());
unsafe { self.data.get_unchecked(offset) }
}
fn permute(&mut self, order: Self::Index<'_>)
where
L: MutLayout,
{
self.layout = self.layout.permuted(order);
}
fn to_vec(&self) -> Vec<T>
where
T: Clone,
{
self.to_vec_in(GlobalAlloc::new())
}
fn to_vec_in<A: Alloc>(&self, alloc: A) -> Vec<T>
where
T: Clone,
{
let len = self.len();
let mut buf = alloc.alloc(len);
if let Some(data) = self.data() {
buf.extend_from_slice(data);
} else {
copy_into_slice(self.as_dyn(), &mut buf.spare_capacity_mut()[..len]);
unsafe { buf.set_len(len) }
}
buf
}
fn to_shape<SH: IntoLayout>(&self, shape: SH) -> TensorBase<Vec<Self::Elem>, SH::Layout>
where
T: Clone,
L: MutLayout,
{
TensorBase {
data: self.to_vec(),
layout: self
.layout
.reshaped_for_copy(shape)
.expect("reshape failed"),
}
}
fn transpose(&mut self)
where
L: MutLayout,
{
self.layout = self.layout.transposed();
}
}
impl<T, S: Storage<Elem = T>, const N: usize> TensorBase<S, NdLayout<N>> {
#[inline]
pub fn get_array<const M: usize>(&self, base: [usize; N], dim: usize) -> [T; M]
where
T: Copy + Default,
{
let offsets: [usize; M] = array_offsets(&self.layout, base, dim);
let mut result = [T::default(); M];
for i in 0..M {
result[i] = unsafe { *self.data.get_unchecked(offsets[i]) };
}
result
}
}
impl<T> TensorBase<Vec<T>, DynLayout> {
#[track_caller]
pub fn reshape(&mut self, shape: &[usize])
where
T: Clone,
{
self.reshape_in(GlobalAlloc::new(), shape)
}
#[track_caller]
pub fn reshape_in<A: Alloc>(&mut self, alloc: A, shape: &[usize])
where
T: Clone,
{
if !self.is_contiguous() {
self.data = self.to_vec_in(alloc);
}
let Ok(layout) = self.layout.reshaped_for_copy(shape) else {
panic!(
"element count mismatch reshaping {:?} to {:?}",
self.shape(),
shape
);
};
self.layout = layout;
}
}
impl<'a, T, L: Layout> TensorBase<ViewMutData<'a, T>, L> {
#[allow(clippy::type_complexity)]
pub fn split_at_mut(
self,
axis: usize,
mid: usize,
) -> (
TensorBase<ViewMutData<'a, T>, L>,
TensorBase<ViewMutData<'a, T>, L>,
)
where
L: MutLayout,
{
let (left, right) = self.layout.split(axis, mid);
let (left_offset_range, left_layout) = left;
let (right_offset_range, right_layout) = right;
let (left_data, right_data) = self
.data
.split_mut(left_offset_range.clone(), right_offset_range.clone());
debug_assert_eq!(left_data.len(), left_layout.min_data_len());
let left_view = TensorBase {
data: left_data,
layout: left_layout,
};
debug_assert_eq!(right_data.len(), right_layout.min_data_len());
let right_view = TensorBase {
data: right_data,
layout: right_layout,
};
(left_view, right_view)
}
pub fn into_slice_mut(self) -> Option<&'a mut [T]> {
let len = self.layout.min_data_len();
self.is_contiguous().then(|| {
let slice = unsafe { self.data.to_slice_mut() };
&mut slice[..len]
})
}
}
impl<T, L: MutLayout> FromIterator<T> for TensorBase<Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
{
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> TensorBase<Vec<T>, L> {
let data: Vec<T> = iter.into_iter().collect();
TensorBase::from_data([data.len()].as_index(), data)
}
}
impl<T, L: MutLayout> From<Vec<T>> for TensorBase<Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
{
fn from(vec: Vec<T>) -> Self {
Self::from_data([vec.len()].as_index(), vec)
}
}
impl<'a, T, L: MutLayout> From<&'a [T]> for TensorBase<ViewData<'a, T>, L>
where
[usize; 1]: AsIndex<L>,
{
fn from(slice: &'a [T]) -> Self {
Self::from_data([slice.len()].as_index(), slice)
}
}
impl<'a, T, L: MutLayout, const N: usize> From<&'a [T; N]> for TensorBase<ViewData<'a, T>, L>
where
[usize; 1]: AsIndex<L>,
{
fn from(slice: &'a [T; N]) -> Self {
Self::from_data([slice.len()].as_index(), slice.as_slice())
}
}
fn array_offsets<const N: usize, const M: usize>(
layout: &NdLayout<N>,
base: [usize; N],
dim: usize,
) -> [usize; M] {
assert!(
base[dim] < usize::MAX - M && layout.size(dim) >= base[dim] + M,
"array indices invalid"
);
let offset = layout.must_offset(base);
let stride = layout.stride(dim);
let mut offsets = [0; M];
for i in 0..M {
offsets[i] = offset + i * stride;
}
offsets
}
impl<T, S: StorageMut<Elem = T>, const N: usize> TensorBase<S, NdLayout<N>> {
#[inline]
pub fn set_array<const M: usize>(&mut self, base: [usize; N], dim: usize, values: [T; M])
where
T: Copy,
{
let offsets: [usize; M] = array_offsets(&self.layout, base, dim);
for i in 0..M {
unsafe { *self.data.get_unchecked_mut(offsets[i]) = values[i] };
}
}
}
impl<T, S: Storage<Elem = T>> TensorBase<S, NdLayout<1>> {
#[inline]
pub fn to_array<const M: usize>(&self) -> [T; M]
where
T: Copy + Default,
{
self.get_array([0], 0)
}
}
impl<T, S: StorageMut<Elem = T>> TensorBase<S, NdLayout<1>> {
#[inline]
pub fn assign_array<const M: usize>(&mut self, values: [T; M])
where
T: Copy + Default,
{
self.set_array([0], 0, values)
}
}
pub type NdTensorView<'a, T, const N: usize> = TensorBase<ViewData<'a, T>, NdLayout<N>>;
pub type NdTensor<T, const N: usize> = TensorBase<Vec<T>, NdLayout<N>>;
pub type NdTensorViewMut<'a, T, const N: usize> = TensorBase<ViewMutData<'a, T>, NdLayout<N>>;
pub type CowNdTensor<'a, T, const N: usize> = TensorBase<CowData<'a, T>, NdLayout<N>>;
pub type Matrix<'a, T = f32> = NdTensorView<'a, T, 2>;
pub type MatrixMut<'a, T = f32> = NdTensorViewMut<'a, T, 2>;
pub type Tensor<T = f32> = TensorBase<Vec<T>, DynLayout>;
pub type TensorView<'a, T = f32> = TensorBase<ViewData<'a, T>, DynLayout>;
pub type TensorViewMut<'a, T = f32> = TensorBase<ViewMutData<'a, T>, DynLayout>;
pub type CowTensor<'a, T> = TensorBase<CowData<'a, T>, DynLayout>;
pub type ArcTensor<T> = TensorBase<Arc<Vec<T>>, DynLayout>;
pub type ArcNdTensor<T, const N: usize> = TensorBase<Arc<Vec<T>>, NdLayout<N>>;
impl<T, S: Storage<Elem = T>, L: TrustedLayout, I: AsIndex<L>> Index<I> for TensorBase<S, L> {
type Output = T;
fn index(&self, index: I) -> &Self::Output {
let offset = self.layout.must_offset(index.as_index());
unsafe { self.data.get_unchecked(offset) }
}
}
impl<T, S: StorageMut<Elem = T>, L: TrustedLayout, I: AsIndex<L>> IndexMut<I> for TensorBase<S, L> {
fn index_mut(&mut self, index: I) -> &mut Self::Output {
let index = index.as_index();
let offset = self.layout.must_offset(index);
unsafe { self.data.get_unchecked_mut(offset) }
}
}
impl<T, S: Storage<Elem = T> + Clone, L: Layout + Clone> Clone for TensorBase<S, L> {
fn clone(&self) -> TensorBase<S, L> {
let data = self.data.clone();
TensorBase {
data,
layout: self.layout.clone(),
}
}
}
impl<T, S: Storage<Elem = T> + Copy, L: Layout + Copy> Copy for TensorBase<S, L> {}
impl<T: PartialEq, S: Storage<Elem = T>, L: Layout + Clone, V: AsView<Elem = T>> PartialEq<V>
for TensorBase<S, L>
{
fn eq(&self, other: &V) -> bool {
self.shape().as_ref() == other.shape().as_ref() && self.iter().eq(other.iter())
}
}
impl<T, S: Storage<Elem = T>, const N: usize> From<TensorBase<S, NdLayout<N>>>
for TensorBase<S, DynLayout>
{
fn from(tensor: TensorBase<S, NdLayout<N>>) -> Self {
Self {
data: tensor.data,
layout: tensor.layout.into(),
}
}
}
impl<T, S1: Storage<Elem = T>, S2: Storage<Elem = T>, const N: usize>
TryFrom<TensorBase<S1, DynLayout>> for TensorBase<S2, NdLayout<N>>
where
S1: Into<S2>,
{
type Error = DimensionError;
fn try_from(value: TensorBase<S1, DynLayout>) -> Result<Self, Self::Error> {
let layout: NdLayout<N> = value.layout().try_into()?;
Ok(TensorBase {
data: value.data.into(),
layout,
})
}
}
pub trait Scalar {}
macro_rules! impl_scalar {
($ty:ty) => {
impl Scalar for $ty {}
};
}
impl_scalar!(bool);
impl_scalar!(u8);
impl_scalar!(i8);
impl_scalar!(u16);
impl_scalar!(i16);
impl_scalar!(u32);
impl_scalar!(i32);
impl_scalar!(u64);
impl_scalar!(i64);
impl_scalar!(usize);
impl_scalar!(isize);
impl_scalar!(f32);
impl_scalar!(f64);
impl_scalar!(String);
impl<T: Clone + Scalar, L: MutLayout> From<T> for TensorBase<Vec<T>, L>
where
[usize; 0]: AsIndex<L>,
{
fn from(value: T) -> Self {
Self::from_scalar(value)
}
}
impl<T: Clone + Scalar, L: MutLayout, const D0: usize> From<[T; D0]> for TensorBase<Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
{
fn from(value: [T; D0]) -> Self {
let data: Vec<T> = value.iter().cloned().collect();
Self::from_data([D0].as_index(), data)
}
}
impl<T: Clone + Scalar, L: MutLayout, const D0: usize, const D1: usize> From<[[T; D1]; D0]>
for TensorBase<Vec<T>, L>
where
[usize; 2]: AsIndex<L>,
{
fn from(value: [[T; D1]; D0]) -> Self {
let data: Vec<_> = value.iter().flat_map(|y| y.iter()).cloned().collect();
Self::from_data([D0, D1].as_index(), data)
}
}
impl<T: Clone + Scalar, L: MutLayout, const D0: usize, const D1: usize, const D2: usize>
From<[[[T; D2]; D1]; D0]> for TensorBase<Vec<T>, L>
where
[usize; 3]: AsIndex<L>,
{
fn from(value: [[[T; D2]; D1]; D0]) -> Self {
let data: Vec<_> = value
.iter()
.flat_map(|y| y.iter().flat_map(|z| z.iter()))
.cloned()
.collect();
Self::from_data([D0, D1, D2].as_index(), data)
}
}
pub struct WeaklyCheckedView<S: Storage, L: Layout> {
base: TensorBase<S, L>,
}
impl<T, S: Storage<Elem = T>, L: Layout> Layout for WeaklyCheckedView<S, L> {
type Index<'a> = L::Index<'a>;
type Indices = L::Indices;
fn ndim(&self) -> usize {
self.base.ndim()
}
fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
self.base.offset(index)
}
fn len(&self) -> usize {
self.base.len()
}
fn shape(&self) -> Self::Index<'_> {
self.base.shape()
}
fn strides(&self) -> Self::Index<'_> {
self.base.strides()
}
fn indices(&self) -> Self::Indices {
self.base.indices()
}
}
impl<T, S: Storage<Elem = T>, L: Layout, I: AsIndex<L>> Index<I> for WeaklyCheckedView<S, L> {
type Output = T;
fn index(&self, index: I) -> &Self::Output {
let offset = self.base.layout.offset_unchecked(index.as_index());
unsafe {
self.base.data.get(offset).expect("invalid offset")
}
}
}
impl<T, S: StorageMut<Elem = T>, L: Layout, I: AsIndex<L>> IndexMut<I> for WeaklyCheckedView<S, L> {
fn index_mut(&mut self, index: I) -> &mut Self::Output {
let offset = self.base.layout.offset_unchecked(index.as_index());
unsafe {
self.base.data.get_mut(offset).expect("invalid offset")
}
}
}
#[cfg(test)]
mod tests;