cubecl_core/
id.rs

1use std::any::{Any, TypeId};
2use std::hash::{DefaultHasher, Hash, Hasher};
3use std::sync::Arc;
4
5use cubecl_common::ExecutionMode;
6use cubecl_runtime::client::ComputeClient;
7
8/// Kernel unique identifier.
9#[derive(Clone, Debug)]
10pub struct KernelId {
11    pub(crate) type_id: core::any::TypeId,
12    pub(crate) info: Option<Info>,
13    pub(crate) mode: Option<ExecutionMode>,
14    type_name: &'static str,
15}
16
17impl Hash for KernelId {
18    fn hash<H: Hasher>(&self, state: &mut H) {
19        self.type_id.hash(state);
20        self.info.hash(state);
21        self.mode.hash(state);
22    }
23}
24
25impl PartialEq for KernelId {
26    fn eq(&self, other: &Self) -> bool {
27        self.type_id == other.type_id && self.mode == other.mode && self.info == other.info
28    }
29}
30
31impl Eq for KernelId {}
32
33impl KernelId {
34    /// Create a new [kernel id](KernelId) for a type.
35    pub fn new<T: 'static>() -> Self {
36        Self {
37            type_id: core::any::TypeId::of::<T>(),
38            type_name: core::any::type_name::<T>(),
39            info: None,
40            mode: None,
41        }
42    }
43
44    /// Render the key in a standard format that can be used between runs.
45    ///
46    /// Can be used as a persistent kernel cache key.
47    pub fn stable_format(&self) -> String {
48        format!("{}-{:?}-{:?}", self.type_name, self.info, self.mode)
49    }
50
51    /// Add information to the [kernel id](KernelId).
52    ///
53    /// The information is used to differentiate kernels of the same kind but with different
54    /// configurations, which affect the generated code.
55    pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
56        mut self,
57        info: I,
58    ) -> Self {
59        self.info = Some(Info::new(info));
60        self
61    }
62
63    /// Set the [execution mode](ExecutionMode).
64    pub fn mode(&mut self, mode: ExecutionMode) {
65        self.mode = Some(mode);
66    }
67}
68
69/// Extra information
70#[derive(Clone)]
71pub(crate) struct Info {
72    value: Arc<dyn DynKey>,
73}
74
75impl core::fmt::Debug for Info {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.write_fmt(format_args!("{:?}", self.value))
78    }
79}
80
81impl Info {
82    fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
83        Self {
84            value: Arc::new(id),
85        }
86    }
87}
88
89/// This trait allows various types to be used as keys within a single data structure.
90///
91/// The downside is that the hashing method is hardcoded and cannot be configured using the
92/// [core::hash::Hash] function. The provided [Hasher] will be modified, but only based on the
93/// result of the hash from the [DefaultHasher].
94trait DynKey: core::fmt::Debug + Send + Sync {
95    fn dyn_type_id(&self) -> TypeId;
96    fn dyn_eq(&self, other: &dyn DynKey) -> bool;
97    fn dyn_hash(&self, state: &mut dyn Hasher);
98    fn as_any(&self) -> &dyn Any;
99}
100
101impl PartialEq for Info {
102    fn eq(&self, other: &Self) -> bool {
103        self.value.dyn_eq(other.value.as_ref())
104    }
105}
106
107impl Eq for Info {}
108
109impl Hash for Info {
110    fn hash<H: Hasher>(&self, state: &mut H) {
111        self.value.dyn_type_id().hash(state);
112        self.value.dyn_hash(state)
113    }
114}
115
116impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
117    fn dyn_eq(&self, other: &dyn DynKey) -> bool {
118        if let Some(other) = other.as_any().downcast_ref::<T>() {
119            self == other
120        } else {
121            false
122        }
123    }
124
125    fn dyn_type_id(&self) -> TypeId {
126        TypeId::of::<T>()
127    }
128
129    fn dyn_hash(&self, state: &mut dyn Hasher) {
130        let mut default_hasher = DefaultHasher::new();
131        self.hash(&mut default_hasher);
132        state.write_u64(default_hasher.finish());
133    }
134
135    fn as_any(&self) -> &dyn Any {
136        self
137    }
138}
139
140/// The device id.
141#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
142pub struct DeviceId {
143    /// The type id identifies the type of the device.
144    pub type_id: u16,
145    /// The index id identifies the device number.
146    pub index_id: u32,
147}
148
149/// ID used to identify a Just-in-Time environment.
150#[derive(Hash, PartialEq, Eq, Debug, Clone)]
151pub struct CubeTuneId {
152    device: DeviceId,
153    name: &'static str,
154}
155
156impl CubeTuneId {
157    /// Create a new ID.
158    pub fn new<R: crate::Runtime>(
159        client: &ComputeClient<R::Server, R::Channel>,
160        device: &R::Device,
161    ) -> Self {
162        Self {
163            device: R::device_id(device),
164            name: R::name(client),
165        }
166    }
167}
168
169impl core::fmt::Display for CubeTuneId {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        f.write_fmt(format_args!(
172            "device-{}-{}-{}",
173            self.device.type_id, self.device.index_id, self.name
174        ))
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use std::collections::HashSet;
182
183    #[test]
184    pub fn kernel_id_hash() {
185        let value_1 = KernelId::new::<()>().info("1");
186        let value_2 = KernelId::new::<()>().info("2");
187
188        let mut set = HashSet::new();
189
190        set.insert(value_1.clone());
191
192        assert!(set.contains(&value_1));
193        assert!(!set.contains(&value_2));
194    }
195}