Skip to main content

cubecl_runtime/
id.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5    any::{Any, TypeId},
6    fmt::Display,
7    hash::{Hash, Hasher},
8};
9use cubecl_common::{
10    format::{DebugRaw, format_str},
11    hash::{StableHash, StableHasher},
12};
13use cubecl_ir::AddressType;
14use derive_more::{Eq, PartialEq};
15
16use crate::server::{CubeDim, ExecutionMode};
17
18#[macro_export(local_inner_macros)]
19/// Create a new storage ID type.
20macro_rules! storage_id_type {
21    ($name:ident) => {
22        /// Storage ID.
23        #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
24        pub struct $name {
25            value: usize,
26        }
27
28        impl $name {
29            /// Create a new ID.
30            pub fn new() -> Self {
31                use core::sync::atomic::{AtomicUsize, Ordering};
32
33                static COUNTER: AtomicUsize = AtomicUsize::new(0);
34
35                let value = COUNTER.fetch_add(1, Ordering::Relaxed);
36                if value == usize::MAX {
37                    core::panic!("Memory ID overflowed");
38                }
39                Self { value }
40            }
41        }
42
43        impl Default for $name {
44            fn default() -> Self {
45                Self::new()
46            }
47        }
48    };
49}
50
51/// Reference to a buffer handle.
52#[derive(Clone, Debug, PartialEq, Eq)]
53pub struct HandleRef<Id> {
54    id: Arc<Id>,
55    all: Arc<()>,
56}
57
58/// Reference to buffer binding.
59#[derive(Clone, Debug)]
60pub struct BindingRef<Id> {
61    id: Id,
62    _all: Arc<()>,
63}
64
65impl<Id> BindingRef<Id>
66where
67    Id: Clone + core::fmt::Debug,
68{
69    /// The id associated to the buffer.
70    pub(crate) fn id(&self) -> &Id {
71        &self.id
72    }
73}
74
75impl<Id> HandleRef<Id>
76where
77    Id: Clone + core::fmt::Debug,
78{
79    /// Create a new handle.
80    pub(crate) fn new(id: Id) -> Self {
81        Self {
82            id: Arc::new(id),
83            all: Arc::new(()),
84        }
85    }
86
87    /// The id associated to the handle.
88    pub(crate) fn id(&self) -> &Id {
89        &self.id
90    }
91
92    /// Get the binding.
93    pub(crate) fn binding(self) -> BindingRef<Id> {
94        BindingRef {
95            id: self.id.as_ref().clone(),
96            _all: self.all,
97        }
98    }
99
100    /// If the handle can be mut.
101    pub(crate) fn can_mut(&self) -> bool {
102        // 1 memory management reference with 1 tensor reference.
103        Arc::strong_count(&self.id) <= 2
104    }
105
106    /// If the resource is free.
107    pub(crate) fn is_free(&self) -> bool {
108        Arc::strong_count(&self.all) <= 1
109    }
110}
111
112#[macro_export(local_inner_macros)]
113/// Create new memory ID types.
114macro_rules! memory_id_type {
115    ($id:ident, $handle:ident) => {
116        /// Memory Handle.
117        #[derive(Clone, Debug, PartialEq, Eq)]
118        pub struct $handle {
119            value: $crate::id::HandleRef<$id>,
120        }
121
122        /// Memory ID.
123        #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
124        pub struct $id {
125            pub(crate) value: usize,
126        }
127
128        impl $handle {
129            /// Create a new ID.
130            pub(crate) fn new() -> Self {
131                let value = Self::gen_id();
132                Self {
133                    value: $crate::id::HandleRef::new($id { value }),
134                }
135            }
136
137            fn gen_id() -> usize {
138                static COUNTER: core::sync::atomic::AtomicUsize =
139                    core::sync::atomic::AtomicUsize::new(0);
140
141                let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
142                if value == usize::MAX {
143                    core::panic!("Memory ID overflowed");
144                }
145
146                value
147            }
148        }
149
150        impl core::ops::Deref for $handle {
151            type Target = $crate::id::HandleRef<$id>;
152
153            fn deref(&self) -> &Self::Target {
154                &self.value
155            }
156        }
157
158        impl Default for $handle {
159            fn default() -> Self {
160                Self::new()
161            }
162        }
163    };
164
165    ($id:ident, $handle:ident, $binding:ident) => {
166        memory_id_type!($id, $handle);
167
168        /// Binding of a memory handle.
169        #[derive(Clone, Debug)]
170        pub struct $binding {
171            value: $crate::id::BindingRef<$id>,
172        }
173
174        impl $handle {
175            pub(crate) fn binding(self) -> $binding {
176                $binding {
177                    value: self.value.binding(),
178                }
179            }
180        }
181
182        impl core::ops::Deref for $binding {
183            type Target = $crate::id::BindingRef<$id>;
184
185            fn deref(&self) -> &Self::Target {
186                &self.value
187            }
188        }
189    };
190}
191
192/// Kernel unique identifier.
193#[derive(Clone, PartialEq, Eq)]
194pub struct KernelId {
195    #[eq(skip)]
196    type_name: &'static str,
197    pub(crate) type_id: core::any::TypeId,
198    pub(crate) address_type: AddressType,
199    /// The [`CubeDim`] for this kernel
200    pub cube_dim: CubeDim,
201    pub(crate) mode: ExecutionMode,
202    pub(crate) info: Option<Info>,
203}
204
205impl Hash for KernelId {
206    fn hash<H: Hasher>(&self, state: &mut H) {
207        self.type_id.hash(state);
208        self.address_type.hash(state);
209        self.cube_dim.hash(state);
210        self.mode.hash(state);
211        self.info.hash(state);
212    }
213}
214
215impl core::fmt::Debug for KernelId {
216    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
217        let mut debug_str = f.debug_struct("KernelId");
218        debug_str
219            .field("type", &DebugRaw(self.type_name))
220            .field("address_type", &self.address_type);
221        debug_str.field("cube_dim", &self.cube_dim);
222        debug_str.field("mode", &self.mode);
223        match &self.info {
224            Some(info) => debug_str.field("info", info),
225            None => debug_str.field("info", &self.info),
226        };
227        debug_str.finish()
228    }
229}
230
231impl Display for KernelId {
232    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
233        match &self.info {
234            Some(info) => f.write_str(
235                format_str(
236                    format!("{info:?}").as_str(),
237                    &[('(', ')'), ('[', ']'), ('{', '}')],
238                    true,
239                )
240                .as_str(),
241            ),
242            None => f.write_str("No info"),
243        }
244    }
245}
246
247impl KernelId {
248    /// Create a new [kernel id](KernelId) for a type.
249    pub fn new<T: 'static>() -> Self {
250        Self {
251            type_id: core::any::TypeId::of::<T>(),
252            type_name: core::any::type_name::<T>(),
253            info: None,
254            cube_dim: CubeDim::new_single(),
255            mode: ExecutionMode::Checked,
256            address_type: Default::default(),
257        }
258    }
259
260    /// Render the key in a standard format that can be used between runs.
261    ///
262    /// Can be used as a persistent kernel cache key.
263    pub fn stable_format(&self) -> String {
264        format!(
265            "{}-{}-{:?}-{:?}-{:?}",
266            self.type_name, self.address_type, self.cube_dim, self.mode, self.info
267        )
268    }
269
270    /// Hash the key in a stable way that can be used between runs.
271    ///
272    /// Can be used as a persistent kernel cache key.
273    pub fn stable_hash(&self) -> StableHash {
274        let mut hasher = StableHasher::new();
275        self.type_name.hash(&mut hasher);
276        self.address_type.hash(&mut hasher);
277        self.cube_dim.hash(&mut hasher);
278        self.mode.hash(&mut hasher);
279        self.info.hash(&mut hasher);
280
281        hasher.finalize()
282    }
283
284    /// Add information to the [kernel id](KernelId).
285    ///
286    /// The information is used to differentiate kernels of the same kind but with different
287    /// configurations, which affect the generated code.
288    pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
289        mut self,
290        info: I,
291    ) -> Self {
292        self.info = Some(Info::new(info));
293        self
294    }
295
296    /// Set the [execution mode](ExecutionMode).
297    pub fn mode(&mut self, mode: ExecutionMode) {
298        self.mode = mode;
299    }
300
301    /// Set the [cube dim](CubeDim).
302    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
303        self.cube_dim = cube_dim;
304        self
305    }
306
307    /// Set the [`AddressType`].
308    pub fn address_type(mut self, addr_ty: AddressType) -> Self {
309        self.address_type = addr_ty;
310        self
311    }
312}
313
314impl core::fmt::Debug for Info {
315    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
316        self.value.fmt(f)
317    }
318}
319
320impl Info {
321    fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
322        Self {
323            value: Arc::new(id),
324        }
325    }
326}
327
328/// This trait allows various types to be used as keys within a single data structure.
329///
330/// The downside is that the hashing method is hardcoded and cannot be configured using the
331/// [`core::hash::Hash`] function. The provided [Hasher] will be modified, but only based on the
332/// result of the hash from the [`DefaultHasher`].
333trait DynKey: core::fmt::Debug + Send + Sync {
334    fn dyn_type_id(&self) -> TypeId;
335    fn dyn_eq(&self, other: &dyn DynKey) -> bool;
336    fn dyn_hash(&self, state: &mut dyn Hasher);
337    fn dyn_hash_one(&self) -> StableHash;
338    fn as_any(&self) -> &dyn Any;
339}
340
341impl PartialEq for Info {
342    fn eq(&self, other: &Self) -> bool {
343        self.value.dyn_eq(other.value.as_ref())
344    }
345}
346
347/// Extra information
348#[derive(Clone)]
349pub(crate) struct Info {
350    value: Arc<dyn DynKey>,
351}
352impl Eq for Info {}
353
354impl Hash for Info {
355    fn hash<H: Hasher>(&self, state: &mut H) {
356        self.value.dyn_type_id().hash(state);
357        self.value.dyn_hash(state)
358    }
359}
360
361impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
362    fn dyn_eq(&self, other: &dyn DynKey) -> bool {
363        if let Some(other) = other.as_any().downcast_ref::<T>() {
364            self == other
365        } else {
366            false
367        }
368    }
369
370    fn dyn_type_id(&self) -> TypeId {
371        TypeId::of::<T>()
372    }
373
374    fn dyn_hash(&self, state: &mut dyn Hasher) {
375        let hash = self.dyn_hash_one();
376        state.write_u128(hash);
377    }
378
379    fn dyn_hash_one(&self) -> StableHash {
380        let mut hasher = StableHasher::new();
381        self.hash(&mut hasher);
382        hasher.finalize()
383    }
384
385    fn as_any(&self) -> &dyn Any {
386        self
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use std::collections::HashSet;
394
395    #[test_log::test]
396    pub fn kernel_id_hash() {
397        let value_1 = KernelId::new::<()>().info("1");
398        let value_2 = KernelId::new::<()>().info("2");
399
400        let mut set = HashSet::new();
401
402        set.insert(value_1.clone());
403
404        assert!(set.contains(&value_1));
405        assert!(!set.contains(&value_2));
406    }
407}