cubecl_core/frontend/container/registry/
base.rs

1use std::{cell::RefCell, collections::BTreeMap, rc::Rc};
2
3use cubecl_ir::Scope;
4
5use crate::prelude::{CubeDebug, CubeType, Init};
6
7/// It is similar to a map, but where the keys are stored at comptime, but the values can be runtime
8/// variables.
9pub struct Registry<K, V> {
10    map: Rc<RefCell<BTreeMap<K, V>>>,
11}
12
13/// To [find](Registry::find) an item from the [registry](Registry), the query must
14/// be able to be translated to the actual key type.
15///
16/// # Example
17///
18/// If you use [u32] as key that may become [`ExpandElementTyped<u32>`] during the expansion, both types
19/// need to implement [RegistryQuery].
20pub trait RegistryQuery<K>: Into<K> {}
21
22// We provide default implementations for some types.
23impl RegistryQuery<u32> for u32 {}
24
25impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
26    /// Create a new registry.
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    /// Expand function of [Self::new].
32    pub fn __expand_new(_: &mut Scope) -> Registry<K, V::ExpandType> {
33        Registry {
34            map: Rc::new(RefCell::new(BTreeMap::new())),
35        }
36    }
37
38    /// Find an item in the registry.
39    ///
40    /// # Notes
41    ///
42    /// If the item isn't present in the registry, the function will panic.
43    pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
44        let key = query.into();
45        let map = self.map.as_ref().borrow();
46
47        match map.get(&key) {
48            Some(val) => val.clone(),
49            None => panic!("No value found for key {key:?}"),
50        }
51    }
52
53    /// Insert an item in the registry.
54    pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
55        let key = query.into();
56        let mut map = self.map.as_ref().borrow_mut();
57
58        map.insert(key, value);
59    }
60
61    /// Expand function of [Self::find].
62    pub fn __expand_find<Query: RegistryQuery<K>>(
63        _scope: &mut Scope,
64        state: Registry<K, V::ExpandType>,
65        key: Query,
66    ) -> V::ExpandType {
67        let key = key.into();
68        let map = state.map.as_ref().borrow();
69
70        map.get(&key).unwrap().clone()
71    }
72
73    /// Expand function of [Self::insert].
74    pub fn __expand_insert<Key: Into<K>>(
75        _scope: &mut Scope,
76        state: Registry<K, V::ExpandType>,
77        key: Key,
78        value: V::ExpandType,
79    ) {
80        let key = key.into();
81        let mut map = state.map.as_ref().borrow_mut();
82
83        map.insert(key, value);
84    }
85}
86
87impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
88    /// Expand method of [Self::find].
89    pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
90        let map = self.map.as_ref().borrow();
91
92        match map.get(&key) {
93            Some(val) => val.clone(),
94            None => panic!("No value found for key {key:?}"),
95        }
96    }
97
98    /// Expand method of [Self::insert].
99    pub fn __expand_insert_method(self, _scope: &mut Scope, key: K, value: V) {
100        let mut map = self.map.as_ref().borrow_mut();
101
102        map.insert(key, value);
103    }
104}
105
106impl<K, V> Default for Registry<K, V> {
107    fn default() -> Self {
108        Self {
109            map: Rc::new(RefCell::new(BTreeMap::default())),
110        }
111    }
112}
113
114impl<K, V> Clone for Registry<K, V> {
115    fn clone(&self) -> Self {
116        Self {
117            map: self.map.clone(),
118        }
119    }
120}
121
122impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
123    type ExpandType = Registry<K, V::ExpandType>;
124}
125
126impl<K: PartialOrd + Ord, V> Init for Registry<K, V> {
127    fn init(self, _scope: &mut crate::ir::Scope) -> Self {
128        self
129    }
130}
131
132impl<K: PartialOrd + Ord, V> CubeDebug for Registry<K, V> {}