cubecl_core/frontend/container/registry/
base.rs1use std::{cell::RefCell, collections::BTreeMap, rc::Rc};
2
3use cubecl_ir::Scope;
4
5use crate::prelude::{CubeDebug, CubeType, IntoMut};
6
7pub struct Registry<K, V> {
10 map: Rc<RefCell<BTreeMap<K, V>>>,
11}
12
13pub trait RegistryQuery<K>: Into<K> {}
21
22impl RegistryQuery<u32> for u32 {}
24impl RegistryQuery<usize> for usize {}
25
26impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn __expand_new(_: &mut Scope) -> Registry<K, V::ExpandType> {
34 Registry {
35 map: Rc::new(RefCell::new(BTreeMap::new())),
36 }
37 }
38
39 pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
45 let key = query.into();
46 let map = self.map.as_ref().borrow();
47
48 match map.get(&key) {
49 Some(val) => val.clone(),
50 None => panic!("No value found for key {key:?}"),
51 }
52 }
53
54 pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
56 let key = query.into();
57 let mut map = self.map.as_ref().borrow_mut();
58
59 map.insert(key, value);
60 }
61
62 pub fn __expand_find<Query: RegistryQuery<K>>(
64 _scope: &mut Scope,
65 state: Registry<K, V::ExpandType>,
66 key: Query,
67 ) -> V::ExpandType {
68 let key = key.into();
69 let map = state.map.as_ref().borrow();
70
71 map.get(&key).unwrap().clone()
72 }
73
74 pub fn __expand_insert<Key: Into<K>>(
76 _scope: &mut Scope,
77 state: Registry<K, V::ExpandType>,
78 key: Key,
79 value: V::ExpandType,
80 ) {
81 let key = key.into();
82 let mut map = state.map.as_ref().borrow_mut();
83
84 map.insert(key, value);
85 }
86}
87
88impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
89 pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
91 let map = self.map.as_ref().borrow();
92
93 match map.get(&key) {
94 Some(val) => val.clone(),
95 None => panic!("No value found for key {key:?}"),
96 }
97 }
98
99 pub fn __expand_insert_method(self, _scope: &mut Scope, key: K, value: V) {
101 let mut map = self.map.as_ref().borrow_mut();
102
103 map.insert(key, value);
104 }
105}
106
107impl<K, V> Default for Registry<K, V> {
108 fn default() -> Self {
109 Self {
110 map: Rc::new(RefCell::new(BTreeMap::default())),
111 }
112 }
113}
114
115impl<K, V> Clone for Registry<K, V> {
116 fn clone(&self) -> Self {
117 Self {
118 map: self.map.clone(),
119 }
120 }
121}
122
123impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
124 type ExpandType = Registry<K, V::ExpandType>;
125}
126
127impl<K: PartialOrd + Ord, V: IntoMut + Clone> IntoMut for Registry<K, V> {
128 fn into_mut(self, scope: &mut crate::ir::Scope) -> Self {
129 let mut map = self.map.borrow_mut();
130 map.iter_mut().for_each(|(_k, v)| {
131 *v = IntoMut::into_mut(v.clone(), scope);
132 });
133 core::mem::drop(map);
134
135 self
136 }
137}
138
139impl<K: PartialOrd + Ord, V> CubeDebug for Registry<K, V> {}