use crate::{
device::{
buffer::{
ArcBuffer, Buffer, CowBuffer, ReadGuard as BufferReadGuard, Slice, SliceMut, SliceRepr,
},
Device,
},
error::Error,
glsl_shaders,
result::Result,
scalar::{Scalar, ScalarType, Uint},
util::{elem_type_name, size_eq},
};
use anyhow::{anyhow, bail};
use bytemuck::Pod;
use ndarray::{
Array, ArrayBase, ArrayView, Dimension, IntoDimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6,
IxDyn, RawArrayView, ShapeBuilder, StrideShape,
};
use serde::{Deserialize, Serialize};
use std::{
convert::{TryFrom, TryInto},
fmt::{self, Debug},
mem::{size_of, transmute},
};
mod accuracy;
mod linalg;
mod ops;
mod reduce;
mod reorder;
pub mod float;
mod sealed {
use super::Device;
pub trait DataBase {
#[doc(hidden)]
fn device(&self) -> Device;
fn len(&self) -> usize;
#[doc(hidden)]
fn is_empty(&self) -> bool {
self.len() == 0
}
}
}
pub(crate) use sealed::DataBase;
macro_rules! impl_data_base {
($($data:ident $(<$a:lifetime>)?),+) => {
$(
impl<T> DataBase for $data <$($a,)? T> {
fn device(&self) -> Device {
self.0.device()
}
fn len(&self) -> usize {
self.0.len()
}
fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
)+
};
}
impl_data_base! {OwnedRepr, ArcRepr, ViewRepr<'_>, ViewMutRepr<'_>, CowRepr<'_>}
pub trait Data: Sized + DataBase {
type Elem;
#[doc(hidden)]
fn try_into_buffer(self) -> Result<Buffer<Self::Elem>, Self> {
Err(self)
}
#[doc(hidden)]
fn into_owned(self) -> Result<OwnedRepr<Self::Elem>>
where
Self::Elem: Copy,
{
match self.try_into_buffer() {
Ok(buffer) => Ok(OwnedRepr(buffer)),
Err(this) => Ok(OwnedRepr(this.as_slice().to_owned()?)),
}
}
#[doc(hidden)]
fn try_into_arc_buffer(self) -> Result<ArcBuffer<Self::Elem>, Self> {
self.try_into_buffer().map(Into::into)
}
#[doc(hidden)]
fn into_shared(self) -> Result<ArcRepr<Self::Elem>>
where
Self::Elem: Copy,
{
match self.try_into_arc_buffer() {
Ok(buffer) => Ok(ArcRepr(buffer)),
Err(this) => Ok(ArcRepr(this.as_slice().to_owned()?.into())),
}
}
#[doc(hidden)]
fn to_shared(&self) -> Result<ArcRepr<Self::Elem>>
where
Self::Elem: Copy,
{
Ok(ArcRepr(self.as_slice().to_owned()?.into()))
}
#[doc(hidden)]
fn as_slice(&self) -> Slice<Self::Elem>;
}
pub trait DataOwned: Data {
#[doc(hidden)]
fn from_buffer(buffer: Buffer<Self::Elem>) -> Self;
}
pub trait DataMut: Data {
#[doc(hidden)]
fn as_slice_mut(&mut self) -> SliceMut<Self::Elem>;
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(bound(
serialize = "T: Pod + Serialize",
deserialize = "T: Pod + Deserialize<'de>"
))]
pub struct OwnedRepr<T>(pub(crate) Buffer<T>);
impl<T> Data for OwnedRepr<T> {
type Elem = T;
fn try_into_buffer(self) -> Result<Buffer<T>, Self> {
Ok(self.0)
}
fn as_slice(&self) -> Slice<T> {
self.0.as_slice()
}
}
impl<T> DataOwned for OwnedRepr<T> {
fn from_buffer(buffer: Buffer<T>) -> Self {
Self(buffer)
}
}
impl<T> DataMut for OwnedRepr<T> {
fn as_slice_mut(&mut self) -> SliceMut<T> {
self.0.as_slice_mut()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(
serialize = "T: Pod + Serialize",
deserialize = "T: Pod + Deserialize<'de>"
))]
pub struct ArcRepr<T>(pub(crate) ArcBuffer<T>);
impl<T> Data for ArcRepr<T> {
type Elem = T;
fn try_into_buffer(self) -> Result<Buffer<T>, Self> {
self.0.try_unwrap().map_err(Self)
}
fn try_into_arc_buffer(self) -> Result<ArcBuffer<T>, Self> {
Ok(self.0)
}
fn to_shared(&self) -> Result<Self>
where
Self::Elem: Copy,
{
Ok(self.clone())
}
fn as_slice(&self) -> Slice<T> {
self.0.as_slice()
}
}
impl<T> DataOwned for ArcRepr<T> {
fn from_buffer(buffer: Buffer<T>) -> Self {
Self(buffer.into())
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(bound = "T: Pod + Serialize")]
pub struct ViewRepr<'a, T>(pub(crate) Slice<'a, T>);
impl<T> Data for ViewRepr<'_, T> {
type Elem = T;
fn as_slice(&self) -> Slice<T> {
self.0.as_slice()
}
}
#[derive(Debug, Serialize)]
#[serde(bound = "T: Pod + Serialize")]
pub struct ViewMutRepr<'a, T>(pub(crate) SliceMut<'a, T>);
impl<T> Data for ViewMutRepr<'_, T> {
type Elem = T;
fn as_slice(&self) -> Slice<T> {
self.0.as_slice()
}
}
impl<T> DataMut for ViewMutRepr<'_, T> {
fn as_slice_mut(&mut self) -> SliceMut<T> {
self.0.as_slice_mut()
}
}
#[derive(Debug)]
pub struct CowRepr<'a, T>(pub(crate) CowBuffer<'a, T>);
impl<T> Data for CowRepr<'_, T> {
type Elem = T;
fn as_slice(&self) -> Slice<T> {
self.0.as_slice()
}
fn try_into_buffer(self) -> Result<Buffer<T>, Self> {
self.0.try_unwrap().map_err(Self)
}
}
impl<T> DataOwned for CowRepr<'_, T> {
fn from_buffer(buffer: Buffer<T>) -> Self {
Self(buffer.into())
}
}
fn strides_from_array<S, D>(array: &ArrayBase<S, D>) -> D
where
S: ndarray::RawData,
D: Dimension,
{
let strides_slice: &[usize] = bytemuck::cast_slice(array.strides());
let mut strides = D::zeros(strides_slice.len());
for (i, s) in strides_slice.iter().copied().enumerate() {
strides[i] = s;
}
strides
}
fn dim_strides_from_shape<D: Dimension>(shape: impl Into<StrideShape<D>>) -> (D, D) {
let array = unsafe { RawArrayView::from_shape_ptr(shape, &()) };
let dim = array.raw_dim();
let strides = strides_from_array(&array);
(dim, strides)
}
fn into_dimensionality<D1, D2>(dim: &D1, strides: &D1) -> Result<(D2, D2)>
where
D1: Dimension,
D2: Dimension,
{
D2::from_dimension(dim)
.and_then(|dim| D2::from_dimension(strides).map(|strides| (dim, strides)))
.ok_or_else(|| {
let strides = bytemuck::cast_slice::<_, isize>(strides.slice());
anyhow!(
"Incompatible Shapes! {:?} {:?} => {:?}",
dim.slice(),
strides,
D2::NDIM
)
})
}
fn into_shape<D1, E>(dim: &D1, strides: &D1, shape: E) -> Result<(E::Dim, E::Dim)>
where
D1: Dimension,
E: IntoDimension,
{
let shape = shape.into_dimension();
let zero_strides = strides.slice().iter().any(|s| *s == 0);
if shape.size() == dim.size() && (zero_strides || strides == &dim.default_strides()) {
let strides = shape.default_strides();
Ok((shape, strides))
} else if dim.ndim() > 1 && (zero_strides || strides == &dim.fortran_strides()) {
let strides = shape.fortran_strides();
Ok((shape, strides))
} else {
Err(anyhow!(
"Incompatible Shapes! {:?} {:?} => {:?}",
dim.slice(),
strides.slice(),
shape.slice()
))
}
}
fn is_standard_layout<D: Dimension>(dim: &D, strides: &D) -> bool {
let zero_strides = strides.slice().iter().any(|s| *s == 0);
zero_strides || strides == &dim.default_strides()
}
fn permuted_axes<D: Dimension>(dim: D, strides: D, axes: D) -> (D, D) {
let mut usage_counts = D::zeros(dim.ndim());
for axis in axes.slice() {
usage_counts[*axis] += 1;
}
for count in usage_counts.slice() {
assert_eq!(*count, 1, "each axis must be listed exactly once");
}
let mut new_dim = usage_counts; let mut new_strides = D::zeros(dim.ndim());
{
let dim = dim.slice();
let strides = strides.slice();
for (new_axis, &axis) in axes.slice().iter().enumerate() {
new_dim[new_axis] = dim[axis];
new_strides[new_axis] = strides[axis];
}
}
(new_dim, new_strides)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorBase<S: Data, D: Dimension> {
dim: D,
strides: D,
data: S,
}
pub type Tensor<T, D> = TensorBase<OwnedRepr<T>, D>;
pub type Tensor0<T> = Tensor<T, Ix0>;
pub type Tensor1<T> = Tensor<T, Ix1>;
pub type Tensor2<T> = Tensor<T, Ix2>;
pub type Tensor3<T> = Tensor<T, Ix3>;
pub type Tensor4<T> = Tensor<T, Ix4>;
pub type Tensor5<T> = Tensor<T, Ix5>;
pub type Tensor6<T> = Tensor<T, Ix6>;
pub type TensorD<T> = Tensor<T, IxDyn>;
pub type ArcTensor<T, D> = TensorBase<ArcRepr<T>, D>;
pub type ArcTensor0<T> = ArcTensor<T, Ix0>;
pub type ArcTensor1<T> = ArcTensor<T, Ix1>;
pub type ArcTensor2<T> = ArcTensor<T, Ix2>;
pub type ArcTensor3<T> = ArcTensor<T, Ix3>;
pub type ArcTensor4<T> = ArcTensor<T, Ix4>;
pub type ArcTensor5<T> = ArcTensor<T, Ix5>;
pub type ArcTensor6<T> = ArcTensor<T, Ix6>;
pub type ArcTensorD<T> = ArcTensor<T, IxDyn>;
pub type TensorView<'a, T, D> = TensorBase<ViewRepr<'a, T>, D>;
pub type TensorView0<'a, T> = TensorView<'a, T, Ix0>;
pub type TensorView1<'a, T> = TensorView<'a, T, Ix1>;
pub type TensorView2<'a, T> = TensorView<'a, T, Ix2>;
pub type TensorView3<'a, T> = TensorView<'a, T, Ix3>;
pub type TensorView4<'a, T> = TensorView<'a, T, Ix4>;
pub type TensorView5<'a, T> = TensorView<'a, T, Ix5>;
pub type TensorView6<'a, T> = TensorView<'a, T, Ix6>;
pub type TensorViewD<'a, T> = TensorView<'a, T, IxDyn>;
pub type TensorViewMut<'a, T, D> = TensorBase<ViewMutRepr<'a, T>, D>;
pub type TensorViewMut0<'a, T> = TensorViewMut<'a, T, Ix0>;
pub type TensorViewMut1<'a, T> = TensorViewMut<'a, T, Ix1>;
pub type TensorViewMut2<'a, T> = TensorViewMut<'a, T, Ix2>;
pub type TensorViewMut3<'a, T> = TensorViewMut<'a, T, Ix3>;
pub type TensorViewMut4<'a, T> = TensorViewMut<'a, T, Ix4>;
pub type TensorViewMut5<'a, T> = TensorViewMut<'a, T, Ix5>;
pub type TensorViewMut6<'a, T> = TensorViewMut<'a, T, Ix6>;
pub type TensorViewMutD<'a, T> = TensorViewMut<'a, T, IxDyn>;
pub type CowTensor<'a, T, D> = TensorBase<CowRepr<'a, T>, D>;
pub type CowTensor0<'a, T> = CowTensor<'a, T, Ix0>;
pub type CowTensor1<'a, T> = CowTensor<'a, T, Ix1>;
pub type CowTensor2<'a, T> = CowTensor<'a, T, Ix2>;
pub type CowTensor3<'a, T> = CowTensor<'a, T, Ix3>;
pub type CowTensor4<'a, T> = CowTensor<'a, T, Ix4>;
pub type CowTensor5<'a, T> = CowTensor<'a, T, Ix5>;
pub type CowTensor6<'a, T> = CowTensor<'a, T, Ix6>;
pub type CowTensorD<'a, T> = CowTensor<'a, T, IxDyn>;
impl<T, S: Data<Elem = T>, D: Dimension> TensorBase<S, D> {
pub unsafe fn alloc<Sh>(device: Device, shape: Sh) -> Result<Self>
where
T: Default + Copy,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
{
let (dim, strides) = dim_strides_from_shape(shape.into_shape());
let data = S::from_buffer(Buffer::alloc(device, dim.size())?);
Ok(Self { dim, strides, data })
}
pub fn from_elem<Sh>(device: Device, shape: Sh, elem: T) -> Result<Self>
where
T: Scalar,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
{
let (dim, strides) = dim_strides_from_shape(shape.into_shape());
let data = S::from_buffer(Buffer::from_elem(device, dim.size(), elem)?);
Ok(Self { dim, strides, data })
}
pub fn zeros<Sh>(device: Device, shape: Sh) -> Result<Self>
where
T: Scalar,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
{
Self::from_elem(device, shape, T::default())
}
pub fn ones<Sh>(device: Device, shape: Sh) -> Result<Self>
where
T: Scalar,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
{
Self::from_elem(device, shape, T::one())
}
pub fn device(&self) -> Device {
self.data.device()
}
pub fn dim(&self) -> D::Pattern {
self.dim.clone().into_pattern()
}
pub fn raw_dim(&self) -> D {
self.dim.clone()
}
pub fn shape(&self) -> &[usize] {
self.dim.slice()
}
pub fn strides(&self) -> &[isize] {
bytemuck::cast_slice(self.strides.slice())
}
pub fn len(&self) -> usize {
debug_assert_eq!(self.data.len(), self.dim.size());
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn ndim(&self) -> usize {
self.dim.ndim()
}
pub fn into_dimensionality<D2>(self) -> Result<TensorBase<S, D2>>
where
D2: Dimension,
{
let (dim, strides) = into_dimensionality(&self.dim, &self.strides)?;
Ok(TensorBase {
dim,
strides,
data: self.data,
})
}
pub fn into_shape<E>(self, shape: E) -> Result<TensorBase<S, E::Dim>>
where
E: IntoDimension,
{
let (dim, strides) = into_shape(&self.dim, &self.strides, shape)?;
Ok(TensorBase {
dim,
strides,
data: self.data,
})
}
pub fn into_dyn(self) -> TensorBase<S, IxDyn> {
TensorBase {
dim: self.dim.into_dyn(),
strides: self.strides.into_dyn(),
data: self.data,
}
}
pub fn view(&self) -> TensorView<T, D> {
TensorView {
dim: self.dim.clone(),
strides: self.strides.clone(),
data: ViewRepr(self.data.as_slice()),
}
}
pub fn view_mut(&mut self) -> TensorViewMut<T, D>
where
S: DataMut,
{
TensorViewMut {
dim: self.dim.clone(),
strides: self.strides.clone(),
data: ViewMutRepr(self.data.as_slice_mut()),
}
}
pub fn is_standard_layout(&self) -> bool {
is_standard_layout(&self.dim, &self.strides)
}
pub fn permuted_axes<A>(self, axes: A) -> Self
where
A: IntoDimension<Dim = D>,
{
let (dim, strides) = permuted_axes(self.dim, self.strides, axes.into_dimension());
Self {
dim,
strides,
data: self.data,
}
}
pub fn reversed_axes(mut self) -> Self {
self.dim.slice_mut().reverse();
self.strides.slice_mut().reverse();
self
}
pub fn t(&self) -> TensorView<T, D> {
self.view().reversed_axes()
}
pub fn to_slice(&self) -> Result<CowBuffer<T>>
where
T: Scalar,
{
if self.strides == self.dim.default_strides() {
Ok(self.data.as_slice().into())
} else {
Ok(self.as_standard_layout()?.data.0)
}
}
pub fn as_raw_slice(&self) -> Slice<T> {
self.data.as_slice()
}
pub fn as_raw_slice_mut(&mut self) -> SliceMut<T>
where
S: DataMut,
{
self.data.as_slice_mut()
}
pub async fn into_device(self, device: Device) -> Result<Tensor<T, D>>
where
T: Pod,
{
if device == self.device() {
self.into_owned()
} else {
let buffer = self.data.as_slice().into_device(device).await?;
Ok(Tensor {
dim: self.dim,
strides: self.strides,
data: OwnedRepr(buffer),
})
}
}
pub async fn read(self) -> Result<ReadGuard<S, D>>
where
T: Pod,
{
ReadGuard::new(self.dim, self.strides, self.data).await
}
pub fn into_owned(self) -> Result<Tensor<T, D>>
where
T: Copy,
{
Ok(TensorBase {
dim: self.dim,
strides: self.strides,
data: self.data.into_owned()?,
})
}
pub fn to_owned(&self) -> Result<Tensor<T, D>>
where
T: Copy,
{
self.view().into_owned()
}
pub fn into_shared(self) -> Result<ArcTensor<T, D>>
where
T: Copy,
{
Ok(TensorBase {
dim: self.dim,
strides: self.strides,
data: self.data.into_shared()?,
})
}
pub fn to_shared(&self) -> Result<ArcTensor<T, D>>
where
T: Copy,
{
Ok(TensorBase {
dim: self.dim.clone(),
strides: self.strides.clone(),
data: self.data.to_shared()?,
})
}
#[doc(hidden)]
pub fn into_float<S2: float::FloatData>(self) -> float::FloatTensorBase<S2, D>
where
Self: Into<float::FloatTensorBase<S2, D>>,
{
self.into()
}
pub fn fill(&mut self, elem: T) -> Result<()>
where
T: Scalar,
S: DataMut,
{
self.data.as_slice_mut().fill(elem)
}
}
impl<T, S: DataOwned<Elem = T>> From<Buffer<T>> for TensorBase<S, Ix1> {
fn from(buffer: Buffer<T>) -> Self {
let dim = buffer.len().into_dimension();
let strides = dim.default_strides();
let data = S::from_buffer(buffer);
Self { dim, strides, data }
}
}
impl<'a, T> From<Slice<'a, T>> for TensorView<'a, T, Ix1> {
fn from(slice: Slice<'a, T>) -> Self {
let dim = slice.len().into_dimension();
let strides = dim.default_strides();
let data = ViewRepr(slice);
Self { dim, strides, data }
}
}
impl<T, S: DataOwned<Elem = T>, D: Dimension> From<Array<T, D>> for TensorBase<S, D> {
fn from(array: Array<T, D>) -> Self {
let dim = array.raw_dim();
let strides = strides_from_array(&array);
let buffer = Buffer::from(array.into_raw_vec());
let data = S::from_buffer(buffer);
Self { dim, strides, data }
}
}
impl<'a, T: Clone, D: Dimension> From<ArrayView<'a, T, D>> for CowTensor<'a, T, D> {
fn from(array: ArrayView<'a, T, D>) -> Self {
if let Some(slice) = array.as_slice_memory_order() {
let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) };
let dim = array.raw_dim();
let strides = strides_from_array(&array);
let data = CowRepr(slice.into());
Self { dim, strides, data }
} else {
Self::from(array.to_owned())
}
}
}
impl<'a, T, D: Dimension> TryFrom<ArrayView<'a, T, D>> for TensorView<'a, T, D> {
type Error = Error;
fn try_from(array: ArrayView<'a, T, D>) -> Result<Self> {
let slice = array
.as_slice_memory_order()
.ok_or_else(|| anyhow!("Shape not contiguous!"))?;
let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) };
let dim = array.raw_dim();
let strides = strides_from_array(&array);
let data = ViewRepr(slice.into());
Ok(Self { dim, strides, data })
}
}
impl<'a, T, D: Dimension> From<TensorView<'a, T, D>> for CowTensor<'a, T, D> {
fn from(view: TensorView<'a, T, D>) -> Self {
Self {
dim: view.dim,
strides: view.strides,
data: CowRepr(view.data.0.into()),
}
}
}
impl<T, D: Dimension> From<Tensor<T, D>> for CowTensor<'_, T, D> {
fn from(tensor: Tensor<T, D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
data: CowRepr(tensor.data.0.into()),
}
}
}
impl<T, D: Dimension> From<Tensor<T, D>> for ArcTensor<T, D> {
fn from(tensor: Tensor<T, D>) -> Self {
Self {
dim: tensor.dim,
strides: tensor.strides,
data: ArcRepr(tensor.data.0.into()),
}
}
}
#[allow(unused)]
impl<T: Scalar, S: Data<Elem = T>, D: Dimension> TensorBase<S, D> {
pub fn cast_into<T2: Scalar>(self) -> Result<Tensor<T2, D>> {
let buffer = match self.data.try_into_buffer() {
Ok(buffer) => buffer.cast_into()?,
Err(data) => data.as_slice().cast_into()?,
};
Ok(TensorBase {
dim: self.dim,
strides: self.strides,
data: OwnedRepr(buffer),
})
}
pub fn cast_to<T2: Scalar>(&self) -> Result<CowTensor<T2, D>> {
let slice = self.data.as_slice();
let buffer: CowBuffer<T2> = slice.cast_to::<T2>()?;
Ok(TensorBase {
dim: self.dim.clone(),
strides: self.strides.clone(),
data: CowRepr(unsafe { transmute(buffer) }),
})
}
pub fn scale_into<T2: Scalar>(self, alpha: T2) -> Result<Tensor<T2, D>> {
let buffer = match self.data.try_into_buffer() {
Ok(buffer) => buffer.scale_into(alpha)?,
Err(data) => data.as_slice().scale_into(alpha)?,
};
Ok(TensorBase {
dim: self.dim,
strides: self.strides,
data: OwnedRepr(buffer),
})
}
}
#[allow(dead_code)]
impl<T: Uint, S: Data<Elem = T>> TensorBase<S, Ix1> {
pub(crate) fn to_one_hot<T2: Scalar>(&self, nclasses: usize) -> Result<Tensor2<T2>> {
let n = self.dim();
let mut output = unsafe { Tensor::alloc(self.device(), [n, nclasses])? };
if size_of::<T2>() < 4 {
output.fill(T2::zero())?;
}
let builder = glsl_shaders::module(&format!(
"one_hot_{}_{}",
T::scalar_name(),
T2::scalar_name()
))?
.compute_pass("main")?
.slice(self.as_raw_slice())?
.slice_mut(output.as_raw_slice_mut())?
.push([n as u32, nclasses as u32])?;
unsafe {
builder.submit([n as u32, 1, 1])?;
}
Ok(output)
}
}
pub struct ReadGuard<S: Data, D: Dimension>
where
S::Elem: 'static,
{
dim: D,
strides: D,
guard: BufferReadGuard<SliceRepr<'static, S::Elem>>,
data: S,
}
impl<T: Pod, S: Data<Elem = T>, D: Dimension> ReadGuard<S, D> {
async fn new(dim: D, strides: D, data: S) -> Result<Self> {
let guard: BufferReadGuard<SliceRepr<T>> = data.as_slice().read().await?;
let guard = unsafe { transmute(guard) };
Ok(Self {
dim,
strides,
guard,
data,
})
}
pub fn as_array(&self) -> ArrayView<T, D> {
unsafe {
RawArrayView::from_shape_ptr(
self.dim.clone().strides(self.strides.clone()),
self.guard.as_slice().as_ptr(),
)
.deref_into_view()
}
}
pub fn into_array(self) -> Array<T, D> {
if let Ok(buffer) = self.data.try_into_buffer() {
if let Some(vec) = buffer.into_vec() {
unsafe {
return Array::from_shape_vec_unchecked(self.dim.strides(self.strides), vec);
}
}
}
unsafe {
Array::from_shape_vec_unchecked(self.dim.strides(self.strides), self.guard.to_vec())
}
}
}
impl<T: Pod + Debug, S: Data<Elem = T>, D: Dimension> Debug for ReadGuard<S, D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.as_array().fmt(f)
}
}
#[cfg(test)]
mod tests {
#[allow(unused)]
use super::*;
#[cfg(feature = "device_tests")]
use half::bf16;
#[cfg(feature = "device_tests")]
use ndarray::{Array1, Array2};
async fn tensor_from_array<D: Dimension>(x: Array<u32, D>) -> Result<()> {
let y = TensorView::try_from(x.view())?.read().await?;
assert_eq!(x.view(), y.as_array());
let y_t = TensorView::try_from(x.t())?.read().await?;
assert_eq!(x.t(), y_t.as_array());
Ok(())
}
#[tokio::test]
async fn tensor_from_array0() -> Result<()> {
tensor_from_array(Array::from_elem((), 1)).await
}
#[tokio::test]
async fn tensor_from_array1() -> Result<()> {
tensor_from_array(Array::from_shape_vec(3, (1..=3).into_iter().collect())?).await
}
#[tokio::test]
async fn tensor_from_array2() -> Result<()> {
tensor_from_array(Array::from_shape_vec(
[2, 3],
(1..=6).into_iter().collect(),
)?)
.await
}
#[tokio::test]
async fn tensor_from_array3() -> Result<()> {
tensor_from_array(Array::from_shape_vec(
[2, 3, 4],
(1..=24).into_iter().collect(),
)?)
.await
}
#[tokio::test]
async fn tensor_from_array4() -> Result<()> {
tensor_from_array(Array::from_shape_vec(
[2, 3, 4, 5],
(1..=120).into_iter().collect(),
)?)
.await
}
#[tokio::test]
async fn test_from_array5() -> Result<()> {
tensor_from_array(Array::from_shape_vec(
[2, 3, 4, 5, 6],
(1..=120 * 6).into_iter().collect(),
)?)
.await
}
#[tokio::test]
async fn tensor_from_array6() -> Result<()> {
tensor_from_array(Array::from_shape_vec(
[2, 3, 4, 5, 6, 7],
(1..=120 * 6 * 7).into_iter().collect(),
)?)
.await
}
#[allow(non_snake_case)]
#[tokio::test]
async fn tensor_from_arrayD() -> Result<()> {
tensor_from_array(Array::from_shape_vec(
[2, 3, 4, 5, 6, 7, 8].as_ref(),
(1..=120 * 6 * 7 * 8).into_iter().collect(),
)?)
.await
}
async fn tensor_serde(device: Device) -> Result<()> {
let x = (0..4 * 5 * 6 * 7).into_iter().collect::<Vec<u32>>();
let array = Array::from(x).into_shape([4, 5, 6, 7])?;
let tensor = TensorView::try_from(array.view())?
.into_device(device)
.await?;
let tensor: Tensor4<u32> = bincode::deserialize(&bincode::serialize(&tensor)?)?;
assert_eq!(array.view(), tensor.read().await?.as_array());
Ok(())
}
#[tokio::test]
async fn tensor_serde_host() -> Result<()> {
tensor_serde(Device::host()).await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn tensor_serde_device() -> Result<()> {
let device = Device::new()?;
let _s = device.acquire().await;
tensor_serde(device).await
}
#[cfg(feature = "device_tests")]
fn array_scaled_cast<T1: Scalar, T2: Scalar>(x: &Array1<T1>, alpha: f64) -> Array1<T2> {
x.iter()
.map(|x| T2::from_f64(x.to_f64().unwrap() * alpha).unwrap())
.collect()
}
#[cfg(feature = "device_tests")]
async fn scaled_cast<T1: Scalar + From<u8>, T2: Scalar + From<u8>>() -> Result<()> {
let n = 100;
let alpha = 2;
let data: Vec<T1> = (0..n as u8).into_iter().map(Into::into).collect();
let x_array = Array::from(data);
let y_true = array_scaled_cast(&x_array, alpha.into());
let device = Device::new()?;
let _s = device.acquire().await;
let x = CowTensor::from(x_array.view()).into_device(device).await?;
let y = x.scale_into::<T2>((alpha as u8).into())?;
let y_array = y.read().await?;
assert_eq!(y_array.as_array(), y_true.view());
Ok(())
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u8_bf16() -> Result<()> {
scaled_cast::<u8, bf16>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u8_u32() -> Result<()> {
scaled_cast::<u8, u32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u8_i32() -> Result<()> {
scaled_cast::<u8, i32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u8_f32() -> Result<()> {
scaled_cast::<u8, f32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u16_bf16() -> Result<()> {
scaled_cast::<u16, bf16>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u16_u32() -> Result<()> {
scaled_cast::<u16, u32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u16_i32() -> Result<()> {
scaled_cast::<u16, i32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u16_f32() -> Result<()> {
scaled_cast::<u16, f32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_bf16_bf16() -> Result<()> {
scaled_cast::<bf16, bf16>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_bf16_u32() -> Result<()> {
scaled_cast::<bf16, u32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_bf16_i32() -> Result<()> {
scaled_cast::<bf16, i32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_bf16_f32() -> Result<()> {
scaled_cast::<bf16, f32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u32_bf16() -> Result<()> {
scaled_cast::<u32, bf16>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u32_u32() -> Result<()> {
scaled_cast::<u32, u32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u32_i32() -> Result<()> {
scaled_cast::<u32, i32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_u32_f32() -> Result<()> {
scaled_cast::<u32, f32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_i32_bf16() -> Result<()> {
scaled_cast::<i32, bf16>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_i32_u32() -> Result<()> {
scaled_cast::<i32, u32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_i32_i32() -> Result<()> {
scaled_cast::<i32, i32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn scaled_cast_i32_f32() -> Result<()> {
scaled_cast::<i32, f32>().await
}
#[cfg(feature = "device_tests")]
fn to_one_hot<U: Uint, T: Scalar>(x: &Array1<U>, nclasses: usize) -> Array2<T> {
let mut y = Array::from_elem([x.len(), nclasses], T::zero());
for (mut y, x) in y.outer_iter_mut().zip(x.iter().copied()) {
y[x.to_usize().unwrap()] = T::one();
}
y
}
#[cfg(feature = "device_tests")]
async fn one_hot<U: Uint + Into<u64> + From<u8>, T: Scalar>() -> Result<()> {
let batch_size = 100;
let nclasses = 10;
let data: Vec<U> = (0..nclasses as u8)
.into_iter()
.cycle()
.take(batch_size)
.map(Into::into)
.collect();
let x_array = Array::from(data.clone());
let y_true = to_one_hot(&x_array, nclasses);
let device = Device::new()?;
let _s = device.acquire().await;
let x = CowTensor::from(x_array.view()).into_device(device).await?;
let y = x.to_one_hot::<T>(nclasses)?;
let y_array = y.read().await?;
assert_eq!(y_array.as_array(), y_true.view());
Ok(())
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u8_bf16() -> Result<()> {
one_hot::<u8, bf16>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u16_bf16() -> Result<()> {
one_hot::<u16, bf16>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u32_bf16() -> Result<()> {
one_hot::<u32, bf16>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u8_u32() -> Result<()> {
one_hot::<u8, u32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u16_u32() -> Result<()> {
one_hot::<u16, u32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u32_u32() -> Result<()> {
one_hot::<u32, u32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u8_i32() -> Result<()> {
one_hot::<u8, i32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u16_i32() -> Result<()> {
one_hot::<u16, i32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u32_i32() -> Result<()> {
one_hot::<u32, i32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u8_f32() -> Result<()> {
one_hot::<u8, f32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u16_f32() -> Result<()> {
one_hot::<u16, f32>().await
}
#[cfg(feature = "device_tests")]
#[tokio::test]
async fn one_hot_u32_f32() -> Result<()> {
one_hot::<u32, f32>().await
}
}