use num_traits::AsPrimitive;
use std::ffi::c_void;
use std::fmt::Display;
use tract_core::internal::*;
use tract_core::tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
use crate::device::{DeviceBuffer, get_context};
use crate::utils::check_strides_validity;
use super::OwnedDeviceTensor;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct DeviceArenaView {
pub(crate) arena: Arc<Box<dyn OwnedDeviceTensor>>,
pub(crate) dt: DatumType,
pub(crate) len: usize,
pub(crate) shape: TVec<usize>,
pub(crate) strides: TVec<isize>,
pub(crate) offset_bytes: usize,
pub(crate) exotic_fact: Option<Box<dyn ExoticFact>>,
}
impl DeviceArenaView {
#[inline]
pub fn shape(&self) -> &[usize] {
self.shape.as_slice()
}
#[inline]
pub fn datum_type(&self) -> DatumType {
self.dt
}
#[inline]
pub fn strides(&self) -> &[isize] {
self.strides.as_slice()
}
pub fn device_buffer(&self) -> &dyn DeviceBuffer {
self.arena.device_buffer()
}
pub fn device_buffer_ptr(&self) -> *const c_void {
self.arena.device_buffer().ptr()
}
pub fn buffer_offset<I: Copy + 'static>(&self) -> I
where
usize: AsPrimitive<I>,
{
self.offset_bytes.as_()
}
pub fn exotic_fact(&self) -> Option<&dyn ExoticFact> {
self.exotic_fact.as_deref()
}
#[inline]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.len
}
pub fn as_bytes(&self) -> Vec<u8> {
let len = if let Some(of) = &self.exotic_fact {
of.mem_size().as_i64().unwrap() as usize
} else {
self.len() * self.dt.size_of()
};
self.arena.get_bytes_slice(self.offset_bytes, len)
}
pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
ensure!(self.exotic_fact.is_none(), "Can't reshape exotic tensor");
let shape = shape.into();
if self.len() != shape.iter().product::<usize>() {
bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
}
if shape.as_slice() != self.shape() {
Ok(Self {
arena: Arc::clone(&self.arena),
dt: self.dt,
len: self.len,
strides: Tensor::natural_strides(&shape),
shape,
offset_bytes: self.offset_bytes,
exotic_fact: None,
})
} else {
Ok(self.clone())
}
}
pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
ensure!(self.exotic_fact.is_none(), "Can't restride exotic tensor");
let strides = strides.into();
check_strides_validity(self.shape().into(), strides.clone())?;
if strides.as_slice() != self.strides() {
Ok(Self {
arena: Arc::clone(&self.arena),
dt: self.dt,
len: self.len,
strides,
shape: self.shape.clone(),
offset_bytes: self.offset_bytes,
exotic_fact: None,
})
} else {
Ok(self.clone())
}
}
pub fn to_host(&self) -> TractResult<Tensor> {
get_context()?.synchronize()?;
let content = self.as_bytes();
unsafe {
if let Some(bqf) =
self.exotic_fact.as_ref().and_then(|of| of.downcast_ref::<BlockQuantFact>())
{
Ok(BlockQuantStorage::new(
bqf.format.clone(),
bqf.m(),
bqf.k(),
Arc::new(Blob::from_bytes(&content)?),
)?
.into_tensor_with_shape(self.dt, bqf.shape()))
} else {
Tensor::from_raw_dt(self.dt, &self.shape, &content)
}
}
}
}
impl Display for DeviceArenaView {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let content = self
.clone()
.to_host()
.unwrap()
.dump(false)
.unwrap_or_else(|e| format!("Error : {e:?}"));
write!(f, "DeviceArenaView: {{ {content} }}")
}
}