1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5 any::{Any, TypeId},
6 fmt::Display,
7 hash::{BuildHasher, Hash, Hasher},
8};
9use cubecl_common::ExecutionMode;
10use cubecl_common::format::format_str;
11
12#[macro_export(local_inner_macros)]
13macro_rules! storage_id_type {
15 ($name:ident) => {
16 #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
18 pub struct $name {
19 value: usize,
20 }
21
22 impl $name {
23 pub fn new() -> Self {
25 use core::sync::atomic::{AtomicUsize, Ordering};
26
27 static COUNTER: AtomicUsize = AtomicUsize::new(0);
28
29 let value = COUNTER.fetch_add(1, Ordering::Relaxed);
30 if value == usize::MAX {
31 core::panic!("Memory ID overflowed");
32 }
33 Self { value }
34 }
35 }
36
37 impl Default for $name {
38 fn default() -> Self {
39 Self::new()
40 }
41 }
42 };
43}
44
45#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct HandleRef<Id> {
48 id: Arc<Id>,
49 all: Arc<()>,
50}
51
52#[derive(Clone, Debug)]
54pub struct BindingRef<Id> {
55 id: Id,
56 _all: Arc<()>,
57}
58
59impl<Id> BindingRef<Id>
60where
61 Id: Clone + core::fmt::Debug,
62{
63 pub(crate) fn id(&self) -> &Id {
65 &self.id
66 }
67}
68
69impl<Id> HandleRef<Id>
70where
71 Id: Clone + core::fmt::Debug,
72{
73 pub(crate) fn new(id: Id) -> Self {
75 Self {
76 id: Arc::new(id),
77 all: Arc::new(()),
78 }
79 }
80
81 pub(crate) fn id(&self) -> &Id {
83 &self.id
84 }
85
86 pub(crate) fn binding(self) -> BindingRef<Id> {
88 BindingRef {
89 id: self.id.as_ref().clone(),
90 _all: self.all,
91 }
92 }
93
94 pub(crate) fn can_mut(&self) -> bool {
96 Arc::strong_count(&self.id) <= 2
98 }
99
100 pub(crate) fn is_free(&self) -> bool {
102 Arc::strong_count(&self.all) <= 1
103 }
104}
105
106#[macro_export(local_inner_macros)]
107macro_rules! memory_id_type {
109 ($id:ident, $handle:ident) => {
110 #[derive(Clone, Debug, PartialEq, Eq)]
112 pub struct $handle {
113 value: $crate::id::HandleRef<$id>,
114 }
115
116 #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
118 pub struct $id {
119 pub(crate) value: usize,
120 }
121
122 impl $handle {
123 pub(crate) fn new() -> Self {
125 let value = Self::gen_id();
126 Self {
127 value: $crate::id::HandleRef::new($id { value }),
128 }
129 }
130
131 fn gen_id() -> usize {
132 static COUNTER: core::sync::atomic::AtomicUsize =
133 core::sync::atomic::AtomicUsize::new(0);
134
135 let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
136 if value == usize::MAX {
137 core::panic!("Memory ID overflowed");
138 }
139
140 value
141 }
142 }
143
144 impl core::ops::Deref for $handle {
145 type Target = $crate::id::HandleRef<$id>;
146
147 fn deref(&self) -> &Self::Target {
148 &self.value
149 }
150 }
151
152 impl Default for $handle {
153 fn default() -> Self {
154 Self::new()
155 }
156 }
157 };
158
159 ($id:ident, $handle:ident, $binding:ident) => {
160 memory_id_type!($id, $handle);
161
162 #[derive(Clone, Debug)]
164 pub struct $binding {
165 value: $crate::id::BindingRef<$id>,
166 }
167
168 impl $handle {
169 pub(crate) fn binding(self) -> $binding {
170 $binding {
171 value: self.value.binding(),
172 }
173 }
174 }
175
176 impl core::ops::Deref for $binding {
177 type Target = $crate::id::BindingRef<$id>;
178
179 fn deref(&self) -> &Self::Target {
180 &self.value
181 }
182 }
183 };
184}
185
186#[derive(Clone, Debug)]
188pub struct KernelId {
189 pub(crate) type_id: core::any::TypeId,
190 pub(crate) info: Option<Info>,
191 pub(crate) mode: Option<ExecutionMode>,
192 type_name: &'static str,
193}
194
195impl Hash for KernelId {
196 fn hash<H: Hasher>(&self, state: &mut H) {
197 self.type_id.hash(state);
198 self.info.hash(state);
199 self.mode.hash(state);
200 }
201}
202
203impl PartialEq for KernelId {
204 fn eq(&self, other: &Self) -> bool {
205 self.type_id == other.type_id && self.mode == other.mode && self.info == other.info
206 }
207}
208
209impl Eq for KernelId {}
210
211impl Display for KernelId {
212 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
213 match &self.info {
214 Some(info) => f.write_str(
215 format_str(
216 format!("{info:?}").as_str(),
217 &[('(', ')'), ('[', ']'), ('{', '}')],
218 true,
219 )
220 .as_str(),
221 ),
222 None => f.write_str("No info"),
223 }
224 }
225}
226
227impl KernelId {
228 pub fn new<T: 'static>() -> Self {
230 Self {
231 type_id: core::any::TypeId::of::<T>(),
232 type_name: core::any::type_name::<T>(),
233 info: None,
234 mode: None,
235 }
236 }
237
238 pub fn stable_format(&self) -> String {
242 format!("{}-{:?}-{:?}", self.type_name, self.info, self.mode)
243 }
244
245 pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
250 mut self,
251 info: I,
252 ) -> Self {
253 self.info = Some(Info::new(info));
254 self
255 }
256
257 pub fn mode(&mut self, mode: ExecutionMode) {
259 self.mode = Some(mode);
260 }
261}
262
263impl core::fmt::Debug for Info {
264 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
265 f.write_fmt(format_args!("{:?}", self.value))
266 }
267}
268
269impl Info {
270 fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
271 Self {
272 value: Arc::new(id),
273 }
274 }
275}
276
277trait DynKey: core::fmt::Debug + Send + Sync {
283 fn dyn_type_id(&self) -> TypeId;
284 fn dyn_eq(&self, other: &dyn DynKey) -> bool;
285 fn dyn_hash(&self, state: &mut dyn Hasher);
286 fn as_any(&self) -> &dyn Any;
287}
288
289impl PartialEq for Info {
290 fn eq(&self, other: &Self) -> bool {
291 self.value.dyn_eq(other.value.as_ref())
292 }
293}
294
295#[derive(Clone)]
297pub(crate) struct Info {
298 value: Arc<dyn DynKey>,
299}
300impl Eq for Info {}
301
302impl Hash for Info {
303 fn hash<H: Hasher>(&self, state: &mut H) {
304 self.value.dyn_type_id().hash(state);
305 self.value.dyn_hash(state)
306 }
307}
308
309impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
310 fn dyn_eq(&self, other: &dyn DynKey) -> bool {
311 if let Some(other) = other.as_any().downcast_ref::<T>() {
312 self == other
313 } else {
314 false
315 }
316 }
317
318 fn dyn_type_id(&self) -> TypeId {
319 TypeId::of::<T>()
320 }
321
322 fn dyn_hash(&self, state: &mut dyn Hasher) {
323 let hash = foldhash::fast::FixedState::with_seed(0).hash_one(self);
326 state.write_u64(hash);
327 }
328
329 fn as_any(&self) -> &dyn Any {
330 self
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use std::collections::HashSet;
338
339 #[test]
340 pub fn kernel_id_hash() {
341 let value_1 = KernelId::new::<()>().info("1");
342 let value_2 = KernelId::new::<()>().info("2");
343
344 let mut set = HashSet::new();
345
346 set.insert(value_1.clone());
347
348 assert!(set.contains(&value_1));
349 assert!(!set.contains(&value_2));
350 }
351}