1use alloc::{borrow::Cow, rc::Rc, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7 BarrierLevel, CubeFnSource, ExpandElement, FastMath, Matrix, Processor, SemanticType,
8 SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12 Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
24#[allow(missing_docs)]
25pub struct Scope {
26 pub depth: u8,
27 pub instructions: Vec<Instruction>,
28 pub locals: Vec<Variable>,
29 matrices: Vec<Variable>,
30 pipelines: Vec<Variable>,
31 barriers: Vec<Variable>,
32 shared_memories: Vec<Variable>,
33 pub const_arrays: Vec<(Variable, Vec<Variable>)>,
34 local_arrays: Vec<Variable>,
35 index_offset_with_output_layout_position: Vec<usize>,
36 pub allocator: Allocator,
37 pub debug: DebugInfo,
38 #[type_hash(skip)]
39 #[cfg_attr(feature = "serde", serde(skip))]
40 pub typemap: Rc<RefCell<HashMap<TypeId, StorageType>>>,
41 pub runtime_properties: Rc<TargetProperties>,
42 pub modes: Rc<RefCell<InstructionModes>>,
43}
44
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
48pub struct DebugInfo {
49 pub enabled: bool,
50 pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
51 pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
52 pub source_loc: Option<SourceLoc>,
53 pub entry_loc: Option<SourceLoc>,
54}
55
56#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
59pub struct InstructionModes {
60 pub fp_math_mode: EnumSet<FastMath>,
61}
62
63impl core::hash::Hash for Scope {
64 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
65 self.depth.hash(ra_expand_state);
66 self.instructions.hash(ra_expand_state);
67 self.locals.hash(ra_expand_state);
68 self.matrices.hash(ra_expand_state);
69 self.pipelines.hash(ra_expand_state);
70 self.barriers.hash(ra_expand_state);
71 self.shared_memories.hash(ra_expand_state);
72 self.const_arrays.hash(ra_expand_state);
73 self.local_arrays.hash(ra_expand_state);
74 self.index_offset_with_output_layout_position
75 .hash(ra_expand_state);
76 }
77}
78
79#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
81#[allow(missing_docs)]
82pub enum ReadingStrategy {
83 OutputLayout,
85 Plain,
87}
88
89impl Scope {
90 pub fn root(debug_enabled: bool) -> Self {
95 Self {
96 depth: 0,
97 instructions: Vec::new(),
98 locals: Vec::new(),
99 matrices: Vec::new(),
100 pipelines: Vec::new(),
101 barriers: Vec::new(),
102 local_arrays: Vec::new(),
103 shared_memories: Vec::new(),
104 const_arrays: Vec::new(),
105 index_offset_with_output_layout_position: Vec::new(),
106 allocator: Allocator::default(),
107 debug: DebugInfo {
108 enabled: debug_enabled,
109 sources: Default::default(),
110 variable_names: Default::default(),
111 source_loc: None,
112 entry_loc: None,
113 },
114 typemap: Default::default(),
115 runtime_properties: Rc::new(Default::default()),
116 modes: Default::default(),
117 }
118 }
119
120 pub fn with_allocator(mut self, allocator: Allocator) -> Self {
122 self.allocator = allocator;
123 self
124 }
125
126 pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
128 let matrix = self.allocator.create_matrix(matrix);
129 self.add_matrix(*matrix);
130 matrix
131 }
132
133 pub fn add_matrix(&mut self, variable: Variable) {
134 self.matrices.push(variable);
135 }
136
137 pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
139 let pipeline = self.allocator.create_pipeline(num_stages);
140 self.add_pipeline(*pipeline);
141 pipeline
142 }
143
144 pub fn create_barrier(&mut self, level: BarrierLevel) -> ExpandElement {
146 let barrier = self.allocator.create_barrier(level);
147 self.add_barrier(*barrier);
148 barrier
149 }
150
151 pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ExpandElement {
153 let token = Variable::new(
154 VariableKind::BarrierToken { id, level },
155 Type::semantic(SemanticType::BarrierToken),
156 );
157 ExpandElement::Plain(token)
158 }
159
160 pub fn add_pipeline(&mut self, variable: Variable) {
161 self.pipelines.push(variable);
162 }
163
164 pub fn add_barrier(&mut self, variable: Variable) {
165 self.barriers.push(variable);
166 }
167
168 pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
170 self.allocator.create_local_mut(item.into())
171 }
172
173 pub fn add_local_mut(&mut self, var: Variable) {
175 if !self.locals.contains(&var) {
176 self.locals.push(var);
177 }
178 }
179
180 pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
183 self.allocator.create_local_restricted(item)
184 }
185
186 pub fn create_local(&mut self, item: Type) -> ExpandElement {
188 self.allocator.create_local(item)
189 }
190
191 pub fn last_local_index(&self) -> Option<&Variable> {
193 self.locals.last()
194 }
195
196 pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
198 let mut inst = instruction.into();
199 inst.source_loc = self.debug.source_loc.clone();
200 inst.modes = *self.modes.borrow();
201 self.instructions.push(inst)
202 }
203
204 pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
206 let map = self.typemap.borrow();
207 let result = map.get(&TypeId::of::<T>());
208
209 result.cloned()
210 }
211
212 pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
214 let mut map = self.typemap.borrow_mut();
215
216 map.insert(TypeId::of::<T>(), elem);
217 }
218
219 pub fn child(&mut self) -> Self {
221 Self {
222 depth: self.depth + 1,
223 instructions: Vec::new(),
224 locals: Vec::new(),
225 matrices: Vec::new(),
226 pipelines: Vec::new(),
227 barriers: Vec::new(),
228 shared_memories: Vec::new(),
229 const_arrays: Vec::new(),
230 local_arrays: Vec::new(),
231 index_offset_with_output_layout_position: Vec::new(),
232 allocator: self.allocator.clone(),
233 debug: self.debug.clone(),
234 typemap: self.typemap.clone(),
235 runtime_properties: self.runtime_properties.clone(),
236 modes: self.modes.clone(),
237 }
238 }
239
240 pub fn process<'a>(
247 &mut self,
248 processors: impl IntoIterator<Item = &'a dyn Processor>,
249 ) -> ScopeProcessing {
250 let mut variables = core::mem::take(&mut self.locals);
251
252 for var in self.matrices.drain(..) {
253 variables.push(var);
254 }
255
256 let mut instructions = Vec::new();
257
258 for inst in self.instructions.drain(..) {
259 instructions.push(inst);
260 }
261
262 variables.extend(self.allocator.take_variables());
263
264 let mut processing = ScopeProcessing {
265 variables,
266 instructions,
267 }
268 .optimize();
269
270 for p in processors {
271 processing = p.transform(processing, self.allocator.clone());
272 }
273
274 processing.variables.extend(self.allocator.take_variables());
276
277 processing
278 }
279
280 pub fn new_local_index(&self) -> u32 {
281 self.allocator.new_local_index()
282 }
283
284 pub fn create_shared<I: Into<Type>>(
286 &mut self,
287 item: I,
288 shared_memory_size: u32,
289 alignment: Option<u32>,
290 ) -> ExpandElement {
291 let item = item.into();
292 let index = self.new_local_index();
293 let shared_memory = Variable::new(
294 VariableKind::SharedMemory {
295 id: index,
296 length: shared_memory_size,
297 unroll_factor: 1,
298 alignment,
299 },
300 item,
301 );
302 self.shared_memories.push(shared_memory);
303 ExpandElement::Plain(shared_memory)
304 }
305
306 pub fn create_const_array<I: Into<Type>>(
308 &mut self,
309 item: I,
310 data: Vec<Variable>,
311 ) -> ExpandElement {
312 let item = item.into();
313 let index = self.new_local_index();
314 let const_array = Variable::new(
315 VariableKind::ConstantArray {
316 id: index,
317 length: data.len() as u32,
318 unroll_factor: 1,
319 },
320 item,
321 );
322 self.const_arrays.push((const_array, data));
323 ExpandElement::Plain(const_array)
324 }
325
326 pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
328 ExpandElement::Plain(crate::Variable::new(
329 VariableKind::GlobalInputArray(id),
330 item,
331 ))
332 }
333
334 pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
336 let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
337 ExpandElement::Plain(var)
338 }
339
340 pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
342 ExpandElement::Plain(crate::Variable::new(
343 VariableKind::GlobalScalar(id),
344 Type::new(storage),
345 ))
346 }
347
348 pub fn create_local_array<I: Into<Type>>(&mut self, item: I, array_size: u32) -> ExpandElement {
350 let local_array = self.allocator.create_local_array(item.into(), array_size);
351 self.add_local_array(*local_array);
352 local_array
353 }
354
355 pub fn add_local_array(&mut self, var: Variable) {
356 self.local_arrays.push(var);
357 }
358
359 pub fn update_source(&mut self, source: CubeFnSource) {
360 if self.debug.enabled {
361 self.debug.sources.borrow_mut().insert(source.clone());
362 self.debug.source_loc = Some(SourceLoc {
363 line: source.line,
364 column: source.column,
365 source,
366 });
367 if self.debug.entry_loc.is_none() {
368 self.debug.entry_loc = self.debug.source_loc.clone();
369 }
370 }
371 }
372
373 pub fn update_span(&mut self, line: u32, col: u32) {
374 if let Some(loc) = self.debug.source_loc.as_mut() {
375 loc.line = line;
376 loc.column = col;
377 }
378 }
379
380 pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
381 if self.debug.enabled {
382 self.debug
383 .variable_names
384 .borrow_mut()
385 .insert(variable, name.into());
386 }
387 }
388}
389
390impl Display for Scope {
391 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
392 writeln!(f, "{{")?;
393 for instruction in self.instructions.iter() {
394 let instruction_str = instruction.to_string();
395 if !instruction_str.is_empty() {
396 writeln!(
397 f,
398 "{}{}",
399 " ".repeat(self.depth as usize + 1),
400 instruction_str,
401 )?;
402 }
403 }
404 write!(f, "{}}}", " ".repeat(self.depth as usize))?;
405 Ok(())
406 }
407}