Skip to main content

tract_gpu/
fact.rs

1use std::fmt;
2use tract_core::internal::*;
3
4/// Origin of the GPU tensor
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum DeviceTensorOrigin {
7    /// Tensor outputted by a device operator
8    /// Can be either a Host or ArenaView tensor
9    /// Note: Tensors marked as Device are from asynchronous operations.
10    FromDevice,
11    /// Tensor built from a CPU tensor (CPU op output or Const)
12    /// Can be only Host tensor.
13    /// Note: Tensors marked as Host are from synchronous operations.
14    FromHost,
15}
16
17#[derive(Clone, PartialEq, Eq, Hash)]
18pub struct DeviceFact {
19    pub origin: DeviceTensorOrigin,
20    pub fact: TypedFact,
21}
22
23impl DeviceFact {
24    pub fn new(origin: DeviceTensorOrigin, fact: TypedFact) -> TractResult<Self> {
25        ensure!(fact.as_device_fact().is_none());
26        let mut fact_wo_cst = fact.clone();
27        if fact.opaque_fact.is_some() {
28            fact_wo_cst.konst = None;
29            fact_wo_cst.uniform = None;
30        }
31        Ok(Self { origin, fact: fact_wo_cst })
32    }
33
34    pub fn from_host(fact: TypedFact) -> TractResult<Self> {
35        Self::new(DeviceTensorOrigin::FromHost, fact)
36    }
37
38    pub fn is_from_device(&self) -> bool {
39        matches!(self.origin, DeviceTensorOrigin::FromDevice)
40    }
41
42    pub fn is_from_host(&self) -> bool {
43        matches!(self.origin, DeviceTensorOrigin::FromHost)
44    }
45
46    pub fn into_typed_fact(self) -> TypedFact {
47        self.fact
48    }
49
50    pub fn into_opaque_fact(self) -> TypedFact {
51        TypedFact::dt_scalar(DatumType::Opaque).with_opaque_fact(self)
52    }
53}
54
55impl OpaqueFact for DeviceFact {
56    fn clarify_dt_shape(&self) -> Option<(DatumType, &[usize])> {
57        self.fact.shape.as_concrete().map(|s| (self.fact.datum_type, s))
58    }
59
60    fn mem_size(&self) -> TDim {
61        self.fact.mem_size()
62    }
63    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
64        other.downcast_ref::<Self>().is_some_and(|o| o == self)
65    }
66    fn compatible_with(&self, other: &dyn OpaqueFact) -> bool {
67        other.is::<Self>()
68    }
69}
70
71impl fmt::Debug for DeviceFact {
72    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
73        match self.origin {
74            DeviceTensorOrigin::FromHost => write!(fmt, "FromHost({:?})", self.without_value()),
75            DeviceTensorOrigin::FromDevice => {
76                write!(fmt, "FromDevice({:?})", self.fact.without_value())
77            }
78        }
79    }
80}
81
82pub trait DeviceTypedFactExt {
83    fn to_device_fact(&self) -> TractResult<&DeviceFact>;
84    fn as_device_fact(&self) -> Option<&DeviceFact>;
85}
86
87impl DeviceTypedFactExt for TypedFact {
88    fn to_device_fact(&self) -> TractResult<&DeviceFact> {
89        ensure!(
90            self.datum_type == DatumType::Opaque,
91            "Cannot retrieve DeviceFact from a non Opaque Tensor"
92        );
93        self.opaque_fact
94            .as_ref()
95            .and_then(|m| m.downcast_ref::<DeviceFact>())
96            .ok_or_else(|| anyhow!("DeviceFact not found in Opaque Tensor"))
97    }
98    fn as_device_fact(&self) -> Option<&DeviceFact> {
99        self.opaque_fact.as_ref().and_then(|m| m.downcast_ref::<DeviceFact>())
100    }
101}
102
103impl std::ops::Deref for DeviceFact {
104    type Target = TypedFact;
105    fn deref(&self) -> &Self::Target {
106        &self.fact
107    }
108}
109
110impl std::convert::AsRef<TypedFact> for DeviceFact {
111    fn as_ref(&self) -> &TypedFact {
112        &self.fact
113    }
114}