1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5 any::{Any, TypeId},
6 fmt::Display,
7 hash::{Hash, Hasher},
8};
9use cubecl_common::{
10 format::{DebugRaw, format_str},
11 hash::{StableHash, StableHasher},
12};
13use cubecl_ir::AddressType;
14use derive_more::{Eq, PartialEq};
15
16use crate::server::{CubeDim, ExecutionMode};
17
18#[macro_export(local_inner_macros)]
19macro_rules! storage_id_type {
21 ($name:ident) => {
22 #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
24 pub struct $name {
25 value: usize,
26 }
27
28 impl $name {
29 pub fn new() -> Self {
31 use core::sync::atomic::{AtomicUsize, Ordering};
32
33 static COUNTER: AtomicUsize = AtomicUsize::new(0);
34
35 let value = COUNTER.fetch_add(1, Ordering::Relaxed);
36 if value == usize::MAX {
37 core::panic!("Memory ID overflowed");
38 }
39 Self { value }
40 }
41 }
42
43 impl Default for $name {
44 fn default() -> Self {
45 Self::new()
46 }
47 }
48 };
49}
50
51#[derive(Clone, Debug, PartialEq, Eq)]
53pub struct HandleRef<Id> {
54 id: Arc<Id>,
55 all: Arc<()>,
56}
57
58#[derive(Clone, Debug)]
60pub struct BindingRef<Id> {
61 id: Id,
62 _all: Arc<()>,
63}
64
65impl<Id> BindingRef<Id>
66where
67 Id: Clone + core::fmt::Debug,
68{
69 pub(crate) fn id(&self) -> &Id {
71 &self.id
72 }
73}
74
75impl<Id> HandleRef<Id>
76where
77 Id: Clone + core::fmt::Debug,
78{
79 pub(crate) fn new(id: Id) -> Self {
81 Self {
82 id: Arc::new(id),
83 all: Arc::new(()),
84 }
85 }
86
87 pub(crate) fn id(&self) -> &Id {
89 &self.id
90 }
91
92 pub(crate) fn binding(self) -> BindingRef<Id> {
94 BindingRef {
95 id: self.id.as_ref().clone(),
96 _all: self.all,
97 }
98 }
99
100 pub(crate) fn can_mut(&self) -> bool {
102 Arc::strong_count(&self.id) <= 2
104 }
105
106 pub(crate) fn is_free(&self) -> bool {
108 Arc::strong_count(&self.all) <= 1
109 }
110}
111
112#[macro_export(local_inner_macros)]
113macro_rules! memory_id_type {
115 ($id:ident, $handle:ident) => {
116 #[derive(Clone, Debug, PartialEq, Eq)]
118 pub struct $handle {
119 value: $crate::id::HandleRef<$id>,
120 }
121
122 #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
124 pub struct $id {
125 pub(crate) value: usize,
126 }
127
128 impl $handle {
129 pub(crate) fn new() -> Self {
131 let value = Self::gen_id();
132 Self {
133 value: $crate::id::HandleRef::new($id { value }),
134 }
135 }
136
137 fn gen_id() -> usize {
138 static COUNTER: core::sync::atomic::AtomicUsize =
139 core::sync::atomic::AtomicUsize::new(0);
140
141 let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
142 if value == usize::MAX {
143 core::panic!("Memory ID overflowed");
144 }
145
146 value
147 }
148 }
149
150 impl core::ops::Deref for $handle {
151 type Target = $crate::id::HandleRef<$id>;
152
153 fn deref(&self) -> &Self::Target {
154 &self.value
155 }
156 }
157
158 impl Default for $handle {
159 fn default() -> Self {
160 Self::new()
161 }
162 }
163 };
164
165 ($id:ident, $handle:ident, $binding:ident) => {
166 memory_id_type!($id, $handle);
167
168 #[derive(Clone, Debug)]
170 pub struct $binding {
171 value: $crate::id::BindingRef<$id>,
172 }
173
174 impl $handle {
175 pub(crate) fn binding(self) -> $binding {
176 $binding {
177 value: self.value.binding(),
178 }
179 }
180 }
181
182 impl core::ops::Deref for $binding {
183 type Target = $crate::id::BindingRef<$id>;
184
185 fn deref(&self) -> &Self::Target {
186 &self.value
187 }
188 }
189 };
190}
191
192#[derive(Clone, PartialEq, Eq)]
194pub struct KernelId {
195 #[eq(skip)]
196 type_name: &'static str,
197 pub(crate) type_id: core::any::TypeId,
198 pub(crate) address_type: AddressType,
199 pub cube_dim: CubeDim,
201 pub(crate) mode: ExecutionMode,
202 pub(crate) info: Option<Info>,
203}
204
205impl Hash for KernelId {
206 fn hash<H: Hasher>(&self, state: &mut H) {
207 self.type_id.hash(state);
208 self.address_type.hash(state);
209 self.cube_dim.hash(state);
210 self.mode.hash(state);
211 self.info.hash(state);
212 }
213}
214
215impl core::fmt::Debug for KernelId {
216 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
217 let mut debug_str = f.debug_struct("KernelId");
218 debug_str
219 .field("type", &DebugRaw(self.type_name))
220 .field("address_type", &self.address_type);
221 debug_str.field("cube_dim", &self.cube_dim);
222 debug_str.field("mode", &self.mode);
223 match &self.info {
224 Some(info) => debug_str.field("info", info),
225 None => debug_str.field("info", &self.info),
226 };
227 debug_str.finish()
228 }
229}
230
231impl Display for KernelId {
232 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
233 match &self.info {
234 Some(info) => f.write_str(
235 format_str(
236 format!("{info:?}").as_str(),
237 &[('(', ')'), ('[', ']'), ('{', '}')],
238 true,
239 )
240 .as_str(),
241 ),
242 None => f.write_str("No info"),
243 }
244 }
245}
246
247impl KernelId {
248 pub fn new<T: 'static>() -> Self {
250 Self {
251 type_id: core::any::TypeId::of::<T>(),
252 type_name: core::any::type_name::<T>(),
253 info: None,
254 cube_dim: CubeDim::new_single(),
255 mode: ExecutionMode::Checked,
256 address_type: Default::default(),
257 }
258 }
259
260 pub fn stable_format(&self) -> String {
264 format!(
265 "{}-{}-{:?}-{:?}-{:?}",
266 self.type_name, self.address_type, self.cube_dim, self.mode, self.info
267 )
268 }
269
270 pub fn stable_hash(&self) -> StableHash {
274 let mut hasher = StableHasher::new();
275 self.type_name.hash(&mut hasher);
276 self.address_type.hash(&mut hasher);
277 self.cube_dim.hash(&mut hasher);
278 self.mode.hash(&mut hasher);
279 self.info.hash(&mut hasher);
280
281 hasher.finalize()
282 }
283
284 pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
289 mut self,
290 info: I,
291 ) -> Self {
292 self.info = Some(Info::new(info));
293 self
294 }
295
296 pub fn mode(&mut self, mode: ExecutionMode) {
298 self.mode = mode;
299 }
300
301 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
303 self.cube_dim = cube_dim;
304 self
305 }
306
307 pub fn address_type(mut self, addr_ty: AddressType) -> Self {
309 self.address_type = addr_ty;
310 self
311 }
312}
313
314impl core::fmt::Debug for Info {
315 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
316 self.value.fmt(f)
317 }
318}
319
320impl Info {
321 fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
322 Self {
323 value: Arc::new(id),
324 }
325 }
326}
327
328trait DynKey: core::fmt::Debug + Send + Sync {
334 fn dyn_type_id(&self) -> TypeId;
335 fn dyn_eq(&self, other: &dyn DynKey) -> bool;
336 fn dyn_hash(&self, state: &mut dyn Hasher);
337 fn dyn_hash_one(&self) -> StableHash;
338 fn as_any(&self) -> &dyn Any;
339}
340
341impl PartialEq for Info {
342 fn eq(&self, other: &Self) -> bool {
343 self.value.dyn_eq(other.value.as_ref())
344 }
345}
346
347#[derive(Clone)]
349pub(crate) struct Info {
350 value: Arc<dyn DynKey>,
351}
352impl Eq for Info {}
353
354impl Hash for Info {
355 fn hash<H: Hasher>(&self, state: &mut H) {
356 self.value.dyn_type_id().hash(state);
357 self.value.dyn_hash(state)
358 }
359}
360
361impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
362 fn dyn_eq(&self, other: &dyn DynKey) -> bool {
363 if let Some(other) = other.as_any().downcast_ref::<T>() {
364 self == other
365 } else {
366 false
367 }
368 }
369
370 fn dyn_type_id(&self) -> TypeId {
371 TypeId::of::<T>()
372 }
373
374 fn dyn_hash(&self, state: &mut dyn Hasher) {
375 let hash = self.dyn_hash_one();
376 state.write_u128(hash);
377 }
378
379 fn dyn_hash_one(&self) -> StableHash {
380 let mut hasher = StableHasher::new();
381 self.hash(&mut hasher);
382 hasher.finalize()
383 }
384
385 fn as_any(&self) -> &dyn Any {
386 self
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use std::collections::HashSet;
394
395 #[test_log::test]
396 pub fn kernel_id_hash() {
397 let value_1 = KernelId::new::<()>().info("1");
398 let value_2 = KernelId::new::<()>().info("2");
399
400 let mut set = HashSet::new();
401
402 set.insert(value_1.clone());
403
404 assert!(set.contains(&value_1));
405 assert!(!set.contains(&value_2));
406 }
407}