1use alloc::{borrow::Cow, rc::Rc, string::String, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7 BarrierLevel, CubeFnSource, ExpandElement, FastMath, Matrix, Processor, SemanticType,
8 SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12 Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
24#[allow(missing_docs)]
25pub struct Scope {
26 validation_errors: ValidationErrors,
27 pub depth: u8,
28 pub instructions: Vec<Instruction>,
29 pub locals: Vec<Variable>,
30 matrices: Vec<Variable>,
31 pipelines: Vec<Variable>,
32 shared: Vec<Variable>,
33 pub const_arrays: Vec<(Variable, Vec<Variable>)>,
34 local_arrays: Vec<Variable>,
35 index_offset_with_output_layout_position: Vec<usize>,
36 pub allocator: Allocator,
37 pub debug: DebugInfo,
38 #[type_hash(skip)]
39 #[cfg_attr(feature = "serde", serde(skip))]
40 pub typemap: Rc<RefCell<HashMap<TypeId, StorageType>>>,
41 pub runtime_properties: Rc<TargetProperties>,
42 pub modes: Rc<RefCell<InstructionModes>>,
43}
44
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
47pub struct ValidationErrors {
48 errors: Rc<RefCell<Vec<String>>>,
49}
50
51#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
54pub struct DebugInfo {
55 pub enabled: bool,
56 pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
57 pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
58 pub source_loc: Option<SourceLoc>,
59 pub entry_loc: Option<SourceLoc>,
60}
61
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
64#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
65pub struct InstructionModes {
66 pub fp_math_mode: EnumSet<FastMath>,
67}
68
69impl core::hash::Hash for Scope {
70 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
71 self.depth.hash(ra_expand_state);
72 self.instructions.hash(ra_expand_state);
73 self.locals.hash(ra_expand_state);
74 self.matrices.hash(ra_expand_state);
75 self.pipelines.hash(ra_expand_state);
76 self.shared.hash(ra_expand_state);
77 self.const_arrays.hash(ra_expand_state);
78 self.local_arrays.hash(ra_expand_state);
79 self.index_offset_with_output_layout_position
80 .hash(ra_expand_state);
81 }
82}
83
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
86#[allow(missing_docs)]
87pub enum ReadingStrategy {
88 OutputLayout,
90 Plain,
92}
93
94impl Scope {
95 pub fn root(debug_enabled: bool) -> Self {
100 Self {
101 validation_errors: ValidationErrors {
102 errors: Rc::new(RefCell::new(Vec::new())),
103 },
104 depth: 0,
105 instructions: Vec::new(),
106 locals: Vec::new(),
107 matrices: Vec::new(),
108 pipelines: Vec::new(),
109 local_arrays: Vec::new(),
110 shared: Vec::new(),
111 const_arrays: Vec::new(),
112 index_offset_with_output_layout_position: Vec::new(),
113 allocator: Allocator::default(),
114 debug: DebugInfo {
115 enabled: debug_enabled,
116 sources: Default::default(),
117 variable_names: Default::default(),
118 source_loc: None,
119 entry_loc: None,
120 },
121 typemap: Default::default(),
122 runtime_properties: Rc::new(Default::default()),
123 modes: Default::default(),
124 }
125 }
126
127 pub fn with_allocator(mut self, allocator: Allocator) -> Self {
129 self.allocator = allocator;
130 self
131 }
132
133 pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
135 let matrix = self.allocator.create_matrix(matrix);
136 self.add_matrix(*matrix);
137 matrix
138 }
139
140 pub fn add_matrix(&mut self, variable: Variable) {
141 self.matrices.push(variable);
142 }
143
144 pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
146 let pipeline = self.allocator.create_pipeline(num_stages);
147 self.add_pipeline(*pipeline);
148 pipeline
149 }
150
151 pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ExpandElement {
153 let token = Variable::new(
154 VariableKind::BarrierToken { id, level },
155 Type::semantic(SemanticType::BarrierToken),
156 );
157 ExpandElement::Plain(token)
158 }
159
160 pub fn add_pipeline(&mut self, variable: Variable) {
161 self.pipelines.push(variable);
162 }
163
164 pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
166 self.allocator.create_local_mut(item.into())
167 }
168
169 pub fn add_local_mut(&mut self, var: Variable) {
171 if !self.locals.contains(&var) {
172 self.locals.push(var);
173 }
174 }
175
176 pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
179 self.allocator.create_local_restricted(item)
180 }
181
182 pub fn create_local(&mut self, item: Type) -> ExpandElement {
184 self.allocator.create_local(item)
185 }
186
187 pub fn last_local_index(&self) -> Option<&Variable> {
189 self.locals.last()
190 }
191
192 pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
194 let mut inst = instruction.into();
195 inst.source_loc = self.debug.source_loc.clone();
196 inst.modes = *self.modes.borrow();
197 self.instructions.push(inst)
198 }
199
200 pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
202 let map = self.typemap.borrow();
203 let result = map.get(&TypeId::of::<T>());
204
205 result.cloned()
206 }
207
208 pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
210 let mut map = self.typemap.borrow_mut();
211
212 map.insert(TypeId::of::<T>(), elem);
213 }
214
215 pub fn child(&mut self) -> Self {
217 Self {
218 validation_errors: self.validation_errors.clone(),
219 depth: self.depth + 1,
220 instructions: Vec::new(),
221 locals: Vec::new(),
222 matrices: Vec::new(),
223 pipelines: Vec::new(),
224 shared: Vec::new(),
225 const_arrays: Vec::new(),
226 local_arrays: Vec::new(),
227 index_offset_with_output_layout_position: Vec::new(),
228 allocator: self.allocator.clone(),
229 debug: self.debug.clone(),
230 typemap: self.typemap.clone(),
231 runtime_properties: self.runtime_properties.clone(),
232 modes: self.modes.clone(),
233 }
234 }
235
236 pub fn push_error(&mut self, msg: impl Into<String>) {
238 self.validation_errors.errors.borrow_mut().push(msg.into());
239 }
240
241 pub fn pop_errors(&mut self) -> Vec<String> {
243 self.validation_errors.errors.replace_with(|_| Vec::new())
244 }
245
246 pub fn process<'a>(
253 &mut self,
254 processors: impl IntoIterator<Item = &'a dyn Processor>,
255 ) -> ScopeProcessing {
256 let mut variables = core::mem::take(&mut self.locals);
257
258 for var in self.matrices.drain(..) {
259 variables.push(var);
260 }
261
262 let mut instructions = Vec::new();
263
264 for inst in self.instructions.drain(..) {
265 instructions.push(inst);
266 }
267
268 variables.extend(self.allocator.take_variables());
269
270 let mut processing = ScopeProcessing {
271 variables,
272 instructions,
273 }
274 .optimize();
275
276 for p in processors {
277 processing = p.transform(processing, self.allocator.clone());
278 }
279
280 processing.variables.extend(self.allocator.take_variables());
282
283 processing
284 }
285
286 pub fn new_local_index(&self) -> u32 {
287 self.allocator.new_local_index()
288 }
289
290 pub fn create_shared_array<I: Into<Type>>(
292 &mut self,
293 item: I,
294 shared_memory_size: u32,
295 alignment: Option<u32>,
296 ) -> ExpandElement {
297 let item = item.into();
298 let index = self.new_local_index();
299 let shared_array = Variable::new(
300 VariableKind::SharedArray {
301 id: index,
302 length: shared_memory_size,
303 unroll_factor: 1,
304 alignment,
305 },
306 item,
307 );
308 self.shared.push(shared_array);
309 ExpandElement::Plain(shared_array)
310 }
311
312 pub fn create_shared<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
314 let item = item.into();
315 let index = self.new_local_index();
316 let shared = Variable::new(VariableKind::Shared { id: index }, item);
317 self.shared.push(shared);
318 ExpandElement::Plain(shared)
319 }
320
321 pub fn create_const_array<I: Into<Type>>(
323 &mut self,
324 item: I,
325 data: Vec<Variable>,
326 ) -> ExpandElement {
327 let item = item.into();
328 let index = self.new_local_index();
329 let const_array = Variable::new(
330 VariableKind::ConstantArray {
331 id: index,
332 length: data.len() as u32,
333 unroll_factor: 1,
334 },
335 item,
336 );
337 self.const_arrays.push((const_array, data));
338 ExpandElement::Plain(const_array)
339 }
340
341 pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
343 ExpandElement::Plain(crate::Variable::new(
344 VariableKind::GlobalInputArray(id),
345 item,
346 ))
347 }
348
349 pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
351 let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
352 ExpandElement::Plain(var)
353 }
354
355 pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
357 ExpandElement::Plain(crate::Variable::new(
358 VariableKind::GlobalScalar(id),
359 Type::new(storage),
360 ))
361 }
362
363 pub fn create_local_array<I: Into<Type>>(&mut self, item: I, array_size: u32) -> ExpandElement {
365 let local_array = self.allocator.create_local_array(item.into(), array_size);
366 self.add_local_array(*local_array);
367 local_array
368 }
369
370 pub fn add_local_array(&mut self, var: Variable) {
371 self.local_arrays.push(var);
372 }
373
374 pub fn update_source(&mut self, source: CubeFnSource) {
375 if self.debug.enabled {
376 self.debug.sources.borrow_mut().insert(source.clone());
377 self.debug.source_loc = Some(SourceLoc {
378 line: source.line,
379 column: source.column,
380 source,
381 });
382 if self.debug.entry_loc.is_none() {
383 self.debug.entry_loc = self.debug.source_loc.clone();
384 }
385 }
386 }
387
388 pub fn update_span(&mut self, line: u32, col: u32) {
389 if let Some(loc) = self.debug.source_loc.as_mut() {
390 loc.line = line;
391 loc.column = col;
392 }
393 }
394
395 pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
396 if self.debug.enabled {
397 self.debug
398 .variable_names
399 .borrow_mut()
400 .insert(variable, name.into());
401 }
402 }
403}
404
405impl Display for Scope {
406 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
407 writeln!(f, "{{")?;
408 for instruction in self.instructions.iter() {
409 let instruction_str = instruction.to_string();
410 if !instruction_str.is_empty() {
411 writeln!(
412 f,
413 "{}{}",
414 " ".repeat(self.depth as usize + 1),
415 instruction_str,
416 )?;
417 }
418 }
419 write!(f, "{}}}", " ".repeat(self.depth as usize))?;
420 Ok(())
421 }
422}