Skip to main content

cubecl_runtime/
id.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5    any::{Any, TypeId},
6    fmt::Display,
7    hash::{Hash, Hasher},
8};
9use cubecl_common::{
10    format::{DebugRaw, format_str},
11    hash::{StableHash, StableHasher},
12};
13use cubecl_ir::AddressType;
14use derive_more::{Eq, PartialEq};
15
16use crate::server::{CubeDim, ExecutionMode};
17
18#[macro_export(local_inner_macros)]
19/// Create a new storage ID type.
20macro_rules! storage_id_type {
21    ($name:ident) => {
22        /// Storage ID.
23        #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
24        pub struct $name {
25            value: usize,
26        }
27
28        impl $name {
29            /// Create a new ID.
30            pub fn new() -> Self {
31                use core::sync::atomic::{AtomicUsize, Ordering};
32
33                static COUNTER: AtomicUsize = AtomicUsize::new(0);
34
35                let value = COUNTER.fetch_add(1, Ordering::Relaxed);
36                if value == usize::MAX {
37                    core::panic!("Memory ID overflowed");
38                }
39                Self { value }
40            }
41        }
42
43        impl Default for $name {
44            fn default() -> Self {
45                Self::new()
46            }
47        }
48    };
49}
50
51/// Kernel unique identifier.
52#[derive(Clone, PartialEq, Eq)]
53pub struct KernelId {
54    #[eq(skip)]
55    type_name: &'static str,
56    pub(crate) type_id: core::any::TypeId,
57    pub(crate) address_type: AddressType,
58    /// The [`CubeDim`] for this kernel
59    pub cube_dim: CubeDim,
60    pub(crate) mode: ExecutionMode,
61    pub(crate) info: Option<Info>,
62}
63
64impl Hash for KernelId {
65    fn hash<H: Hasher>(&self, state: &mut H) {
66        self.type_id.hash(state);
67        self.address_type.hash(state);
68        self.cube_dim.hash(state);
69        self.mode.hash(state);
70        self.info.hash(state);
71    }
72}
73
74impl core::fmt::Debug for KernelId {
75    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
76        let mut debug_str = f.debug_struct("KernelId");
77        debug_str
78            .field("type", &DebugRaw(self.type_name))
79            .field("address_type", &self.address_type);
80        debug_str.field("cube_dim", &self.cube_dim);
81        debug_str.field("mode", &self.mode);
82        match &self.info {
83            Some(info) => debug_str.field("info", info),
84            None => debug_str.field("info", &self.info),
85        };
86        debug_str.finish()
87    }
88}
89
90impl Display for KernelId {
91    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
92        match &self.info {
93            Some(info) => f.write_str(
94                format_str(
95                    format!("{info:?}").as_str(),
96                    &[('(', ')'), ('[', ']'), ('{', '}')],
97                    true,
98                )
99                .as_str(),
100            ),
101            None => f.write_str("No info"),
102        }
103    }
104}
105
106impl KernelId {
107    /// Create a new [kernel id](KernelId) for a type.
108    pub fn new<T: 'static>() -> Self {
109        Self {
110            type_id: core::any::TypeId::of::<T>(),
111            type_name: core::any::type_name::<T>(),
112            info: None,
113            cube_dim: CubeDim::new_single(),
114            mode: ExecutionMode::Checked,
115            address_type: Default::default(),
116        }
117    }
118
119    /// Render the key in a standard format that can be used between runs.
120    ///
121    /// Can be used as a persistent kernel cache key.
122    pub fn stable_format(&self) -> String {
123        format!(
124            "{}-{}-{:?}-{:?}-{:?}",
125            self.type_name, self.address_type, self.cube_dim, self.mode, self.info
126        )
127    }
128
129    /// Hash the key in a stable way that can be used between runs.
130    ///
131    /// Can be used as a persistent kernel cache key.
132    pub fn stable_hash(&self) -> StableHash {
133        let mut hasher = StableHasher::new();
134        self.type_name.hash(&mut hasher);
135        self.address_type.hash(&mut hasher);
136        self.cube_dim.hash(&mut hasher);
137        self.mode.hash(&mut hasher);
138        self.info.hash(&mut hasher);
139
140        hasher.finalize()
141    }
142
143    /// Add information to the [kernel id](KernelId).
144    ///
145    /// The information is used to differentiate kernels of the same kind but with different
146    /// configurations, which affect the generated code.
147    pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
148        mut self,
149        info: I,
150    ) -> Self {
151        self.info = Some(Info::new(info));
152        self
153    }
154
155    /// Set the [execution mode](ExecutionMode).
156    pub fn mode(&mut self, mode: ExecutionMode) {
157        self.mode = mode;
158    }
159
160    /// Set the [cube dim](CubeDim).
161    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
162        self.cube_dim = cube_dim;
163        self
164    }
165
166    /// Set the [`AddressType`].
167    pub fn address_type(mut self, addr_ty: AddressType) -> Self {
168        self.address_type = addr_ty;
169        self
170    }
171}
172
173impl core::fmt::Debug for Info {
174    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
175        self.value.fmt(f)
176    }
177}
178
179impl Info {
180    fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
181        Self {
182            value: Arc::new(id),
183        }
184    }
185}
186
187/// This trait allows various types to be used as keys within a single data structure.
188///
189/// The downside is that the hashing method is hardcoded and cannot be configured using the
190/// [`core::hash::Hash`] function. The provided [Hasher] will be modified, but only based on the
191/// result of the hash from the [`DefaultHasher`].
192trait DynKey: core::fmt::Debug + Send + Sync {
193    fn dyn_type_id(&self) -> TypeId;
194    fn dyn_eq(&self, other: &dyn DynKey) -> bool;
195    fn dyn_hash(&self, state: &mut dyn Hasher);
196    fn dyn_hash_one(&self) -> StableHash;
197    fn as_any(&self) -> &dyn Any;
198}
199
200impl PartialEq for Info {
201    fn eq(&self, other: &Self) -> bool {
202        self.value.dyn_eq(other.value.as_ref())
203    }
204}
205
206/// Extra information
207#[derive(Clone)]
208pub(crate) struct Info {
209    value: Arc<dyn DynKey>,
210}
211impl Eq for Info {}
212
213impl Hash for Info {
214    fn hash<H: Hasher>(&self, state: &mut H) {
215        self.value.dyn_type_id().hash(state);
216        self.value.dyn_hash(state)
217    }
218}
219
220impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
221    fn dyn_eq(&self, other: &dyn DynKey) -> bool {
222        if let Some(other) = other.as_any().downcast_ref::<T>() {
223            self == other
224        } else {
225            false
226        }
227    }
228
229    fn dyn_type_id(&self) -> TypeId {
230        TypeId::of::<T>()
231    }
232
233    fn dyn_hash(&self, state: &mut dyn Hasher) {
234        let hash = self.dyn_hash_one();
235        state.write_u128(hash);
236    }
237
238    fn dyn_hash_one(&self) -> StableHash {
239        let mut hasher = StableHasher::new();
240        self.hash(&mut hasher);
241        hasher.finalize()
242    }
243
244    fn as_any(&self) -> &dyn Any {
245        self
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use std::collections::HashSet;
253
254    #[test_log::test]
255    pub fn kernel_id_hash() {
256        let value_1 = KernelId::new::<()>().info("1");
257        let value_2 = KernelId::new::<()>().info("2");
258
259        let mut set = HashSet::new();
260
261        set.insert(value_1.clone());
262
263        assert!(set.contains(&value_1));
264        assert!(!set.contains(&value_2));
265    }
266}