1use std::any::{Any, TypeId};
2use std::hash::{DefaultHasher, Hash, Hasher};
3use std::sync::Arc;
4
5use cubecl_common::ExecutionMode;
6use cubecl_runtime::client::ComputeClient;
7
8#[derive(Clone, Debug)]
10pub struct KernelId {
11 pub(crate) type_id: core::any::TypeId,
12 pub(crate) info: Option<Info>,
13 pub(crate) mode: Option<ExecutionMode>,
14 type_name: &'static str,
15}
16
17impl Hash for KernelId {
18 fn hash<H: Hasher>(&self, state: &mut H) {
19 self.type_id.hash(state);
20 self.info.hash(state);
21 self.mode.hash(state);
22 }
23}
24
25impl PartialEq for KernelId {
26 fn eq(&self, other: &Self) -> bool {
27 self.type_id == other.type_id && self.mode == other.mode && self.info == other.info
28 }
29}
30
31impl Eq for KernelId {}
32
33impl KernelId {
34 pub fn new<T: 'static>() -> Self {
36 Self {
37 type_id: core::any::TypeId::of::<T>(),
38 type_name: core::any::type_name::<T>(),
39 info: None,
40 mode: None,
41 }
42 }
43
44 pub fn stable_format(&self) -> String {
48 format!("{}-{:?}-{:?}", self.type_name, self.info, self.mode)
49 }
50
51 pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
56 mut self,
57 info: I,
58 ) -> Self {
59 self.info = Some(Info::new(info));
60 self
61 }
62
63 pub fn mode(&mut self, mode: ExecutionMode) {
65 self.mode = Some(mode);
66 }
67}
68
69#[derive(Clone)]
71pub(crate) struct Info {
72 value: Arc<dyn DynKey>,
73}
74
75impl core::fmt::Debug for Info {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.write_fmt(format_args!("{:?}", self.value))
78 }
79}
80
81impl Info {
82 fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
83 Self {
84 value: Arc::new(id),
85 }
86 }
87}
88
89trait DynKey: core::fmt::Debug + Send + Sync {
95 fn dyn_type_id(&self) -> TypeId;
96 fn dyn_eq(&self, other: &dyn DynKey) -> bool;
97 fn dyn_hash(&self, state: &mut dyn Hasher);
98 fn as_any(&self) -> &dyn Any;
99}
100
101impl PartialEq for Info {
102 fn eq(&self, other: &Self) -> bool {
103 self.value.dyn_eq(other.value.as_ref())
104 }
105}
106
107impl Eq for Info {}
108
109impl Hash for Info {
110 fn hash<H: Hasher>(&self, state: &mut H) {
111 self.value.dyn_type_id().hash(state);
112 self.value.dyn_hash(state)
113 }
114}
115
116impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
117 fn dyn_eq(&self, other: &dyn DynKey) -> bool {
118 if let Some(other) = other.as_any().downcast_ref::<T>() {
119 self == other
120 } else {
121 false
122 }
123 }
124
125 fn dyn_type_id(&self) -> TypeId {
126 TypeId::of::<T>()
127 }
128
129 fn dyn_hash(&self, state: &mut dyn Hasher) {
130 let mut default_hasher = DefaultHasher::new();
131 self.hash(&mut default_hasher);
132 state.write_u64(default_hasher.finish());
133 }
134
135 fn as_any(&self) -> &dyn Any {
136 self
137 }
138}
139
140#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
142pub struct DeviceId {
143 pub type_id: u16,
145 pub index_id: u32,
147}
148
149#[derive(Hash, PartialEq, Eq, Debug, Clone)]
151pub struct CubeTuneId {
152 device: DeviceId,
153 name: &'static str,
154}
155
156impl CubeTuneId {
157 pub fn new<R: crate::Runtime>(
159 client: &ComputeClient<R::Server, R::Channel>,
160 device: &R::Device,
161 ) -> Self {
162 Self {
163 device: R::device_id(device),
164 name: R::name(client),
165 }
166 }
167}
168
169impl core::fmt::Display for CubeTuneId {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.write_fmt(format_args!(
172 "device-{}-{}-{}",
173 self.device.type_id, self.device.index_id, self.name
174 ))
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use std::collections::HashSet;
182
183 #[test]
184 pub fn kernel_id_hash() {
185 let value_1 = KernelId::new::<()>().info("1");
186 let value_2 = KernelId::new::<()>().info("2");
187
188 let mut set = HashSet::new();
189
190 set.insert(value_1.clone());
191
192 assert!(set.contains(&value_1));
193 assert!(!set.contains(&value_2));
194 }
195}