use std::fmt::Display;
use tract_core::internal::*;
use tract_gpu::device::DeviceBuffer;
use tract_gpu::tensor::{DeviceTensor, OwnedDeviceTensor};
use tract_gpu::utils::check_strides_validity;
use crate::context::MetalBuffer;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum MValue {
Natural(Arc<Tensor>),
Reshaped { t: Arc<Tensor>, shape: TVec<usize>, strides: TVec<isize> },
}
impl MValue {
#[inline]
pub fn datum_type(&self) -> DatumType {
match self {
Self::Natural(t) => t.datum_type(),
Self::Reshaped { t, .. } => t.datum_type(),
}
}
#[inline]
pub fn shape(&self) -> &[usize] {
match self {
MValue::Natural(t) => t.shape(),
MValue::Reshaped { shape, .. } => shape,
}
}
#[inline]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.shape().iter().product()
}
pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
let shape = shape.into();
if self.len() != shape.iter().product::<usize>() {
bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
}
if shape.as_slice() != self.shape() {
match &self {
MValue::Natural(t) | MValue::Reshaped { t, .. } => Ok(Self::Reshaped {
t: Arc::clone(t),
strides: Tensor::natural_strides(&shape),
shape,
}),
}
} else {
Ok(self.clone())
}
}
pub fn restrided(&self, strides: impl Into<TVec<isize>>) -> TractResult<Self> {
let strides = strides.into();
check_strides_validity(self.shape().into(), strides.clone())?;
match &self {
MValue::Natural(t) => {
Ok(Self::Reshaped { t: Arc::clone(t), strides, shape: self.shape().into() })
}
MValue::Reshaped { t, strides: old_strides, .. } => {
if &strides != old_strides {
Ok(Self::Reshaped { t: Arc::clone(t), strides, shape: self.shape().into() })
} else {
Ok(self.clone())
}
}
}
}
pub fn as_arc_tensor(&self) -> Option<&Arc<Tensor>> {
match self {
MValue::Natural(t) => Some(t),
MValue::Reshaped { .. } => None,
}
}
}
impl IntoTensor for MValue {
fn into_tensor(self) -> Tensor {
match self {
Self::Natural(t) => Arc::try_unwrap(t).unwrap_or_else(|t| (*t).clone()),
Self::Reshaped { t, shape, strides: _ } => {
let mut t = Arc::try_unwrap(t).unwrap_or_else(|t| (*t).clone());
t.set_shape(&shape).expect("Could not apply shape to reshaped GPU tensor");
t
}
}
}
}
impl From<Tensor> for MValue {
fn from(v: Tensor) -> Self {
Self::Natural(Arc::new(v))
}
}
impl From<Arc<Tensor>> for MValue {
fn from(v: Arc<Tensor>) -> Self {
Self::Natural(v)
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct MetalTensor {
pub inner: MValue,
pub device_buffer: MetalBuffer,
pub exotic_fact: Option<Box<dyn ExoticFact>>,
}
impl std::fmt::Debug for MetalTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalTensor: {:?}", self.inner)
}
}
impl Hash for MetalTensor {
#[inline]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.inner.hash(state)
}
}
impl OwnedDeviceTensor for MetalTensor {
fn datum_type(&self) -> DatumType {
self.inner.datum_type()
}
#[inline]
fn shape(&self) -> &[usize] {
self.inner.shape()
}
#[inline]
#[allow(clippy::len_without_is_empty)]
fn len(&self) -> usize {
self.shape().iter().product()
}
#[inline]
fn strides(&self) -> &[isize] {
match &self.inner {
MValue::Natural(t) => t.strides(),
MValue::Reshaped { strides, .. } => strides,
}
}
#[inline]
fn device_buffer(&self) -> &dyn DeviceBuffer {
&self.device_buffer
}
#[inline]
fn reshaped(&self, shape: TVec<usize>) -> TractResult<DeviceTensor> {
Ok(DeviceTensor::Owned(Box::new(Self {
inner: self.inner.reshaped(shape)?,
device_buffer: self.device_buffer.clone(),
exotic_fact: self.exotic_fact.clone(),
})))
}
#[inline]
fn restrided(&self, strides: TVec<isize>) -> TractResult<DeviceTensor> {
Ok(DeviceTensor::Owned(Box::new(Self {
inner: self.inner.restrided(strides)?,
device_buffer: self.device_buffer.clone(),
exotic_fact: self.exotic_fact.clone(),
})))
}
fn to_host(&self) -> TractResult<Arc<Tensor>> {
Ok(self
.inner
.as_arc_tensor()
.cloned()
.unwrap_or_else(|| self.inner.clone().into_tensor().into_arc_tensor()))
}
fn exotic_fact(&self) -> Option<&dyn ExoticFact> {
self.exotic_fact.as_deref()
}
fn get_bytes_slice(&self, offset: usize, len: usize) -> Vec<u8> {
self.inner.as_arc_tensor().unwrap().as_bytes()[offset..offset + len].to_vec()
}
}
impl Display for MetalTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.inner {
MValue::Natural(t) => {
let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
write!(f, "GPU {{ {content} }}")
}
MValue::Reshaped { t, shape, strides: _ } => {
let content = t.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
write!(f, "GPU reshaped: {:?} - {{ {content} }}", shape)
}
}
}
}