use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
use ndarray;
use tvm_sys::{ffi::DLTensor, Context, DataType};
use crate::allocator::Allocation;
use crate::errors::ArrayError;
use std::alloc::LayoutErr;
#[derive(PartialEq)]
pub enum Storage<'a> {
Owned(Allocation),
View(&'a mut [u8], usize), }
impl<'a> Storage<'a> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, LayoutErr> {
Ok(Storage::Owned(Allocation::new(size, align)?))
}
pub fn as_mut_ptr(&self) -> *mut u8 {
match self {
Storage::Owned(alloc) => alloc.as_mut_ptr(),
Storage::View(slice, _) => slice.as_ptr() as *mut u8,
}
}
pub fn size(&self) -> usize {
match self {
Storage::Owned(alloc) => alloc.size(),
Storage::View(slice, _) => slice.len(),
}
}
pub fn align(&self) -> usize {
match self {
Storage::Owned(alloc) => alloc.align(),
Storage::View(_, align) => *align,
}
}
pub fn as_ptr(&self) -> *const u8 {
self.as_mut_ptr() as *const _
}
pub fn view(&self) -> Storage<'a> {
match self {
Storage::Owned(alloc) => Storage::View(
unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
self.align(),
),
Storage::View(slice, _) => Storage::View(
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
self.align(),
),
}
}
pub fn is_owned(&self) -> bool {
match self {
Storage::Owned(_) => true,
_ => false,
}
}
pub fn to_owned(&self) -> Storage<'static> {
let s = Storage::new(self.size(), Some(self.align())).unwrap();
unsafe {
s.as_mut_ptr()
.copy_from_nonoverlapping(self.as_ptr(), self.size());
}
s
}
pub fn as_slice(&self) -> &[u8] {
match self {
Storage::Owned(alloc) => alloc.as_slice(),
Storage::View(slice, _) => &*slice,
}
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
match self {
Storage::Owned(alloc) => alloc.as_mut_slice(),
Storage::View(slice, _) => slice,
}
}
}
impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
fn from(data: &'d [T]) -> Self {
let data = unsafe {
slice::from_raw_parts_mut(
data.as_ptr() as *const u8 as *mut u8,
data.len() * mem::size_of::<T>() as usize,
)
};
Storage::View(data, mem::align_of::<T>())
}
}
#[derive(PartialEq)]
pub struct Tensor<'a> {
pub(crate) data: Storage<'a>,
pub(crate) ctx: Context,
pub(crate) dtype: DataType,
pub(crate) shape: Vec<i64>,
pub(crate) strides: Option<Vec<usize>>,
pub(crate) byte_offset: isize,
pub(crate) size: usize,
}
unsafe impl<'a> Send for Tensor<'a> {}
impl<'a> Tensor<'a> {
pub fn shape(&self) -> Vec<i64> {
self.shape.clone()
}
pub fn data(&self) -> &Storage {
&self.data
}
pub fn data_mut(&mut self) -> &'a mut Storage {
&mut self.data
}
pub fn to_vec<T: 'static + std::fmt::Debug + Clone>(&self) -> Vec<T> {
assert!(self.is_contiguous());
assert!(self.dtype.is_type::<T>());
unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() }
}
pub fn is_contiguous(&self) -> bool {
match self.strides {
None => true,
Some(ref strides) => {
self.shape
.iter()
.zip(strides)
.rfold(
(true, 1),
|(is_contig, expected_stride), (shape, stride)| {
(
is_contig && *stride == expected_stride,
expected_stride * (*shape as usize),
)
},
)
.0
}
}
}
pub fn copy(&mut self, other: &Tensor) {
assert!(
self.dtype == other.dtype && self.size == other.size,
"Tensor shape/dtype mismatch."
);
assert!(
self.is_contiguous() && other.is_contiguous(),
"copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
self.strides,
other.strides
);
unsafe {
self.data
.as_mut_ptr()
.offset(self.byte_offset as isize)
.copy_from_nonoverlapping(
other.data.as_mut_ptr().offset(other.byte_offset),
other.size * other.dtype.itemsize(),
);
}
}
pub fn to_owned(&self) -> Tensor<'static> {
let t = Tensor {
data: self.data.to_owned(),
ctx: self.ctx,
dtype: self.dtype,
size: self.size,
shape: self.shape.clone(),
strides: None,
byte_offset: 0,
};
unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
}
fn from_array_storage<'s, T, D: ndarray::Dimension>(
arr: &ndarray::Array<T, D>,
storage: Storage<'s>,
dtype_fn: fn(u8, u16) -> DataType,
) -> Tensor<'s> {
let type_width = mem::size_of::<T>() as u8;
Tensor {
data: storage,
ctx: Context::default(),
dtype: dtype_fn(8 * type_width, 1),
size: arr.len(),
shape: arr.shape().iter().map(|&v| v as i64).collect(),
strides: Some(arr.strides().iter().map(|&v| v as usize).collect()),
byte_offset: 0,
}
}
pub fn as_dltensor(&self, flatten: bool) -> DLTensor {
assert!(!flatten || self.is_contiguous());
DLTensor {
data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
ctx: self.ctx.into(),
ndim: if flatten { 1 } else { self.shape.len() } as i32,
dtype: self.dtype.into(),
shape: if flatten {
&self.size as *const _ as *mut i64
} else {
self.shape.as_ptr()
} as *mut i64,
strides: if flatten || self.is_contiguous() {
ptr::null_mut()
} else {
self.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
..Default::default()
}
}
}
macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'t> TryFrom<Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = ArrayError;
fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Self::Error> {
if tensor.dtype != $dtype {
return Err(ArrayError::IncompatibleDataType(tensor.dtype));
}
Ok(ndarray::Array::from_shape_vec(
tensor
.shape
.iter()
.map(|s| *s as usize)
.collect::<Vec<usize>>(),
tensor.to_vec::<$type>(),
)
.map_err(|_| ArrayError::ShapeError(tensor.shape.clone()))?)
}
}
};
}
macro_rules! make_dtype_const {
($name: ident, $cnst:expr) => {
pub const $name: DataType = $cnst;
};
}
make_dtype_const!(DTYPE_INT32, DataType::int(32, 1));
make_dtype_const!(DTYPE_UINT32, DataType::uint(32, 1));
make_dtype_const!(DTYPE_FLOAT32, DataType::float(32, 1));
make_dtype_const!(DTYPE_FLOAT64, DataType::float(64, 1));
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
fn from(tensor: &'a Tensor<'t>) -> Self {
Tensor::as_dltensor(tensor, false )
}
}
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
fn from(tensor: &'a mut Tensor<'t>) -> Self {
Tensor::as_dltensor(tensor, false )
}
}
impl<'a> From<DLTensor> for Tensor<'a> {
fn from(dlt: DLTensor) -> Self {
unsafe {
let dtype = DataType::from(dlt.dtype);
let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
let storage = Storage::from(slice::from_raw_parts(
dlt.data as *const u8,
dtype.itemsize() * size,
));
Self {
data: storage,
ctx: Context::default(),
dtype,
size,
shape,
strides: if dlt.strides.is_null() {
None
} else {
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
},
byte_offset: dlt.byte_offset as isize,
}
}
}
}
macro_rules! impl_tensor_from_ndarray {
($type:ty, $dtype_fn:expr) => {
impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
fn from(arr: ndarray::Array<$type, D>) -> Self {
let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
Tensor::from_array_storage(&arr, storage.to_owned(), $dtype_fn)
}
}
impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
Tensor::from_array_storage(arr, storage, $dtype_fn)
}
}
};
}
impl_tensor_from_ndarray!(f32, DataType::float);
impl_tensor_from_ndarray!(f64, DataType::float);
impl_tensor_from_ndarray!(i32, DataType::int);
impl_tensor_from_ndarray!(i64, DataType::int);
impl_tensor_from_ndarray!(u32, DataType::uint);
impl_tensor_from_ndarray!(u64, DataType::uint);