cubecl_core/frontend/container/registry/
base.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use std::{cell::RefCell, collections::BTreeMap, rc::Rc};

use crate::prelude::{CubeContext, CubeType, ExpandElementTyped, Init, IntoRuntime};

/// It is similar to a map, but where the keys are stored at comptime, but the values can be runtime
/// variables.
pub struct Registry<K, V> {
    map: Rc<RefCell<BTreeMap<K, V>>>,
}

/// To [find](Registry::find) an item from the [registry](Registry), the query must
/// be able to be translated to the actual key type.
///
/// # Example
///
/// If you use [u32] as key that may become [ExpandElementTyped<u32>] during the expansion, both types
/// need to implement [RegistryQuery].
pub trait RegistryQuery<K>: Into<K> {}

// We provide default implementations for some types.
impl RegistryQuery<u32> for u32 {}
impl RegistryQuery<u32> for ExpandElementTyped<u32> {}

impl From<ExpandElementTyped<u32>> for u32 {
    fn from(val: ExpandElementTyped<u32>) -> Self {
        val.constant().unwrap().as_u32()
    }
}

impl<K: PartialOrd + Ord, V: CubeType + Clone> Registry<K, V> {
    /// Create a new registry.
    pub fn new() -> Self {
        Self::default()
    }

    /// Expand function of [Self::new].
    pub fn __expand_new(_: &mut CubeContext) -> Registry<K, V::ExpandType> {
        Registry {
            map: Rc::new(RefCell::new(BTreeMap::new())),
        }
    }

    /// Find an item in the registry.
    ///
    /// # Notes
    ///
    /// If the item isn't present in the registry, the function will panic.
    pub fn find<Query: RegistryQuery<K>>(&self, query: Query) -> V {
        let key = query.into();
        let map = self.map.as_ref().borrow();

        map.get(&key).unwrap().clone()
    }

    /// Insert an item in the registry.
    pub fn insert<Query: RegistryQuery<K>>(&mut self, query: Query, value: V) {
        let key = query.into();
        let mut map = self.map.as_ref().borrow_mut();

        map.insert(key, value);
    }

    /// Expand function of [Self::find].
    pub fn __expand_find<Query: RegistryQuery<K>>(
        _context: &mut CubeContext,
        state: Registry<K, V::ExpandType>,
        key: Query,
    ) -> V::ExpandType {
        let key = key.into();
        let map = state.map.as_ref().borrow();

        map.get(&key).unwrap().clone()
    }

    /// Expand function of [Self::insert].
    pub fn __expand_insert<Key: Into<K>>(
        _context: &mut CubeContext,
        state: Registry<K, V::ExpandType>,
        key: Key,
        value: V::ExpandType,
    ) {
        let key = key.into();
        let mut map = state.map.as_ref().borrow_mut();

        map.insert(key, value);
    }
}

impl<K: PartialOrd + Ord, V: Clone> Registry<K, V> {
    /// Expand method of [Self::find].
    pub fn __expand_find_method(&self, _context: &mut CubeContext, key: K) -> V {
        let map = self.map.as_ref().borrow();

        map.get(&key).unwrap().clone()
    }

    /// Expand method of [Self::insert].
    pub fn __expand_insert_method(self, _context: &mut CubeContext, key: K, value: V) {
        let mut map = self.map.as_ref().borrow_mut();

        map.insert(key, value);
    }
}

impl<K, V> Default for Registry<K, V> {
    fn default() -> Self {
        Self {
            map: Rc::new(RefCell::new(BTreeMap::default())),
        }
    }
}

impl<K, V> Clone for Registry<K, V> {
    fn clone(&self) -> Self {
        Self {
            map: self.map.clone(),
        }
    }
}

impl<K: PartialOrd + Ord, V: CubeType> CubeType for Registry<K, V> {
    type ExpandType = Registry<K, V::ExpandType>;
}

impl<K: PartialOrd + Ord, V> Init for Registry<K, V> {
    fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
        self
    }
}

impl<K: PartialOrd + Ord, V: CubeType> IntoRuntime for Registry<K, V> {
    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Registry<K, V::ExpandType> {
        unimplemented!("Comptime registry can't be moved to runtime.");
    }
}