1use cubecl_runtime::ExecutionMode;
2use std::any::{Any, TypeId};
3use std::hash::{DefaultHasher, Hash, Hasher};
4use std::sync::Arc;
5
6#[derive(Hash, PartialEq, Eq, Clone, Debug)]
8pub struct KernelId {
9 pub(crate) type_id: core::any::TypeId,
10 pub(crate) info: Option<Info>,
11 pub(crate) mode: Option<ExecutionMode>,
12}
13
14impl KernelId {
15 pub fn new<T: 'static>() -> Self {
17 Self {
18 type_id: core::any::TypeId::of::<T>(),
19 info: None,
20 mode: None,
21 }
22 }
23
24 pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
29 mut self,
30 info: I,
31 ) -> Self {
32 self.info = Some(Info::new(info));
33 self
34 }
35
36 pub fn mode(&mut self, mode: ExecutionMode) {
38 self.mode = Some(mode);
39 }
40}
41
42#[derive(Clone)]
44pub(crate) struct Info {
45 value: Arc<dyn DynKey>,
46}
47
48impl core::fmt::Debug for Info {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.write_fmt(format_args!("{:?}", self.value))
51 }
52}
53
54impl Info {
55 fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
56 Self {
57 value: Arc::new(id),
58 }
59 }
60}
61
62trait DynKey: core::fmt::Debug + Send + Sync {
68 fn dyn_type_id(&self) -> TypeId;
69 fn dyn_eq(&self, other: &dyn DynKey) -> bool;
70 fn dyn_hash(&self, state: &mut dyn Hasher);
71 fn as_any(&self) -> &dyn Any;
72}
73
74impl PartialEq for Info {
75 fn eq(&self, other: &Self) -> bool {
76 self.value.dyn_eq(other.value.as_ref())
77 }
78}
79
80impl Eq for Info {}
81
82impl Hash for Info {
83 fn hash<H: Hasher>(&self, state: &mut H) {
84 self.value.dyn_type_id().hash(state);
85 self.value.dyn_hash(state)
86 }
87}
88
89impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
90 fn dyn_eq(&self, other: &dyn DynKey) -> bool {
91 if let Some(other) = other.as_any().downcast_ref::<T>() {
92 self == other
93 } else {
94 false
95 }
96 }
97
98 fn dyn_type_id(&self) -> TypeId {
99 TypeId::of::<T>()
100 }
101
102 fn dyn_hash(&self, state: &mut dyn Hasher) {
103 let mut default_hasher = DefaultHasher::new();
104 self.hash(&mut default_hasher);
105 state.write_u64(default_hasher.finish());
106 }
107
108 fn as_any(&self) -> &dyn Any {
109 self
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use std::collections::HashSet;
117
118 #[test]
119 pub fn kernel_id_hash() {
120 let value_1 = KernelId::new::<()>().info("1");
121 let value_2 = KernelId::new::<()>().info("2");
122
123 let mut set = HashSet::new();
124
125 set.insert(value_1.clone());
126
127 assert!(set.contains(&value_1));
128 assert!(!set.contains(&value_2));
129 }
130}