cubecl_core/
id.rs

1use cubecl_runtime::ExecutionMode;
2use std::any::{Any, TypeId};
3use std::hash::{DefaultHasher, Hash, Hasher};
4use std::sync::Arc;
5
6/// Kernel unique identifier.
7#[derive(Hash, PartialEq, Eq, Clone, Debug)]
8pub struct KernelId {
9    pub(crate) type_id: core::any::TypeId,
10    pub(crate) info: Option<Info>,
11    pub(crate) mode: Option<ExecutionMode>,
12}
13
14impl KernelId {
15    /// Create a new [kernel id](KernelId) for a type.
16    pub fn new<T: 'static>() -> Self {
17        Self {
18            type_id: core::any::TypeId::of::<T>(),
19            info: None,
20            mode: None,
21        }
22    }
23
24    /// Add information to the [kernel id](KernelId).
25    ///
26    /// The information is used to differentiate kernels of the same kind but with different
27    /// configurations, which affect the generated code.
28    pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
29        mut self,
30        info: I,
31    ) -> Self {
32        self.info = Some(Info::new(info));
33        self
34    }
35
36    /// Set the [execution mode](ExecutionMode).
37    pub fn mode(&mut self, mode: ExecutionMode) {
38        self.mode = Some(mode);
39    }
40}
41
42/// Extra information
43#[derive(Clone)]
44pub(crate) struct Info {
45    value: Arc<dyn DynKey>,
46}
47
48impl core::fmt::Debug for Info {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        f.write_fmt(format_args!("{:?}", self.value))
51    }
52}
53
54impl Info {
55    fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
56        Self {
57            value: Arc::new(id),
58        }
59    }
60}
61
62/// This trait allows various types to be used as keys within a single data structure.
63///
64/// The downside is that the hashing method is hardcoded and cannot be configured using the
65/// [core::hash::Hash] function. The provided [Hasher] will be modified, but only based on the
66/// result of the hash from the [DefaultHasher].
67trait DynKey: core::fmt::Debug + Send + Sync {
68    fn dyn_type_id(&self) -> TypeId;
69    fn dyn_eq(&self, other: &dyn DynKey) -> bool;
70    fn dyn_hash(&self, state: &mut dyn Hasher);
71    fn as_any(&self) -> &dyn Any;
72}
73
74impl PartialEq for Info {
75    fn eq(&self, other: &Self) -> bool {
76        self.value.dyn_eq(other.value.as_ref())
77    }
78}
79
80impl Eq for Info {}
81
82impl Hash for Info {
83    fn hash<H: Hasher>(&self, state: &mut H) {
84        self.value.dyn_type_id().hash(state);
85        self.value.dyn_hash(state)
86    }
87}
88
89impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
90    fn dyn_eq(&self, other: &dyn DynKey) -> bool {
91        if let Some(other) = other.as_any().downcast_ref::<T>() {
92            self == other
93        } else {
94            false
95        }
96    }
97
98    fn dyn_type_id(&self) -> TypeId {
99        TypeId::of::<T>()
100    }
101
102    fn dyn_hash(&self, state: &mut dyn Hasher) {
103        let mut default_hasher = DefaultHasher::new();
104        self.hash(&mut default_hasher);
105        state.write_u64(default_hasher.finish());
106    }
107
108    fn as_any(&self) -> &dyn Any {
109        self
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use std::collections::HashSet;
117
118    #[test]
119    pub fn kernel_id_hash() {
120        let value_1 = KernelId::new::<()>().info("1");
121        let value_2 = KernelId::new::<()>().info("2");
122
123        let mut set = HashSet::new();
124
125        set.insert(value_1.clone());
126
127        assert!(set.contains(&value_1));
128        assert!(!set.contains(&value_2));
129    }
130}