cubecl_core/frontend/container/registry/
base.rs1use std::{cell::RefCell, collections::BTreeMap, rc::Rc};
2
3use crate::prelude::{CubeContext, CubeType, ExpandElementTyped, Init, IntoRuntime};
4
5pub struct Registry<K, V> {
8 map: Rc<RefCell<BTreeMap<K, V>>>,
9}
10
11pub trait RegistryQuery<K>: Into<K> {}
19
20impl RegistryQuery<u32> for u32 {}
22impl RegistryQuery<u32> for ExpandElementTyped<u32> {}
23
24impl From<ExpandElementTyped<u32>> for u32 {
25 fn from(val: ExpandElementTyped<u32>) -> Self {
26 val.constant().unwrap().as_u32()
27 }
28}
29
30impl<K: PartialOrd + Ord, V: CubeType + Clone> Registry<K, V> {
31 pub fn new() -> Self {
33 Self::default()
34 }
35
36 pub fn __expand_new(_: &mut CubeContext) -> Registry<K, V::ExpandType> {
38 Registry {
39 map: Rc::new(RefCell::new(BTreeMap::new())),
40 }
41 }
42
43 pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
49 let key = query.into();
50 let map = self.map.as_ref().borrow();
51
52 map.get(&key).unwrap().clone()
53 }
54
55 pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
57 let key = query.into();
58 let mut map = self.map.as_ref().borrow_mut();
59
60 map.insert(key, value);
61 }
62
63 pub fn __expand_find<Query: RegistryQuery<K>>(
65 _context: &mut CubeContext,
66 state: Registry<K, V::ExpandType>,
67 key: Query,
68 ) -> V::ExpandType {
69 let key = key.into();
70 let map = state.map.as_ref().borrow();
71
72 map.get(&key).unwrap().clone()
73 }
74
75 pub fn __expand_insert<Key: Into<K>>(
77 _context: &mut CubeContext,
78 state: Registry<K, V::ExpandType>,
79 key: Key,
80 value: V::ExpandType,
81 ) {
82 let key = key.into();
83 let mut map = state.map.as_ref().borrow_mut();
84
85 map.insert(key, value);
86 }
87}
88
89impl<K: PartialOrd + Ord, V: Clone> Registry<K, V> {
90 pub fn __expand_find_method(&self, _context: &mut CubeContext, key: K) -> V {
92 let map = self.map.as_ref().borrow();
93
94 map.get(&key).unwrap().clone()
95 }
96
97 pub fn __expand_insert_method(self, _context: &mut CubeContext, key: K, value: V) {
99 let mut map = self.map.as_ref().borrow_mut();
100
101 map.insert(key, value);
102 }
103}
104
105impl<K, V> Default for Registry<K, V> {
106 fn default() -> Self {
107 Self {
108 map: Rc::new(RefCell::new(BTreeMap::default())),
109 }
110 }
111}
112
113impl<K, V> Clone for Registry<K, V> {
114 fn clone(&self) -> Self {
115 Self {
116 map: self.map.clone(),
117 }
118 }
119}
120
121impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
122 type ExpandType = Registry<K, V::ExpandType>;
123}
124
125impl<K: PartialOrd + Ord, V> Init for Registry<K, V> {
126 fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
127 self
128 }
129}
130
131impl<K: PartialOrd + Ord, V: CubeType> IntoRuntime for Registry<K, V> {
132 fn __expand_runtime_method(self, _context: &mut CubeContext) -> Registry<K, V::ExpandType> {
133 unimplemented!("Comptime registry can't be moved to runtime.");
134 }
135}