use cubecl_runtime::ExecutionMode;
use std::any::{Any, TypeId};
use std::fmt::Display;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
#[derive(Hash, PartialEq, Eq, Clone, Debug)]
pub struct KernelId {
type_id: core::any::TypeId,
info: Option<Info>,
mode: Option<ExecutionMode>,
}
impl Display for KernelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.info {
Some(info) => f.write_fmt(format_args!("{}", info)),
None => f.write_str("No info"),
}
}
}
impl KernelId {
pub fn new<T: 'static>() -> Self {
Self {
type_id: core::any::TypeId::of::<T>(),
info: None,
mode: None,
}
}
pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
mut self,
info: I,
) -> Self {
self.info = Some(Info::new(info));
self
}
pub fn mode(&mut self, mode: ExecutionMode) {
self.mode = Some(mode);
}
}
#[derive(Clone, Debug)]
struct Info {
value: Arc<dyn DynKey>,
}
impl core::fmt::Display for Info {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{:?}", self.value))
}
}
impl Info {
fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
Self {
value: Arc::new(id),
}
}
}
trait DynKey: core::fmt::Debug + Send + Sync {
fn dyn_type_id(&self) -> TypeId;
fn dyn_eq(&self, other: &dyn DynKey) -> bool;
fn dyn_hash(&self, state: &mut dyn Hasher);
fn as_any(&self) -> &dyn Any;
}
impl PartialEq for Info {
fn eq(&self, other: &Self) -> bool {
self.value.dyn_eq(other.value.as_ref())
}
}
impl Eq for Info {}
impl Hash for Info {
fn hash<H: Hasher>(&self, state: &mut H) {
self.value.dyn_type_id().hash(state);
self.value.dyn_hash(state)
}
}
impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
fn dyn_eq(&self, other: &dyn DynKey) -> bool {
if let Some(other) = other.as_any().downcast_ref::<T>() {
self == other
} else {
false
}
}
fn dyn_type_id(&self) -> TypeId {
TypeId::of::<T>()
}
fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut default_hasher = DefaultHasher::new();
self.hash(&mut default_hasher);
state.write_u64(default_hasher.finish());
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
pub fn kernel_id_hash() {
let value_1 = KernelId::new::<()>().info("1");
let value_2 = KernelId::new::<()>().info("2");
let mut set = HashSet::new();
set.insert(value_1.clone());
assert!(set.contains(&value_1));
assert!(!set.contains(&value_2));
}
}