cubecl_runtime/
id.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5    any::{Any, TypeId},
6    fmt::Display,
7    hash::{BuildHasher, Hash, Hasher},
8};
9use cubecl_common::format::format_str;
10
11use crate::server::ExecutionMode;
12
13#[macro_export(local_inner_macros)]
14/// Create a new storage ID type.
15macro_rules! storage_id_type {
16    ($name:ident) => {
17        /// Storage ID.
18        #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
19        pub struct $name {
20            value: usize,
21        }
22
23        impl $name {
24            /// Create a new ID.
25            pub fn new() -> Self {
26                use core::sync::atomic::{AtomicUsize, Ordering};
27
28                static COUNTER: AtomicUsize = AtomicUsize::new(0);
29
30                let value = COUNTER.fetch_add(1, Ordering::Relaxed);
31                if value == usize::MAX {
32                    core::panic!("Memory ID overflowed");
33                }
34                Self { value }
35            }
36        }
37
38        impl Default for $name {
39            fn default() -> Self {
40                Self::new()
41            }
42        }
43    };
44}
45
46/// Reference to a buffer handle.
47#[derive(Clone, Debug, PartialEq, Eq)]
48pub struct HandleRef<Id> {
49    id: Arc<Id>,
50    all: Arc<()>,
51}
52
53/// Reference to buffer binding.
54#[derive(Clone, Debug)]
55pub struct BindingRef<Id> {
56    id: Id,
57    _all: Arc<()>,
58}
59
60impl<Id> BindingRef<Id>
61where
62    Id: Clone + core::fmt::Debug,
63{
64    /// The id associated to the buffer.
65    pub(crate) fn id(&self) -> &Id {
66        &self.id
67    }
68}
69
70impl<Id> HandleRef<Id>
71where
72    Id: Clone + core::fmt::Debug,
73{
74    /// Create a new handle.
75    pub(crate) fn new(id: Id) -> Self {
76        Self {
77            id: Arc::new(id),
78            all: Arc::new(()),
79        }
80    }
81
82    /// The id associated to the handle.
83    pub(crate) fn id(&self) -> &Id {
84        &self.id
85    }
86
87    /// Get the binding.
88    pub(crate) fn binding(self) -> BindingRef<Id> {
89        BindingRef {
90            id: self.id.as_ref().clone(),
91            _all: self.all,
92        }
93    }
94
95    /// If the handle can be mut.
96    pub(crate) fn can_mut(&self) -> bool {
97        // 1 memory management reference with 1 tensor reference.
98        Arc::strong_count(&self.id) <= 2
99    }
100
101    /// If the resource is free.
102    pub(crate) fn is_free(&self) -> bool {
103        Arc::strong_count(&self.all) <= 1
104    }
105}
106
107#[macro_export(local_inner_macros)]
108/// Create new memory ID types.
109macro_rules! memory_id_type {
110    ($id:ident, $handle:ident) => {
111        /// Memory Handle.
112        #[derive(Clone, Debug, PartialEq, Eq)]
113        pub struct $handle {
114            value: $crate::id::HandleRef<$id>,
115        }
116
117        /// Memory ID.
118        #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
119        pub struct $id {
120            pub(crate) value: usize,
121        }
122
123        impl $handle {
124            /// Create a new ID.
125            pub(crate) fn new() -> Self {
126                let value = Self::gen_id();
127                Self {
128                    value: $crate::id::HandleRef::new($id { value }),
129                }
130            }
131
132            fn gen_id() -> usize {
133                static COUNTER: core::sync::atomic::AtomicUsize =
134                    core::sync::atomic::AtomicUsize::new(0);
135
136                let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
137                if value == usize::MAX {
138                    core::panic!("Memory ID overflowed");
139                }
140
141                value
142            }
143        }
144
145        impl core::ops::Deref for $handle {
146            type Target = $crate::id::HandleRef<$id>;
147
148            fn deref(&self) -> &Self::Target {
149                &self.value
150            }
151        }
152
153        impl Default for $handle {
154            fn default() -> Self {
155                Self::new()
156            }
157        }
158    };
159
160    ($id:ident, $handle:ident, $binding:ident) => {
161        memory_id_type!($id, $handle);
162
163        /// Binding of a memory handle.
164        #[derive(Clone, Debug)]
165        pub struct $binding {
166            value: $crate::id::BindingRef<$id>,
167        }
168
169        impl $handle {
170            pub(crate) fn binding(self) -> $binding {
171                $binding {
172                    value: self.value.binding(),
173                }
174            }
175        }
176
177        impl core::ops::Deref for $binding {
178            type Target = $crate::id::BindingRef<$id>;
179
180            fn deref(&self) -> &Self::Target {
181                &self.value
182            }
183        }
184    };
185}
186
187/// Kernel unique identifier.
188#[derive(Clone, Debug)]
189pub struct KernelId {
190    pub(crate) type_id: core::any::TypeId,
191    pub(crate) info: Option<Info>,
192    pub(crate) mode: Option<ExecutionMode>,
193    type_name: &'static str,
194}
195
196impl Hash for KernelId {
197    fn hash<H: Hasher>(&self, state: &mut H) {
198        self.type_id.hash(state);
199        self.info.hash(state);
200        self.mode.hash(state);
201    }
202}
203
204impl PartialEq for KernelId {
205    fn eq(&self, other: &Self) -> bool {
206        self.type_id == other.type_id && self.mode == other.mode && self.info == other.info
207    }
208}
209
210impl Eq for KernelId {}
211
212impl Display for KernelId {
213    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
214        match &self.info {
215            Some(info) => f.write_str(
216                format_str(
217                    format!("{info:?}").as_str(),
218                    &[('(', ')'), ('[', ']'), ('{', '}')],
219                    true,
220                )
221                .as_str(),
222            ),
223            None => f.write_str("No info"),
224        }
225    }
226}
227
228impl KernelId {
229    /// Create a new [kernel id](KernelId) for a type.
230    pub fn new<T: 'static>() -> Self {
231        Self {
232            type_id: core::any::TypeId::of::<T>(),
233            type_name: core::any::type_name::<T>(),
234            info: None,
235            mode: None,
236        }
237    }
238
239    /// Render the key in a standard format that can be used between runs.
240    ///
241    /// Can be used as a persistent kernel cache key.
242    pub fn stable_format(&self) -> String {
243        format!("{}-{:?}-{:?}", self.type_name, self.info, self.mode)
244    }
245
246    /// Add information to the [kernel id](KernelId).
247    ///
248    /// The information is used to differentiate kernels of the same kind but with different
249    /// configurations, which affect the generated code.
250    pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
251        mut self,
252        info: I,
253    ) -> Self {
254        self.info = Some(Info::new(info));
255        self
256    }
257
258    /// Set the [execution mode](ExecutionMode).
259    pub fn mode(&mut self, mode: ExecutionMode) {
260        self.mode = Some(mode);
261    }
262}
263
264impl core::fmt::Debug for Info {
265    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
266        f.write_fmt(format_args!("{:?}", self.value))
267    }
268}
269
270impl Info {
271    fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
272        Self {
273            value: Arc::new(id),
274        }
275    }
276}
277
278/// This trait allows various types to be used as keys within a single data structure.
279///
280/// The downside is that the hashing method is hardcoded and cannot be configured using the
281/// [core::hash::Hash] function. The provided [Hasher] will be modified, but only based on the
282/// result of the hash from the [DefaultHasher].
283trait DynKey: core::fmt::Debug + Send + Sync {
284    fn dyn_type_id(&self) -> TypeId;
285    fn dyn_eq(&self, other: &dyn DynKey) -> bool;
286    fn dyn_hash(&self, state: &mut dyn Hasher);
287    fn as_any(&self) -> &dyn Any;
288}
289
290impl PartialEq for Info {
291    fn eq(&self, other: &Self) -> bool {
292        self.value.dyn_eq(other.value.as_ref())
293    }
294}
295
296/// Extra information
297#[derive(Clone)]
298pub(crate) struct Info {
299    value: Arc<dyn DynKey>,
300}
301impl Eq for Info {}
302
303impl Hash for Info {
304    fn hash<H: Hasher>(&self, state: &mut H) {
305        self.value.dyn_type_id().hash(state);
306        self.value.dyn_hash(state)
307    }
308}
309
310impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
311    fn dyn_eq(&self, other: &dyn DynKey) -> bool {
312        if let Some(other) = other.as_any().downcast_ref::<T>() {
313            self == other
314        } else {
315            false
316        }
317    }
318
319    fn dyn_type_id(&self) -> TypeId {
320        TypeId::of::<T>()
321    }
322
323    fn dyn_hash(&self, state: &mut dyn Hasher) {
324        // HashBrown uses foldhash but the default hasher still creates some random state. We need this hash here
325        // to be exactly reproducible.
326        let hash = foldhash::fast::FixedState::with_seed(0).hash_one(self);
327        state.write_u64(hash);
328    }
329
330    fn as_any(&self) -> &dyn Any {
331        self
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use std::collections::HashSet;
339
340    #[test]
341    pub fn kernel_id_hash() {
342        let value_1 = KernelId::new::<()>().info("1");
343        let value_2 = KernelId::new::<()>().info("2");
344
345        let mut set = HashSet::new();
346
347        set.insert(value_1.clone());
348
349        assert!(set.contains(&value_1));
350        assert!(!set.contains(&value_2));
351    }
352}