use alloc::{collections::BTreeMap, rc::Rc};
use core::cell::RefCell;
use cubecl_ir::Scope;
use crate::prelude::{CubeDebug, CubeType, IntoMut};
pub struct Registry<K, V> {
map: Rc<RefCell<BTreeMap<K, V>>>,
}
pub trait RegistryQuery<K>: Into<K> {}
impl RegistryQuery<u32> for u32 {}
impl RegistryQuery<usize> for usize {}
impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
pub fn new() -> Self {
Self::default()
}
pub fn __expand_new(_: &mut Scope) -> Registry<K, V::ExpandType> {
Registry {
map: Rc::new(RefCell::new(BTreeMap::new())),
}
}
pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
let key = query.into();
let map = self.map.as_ref().borrow();
match map.get(&key) {
Some(val) => val.clone(),
None => panic!("No value found for key {key:?}"),
}
}
pub fn find_or_default<Query: RegistryQuery<K>>(&mut self, query: Query) -> V
where
V: Default,
K: Clone,
{
let key = query.into();
let mut map = self.map.as_ref().borrow_mut();
match map.get(&key) {
Some(val) => val.clone(),
None => {
map.insert(key.clone(), Default::default());
map.get(&key).unwrap().clone()
}
}
}
pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
let key = query.into();
let mut map = self.map.as_ref().borrow_mut();
map.insert(key, value);
}
pub fn __expand_find<Query: RegistryQuery<K>>(
_scope: &mut Scope,
state: Registry<K, V::ExpandType>,
key: Query,
) -> V::ExpandType {
let key = key.into();
let map = state.map.as_ref().borrow();
map.get(&key).unwrap().clone()
}
pub fn __expand_find_or_default<Query: RegistryQuery<K>>(
_scope: &mut Scope,
state: Registry<K, V::ExpandType>,
key: Query,
) -> V::ExpandType
where
V::ExpandType: Default,
K: Clone,
{
let key = key.into();
let mut map = state.map.as_ref().borrow_mut();
match map.get(&key) {
Some(val) => val.clone(),
None => {
map.insert(key.clone(), Default::default());
map.get(&key).unwrap().clone()
}
}
}
pub fn __expand_insert<Key: Into<K>>(
_scope: &mut Scope,
state: Registry<K, V::ExpandType>,
key: Key,
value: V::ExpandType,
) {
let key = key.into();
let mut map = state.map.as_ref().borrow_mut();
map.insert(key, value);
}
}
impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
let map = self.map.as_ref().borrow();
match map.get(&key) {
Some(val) => val.clone(),
None => panic!("No value found for key {key:?}"),
}
}
pub fn __expand_insert_method(self, _scope: &mut Scope, key: K, value: V) {
let mut map = self.map.as_ref().borrow_mut();
map.insert(key, value);
}
}
impl<K, V> Default for Registry<K, V> {
fn default() -> Self {
Self {
map: Rc::new(RefCell::new(BTreeMap::default())),
}
}
}
impl<K, V> Clone for Registry<K, V> {
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
}
}
}
impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
type ExpandType = Registry<K, V::ExpandType>;
}
impl<K: PartialOrd + Ord, V: IntoMut + Clone> IntoMut for Registry<K, V> {
fn into_mut(self, scope: &mut crate::ir::Scope) -> Self {
let mut map = self.map.borrow_mut();
map.iter_mut().for_each(|(_k, v)| {
*v = IntoMut::into_mut(v.clone(), scope);
});
core::mem::drop(map);
self
}
}
impl<K: PartialOrd + Ord, V> CubeDebug for Registry<K, V> {}