cubecl_core/frontend/container/registry/
base.rs

1use std::{cell::RefCell, collections::BTreeMap, rc::Rc};
2
3use crate::prelude::{CubeContext, CubeType, ExpandElementTyped, Init, IntoRuntime};
4
5/// It is similar to a map, but where the keys are stored at comptime, but the values can be runtime
6/// variables.
7pub struct Registry<K, V> {
8    map: Rc<RefCell<BTreeMap<K, V>>>,
9}
10
11/// To [find](Registry::find) an item from the [registry](Registry), the query must
12/// be able to be translated to the actual key type.
13///
14/// # Example
15///
16/// If you use [u32] as key that may become [ExpandElementTyped<u32>] during the expansion, both types
17/// need to implement [RegistryQuery].
18pub trait RegistryQuery<K>: Into<K> {}
19
20// We provide default implementations for some types.
21impl RegistryQuery<u32> for u32 {}
22impl RegistryQuery<u32> for ExpandElementTyped<u32> {}
23
24impl From<ExpandElementTyped<u32>> for u32 {
25    fn from(val: ExpandElementTyped<u32>) -> Self {
26        val.constant().unwrap().as_u32()
27    }
28}
29
30impl<K: PartialOrd + Ord, V: CubeType + Clone> Registry<K, V> {
31    /// Create a new registry.
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    /// Expand function of [Self::new].
37    pub fn __expand_new(_: &mut CubeContext) -> Registry<K, V::ExpandType> {
38        Registry {
39            map: Rc::new(RefCell::new(BTreeMap::new())),
40        }
41    }
42
43    /// Find an item in the registry.
44    ///
45    /// # Notes
46    ///
47    /// If the item isn't present in the registry, the function will panic.
48    pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
49        let key = query.into();
50        let map = self.map.as_ref().borrow();
51
52        map.get(&key).unwrap().clone()
53    }
54
55    /// Insert an item in the registry.
56    pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
57        let key = query.into();
58        let mut map = self.map.as_ref().borrow_mut();
59
60        map.insert(key, value);
61    }
62
63    /// Expand function of [Self::find].
64    pub fn __expand_find<Query: RegistryQuery<K>>(
65        _context: &mut CubeContext,
66        state: Registry<K, V::ExpandType>,
67        key: Query,
68    ) -> V::ExpandType {
69        let key = key.into();
70        let map = state.map.as_ref().borrow();
71
72        map.get(&key).unwrap().clone()
73    }
74
75    /// Expand function of [Self::insert].
76    pub fn __expand_insert<Key: Into<K>>(
77        _context: &mut CubeContext,
78        state: Registry<K, V::ExpandType>,
79        key: Key,
80        value: V::ExpandType,
81    ) {
82        let key = key.into();
83        let mut map = state.map.as_ref().borrow_mut();
84
85        map.insert(key, value);
86    }
87}
88
89impl<K: PartialOrd + Ord, V: Clone> Registry<K, V> {
90    /// Expand method of [Self::find].
91    pub fn __expand_find_method(&self, _context: &mut CubeContext, key: K) -> V {
92        let map = self.map.as_ref().borrow();
93
94        map.get(&key).unwrap().clone()
95    }
96
97    /// Expand method of [Self::insert].
98    pub fn __expand_insert_method(self, _context: &mut CubeContext, key: K, value: V) {
99        let mut map = self.map.as_ref().borrow_mut();
100
101        map.insert(key, value);
102    }
103}
104
105impl<K, V> Default for Registry<K, V> {
106    fn default() -> Self {
107        Self {
108            map: Rc::new(RefCell::new(BTreeMap::default())),
109        }
110    }
111}
112
113impl<K, V> Clone for Registry<K, V> {
114    fn clone(&self) -> Self {
115        Self {
116            map: self.map.clone(),
117        }
118    }
119}
120
121impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
122    type ExpandType = Registry<K, V::ExpandType>;
123}
124
125impl<K: PartialOrd + Ord, V> Init for Registry<K, V> {
126    fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
127        self
128    }
129}
130
131impl<K: PartialOrd + Ord, V: CubeType> IntoRuntime for Registry<K, V> {
132    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Registry<K, V::ExpandType> {
133        unimplemented!("Comptime registry can't be moved to runtime.");
134    }
135}