1use alloc::{borrow::Cow, boxed::Box, rc::Rc, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use hashbrown::{HashMap, HashSet};
4
5use crate::{BarrierLevel, CubeFnSource, ExpandElement, Matrix, Processor, SourceLoc, TypeHash};
6
7use super::{
8 Allocator, Elem, Id, Instruction, Item, Variable, VariableKind, processing::ScopeProcessing,
9};
10
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
20#[allow(missing_docs)]
21pub struct Scope {
22 pub depth: u8,
23 pub instructions: Vec<Instruction>,
24 pub locals: Vec<Variable>,
25 matrices: Vec<Variable>,
26 pipelines: Vec<Variable>,
27 barriers: Vec<Variable>,
28 shared_memories: Vec<Variable>,
29 pub const_arrays: Vec<(Variable, Vec<Variable>)>,
30 local_arrays: Vec<Variable>,
31 index_offset_with_output_layout_position: Vec<usize>,
32 pub allocator: Allocator,
33 pub debug: DebugInfo,
34 #[type_hash(skip)]
35 #[cfg_attr(feature = "serde", serde(skip))]
36 pub typemap: Rc<RefCell<HashMap<TypeId, Elem>>>,
37}
38
39#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
41#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
42pub struct DebugInfo {
43 pub enabled: bool,
44 pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
45 pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
46 pub source_loc: Option<SourceLoc>,
47 pub entry_loc: Option<SourceLoc>,
48}
49
50impl core::hash::Hash for Scope {
51 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
52 self.depth.hash(ra_expand_state);
53 self.instructions.hash(ra_expand_state);
54 self.locals.hash(ra_expand_state);
55 self.matrices.hash(ra_expand_state);
56 self.pipelines.hash(ra_expand_state);
57 self.barriers.hash(ra_expand_state);
58 self.shared_memories.hash(ra_expand_state);
59 self.const_arrays.hash(ra_expand_state);
60 self.local_arrays.hash(ra_expand_state);
61 self.index_offset_with_output_layout_position
62 .hash(ra_expand_state);
63 }
64}
65
66#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
68#[allow(missing_docs)]
69pub enum ReadingStrategy {
70 OutputLayout,
72 Plain,
74}
75
76impl Scope {
77 pub fn root(debug_enabled: bool) -> Self {
82 Self {
83 depth: 0,
84 instructions: Vec::new(),
85 locals: Vec::new(),
86 matrices: Vec::new(),
87 pipelines: Vec::new(),
88 barriers: Vec::new(),
89 local_arrays: Vec::new(),
90 shared_memories: Vec::new(),
91 const_arrays: Vec::new(),
92 index_offset_with_output_layout_position: Vec::new(),
93 allocator: Allocator::default(),
94 debug: DebugInfo {
95 enabled: debug_enabled,
96 sources: Default::default(),
97 variable_names: Default::default(),
98 source_loc: None,
99 entry_loc: None,
100 },
101 typemap: Default::default(),
102 }
103 }
104
105 pub fn with_allocator(mut self, allocator: Allocator) -> Self {
107 self.allocator = allocator;
108 self
109 }
110
111 pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
113 let matrix = self.allocator.create_matrix(matrix);
114 self.add_matrix(*matrix);
115 matrix
116 }
117
118 pub fn add_matrix(&mut self, variable: Variable) {
119 self.matrices.push(variable);
120 }
121
122 pub fn create_pipeline(&mut self, item: Item, num_stages: u8) -> ExpandElement {
124 let pipeline = self.allocator.create_pipeline(item, num_stages);
125 self.add_pipeline(*pipeline);
126 pipeline
127 }
128
129 pub fn create_barrier(&mut self, item: Item, level: BarrierLevel) -> ExpandElement {
131 let barrier = self.allocator.create_barrier(item, level);
132 self.add_barrier(*barrier);
133 barrier
134 }
135
136 pub fn add_pipeline(&mut self, variable: Variable) {
137 self.pipelines.push(variable);
138 }
139
140 pub fn add_barrier(&mut self, variable: Variable) {
141 self.barriers.push(variable);
142 }
143
144 pub fn create_local_mut<I: Into<Item>>(&mut self, item: I) -> ExpandElement {
146 self.allocator.create_local_mut(item.into())
147 }
148
149 pub fn add_local_mut(&mut self, var: Variable) {
151 if !self.locals.contains(&var) {
152 self.locals.push(var);
153 }
154 }
155
156 pub fn create_local_restricted(&mut self, item: Item) -> ExpandElement {
159 self.allocator.create_local_restricted(item)
160 }
161
162 pub fn create_local(&mut self, item: Item) -> ExpandElement {
164 self.allocator.create_local(item)
165 }
166
167 pub fn last_local_index(&self) -> Option<&Variable> {
169 self.locals.last()
170 }
171
172 pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
174 let mut inst = instruction.into();
175 inst.source_loc = self.debug.source_loc.clone();
176 self.instructions.push(inst)
177 }
178
179 pub fn resolve_elem<T: 'static>(&self) -> Option<Elem> {
181 let map = self.typemap.borrow();
182 let result = map.get(&TypeId::of::<T>());
183
184 result.cloned()
185 }
186
187 pub fn register_elem<T: 'static>(&mut self, elem: Elem) {
189 let mut map = self.typemap.borrow_mut();
190
191 map.insert(TypeId::of::<T>(), elem);
192 }
193
194 pub fn child(&mut self) -> Self {
196 Self {
197 depth: self.depth + 1,
198 instructions: Vec::new(),
199 locals: Vec::new(),
200 matrices: Vec::new(),
201 pipelines: Vec::new(),
202 barriers: Vec::new(),
203 shared_memories: Vec::new(),
204 const_arrays: Vec::new(),
205 local_arrays: Vec::new(),
206 index_offset_with_output_layout_position: Vec::new(),
207 allocator: self.allocator.clone(),
208 debug: self.debug.clone(),
209 typemap: self.typemap.clone(),
210 }
211 }
212
213 pub fn process(
220 &mut self,
221 processors: impl IntoIterator<Item = Box<dyn Processor>>,
222 ) -> ScopeProcessing {
223 let mut variables = core::mem::take(&mut self.locals);
224
225 for var in self.matrices.drain(..) {
226 variables.push(var);
227 }
228
229 let mut instructions = Vec::new();
230
231 for inst in self.instructions.drain(..) {
232 instructions.push(inst);
233 }
234
235 variables.extend(self.allocator.take_variables());
236
237 let mut processing = ScopeProcessing {
238 variables,
239 instructions,
240 }
241 .optimize();
242
243 for p in processors {
244 processing = p.transform(processing, self.allocator.clone());
245 }
246
247 processing
248 }
249
250 pub fn new_local_index(&self) -> u32 {
251 self.allocator.new_local_index()
252 }
253
254 pub fn create_shared<I: Into<Item>>(
256 &mut self,
257 item: I,
258 shared_memory_size: u32,
259 alignment: Option<u32>,
260 ) -> ExpandElement {
261 let item = item.into();
262 let index = self.new_local_index();
263 let shared_memory = Variable::new(
264 VariableKind::SharedMemory {
265 id: index,
266 length: shared_memory_size,
267 alignment,
268 },
269 item,
270 );
271 self.shared_memories.push(shared_memory);
272 ExpandElement::Plain(shared_memory)
273 }
274
275 pub fn create_const_array<I: Into<Item>>(
277 &mut self,
278 item: I,
279 data: Vec<Variable>,
280 ) -> ExpandElement {
281 let item = item.into();
282 let index = self.new_local_index();
283 let const_array = Variable::new(
284 VariableKind::ConstantArray {
285 id: index,
286 length: data.len() as u32,
287 },
288 item,
289 );
290 self.const_arrays.push((const_array, data));
291 ExpandElement::Plain(const_array)
292 }
293
294 pub fn input(&mut self, id: Id, item: Item) -> ExpandElement {
296 ExpandElement::Plain(crate::Variable::new(
297 VariableKind::GlobalInputArray(id),
298 item,
299 ))
300 }
301
302 pub fn output(&mut self, id: Id, item: Item) -> ExpandElement {
304 let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
305 ExpandElement::Plain(var)
306 }
307
308 pub fn scalar(&self, id: Id, elem: Elem) -> ExpandElement {
310 ExpandElement::Plain(crate::Variable::new(
311 VariableKind::GlobalScalar(id),
312 Item::new(elem),
313 ))
314 }
315
316 pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> ExpandElement {
318 let local_array = self.allocator.create_local_array(item.into(), array_size);
319 self.add_local_array(*local_array);
320 local_array
321 }
322
323 pub fn add_local_array(&mut self, var: Variable) {
324 self.local_arrays.push(var);
325 }
326
327 pub fn update_source(&mut self, source: CubeFnSource) {
328 if self.debug.enabled {
329 self.debug.sources.borrow_mut().insert(source.clone());
330 self.debug.source_loc = Some(SourceLoc {
331 line: source.line,
332 column: source.column,
333 source,
334 });
335 if self.debug.entry_loc.is_none() {
336 self.debug.entry_loc = self.debug.source_loc.clone();
337 }
338 }
339 }
340
341 pub fn update_span(&mut self, line: u32, col: u32) {
342 if let Some(loc) = self.debug.source_loc.as_mut() {
343 loc.line = line;
344 loc.column = col;
345 }
346 }
347
348 pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
349 if self.debug.enabled {
350 self.debug
351 .variable_names
352 .borrow_mut()
353 .insert(variable, name.into());
354 }
355 }
356}
357
358impl Display for Scope {
359 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
360 writeln!(f, "{{")?;
361 for instruction in self.instructions.iter() {
362 let instruction_str = instruction.to_string();
363 if !instruction_str.is_empty() {
364 writeln!(
365 f,
366 "{}{}",
367 " ".repeat(self.depth as usize + 1),
368 instruction_str,
369 )?;
370 }
371 }
372 write!(f, "{}}}", " ".repeat(self.depth as usize))?;
373 Ok(())
374 }
375}