cubecl_core/frontend/
context.rs1use crate::ir::{self, Elem, Instruction, Item, Scope, Variable, VariableKind};
2use crate::{frontend::ExpandElement, ir::Id};
3use alloc::rc::Rc;
4use core::cell::RefCell;
5use cubecl_runtime::debug::DebugLogger;
6use std::any::TypeId;
7use std::collections::HashMap;
8
9pub struct CubeContext {
10 pub root: Rc<RefCell<Scope>>,
11 pub scope: Rc<RefCell<Scope>>,
12 pub debug_enabled: bool,
13 pub typemap: Rc<RefCell<HashMap<TypeId, Elem>>>,
14}
15
16impl Default for CubeContext {
17 fn default() -> Self {
18 Self::root()
19 }
20}
21
22impl CubeContext {
23 pub fn root() -> CubeContext {
28 let root = Rc::new(RefCell::new(Scope::root()));
29 let typemap = Rc::new(RefCell::new(HashMap::new()));
30 let scope = root.clone();
31
32 Self {
33 scope,
34 root,
35 debug_enabled: DebugLogger::default().is_activated(),
36 typemap,
37 }
38 }
39
40 pub fn register<O: Into<Instruction>>(&mut self, op: O) {
41 self.scope.borrow_mut().register(op)
42 }
43
44 pub fn resolve_elem<T: 'static>(&self) -> Option<Elem> {
46 let map = self.typemap.borrow();
47 let result = map.get(&TypeId::of::<T>());
48
49 result.cloned()
50 }
51
52 pub fn register_elem<T: 'static>(&mut self, elem: Elem) {
54 let mut map = self.typemap.borrow_mut();
55
56 map.insert(TypeId::of::<T>(), elem);
57 }
58
59 pub fn child(&mut self) -> CubeContext {
60 let scope = self.scope.borrow_mut().child();
61
62 Self {
63 scope: Rc::new(RefCell::new(scope)),
64 root: self.root.clone(),
65 debug_enabled: self.debug_enabled,
66 typemap: self.typemap.clone(),
67 }
68 }
69
70 pub fn into_scope(self) -> Scope {
71 core::mem::drop(self.root);
72
73 Rc::into_inner(self.scope)
74 .expect("Only one reference")
75 .into_inner()
76 }
77
78 pub fn create_local_mut(&mut self, item: Item) -> ExpandElement {
80 let local = self.scope.borrow().allocator.create_local_mut(item);
81 self.scope.borrow_mut().add_local_mut(*local);
82 local
83 }
84
85 pub fn create_local(&mut self, item: Item) -> ExpandElement {
87 self.scope.borrow().allocator.create_local(item)
88 }
89
90 pub fn create_local_restricted(&mut self, item: Item) -> ExpandElement {
93 self.scope.borrow().allocator.create_local_restricted(item)
94 }
95
96 pub fn create_matrix(&mut self, matrix: ir::Matrix) -> ExpandElement {
98 let matrix = self.scope.borrow().allocator.create_matrix(matrix);
99 self.scope.borrow_mut().add_matrix(*matrix);
100 matrix
101 }
102
103 pub fn create_slice(&mut self, item: Item) -> ExpandElement {
105 let slice = self.scope.borrow().allocator.create_slice(item);
106 self.scope.borrow_mut().add_slice(*slice);
107 slice
108 }
109
110 pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement {
111 ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size))
112 }
113
114 pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
115 let local_array: ExpandElement =
116 self.root.borrow().allocator.create_local_array(item, size);
117 self.root.borrow_mut().add_local_array(*local_array);
118 local_array
119 }
120
121 pub fn create_const_array(&mut self, item: Item, data: Vec<Variable>) -> ExpandElement {
122 ExpandElement::Plain(self.root.borrow_mut().create_const_array(item, data))
123 }
124
125 pub fn input(&mut self, id: Id, item: Item) -> ExpandElement {
127 ExpandElement::Plain(crate::ir::Variable::new(
128 VariableKind::GlobalInputArray(id),
129 item,
130 ))
131 }
132
133 pub fn output(&mut self, id: Id, item: Item) -> ExpandElement {
135 let var = crate::ir::Variable::new(VariableKind::GlobalOutputArray(id), item);
136 self.scope.borrow_mut().write_global_custom(var);
137 ExpandElement::Plain(var)
138 }
139
140 pub fn scalar(&self, id: Id, elem: Elem) -> ExpandElement {
142 ExpandElement::Plain(crate::ir::Variable::new(
143 VariableKind::GlobalScalar(id),
144 Item::new(elem),
145 ))
146 }
147}