1#![allow(clippy::derived_hash_with_manual_eq)]
2use crate::blob::Blob;
3use crate::datum::DatumType;
4use crate::dim::TDim;
5use crate::internal::{TVec, Tensor, TractResult};
6use std::fmt::{Debug, Display};
7use std::hash::Hash;
8use std::ops::Deref;
9use std::sync::Arc;
10
11use downcast_rs::{Downcast, impl_downcast};
12use dyn_hash::DynHash;
13
14pub trait OpaquePayload: DynHash + Send + Sync + Debug + Display + Downcast {
15 fn clarify_to_tensor(&self) -> TractResult<Option<Arc<Tensor>>> {
16 Ok(None)
17 }
18
19 fn same_as(&self, other: &dyn OpaquePayload) -> bool;
20}
21impl_downcast!(OpaquePayload);
22dyn_hash::hash_trait_object!(OpaquePayload);
23
24pub trait OpaqueFact: DynHash + Send + Sync + Debug + dyn_clone::DynClone + Downcast {
25 fn same_as(&self, other: &dyn OpaqueFact) -> bool;
26
27 fn compatible_with(&self, other: &dyn OpaqueFact) -> bool {
31 self.same_as(other)
32 }
33
34 fn clarify_dt_shape(&self) -> Option<(DatumType, TVec<TDim>)> {
35 None
36 }
37
38 fn buffer_sizes(&self) -> TVec<TDim>;
39
40 fn mem_size(&self) -> TDim {
41 self.buffer_sizes().iter().sum::<TDim>()
42 }
43}
44
45impl_downcast!(OpaqueFact);
46dyn_hash::hash_trait_object!(OpaqueFact);
47dyn_clone::clone_trait_object!(OpaqueFact);
48
49impl<T: OpaqueFact> From<T> for Box<dyn OpaqueFact> {
50 fn from(v: T) -> Self {
51 Box::new(v)
52 }
53}
54
55impl PartialEq for Box<dyn OpaqueFact> {
56 fn eq(&self, other: &Self) -> bool {
57 self.as_ref().same_as(other.as_ref())
58 }
59}
60
61impl Eq for Box<dyn OpaqueFact> {}
62
63impl OpaqueFact for TVec<Box<dyn OpaqueFact>> {
64 fn same_as(&self, other: &dyn OpaqueFact) -> bool {
65 other.downcast_ref::<Self>().is_some_and(|o| self == o)
66 }
67
68 fn buffer_sizes(&self) -> TVec<TDim> {
69 self.iter().flat_map(|it| it.buffer_sizes()).collect()
70 }
71}
72impl OpaqueFact for TVec<Option<Box<dyn OpaqueFact>>> {
73 fn same_as(&self, other: &dyn OpaqueFact) -> bool {
74 other.downcast_ref::<Self>().is_some_and(|o| self == o)
75 }
76
77 fn buffer_sizes(&self) -> TVec<TDim> {
78 self.iter().flatten().flat_map(|it| it.buffer_sizes()).collect()
79 }
80}
81
82#[derive(Debug, Hash, PartialEq, Eq)]
83pub struct DummyPayload;
84
85impl OpaquePayload for DummyPayload {
86 fn same_as(&self, other: &dyn OpaquePayload) -> bool {
87 other.downcast_ref::<Self>().is_some()
88 }
89}
90
91impl Display for DummyPayload {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "DummyPayload")
94 }
95}
96
97#[derive(Clone, Debug, Hash)]
98pub struct Opaque(pub Arc<dyn OpaquePayload>);
99
100impl Opaque {
101 pub fn downcast_ref<T: OpaquePayload>(&self) -> Option<&T> {
102 (*self.0).downcast_ref::<T>()
103 }
104
105 pub fn downcast_mut<T: OpaquePayload>(&mut self) -> Option<&mut T> {
106 Arc::get_mut(&mut self.0).and_then(|it| it.downcast_mut::<T>())
107 }
108}
109
110impl Deref for Opaque {
111 type Target = dyn OpaquePayload;
112 fn deref(&self) -> &Self::Target {
113 &*self.0
114 }
115}
116
117impl Display for Opaque {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(f, "{}", self.0)
120 }
121}
122
123impl Default for Opaque {
124 fn default() -> Self {
125 Opaque(Arc::new(DummyPayload))
126 }
127}
128
129impl PartialEq for Opaque {
130 fn eq(&self, other: &Self) -> bool {
131 Arc::ptr_eq(&self.0, &other.0) && self.0.same_as(&*other.0)
132 }
133}
134
135#[derive(Clone, Hash)]
136pub struct BlobWithFact {
137 pub fact: Box<dyn OpaqueFact>,
138 pub value: Arc<Blob>,
139}
140
141impl OpaquePayload for BlobWithFact {
142 fn same_as(&self, other: &dyn OpaquePayload) -> bool {
143 other
144 .downcast_ref::<Self>()
145 .is_some_and(|o| o.fact == self.fact.clone() && o.value == self.value)
146 }
147}
148
149impl std::fmt::Debug for BlobWithFact {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 write!(f, "{:?} {:?}", self.fact, self.value)
152 }
153}
154
155impl std::fmt::Display for BlobWithFact {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 write!(f, "{self:?}")
158 }
159}