Skip to main content

cubecl_core/frontend/container/registry/
base.rs

1use std::{cell::RefCell, collections::BTreeMap, rc::Rc};
2
3use cubecl_ir::Scope;
4
5use crate::prelude::{CubeDebug, CubeType, IntoMut};
6
7/// It is similar to a map, but where the keys are stored at comptime, but the values can be runtime
8/// variables.
9pub struct Registry<K, V> {
10    map: Rc<RefCell<BTreeMap<K, V>>>,
11}
12
13/// To [find](Registry::find) an item from the [registry](Registry), the query must
14/// be able to be translated to the actual key type.
15///
16/// # Example
17///
18/// If you use [u32] as key that may become [`crate::frontend::ExpandElementTyped<u32>`] during the expansion, both types
19/// need to implement [`RegistryQuery`].
20pub trait RegistryQuery<K>: Into<K> {}
21
22// We provide default implementations for some types.
23impl RegistryQuery<u32> for u32 {}
24impl RegistryQuery<usize> for usize {}
25
26impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
27    /// Create a new registry.
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    /// Expand function of [`Self::new`].
33    pub fn __expand_new(_: &mut Scope) -> Registry<K, V::ExpandType> {
34        Registry {
35            map: Rc::new(RefCell::new(BTreeMap::new())),
36        }
37    }
38
39    /// Find an item in the registry.
40    ///
41    /// # Notes
42    ///
43    /// If the item isn't present in the registry, the function will panic.
44    pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
45        let key = query.into();
46        let map = self.map.as_ref().borrow();
47
48        match map.get(&key) {
49            Some(val) => val.clone(),
50            None => panic!("No value found for key {key:?}"),
51        }
52    }
53
54    /// Find an item in the registry or return the default value.
55    pub fn find_or_default<Query: RegistryQuery<K>>(&mut self, query: Query) -> V
56    where
57        V: Default,
58        K: Clone,
59    {
60        let key = query.into();
61        let mut map = self.map.as_ref().borrow_mut();
62
63        match map.get(&key) {
64            Some(val) => val.clone(),
65            None => {
66                map.insert(key.clone(), Default::default());
67                map.get(&key).unwrap().clone()
68            }
69        }
70    }
71
72    /// Insert an item in the registry.
73    pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
74        let key = query.into();
75        let mut map = self.map.as_ref().borrow_mut();
76
77        map.insert(key, value);
78    }
79
80    /// Expand function of [`Self::find`].
81    pub fn __expand_find<Query: RegistryQuery<K>>(
82        _scope: &mut Scope,
83        state: Registry<K, V::ExpandType>,
84        key: Query,
85    ) -> V::ExpandType {
86        let key = key.into();
87        let map = state.map.as_ref().borrow();
88
89        map.get(&key).unwrap().clone()
90    }
91
92    /// Expand function of [`Self::find_or_default`].
93    pub fn __expand_find_or_default<Query: RegistryQuery<K>>(
94        _scope: &mut Scope,
95        state: Registry<K, V::ExpandType>,
96        key: Query,
97    ) -> V::ExpandType
98    where
99        V::ExpandType: Default,
100        K: Clone,
101    {
102        let key = key.into();
103        let mut map = state.map.as_ref().borrow_mut();
104
105        match map.get(&key) {
106            Some(val) => val.clone(),
107            None => {
108                map.insert(key.clone(), Default::default());
109                map.get(&key).unwrap().clone()
110            }
111        }
112    }
113
114    /// Expand function of [`Self::insert`].
115    pub fn __expand_insert<Key: Into<K>>(
116        _scope: &mut Scope,
117        state: Registry<K, V::ExpandType>,
118        key: Key,
119        value: V::ExpandType,
120    ) {
121        let key = key.into();
122        let mut map = state.map.as_ref().borrow_mut();
123
124        map.insert(key, value);
125    }
126}
127
128impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
129    /// Expand method of [`Self::find`].
130    pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
131        let map = self.map.as_ref().borrow();
132
133        match map.get(&key) {
134            Some(val) => val.clone(),
135            None => panic!("No value found for key {key:?}"),
136        }
137    }
138
139    /// Expand method of [`Self::insert`].
140    pub fn __expand_insert_method(self, _scope: &mut Scope, key: K, value: V) {
141        let mut map = self.map.as_ref().borrow_mut();
142
143        map.insert(key, value);
144    }
145}
146
147impl<K, V> Default for Registry<K, V> {
148    fn default() -> Self {
149        Self {
150            map: Rc::new(RefCell::new(BTreeMap::default())),
151        }
152    }
153}
154
155impl<K, V> Clone for Registry<K, V> {
156    fn clone(&self) -> Self {
157        Self {
158            map: self.map.clone(),
159        }
160    }
161}
162
163impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
164    type ExpandType = Registry<K, V::ExpandType>;
165}
166
167impl<K: PartialOrd + Ord, V: IntoMut + Clone> IntoMut for Registry<K, V> {
168    fn into_mut(self, scope: &mut crate::ir::Scope) -> Self {
169        let mut map = self.map.borrow_mut();
170        map.iter_mut().for_each(|(_k, v)| {
171            *v = IntoMut::into_mut(v.clone(), scope);
172        });
173        core::mem::drop(map);
174
175        self
176    }
177}
178
179impl<K: PartialOrd + Ord, V> CubeDebug for Registry<K, V> {}