Skip to main content

cubecl_core/frontend/container/registry/
base.rs

1use alloc::{collections::BTreeMap, rc::Rc};
2use core::cell::RefCell;
3
4use cubecl_ir::Scope;
5
6use crate::prelude::{CubeDebug, CubeType, IntoMut};
7
8/// It is similar to a map, but where the keys are stored at comptime, but the values can be runtime
9/// variables.
10pub struct Registry<K, V> {
11    map: Rc<RefCell<BTreeMap<K, V>>>,
12}
13
14/// To [find](Registry::find) an item from the [registry](Registry), the query must
15/// be able to be translated to the actual key type.
16///
17/// # Example
18///
19/// If you use [u32] as key that may become [`crate::frontend::ExpandElementTyped<u32>`] during the expansion, both types
20/// need to implement [`RegistryQuery`].
21pub trait RegistryQuery<K>: Into<K> {}
22
23// We provide default implementations for some types.
24impl RegistryQuery<u32> for u32 {}
25impl RegistryQuery<usize> for usize {}
26
27impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
28    /// Create a new registry.
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    /// Expand function of [`Self::new`].
34    pub fn __expand_new(_: &mut Scope) -> Registry<K, V::ExpandType> {
35        Registry {
36            map: Rc::new(RefCell::new(BTreeMap::new())),
37        }
38    }
39
40    /// Find an item in the registry.
41    ///
42    /// # Notes
43    ///
44    /// If the item isn't present in the registry, the function will panic.
45    pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
46        let key = query.into();
47        let map = self.map.as_ref().borrow();
48
49        match map.get(&key) {
50            Some(val) => val.clone(),
51            None => panic!("No value found for key {key:?}"),
52        }
53    }
54
55    /// Find an item in the registry or return the default value.
56    pub fn find_or_default<Query: RegistryQuery<K>>(&mut self, query: Query) -> V
57    where
58        V: Default,
59        K: Clone,
60    {
61        let key = query.into();
62        let mut map = self.map.as_ref().borrow_mut();
63
64        match map.get(&key) {
65            Some(val) => val.clone(),
66            None => {
67                map.insert(key.clone(), Default::default());
68                map.get(&key).unwrap().clone()
69            }
70        }
71    }
72
73    /// Insert an item in the registry.
74    pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
75        let key = query.into();
76        let mut map = self.map.as_ref().borrow_mut();
77
78        map.insert(key, value);
79    }
80
81    /// Expand function of [`Self::find`].
82    pub fn __expand_find<Query: RegistryQuery<K>>(
83        _scope: &mut Scope,
84        state: Registry<K, V::ExpandType>,
85        key: Query,
86    ) -> V::ExpandType {
87        let key = key.into();
88        let map = state.map.as_ref().borrow();
89
90        map.get(&key).unwrap().clone()
91    }
92
93    /// Expand function of [`Self::find_or_default`].
94    pub fn __expand_find_or_default<Query: RegistryQuery<K>>(
95        _scope: &mut Scope,
96        state: Registry<K, V::ExpandType>,
97        key: Query,
98    ) -> V::ExpandType
99    where
100        V::ExpandType: Default,
101        K: Clone,
102    {
103        let key = key.into();
104        let mut map = state.map.as_ref().borrow_mut();
105
106        match map.get(&key) {
107            Some(val) => val.clone(),
108            None => {
109                map.insert(key.clone(), Default::default());
110                map.get(&key).unwrap().clone()
111            }
112        }
113    }
114
115    /// Expand function of [`Self::insert`].
116    pub fn __expand_insert<Key: Into<K>>(
117        _scope: &mut Scope,
118        state: Registry<K, V::ExpandType>,
119        key: Key,
120        value: V::ExpandType,
121    ) {
122        let key = key.into();
123        let mut map = state.map.as_ref().borrow_mut();
124
125        map.insert(key, value);
126    }
127}
128
129impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
130    /// Expand method of [`Self::find`].
131    pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
132        let map = self.map.as_ref().borrow();
133
134        match map.get(&key) {
135            Some(val) => val.clone(),
136            None => panic!("No value found for key {key:?}"),
137        }
138    }
139
140    /// Expand method of [`Self::insert`].
141    pub fn __expand_insert_method(self, _scope: &mut Scope, key: K, value: V) {
142        let mut map = self.map.as_ref().borrow_mut();
143
144        map.insert(key, value);
145    }
146}
147
148impl<K, V> Default for Registry<K, V> {
149    fn default() -> Self {
150        Self {
151            map: Rc::new(RefCell::new(BTreeMap::default())),
152        }
153    }
154}
155
156impl<K, V> Clone for Registry<K, V> {
157    fn clone(&self) -> Self {
158        Self {
159            map: self.map.clone(),
160        }
161    }
162}
163
164impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
165    type ExpandType = Registry<K, V::ExpandType>;
166}
167
168impl<K: PartialOrd + Ord, V: IntoMut + Clone> IntoMut for Registry<K, V> {
169    fn into_mut(self, scope: &mut crate::ir::Scope) -> Self {
170        let mut map = self.map.borrow_mut();
171        map.iter_mut().for_each(|(_k, v)| {
172            *v = IntoMut::into_mut(v.clone(), scope);
173        });
174        core::mem::drop(map);
175
176        self
177    }
178}
179
180impl<K: PartialOrd + Ord, V> CubeDebug for Registry<K, V> {}