cubecl_runtime/
id.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::sync::Arc;
4use core::{
5    any::{Any, TypeId},
6    fmt::Display,
7    hash::{BuildHasher, Hash, Hasher},
8};
9use cubecl_common::format::{DebugRaw, format_str};
10use cubecl_ir::AddressType;
11use derive_more::{Eq, PartialEq};
12
13use crate::server::{CubeDim, ExecutionMode};
14
15#[macro_export(local_inner_macros)]
16/// Create a new storage ID type.
17macro_rules! storage_id_type {
18    ($name:ident) => {
19        /// Storage ID.
20        #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
21        pub struct $name {
22            value: usize,
23        }
24
25        impl $name {
26            /// Create a new ID.
27            pub fn new() -> Self {
28                use core::sync::atomic::{AtomicUsize, Ordering};
29
30                static COUNTER: AtomicUsize = AtomicUsize::new(0);
31
32                let value = COUNTER.fetch_add(1, Ordering::Relaxed);
33                if value == usize::MAX {
34                    core::panic!("Memory ID overflowed");
35                }
36                Self { value }
37            }
38        }
39
40        impl Default for $name {
41            fn default() -> Self {
42                Self::new()
43            }
44        }
45    };
46}
47
48/// Reference to a buffer handle.
49#[derive(Clone, Debug, PartialEq, Eq)]
50pub struct HandleRef<Id> {
51    id: Arc<Id>,
52    all: Arc<()>,
53}
54
55/// Reference to buffer binding.
56#[derive(Clone, Debug)]
57pub struct BindingRef<Id> {
58    id: Id,
59    _all: Arc<()>,
60}
61
62impl<Id> BindingRef<Id>
63where
64    Id: Clone + core::fmt::Debug,
65{
66    /// The id associated to the buffer.
67    pub(crate) fn id(&self) -> &Id {
68        &self.id
69    }
70}
71
72impl<Id> HandleRef<Id>
73where
74    Id: Clone + core::fmt::Debug,
75{
76    /// Create a new handle.
77    pub(crate) fn new(id: Id) -> Self {
78        Self {
79            id: Arc::new(id),
80            all: Arc::new(()),
81        }
82    }
83
84    /// The id associated to the handle.
85    pub(crate) fn id(&self) -> &Id {
86        &self.id
87    }
88
89    /// Get the binding.
90    pub(crate) fn binding(self) -> BindingRef<Id> {
91        BindingRef {
92            id: self.id.as_ref().clone(),
93            _all: self.all,
94        }
95    }
96
97    /// If the handle can be mut.
98    pub(crate) fn can_mut(&self) -> bool {
99        // 1 memory management reference with 1 tensor reference.
100        Arc::strong_count(&self.id) <= 2
101    }
102
103    /// If the resource is free.
104    pub(crate) fn is_free(&self) -> bool {
105        Arc::strong_count(&self.all) <= 1
106    }
107}
108
109#[macro_export(local_inner_macros)]
110/// Create new memory ID types.
111macro_rules! memory_id_type {
112    ($id:ident, $handle:ident) => {
113        /// Memory Handle.
114        #[derive(Clone, Debug, PartialEq, Eq)]
115        pub struct $handle {
116            value: $crate::id::HandleRef<$id>,
117        }
118
119        /// Memory ID.
120        #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
121        pub struct $id {
122            pub(crate) value: usize,
123        }
124
125        impl $handle {
126            /// Create a new ID.
127            pub(crate) fn new() -> Self {
128                let value = Self::gen_id();
129                Self {
130                    value: $crate::id::HandleRef::new($id { value }),
131                }
132            }
133
134            fn gen_id() -> usize {
135                static COUNTER: core::sync::atomic::AtomicUsize =
136                    core::sync::atomic::AtomicUsize::new(0);
137
138                let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
139                if value == usize::MAX {
140                    core::panic!("Memory ID overflowed");
141                }
142
143                value
144            }
145        }
146
147        impl core::ops::Deref for $handle {
148            type Target = $crate::id::HandleRef<$id>;
149
150            fn deref(&self) -> &Self::Target {
151                &self.value
152            }
153        }
154
155        impl Default for $handle {
156            fn default() -> Self {
157                Self::new()
158            }
159        }
160    };
161
162    ($id:ident, $handle:ident, $binding:ident) => {
163        memory_id_type!($id, $handle);
164
165        /// Binding of a memory handle.
166        #[derive(Clone, Debug)]
167        pub struct $binding {
168            value: $crate::id::BindingRef<$id>,
169        }
170
171        impl $handle {
172            pub(crate) fn binding(self) -> $binding {
173                $binding {
174                    value: self.value.binding(),
175                }
176            }
177        }
178
179        impl core::ops::Deref for $binding {
180            type Target = $crate::id::BindingRef<$id>;
181
182            fn deref(&self) -> &Self::Target {
183                &self.value
184            }
185        }
186    };
187}
188
189/// Kernel unique identifier.
190#[derive(Clone, PartialEq, Eq)]
191pub struct KernelId {
192    #[eq(skip)]
193    type_name: &'static str,
194    pub(crate) type_id: core::any::TypeId,
195    pub(crate) address_type: AddressType,
196    pub(crate) cube_dim: Option<CubeDim>,
197    pub(crate) mode: ExecutionMode,
198    pub(crate) info: Option<Info>,
199}
200
201impl Hash for KernelId {
202    fn hash<H: Hasher>(&self, state: &mut H) {
203        self.type_id.hash(state);
204        self.address_type.hash(state);
205        self.cube_dim.hash(state);
206        self.mode.hash(state);
207        self.info.hash(state);
208    }
209}
210
211impl core::fmt::Debug for KernelId {
212    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
213        let mut debug_str = f.debug_struct("KernelId");
214        debug_str
215            .field("type", &DebugRaw(self.type_name))
216            .field("address_type", &self.address_type);
217        match &self.cube_dim {
218            Some(cube_dim) => debug_str.field("cube_dim", cube_dim),
219            None => debug_str.field("cube_dim", &self.cube_dim),
220        };
221        debug_str.field("mode", &self.mode);
222        match &self.info {
223            Some(info) => debug_str.field("info", info),
224            None => debug_str.field("info", &self.info),
225        };
226        debug_str.finish()
227    }
228}
229
230impl Display for KernelId {
231    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
232        match &self.info {
233            Some(info) => f.write_str(
234                format_str(
235                    format!("{info:?}").as_str(),
236                    &[('(', ')'), ('[', ']'), ('{', '}')],
237                    true,
238                )
239                .as_str(),
240            ),
241            None => f.write_str("No info"),
242        }
243    }
244}
245
246impl KernelId {
247    /// Create a new [kernel id](KernelId) for a type.
248    pub fn new<T: 'static>() -> Self {
249        Self {
250            type_id: core::any::TypeId::of::<T>(),
251            type_name: core::any::type_name::<T>(),
252            info: None,
253            cube_dim: None,
254            mode: ExecutionMode::Checked,
255            address_type: Default::default(),
256        }
257    }
258
259    /// Render the key in a standard format that can be used between runs.
260    ///
261    /// Can be used as a persistent kernel cache key.
262    pub fn stable_format(&self) -> String {
263        format!(
264            "{}-{}-{:?}-{:?}-{:?}",
265            self.type_name, self.address_type, self.cube_dim, self.mode, self.info
266        )
267    }
268
269    /// Add information to the [kernel id](KernelId).
270    ///
271    /// The information is used to differentiate kernels of the same kind but with different
272    /// configurations, which affect the generated code.
273    pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
274        mut self,
275        info: I,
276    ) -> Self {
277        self.info = Some(Info::new(info));
278        self
279    }
280
281    /// Set the [execution mode](ExecutionMode).
282    pub fn mode(&mut self, mode: ExecutionMode) {
283        self.mode = mode;
284    }
285
286    /// Set the [cube dim](CubeDim).
287    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
288        self.cube_dim = Some(cube_dim);
289        self
290    }
291
292    /// Set the [address_type](AddressType).
293    pub fn address_type(mut self, addr_ty: AddressType) -> Self {
294        self.address_type = addr_ty;
295        self
296    }
297}
298
299impl core::fmt::Debug for Info {
300    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
301        self.value.fmt(f)
302    }
303}
304
305impl Info {
306    fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
307        Self {
308            value: Arc::new(id),
309        }
310    }
311}
312
313/// This trait allows various types to be used as keys within a single data structure.
314///
315/// The downside is that the hashing method is hardcoded and cannot be configured using the
316/// [core::hash::Hash] function. The provided [Hasher] will be modified, but only based on the
317/// result of the hash from the [DefaultHasher].
318trait DynKey: core::fmt::Debug + Send + Sync {
319    fn dyn_type_id(&self) -> TypeId;
320    fn dyn_eq(&self, other: &dyn DynKey) -> bool;
321    fn dyn_hash(&self, state: &mut dyn Hasher);
322    fn as_any(&self) -> &dyn Any;
323}
324
325impl PartialEq for Info {
326    fn eq(&self, other: &Self) -> bool {
327        self.value.dyn_eq(other.value.as_ref())
328    }
329}
330
331/// Extra information
332#[derive(Clone)]
333pub(crate) struct Info {
334    value: Arc<dyn DynKey>,
335}
336impl Eq for Info {}
337
338impl Hash for Info {
339    fn hash<H: Hasher>(&self, state: &mut H) {
340        self.value.dyn_type_id().hash(state);
341        self.value.dyn_hash(state)
342    }
343}
344
345impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
346    fn dyn_eq(&self, other: &dyn DynKey) -> bool {
347        if let Some(other) = other.as_any().downcast_ref::<T>() {
348            self == other
349        } else {
350            false
351        }
352    }
353
354    fn dyn_type_id(&self) -> TypeId {
355        TypeId::of::<T>()
356    }
357
358    fn dyn_hash(&self, state: &mut dyn Hasher) {
359        // HashBrown uses foldhash but the default hasher still creates some random state. We need this hash here
360        // to be exactly reproducible.
361        let hash = foldhash::fast::FixedState::with_seed(0).hash_one(self);
362        state.write_u64(hash);
363    }
364
365    fn as_any(&self) -> &dyn Any {
366        self
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use std::collections::HashSet;
374
375    #[test_log::test]
376    pub fn kernel_id_hash() {
377        let value_1 = KernelId::new::<()>().info("1");
378        let value_2 = KernelId::new::<()>().info("2");
379
380        let mut set = HashSet::new();
381
382        set.insert(value_1.clone());
383
384        assert!(set.contains(&value_1));
385        assert!(!set.contains(&value_2));
386    }
387}