1use std::fmt;
2use tract_core::internal::*;
3
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum DeviceTensorOrigin {
7 FromDevice,
11 FromHost,
15}
16
17#[derive(Clone, PartialEq, Eq, Hash)]
18pub struct DeviceFact {
19 pub origin: DeviceTensorOrigin,
20 pub fact: TypedFact,
21 pub state_owned: bool,
22}
23
24impl DeviceFact {
25 pub fn new(origin: DeviceTensorOrigin, fact: TypedFact) -> TractResult<Self> {
26 ensure!(fact.as_device_fact().is_none());
27 let new_fact = fact.without_value();
28 Ok(Self { origin, fact: new_fact, state_owned: false })
29 }
30
31 pub fn from_host(fact: TypedFact) -> TractResult<Self> {
32 Self::new(DeviceTensorOrigin::FromHost, fact)
33 }
34
35 pub fn is_from_device(&self) -> bool {
36 matches!(self.origin, DeviceTensorOrigin::FromDevice)
37 }
38
39 pub fn is_state_owned(&self) -> bool {
40 self.state_owned
41 }
42
43 pub fn is_from_host(&self) -> bool {
44 matches!(self.origin, DeviceTensorOrigin::FromHost)
45 }
46
47 pub fn into_typed_fact(self) -> TypedFact {
48 self.fact
49 }
50
51 pub fn into_opaque_fact(self) -> TypedFact {
52 TypedFact::dt_scalar(DatumType::Opaque).with_opaque_fact(self)
53 }
54}
55
56impl OpaqueFact for DeviceFact {
57 fn clarify_dt_shape(&self) -> Option<(DatumType, TVec<TDim>)> {
58 Some((self.fact.datum_type, self.fact.shape.to_tvec()))
59 }
60
61 fn buffer_sizes(&self) -> TVec<TDim> {
62 let inner_fact = &self.fact;
63 let mut sizes = tvec!(inner_fact.shape.volume() * inner_fact.datum_type.size_of());
64 if let Some(of) = inner_fact.opaque_fact() {
65 sizes.extend(of.buffer_sizes());
66 }
67 sizes
68 }
69
70 fn same_as(&self, other: &dyn OpaqueFact) -> bool {
71 other.downcast_ref::<Self>().is_some_and(|o| o == self)
72 }
73 fn compatible_with(&self, other: &dyn OpaqueFact) -> bool {
74 other.is::<Self>()
75 }
76}
77
78impl fmt::Debug for DeviceFact {
79 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
80 match self.origin {
81 DeviceTensorOrigin::FromHost => write!(fmt, "FromHost({:?})", self.without_value()),
82 DeviceTensorOrigin::FromDevice => {
83 write!(fmt, "FromDevice({:?})", self.fact.without_value())
84 }
85 }
86 }
87}
88
89pub trait DeviceTypedFactExt {
90 fn to_device_fact(&self) -> TractResult<&DeviceFact>;
91 fn as_device_fact(&self) -> Option<&DeviceFact>;
92 fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact>;
93}
94
95impl DeviceTypedFactExt for TypedFact {
96 fn to_device_fact(&self) -> TractResult<&DeviceFact> {
97 ensure!(
98 self.datum_type == DatumType::Opaque,
99 "Cannot retrieve DeviceFact from a non Opaque Tensor"
100 );
101 self.opaque_fact
102 .as_ref()
103 .and_then(|m| m.downcast_ref::<DeviceFact>())
104 .ok_or_else(|| anyhow!("DeviceFact not found in Opaque Tensor"))
105 }
106 fn as_device_fact(&self) -> Option<&DeviceFact> {
107 self.opaque_fact.as_ref().and_then(|m| m.downcast_ref::<DeviceFact>())
108 }
109 fn as_device_fact_mut(&mut self) -> Option<&mut DeviceFact> {
110 self.opaque_fact.as_mut().and_then(|m| m.downcast_mut::<DeviceFact>())
111 }
112}
113
114impl std::ops::Deref for DeviceFact {
115 type Target = TypedFact;
116 fn deref(&self) -> &Self::Target {
117 &self.fact
118 }
119}
120
121impl std::convert::AsRef<TypedFact> for DeviceFact {
122 fn as_ref(&self) -> &TypedFact {
123 &self.fact
124 }
125}