use anyhow::{anyhow, bail, Result};
use dry::macro_for;
#[cfg(feature = "serde")]
use krnl::buffer::{CowBuffer, ScalarCowBuffer};
#[cfg(feature = "device")]
use krnl::krnl_core::half::bf16;
#[cfg(doc)]
use krnl::{buffer::ArcBuffer, device::error::DeviceLost};
use krnl::{
buffer::{
ArcBufferRepr, Buffer, BufferBase, BufferRepr, CowBufferRepr, Data, DataMut, DataOwned,
ScalarArcBufferRepr, ScalarBuffer, ScalarBufferBase, ScalarBufferRepr, ScalarCowBufferRepr,
ScalarData, ScalarDataMut, ScalarDataOwned, ScalarSlice, ScalarSliceMut,
ScalarSliceMutRepr, ScalarSliceRepr, Slice, SliceMut, SliceMutRepr, SliceRepr,
},
device::Device,
scalar::{Scalar, ScalarElem, ScalarType},
};
use ndarray::{
Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Dimension, IntoDimension, Ix0, Ix1, Ix2, Ix3,
Ix4, Ix5, Ix6, IxDyn, RawArrayView, RemoveAxis, ShapeError, StrideShape,
};
#[cfg(feature = "device")]
use num_traits::ToPrimitive;
use paste::paste;
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt::{self, Debug};
mod linalg;
mod ops;
pub(crate) mod parallel;
mod reduce;
fn strides_from_array<S, D>(array: &ArrayBase<S, D>) -> D
where
S: ndarray::RawData,
D: Dimension,
{
let strides_slice: &[usize] = bytemuck::cast_slice(array.strides());
let mut strides = D::zeros(strides_slice.len());
for (i, s) in strides_slice.iter().copied().enumerate() {
strides[i] = s;
}
strides
}
fn dim_strides_from_shape<D: Dimension>(shape: impl Into<StrideShape<D>>) -> (D, D) {
let array = unsafe { RawArrayView::from_shape_ptr(shape, &()) };
let dim = array.raw_dim();
let strides = strides_from_array(&array);
(dim, strides)
}
fn into_dimensionality<D1, D2>(dim: &D1, strides: &D1) -> Result<(D2, D2), ShapeError>
where
D1: Dimension,
D2: Dimension,
{
D2::from_dimension(dim)
.and_then(|dim| D2::from_dimension(strides).map(|strides| (dim, strides)))
.ok_or(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape))
}
fn into_shape<D1, E>(dim: &D1, strides: &D1, shape: E) -> Result<(E::Dim, E::Dim), ShapeError>
where
D1: Dimension,
E: IntoDimension,
{
use ndarray::ErrorKind;
let shape = shape.into_dimension();
if size_of_shape_checked(&shape)? != dim.size() {
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
} else if is_standard_layout(dim, strides) {
let strides = shape.default_strides();
Ok((shape, strides))
} else if is_fortran_layout(dim, strides) {
let strides = shape.fortran_strides();
Ok((shape, strides))
} else {
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout))
}
}
pub(crate) fn flatten(shape: &[usize]) -> [usize; 2] {
let mut iter = shape.iter().copied();
let rows = iter.next().unwrap_or(1);
let cols = iter.product();
[rows, cols]
}
fn is_contiguous<D: Dimension>(dim: &D, strides: &D) -> bool {
is_standard_layout(dim, strides) || is_fortran_layout(dim, strides)
}
fn is_standard_layout<D: Dimension>(dim: &D, strides: &D) -> bool {
debug_assert_eq!(dim.ndim(), strides.ndim());
for d in dim.slice().iter().copied() {
if d == 0 {
return true;
}
}
let mut acc = 1isize;
let strides: &[isize] = bytemuck::cast_slice(strides.slice());
for (d, s) in dim
.slice()
.iter()
.copied()
.zip(strides.iter().copied())
.rev()
{
if !(d == 1 || s == acc) {
return false;
}
acc *= d as isize;
}
true
}
fn is_fortran_layout<D: Dimension>(dim: &D, strides: &D) -> bool {
debug_assert_eq!(dim.ndim(), strides.ndim());
for d in dim.slice().iter().copied() {
if d == 0 {
return true;
}
}
let mut acc = 1;
for (d, s) in dim
.slice()
.iter()
.copied()
.zip(strides.slice().iter().copied())
{
if !(d == 1 || s == acc) {
return false;
}
acc *= d;
}
true
}
fn permuted_axes<D: Dimension>(dim: D, strides: D, axes: D) -> (D, D) {
let mut usage_counts = D::zeros(dim.ndim());
for axis in axes.slice() {
usage_counts[*axis] += 1;
}
for count in usage_counts.slice() {
assert_eq!(*count, 1, "each axis must be listed exactly once");
}
let mut new_dim = usage_counts; let mut new_strides = D::zeros(dim.ndim());
{
let dim = dim.slice();
let strides = strides.slice();
for (new_axis, &axis) in axes.slice().iter().enumerate() {
new_dim[new_axis] = dim[axis];
new_strides[new_axis] = strides[axis];
}
}
(new_dim, new_strides)
}
fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError> {
use ndarray::ErrorKind;
let size_nonzero = dim
.slice()
.iter()
.filter(|&&d| d != 0)
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.ok_or_else(|| ShapeError::from_kind(ErrorKind::Overflow))?;
if size_nonzero > isize::MAX as usize {
Err(ShapeError::from_kind(ErrorKind::Overflow))
} else {
Ok(dim.size())
}
}
fn broadcast<D: Dimension, E: IntoDimension>(
from: &D,
strides: &D,
dim: E,
) -> Option<(E::Dim, E::Dim)> {
fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
let _ = size_of_shape_checked(to).ok()?;
let mut new_stride = to.clone();
if to.ndim() < from.ndim() {
return None;
}
{
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
for ((er, es), dr) in from
.slice()
.iter()
.rev()
.zip(stride.slice().iter().rev())
.zip(new_stride_iter.by_ref())
{
if *dr == *er {
*dr = *es;
} else if *er == 1 {
*dr = 0
} else {
return None;
}
}
for dr in new_stride_iter {
*dr = 0;
}
}
Some(new_stride)
}
let dim = dim.into_dimension();
let broadcast_strides = match upcast(&dim, from, strides) {
Some(st) => st,
None => return None,
};
Some((dim, broadcast_strides))
}
fn collapse_axis<D: Dimension>(dims: &mut D, strides: &D, Axis(axis): Axis, index: usize) -> isize {
let dim = dims[axis];
assert!(index < dim);
dims.slice_mut()[axis] = 1;
index as isize * strides[axis] as isize
}
fn tensor_buffer_len(offset: usize, shape: &[usize], strides: &[isize]) -> Option<usize> {
if shape.iter().any(|x| *x == 0) {
Some(0)
} else if strides.iter().any(|x| *x < 0) {
None
} else {
let len = (shape
.iter()
.zip(strides)
.map(|(d, s)| (*d as isize - 1) * *s)
.sum::<isize>()
+ offset as isize
+ 1)
.try_into()
.unwrap();
Some(len)
}
}
#[derive(Clone)]
pub struct ScalarTensorBase<S: ScalarData, D: Dimension> {
dim: D,
strides: D,
buffer: ScalarBufferBase<S>,
offset: usize,
}
pub type ScalarTensor<D> = ScalarTensorBase<ScalarBufferRepr, D>;
pub type ScalarTensor0 = ScalarTensor<Ix0>;
pub type ScalarTensor1 = ScalarTensor<Ix1>;
pub type ScalarTensor2 = ScalarTensor<Ix2>;
pub type ScalarTensor3 = ScalarTensor<Ix3>;
pub type ScalarTensor4 = ScalarTensor<Ix4>;
pub type ScalarTensor5 = ScalarTensor<Ix5>;
pub type ScalarTensor6 = ScalarTensor<Ix6>;
pub type ScalarTensorD = ScalarTensor<IxDyn>;
pub type ScalarArcTensor<D> = ScalarTensorBase<ScalarArcBufferRepr, D>;
pub type ScalarArcTensor0 = ScalarArcTensor<Ix0>;
pub type ScalarArcTensor1 = ScalarArcTensor<Ix1>;
pub type ScalarArcTensor2 = ScalarArcTensor<Ix2>;
pub type ScalarArcTensor3 = ScalarArcTensor<Ix3>;
pub type ScalarArcTensor4 = ScalarArcTensor<Ix4>;
pub type ScalarArcTensor5 = ScalarArcTensor<Ix5>;
pub type ScalarArcTensor6 = ScalarArcTensor<Ix6>;
pub type ScalarArcTensorD = ScalarArcTensor<IxDyn>;
pub type ScalarTensorView<'a, D> = ScalarTensorBase<ScalarSliceRepr<'a>, D>;
pub type ScalarTensorView0<'a> = ScalarTensorView<'a, Ix0>;
pub type ScalarTensorView1<'a> = ScalarTensorView<'a, Ix1>;
pub type ScalarTensorView2<'a> = ScalarTensorView<'a, Ix2>;
pub type ScalarTensorView3<'a> = ScalarTensorView<'a, Ix3>;
pub type ScalarTensorView4<'a> = ScalarTensorView<'a, Ix4>;
pub type ScalarTensorView5<'a> = ScalarTensorView<'a, Ix5>;
pub type ScalarTensorView6<'a> = ScalarTensorView<'a, Ix6>;
pub type ScalarTensorViewD<'a> = ScalarTensorView<'a, IxDyn>;
pub type ScalarTensorViewMut<'a, D> = ScalarTensorBase<ScalarSliceMutRepr<'a>, D>;
pub type ScalarTensorViewMut0<'a> = ScalarTensorViewMut<'a, Ix0>;
pub type ScalarTensorViewMut1<'a> = ScalarTensorViewMut<'a, Ix1>;
pub type ScalarTensorViewMut2<'a> = ScalarTensorViewMut<'a, Ix2>;
pub type ScalarTensorViewMut3<'a> = ScalarTensorViewMut<'a, Ix3>;
pub type ScalarTensorViewMut4<'a> = ScalarTensorViewMut<'a, Ix4>;
pub type ScalarTensorViewMut5<'a> = ScalarTensorViewMut<'a, Ix5>;
pub type ScalarTensorViewMut6<'a> = ScalarTensorViewMut<'a, Ix6>;
pub type ScalarTensorViewMutD<'a> = ScalarTensorViewMut<'a, IxDyn>;
pub type ScalarCowTensor<'a, D> = ScalarTensorBase<ScalarCowBufferRepr<'a>, D>;
pub type ScalarCowTensor0<'a> = ScalarCowTensor<'a, Ix0>;
pub type ScalarCowTensor1<'a> = ScalarCowTensor<'a, Ix1>;
pub type ScalarCowTensor2<'a> = ScalarCowTensor<'a, Ix2>;
pub type ScalarCowTensor3<'a> = ScalarCowTensor<'a, Ix3>;
pub type ScalarCowTensor4<'a> = ScalarCowTensor<'a, Ix4>;
pub type ScalarCowTensor5<'a> = ScalarCowTensor<'a, Ix5>;
pub type ScalarCowTensor6<'a> = ScalarCowTensor<'a, Ix6>;
pub type ScalarCowTensorD<'a> = ScalarCowTensor<'a, IxDyn>;
impl<S: ScalarDataOwned, D: Dimension> ScalarTensorBase<S, D> {
pub unsafe fn uninit<Sh>(device: Device, shape: Sh, scalar_type: ScalarType) -> Result<Self>
where
Sh: ndarray::ShapeBuilder<Dim = D>,
{
let (dim, strides) = dim_strides_from_shape(shape.into_shape());
let buffer = unsafe { ScalarBufferBase::uninit(device, dim.size(), scalar_type)? };
Ok(Self {
dim,
strides,
buffer,
offset: 0,
})
}
pub fn from_elem<Sh>(device: Device, shape: Sh, elem: ScalarElem) -> Result<Self>
where
Sh: ndarray::ShapeBuilder<Dim = D>,
{
let (dim, strides) = dim_strides_from_shape(shape.into_shape());
let buffer = ScalarBufferBase::from_elem(device, dim.size(), elem)?;
Ok(Self {
dim,
strides,
buffer,
offset: 0,
})
}
pub fn zeros<Sh>(device: Device, shape: Sh, scalar_type: ScalarType) -> Result<Self>
where
Sh: ndarray::ShapeBuilder<Dim = D>,
{
Self::from_elem(device, shape, ScalarElem::zero(scalar_type))
}
pub fn ones<Sh>(device: Device, shape: Sh, scalar_type: ScalarType) -> Result<Self>
where
Sh: ndarray::ShapeBuilder<Dim = D>,
{
Self::from_elem(device, shape, ScalarElem::one(scalar_type))
}
}
impl<S: ScalarData, D: Dimension> ScalarTensorBase<S, D> {
pub fn device(&self) -> Device {
self.buffer.device()
}
pub fn scalar_type(&self) -> ScalarType {
self.buffer.scalar_type()
}
pub fn dim(&self) -> D::Pattern {
self.dim.clone().into_pattern()
}
pub fn raw_dim(&self) -> D {
self.dim.clone()
}
pub fn shape(&self) -> &[usize] {
self.dim.slice()
}
pub fn strides(&self) -> &[isize] {
bytemuck::cast_slice(self.strides.slice())
}
pub fn len(&self) -> usize {
self.dim.size()
}
pub fn is_empty(&self) -> bool {
self.shape().iter().any(|x| *x == 0)
}
pub fn ndim(&self) -> usize {
self.dim.ndim()
}
pub fn into_dimensionality<D2>(self) -> Result<ScalarTensorBase<S, D2>, ShapeError>
where
D2: Dimension,
{
let (dim, strides) = into_dimensionality(&self.dim, &self.strides)?;
Ok(ScalarTensorBase {
dim,
strides,
buffer: self.buffer,
offset: self.offset,
})
}
pub fn into_dyn(self) -> ScalarTensorBase<S, IxDyn> {
ScalarTensorBase {
dim: self.dim.into_dyn(),
strides: self.strides.into_dyn(),
buffer: self.buffer,
offset: self.offset,
}
}
pub fn into_shape<E>(self, shape: E) -> Result<ScalarTensorBase<S, E::Dim>, ShapeError>
where
E: IntoDimension,
{
let shape = shape.into_dimension();
let (dim, strides) = into_shape(&self.dim, &self.strides, shape)?;
assert_eq!(self.offset, 0);
Ok(ScalarTensorBase {
dim,
strides,
buffer: self.buffer,
offset: self.offset,
})
}
pub fn broadcast<E>(&self, dim: E) -> Option<ScalarTensorView<E::Dim>>
where
E: IntoDimension,
{
let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
Some(ScalarTensorView {
dim,
strides,
buffer: self.buffer.as_scalar_slice(),
offset: self.offset,
})
}
pub fn view(&self) -> ScalarTensorView<D> {
ScalarTensorView {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.as_scalar_slice(),
offset: self.offset,
}
}
pub fn view_mut(&mut self) -> ScalarTensorViewMut<D>
where
S: ScalarDataMut,
{
ScalarTensorViewMut {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.as_scalar_slice_mut(),
offset: self.offset,
}
}
pub fn get_view_mut(&mut self) -> Option<ScalarTensorViewMut<D>> {
if self.offset == 0 && self.is_contiguous() {
let buffer = self.buffer.get_scalar_slice_mut()?;
Some(ScalarTensorViewMut {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer,
offset: 0,
})
} else {
None
}
}
pub fn make_view_mut(&mut self) -> Result<ScalarTensorViewMut<D>>
where
S: ScalarDataOwned,
{
if self.offset == 0 && self.is_contiguous() {
Ok(ScalarTensorViewMut {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.make_scalar_slice_mut()?,
offset: 0,
})
} else {
let tensor = self.to_owned()?;
*self = Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: ScalarBufferBase::from_scalar_buffer(tensor.buffer),
offset: 0,
};
Ok(ScalarTensorViewMut {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.get_scalar_slice_mut().unwrap(),
offset: 0,
})
}
}
pub fn is_contiguous(&self) -> bool {
is_contiguous(&self.dim, &self.strides)
}
pub fn is_standard_layout(&self) -> bool {
is_standard_layout(&self.dim, &self.strides)
}
pub fn permuted_axes<A>(self, axes: A) -> Self
where
A: IntoDimension<Dim = D>,
{
let (dim, strides) = permuted_axes(self.dim, self.strides, axes.into_dimension());
Self {
dim,
strides,
..self
}
}
pub fn reversed_axes(mut self) -> Self {
self.dim.slice_mut().reverse();
self.strides.slice_mut().reverse();
self
}
pub fn t(&self) -> ScalarTensorView<D> {
self.view().reversed_axes()
}
pub fn index_axis(&self, axis: Axis, index: usize) -> ScalarTensorView<D::Smaller>
where
D: RemoveAxis,
{
self.view().index_axis_into(axis, index)
}
pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> ScalarTensorViewMut<D::Smaller>
where
S: ScalarDataMut,
D: RemoveAxis,
{
self.view_mut().index_axis_into(axis, index)
}
pub fn index_axis_into(mut self, axis: Axis, index: usize) -> ScalarTensorBase<S, D::Smaller>
where
D: RemoveAxis,
{
self.collapse_axis(axis, index);
let dim = self.dim.remove_axis(axis);
let strides = self.strides.remove_axis(axis);
ScalarTensorBase {
dim,
strides,
buffer: self.buffer,
offset: self.offset,
}
}
pub fn collapse_axis(&mut self, axis: Axis, index: usize) {
let offset =
collapse_axis(&mut self.dim, &self.strides, axis, index) + self.offset as isize;
debug_assert!(offset >= 0);
self.offset = offset as usize;
debug_assert!(self.offset < self.buffer.len());
}
pub fn as_scalar_slice(&self) -> Option<ScalarSlice> {
if self.is_standard_layout() {
let (slice, _offset) = self.as_raw_scalar_slice_offset();
Some(slice)
} else {
None
}
}
pub fn as_scalar_slice_memory_order(&self) -> Option<ScalarSlice> {
if self.is_contiguous() {
let (slice, _offset) = self.as_raw_scalar_slice_offset();
Some(slice)
} else {
None
}
}
pub fn as_scalar_slice_mut(&mut self) -> Option<ScalarSliceMut>
where
S: ScalarDataMut,
{
if self.is_standard_layout() {
let (slice, _offset) = self.as_raw_scalar_slice_offset_mut();
Some(slice)
} else {
None
}
}
pub fn as_scalar_slice_memory_order_mut(&mut self) -> Option<ScalarSliceMut>
where
S: ScalarDataMut,
{
if self.is_contiguous() {
let (slice, _offset) = self.as_raw_scalar_slice_offset_mut();
Some(slice)
} else {
None
}
}
pub fn as_raw_scalar_slice_offset(&self) -> (ScalarSlice, usize) {
let strides: &[isize] = Self::strides(self);
if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
let slice = self.buffer.slice(self.offset..self.offset + len).unwrap();
(slice, 0)
} else {
(self.buffer.as_scalar_slice(), self.offset)
}
}
pub fn as_raw_scalar_slice_offset_mut(&mut self) -> (ScalarSliceMut, usize)
where
S: ScalarDataMut,
{
let strides: &[isize] = Self::strides(self);
if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
let slice = self
.buffer
.slice_mut(self.offset..self.offset + len)
.unwrap();
(slice, 0)
} else {
(self.buffer.as_scalar_slice_mut(), self.offset)
}
}
pub fn into_device(self, device: Device) -> Result<ScalarTensor<D>> {
if self.device() == device {
self.into_owned()
} else if let Some(slice) = self.as_scalar_slice_memory_order() {
let buffer = slice.to_device(device)?;
Ok(ScalarTensor {
dim: self.dim,
strides: self.strides,
buffer,
offset: 0,
})
} else {
self.into_owned()?.into_device(device)
}
}
pub fn to_device(&self, device: Device) -> Result<ScalarTensor<D>> {
if self.device() == device {
self.to_owned()
} else {
self.view().into_device(device)
}
}
pub fn to_device_mut(&mut self, device: Device) -> Result<()>
where
S: ScalarDataOwned,
{
if self.device() == device {
return Ok(());
}
let ScalarTensor {
dim,
strides,
buffer,
offset,
} = self.to_device(device)?;
*self = Self {
dim,
strides,
buffer: ScalarBufferBase::from_scalar_buffer(buffer),
offset,
};
Ok(())
}
pub fn into_device_shared(self, device: Device) -> Result<ScalarArcTensor<D>> {
if self.device() == device {
self.into_shared()
} else {
self.to_device(device).map(Into::into)
}
}
pub fn to_device_shared(&self, device: Device) -> Result<ScalarArcTensor<D>> {
if device == self.device() {
self.to_shared()
} else {
self.to_device(device).map(Into::into)
}
}
pub fn into_owned(self) -> Result<ScalarTensor<D>> {
if self.offset == 0 && self.is_contiguous() {
return Ok(ScalarTensorBase {
dim: self.dim,
strides: self.strides,
buffer: self.buffer.into_owned()?,
offset: 0,
});
}
if let Some(slice) = self.as_scalar_slice_memory_order() {
let buffer = slice.to_owned()?;
return Ok(ScalarTensorBase {
dim: self.dim,
strides: self.strides,
buffer,
offset: 0,
});
}
let mut output =
unsafe { ScalarTensor::uninit(self.device(), self.raw_dim(), self.scalar_type())? };
output.assign(&self)?;
Ok(output)
}
pub fn to_owned(&self) -> Result<ScalarTensor<D>> {
self.view().into_owned()
}
pub fn into_shared(self) -> Result<ScalarArcTensor<D>> {
if self.offset == 0 && self.is_contiguous() {
Ok(ScalarTensorBase {
dim: self.dim,
strides: self.strides,
buffer: self.buffer.into_shared()?,
offset: 0,
})
} else {
self.as_standard_layout()?.into_shared()
}
}
pub fn to_shared(&self) -> Result<ScalarArcTensor<D>> {
if !self.is_contiguous() {
return self.as_standard_layout()?.to_shared();
}
Ok(ScalarTensorBase {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.to_shared()?,
offset: 0,
})
}
}
impl<D: Dimension> ScalarTensor<D> {
pub fn try_into_tensor<T: Scalar>(self) -> Result<Tensor<T, D>, Self> {
self.try_into()
}
}
impl<D: Dimension> ScalarArcTensor<D> {
pub fn try_into_arc_tensor<T: Scalar>(self) -> Result<ArcTensor<T, D>, Self> {
self.try_into()
}
}
impl<'a, D: Dimension> ScalarTensorView<'a, D> {
pub fn try_into_tensor_view<T: Scalar>(self) -> Result<TensorView<'a, T, D>, Self> {
self.try_into()
}
}
impl<'a, D: Dimension> ScalarTensorViewMut<'a, D> {
pub fn try_into_tensor_view_mut<T: Scalar>(self) -> Result<TensorViewMut<'a, T, D>, Self> {
self.try_into()
}
}
impl<D: Dimension> ScalarArcTensor<D> {
pub fn broadcast_shared<E>(&self, dim: E) -> Option<ScalarArcTensor<E::Dim>>
where
E: IntoDimension,
{
let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
Some(ScalarArcTensor {
dim,
strides,
buffer: self.buffer.clone(),
offset: self.offset,
})
}
}
impl<S: ScalarDataOwned> From<ScalarBuffer> for ScalarTensorBase<S, Ix1> {
fn from(buffer: ScalarBuffer) -> Self {
let dim = buffer.len().into_dimension();
let strides = dim.default_strides();
let buffer = ScalarBufferBase::from_scalar_buffer(buffer);
Self {
dim,
strides,
buffer,
offset: 0,
}
}
}
impl<S: ScalarDataOwned, T: Scalar, D: Dimension> From<Tensor<T, D>> for ScalarTensorBase<S, D> {
fn from(tensor: Tensor<T, D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: tensor.buffer.into(),
offset: tensor.offset,
}
}
}
impl<D: Dimension> From<ScalarTensor<D>> for ScalarArcTensor<D> {
fn from(tensor: ScalarTensor<D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: tensor.buffer.into(),
offset: tensor.offset,
}
}
}
impl<T: Scalar, D: Dimension> From<ArcTensor<T, D>> for ScalarArcTensor<D> {
fn from(tensor: ArcTensor<T, D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: tensor.buffer.into(),
offset: tensor.offset,
}
}
}
impl<D: Dimension> From<ScalarTensor<D>> for ScalarCowTensor<'_, D> {
fn from(tensor: ScalarTensor<D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: tensor.buffer.into(),
offset: tensor.offset,
}
}
}
impl<'a, D: Dimension> From<ScalarTensorView<'a, D>> for ScalarCowTensor<'a, D> {
fn from(tensor: ScalarTensorView<'a, D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: tensor.buffer.into(),
offset: tensor.offset,
}
}
}
macro_for!($Tensor in [Tensor, ArcTensor] {
paste! {
impl<T: Scalar, D: Dimension> TryFrom<[<Scalar $Tensor>]<D>> for $Tensor<T, D> {
type Error = [<Scalar $Tensor>]<D>;
fn try_from(tensor: [<Scalar $Tensor>]<D>) -> Result<Self, Self::Error> {
match tensor.buffer.try_into() {
Ok(buffer) => Ok(Self {
dim: tensor.dim,
strides: tensor.strides,
buffer,
offset: tensor.offset,
}),
Err(buffer) => Err(Self::Error {
dim: tensor.dim,
strides: tensor.strides,
buffer,
offset: tensor.offset,
})
}
}
}
}
});
macro_for!($Tensor in [TensorView, TensorViewMut, CowTensor] {
paste! {
impl<'a, T: Scalar, D: Dimension> From<$Tensor<'a, T, D>> for [<Scalar $Tensor>]<'a, D> {
fn from(tensor: $Tensor<'a, T, D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: tensor.buffer.into(),
offset: tensor.offset,
}
}
}
impl<'a, T: Scalar, D: Dimension> TryFrom<[<Scalar $Tensor>]<'a, D>> for $Tensor<'a, T, D> {
type Error = [<Scalar $Tensor>]<'a, D>;
fn try_from(tensor: [<Scalar $Tensor>]<'a, D>) -> Result<Self, Self::Error> {
match tensor.buffer.try_into() {
Ok(buffer) => Ok(Self {
dim: tensor.dim,
strides: tensor.strides,
buffer,
offset: tensor.offset,
}),
Err(buffer) => Err(Self::Error {
dim: tensor.dim,
strides: tensor.strides,
buffer,
offset: tensor.offset,
})
}
}
}
}
});
impl<S: ScalarData, D: Dimension> Debug for ScalarTensorBase<S, D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut builder = f.debug_struct("TensorBase");
builder
.field("device", &self.device())
.field("scalar_type", &self.scalar_type())
.field("shape", &self.shape());
if self.strides != self.dim.default_strides() {
builder.field("strides", &self.strides());
}
if self.offset > 0 {
builder.field("offset", &self.offset);
}
builder.finish()
}
}
impl<S: ScalarData, D: Dimension> ScalarTensorBase<S, D> {
pub fn cast_into(self, scalar_type: ScalarType) -> Result<ScalarTensor<D>> {
if self.scalar_type() == scalar_type {
self.into_owned()
} else {
self.cast(scalar_type)
}
}
pub fn cast(&self, scalar_type: ScalarType) -> Result<ScalarTensor<D>> {
if self.scalar_type() == scalar_type {
self.to_owned()
} else if !self.is_contiguous() {
self.scaled_cast(ScalarElem::one(scalar_type))
} else {
Ok(ScalarTensorBase {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.cast(scalar_type)?,
offset: 0,
})
}
}
pub fn cast_mut(&mut self, scalar_type: ScalarType) -> Result<()>
where
S: ScalarDataOwned,
{
if self.scalar_type() == scalar_type {
return Ok(());
}
let ScalarTensor {
dim,
strides,
buffer,
offset,
} = self.cast(scalar_type)?;
*self = Self {
dim,
strides,
buffer: ScalarBufferBase::from_scalar_buffer(buffer),
offset,
};
Ok(())
}
pub fn cast_into_tensor<T: Scalar>(self) -> Result<Tensor<T, D>> {
Ok(self.cast_into(T::SCALAR_TYPE)?.try_into().unwrap())
}
}
#[cfg(feature = "serde")]
#[derive(Serialize, Deserialize)]
#[serde(bound(
serialize = "S: ScalarData, D: Dimension + Serialize",
deserialize = "S: ScalarDataOwned, D: Dimension + Deserialize<'de>"
))]
#[serde(rename = "Tensor")]
struct ScalarTensorSerde<S: ScalarData, D: Dimension> {
dim: D,
buffer: ScalarBufferBase<S>,
}
#[cfg(feature = "serde")]
impl<S1: ScalarData, D: Dimension + Serialize> Serialize for ScalarTensorBase<S1, D> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::Error;
let buffer = if let Some(slice) = self.as_scalar_slice() {
ScalarCowBuffer::from(slice)
} else {
self.to_device(Device::host())
.map_err(S::Error::custom)?
.buffer
.into()
};
ScalarTensorSerde {
dim: self.dim.clone(),
buffer,
}
.serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de, S: ScalarDataOwned, D1: Dimension + Deserialize<'de>> Deserialize<'de>
for ScalarTensorBase<S, D1>
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let ScalarTensorSerde { dim, buffer } =
ScalarTensorSerde::<ScalarBufferRepr, D1>::deserialize(deserializer)?;
ScalarTensorBase::from(buffer)
.into_shape(dim)
.map_err(D::Error::custom)
}
}
#[derive(Clone)]
pub struct TensorBase<S: Data, D: Dimension> {
dim: D,
strides: D,
buffer: BufferBase<S>,
offset: usize,
}
pub type Tensor<T, D> = TensorBase<BufferRepr<T>, D>;
pub type Tensor0<T> = Tensor<T, Ix0>;
pub type Tensor1<T> = Tensor<T, Ix1>;
pub type Tensor2<T> = Tensor<T, Ix2>;
pub type Tensor3<T> = Tensor<T, Ix3>;
pub type Tensor4<T> = Tensor<T, Ix4>;
pub type Tensor5<T> = Tensor<T, Ix5>;
pub type Tensor6<T> = Tensor<T, Ix6>;
pub type TensorD<T> = Tensor<T, IxDyn>;
pub type ArcTensor<T, D> = TensorBase<ArcBufferRepr<T>, D>;
pub type ArcTensor0<T> = ArcTensor<T, Ix0>;
pub type ArcTensor1<T> = ArcTensor<T, Ix1>;
pub type ArcTensor2<T> = ArcTensor<T, Ix2>;
pub type ArcTensor3<T> = ArcTensor<T, Ix3>;
pub type ArcTensor4<T> = ArcTensor<T, Ix4>;
pub type ArcTensor5<T> = ArcTensor<T, Ix5>;
pub type ArcTensor6<T> = ArcTensor<T, Ix6>;
pub type ArcTensorD<T> = ArcTensor<T, IxDyn>;
pub type TensorView<'a, T, D> = TensorBase<SliceRepr<'a, T>, D>;
pub type TensorView0<'a, T> = TensorView<'a, T, Ix0>;
pub type TensorView1<'a, T> = TensorView<'a, T, Ix1>;
pub type TensorView2<'a, T> = TensorView<'a, T, Ix2>;
pub type TensorView3<'a, T> = TensorView<'a, T, Ix3>;
pub type TensorView4<'a, T> = TensorView<'a, T, Ix4>;
pub type TensorView5<'a, T> = TensorView<'a, T, Ix5>;
pub type TensorView6<'a, T> = TensorView<'a, T, Ix6>;
pub type TensorViewD<'a, T> = TensorView<'a, T, IxDyn>;
pub type TensorViewMut<'a, T, D> = TensorBase<SliceMutRepr<'a, T>, D>;
pub type TensorViewMut0<'a, T> = TensorViewMut<'a, T, Ix0>;
pub type TensorViewMut1<'a, T> = TensorViewMut<'a, T, Ix1>;
pub type TensorViewMut2<'a, T> = TensorViewMut<'a, T, Ix2>;
pub type TensorViewMut3<'a, T> = TensorViewMut<'a, T, Ix3>;
pub type TensorViewMut4<'a, T> = TensorViewMut<'a, T, Ix4>;
pub type TensorViewMut5<'a, T> = TensorViewMut<'a, T, Ix5>;
pub type TensorViewMut6<'a, T> = TensorViewMut<'a, T, Ix6>;
pub type TensorViewMutD<'a, T> = TensorViewMut<'a, T, IxDyn>;
pub type CowTensor<'a, T, D> = TensorBase<CowBufferRepr<'a, T>, D>;
pub type CowTensor0<'a, T> = CowTensor<'a, T, Ix0>;
pub type CowTensor1<'a, T> = CowTensor<'a, T, Ix1>;
pub type CowTensor2<'a, T> = CowTensor<'a, T, Ix2>;
pub type CowTensor3<'a, T> = CowTensor<'a, T, Ix3>;
pub type CowTensor4<'a, T> = CowTensor<'a, T, Ix4>;
pub type CowTensor5<'a, T> = CowTensor<'a, T, Ix5>;
pub type CowTensor6<'a, T> = CowTensor<'a, T, Ix6>;
pub type CowTensorD<'a, T> = CowTensor<'a, T, IxDyn>;
impl<T: Scalar, S: DataOwned<Elem = T>, D: Dimension> TensorBase<S, D> {
pub unsafe fn uninit<Sh>(device: Device, shape: Sh) -> Result<Self>
where
Sh: ndarray::ShapeBuilder<Dim = D>,
{
let (dim, strides) = dim_strides_from_shape(shape.into_shape());
let buffer = unsafe { BufferBase::uninit(device, dim.size())? };
Ok(Self {
dim,
strides,
buffer,
offset: 0,
})
}
pub fn from_elem<Sh>(device: Device, shape: Sh, elem: T) -> Result<Self>
where
Sh: ndarray::ShapeBuilder<Dim = D>,
{
let (dim, strides) = dim_strides_from_shape(shape.into_shape());
let buffer = BufferBase::from_elem(device, dim.size(), elem)?;
Ok(Self {
dim,
strides,
buffer,
offset: 0,
})
}
pub fn zeros<Sh>(device: Device, shape: Sh) -> Result<Self>
where
Sh: ndarray::ShapeBuilder<Dim = D>,
{
Self::from_elem(device, shape, T::default())
}
pub fn ones<Sh>(device: Device, shape: Sh) -> Result<Self>
where
Sh: ndarray::ShapeBuilder<Dim = D>,
{
Self::from_elem(device, shape, T::one())
}
}
impl<T: Scalar, S: Data<Elem = T>, D: Dimension> TensorBase<S, D> {
pub fn device(&self) -> Device {
self.buffer.device()
}
pub fn scalar_type(&self) -> ScalarType {
T::SCALAR_TYPE
}
pub fn dim(&self) -> D::Pattern {
self.dim.clone().into_pattern()
}
pub fn raw_dim(&self) -> D {
self.dim.clone()
}
pub fn shape(&self) -> &[usize] {
self.dim.slice()
}
pub fn strides(&self) -> &[isize] {
bytemuck::cast_slice(self.strides.slice())
}
pub fn len(&self) -> usize {
self.dim.size()
}
pub fn is_empty(&self) -> bool {
self.shape().iter().any(|x| *x == 0)
}
pub fn ndim(&self) -> usize {
self.dim.ndim()
}
pub fn into_dimensionality<D2>(self) -> Result<TensorBase<S, D2>, ShapeError>
where
D2: Dimension,
{
let (dim, strides) = into_dimensionality(&self.dim, &self.strides)?;
Ok(TensorBase {
dim,
strides,
buffer: self.buffer,
offset: self.offset,
})
}
pub fn into_dyn(self) -> TensorBase<S, IxDyn> {
TensorBase {
dim: self.dim.into_dyn(),
strides: self.strides.into_dyn(),
buffer: self.buffer,
offset: self.offset,
}
}
pub fn into_shape<E>(self, shape: E) -> Result<TensorBase<S, E::Dim>, ShapeError>
where
E: IntoDimension,
{
let shape = shape.into_dimension();
let (dim, strides) = into_shape(&self.dim, &self.strides, shape)?;
debug_assert_eq!(self.offset, 0);
Ok(TensorBase {
dim,
strides,
buffer: self.buffer,
offset: self.offset,
})
}
pub fn flatten(self) -> Result<TensorBase<S, Ix2>, ShapeError> {
let dim = flatten(self.shape());
self.into_shape(dim)
}
pub fn broadcast<E>(&self, dim: E) -> Option<TensorView<T, E::Dim>>
where
E: IntoDimension,
{
let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
Some(TensorView {
dim,
strides,
buffer: self.buffer.as_slice(),
offset: self.offset,
})
}
pub fn view(&self) -> TensorView<T, D> {
TensorView {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.as_slice(),
offset: self.offset,
}
}
pub fn view_mut(&mut self) -> TensorViewMut<T, D>
where
S: DataMut,
{
TensorViewMut {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.as_slice_mut(),
offset: self.offset,
}
}
pub fn get_view_mut(&mut self) -> Option<TensorViewMut<T, D>> {
if self.offset == 0 && self.is_contiguous() {
let buffer = self.buffer.get_slice_mut()?;
Some(TensorViewMut {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer,
offset: 0,
})
} else {
None
}
}
pub fn make_view_mut(&mut self) -> Result<TensorViewMut<T, D>>
where
S: DataOwned,
{
if self.offset == 0 && self.is_contiguous() {
Ok(TensorViewMut {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.make_slice_mut()?,
offset: 0,
})
} else {
let tensor = self.to_owned()?;
*self = Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: BufferBase::from_buffer(tensor.buffer),
offset: 0,
};
Ok(TensorViewMut {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.get_slice_mut().unwrap(),
offset: 0,
})
}
}
pub fn is_contiguous(&self) -> bool {
is_contiguous(&self.dim, &self.strides)
}
pub fn is_standard_layout(&self) -> bool {
is_standard_layout(&self.dim, &self.strides)
}
pub fn permuted_axes<A>(self, axes: A) -> Self
where
A: IntoDimension<Dim = D>,
{
let (dim, strides) = permuted_axes(self.dim, self.strides, axes.into_dimension());
Self {
dim,
strides,
..self
}
}
pub fn reversed_axes(mut self) -> Self {
self.dim.slice_mut().reverse();
self.strides.slice_mut().reverse();
self
}
pub fn t(&self) -> TensorView<T, D> {
self.view().reversed_axes()
}
pub fn index_axis(&self, axis: Axis, index: usize) -> TensorView<T, D::Smaller>
where
D: RemoveAxis,
{
self.view().index_axis_into(axis, index)
}
pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> TensorViewMut<T, D::Smaller>
where
S: DataMut,
D: RemoveAxis,
{
self.view_mut().index_axis_into(axis, index)
}
pub fn index_axis_into(mut self, axis: Axis, index: usize) -> TensorBase<S, D::Smaller>
where
D: RemoveAxis,
{
self.collapse_axis(axis, index);
let dim = self.dim.remove_axis(axis);
let strides = self.strides.remove_axis(axis);
TensorBase {
dim,
strides,
buffer: self.buffer,
offset: self.offset,
}
}
pub fn collapse_axis(&mut self, axis: Axis, index: usize) {
let offset =
collapse_axis(&mut self.dim, &self.strides, axis, index) + self.offset as isize;
debug_assert!(offset >= 0);
let offset = offset as usize;
debug_assert!(offset < self.buffer.len());
self.offset = offset;
}
pub fn as_slice(&self) -> Option<Slice<T>> {
if self.is_standard_layout() {
let (slice, _offset) = self.as_raw_slice_offset();
Some(slice)
} else {
None
}
}
pub fn as_slice_memory_order(&self) -> Option<Slice<T>> {
if self.is_contiguous() {
let (slice, _offset) = self.as_raw_slice_offset();
Some(slice)
} else {
None
}
}
pub fn as_slice_mut(&mut self) -> Option<SliceMut<T>>
where
S: DataMut,
{
if self.is_standard_layout() {
let (slice, _offset) = self.as_raw_slice_offset_mut();
Some(slice)
} else {
None
}
}
pub fn as_slice_memory_order_mut(&mut self) -> Option<SliceMut<T>>
where
S: DataMut,
{
if self.is_contiguous() {
let (slice, _offset) = self.as_raw_slice_offset_mut();
Some(slice)
} else {
None
}
}
pub fn as_raw_slice_offset(&self) -> (Slice<T>, usize) {
let strides: &[isize] = Self::strides(self);
if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
let slice = self.buffer.slice(self.offset..self.offset + len).unwrap();
(slice, 0)
} else {
(self.buffer.as_slice(), self.offset)
}
}
pub fn as_raw_slice_offset_mut(&mut self) -> (SliceMut<T>, usize)
where
S: DataMut,
{
let strides: &[isize] = Self::strides(self);
if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
let slice = self
.buffer
.slice_mut(self.offset..self.offset + len)
.unwrap();
(slice, 0)
} else {
(self.buffer.as_slice_mut(), self.offset)
}
}
pub fn to_device(&self, device: Device) -> Result<Tensor<T, D>> {
if self.device() == device {
self.to_owned()
} else {
self.view().into_device(device)
}
}
pub fn to_device_shared(&self, device: Device) -> Result<ArcTensor<T, D>> {
if self.device() == device {
self.to_shared()
} else {
self.to_device(device).map(Into::into)
}
}
pub fn to_device_mut(&mut self, device: Device) -> Result<()>
where
S: DataOwned,
{
if self.device() == device {
return Ok(());
}
let Tensor {
dim,
strides,
buffer,
offset,
} = self.to_device(device)?;
*self = Self {
dim,
strides,
buffer: BufferBase::from_buffer(buffer),
offset,
};
Ok(())
}
pub fn into_device(self, device: Device) -> Result<Tensor<T, D>> {
if device == self.device() {
self.into_owned()
} else if !self.is_contiguous() {
self.as_standard_layout()?.to_device(device)
} else {
let buffer = self.buffer.to_device(device)?;
Ok(Tensor {
dim: self.dim,
strides: self.strides,
buffer,
offset: 0,
})
}
}
pub fn into_device_shared(self, device: Device) -> Result<ArcTensor<T, D>> {
if device == self.device() {
self.into_shared()
} else if !self.is_contiguous() {
self.view()
.into_standard_layout()?
.into_device_shared(device)
} else {
let buffer = self.buffer.to_device_shared(device)?;
Ok(ArcTensor {
dim: self.dim,
strides: self.strides,
buffer,
offset: 0,
})
}
}
pub fn into_owned(self) -> Result<Tensor<T, D>> {
if !self.is_contiguous() {
return self.into_standard_layout();
}
Ok(TensorBase {
dim: self.dim,
strides: self.strides,
buffer: self.buffer.into_owned()?,
offset: 0,
})
}
pub fn to_owned(&self) -> Result<Tensor<T, D>> {
self.view().into_owned()
}
pub fn into_shared(self) -> Result<ArcTensor<T, D>> {
if self.is_contiguous() {
Ok(TensorBase {
dim: self.dim,
strides: self.strides,
buffer: self.buffer.into_shared()?,
offset: self.offset,
})
} else {
self.as_standard_layout()?.into_shared()
}
}
pub fn to_shared(&self) -> Result<ArcTensor<T, D>> {
if self.is_contiguous() {
Ok(TensorBase {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.to_shared()?,
offset: self.offset,
})
} else {
self.to_owned()?.into_shared()
}
}
pub fn fill(&mut self, elem: T) -> Result<()>
where
S: DataMut,
{
if self.is_contiguous() {
self.buffer.as_slice_mut().fill(elem)
} else if let Some(mut array) = self.as_array_mut() {
array.fill(elem);
Ok(())
} else {
bail!("TensorBase::fill tensor is not contiguous!")
}
}
pub fn into_array(self) -> Result<Array<T, D>> {
if self.is_contiguous() {
use ndarray::ShapeBuilder;
let vec = self.buffer.into_vec()?;
Ok(Array::from_shape_vec(self.dim.strides(self.strides), vec).unwrap())
} else if let Some(array) = self.as_array() {
Ok(array.into_owned())
} else {
bail!("TensorBase::into_array tensor is not contiguous!")
}
}
pub fn as_array(&self) -> Option<ArrayView<T, D>> {
use ndarray::ShapeBuilder;
self.buffer.as_host_slice().map(|host_slice| unsafe {
ArrayView::from_shape_ptr(
self.dim.clone().strides(self.strides.clone()),
&host_slice[self.offset] as *const T,
)
})
}
pub fn as_array_mut(&mut self) -> Option<ArrayViewMut<T, D>>
where
S: DataMut,
{
use ndarray::ShapeBuilder;
if let Some(host_slice) = self.buffer.as_host_slice_mut() {
let host_slice = unsafe {
std::slice::from_raw_parts_mut(host_slice.as_mut_ptr(), host_slice.len())
};
Some(unsafe {
ArrayViewMut::from_shape_ptr(
self.dim.clone().strides(self.strides.clone()),
host_slice[self.offset..].as_mut_ptr(),
)
})
} else {
None
}
}
}
impl<T: Scalar, D: Dimension> Tensor<T, D> {
pub fn into_scalar_tensor(self) -> ScalarTensor<D> {
self.into()
}
}
impl<'a, T: Scalar, D: Dimension> CowTensor<'a, T, D> {
pub fn into_scalar_cow_tensor(self) -> ScalarCowTensor<'a, D> {
self.into()
}
}
impl<T: Scalar, D: Dimension> ArcTensor<T, D> {
pub fn broadcast_shared<E>(&self, dim: E) -> Option<ArcTensor<T, E::Dim>>
where
E: IntoDimension,
{
let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
Some(ArcTensor {
dim,
strides,
buffer: self.buffer.clone(),
offset: self.offset,
})
}
}
impl<T: Scalar, S: DataOwned<Elem = T>> From<Buffer<T>> for TensorBase<S, Ix1> {
fn from(buffer: Buffer<T>) -> Self {
let dim = buffer.len().into_dimension();
let strides = dim.default_strides();
let buffer = BufferBase::from_buffer(buffer);
Self {
dim,
strides,
buffer,
offset: 0,
}
}
}
impl<T: Scalar, S: DataOwned<Elem = T>> From<Vec<T>> for TensorBase<S, Ix1> {
fn from(vec: Vec<T>) -> Self {
let dim = vec.len().into_dimension();
let strides = dim.default_strides();
let buffer = BufferBase::from_buffer(Buffer::from(vec));
Self {
dim,
strides,
buffer,
offset: 0,
}
}
}
impl<'a, T: Scalar> From<Slice<'a, T>> for TensorView<'a, T, Ix1> {
fn from(slice: Slice<'a, T>) -> Self {
let dim = slice.len().into_dimension();
let strides = dim.default_strides();
Self {
dim,
strides,
buffer: slice,
offset: 0,
}
}
}
impl<'a, T: Scalar> From<SliceMut<'a, T>> for TensorViewMut<'a, T, Ix1> {
fn from(slice: SliceMut<'a, T>) -> Self {
let dim = slice.len().into_dimension();
let strides = dim.default_strides();
Self {
dim,
strides,
buffer: slice,
offset: 0,
}
}
}
impl<T: Scalar, S: DataOwned<Elem = T>, D: Dimension> From<Array<T, D>> for TensorBase<S, D> {
fn from(array: Array<T, D>) -> Self {
let dim = array.raw_dim();
let strides = strides_from_array(&array);
let buffer = BufferBase::from_vec(array.into_raw_vec());
Self {
dim,
strides,
buffer,
offset: 0,
}
}
}
impl<'a, T: Scalar, D: Dimension> From<ArrayView<'a, T, D>> for CowTensor<'a, T, D> {
fn from(array: ArrayView<'a, T, D>) -> Self {
if let Some(slice) = array.to_slice_memory_order() {
let dim = array.raw_dim();
let strides = strides_from_array(&array);
let buffer = Slice::from(slice).into();
Self {
dim,
strides,
buffer,
offset: 0,
}
} else {
Self::from(array.to_owned())
}
}
}
impl<'a, T: Scalar, D: Dimension> TryFrom<ArrayView<'a, T, D>> for TensorView<'a, T, D> {
type Error = anyhow::Error;
fn try_from(array: ArrayView<'a, T, D>) -> Result<Self> {
let slice = array
.as_slice_memory_order()
.ok_or_else(|| anyhow!("Not contiguous!"))?;
let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) };
let dim = array.raw_dim();
let strides = strides_from_array(&array);
Ok(Self {
dim,
strides,
buffer: slice.into(),
offset: 0,
})
}
}
impl<'a, T: Scalar, D: Dimension> From<TensorView<'a, T, D>> for CowTensor<'a, T, D> {
fn from(view: TensorView<'a, T, D>) -> Self {
Self {
dim: view.dim,
strides: view.strides,
buffer: view.buffer.into(),
offset: view.offset,
}
}
}
impl<T: Scalar, D: Dimension> From<Tensor<T, D>> for CowTensor<'_, T, D> {
fn from(tensor: Tensor<T, D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: tensor.buffer.into(),
offset: tensor.offset,
}
}
}
impl<T: Scalar, D: Dimension> From<Tensor<T, D>> for ArcTensor<T, D> {
fn from(tensor: Tensor<T, D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
buffer: tensor.buffer.into(),
offset: tensor.offset,
}
}
}
impl<S: Data, D: Dimension> Debug for TensorBase<S, D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
ScalarTensorView::from(self.view()).fmt(f)
}
}
impl<T: Scalar, S: Data<Elem = T>, D: Dimension> TensorBase<S, D> {
pub fn cast_into<Y: Scalar>(self) -> Result<Tensor<Y, D>> {
if T::SCALAR_TYPE == Y::SCALAR_TYPE && self.is_contiguous() {
Ok(TensorBase {
dim: self.dim,
strides: self.strides,
buffer: self.buffer.cast_into()?,
offset: 0,
})
} else {
self.cast()
}
}
pub fn cast<Y: Scalar>(&self) -> Result<Tensor<Y, D>> {
if !self.is_contiguous() {
return self.scaled_cast(Y::one());
}
Ok(TensorBase {
dim: self.dim.clone(),
strides: self.strides.clone(),
buffer: self.buffer.cast()?,
offset: 0,
})
}
}
#[cfg(feature = "serde")]
#[derive(Serialize, Deserialize)]
#[serde(bound(
serialize = "S: Data, D: Dimension + Serialize",
deserialize = "S: DataOwned, D: Dimension + Deserialize<'de>"
))]
#[serde(rename = "Tensor")]
struct TensorSerde<S: Data, D: Dimension> {
dim: D,
buffer: BufferBase<S>,
}
#[cfg(feature = "serde")]
impl<S1: Data, D: Dimension + Serialize> Serialize for TensorBase<S1, D> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::Error;
let buffer = if let Some(slice) = self.as_slice() {
CowBuffer::from(slice)
} else {
self.to_device(Device::host())
.map_err(S::Error::custom)?
.buffer
.into()
};
TensorSerde {
dim: self.dim.clone(),
buffer,
}
.serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de, S: DataOwned, D1: Dimension + Deserialize<'de>> Deserialize<'de> for TensorBase<S, D1> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let TensorSerde { dim, buffer } =
TensorSerde::<BufferRepr<S::Elem>, D1>::deserialize(deserializer)?;
TensorBase::from(buffer)
.into_shape(dim)
.map_err(D::Error::custom)
}
}
#[cfg(all(test, feature = "serde"))]
mod tests {
use super::*;
use serde_test::{assert_tokens, Token};
#[test]
fn tensor_serde() {
let data = vec![1u32, 2, 3, 4];
let items: Vec<u64> = bytemuck::cast_slice(data.as_slice()).to_vec();
let tensor = Tensor::from(Buffer::from(data));
let tokens = [
Token::Struct {
name: "Tensor",
len: 2,
},
Token::Str("dim"),
Token::Tuple { len: 1 },
Token::U64(4),
Token::TupleEnd,
Token::Str("buffer"),
Token::TupleStruct {
name: "Buffer",
len: 3,
},
Token::Str("U32"),
Token::U64(4),
Token::Seq { len: Some(2) },
Token::U64(items[0].to_be()),
Token::U64(items[1].to_be()),
Token::SeqEnd,
Token::TupleStructEnd,
Token::StructEnd,
];
#[derive(Debug, Serialize, Deserialize)]
#[serde(transparent)]
struct TensorWrap(Tensor1<u32>);
impl PartialEq for TensorWrap {
fn eq(&self, other: &Self) -> bool {
self.0.as_array().unwrap() == other.0.as_array().unwrap()
}
}
impl Eq for TensorWrap {}
assert_tokens(&TensorWrap(tensor), &tokens);
}
}