use std::fmt;
use tract_core::internal::*;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DeviceTensorOrigin {
FromDevice,
FromHost,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct DeviceFact {
pub origin: DeviceTensorOrigin,
pub fact: TypedFact,
pub state_owned: bool,
}
impl DeviceFact {
pub fn new(origin: DeviceTensorOrigin, fact: TypedFact) -> TractResult<Self> {
ensure!(fact.as_device_fact().is_none());
let new_fact = fact.without_value();
Ok(Self { origin, fact: new_fact, state_owned: false })
}
pub fn from_host(fact: TypedFact) -> TractResult<Self> {
Self::new(DeviceTensorOrigin::FromHost, fact)
}
pub fn is_from_device(&self) -> bool {
matches!(self.origin, DeviceTensorOrigin::FromDevice)
}
pub fn is_state_owned(&self) -> bool {
self.state_owned
}
pub fn is_from_host(&self) -> bool {
matches!(self.origin, DeviceTensorOrigin::FromHost)
}
pub fn into_typed_fact(self) -> TypedFact {
self.fact
}
pub fn into_exotic_fact(self) -> TypedFact {
let dt = self.fact.datum_type;
let shape = self.fact.shape.clone();
TypedFact::dt_shape(dt, shape).with_exotic_fact(self)
}
}
impl ExoticFact for DeviceFact {
fn clarify_dt_shape(&self) -> Option<(DatumType, TVec<TDim>)> {
Some((self.fact.datum_type, self.fact.shape.to_tvec()))
}
fn buffer_sizes(&self) -> TVec<TDim> {
let inner_fact = &self.fact;
let mut sizes = tvec!(inner_fact.shape.volume() * inner_fact.datum_type.size_of());
if let Some(of) = inner_fact.exotic_fact() {
sizes.extend(of.buffer_sizes());
}
sizes
}
fn compatible_with(&self, other: &dyn ExoticFact) -> bool {
other.is::<Self>()
}
}
impl fmt::Debug for DeviceFact {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self.origin {
DeviceTensorOrigin::FromHost => write!(fmt, "FromHost({:?})", self.without_value()),
DeviceTensorOrigin::FromDevice => {
write!(fmt, "FromDevice({:?})", self.fact.without_value())
}
}
}
}
pub trait DeviceTypedFactExt {
fn to_device_fact(&self) -> TractResult<&DeviceFact>;
fn as_device_fact(&self) -> Option<&DeviceFact>;
fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact>;
}
impl DeviceTypedFactExt for TypedFact {
fn to_device_fact(&self) -> TractResult<&DeviceFact> {
self.exotic_fact
.as_ref()
.and_then(|m| m.downcast_ref::<DeviceFact>())
.ok_or_else(|| anyhow!("DeviceFact not found"))
}
fn as_device_fact(&self) -> Option<&DeviceFact> {
self.exotic_fact.as_ref().and_then(|m| m.downcast_ref::<DeviceFact>())
}
fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact> {
self.exotic_fact.as_mut().and_then(|m| m.downcast_mut::<DeviceFact>())
}
}
impl std::ops::Deref for DeviceFact {
type Target = TypedFact;
fn deref(&self) -> &Self::Target {
&self.fact
}
}
impl std::convert::AsRef<TypedFact> for DeviceFact {
fn as_ref(&self) -> &TypedFact {
&self.fact
}
}