tract-gpu 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use std::fmt;
use tract_core::internal::*;

/// Origin of the GPU tensor
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DeviceTensorOrigin {
    /// Tensor outputted by a device operator
    /// Can be either a Host or ArenaView tensor
    /// Note: Tensors marked as Device are from asynchronous operations.
    FromDevice,
    /// Tensor built from a CPU tensor (CPU op output or Const)
    /// Can be only Host tensor.
    /// Note: Tensors marked as Host are from synchronous operations.
    FromHost,
}

#[derive(Clone, PartialEq, Eq, Hash)]
pub struct DeviceFact {
    pub origin: DeviceTensorOrigin,
    pub fact: TypedFact,
    pub state_owned: bool,
}

impl DeviceFact {
    pub fn new(origin: DeviceTensorOrigin, fact: TypedFact) -> TractResult<Self> {
        ensure!(fact.as_device_fact().is_none());
        let new_fact = fact.without_value();
        Ok(Self { origin, fact: new_fact, state_owned: false })
    }

    pub fn from_host(fact: TypedFact) -> TractResult<Self> {
        Self::new(DeviceTensorOrigin::FromHost, fact)
    }

    pub fn is_from_device(&self) -> bool {
        matches!(self.origin, DeviceTensorOrigin::FromDevice)
    }

    pub fn is_state_owned(&self) -> bool {
        self.state_owned
    }

    pub fn is_from_host(&self) -> bool {
        matches!(self.origin, DeviceTensorOrigin::FromHost)
    }

    pub fn into_typed_fact(self) -> TypedFact {
        self.fact
    }

    pub fn into_exotic_fact(self) -> TypedFact {
        let dt = self.fact.datum_type;
        let shape = self.fact.shape.clone();
        TypedFact::dt_shape(dt, shape).with_exotic_fact(self)
    }
}

impl ExoticFact for DeviceFact {
    fn clarify_dt_shape(&self) -> Option<(DatumType, TVec<TDim>)> {
        Some((self.fact.datum_type, self.fact.shape.to_tvec()))
    }

    fn buffer_sizes(&self) -> TVec<TDim> {
        let inner_fact = &self.fact;
        let mut sizes = tvec!(inner_fact.shape.volume() * inner_fact.datum_type.size_of());
        if let Some(of) = inner_fact.exotic_fact() {
            sizes.extend(of.buffer_sizes());
        }
        sizes
    }
    fn compatible_with(&self, other: &dyn ExoticFact) -> bool {
        other.is::<Self>()
    }
}

impl fmt::Debug for DeviceFact {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        match self.origin {
            DeviceTensorOrigin::FromHost => write!(fmt, "FromHost({:?})", self.without_value()),
            DeviceTensorOrigin::FromDevice => {
                write!(fmt, "FromDevice({:?})", self.fact.without_value())
            }
        }
    }
}

pub trait DeviceTypedFactExt {
    fn to_device_fact(&self) -> TractResult<&DeviceFact>;
    fn as_device_fact(&self) -> Option<&DeviceFact>;
    fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact>;
}

impl DeviceTypedFactExt for TypedFact {
    fn to_device_fact(&self) -> TractResult<&DeviceFact> {
        self.exotic_fact
            .as_ref()
            .and_then(|m| m.downcast_ref::<DeviceFact>())
            .ok_or_else(|| anyhow!("DeviceFact not found"))
    }
    fn as_device_fact(&self) -> Option<&DeviceFact> {
        self.exotic_fact.as_ref().and_then(|m| m.downcast_ref::<DeviceFact>())
    }
    fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact> {
        self.exotic_fact.as_mut().and_then(|m| m.downcast_mut::<DeviceFact>())
    }
}

impl std::ops::Deref for DeviceFact {
    type Target = TypedFact;
    fn deref(&self) -> &Self::Target {
        &self.fact
    }
}

impl std::convert::AsRef<TypedFact> for DeviceFact {
    fn as_ref(&self) -> &TypedFact {
        &self.fact
    }
}