cubecl_core/frontend/container/registry/
base.rs

1use std::{cell::RefCell, collections::BTreeMap, rc::Rc};
2
3use cubecl_ir::Scope;
4
5use crate::prelude::{CubeDebug, CubeType, IntoMut};
6
7/// It is similar to a map, but where the keys are stored at comptime, but the values can be runtime
8/// variables.
9pub struct Registry<K, V> {
10    map: Rc<RefCell<BTreeMap<K, V>>>,
11}
12
13/// To [find](Registry::find) an item from the [registry](Registry), the query must
14/// be able to be translated to the actual key type.
15///
16/// # Example
17///
18/// If you use [u32] as key that may become [`ExpandElementTyped<u32>`] during the expansion, both types
19/// need to implement [RegistryQuery].
20pub trait RegistryQuery<K>: Into<K> {}
21
22// We provide default implementations for some types.
23impl RegistryQuery<u32> for u32 {}
24impl RegistryQuery<usize> for usize {}
25
26impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
27    /// Create a new registry.
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    /// Expand function of [Self::new].
33    pub fn __expand_new(_: &mut Scope) -> Registry<K, V::ExpandType> {
34        Registry {
35            map: Rc::new(RefCell::new(BTreeMap::new())),
36        }
37    }
38
39    /// Find an item in the registry.
40    ///
41    /// # Notes
42    ///
43    /// If the item isn't present in the registry, the function will panic.
44    pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
45        let key = query.into();
46        let map = self.map.as_ref().borrow();
47
48        match map.get(&key) {
49            Some(val) => val.clone(),
50            None => panic!("No value found for key {key:?}"),
51        }
52    }
53
54    /// Insert an item in the registry.
55    pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
56        let key = query.into();
57        let mut map = self.map.as_ref().borrow_mut();
58
59        map.insert(key, value);
60    }
61
62    /// Expand function of [Self::find].
63    pub fn __expand_find<Query: RegistryQuery<K>>(
64        _scope: &mut Scope,
65        state: Registry<K, V::ExpandType>,
66        key: Query,
67    ) -> V::ExpandType {
68        let key = key.into();
69        let map = state.map.as_ref().borrow();
70
71        map.get(&key).unwrap().clone()
72    }
73
74    /// Expand function of [Self::insert].
75    pub fn __expand_insert<Key: Into<K>>(
76        _scope: &mut Scope,
77        state: Registry<K, V::ExpandType>,
78        key: Key,
79        value: V::ExpandType,
80    ) {
81        let key = key.into();
82        let mut map = state.map.as_ref().borrow_mut();
83
84        map.insert(key, value);
85    }
86}
87
88impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
89    /// Expand method of [Self::find].
90    pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
91        let map = self.map.as_ref().borrow();
92
93        match map.get(&key) {
94            Some(val) => val.clone(),
95            None => panic!("No value found for key {key:?}"),
96        }
97    }
98
99    /// Expand method of [Self::insert].
100    pub fn __expand_insert_method(self, _scope: &mut Scope, key: K, value: V) {
101        let mut map = self.map.as_ref().borrow_mut();
102
103        map.insert(key, value);
104    }
105}
106
107impl<K, V> Default for Registry<K, V> {
108    fn default() -> Self {
109        Self {
110            map: Rc::new(RefCell::new(BTreeMap::default())),
111        }
112    }
113}
114
115impl<K, V> Clone for Registry<K, V> {
116    fn clone(&self) -> Self {
117        Self {
118            map: self.map.clone(),
119        }
120    }
121}
122
123impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
124    type ExpandType = Registry<K, V::ExpandType>;
125}
126
127impl<K: PartialOrd + Ord, V: IntoMut + Clone> IntoMut for Registry<K, V> {
128    fn into_mut(self, scope: &mut crate::ir::Scope) -> Self {
129        let mut map = self.map.borrow_mut();
130        map.iter_mut().for_each(|(_k, v)| {
131            *v = IntoMut::into_mut(v.clone(), scope);
132        });
133        core::mem::drop(map);
134
135        self
136    }
137}
138
139impl<K: PartialOrd + Ord, V> CubeDebug for Registry<K, V> {}