#![allow(clippy::missing_safety_doc)]
#![allow(clippy::missing_transmute_annotations)]
mod arena_view;
mod owned;
pub use arena_view::*;
pub use owned::*;
use num_traits::AsPrimitive;
use std::ffi::c_void;
use std::fmt::Display;
use tract_core::internal::*;
use tract_data::itertools::Itertools;
use crate::device::{DeviceBuffer, get_context};
#[derive(Debug, Clone, Hash)]
pub enum DeviceTensor {
Owned(Box<dyn OwnedDeviceTensor>),
ArenaView(DeviceArenaView),
}
impl DeviceTensor {
pub const SUPPORTED_DT: [DatumType; 12] = [
DatumType::Bool,
DatumType::F32,
DatumType::F16,
DatumType::I8,
DatumType::U8,
DatumType::I16,
DatumType::U16,
DatumType::I32,
DatumType::U32,
DatumType::I64,
DatumType::U64,
DatumType::Opaque,
];
pub fn tname(dt: DatumType) -> TractResult<&'static str> {
Ok(match dt {
DatumType::F32 => "f32",
DatumType::F16 => "f16",
DatumType::U8 => "u8",
DatumType::U16 => "u16",
DatumType::U32 => "u32",
DatumType::U64 => "u64",
DatumType::I8 => "i8",
DatumType::I16 => "i16",
DatumType::I32 => "i32",
DatumType::I64 => "i64",
DatumType::Bool => "bool",
DatumType::Opaque => "opaque",
_ => bail!("Unsupported dt {:?} for GPU Tensor", dt),
})
}
pub fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<DeviceTensor> {
Ok(DeviceTensor::Owned(get_context()?.uninitialized_device_tensor(shape, dt)?))
}
pub fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<DeviceTensor> {
Self::uninitialized_dt(T::datum_type(), shape)
}
pub fn from_shape<T: Copy + Datum>(shape: &[usize], data: &[T]) -> TractResult<DeviceTensor> {
Tensor::from_shape(shape, data)?.into_device()
}
pub fn is_supported_dt(dt: DatumType) -> bool {
Self::SUPPORTED_DT.contains(&dt)
}
#[inline]
pub fn datum_type(&self) -> DatumType {
match self {
Self::Owned(owned) => owned.datum_type(),
Self::ArenaView(view) => view.datum_type(),
}
}
#[inline]
pub fn rank(&self) -> usize {
self.shape().len()
}
#[inline]
pub fn shape(&self) -> &[usize] {
match self {
Self::Owned(t) => t.shape(),
Self::ArenaView(t) => t.shape(),
}
}
#[inline]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
match self {
Self::Owned(t) => t.len(),
Self::ArenaView(t) => t.len(),
}
}
#[inline]
pub fn strides(&self) -> &[isize] {
match self {
Self::Owned(t) => t.strides(),
Self::ArenaView(t) => t.strides(),
}
}
pub fn device_buffer(&self) -> &dyn DeviceBuffer {
match self {
Self::Owned(t) => t.device_buffer(),
Self::ArenaView(t) => t.device_buffer(),
}
}
pub fn buffer_offset<I: Copy + 'static>(&self) -> I
where
usize: AsPrimitive<I>,
{
match self {
Self::Owned(_) => 0.as_(),
Self::ArenaView(t) => t.buffer_offset(),
}
}
pub fn device_buffer_ptr(&self) -> *const c_void {
match self {
Self::Owned(t) => t.device_buffer().ptr(),
Self::ArenaView(t) => t.device_buffer().ptr(),
}
}
#[inline]
pub fn view(&self) -> TensorView<'_> {
match self {
Self::Owned(t) => t.view(),
Self::ArenaView(t) => t.view(),
}
}
pub fn description(&self) -> String {
format!("|{},{:?}|", self.shape().iter().join(","), self.datum_type(),)
}
pub fn reshaped(&self, shape: TVec<usize>) -> TractResult<Self> {
match self {
Self::Owned(t) => Ok(t.reshaped(shape)?),
Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
}
}
pub fn restrided(&self, strides: TVec<isize>) -> TractResult<Self> {
match self {
Self::Owned(t) => Ok(t.restrided(strides)?),
Self::ArenaView(t) => Ok(Self::ArenaView(t.restrided(strides)?)),
}
}
pub fn into_opaque_tensor(self) -> Tensor {
tensor0::<Opaque>(self.into())
}
pub fn to_host(&self) -> TractResult<Arc<Tensor>> {
get_context()?.synchronize()?;
Ok(match self {
Self::Owned(o) => o.to_host()?,
Self::ArenaView(v) => v.clone().into_tensor().into(),
})
}
}
impl Display for DeviceTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Owned(o) => o.fmt(f),
Self::ArenaView(v) => {
let content = v
.clone()
.into_tensor()
.dump(false)
.unwrap_or_else(|e| format!("Error : {e:?}"));
write!(f, "ArenaView: {{ {content} }}")
}
}
}
}
pub trait IntoDevice<T> {
fn into_device(self) -> TractResult<T>;
}
impl IntoDevice<DeviceTensor> for Tensor {
fn into_device(self) -> TractResult<DeviceTensor> {
Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
}
}
impl IntoDevice<DeviceTensor> for Arc<Tensor> {
fn into_device(self) -> TractResult<DeviceTensor> {
Ok(DeviceTensor::Owned(get_context()?.tensor_to_device(self.into_tvalue())?))
}
}
impl From<DeviceTensor> for Opaque {
fn from(value: DeviceTensor) -> Self {
Opaque(Arc::new(value))
}
}
impl From<DeviceArenaView> for DeviceTensor {
fn from(view: DeviceArenaView) -> Self {
Self::ArenaView(view)
}
}
impl OpaquePayload for DeviceTensor {
fn same_as(&self, other: &dyn OpaquePayload) -> bool {
other
.downcast_ref::<Self>()
.is_some_and(|other| self.device_buffer_ptr() == other.device_buffer_ptr())
}
fn clarify_to_tensor(&self) -> TractResult<Option<Arc<Tensor>>> {
Ok(Some(self.to_host()?))
}
}
pub trait DeviceTensorExt {
fn to_device_tensor(&self) -> TractResult<&DeviceTensor>;
fn as_device_tensor(&self) -> Option<&DeviceTensor>;
fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor>;
fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor>;
}
impl DeviceTensorExt for Tensor {
fn to_device_tensor_mut(&mut self) -> TractResult<&mut DeviceTensor> {
let opaque = self.to_scalar_mut::<Opaque>()?;
opaque.downcast_mut::<DeviceTensor>().ok_or_else(|| {
anyhow::anyhow!("Could convert opaque tensor to mutable reference on a device tensor")
})
}
fn as_device_tensor_mut(&mut self) -> Option<&mut DeviceTensor> {
let opaque = self.to_scalar_mut::<Opaque>().ok()?;
opaque.downcast_mut::<DeviceTensor>()
}
fn to_device_tensor(&self) -> TractResult<&DeviceTensor> {
let opaque = self.to_scalar::<Opaque>()?;
opaque.downcast_ref::<DeviceTensor>().ok_or_else(|| {
anyhow::anyhow!("Could convert opaque tensor to reference on a device tensor")
})
}
fn as_device_tensor(&self) -> Option<&DeviceTensor> {
let opaque = self.to_scalar::<Opaque>().ok()?;
opaque.downcast_ref::<DeviceTensor>()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_tensor() -> TractResult<()> {
let a = DeviceTensor::from_shape(&[1], &[0f32])?;
assert_eq!(a.to_host()?.as_slice::<f32>()?, &[0.0]);
Ok(())
}
}