1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5 any::{Any, TypeId},
6 fmt::Display,
7 hash::{Hash, Hasher},
8};
9use cubecl_common::{
10 format::{DebugRaw, format_str},
11 hash::{StableHash, StableHasher},
12};
13use cubecl_ir::AddressType;
14use derive_more::{Eq, PartialEq};
15
16use crate::server::{CubeDim, ExecutionMode};
17
18#[macro_export(local_inner_macros)]
19macro_rules! storage_id_type {
21 ($name:ident) => {
22 #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
24 pub struct $name {
25 value: usize,
26 }
27
28 impl $name {
29 pub fn new() -> Self {
31 use core::sync::atomic::{AtomicUsize, Ordering};
32
33 static COUNTER: AtomicUsize = AtomicUsize::new(0);
34
35 let value = COUNTER.fetch_add(1, Ordering::Relaxed);
36 if value == usize::MAX {
37 core::panic!("Memory ID overflowed");
38 }
39 Self { value }
40 }
41 }
42
43 impl Default for $name {
44 fn default() -> Self {
45 Self::new()
46 }
47 }
48 };
49}
50
51#[derive(Clone, PartialEq, Eq)]
53pub struct KernelId {
54 #[eq(skip)]
55 type_name: &'static str,
56 pub(crate) type_id: core::any::TypeId,
57 pub(crate) address_type: AddressType,
58 pub cube_dim: CubeDim,
60 pub(crate) mode: ExecutionMode,
61 pub(crate) info: Option<Info>,
62}
63
64impl Hash for KernelId {
65 fn hash<H: Hasher>(&self, state: &mut H) {
66 self.type_id.hash(state);
67 self.address_type.hash(state);
68 self.cube_dim.hash(state);
69 self.mode.hash(state);
70 self.info.hash(state);
71 }
72}
73
74impl core::fmt::Debug for KernelId {
75 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
76 let mut debug_str = f.debug_struct("KernelId");
77 debug_str
78 .field("type", &DebugRaw(self.type_name))
79 .field("address_type", &self.address_type);
80 debug_str.field("cube_dim", &self.cube_dim);
81 debug_str.field("mode", &self.mode);
82 match &self.info {
83 Some(info) => debug_str.field("info", info),
84 None => debug_str.field("info", &self.info),
85 };
86 debug_str.finish()
87 }
88}
89
90impl Display for KernelId {
91 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
92 match &self.info {
93 Some(info) => f.write_str(
94 format_str(
95 format!("{info:?}").as_str(),
96 &[('(', ')'), ('[', ']'), ('{', '}')],
97 true,
98 )
99 .as_str(),
100 ),
101 None => f.write_str("No info"),
102 }
103 }
104}
105
106impl KernelId {
107 pub fn new<T: 'static>() -> Self {
109 Self {
110 type_id: core::any::TypeId::of::<T>(),
111 type_name: core::any::type_name::<T>(),
112 info: None,
113 cube_dim: CubeDim::new_single(),
114 mode: ExecutionMode::Checked,
115 address_type: Default::default(),
116 }
117 }
118
119 pub fn stable_format(&self) -> String {
123 format!(
124 "{}-{}-{:?}-{:?}-{:?}",
125 self.type_name, self.address_type, self.cube_dim, self.mode, self.info
126 )
127 }
128
129 pub fn stable_hash(&self) -> StableHash {
133 let mut hasher = StableHasher::new();
134 self.type_name.hash(&mut hasher);
135 self.address_type.hash(&mut hasher);
136 self.cube_dim.hash(&mut hasher);
137 self.mode.hash(&mut hasher);
138 self.info.hash(&mut hasher);
139
140 hasher.finalize()
141 }
142
143 pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
148 mut self,
149 info: I,
150 ) -> Self {
151 self.info = Some(Info::new(info));
152 self
153 }
154
155 pub fn mode(&mut self, mode: ExecutionMode) {
157 self.mode = mode;
158 }
159
160 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
162 self.cube_dim = cube_dim;
163 self
164 }
165
166 pub fn address_type(mut self, addr_ty: AddressType) -> Self {
168 self.address_type = addr_ty;
169 self
170 }
171}
172
173impl core::fmt::Debug for Info {
174 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
175 self.value.fmt(f)
176 }
177}
178
179impl Info {
180 fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
181 Self {
182 value: Arc::new(id),
183 }
184 }
185}
186
187trait DynKey: core::fmt::Debug + Send + Sync {
193 fn dyn_type_id(&self) -> TypeId;
194 fn dyn_eq(&self, other: &dyn DynKey) -> bool;
195 fn dyn_hash(&self, state: &mut dyn Hasher);
196 fn dyn_hash_one(&self) -> StableHash;
197 fn as_any(&self) -> &dyn Any;
198}
199
200impl PartialEq for Info {
201 fn eq(&self, other: &Self) -> bool {
202 self.value.dyn_eq(other.value.as_ref())
203 }
204}
205
206#[derive(Clone)]
208pub(crate) struct Info {
209 value: Arc<dyn DynKey>,
210}
211impl Eq for Info {}
212
213impl Hash for Info {
214 fn hash<H: Hasher>(&self, state: &mut H) {
215 self.value.dyn_type_id().hash(state);
216 self.value.dyn_hash(state)
217 }
218}
219
220impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
221 fn dyn_eq(&self, other: &dyn DynKey) -> bool {
222 if let Some(other) = other.as_any().downcast_ref::<T>() {
223 self == other
224 } else {
225 false
226 }
227 }
228
229 fn dyn_type_id(&self) -> TypeId {
230 TypeId::of::<T>()
231 }
232
233 fn dyn_hash(&self, state: &mut dyn Hasher) {
234 let hash = self.dyn_hash_one();
235 state.write_u128(hash);
236 }
237
238 fn dyn_hash_one(&self) -> StableHash {
239 let mut hasher = StableHasher::new();
240 self.hash(&mut hasher);
241 hasher.finalize()
242 }
243
244 fn as_any(&self) -> &dyn Any {
245 self
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use std::collections::HashSet;
253
254 #[test_log::test]
255 pub fn kernel_id_hash() {
256 let value_1 = KernelId::new::<()>().info("1");
257 let value_2 = KernelId::new::<()>().info("2");
258
259 let mut set = HashSet::new();
260
261 set.insert(value_1.clone());
262
263 assert!(set.contains(&value_1));
264 assert!(!set.contains(&value_2));
265 }
266}