Skip to main content

tract_data/
opaque.rs

1#![allow(clippy::derived_hash_with_manual_eq)]
2use crate::blob::Blob;
3use crate::datum::DatumType;
4use crate::dim::TDim;
5use crate::internal::{TVec, Tensor, TractResult};
6use std::fmt::{Debug, Display};
7use std::hash::Hash;
8use std::ops::Deref;
9use std::sync::Arc;
10
11use downcast_rs::{Downcast, impl_downcast};
12use dyn_hash::DynHash;
13
14pub trait OpaquePayload: DynHash + Send + Sync + Debug + Display + Downcast {
15    fn clarify_to_tensor(&self) -> TractResult<Option<Arc<Tensor>>> {
16        Ok(None)
17    }
18
19    fn same_as(&self, other: &dyn OpaquePayload) -> bool;
20}
21impl_downcast!(OpaquePayload);
22dyn_hash::hash_trait_object!(OpaquePayload);
23
24pub trait OpaqueFact: DynHash + Send + Sync + Debug + dyn_clone::DynClone + Downcast {
25    fn same_as(&self, other: &dyn OpaqueFact) -> bool;
26
27    /// Whether or not it is acceptable for a Patch to substitute `self` by `other`.
28    ///
29    /// In other terms, all operators consuming `self` MUST accept also accept `other` without being altered.
30    fn compatible_with(&self, other: &dyn OpaqueFact) -> bool {
31        self.same_as(other)
32    }
33
34    fn clarify_dt_shape(&self) -> Option<(DatumType, TVec<TDim>)> {
35        None
36    }
37
38    fn buffer_sizes(&self) -> TVec<TDim>;
39
40    fn mem_size(&self) -> TDim {
41        self.buffer_sizes().iter().sum::<TDim>()
42    }
43}
44
45impl_downcast!(OpaqueFact);
46dyn_hash::hash_trait_object!(OpaqueFact);
47dyn_clone::clone_trait_object!(OpaqueFact);
48
49impl<T: OpaqueFact> From<T> for Box<dyn OpaqueFact> {
50    fn from(v: T) -> Self {
51        Box::new(v)
52    }
53}
54
55impl PartialEq for Box<dyn OpaqueFact> {
56    fn eq(&self, other: &Self) -> bool {
57        self.as_ref().same_as(other.as_ref())
58    }
59}
60
61impl Eq for Box<dyn OpaqueFact> {}
62
63impl OpaqueFact for TVec<Box<dyn OpaqueFact>> {
64    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
65        other.downcast_ref::<Self>().is_some_and(|o| self == o)
66    }
67
68    fn buffer_sizes(&self) -> TVec<TDim> {
69        self.iter().flat_map(|it| it.buffer_sizes()).collect()
70    }
71}
72impl OpaqueFact for TVec<Option<Box<dyn OpaqueFact>>> {
73    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
74        other.downcast_ref::<Self>().is_some_and(|o| self == o)
75    }
76
77    fn buffer_sizes(&self) -> TVec<TDim> {
78        self.iter().flatten().flat_map(|it| it.buffer_sizes()).collect()
79    }
80}
81
82#[derive(Debug, Hash, PartialEq, Eq)]
83pub struct DummyPayload;
84
85impl OpaquePayload for DummyPayload {
86    fn same_as(&self, other: &dyn OpaquePayload) -> bool {
87        other.downcast_ref::<Self>().is_some()
88    }
89}
90
91impl Display for DummyPayload {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "DummyPayload")
94    }
95}
96
97#[derive(Clone, Debug, Hash)]
98pub struct Opaque(pub Arc<dyn OpaquePayload>);
99
100impl Opaque {
101    pub fn downcast_ref<T: OpaquePayload>(&self) -> Option<&T> {
102        (*self.0).downcast_ref::<T>()
103    }
104
105    pub fn downcast_mut<T: OpaquePayload>(&mut self) -> Option<&mut T> {
106        Arc::get_mut(&mut self.0).and_then(|it| it.downcast_mut::<T>())
107    }
108}
109
110impl Deref for Opaque {
111    type Target = dyn OpaquePayload;
112    fn deref(&self) -> &Self::Target {
113        &*self.0
114    }
115}
116
117impl Display for Opaque {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(f, "{}", self.0)
120    }
121}
122
123impl Default for Opaque {
124    fn default() -> Self {
125        Opaque(Arc::new(DummyPayload))
126    }
127}
128
129impl PartialEq for Opaque {
130    fn eq(&self, other: &Self) -> bool {
131        Arc::ptr_eq(&self.0, &other.0) && self.0.same_as(&*other.0)
132    }
133}
134
135#[derive(Clone, Hash)]
136pub struct BlobWithFact {
137    pub fact: Box<dyn OpaqueFact>,
138    pub value: Arc<Blob>,
139}
140
141impl OpaquePayload for BlobWithFact {
142    fn same_as(&self, other: &dyn OpaquePayload) -> bool {
143        other
144            .downcast_ref::<Self>()
145            .is_some_and(|o| o.fact == self.fact.clone() && o.value == self.value)
146    }
147}
148
149impl std::fmt::Debug for BlobWithFact {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        write!(f, "{:?} {:?}", self.fact, self.value)
152    }
153}
154
155impl std::fmt::Display for BlobWithFact {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        write!(f, "{self:?}")
158    }
159}