cubecl_core/frontend/container/registry/
base.rs1use alloc::{collections::BTreeMap, rc::Rc};
2use core::cell::RefCell;
3
4use cubecl_ir::Scope;
5
6use crate::prelude::{CubeDebug, CubeType, IntoMut};
7
8pub struct Registry<K, V> {
11 map: Rc<RefCell<BTreeMap<K, V>>>,
12}
13
14pub trait RegistryQuery<K>: Into<K> {}
22
23impl RegistryQuery<u32> for u32 {}
25impl RegistryQuery<usize> for usize {}
26
27impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
28 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub fn __expand_new(_: &mut Scope) -> Registry<K, V::ExpandType> {
35 Registry {
36 map: Rc::new(RefCell::new(BTreeMap::new())),
37 }
38 }
39
40 pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
46 let key = query.into();
47 let map = self.map.as_ref().borrow();
48
49 match map.get(&key) {
50 Some(val) => val.clone(),
51 None => panic!("No value found for key {key:?}"),
52 }
53 }
54
55 pub fn find_or_default<Query: RegistryQuery<K>>(&mut self, query: Query) -> V
57 where
58 V: Default,
59 K: Clone,
60 {
61 let key = query.into();
62 let mut map = self.map.as_ref().borrow_mut();
63
64 match map.get(&key) {
65 Some(val) => val.clone(),
66 None => {
67 map.insert(key.clone(), Default::default());
68 map.get(&key).unwrap().clone()
69 }
70 }
71 }
72
73 pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
75 let key = query.into();
76 let mut map = self.map.as_ref().borrow_mut();
77
78 map.insert(key, value);
79 }
80
81 pub fn __expand_find<Query: RegistryQuery<K>>(
83 _scope: &mut Scope,
84 state: Registry<K, V::ExpandType>,
85 key: Query,
86 ) -> V::ExpandType {
87 let key = key.into();
88 let map = state.map.as_ref().borrow();
89
90 map.get(&key).unwrap().clone()
91 }
92
93 pub fn __expand_find_or_default<Query: RegistryQuery<K>>(
95 _scope: &mut Scope,
96 state: Registry<K, V::ExpandType>,
97 key: Query,
98 ) -> V::ExpandType
99 where
100 V::ExpandType: Default,
101 K: Clone,
102 {
103 let key = key.into();
104 let mut map = state.map.as_ref().borrow_mut();
105
106 match map.get(&key) {
107 Some(val) => val.clone(),
108 None => {
109 map.insert(key.clone(), Default::default());
110 map.get(&key).unwrap().clone()
111 }
112 }
113 }
114
115 pub fn __expand_insert<Key: Into<K>>(
117 _scope: &mut Scope,
118 state: Registry<K, V::ExpandType>,
119 key: Key,
120 value: V::ExpandType,
121 ) {
122 let key = key.into();
123 let mut map = state.map.as_ref().borrow_mut();
124
125 map.insert(key, value);
126 }
127}
128
129impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
130 pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
132 let map = self.map.as_ref().borrow();
133
134 match map.get(&key) {
135 Some(val) => val.clone(),
136 None => panic!("No value found for key {key:?}"),
137 }
138 }
139
140 pub fn __expand_insert_method(self, _scope: &mut Scope, key: K, value: V) {
142 let mut map = self.map.as_ref().borrow_mut();
143
144 map.insert(key, value);
145 }
146}
147
148impl<K, V> Default for Registry<K, V> {
149 fn default() -> Self {
150 Self {
151 map: Rc::new(RefCell::new(BTreeMap::default())),
152 }
153 }
154}
155
156impl<K, V> Clone for Registry<K, V> {
157 fn clone(&self) -> Self {
158 Self {
159 map: self.map.clone(),
160 }
161 }
162}
163
164impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
165 type ExpandType = Registry<K, V::ExpandType>;
166}
167
168impl<K: PartialOrd + Ord, V: IntoMut + Clone> IntoMut for Registry<K, V> {
169 fn into_mut(self, scope: &mut crate::ir::Scope) -> Self {
170 let mut map = self.map.borrow_mut();
171 map.iter_mut().for_each(|(_k, v)| {
172 *v = IntoMut::into_mut(v.clone(), scope);
173 });
174 core::mem::drop(map);
175
176 self
177 }
178}
179
180impl<K: PartialOrd + Ord, V> CubeDebug for Registry<K, V> {}