1use alloc::format;
2use alloc::string::String;
3use alloc::string::ToString;
4use alloc::sync::Arc;
5use core::{
6 any::{Any, TypeId},
7 fmt::Display,
8 hash::{BuildHasher, Hash, Hasher},
9};
10use cubecl_common::ExecutionMode;
11
12#[macro_export(local_inner_macros)]
13macro_rules! storage_id_type {
15 ($name:ident) => {
16 #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
18 pub struct $name {
19 value: usize,
20 }
21
22 impl $name {
23 pub fn new() -> Self {
25 use core::sync::atomic::{AtomicUsize, Ordering};
26
27 static COUNTER: AtomicUsize = AtomicUsize::new(0);
28
29 let value = COUNTER.fetch_add(1, Ordering::Relaxed);
30 if value == usize::MAX {
31 core::panic!("Memory ID overflowed");
32 }
33 Self { value }
34 }
35 }
36
37 impl Default for $name {
38 fn default() -> Self {
39 Self::new()
40 }
41 }
42 };
43}
44
45#[derive(Clone, Debug)]
47pub struct HandleRef<Id> {
48 id: Arc<Id>,
49 all: Arc<()>,
50}
51
52#[derive(Clone, Debug)]
54pub struct BindingRef<Id> {
55 id: Id,
56 _all: Arc<()>,
57}
58
59impl<Id> BindingRef<Id>
60where
61 Id: Clone + core::fmt::Debug,
62{
63 pub(crate) fn id(&self) -> &Id {
65 &self.id
66 }
67}
68
69impl<Id> HandleRef<Id>
70where
71 Id: Clone + core::fmt::Debug,
72{
73 pub(crate) fn new(id: Id) -> Self {
75 Self {
76 id: Arc::new(id),
77 all: Arc::new(()),
78 }
79 }
80
81 pub(crate) fn id(&self) -> &Id {
83 &self.id
84 }
85
86 pub(crate) fn binding(self) -> BindingRef<Id> {
88 BindingRef {
89 id: self.id.as_ref().clone(),
90 _all: self.all,
91 }
92 }
93
94 pub(crate) fn can_mut(&self) -> bool {
96 Arc::strong_count(&self.id) <= 2
98 }
99
100 pub(crate) fn is_free(&self) -> bool {
102 Arc::strong_count(&self.all) <= 1
103 }
104}
105
106#[macro_export(local_inner_macros)]
107macro_rules! memory_id_type {
109 ($id:ident, $handle:ident) => {
110 #[derive(Clone, Debug)]
112 pub struct $handle {
113 value: $crate::id::HandleRef<$id>,
114 }
115
116 #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
118 pub struct $id {
119 pub(crate) value: usize,
120 }
121
122 impl $handle {
123 pub(crate) fn new() -> Self {
125 let value = Self::gen_id();
126 Self {
127 value: $crate::id::HandleRef::new($id { value }),
128 }
129 }
130
131 fn gen_id() -> usize {
132 static COUNTER: core::sync::atomic::AtomicUsize =
133 core::sync::atomic::AtomicUsize::new(0);
134
135 let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
136 if value == usize::MAX {
137 core::panic!("Memory ID overflowed");
138 }
139
140 value
141 }
142 }
143
144 impl core::ops::Deref for $handle {
145 type Target = $crate::id::HandleRef<$id>;
146
147 fn deref(&self) -> &Self::Target {
148 &self.value
149 }
150 }
151
152 impl Default for $handle {
153 fn default() -> Self {
154 Self::new()
155 }
156 }
157 };
158
159 ($id:ident, $handle:ident, $binding:ident) => {
160 memory_id_type!($id, $handle);
161
162 #[derive(Clone, Debug)]
164 pub struct $binding {
165 value: $crate::id::BindingRef<$id>,
166 }
167
168 impl $handle {
169 pub(crate) fn binding(self) -> $binding {
170 $binding {
171 value: self.value.binding(),
172 }
173 }
174 }
175
176 impl core::ops::Deref for $binding {
177 type Target = $crate::id::BindingRef<$id>;
178
179 fn deref(&self) -> &Self::Target {
180 &self.value
181 }
182 }
183 };
184}
185
186#[derive(Clone, Debug)]
188pub struct KernelId {
189 pub(crate) type_id: core::any::TypeId,
190 pub(crate) info: Option<Info>,
191 pub(crate) mode: Option<ExecutionMode>,
192 type_name: &'static str,
193}
194
195impl Hash for KernelId {
196 fn hash<H: Hasher>(&self, state: &mut H) {
197 self.type_id.hash(state);
198 self.info.hash(state);
199 self.mode.hash(state);
200 }
201}
202
203impl PartialEq for KernelId {
204 fn eq(&self, other: &Self) -> bool {
205 self.type_id == other.type_id && self.mode == other.mode && self.info == other.info
206 }
207}
208
209impl Eq for KernelId {}
210
211pub fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> String {
213 let kernel_id = kernel_id.to_string();
214 let mut result = String::new();
215 let mut depth = 0;
216 let indentation = 4;
217
218 let mut prev = ' ';
219
220 for c in kernel_id.chars() {
221 if c == ' ' {
222 continue;
223 }
224
225 let mut found_marker = false;
226
227 for (start, end) in markers {
228 let (start, end) = (*start, *end);
229
230 if c == start {
231 depth += 1;
232 if prev != ' ' && include_space {
233 result.push(' ');
234 }
235 result.push(start);
236 result.push('\n');
237 result.push_str(&" ".repeat(indentation * depth));
238 found_marker = true;
239 } else if c == end {
240 depth -= 1;
241 if prev != start {
242 if prev == ' ' {
243 result.pop();
244 }
245 result.push_str(",\n");
246 result.push_str(&" ".repeat(indentation * depth));
247 result.push(end);
248 } else {
249 for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
250 result.pop();
251 }
252 result.push(end);
253 }
254 found_marker = true;
255 }
256 }
257
258 if found_marker {
259 prev = c;
260 continue;
261 }
262
263 if c == ',' && depth > 0 {
264 if prev == ' ' {
265 result.pop();
266 }
267
268 result.push_str(",\n");
269 result.push_str(&" ".repeat(indentation * depth));
270 continue;
271 }
272
273 if c == ':' && include_space {
274 result.push(c);
275 result.push(' ');
276 prev = ' ';
277 } else {
278 result.push(c);
279 prev = c;
280 }
281 }
282
283 result
284}
285
286impl Display for KernelId {
287 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
288 match &self.info {
289 Some(info) => f.write_str(
290 format_str(
291 format!("{info:?}").as_str(),
292 &[('(', ')'), ('[', ']'), ('{', '}')],
293 true,
294 )
295 .as_str(),
296 ),
297 None => f.write_str("No info"),
298 }
299 }
300}
301
302impl KernelId {
303 pub fn new<T: 'static>() -> Self {
305 Self {
306 type_id: core::any::TypeId::of::<T>(),
307 type_name: core::any::type_name::<T>(),
308 info: None,
309 mode: None,
310 }
311 }
312
313 pub fn stable_format(&self) -> String {
317 format!("{}-{:?}-{:?}", self.type_name, self.info, self.mode)
318 }
319
320 pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
325 mut self,
326 info: I,
327 ) -> Self {
328 self.info = Some(Info::new(info));
329 self
330 }
331
332 pub fn mode(&mut self, mode: ExecutionMode) {
334 self.mode = Some(mode);
335 }
336}
337
338impl core::fmt::Debug for Info {
339 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
340 f.write_fmt(format_args!("{:?}", self.value))
341 }
342}
343
344impl Info {
345 fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
346 Self {
347 value: Arc::new(id),
348 }
349 }
350}
351
352trait DynKey: core::fmt::Debug + Send + Sync {
358 fn dyn_type_id(&self) -> TypeId;
359 fn dyn_eq(&self, other: &dyn DynKey) -> bool;
360 fn dyn_hash(&self, state: &mut dyn Hasher);
361 fn as_any(&self) -> &dyn Any;
362}
363
364impl PartialEq for Info {
365 fn eq(&self, other: &Self) -> bool {
366 self.value.dyn_eq(other.value.as_ref())
367 }
368}
369
370#[derive(Clone)]
372pub(crate) struct Info {
373 value: Arc<dyn DynKey>,
374}
375impl Eq for Info {}
376
377impl Hash for Info {
378 fn hash<H: Hasher>(&self, state: &mut H) {
379 self.value.dyn_type_id().hash(state);
380 self.value.dyn_hash(state)
381 }
382}
383
384impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
385 fn dyn_eq(&self, other: &dyn DynKey) -> bool {
386 if let Some(other) = other.as_any().downcast_ref::<T>() {
387 self == other
388 } else {
389 false
390 }
391 }
392
393 fn dyn_type_id(&self) -> TypeId {
394 TypeId::of::<T>()
395 }
396
397 fn dyn_hash(&self, state: &mut dyn Hasher) {
398 let hash = foldhash::fast::FixedState::with_seed(0).hash_one(self);
401 state.write_u64(hash);
402 }
403
404 fn as_any(&self) -> &dyn Any {
405 self
406 }
407}
408
409#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
411pub struct DeviceId {
412 pub type_id: u16,
414 pub index_id: u32,
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use std::collections::HashSet;
422
423 #[test]
424 pub fn kernel_id_hash() {
425 let value_1 = KernelId::new::<()>().info("1");
426 let value_2 = KernelId::new::<()>().info("2");
427
428 let mut set = HashSet::new();
429
430 set.insert(value_1.clone());
431
432 assert!(set.contains(&value_1));
433 assert!(!set.contains(&value_2));
434 }
435}