cubecl_core/frontend/
context.rs

1use crate::ir::{self, Elem, Instruction, Item, Scope, Variable, VariableKind};
2use crate::{frontend::ExpandElement, ir::Id};
3use alloc::rc::Rc;
4use core::cell::RefCell;
5use cubecl_runtime::debug::DebugLogger;
6use std::any::TypeId;
7use std::collections::HashMap;
8
9pub struct CubeContext {
10    pub root: Rc<RefCell<Scope>>,
11    pub scope: Rc<RefCell<Scope>>,
12    pub debug_enabled: bool,
13    pub typemap: Rc<RefCell<HashMap<TypeId, Elem>>>,
14}
15
16impl Default for CubeContext {
17    fn default() -> Self {
18        Self::root()
19    }
20}
21
22impl CubeContext {
23    /// Create a new cube context, with a root scope
24    /// A root scope is at the root of a compute shader
25    /// Therefore there is one cube context per shader
26    /// The allocator will define the strategy for creating local intermediates and mutable variables
27    pub fn root() -> CubeContext {
28        let root = Rc::new(RefCell::new(Scope::root()));
29        let typemap = Rc::new(RefCell::new(HashMap::new()));
30        let scope = root.clone();
31
32        Self {
33            scope,
34            root,
35            debug_enabled: DebugLogger::default().is_activated(),
36            typemap,
37        }
38    }
39
40    pub fn register<O: Into<Instruction>>(&mut self, op: O) {
41        self.scope.borrow_mut().register(op)
42    }
43
44    /// Resolve the element type of the given generic type.
45    pub fn resolve_elem<T: 'static>(&self) -> Option<Elem> {
46        let map = self.typemap.borrow();
47        let result = map.get(&TypeId::of::<T>());
48
49        result.cloned()
50    }
51
52    /// Register the element type for the given generic type.
53    pub fn register_elem<T: 'static>(&mut self, elem: Elem) {
54        let mut map = self.typemap.borrow_mut();
55
56        map.insert(TypeId::of::<T>(), elem);
57    }
58
59    pub fn child(&mut self) -> CubeContext {
60        let scope = self.scope.borrow_mut().child();
61
62        Self {
63            scope: Rc::new(RefCell::new(scope)),
64            root: self.root.clone(),
65            debug_enabled: self.debug_enabled,
66            typemap: self.typemap.clone(),
67        }
68    }
69
70    pub fn into_scope(self) -> Scope {
71        core::mem::drop(self.root);
72
73        Rc::into_inner(self.scope)
74            .expect("Only one reference")
75            .into_inner()
76    }
77
78    /// Create a new mutable local variable.
79    pub fn create_local_mut(&mut self, item: Item) -> ExpandElement {
80        let local = self.scope.borrow().allocator.create_local_mut(item);
81        self.scope.borrow_mut().add_local_mut(*local);
82        local
83    }
84
85    /// Create a new immutable local variable.
86    pub fn create_local(&mut self, item: Item) -> ExpandElement {
87        self.scope.borrow().allocator.create_local(item)
88    }
89
90    /// Create a new immutable local binding that must never be a reused variable, regardless of
91    /// allocator
92    pub fn create_local_restricted(&mut self, item: Item) -> ExpandElement {
93        self.scope.borrow().allocator.create_local_restricted(item)
94    }
95
96    /// Create a new matrix element.
97    pub fn create_matrix(&mut self, matrix: ir::Matrix) -> ExpandElement {
98        let matrix = self.scope.borrow().allocator.create_matrix(matrix);
99        self.scope.borrow_mut().add_matrix(*matrix);
100        matrix
101    }
102
103    /// Create a new slice element.
104    pub fn create_slice(&mut self, item: Item) -> ExpandElement {
105        let slice = self.scope.borrow().allocator.create_slice(item);
106        self.scope.borrow_mut().add_slice(*slice);
107        slice
108    }
109
110    pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement {
111        ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size))
112    }
113
114    pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
115        let local_array: ExpandElement =
116            self.root.borrow().allocator.create_local_array(item, size);
117        self.root.borrow_mut().add_local_array(*local_array);
118        local_array
119    }
120
121    pub fn create_const_array(&mut self, item: Item, data: Vec<Variable>) -> ExpandElement {
122        ExpandElement::Plain(self.root.borrow_mut().create_const_array(item, data))
123    }
124
125    /// Obtain the index-th input
126    pub fn input(&mut self, id: Id, item: Item) -> ExpandElement {
127        ExpandElement::Plain(crate::ir::Variable::new(
128            VariableKind::GlobalInputArray(id),
129            item,
130        ))
131    }
132
133    /// Obtain the index-th output
134    pub fn output(&mut self, id: Id, item: Item) -> ExpandElement {
135        let var = crate::ir::Variable::new(VariableKind::GlobalOutputArray(id), item);
136        self.scope.borrow_mut().write_global_custom(var);
137        ExpandElement::Plain(var)
138    }
139
140    /// Obtain the index-th scalar
141    pub fn scalar(&self, id: Id, elem: Elem) -> ExpandElement {
142        ExpandElement::Plain(crate::ir::Variable::new(
143            VariableKind::GlobalScalar(id),
144            Item::new(elem),
145        ))
146    }
147}