use std::alloc::Layout;
use ndarray::prelude::*;
use crate::datum::{Datum, DatumType};
use crate::internal::*;
use crate::tensor::Tensor;
use super::storage::PlainStorage;
fn check_for_access<D: Datum>(dt: DatumType) -> TractResult<()> {
ensure!(
dt.unquantized() == D::datum_type().unquantized(),
"Tensor datum type error: tensor is {:?}, accessed as {:?}",
dt,
D::datum_type(),
);
Ok(())
}
pub struct PlainView<'a> {
tensor: &'a Tensor,
storage: &'a PlainStorage,
}
impl<'a> PlainView<'a> {
#[inline]
pub(crate) fn new(tensor: &'a Tensor, storage: &'a PlainStorage) -> Self {
PlainView { tensor, storage }
}
#[inline]
pub fn tensor(&self) -> &Tensor {
self.tensor
}
#[inline]
pub fn datum_type(&self) -> DatumType {
self.tensor.datum_type()
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.tensor.shape()
}
#[inline]
pub fn strides(&self) -> &[isize] {
self.tensor.strides()
}
#[inline]
pub fn rank(&self) -> usize {
self.tensor.rank()
}
#[inline]
pub fn len(&self) -> usize {
self.tensor.len()
}
#[inline]
pub fn as_bytes(&self) -> &'a [u8] {
self.storage.as_bytes()
}
#[inline]
pub fn layout(&self) -> &Layout {
self.storage.layout()
}
#[inline]
pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
check_for_access::<D>(self.datum_type())?;
unsafe { Ok(self.as_ptr_unchecked()) }
}
#[inline]
pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
self.storage.as_ptr() as *const D
}
#[inline]
pub fn as_slice<D: Datum>(&self) -> TractResult<&'a [D]> {
check_for_access::<D>(self.datum_type())?;
unsafe { Ok(self.as_slice_unchecked()) }
}
#[inline]
pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &'a [D] {
if self.storage.is_empty() {
&[]
} else {
unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len()) }
}
}
#[inline]
pub fn to_scalar<D: Datum>(&self) -> TractResult<&'a D> {
check_for_access::<D>(self.datum_type())?;
unsafe { Ok(self.to_scalar_unchecked()) }
}
#[inline]
pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &'a D {
unsafe { &*(self.storage.as_ptr() as *const D) }
}
#[inline]
pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'a, D>> {
check_for_access::<D>(self.datum_type())?;
unsafe { Ok(self.to_array_view_unchecked()) }
}
#[inline]
pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'a, D> {
if self.len() != 0 {
unsafe { ArrayViewD::from_shape_ptr(self.shape(), self.storage.as_ptr() as *const D) }
} else {
ArrayViewD::from_shape(self.shape(), &[]).unwrap()
}
}
}
pub struct PlainViewMut<'a> {
dt: DatumType,
shape: &'a [usize],
strides: &'a [isize],
len: usize,
storage: &'a mut PlainStorage,
}
impl<'a> PlainViewMut<'a> {
#[inline]
pub(crate) fn new(
dt: DatumType,
shape: &'a [usize],
strides: &'a [isize],
len: usize,
storage: &'a mut PlainStorage,
) -> Self {
PlainViewMut { dt, shape, strides, len, storage }
}
#[inline]
pub fn datum_type(&self) -> DatumType {
self.dt
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.shape
}
#[inline]
pub fn strides(&self) -> &[isize] {
self.strides
}
#[inline]
pub fn rank(&self) -> usize {
self.shape.len()
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn as_bytes(&self) -> &[u8] {
self.storage.as_bytes()
}
#[inline]
pub fn layout(&self) -> &Layout {
self.storage.layout()
}
#[inline]
pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
check_for_access::<D>(self.dt)?;
unsafe { Ok(self.as_ptr_unchecked()) }
}
#[inline]
pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
self.storage.as_ptr() as *const D
}
#[inline]
pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
check_for_access::<D>(self.dt)?;
unsafe { Ok(self.as_slice_unchecked()) }
}
#[inline]
pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
if self.storage.is_empty() {
&[]
} else {
unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.len) }
}
}
#[inline]
pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
check_for_access::<D>(self.dt)?;
unsafe { Ok(self.to_scalar_unchecked()) }
}
#[inline]
pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
unsafe { &*(self.storage.as_ptr() as *const D) }
}
#[inline]
pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
check_for_access::<D>(self.dt)?;
unsafe { Ok(self.to_array_view_unchecked()) }
}
#[inline]
pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
if self.len != 0 {
unsafe { ArrayViewD::from_shape_ptr(self.shape, self.storage.as_ptr() as *const D) }
} else {
ArrayViewD::from_shape(self.shape, &[]).unwrap()
}
}
#[inline]
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
self.storage.as_bytes_mut()
}
#[inline]
pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
check_for_access::<D>(self.dt)?;
unsafe { Ok(self.as_ptr_mut_unchecked()) }
}
#[inline]
pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
self.storage.as_mut_ptr() as *mut D
}
#[inline]
pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
check_for_access::<D>(self.dt)?;
unsafe { Ok(self.as_slice_mut_unchecked()) }
}
#[inline]
pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
if self.storage.is_empty() {
&mut []
} else {
let len = self.len;
unsafe { std::slice::from_raw_parts_mut(self.as_ptr_mut_unchecked(), len) }
}
}
#[inline]
pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
check_for_access::<D>(self.dt)?;
unsafe { Ok(self.to_scalar_mut_unchecked()) }
}
#[inline]
pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
unsafe { &mut *(self.storage.as_mut_ptr() as *mut D) }
}
#[inline]
pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
check_for_access::<D>(self.dt)?;
unsafe { Ok(self.to_array_view_mut_unchecked()) }
}
#[inline]
pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
if self.len != 0 {
unsafe {
ArrayViewMutD::from_shape_ptr(self.shape, self.storage.as_mut_ptr() as *mut D)
}
} else {
ArrayViewMutD::from_shape(self.shape, &mut []).unwrap()
}
}
}