cubecl_core/frontend/container/registry/
base.rsuse std::{cell::RefCell, collections::BTreeMap, rc::Rc};
use crate::prelude::{CubeContext, CubeType, ExpandElementTyped, Init, IntoRuntime};
pub struct Registry<K, V> {
map: Rc<RefCell<BTreeMap<K, V>>>,
}
pub trait RegistryQuery<K>: Into<K> {}
impl RegistryQuery<u32> for u32 {}
impl RegistryQuery<u32> for ExpandElementTyped<u32> {}
impl From<ExpandElementTyped<u32>> for u32 {
fn from(val: ExpandElementTyped<u32>) -> Self {
val.constant().unwrap().as_u32()
}
}
impl<K: PartialOrd + Ord, V: CubeType + Clone> Registry<K, V> {
pub fn new() -> Self {
Self::default()
}
pub fn __expand_new(_: &mut CubeContext) -> Registry<K, V::ExpandType> {
Registry {
map: Rc::new(RefCell::new(BTreeMap::new())),
}
}
pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
let key = query.into();
let map = self.map.as_ref().borrow();
map.get(&key).unwrap().clone()
}
pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
let key = query.into();
let mut map = self.map.as_ref().borrow_mut();
map.insert(key, value);
}
pub fn __expand_find<Query: RegistryQuery<K>>(
_context: &mut CubeContext,
state: Registry<K, V::ExpandType>,
key: Query,
) -> V::ExpandType {
let key = key.into();
let map = state.map.as_ref().borrow();
map.get(&key).unwrap().clone()
}
pub fn __expand_insert<Key: Into<K>>(
_context: &mut CubeContext,
state: Registry<K, V::ExpandType>,
key: Key,
value: V::ExpandType,
) {
let key = key.into();
let mut map = state.map.as_ref().borrow_mut();
map.insert(key, value);
}
}
impl<K: PartialOrd + Ord, V: Clone> Registry<K, V> {
pub fn __expand_find_method(&self, _context: &mut CubeContext, key: K) -> V {
let map = self.map.as_ref().borrow();
map.get(&key).unwrap().clone()
}
pub fn __expand_insert_method(self, _context: &mut CubeContext, key: K, value: V) {
let mut map = self.map.as_ref().borrow_mut();
map.insert(key, value);
}
}
impl<K, V> Default for Registry<K, V> {
fn default() -> Self {
Self {
map: Rc::new(RefCell::new(BTreeMap::default())),
}
}
}
impl<K, V> Clone for Registry<K, V> {
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
}
}
}
impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
type ExpandType = Registry<K, V::ExpandType>;
}
impl<K: PartialOrd + Ord, V> Init for Registry<K, V> {
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
self
}
}
impl<K: PartialOrd + Ord, V: CubeType> IntoRuntime for Registry<K, V> {
fn __expand_runtime_method(self, _context: &mut CubeContext) -> Registry<K, V::ExpandType> {
unimplemented!("Comptime registry can't be moved to runtime.");
}
}