cubecl_core/frontend/container/registry/
base.rs1use std::{cell::RefCell, collections::BTreeMap, rc::Rc};
2
3use cubecl_ir::Scope;
4
5use crate::prelude::{CubeDebug, CubeType, IntoMut};
6
7pub struct Registry<K, V> {
10 map: Rc<RefCell<BTreeMap<K, V>>>,
11}
12
13pub trait RegistryQuery<K>: Into<K> {}
21
22impl RegistryQuery<u32> for u32 {}
24impl RegistryQuery<usize> for usize {}
25
26impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn __expand_new(_: &mut Scope) -> Registry<K, V::ExpandType> {
34 Registry {
35 map: Rc::new(RefCell::new(BTreeMap::new())),
36 }
37 }
38
39 pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
45 let key = query.into();
46 let map = self.map.as_ref().borrow();
47
48 match map.get(&key) {
49 Some(val) => val.clone(),
50 None => panic!("No value found for key {key:?}"),
51 }
52 }
53
54 pub fn find_or_default<Query: RegistryQuery<K>>(&mut self, query: Query) -> V
56 where
57 V: Default,
58 K: Clone,
59 {
60 let key = query.into();
61 let mut map = self.map.as_ref().borrow_mut();
62
63 match map.get(&key) {
64 Some(val) => val.clone(),
65 None => {
66 map.insert(key.clone(), Default::default());
67 map.get(&key).unwrap().clone()
68 }
69 }
70 }
71
72 pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
74 let key = query.into();
75 let mut map = self.map.as_ref().borrow_mut();
76
77 map.insert(key, value);
78 }
79
80 pub fn __expand_find<Query: RegistryQuery<K>>(
82 _scope: &mut Scope,
83 state: Registry<K, V::ExpandType>,
84 key: Query,
85 ) -> V::ExpandType {
86 let key = key.into();
87 let map = state.map.as_ref().borrow();
88
89 map.get(&key).unwrap().clone()
90 }
91
92 pub fn __expand_find_or_default<Query: RegistryQuery<K>>(
94 _scope: &mut Scope,
95 state: Registry<K, V::ExpandType>,
96 key: Query,
97 ) -> V::ExpandType
98 where
99 V::ExpandType: Default,
100 K: Clone,
101 {
102 let key = key.into();
103 let mut map = state.map.as_ref().borrow_mut();
104
105 match map.get(&key) {
106 Some(val) => val.clone(),
107 None => {
108 map.insert(key.clone(), Default::default());
109 map.get(&key).unwrap().clone()
110 }
111 }
112 }
113
114 pub fn __expand_insert<Key: Into<K>>(
116 _scope: &mut Scope,
117 state: Registry<K, V::ExpandType>,
118 key: Key,
119 value: V::ExpandType,
120 ) {
121 let key = key.into();
122 let mut map = state.map.as_ref().borrow_mut();
123
124 map.insert(key, value);
125 }
126}
127
128impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
129 pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
131 let map = self.map.as_ref().borrow();
132
133 match map.get(&key) {
134 Some(val) => val.clone(),
135 None => panic!("No value found for key {key:?}"),
136 }
137 }
138
139 pub fn __expand_insert_method(self, _scope: &mut Scope, key: K, value: V) {
141 let mut map = self.map.as_ref().borrow_mut();
142
143 map.insert(key, value);
144 }
145}
146
147impl<K, V> Default for Registry<K, V> {
148 fn default() -> Self {
149 Self {
150 map: Rc::new(RefCell::new(BTreeMap::default())),
151 }
152 }
153}
154
155impl<K, V> Clone for Registry<K, V> {
156 fn clone(&self) -> Self {
157 Self {
158 map: self.map.clone(),
159 }
160 }
161}
162
163impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
164 type ExpandType = Registry<K, V::ExpandType>;
165}
166
167impl<K: PartialOrd + Ord, V: IntoMut + Clone> IntoMut for Registry<K, V> {
168 fn into_mut(self, scope: &mut crate::ir::Scope) -> Self {
169 let mut map = self.map.borrow_mut();
170 map.iter_mut().for_each(|(_k, v)| {
171 *v = IntoMut::into_mut(v.clone(), scope);
172 });
173 core::mem::drop(map);
174
175 self
176 }
177}
178
179impl<K: PartialOrd + Ord, V> CubeDebug for Registry<K, V> {}