Skip to main content

tract_gpu/
fact.rs

1use std::fmt;
2use tract_core::internal::*;
3
4/// Origin of the GPU tensor
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum DeviceTensorOrigin {
7    /// Tensor outputted by a device operator
8    /// Can be either a Host or ArenaView tensor
9    /// Note: Tensors marked as Device are from asynchronous operations.
10    FromDevice,
11    /// Tensor built from a CPU tensor (CPU op output or Const)
12    /// Can be only Host tensor.
13    /// Note: Tensors marked as Host are from synchronous operations.
14    FromHost,
15}
16
17#[derive(Clone, PartialEq, Eq, Hash)]
18pub struct DeviceFact {
19    pub origin: DeviceTensorOrigin,
20    pub fact: TypedFact,
21    pub state_owned: bool,
22}
23
24impl DeviceFact {
25    pub fn new(origin: DeviceTensorOrigin, fact: TypedFact) -> TractResult<Self> {
26        ensure!(fact.as_device_fact().is_none());
27        let new_fact = fact.without_value();
28        Ok(Self { origin, fact: new_fact, state_owned: false })
29    }
30
31    pub fn from_host(fact: TypedFact) -> TractResult<Self> {
32        Self::new(DeviceTensorOrigin::FromHost, fact)
33    }
34
35    pub fn is_from_device(&self) -> bool {
36        matches!(self.origin, DeviceTensorOrigin::FromDevice)
37    }
38
39    pub fn is_state_owned(&self) -> bool {
40        self.state_owned
41    }
42
43    pub fn is_from_host(&self) -> bool {
44        matches!(self.origin, DeviceTensorOrigin::FromHost)
45    }
46
47    pub fn into_typed_fact(self) -> TypedFact {
48        self.fact
49    }
50
51    pub fn into_opaque_fact(self) -> TypedFact {
52        TypedFact::dt_scalar(DatumType::Opaque).with_opaque_fact(self)
53    }
54}
55
56impl OpaqueFact for DeviceFact {
57    fn clarify_dt_shape(&self) -> Option<(DatumType, TVec<TDim>)> {
58        Some((self.fact.datum_type, self.fact.shape.to_tvec()))
59    }
60
61    fn buffer_sizes(&self) -> TVec<TDim> {
62        let inner_fact = &self.fact;
63        let mut sizes = tvec!(inner_fact.shape.volume() * inner_fact.datum_type.size_of());
64        if let Some(of) = inner_fact.opaque_fact() {
65            sizes.extend(of.buffer_sizes());
66        }
67        sizes
68    }
69
70    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
71        other.downcast_ref::<Self>().is_some_and(|o| o == self)
72    }
73    fn compatible_with(&self, other: &dyn OpaqueFact) -> bool {
74        other.is::<Self>()
75    }
76}
77
78impl fmt::Debug for DeviceFact {
79    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
80        match self.origin {
81            DeviceTensorOrigin::FromHost => write!(fmt, "FromHost({:?})", self.without_value()),
82            DeviceTensorOrigin::FromDevice => {
83                write!(fmt, "FromDevice({:?})", self.fact.without_value())
84            }
85        }
86    }
87}
88
89pub trait DeviceTypedFactExt {
90    fn to_device_fact(&self) -> TractResult<&DeviceFact>;
91    fn as_device_fact(&self) -> Option<&DeviceFact>;
92    fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact>;
93}
94
95impl DeviceTypedFactExt for TypedFact {
96    fn to_device_fact(&self) -> TractResult<&DeviceFact> {
97        ensure!(
98            self.datum_type == DatumType::Opaque,
99            "Cannot retrieve DeviceFact from a non Opaque Tensor"
100        );
101        self.opaque_fact
102            .as_ref()
103            .and_then(|m| m.downcast_ref::<DeviceFact>())
104            .ok_or_else(|| anyhow!("DeviceFact not found in Opaque Tensor"))
105    }
106    fn as_device_fact(&self) -> Option<&DeviceFact> {
107        self.opaque_fact.as_ref().and_then(|m| m.downcast_ref::<DeviceFact>())
108    }
109    fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact> {
110        self.opaque_fact.as_mut().and_then(|m| m.downcast_mut::<DeviceFact>())
111    }
112}
113
114impl std::ops::Deref for DeviceFact {
115    type Target = TypedFact;
116    fn deref(&self) -> &Self::Target {
117        &self.fact
118    }
119}
120
121impl std::convert::AsRef<TypedFact> for DeviceFact {
122    fn as_ref(&self) -> &TypedFact {
123        &self.fact
124    }
125}