1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5 any::{Any, TypeId},
6 fmt::Display,
7 hash::{BuildHasher, Hash, Hasher},
8};
9use cubecl_common::format::{DebugRaw, format_str};
10use cubecl_ir::AddressType;
11use derive_more::{Eq, PartialEq};
12
13use crate::server::{CubeDim, ExecutionMode};
14
15#[macro_export(local_inner_macros)]
16macro_rules! storage_id_type {
18 ($name:ident) => {
19 #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
21 pub struct $name {
22 value: usize,
23 }
24
25 impl $name {
26 pub fn new() -> Self {
28 use core::sync::atomic::{AtomicUsize, Ordering};
29
30 static COUNTER: AtomicUsize = AtomicUsize::new(0);
31
32 let value = COUNTER.fetch_add(1, Ordering::Relaxed);
33 if value == usize::MAX {
34 core::panic!("Memory ID overflowed");
35 }
36 Self { value }
37 }
38 }
39
40 impl Default for $name {
41 fn default() -> Self {
42 Self::new()
43 }
44 }
45 };
46}
47
48#[derive(Clone, Debug, PartialEq, Eq)]
50pub struct HandleRef<Id> {
51 id: Arc<Id>,
52 all: Arc<()>,
53}
54
55#[derive(Clone, Debug)]
57pub struct BindingRef<Id> {
58 id: Id,
59 _all: Arc<()>,
60}
61
62impl<Id> BindingRef<Id>
63where
64 Id: Clone + core::fmt::Debug,
65{
66 pub(crate) fn id(&self) -> &Id {
68 &self.id
69 }
70}
71
72impl<Id> HandleRef<Id>
73where
74 Id: Clone + core::fmt::Debug,
75{
76 pub(crate) fn new(id: Id) -> Self {
78 Self {
79 id: Arc::new(id),
80 all: Arc::new(()),
81 }
82 }
83
84 pub(crate) fn id(&self) -> &Id {
86 &self.id
87 }
88
89 pub(crate) fn binding(self) -> BindingRef<Id> {
91 BindingRef {
92 id: self.id.as_ref().clone(),
93 _all: self.all,
94 }
95 }
96
97 pub(crate) fn can_mut(&self) -> bool {
99 Arc::strong_count(&self.id) <= 2
101 }
102
103 pub(crate) fn is_free(&self) -> bool {
105 Arc::strong_count(&self.all) <= 1
106 }
107}
108
109#[macro_export(local_inner_macros)]
110macro_rules! memory_id_type {
112 ($id:ident, $handle:ident) => {
113 #[derive(Clone, Debug, PartialEq, Eq)]
115 pub struct $handle {
116 value: $crate::id::HandleRef<$id>,
117 }
118
119 #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
121 pub struct $id {
122 pub(crate) value: usize,
123 }
124
125 impl $handle {
126 pub(crate) fn new() -> Self {
128 let value = Self::gen_id();
129 Self {
130 value: $crate::id::HandleRef::new($id { value }),
131 }
132 }
133
134 fn gen_id() -> usize {
135 static COUNTER: core::sync::atomic::AtomicUsize =
136 core::sync::atomic::AtomicUsize::new(0);
137
138 let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
139 if value == usize::MAX {
140 core::panic!("Memory ID overflowed");
141 }
142
143 value
144 }
145 }
146
147 impl core::ops::Deref for $handle {
148 type Target = $crate::id::HandleRef<$id>;
149
150 fn deref(&self) -> &Self::Target {
151 &self.value
152 }
153 }
154
155 impl Default for $handle {
156 fn default() -> Self {
157 Self::new()
158 }
159 }
160 };
161
162 ($id:ident, $handle:ident, $binding:ident) => {
163 memory_id_type!($id, $handle);
164
165 #[derive(Clone, Debug)]
167 pub struct $binding {
168 value: $crate::id::BindingRef<$id>,
169 }
170
171 impl $handle {
172 pub(crate) fn binding(self) -> $binding {
173 $binding {
174 value: self.value.binding(),
175 }
176 }
177 }
178
179 impl core::ops::Deref for $binding {
180 type Target = $crate::id::BindingRef<$id>;
181
182 fn deref(&self) -> &Self::Target {
183 &self.value
184 }
185 }
186 };
187}
188
189#[derive(Clone, PartialEq, Eq)]
191pub struct KernelId {
192 #[eq(skip)]
193 type_name: &'static str,
194 pub(crate) type_id: core::any::TypeId,
195 pub(crate) address_type: AddressType,
196 pub(crate) cube_dim: Option<CubeDim>,
197 pub(crate) mode: ExecutionMode,
198 pub(crate) info: Option<Info>,
199}
200
201impl Hash for KernelId {
202 fn hash<H: Hasher>(&self, state: &mut H) {
203 self.type_id.hash(state);
204 self.address_type.hash(state);
205 self.cube_dim.hash(state);
206 self.mode.hash(state);
207 self.info.hash(state);
208 }
209}
210
211impl core::fmt::Debug for KernelId {
212 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
213 let mut debug_str = f.debug_struct("KernelId");
214 debug_str
215 .field("type", &DebugRaw(self.type_name))
216 .field("address_type", &self.address_type);
217 match &self.cube_dim {
218 Some(cube_dim) => debug_str.field("cube_dim", cube_dim),
219 None => debug_str.field("cube_dim", &self.cube_dim),
220 };
221 debug_str.field("mode", &self.mode);
222 match &self.info {
223 Some(info) => debug_str.field("info", info),
224 None => debug_str.field("info", &self.info),
225 };
226 debug_str.finish()
227 }
228}
229
230impl Display for KernelId {
231 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
232 match &self.info {
233 Some(info) => f.write_str(
234 format_str(
235 format!("{info:?}").as_str(),
236 &[('(', ')'), ('[', ']'), ('{', '}')],
237 true,
238 )
239 .as_str(),
240 ),
241 None => f.write_str("No info"),
242 }
243 }
244}
245
246impl KernelId {
247 pub fn new<T: 'static>() -> Self {
249 Self {
250 type_id: core::any::TypeId::of::<T>(),
251 type_name: core::any::type_name::<T>(),
252 info: None,
253 cube_dim: None,
254 mode: ExecutionMode::Checked,
255 address_type: Default::default(),
256 }
257 }
258
259 pub fn stable_format(&self) -> String {
263 format!(
264 "{}-{}-{:?}-{:?}-{:?}",
265 self.type_name, self.address_type, self.cube_dim, self.mode, self.info
266 )
267 }
268
269 pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
274 mut self,
275 info: I,
276 ) -> Self {
277 self.info = Some(Info::new(info));
278 self
279 }
280
281 pub fn mode(&mut self, mode: ExecutionMode) {
283 self.mode = mode;
284 }
285
286 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
288 self.cube_dim = Some(cube_dim);
289 self
290 }
291
292 pub fn address_type(mut self, addr_ty: AddressType) -> Self {
294 self.address_type = addr_ty;
295 self
296 }
297}
298
299impl core::fmt::Debug for Info {
300 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
301 self.value.fmt(f)
302 }
303}
304
305impl Info {
306 fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
307 Self {
308 value: Arc::new(id),
309 }
310 }
311}
312
313trait DynKey: core::fmt::Debug + Send + Sync {
319 fn dyn_type_id(&self) -> TypeId;
320 fn dyn_eq(&self, other: &dyn DynKey) -> bool;
321 fn dyn_hash(&self, state: &mut dyn Hasher);
322 fn as_any(&self) -> &dyn Any;
323}
324
325impl PartialEq for Info {
326 fn eq(&self, other: &Self) -> bool {
327 self.value.dyn_eq(other.value.as_ref())
328 }
329}
330
331#[derive(Clone)]
333pub(crate) struct Info {
334 value: Arc<dyn DynKey>,
335}
336impl Eq for Info {}
337
338impl Hash for Info {
339 fn hash<H: Hasher>(&self, state: &mut H) {
340 self.value.dyn_type_id().hash(state);
341 self.value.dyn_hash(state)
342 }
343}
344
345impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
346 fn dyn_eq(&self, other: &dyn DynKey) -> bool {
347 if let Some(other) = other.as_any().downcast_ref::<T>() {
348 self == other
349 } else {
350 false
351 }
352 }
353
354 fn dyn_type_id(&self) -> TypeId {
355 TypeId::of::<T>()
356 }
357
358 fn dyn_hash(&self, state: &mut dyn Hasher) {
359 let hash = foldhash::fast::FixedState::with_seed(0).hash_one(self);
362 state.write_u64(hash);
363 }
364
365 fn as_any(&self) -> &dyn Any {
366 self
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use std::collections::HashSet;
374
375 #[test_log::test]
376 pub fn kernel_id_hash() {
377 let value_1 = KernelId::new::<()>().info("1");
378 let value_2 = KernelId::new::<()>().info("2");
379
380 let mut set = HashSet::new();
381
382 set.insert(value_1.clone());
383
384 assert!(set.contains(&value_1));
385 assert!(!set.contains(&value_2));
386 }
387}