cubecl_core/frontend/container/registry/
base.rs1use std::{cell::RefCell, collections::BTreeMap, rc::Rc};
2
3use cubecl_ir::Scope;
4
5use crate::prelude::{CubeDebug, CubeType, Init};
6
7pub struct Registry<K, V> {
10 map: Rc<RefCell<BTreeMap<K, V>>>,
11}
12
13pub trait RegistryQuery<K>: Into<K> {}
21
22impl RegistryQuery<u32> for u32 {}
24
25impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
26 pub fn new() -> Self {
28 Self::default()
29 }
30
31 pub fn __expand_new(_: &mut Scope) -> Registry<K, V::ExpandType> {
33 Registry {
34 map: Rc::new(RefCell::new(BTreeMap::new())),
35 }
36 }
37
38 pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
44 let key = query.into();
45 let map = self.map.as_ref().borrow();
46
47 match map.get(&key) {
48 Some(val) => val.clone(),
49 None => panic!("No value found for key {key:?}"),
50 }
51 }
52
53 pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
55 let key = query.into();
56 let mut map = self.map.as_ref().borrow_mut();
57
58 map.insert(key, value);
59 }
60
61 pub fn __expand_find<Query: RegistryQuery<K>>(
63 _scope: &mut Scope,
64 state: Registry<K, V::ExpandType>,
65 key: Query,
66 ) -> V::ExpandType {
67 let key = key.into();
68 let map = state.map.as_ref().borrow();
69
70 map.get(&key).unwrap().clone()
71 }
72
73 pub fn __expand_insert<Key: Into<K>>(
75 _scope: &mut Scope,
76 state: Registry<K, V::ExpandType>,
77 key: Key,
78 value: V::ExpandType,
79 ) {
80 let key = key.into();
81 let mut map = state.map.as_ref().borrow_mut();
82
83 map.insert(key, value);
84 }
85}
86
87impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
88 pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
90 let map = self.map.as_ref().borrow();
91
92 match map.get(&key) {
93 Some(val) => val.clone(),
94 None => panic!("No value found for key {key:?}"),
95 }
96 }
97
98 pub fn __expand_insert_method(self, _scope: &mut Scope, key: K, value: V) {
100 let mut map = self.map.as_ref().borrow_mut();
101
102 map.insert(key, value);
103 }
104}
105
106impl<K, V> Default for Registry<K, V> {
107 fn default() -> Self {
108 Self {
109 map: Rc::new(RefCell::new(BTreeMap::default())),
110 }
111 }
112}
113
114impl<K, V> Clone for Registry<K, V> {
115 fn clone(&self) -> Self {
116 Self {
117 map: self.map.clone(),
118 }
119 }
120}
121
122impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
123 type ExpandType = Registry<K, V::ExpandType>;
124}
125
126impl<K: PartialOrd + Ord, V> Init for Registry<K, V> {
127 fn init(self, _scope: &mut crate::ir::Scope) -> Self {
128 self
129 }
130}
131
132impl<K: PartialOrd + Ord, V> CubeDebug for Registry<K, V> {}