1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5 any::{Any, TypeId},
6 fmt::Display,
7 hash::{BuildHasher, Hash, Hasher},
8};
9use cubecl_common::format::format_str;
10
11use crate::server::ExecutionMode;
12
13#[macro_export(local_inner_macros)]
14macro_rules! storage_id_type {
16 ($name:ident) => {
17 #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
19 pub struct $name {
20 value: usize,
21 }
22
23 impl $name {
24 pub fn new() -> Self {
26 use core::sync::atomic::{AtomicUsize, Ordering};
27
28 static COUNTER: AtomicUsize = AtomicUsize::new(0);
29
30 let value = COUNTER.fetch_add(1, Ordering::Relaxed);
31 if value == usize::MAX {
32 core::panic!("Memory ID overflowed");
33 }
34 Self { value }
35 }
36 }
37
38 impl Default for $name {
39 fn default() -> Self {
40 Self::new()
41 }
42 }
43 };
44}
45
46#[derive(Clone, Debug, PartialEq, Eq)]
48pub struct HandleRef<Id> {
49 id: Arc<Id>,
50 all: Arc<()>,
51}
52
53#[derive(Clone, Debug)]
55pub struct BindingRef<Id> {
56 id: Id,
57 _all: Arc<()>,
58}
59
60impl<Id> BindingRef<Id>
61where
62 Id: Clone + core::fmt::Debug,
63{
64 pub(crate) fn id(&self) -> &Id {
66 &self.id
67 }
68}
69
70impl<Id> HandleRef<Id>
71where
72 Id: Clone + core::fmt::Debug,
73{
74 pub(crate) fn new(id: Id) -> Self {
76 Self {
77 id: Arc::new(id),
78 all: Arc::new(()),
79 }
80 }
81
82 pub(crate) fn id(&self) -> &Id {
84 &self.id
85 }
86
87 pub(crate) fn binding(self) -> BindingRef<Id> {
89 BindingRef {
90 id: self.id.as_ref().clone(),
91 _all: self.all,
92 }
93 }
94
95 pub(crate) fn can_mut(&self) -> bool {
97 Arc::strong_count(&self.id) <= 2
99 }
100
101 pub(crate) fn is_free(&self) -> bool {
103 Arc::strong_count(&self.all) <= 1
104 }
105}
106
107#[macro_export(local_inner_macros)]
108macro_rules! memory_id_type {
110 ($id:ident, $handle:ident) => {
111 #[derive(Clone, Debug, PartialEq, Eq)]
113 pub struct $handle {
114 value: $crate::id::HandleRef<$id>,
115 }
116
117 #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
119 pub struct $id {
120 pub(crate) value: usize,
121 }
122
123 impl $handle {
124 pub(crate) fn new() -> Self {
126 let value = Self::gen_id();
127 Self {
128 value: $crate::id::HandleRef::new($id { value }),
129 }
130 }
131
132 fn gen_id() -> usize {
133 static COUNTER: core::sync::atomic::AtomicUsize =
134 core::sync::atomic::AtomicUsize::new(0);
135
136 let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
137 if value == usize::MAX {
138 core::panic!("Memory ID overflowed");
139 }
140
141 value
142 }
143 }
144
145 impl core::ops::Deref for $handle {
146 type Target = $crate::id::HandleRef<$id>;
147
148 fn deref(&self) -> &Self::Target {
149 &self.value
150 }
151 }
152
153 impl Default for $handle {
154 fn default() -> Self {
155 Self::new()
156 }
157 }
158 };
159
160 ($id:ident, $handle:ident, $binding:ident) => {
161 memory_id_type!($id, $handle);
162
163 #[derive(Clone, Debug)]
165 pub struct $binding {
166 value: $crate::id::BindingRef<$id>,
167 }
168
169 impl $handle {
170 pub(crate) fn binding(self) -> $binding {
171 $binding {
172 value: self.value.binding(),
173 }
174 }
175 }
176
177 impl core::ops::Deref for $binding {
178 type Target = $crate::id::BindingRef<$id>;
179
180 fn deref(&self) -> &Self::Target {
181 &self.value
182 }
183 }
184 };
185}
186
187#[derive(Clone, Debug)]
189pub struct KernelId {
190 pub(crate) type_id: core::any::TypeId,
191 pub(crate) info: Option<Info>,
192 pub(crate) mode: Option<ExecutionMode>,
193 type_name: &'static str,
194}
195
196impl Hash for KernelId {
197 fn hash<H: Hasher>(&self, state: &mut H) {
198 self.type_id.hash(state);
199 self.info.hash(state);
200 self.mode.hash(state);
201 }
202}
203
204impl PartialEq for KernelId {
205 fn eq(&self, other: &Self) -> bool {
206 self.type_id == other.type_id && self.mode == other.mode && self.info == other.info
207 }
208}
209
210impl Eq for KernelId {}
211
212impl Display for KernelId {
213 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
214 match &self.info {
215 Some(info) => f.write_str(
216 format_str(
217 format!("{info:?}").as_str(),
218 &[('(', ')'), ('[', ']'), ('{', '}')],
219 true,
220 )
221 .as_str(),
222 ),
223 None => f.write_str("No info"),
224 }
225 }
226}
227
228impl KernelId {
229 pub fn new<T: 'static>() -> Self {
231 Self {
232 type_id: core::any::TypeId::of::<T>(),
233 type_name: core::any::type_name::<T>(),
234 info: None,
235 mode: None,
236 }
237 }
238
239 pub fn stable_format(&self) -> String {
243 format!("{}-{:?}-{:?}", self.type_name, self.info, self.mode)
244 }
245
246 pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
251 mut self,
252 info: I,
253 ) -> Self {
254 self.info = Some(Info::new(info));
255 self
256 }
257
258 pub fn mode(&mut self, mode: ExecutionMode) {
260 self.mode = Some(mode);
261 }
262}
263
264impl core::fmt::Debug for Info {
265 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
266 f.write_fmt(format_args!("{:?}", self.value))
267 }
268}
269
270impl Info {
271 fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
272 Self {
273 value: Arc::new(id),
274 }
275 }
276}
277
278trait DynKey: core::fmt::Debug + Send + Sync {
284 fn dyn_type_id(&self) -> TypeId;
285 fn dyn_eq(&self, other: &dyn DynKey) -> bool;
286 fn dyn_hash(&self, state: &mut dyn Hasher);
287 fn as_any(&self) -> &dyn Any;
288}
289
290impl PartialEq for Info {
291 fn eq(&self, other: &Self) -> bool {
292 self.value.dyn_eq(other.value.as_ref())
293 }
294}
295
296#[derive(Clone)]
298pub(crate) struct Info {
299 value: Arc<dyn DynKey>,
300}
301impl Eq for Info {}
302
303impl Hash for Info {
304 fn hash<H: Hasher>(&self, state: &mut H) {
305 self.value.dyn_type_id().hash(state);
306 self.value.dyn_hash(state)
307 }
308}
309
310impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
311 fn dyn_eq(&self, other: &dyn DynKey) -> bool {
312 if let Some(other) = other.as_any().downcast_ref::<T>() {
313 self == other
314 } else {
315 false
316 }
317 }
318
319 fn dyn_type_id(&self) -> TypeId {
320 TypeId::of::<T>()
321 }
322
323 fn dyn_hash(&self, state: &mut dyn Hasher) {
324 let hash = foldhash::fast::FixedState::with_seed(0).hash_one(self);
327 state.write_u64(hash);
328 }
329
330 fn as_any(&self) -> &dyn Any {
331 self
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use std::collections::HashSet;
339
340 #[test]
341 pub fn kernel_id_hash() {
342 let value_1 = KernelId::new::<()>().info("1");
343 let value_2 = KernelId::new::<()>().info("2");
344
345 let mut set = HashSet::new();
346
347 set.insert(value_1.clone());
348
349 assert!(set.contains(&value_1));
350 assert!(!set.contains(&value_2));
351 }
352}