1use alloc::{borrow::Cow, rc::Rc, string::String, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7 BarrierLevel, CubeFnSource, DeviceProperties, ExpandElement, FastMath, Matrix, Processor,
8 SemanticType, SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12 Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15pub type TypeMap = Rc<RefCell<HashMap<TypeId, StorageType>>>;
16
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
26#[allow(missing_docs)]
27pub struct Scope {
28 validation_errors: ValidationErrors,
29 pub depth: u8,
30 pub instructions: Vec<Instruction>,
31 pub locals: Vec<Variable>,
32 matrices: Vec<Variable>,
33 pipelines: Vec<Variable>,
34 shared: Vec<Variable>,
35 pub const_arrays: Vec<(Variable, Vec<Variable>)>,
36 local_arrays: Vec<Variable>,
37 index_offset_with_output_layout_position: Vec<usize>,
38 pub allocator: Allocator,
39 pub debug: DebugInfo,
40 #[type_hash(skip)]
41 #[cfg_attr(feature = "serde", serde(skip))]
42 pub typemap: TypeMap,
43 pub runtime_properties: Rc<TargetProperties>,
44 pub modes: Rc<RefCell<InstructionModes>>,
45 #[cfg_attr(feature = "serde", serde(skip))]
46 pub properties: Option<Rc<DeviceProperties>>,
47}
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
51pub struct ValidationErrors {
52 errors: Rc<RefCell<Vec<String>>>,
53}
54
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
57#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
58pub struct DebugInfo {
59 pub enabled: bool,
60 pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
61 pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
62 pub source_loc: Option<SourceLoc>,
63 pub entry_loc: Option<SourceLoc>,
64}
65
66#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
69pub struct InstructionModes {
70 pub fp_math_mode: EnumSet<FastMath>,
71}
72
73impl core::hash::Hash for Scope {
74 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
75 self.depth.hash(ra_expand_state);
76 self.instructions.hash(ra_expand_state);
77 self.locals.hash(ra_expand_state);
78 self.matrices.hash(ra_expand_state);
79 self.pipelines.hash(ra_expand_state);
80 self.shared.hash(ra_expand_state);
81 self.const_arrays.hash(ra_expand_state);
82 self.local_arrays.hash(ra_expand_state);
83 self.index_offset_with_output_layout_position
84 .hash(ra_expand_state);
85 }
86}
87
88#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
90#[allow(missing_docs)]
91pub enum ReadingStrategy {
92 OutputLayout,
94 Plain,
96}
97
98impl Scope {
99 pub fn device_properties(&mut self, properties: &DeviceProperties) {
101 self.properties = Some(Rc::new(properties.clone()));
102 }
103 pub fn root(debug_enabled: bool) -> Self {
107 Self {
108 validation_errors: ValidationErrors {
109 errors: Rc::new(RefCell::new(Vec::new())),
110 },
111 depth: 0,
112 instructions: Vec::new(),
113 locals: Vec::new(),
114 matrices: Vec::new(),
115 pipelines: Vec::new(),
116 local_arrays: Vec::new(),
117 shared: Vec::new(),
118 const_arrays: Vec::new(),
119 index_offset_with_output_layout_position: Vec::new(),
120 allocator: Allocator::default(),
121 debug: DebugInfo {
122 enabled: debug_enabled,
123 sources: Default::default(),
124 variable_names: Default::default(),
125 source_loc: None,
126 entry_loc: None,
127 },
128 typemap: Default::default(),
129 runtime_properties: Rc::new(Default::default()),
130 modes: Default::default(),
131 properties: None,
132 }
133 }
134
135 pub fn with_allocator(mut self, allocator: Allocator) -> Self {
137 self.allocator = allocator;
138 self
139 }
140
141 pub fn with_types(mut self, typemap: TypeMap) -> Self {
142 self.typemap = typemap;
143 self
144 }
145
146 pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
148 let matrix = self.allocator.create_matrix(matrix);
149 self.add_matrix(*matrix);
150 matrix
151 }
152
153 pub fn add_matrix(&mut self, variable: Variable) {
154 self.matrices.push(variable);
155 }
156
157 pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
159 let pipeline = self.allocator.create_pipeline(num_stages);
160 self.add_pipeline(*pipeline);
161 pipeline
162 }
163
164 pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ExpandElement {
166 let token = Variable::new(
167 VariableKind::BarrierToken { id, level },
168 Type::semantic(SemanticType::BarrierToken),
169 );
170 ExpandElement::Plain(token)
171 }
172
173 pub fn add_pipeline(&mut self, variable: Variable) {
174 self.pipelines.push(variable);
175 }
176
177 pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
179 self.allocator.create_local_mut(item.into())
180 }
181
182 pub fn add_local_mut(&mut self, var: Variable) {
184 if !self.locals.contains(&var) {
185 self.locals.push(var);
186 }
187 }
188
189 pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
192 self.allocator.create_local_restricted(item)
193 }
194
195 pub fn create_local(&mut self, item: Type) -> ExpandElement {
197 self.allocator.create_local(item)
198 }
199
200 pub fn last_local_index(&self) -> Option<&Variable> {
202 self.locals.last()
203 }
204
205 pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
207 let mut inst = instruction.into();
208 inst.source_loc = self.debug.source_loc.clone();
209 inst.modes = *self.modes.borrow();
210 self.instructions.push(inst)
211 }
212
213 pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
215 let map = self.typemap.borrow();
216 let result = map.get(&TypeId::of::<T>());
217
218 result.cloned()
219 }
220
221 pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
223 let mut map = self.typemap.borrow_mut();
224
225 map.insert(TypeId::of::<T>(), elem);
226 }
227
228 pub fn child(&mut self) -> Self {
230 Self {
231 validation_errors: self.validation_errors.clone(),
232 depth: self.depth + 1,
233 instructions: Vec::new(),
234 locals: Vec::new(),
235 matrices: Vec::new(),
236 pipelines: Vec::new(),
237 shared: Vec::new(),
238 const_arrays: Vec::new(),
239 local_arrays: Vec::new(),
240 index_offset_with_output_layout_position: Vec::new(),
241 allocator: self.allocator.clone(),
242 debug: self.debug.clone(),
243 typemap: self.typemap.clone(),
244 runtime_properties: self.runtime_properties.clone(),
245 modes: self.modes.clone(),
246 properties: self.properties.clone(),
247 }
248 }
249
250 pub fn push_error(&mut self, msg: impl Into<String>) {
252 self.validation_errors.errors.borrow_mut().push(msg.into());
253 }
254
255 pub fn pop_errors(&mut self) -> Vec<String> {
257 self.validation_errors.errors.replace_with(|_| Vec::new())
258 }
259
260 pub fn process<'a>(
267 &mut self,
268 processors: impl IntoIterator<Item = &'a dyn Processor>,
269 ) -> ScopeProcessing {
270 let mut variables = core::mem::take(&mut self.locals);
271
272 for var in self.matrices.drain(..) {
273 variables.push(var);
274 }
275
276 let mut instructions = Vec::new();
277
278 for inst in self.instructions.drain(..) {
279 instructions.push(inst);
280 }
281
282 variables.extend(self.allocator.take_variables());
283
284 let mut processing = ScopeProcessing {
285 variables,
286 instructions,
287 typemap: self.typemap.clone(),
288 };
289
290 for p in processors {
291 processing = p.transform(processing, self.allocator.clone());
292 }
293
294 processing.variables.extend(self.allocator.take_variables());
296
297 processing
298 }
299
300 pub fn new_local_index(&self) -> u32 {
301 self.allocator.new_local_index()
302 }
303
304 pub fn create_shared_array<I: Into<Type>>(
306 &mut self,
307 item: I,
308 shared_memory_size: usize,
309 alignment: Option<usize>,
310 ) -> ExpandElement {
311 let item = item.into();
312 let index = self.new_local_index();
313 let shared_array = Variable::new(
314 VariableKind::SharedArray {
315 id: index,
316 length: shared_memory_size,
317 unroll_factor: 1,
318 alignment,
319 },
320 item,
321 );
322 self.shared.push(shared_array);
323 ExpandElement::Plain(shared_array)
324 }
325
326 pub fn create_shared<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
328 let item = item.into();
329 let index = self.new_local_index();
330 let shared = Variable::new(VariableKind::Shared { id: index }, item);
331 self.shared.push(shared);
332 ExpandElement::Plain(shared)
333 }
334
335 pub fn create_const_array<I: Into<Type>>(
337 &mut self,
338 item: I,
339 data: Vec<Variable>,
340 ) -> ExpandElement {
341 let item = item.into();
342 let index = self.new_local_index();
343 let const_array = Variable::new(
344 VariableKind::ConstantArray {
345 id: index,
346 length: data.len(),
347 unroll_factor: 1,
348 },
349 item,
350 );
351 self.const_arrays.push((const_array, data));
352 ExpandElement::Plain(const_array)
353 }
354
355 pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
357 ExpandElement::Plain(crate::Variable::new(
358 VariableKind::GlobalInputArray(id),
359 item,
360 ))
361 }
362
363 pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
365 let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
366 ExpandElement::Plain(var)
367 }
368
369 pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
371 ExpandElement::Plain(crate::Variable::new(
372 VariableKind::GlobalScalar(id),
373 Type::new(storage),
374 ))
375 }
376
377 pub fn create_local_array<I: Into<Type>>(
379 &mut self,
380 item: I,
381 array_size: usize,
382 ) -> ExpandElement {
383 let local_array = self.allocator.create_local_array(item.into(), array_size);
384 self.add_local_array(*local_array);
385 local_array
386 }
387
388 pub fn add_local_array(&mut self, var: Variable) {
389 self.local_arrays.push(var);
390 }
391
392 pub fn update_source(&mut self, source: CubeFnSource) {
393 if self.debug.enabled {
394 self.debug.sources.borrow_mut().insert(source.clone());
395 self.debug.source_loc = Some(SourceLoc {
396 line: source.line,
397 column: source.column,
398 source,
399 });
400 if self.debug.entry_loc.is_none() {
401 self.debug.entry_loc = self.debug.source_loc.clone();
402 }
403 }
404 }
405
406 pub fn update_span(&mut self, line: u32, col: u32) {
407 if let Some(loc) = self.debug.source_loc.as_mut() {
408 loc.line = line;
409 loc.column = col;
410 }
411 }
412
413 pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
414 if self.debug.enabled {
415 self.debug
416 .variable_names
417 .borrow_mut()
418 .insert(variable, name.into());
419 }
420 }
421}
422
423impl Display for Scope {
424 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
425 writeln!(f, "{{")?;
426 for instruction in self.instructions.iter() {
427 let instruction_str = instruction.to_string();
428 if !instruction_str.is_empty() {
429 writeln!(
430 f,
431 "{}{}",
432 " ".repeat(self.depth as usize + 1),
433 instruction_str,
434 )?;
435 }
436 }
437 write!(f, "{}}}", " ".repeat(self.depth as usize))?;
438 Ok(())
439 }
440}