1use alloc::{borrow::Cow, rc::Rc, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use hashbrown::{HashMap, HashSet};
4
5use crate::{
6 BarrierLevel, CubeFnSource, ExpandElement, Matrix, Processor, SourceLoc, StorageType,
7 TargetProperties, TypeHash,
8};
9
10use super::{
11 Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
12};
13
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
23#[allow(missing_docs)]
24pub struct Scope {
25 pub depth: u8,
26 pub instructions: Vec<Instruction>,
27 pub locals: Vec<Variable>,
28 matrices: Vec<Variable>,
29 pipelines: Vec<Variable>,
30 barriers: Vec<Variable>,
31 shared_memories: Vec<Variable>,
32 pub const_arrays: Vec<(Variable, Vec<Variable>)>,
33 local_arrays: Vec<Variable>,
34 index_offset_with_output_layout_position: Vec<usize>,
35 pub allocator: Allocator,
36 pub debug: DebugInfo,
37 #[type_hash(skip)]
38 #[cfg_attr(feature = "serde", serde(skip))]
39 pub typemap: Rc<RefCell<HashMap<TypeId, StorageType>>>,
40 pub runtime_properties: Rc<TargetProperties>,
41}
42
43#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
46pub struct DebugInfo {
47 pub enabled: bool,
48 pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
49 pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
50 pub source_loc: Option<SourceLoc>,
51 pub entry_loc: Option<SourceLoc>,
52}
53
54impl core::hash::Hash for Scope {
55 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
56 self.depth.hash(ra_expand_state);
57 self.instructions.hash(ra_expand_state);
58 self.locals.hash(ra_expand_state);
59 self.matrices.hash(ra_expand_state);
60 self.pipelines.hash(ra_expand_state);
61 self.barriers.hash(ra_expand_state);
62 self.shared_memories.hash(ra_expand_state);
63 self.const_arrays.hash(ra_expand_state);
64 self.local_arrays.hash(ra_expand_state);
65 self.index_offset_with_output_layout_position
66 .hash(ra_expand_state);
67 }
68}
69
70#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
72#[allow(missing_docs)]
73pub enum ReadingStrategy {
74 OutputLayout,
76 Plain,
78}
79
80impl Scope {
81 pub fn root(debug_enabled: bool) -> Self {
86 Self {
87 depth: 0,
88 instructions: Vec::new(),
89 locals: Vec::new(),
90 matrices: Vec::new(),
91 pipelines: Vec::new(),
92 barriers: Vec::new(),
93 local_arrays: Vec::new(),
94 shared_memories: Vec::new(),
95 const_arrays: Vec::new(),
96 index_offset_with_output_layout_position: Vec::new(),
97 allocator: Allocator::default(),
98 debug: DebugInfo {
99 enabled: debug_enabled,
100 sources: Default::default(),
101 variable_names: Default::default(),
102 source_loc: None,
103 entry_loc: None,
104 },
105 typemap: Default::default(),
106 runtime_properties: Rc::new(Default::default()),
107 }
108 }
109
110 pub fn with_allocator(mut self, allocator: Allocator) -> Self {
112 self.allocator = allocator;
113 self
114 }
115
116 pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
118 let matrix = self.allocator.create_matrix(matrix);
119 self.add_matrix(*matrix);
120 matrix
121 }
122
123 pub fn add_matrix(&mut self, variable: Variable) {
124 self.matrices.push(variable);
125 }
126
127 pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
129 let pipeline = self.allocator.create_pipeline(num_stages);
130 self.add_pipeline(*pipeline);
131 pipeline
132 }
133
134 pub fn create_barrier(&mut self, level: BarrierLevel) -> ExpandElement {
136 let barrier = self.allocator.create_barrier(level);
137 self.add_barrier(*barrier);
138 barrier
139 }
140
141 pub fn add_pipeline(&mut self, variable: Variable) {
142 self.pipelines.push(variable);
143 }
144
145 pub fn add_barrier(&mut self, variable: Variable) {
146 self.barriers.push(variable);
147 }
148
149 pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
151 self.allocator.create_local_mut(item.into())
152 }
153
154 pub fn add_local_mut(&mut self, var: Variable) {
156 if !self.locals.contains(&var) {
157 self.locals.push(var);
158 }
159 }
160
161 pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
164 self.allocator.create_local_restricted(item)
165 }
166
167 pub fn create_local(&mut self, item: Type) -> ExpandElement {
169 self.allocator.create_local(item)
170 }
171
172 pub fn last_local_index(&self) -> Option<&Variable> {
174 self.locals.last()
175 }
176
177 pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
179 let mut inst = instruction.into();
180 inst.source_loc = self.debug.source_loc.clone();
181 self.instructions.push(inst)
182 }
183
184 pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
186 let map = self.typemap.borrow();
187 let result = map.get(&TypeId::of::<T>());
188
189 result.cloned()
190 }
191
192 pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
194 let mut map = self.typemap.borrow_mut();
195
196 map.insert(TypeId::of::<T>(), elem);
197 }
198
199 pub fn child(&mut self) -> Self {
201 Self {
202 depth: self.depth + 1,
203 instructions: Vec::new(),
204 locals: Vec::new(),
205 matrices: Vec::new(),
206 pipelines: Vec::new(),
207 barriers: Vec::new(),
208 shared_memories: Vec::new(),
209 const_arrays: Vec::new(),
210 local_arrays: Vec::new(),
211 index_offset_with_output_layout_position: Vec::new(),
212 allocator: self.allocator.clone(),
213 debug: self.debug.clone(),
214 typemap: self.typemap.clone(),
215 runtime_properties: self.runtime_properties.clone(),
216 }
217 }
218
219 pub fn process<'a>(
226 &mut self,
227 processors: impl IntoIterator<Item = &'a dyn Processor>,
228 ) -> ScopeProcessing {
229 let mut variables = core::mem::take(&mut self.locals);
230
231 for var in self.matrices.drain(..) {
232 variables.push(var);
233 }
234
235 let mut instructions = Vec::new();
236
237 for inst in self.instructions.drain(..) {
238 instructions.push(inst);
239 }
240
241 variables.extend(self.allocator.take_variables());
242
243 let mut processing = ScopeProcessing {
244 variables,
245 instructions,
246 }
247 .optimize();
248
249 for p in processors {
250 processing = p.transform(processing, self.allocator.clone());
251 }
252
253 processing.variables.extend(self.allocator.take_variables());
255
256 processing
257 }
258
259 pub fn new_local_index(&self) -> u32 {
260 self.allocator.new_local_index()
261 }
262
263 pub fn create_shared<I: Into<Type>>(
265 &mut self,
266 item: I,
267 shared_memory_size: u32,
268 alignment: Option<u32>,
269 ) -> ExpandElement {
270 let item = item.into();
271 let index = self.new_local_index();
272 let shared_memory = Variable::new(
273 VariableKind::SharedMemory {
274 id: index,
275 length: shared_memory_size,
276 unroll_factor: 1,
277 alignment,
278 },
279 item,
280 );
281 self.shared_memories.push(shared_memory);
282 ExpandElement::Plain(shared_memory)
283 }
284
285 pub fn create_const_array<I: Into<Type>>(
287 &mut self,
288 item: I,
289 data: Vec<Variable>,
290 ) -> ExpandElement {
291 let item = item.into();
292 let index = self.new_local_index();
293 let const_array = Variable::new(
294 VariableKind::ConstantArray {
295 id: index,
296 length: data.len() as u32,
297 unroll_factor: 1,
298 },
299 item,
300 );
301 self.const_arrays.push((const_array, data));
302 ExpandElement::Plain(const_array)
303 }
304
305 pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
307 ExpandElement::Plain(crate::Variable::new(
308 VariableKind::GlobalInputArray(id),
309 item,
310 ))
311 }
312
313 pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
315 let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
316 ExpandElement::Plain(var)
317 }
318
319 pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
321 ExpandElement::Plain(crate::Variable::new(
322 VariableKind::GlobalScalar(id),
323 Type::new(storage),
324 ))
325 }
326
327 pub fn create_local_array<I: Into<Type>>(&mut self, item: I, array_size: u32) -> ExpandElement {
329 let local_array = self.allocator.create_local_array(item.into(), array_size);
330 self.add_local_array(*local_array);
331 local_array
332 }
333
334 pub fn add_local_array(&mut self, var: Variable) {
335 self.local_arrays.push(var);
336 }
337
338 pub fn update_source(&mut self, source: CubeFnSource) {
339 if self.debug.enabled {
340 self.debug.sources.borrow_mut().insert(source.clone());
341 self.debug.source_loc = Some(SourceLoc {
342 line: source.line,
343 column: source.column,
344 source,
345 });
346 if self.debug.entry_loc.is_none() {
347 self.debug.entry_loc = self.debug.source_loc.clone();
348 }
349 }
350 }
351
352 pub fn update_span(&mut self, line: u32, col: u32) {
353 if let Some(loc) = self.debug.source_loc.as_mut() {
354 loc.line = line;
355 loc.column = col;
356 }
357 }
358
359 pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
360 if self.debug.enabled {
361 self.debug
362 .variable_names
363 .borrow_mut()
364 .insert(variable, name.into());
365 }
366 }
367}
368
369impl Display for Scope {
370 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
371 writeln!(f, "{{")?;
372 for instruction in self.instructions.iter() {
373 let instruction_str = instruction.to_string();
374 if !instruction_str.is_empty() {
375 writeln!(
376 f,
377 "{}{}",
378 " ".repeat(self.depth as usize + 1),
379 instruction_str,
380 )?;
381 }
382 }
383 write!(f, "{}}}", " ".repeat(self.depth as usize))?;
384 Ok(())
385 }
386}