cubecl_runtime/
id.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::string::ToString;
4use alloc::sync::Arc;
5use core::{
6    any::{Any, TypeId},
7    fmt::Display,
8    hash::{BuildHasher, Hash, Hasher},
9};
10use cubecl_common::ExecutionMode;
11
12#[macro_export(local_inner_macros)]
13/// Create a new storage ID type.
14macro_rules! storage_id_type {
15    ($name:ident) => {
16        /// Storage ID.
17        #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
18        pub struct $name {
19            value: usize,
20        }
21
22        impl $name {
23            /// Create a new ID.
24            pub fn new() -> Self {
25                use core::sync::atomic::{AtomicUsize, Ordering};
26
27                static COUNTER: AtomicUsize = AtomicUsize::new(0);
28
29                let value = COUNTER.fetch_add(1, Ordering::Relaxed);
30                if value == usize::MAX {
31                    core::panic!("Memory ID overflowed");
32                }
33                Self { value }
34            }
35        }
36
37        impl Default for $name {
38            fn default() -> Self {
39                Self::new()
40            }
41        }
42    };
43}
44
45/// Reference to a buffer handle.
46#[derive(Clone, Debug)]
47pub struct HandleRef<Id> {
48    id: Arc<Id>,
49    all: Arc<()>,
50}
51
52/// Reference to buffer binding.
53#[derive(Clone, Debug)]
54pub struct BindingRef<Id> {
55    id: Id,
56    _all: Arc<()>,
57}
58
59impl<Id> BindingRef<Id>
60where
61    Id: Clone + core::fmt::Debug,
62{
63    /// The id associated to the buffer.
64    pub(crate) fn id(&self) -> &Id {
65        &self.id
66    }
67}
68
69impl<Id> HandleRef<Id>
70where
71    Id: Clone + core::fmt::Debug,
72{
73    /// Create a new handle.
74    pub(crate) fn new(id: Id) -> Self {
75        Self {
76            id: Arc::new(id),
77            all: Arc::new(()),
78        }
79    }
80
81    /// The id associated to the handle.
82    pub(crate) fn id(&self) -> &Id {
83        &self.id
84    }
85
86    /// Get the binding.
87    pub(crate) fn binding(self) -> BindingRef<Id> {
88        BindingRef {
89            id: self.id.as_ref().clone(),
90            _all: self.all,
91        }
92    }
93
94    /// If the handle can be mut.
95    pub(crate) fn can_mut(&self) -> bool {
96        // 1 memory management reference with 1 tensor reference.
97        Arc::strong_count(&self.id) <= 2
98    }
99
100    /// If the resource is free.
101    pub(crate) fn is_free(&self) -> bool {
102        Arc::strong_count(&self.all) <= 1
103    }
104}
105
106#[macro_export(local_inner_macros)]
107/// Create new memory ID types.
108macro_rules! memory_id_type {
109    ($id:ident, $handle:ident) => {
110        /// Memory Handle.
111        #[derive(Clone, Debug)]
112        pub struct $handle {
113            value: $crate::id::HandleRef<$id>,
114        }
115
116        /// Memory ID.
117        #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
118        pub struct $id {
119            pub(crate) value: usize,
120        }
121
122        impl $handle {
123            /// Create a new ID.
124            pub(crate) fn new() -> Self {
125                let value = Self::gen_id();
126                Self {
127                    value: $crate::id::HandleRef::new($id { value }),
128                }
129            }
130
131            fn gen_id() -> usize {
132                static COUNTER: core::sync::atomic::AtomicUsize =
133                    core::sync::atomic::AtomicUsize::new(0);
134
135                let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
136                if value == usize::MAX {
137                    core::panic!("Memory ID overflowed");
138                }
139
140                value
141            }
142        }
143
144        impl core::ops::Deref for $handle {
145            type Target = $crate::id::HandleRef<$id>;
146
147            fn deref(&self) -> &Self::Target {
148                &self.value
149            }
150        }
151
152        impl Default for $handle {
153            fn default() -> Self {
154                Self::new()
155            }
156        }
157    };
158
159    ($id:ident, $handle:ident, $binding:ident) => {
160        memory_id_type!($id, $handle);
161
162        /// Binding of a memory handle.
163        #[derive(Clone, Debug)]
164        pub struct $binding {
165            value: $crate::id::BindingRef<$id>,
166        }
167
168        impl $handle {
169            pub(crate) fn binding(self) -> $binding {
170                $binding {
171                    value: self.value.binding(),
172                }
173            }
174        }
175
176        impl core::ops::Deref for $binding {
177            type Target = $crate::id::BindingRef<$id>;
178
179            fn deref(&self) -> &Self::Target {
180                &self.value
181            }
182        }
183    };
184}
185
186/// Kernel unique identifier.
187#[derive(Clone, Debug)]
188pub struct KernelId {
189    pub(crate) type_id: core::any::TypeId,
190    pub(crate) info: Option<Info>,
191    pub(crate) mode: Option<ExecutionMode>,
192    type_name: &'static str,
193}
194
195impl Hash for KernelId {
196    fn hash<H: Hasher>(&self, state: &mut H) {
197        self.type_id.hash(state);
198        self.info.hash(state);
199        self.mode.hash(state);
200    }
201}
202
203impl PartialEq for KernelId {
204    fn eq(&self, other: &Self) -> bool {
205        self.type_id == other.type_id && self.mode == other.mode && self.info == other.info
206    }
207}
208
209impl Eq for KernelId {}
210
211/// Format strings for use in identifiers and types.
212pub fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> String {
213    let kernel_id = kernel_id.to_string();
214    let mut result = String::new();
215    let mut depth = 0;
216    let indentation = 4;
217
218    let mut prev = ' ';
219
220    for c in kernel_id.chars() {
221        if c == ' ' {
222            continue;
223        }
224
225        let mut found_marker = false;
226
227        for (start, end) in markers {
228            let (start, end) = (*start, *end);
229
230            if c == start {
231                depth += 1;
232                if prev != ' ' && include_space {
233                    result.push(' ');
234                }
235                result.push(start);
236                result.push('\n');
237                result.push_str(&" ".repeat(indentation * depth));
238                found_marker = true;
239            } else if c == end {
240                depth -= 1;
241                if prev != start {
242                    if prev == ' ' {
243                        result.pop();
244                    }
245                    result.push_str(",\n");
246                    result.push_str(&" ".repeat(indentation * depth));
247                    result.push(end);
248                } else {
249                    for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
250                        result.pop();
251                    }
252                    result.push(end);
253                }
254                found_marker = true;
255            }
256        }
257
258        if found_marker {
259            prev = c;
260            continue;
261        }
262
263        if c == ',' && depth > 0 {
264            if prev == ' ' {
265                result.pop();
266            }
267
268            result.push_str(",\n");
269            result.push_str(&" ".repeat(indentation * depth));
270            continue;
271        }
272
273        if c == ':' && include_space {
274            result.push(c);
275            result.push(' ');
276            prev = ' ';
277        } else {
278            result.push(c);
279            prev = c;
280        }
281    }
282
283    result
284}
285
286impl Display for KernelId {
287    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
288        match &self.info {
289            Some(info) => f.write_str(
290                format_str(
291                    format!("{info:?}").as_str(),
292                    &[('(', ')'), ('[', ']'), ('{', '}')],
293                    true,
294                )
295                .as_str(),
296            ),
297            None => f.write_str("No info"),
298        }
299    }
300}
301
302impl KernelId {
303    /// Create a new [kernel id](KernelId) for a type.
304    pub fn new<T: 'static>() -> Self {
305        Self {
306            type_id: core::any::TypeId::of::<T>(),
307            type_name: core::any::type_name::<T>(),
308            info: None,
309            mode: None,
310        }
311    }
312
313    /// Render the key in a standard format that can be used between runs.
314    ///
315    /// Can be used as a persistent kernel cache key.
316    pub fn stable_format(&self) -> String {
317        format!("{}-{:?}-{:?}", self.type_name, self.info, self.mode)
318    }
319
320    /// Add information to the [kernel id](KernelId).
321    ///
322    /// The information is used to differentiate kernels of the same kind but with different
323    /// configurations, which affect the generated code.
324    pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
325        mut self,
326        info: I,
327    ) -> Self {
328        self.info = Some(Info::new(info));
329        self
330    }
331
332    /// Set the [execution mode](ExecutionMode).
333    pub fn mode(&mut self, mode: ExecutionMode) {
334        self.mode = Some(mode);
335    }
336}
337
338impl core::fmt::Debug for Info {
339    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
340        f.write_fmt(format_args!("{:?}", self.value))
341    }
342}
343
344impl Info {
345    fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
346        Self {
347            value: Arc::new(id),
348        }
349    }
350}
351
352/// This trait allows various types to be used as keys within a single data structure.
353///
354/// The downside is that the hashing method is hardcoded and cannot be configured using the
355/// [core::hash::Hash] function. The provided [Hasher] will be modified, but only based on the
356/// result of the hash from the [DefaultHasher].
357trait DynKey: core::fmt::Debug + Send + Sync {
358    fn dyn_type_id(&self) -> TypeId;
359    fn dyn_eq(&self, other: &dyn DynKey) -> bool;
360    fn dyn_hash(&self, state: &mut dyn Hasher);
361    fn as_any(&self) -> &dyn Any;
362}
363
364impl PartialEq for Info {
365    fn eq(&self, other: &Self) -> bool {
366        self.value.dyn_eq(other.value.as_ref())
367    }
368}
369
370/// Extra information
371#[derive(Clone)]
372pub(crate) struct Info {
373    value: Arc<dyn DynKey>,
374}
375impl Eq for Info {}
376
377impl Hash for Info {
378    fn hash<H: Hasher>(&self, state: &mut H) {
379        self.value.dyn_type_id().hash(state);
380        self.value.dyn_hash(state)
381    }
382}
383
384impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
385    fn dyn_eq(&self, other: &dyn DynKey) -> bool {
386        if let Some(other) = other.as_any().downcast_ref::<T>() {
387            self == other
388        } else {
389            false
390        }
391    }
392
393    fn dyn_type_id(&self) -> TypeId {
394        TypeId::of::<T>()
395    }
396
397    fn dyn_hash(&self, state: &mut dyn Hasher) {
398        // HashBrown uses foldhash but the default hasher still creates some random state. We need this hash here
399        // to be exactly reproducible.
400        let hash = foldhash::fast::FixedState::with_seed(0).hash_one(self);
401        state.write_u64(hash);
402    }
403
404    fn as_any(&self) -> &dyn Any {
405        self
406    }
407}
408
409/// The device id.
410#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
411pub struct DeviceId {
412    /// The type id identifies the type of the device.
413    pub type_id: u16,
414    /// The index id identifies the device number.
415    pub index_id: u32,
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use std::collections::HashSet;
422
423    #[test]
424    pub fn kernel_id_hash() {
425        let value_1 = KernelId::new::<()>().info("1");
426        let value_2 = KernelId::new::<()>().info("2");
427
428        let mut set = HashSet::new();
429
430        set.insert(value_1.clone());
431
432        assert!(set.contains(&value_1));
433        assert!(!set.contains(&value_2));
434    }
435}