cubecl_runtime/
id.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5    any::{Any, TypeId},
6    fmt::Display,
7    hash::{BuildHasher, Hash, Hasher},
8};
9use cubecl_common::ExecutionMode;
10use cubecl_common::format::format_str;
11
12#[macro_export(local_inner_macros)]
13/// Create a new storage ID type.
14macro_rules! storage_id_type {
15    ($name:ident) => {
16        /// Storage ID.
17        #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
18        pub struct $name {
19            value: usize,
20        }
21
22        impl $name {
23            /// Create a new ID.
24            pub fn new() -> Self {
25                use core::sync::atomic::{AtomicUsize, Ordering};
26
27                static COUNTER: AtomicUsize = AtomicUsize::new(0);
28
29                let value = COUNTER.fetch_add(1, Ordering::Relaxed);
30                if value == usize::MAX {
31                    core::panic!("Memory ID overflowed");
32                }
33                Self { value }
34            }
35        }
36
37        impl Default for $name {
38            fn default() -> Self {
39                Self::new()
40            }
41        }
42    };
43}
44
45/// Reference to a buffer handle.
46#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct HandleRef<Id> {
48    id: Arc<Id>,
49    all: Arc<()>,
50}
51
52/// Reference to buffer binding.
53#[derive(Clone, Debug)]
54pub struct BindingRef<Id> {
55    id: Id,
56    _all: Arc<()>,
57}
58
59impl<Id> BindingRef<Id>
60where
61    Id: Clone + core::fmt::Debug,
62{
63    /// The id associated to the buffer.
64    pub(crate) fn id(&self) -> &Id {
65        &self.id
66    }
67}
68
69impl<Id> HandleRef<Id>
70where
71    Id: Clone + core::fmt::Debug,
72{
73    /// Create a new handle.
74    pub(crate) fn new(id: Id) -> Self {
75        Self {
76            id: Arc::new(id),
77            all: Arc::new(()),
78        }
79    }
80
81    /// The id associated to the handle.
82    pub(crate) fn id(&self) -> &Id {
83        &self.id
84    }
85
86    /// Get the binding.
87    pub(crate) fn binding(self) -> BindingRef<Id> {
88        BindingRef {
89            id: self.id.as_ref().clone(),
90            _all: self.all,
91        }
92    }
93
94    /// If the handle can be mut.
95    pub(crate) fn can_mut(&self) -> bool {
96        // 1 memory management reference with 1 tensor reference.
97        Arc::strong_count(&self.id) <= 2
98    }
99
100    /// If the resource is free.
101    pub(crate) fn is_free(&self) -> bool {
102        Arc::strong_count(&self.all) <= 1
103    }
104}
105
106#[macro_export(local_inner_macros)]
107/// Create new memory ID types.
108macro_rules! memory_id_type {
109    ($id:ident, $handle:ident) => {
110        /// Memory Handle.
111        #[derive(Clone, Debug, PartialEq, Eq)]
112        pub struct $handle {
113            value: $crate::id::HandleRef<$id>,
114        }
115
116        /// Memory ID.
117        #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
118        pub struct $id {
119            pub(crate) value: usize,
120        }
121
122        impl $handle {
123            /// Create a new ID.
124            pub(crate) fn new() -> Self {
125                let value = Self::gen_id();
126                Self {
127                    value: $crate::id::HandleRef::new($id { value }),
128                }
129            }
130
131            fn gen_id() -> usize {
132                static COUNTER: core::sync::atomic::AtomicUsize =
133                    core::sync::atomic::AtomicUsize::new(0);
134
135                let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
136                if value == usize::MAX {
137                    core::panic!("Memory ID overflowed");
138                }
139
140                value
141            }
142        }
143
144        impl core::ops::Deref for $handle {
145            type Target = $crate::id::HandleRef<$id>;
146
147            fn deref(&self) -> &Self::Target {
148                &self.value
149            }
150        }
151
152        impl Default for $handle {
153            fn default() -> Self {
154                Self::new()
155            }
156        }
157    };
158
159    ($id:ident, $handle:ident, $binding:ident) => {
160        memory_id_type!($id, $handle);
161
162        /// Binding of a memory handle.
163        #[derive(Clone, Debug)]
164        pub struct $binding {
165            value: $crate::id::BindingRef<$id>,
166        }
167
168        impl $handle {
169            pub(crate) fn binding(self) -> $binding {
170                $binding {
171                    value: self.value.binding(),
172                }
173            }
174        }
175
176        impl core::ops::Deref for $binding {
177            type Target = $crate::id::BindingRef<$id>;
178
179            fn deref(&self) -> &Self::Target {
180                &self.value
181            }
182        }
183    };
184}
185
186/// Kernel unique identifier.
187#[derive(Clone, Debug)]
188pub struct KernelId {
189    pub(crate) type_id: core::any::TypeId,
190    pub(crate) info: Option<Info>,
191    pub(crate) mode: Option<ExecutionMode>,
192    type_name: &'static str,
193}
194
195impl Hash for KernelId {
196    fn hash<H: Hasher>(&self, state: &mut H) {
197        self.type_id.hash(state);
198        self.info.hash(state);
199        self.mode.hash(state);
200    }
201}
202
203impl PartialEq for KernelId {
204    fn eq(&self, other: &Self) -> bool {
205        self.type_id == other.type_id && self.mode == other.mode && self.info == other.info
206    }
207}
208
209impl Eq for KernelId {}
210
211impl Display for KernelId {
212    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
213        match &self.info {
214            Some(info) => f.write_str(
215                format_str(
216                    format!("{info:?}").as_str(),
217                    &[('(', ')'), ('[', ']'), ('{', '}')],
218                    true,
219                )
220                .as_str(),
221            ),
222            None => f.write_str("No info"),
223        }
224    }
225}
226
227impl KernelId {
228    /// Create a new [kernel id](KernelId) for a type.
229    pub fn new<T: 'static>() -> Self {
230        Self {
231            type_id: core::any::TypeId::of::<T>(),
232            type_name: core::any::type_name::<T>(),
233            info: None,
234            mode: None,
235        }
236    }
237
238    /// Render the key in a standard format that can be used between runs.
239    ///
240    /// Can be used as a persistent kernel cache key.
241    pub fn stable_format(&self) -> String {
242        format!("{}-{:?}-{:?}", self.type_name, self.info, self.mode)
243    }
244
245    /// Add information to the [kernel id](KernelId).
246    ///
247    /// The information is used to differentiate kernels of the same kind but with different
248    /// configurations, which affect the generated code.
249    pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
250        mut self,
251        info: I,
252    ) -> Self {
253        self.info = Some(Info::new(info));
254        self
255    }
256
257    /// Set the [execution mode](ExecutionMode).
258    pub fn mode(&mut self, mode: ExecutionMode) {
259        self.mode = Some(mode);
260    }
261}
262
263impl core::fmt::Debug for Info {
264    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
265        f.write_fmt(format_args!("{:?}", self.value))
266    }
267}
268
269impl Info {
270    fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
271        Self {
272            value: Arc::new(id),
273        }
274    }
275}
276
277/// This trait allows various types to be used as keys within a single data structure.
278///
279/// The downside is that the hashing method is hardcoded and cannot be configured using the
280/// [core::hash::Hash] function. The provided [Hasher] will be modified, but only based on the
281/// result of the hash from the [DefaultHasher].
282trait DynKey: core::fmt::Debug + Send + Sync {
283    fn dyn_type_id(&self) -> TypeId;
284    fn dyn_eq(&self, other: &dyn DynKey) -> bool;
285    fn dyn_hash(&self, state: &mut dyn Hasher);
286    fn as_any(&self) -> &dyn Any;
287}
288
289impl PartialEq for Info {
290    fn eq(&self, other: &Self) -> bool {
291        self.value.dyn_eq(other.value.as_ref())
292    }
293}
294
295/// Extra information
296#[derive(Clone)]
297pub(crate) struct Info {
298    value: Arc<dyn DynKey>,
299}
300impl Eq for Info {}
301
302impl Hash for Info {
303    fn hash<H: Hasher>(&self, state: &mut H) {
304        self.value.dyn_type_id().hash(state);
305        self.value.dyn_hash(state)
306    }
307}
308
309impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
310    fn dyn_eq(&self, other: &dyn DynKey) -> bool {
311        if let Some(other) = other.as_any().downcast_ref::<T>() {
312            self == other
313        } else {
314            false
315        }
316    }
317
318    fn dyn_type_id(&self) -> TypeId {
319        TypeId::of::<T>()
320    }
321
322    fn dyn_hash(&self, state: &mut dyn Hasher) {
323        // HashBrown uses foldhash but the default hasher still creates some random state. We need this hash here
324        // to be exactly reproducible.
325        let hash = foldhash::fast::FixedState::with_seed(0).hash_one(self);
326        state.write_u64(hash);
327    }
328
329    fn as_any(&self) -> &dyn Any {
330        self
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use std::collections::HashSet;
338
339    #[test]
340    pub fn kernel_id_hash() {
341        let value_1 = KernelId::new::<()>().info("1");
342        let value_2 = KernelId::new::<()>().info("2");
343
344        let mut set = HashSet::new();
345
346        set.insert(value_1.clone());
347
348        assert!(set.contains(&value_1));
349        assert!(!set.contains(&value_2));
350    }
351}