1use alloc::{borrow::Cow, rc::Rc, string::String, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7 BarrierLevel, CubeFnSource, DeviceProperties, FastMath, ManagedVariable, Matrix, Processor,
8 SemanticType, SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12 Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15pub type TypeMap = Rc<RefCell<HashMap<TypeId, StorageType>>>;
16pub type SizeMap = Rc<RefCell<HashMap<TypeId, usize>>>;
17
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
27#[allow(missing_docs)]
28pub struct Scope {
29 validation_errors: ValidationErrors,
30 pub depth: u8,
31 pub instructions: Vec<Instruction>,
32 pub locals: Vec<Variable>,
33 matrices: Vec<Variable>,
34 pipelines: Vec<Variable>,
35 shared: Vec<Variable>,
36 pub const_arrays: Vec<(Variable, Vec<Variable>)>,
37 local_arrays: Vec<Variable>,
38 index_offset_with_output_layout_position: Vec<usize>,
39 pub allocator: Allocator,
40 pub debug: DebugInfo,
41 #[type_hash(skip)]
42 #[cfg_attr(feature = "serde", serde(skip))]
43 pub typemap: TypeMap,
44 #[type_hash(skip)]
45 #[cfg_attr(feature = "serde", serde(skip))]
46 pub sizemap: SizeMap,
47 pub runtime_properties: Rc<TargetProperties>,
48 pub modes: Rc<RefCell<InstructionModes>>,
49 #[cfg_attr(feature = "serde", serde(skip))]
50 pub properties: Option<Rc<DeviceProperties>>,
51}
52
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
55pub struct ValidationErrors {
56 errors: Rc<RefCell<Vec<String>>>,
57}
58
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
62pub struct DebugInfo {
63 pub enabled: bool,
64 pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
65 pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
66 pub source_loc: Option<SourceLoc>,
67 pub entry_loc: Option<SourceLoc>,
68}
69
70#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
72#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
73pub struct InstructionModes {
74 pub fp_math_mode: EnumSet<FastMath>,
75}
76
77impl core::hash::Hash for Scope {
78 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
79 self.depth.hash(ra_expand_state);
80 self.instructions.hash(ra_expand_state);
81 self.locals.hash(ra_expand_state);
82 self.matrices.hash(ra_expand_state);
83 self.pipelines.hash(ra_expand_state);
84 self.shared.hash(ra_expand_state);
85 self.const_arrays.hash(ra_expand_state);
86 self.local_arrays.hash(ra_expand_state);
87 self.index_offset_with_output_layout_position
88 .hash(ra_expand_state);
89 }
90}
91
92#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
94#[allow(missing_docs)]
95pub enum ReadingStrategy {
96 OutputLayout,
98 Plain,
100}
101
102impl Scope {
103 pub fn device_properties(&mut self, properties: &DeviceProperties) {
105 self.properties = Some(Rc::new(properties.clone()));
106 }
107 pub fn root(debug_enabled: bool) -> Self {
111 Self {
112 validation_errors: ValidationErrors {
113 errors: Rc::new(RefCell::new(Vec::new())),
114 },
115 depth: 0,
116 instructions: Vec::new(),
117 locals: Vec::new(),
118 matrices: Vec::new(),
119 pipelines: Vec::new(),
120 local_arrays: Vec::new(),
121 shared: Vec::new(),
122 const_arrays: Vec::new(),
123 index_offset_with_output_layout_position: Vec::new(),
124 allocator: Allocator::default(),
125 debug: DebugInfo {
126 enabled: debug_enabled,
127 sources: Default::default(),
128 variable_names: Default::default(),
129 source_loc: None,
130 entry_loc: None,
131 },
132 typemap: Default::default(),
133 sizemap: Default::default(),
134 runtime_properties: Rc::new(Default::default()),
135 modes: Default::default(),
136 properties: None,
137 }
138 }
139
140 pub fn with_allocator(mut self, allocator: Allocator) -> Self {
142 self.allocator = allocator;
143 self
144 }
145
146 pub fn with_types(mut self, typemap: TypeMap) -> Self {
147 self.typemap = typemap;
148 self
149 }
150
151 pub fn create_matrix(&mut self, matrix: Matrix) -> ManagedVariable {
153 let matrix = self.allocator.create_matrix(matrix);
154 self.add_matrix(*matrix);
155 matrix
156 }
157
158 pub fn add_matrix(&mut self, variable: Variable) {
159 self.matrices.push(variable);
160 }
161
162 pub fn create_pipeline(&mut self, num_stages: u8) -> ManagedVariable {
164 let pipeline = self.allocator.create_pipeline(num_stages);
165 self.add_pipeline(*pipeline);
166 pipeline
167 }
168
169 pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ManagedVariable {
171 let token = Variable::new(
172 VariableKind::BarrierToken { id, level },
173 Type::semantic(SemanticType::BarrierToken),
174 );
175 ManagedVariable::Plain(token)
176 }
177
178 pub fn add_pipeline(&mut self, variable: Variable) {
179 self.pipelines.push(variable);
180 }
181
182 pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ManagedVariable {
184 self.allocator.create_local_mut(item.into())
185 }
186
187 pub fn add_local_mut(&mut self, var: Variable) {
189 if !self.locals.contains(&var) {
190 self.locals.push(var);
191 }
192 }
193
194 pub fn create_local_restricted(&mut self, item: Type) -> ManagedVariable {
197 self.allocator.create_local_restricted(item)
198 }
199
200 pub fn create_local(&mut self, item: Type) -> ManagedVariable {
202 self.allocator.create_local(item)
203 }
204
205 pub fn last_local_index(&self) -> Option<&Variable> {
207 self.locals.last()
208 }
209
210 pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
212 let mut inst = instruction.into();
213 inst.source_loc = self.debug.source_loc.clone();
214 inst.modes = *self.modes.borrow();
215 self.instructions.push(inst)
216 }
217
218 pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
220 let map = self.typemap.borrow();
221 let result = map.get(&TypeId::of::<T>());
222
223 result.cloned()
224 }
225
226 pub fn resolve_size<T: 'static>(&self) -> Option<usize> {
228 let map = self.sizemap.borrow();
229 let result = map.get(&TypeId::of::<T>());
230
231 result.cloned()
232 }
233
234 pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
236 let mut map = self.typemap.borrow_mut();
237
238 map.insert(TypeId::of::<T>(), elem);
239 }
240
241 pub fn register_size<T: 'static>(&mut self, size: usize) {
243 let mut map = self.sizemap.borrow_mut();
244
245 map.insert(TypeId::of::<T>(), size);
246 }
247
248 pub fn child(&mut self) -> Self {
250 Self {
251 validation_errors: self.validation_errors.clone(),
252 depth: self.depth + 1,
253 instructions: Vec::new(),
254 locals: Vec::new(),
255 matrices: Vec::new(),
256 pipelines: Vec::new(),
257 shared: Vec::new(),
258 const_arrays: Vec::new(),
259 local_arrays: Vec::new(),
260 index_offset_with_output_layout_position: Vec::new(),
261 allocator: self.allocator.clone(),
262 debug: self.debug.clone(),
263 typemap: self.typemap.clone(),
264 sizemap: self.sizemap.clone(),
265 runtime_properties: self.runtime_properties.clone(),
266 modes: self.modes.clone(),
267 properties: self.properties.clone(),
268 }
269 }
270
271 pub fn push_error(&mut self, msg: impl Into<String>) {
273 self.validation_errors.errors.borrow_mut().push(msg.into());
274 }
275
276 pub fn pop_errors(&mut self) -> Vec<String> {
278 self.validation_errors.errors.replace_with(|_| Vec::new())
279 }
280
281 pub fn process<'a>(
288 &mut self,
289 processors: impl IntoIterator<Item = &'a dyn Processor>,
290 ) -> ScopeProcessing {
291 let mut variables = core::mem::take(&mut self.locals);
292
293 for var in self.matrices.drain(..) {
294 variables.push(var);
295 }
296
297 let mut instructions = Vec::new();
298
299 for inst in self.instructions.drain(..) {
300 instructions.push(inst);
301 }
302
303 variables.extend(self.allocator.take_variables());
304
305 let mut processing = ScopeProcessing {
306 variables,
307 instructions,
308 typemap: self.typemap.clone(),
309 };
310
311 for p in processors {
312 processing = p.transform(processing, self.allocator.clone());
313 }
314
315 processing.variables.extend(self.allocator.take_variables());
317
318 processing
319 }
320
321 pub fn new_local_index(&self) -> u32 {
322 self.allocator.new_local_index()
323 }
324
325 pub fn create_shared_array<I: Into<Type>>(
327 &mut self,
328 item: I,
329 shared_memory_size: usize,
330 alignment: Option<usize>,
331 ) -> ManagedVariable {
332 let item = item.into();
333 let index = self.new_local_index();
334 let shared_array = Variable::new(
335 VariableKind::SharedArray {
336 id: index,
337 length: shared_memory_size,
338 unroll_factor: 1,
339 alignment,
340 },
341 item,
342 );
343 self.shared.push(shared_array);
344 ManagedVariable::Plain(shared_array)
345 }
346
347 pub fn create_shared<I: Into<Type>>(&mut self, item: I) -> ManagedVariable {
349 let item = item.into();
350 let index = self.new_local_index();
351 let shared = Variable::new(VariableKind::Shared { id: index }, item);
352 self.shared.push(shared);
353 ManagedVariable::Plain(shared)
354 }
355
356 pub fn create_const_array<I: Into<Type>>(
358 &mut self,
359 item: I,
360 data: Vec<Variable>,
361 ) -> ManagedVariable {
362 let item = item.into();
363 let index = self.new_local_index();
364 let const_array = Variable::new(
365 VariableKind::ConstantArray {
366 id: index,
367 length: data.len(),
368 unroll_factor: 1,
369 },
370 item,
371 );
372 self.const_arrays.push((const_array, data));
373 ManagedVariable::Plain(const_array)
374 }
375
376 pub fn input(&mut self, id: Id, item: Type) -> ManagedVariable {
378 ManagedVariable::Plain(crate::Variable::new(
379 VariableKind::GlobalInputArray(id),
380 item,
381 ))
382 }
383
384 pub fn output(&mut self, id: Id, item: Type) -> ManagedVariable {
386 let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
387 ManagedVariable::Plain(var)
388 }
389
390 pub fn scalar(&self, id: Id, storage: StorageType) -> ManagedVariable {
392 ManagedVariable::Plain(crate::Variable::new(
393 VariableKind::GlobalScalar(id),
394 Type::new(storage),
395 ))
396 }
397
398 pub fn create_local_array<I: Into<Type>>(
400 &mut self,
401 item: I,
402 array_size: usize,
403 ) -> ManagedVariable {
404 let local_array = self.allocator.create_local_array(item.into(), array_size);
405 self.add_local_array(*local_array);
406 local_array
407 }
408
409 pub fn add_local_array(&mut self, var: Variable) {
410 self.local_arrays.push(var);
411 }
412
413 pub fn update_source(&mut self, source: CubeFnSource) {
414 if self.debug.enabled {
415 self.debug.sources.borrow_mut().insert(source.clone());
416 self.debug.source_loc = Some(SourceLoc {
417 line: source.line,
418 column: source.column,
419 source,
420 });
421 if self.debug.entry_loc.is_none() {
422 self.debug.entry_loc = self.debug.source_loc.clone();
423 }
424 }
425 }
426
427 pub fn update_span(&mut self, line: u32, col: u32) {
428 if let Some(loc) = self.debug.source_loc.as_mut() {
429 loc.line = line;
430 loc.column = col;
431 }
432 }
433
434 pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
435 if self.debug.enabled {
436 self.debug
437 .variable_names
438 .borrow_mut()
439 .insert(variable, name.into());
440 }
441 }
442}
443
444impl Display for Scope {
445 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
446 writeln!(f, "{{")?;
447 for instruction in self.instructions.iter() {
448 let instruction_str = instruction.to_string();
449 if !instruction_str.is_empty() {
450 writeln!(
451 f,
452 "{}{}",
453 " ".repeat(self.depth as usize + 1),
454 instruction_str,
455 )?;
456 }
457 }
458 write!(f, "{}}}", " ".repeat(self.depth as usize))?;
459 Ok(())
460 }
461}