#![allow(clippy::missing_safety_doc)]
#![allow(clippy::missing_transmute_annotations)]
mod arena_view;
mod owned;
pub use arena_view::*;
pub use owned::*;
use num_traits::AsPrimitive;
use std::ffi::c_void;
use std::fmt::Display;
use tract_core::internal::*;
use tract_data::itertools::Itertools;
use crate::device::{DeviceBuffer, get_context};
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum DeviceTensor {
Owned(Box<dyn OwnedDeviceTensor>),
ArenaView(DeviceArenaView),
}
impl DeviceTensor {
pub const SUPPORTED_DT: [DatumType; 11] = [
DatumType::Bool,
DatumType::F32,
DatumType::F16,
DatumType::I8,
DatumType::U8,
DatumType::I16,
DatumType::U16,
DatumType::I32,
DatumType::U32,
DatumType::I64,
DatumType::U64,
];
pub fn tname(dt: DatumType) -> TractResult<&'static str> {
Ok(match dt {
DatumType::F32 => "f32",
DatumType::F16 => "f16",
DatumType::U8 => "u8",
DatumType::U16 => "u16",
DatumType::U32 => "u32",
DatumType::U64 => "u64",
DatumType::I8 => "i8",
DatumType::I16 => "i16",
DatumType::I32 => "i32",
DatumType::I64 => "i64",
DatumType::Bool => "bool",
_ => bail!("Unsupported dt {:?} for GPU Tensor", dt),
})
}
pub fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<DeviceTensor> {
Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_tensor(shape, dt)?))
}
pub fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<DeviceTensor> {
Self::uninitialized_dt(T::datum_type(), shape)
}
pub fn uninitialized_exotic(exotic_fact: Box<dyn ExoticFact>) -> TractResult<DeviceTensor> {
Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_exotic_tensor(exotic_fact)?))
}
pub fn from_shape<T: Copy + Datum>(shape: &[usize], data: &[T]) -> TractResult<DeviceTensor> {
Tensor::from_shape(shape, data)?.into_device()
}
pub fn is_supported_dt(dt: DatumType) -> bool {
Self::SUPPORTED_DT.contains(&dt)
}
#[inline]
pub fn datum_type(&self) -> DatumType {
match self {
Self::Owned(owned) => owned.datum_type(),
Self::ArenaView(view) => view.datum_type(),
}
}
#[inline]
pub fn rank(&self) -> usize {
self.shape().len()
}
#[inline]
pub fn shape(&self) -> &[usize] {
match self {
Self::Owned(t) => t.shape(),
Self::ArenaView(t) => t.shape(),
}
}
#[inline]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
match self {
Self::Owned(t) => t.len(),
Self::ArenaView(t) => t.len(),
}
}
#[inline]
pub fn strides(&self) -> &[isize] {
match self {
Self::Owned(t) => t.strides(),
Self::ArenaView(t) => t.strides(),
}
}
pub fn device_buffer(&self) -> &dyn DeviceBuffer {
match self {
Self::Owned(t) => t.device_buffer(),
Self::ArenaView(t) => t.device_buffer(),
}
}
pub fn buffer_offset<I: Copy + 'static>(&self) -> I
where
usize: AsPrimitive<I>,
{
match self {
Self::Owned(_) => 0.as_(),
Self::ArenaView(t) => t.buffer_offset(),
}
}
pub fn device_buffer_ptr(&self) -> *const c_void {
match self {
Self::Owned(t) => t.device_buffer().ptr(),
Self::ArenaView(t) => t.device_buffer().ptr(),
}
}
pub fn description(&self) -> String {
format!("|{},{:?}|", self.shape().iter().join(","), self.datum_type(),)
}
pub fn reshaped(&self, shape: TVec<usize>) -> TractResult<Self> {
match self {
Self::Owned(t) => Ok(t.reshaped(shape)?),
Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
}
}
pub fn restrided(&self, strides: TVec<isize>) -> TractResult<Self> {
match self {
Self::Owned(t) => Ok(t.restrided(strides)?),
Self::ArenaView(t) => Ok(Self::ArenaView(t.restrided(strides)?)),
}
}
pub fn into_tensor(self) -> Tensor {
let dt = self.datum_type();
let shape: TVec<usize> = self.shape().into();
Tensor::from_storage(dt, &shape, self)
}
pub fn to_host(&self) -> TractResult<Arc<Tensor>> {
get_context()?.synchronize()?;
Ok(match self {
Self::Owned(o) => o.to_host()?,
Self::ArenaView(v) => v.to_host()?.into(),
})
}
}
impl Display for DeviceTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Owned(o) => o.fmt(f),
Self::ArenaView(v) => {
let content =
v.to_host().unwrap().dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
write!(f, "ArenaView: {{ {content} }}")
}
}
}
}
pub trait IntoDevice<T> {
fn into_device(self) -> TractResult<T>;
}
impl IntoDevice<DeviceTensor> for Tensor {
fn into_device(self) -> TractResult<DeviceTensor> {
Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
}
}
impl IntoDevice<DeviceTensor> for Arc<Tensor> {
fn into_device(self) -> TractResult<DeviceTensor> {
Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
}
}
impl TensorStorage for DeviceTensor {
fn byte_len(&self) -> usize {
self.len() * self.datum_type().size_of()
}
fn is_empty(&self) -> bool {
self.byte_len() == 0
}
fn deep_clone(&self) -> Box<dyn TensorStorage> {
Box::new(self.clone())
}
fn as_plain(&self) -> Option<&PlainStorage> {
None
}
fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
None
}
fn into_plain(self: Box<Self>) -> Option<PlainStorage> {
None
}
fn dyn_hash(&self, _state: &mut dyn std::hash::Hasher) {
}
fn exotic_fact(&self, _shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>> {
bail!(
"DeviceTensor cannot reconstruct a DeviceFact: origin (FromHost/FromDevice) is not carried by storage"
)
}
}
impl From<DeviceArenaView> for DeviceTensor {
fn from(view: DeviceArenaView) -> Self {
Self::ArenaView(view)
}
}
pub trait DeviceTensorExt {
fn to_device_tensor(&self) -> TractResult<&DeviceTensor>;
fn as_device_tensor(&self) -> Option<&DeviceTensor>;
}
impl DeviceTensorExt for Tensor {
fn to_device_tensor(&self) -> TractResult<&DeviceTensor> {
self.try_storage_as::<DeviceTensor>()
}
fn as_device_tensor(&self) -> Option<&DeviceTensor> {
self.storage_as::<DeviceTensor>()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_tensor() -> TractResult<()> {
let a = DeviceTensor::from_shape(&[1], &[0f32])?;
assert_eq!(a.to_host()?.try_as_plain()?.as_slice::<f32>()?, &[0.0]);
Ok(())
}
}