1use alloc::{borrow::Cow, rc::Rc, string::String, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7 BarrierLevel, CubeFnSource, DeviceProperties, ExpandElement, FastMath, Matrix, Processor,
8 SemanticType, SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12 Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15pub type TypeMap = Rc<RefCell<HashMap<TypeId, StorageType>>>;
16
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
26#[allow(missing_docs)]
27pub struct Scope {
28 validation_errors: ValidationErrors,
29 pub depth: u8,
30 pub instructions: Vec<Instruction>,
31 pub locals: Vec<Variable>,
32 matrices: Vec<Variable>,
33 pipelines: Vec<Variable>,
34 shared: Vec<Variable>,
35 pub const_arrays: Vec<(Variable, Vec<Variable>)>,
36 local_arrays: Vec<Variable>,
37 index_offset_with_output_layout_position: Vec<usize>,
38 pub allocator: Allocator,
39 pub debug: DebugInfo,
40 #[type_hash(skip)]
41 #[cfg_attr(feature = "serde", serde(skip))]
42 pub typemap: TypeMap,
43 pub runtime_properties: Rc<TargetProperties>,
44 pub modes: Rc<RefCell<InstructionModes>>,
45 #[cfg_attr(feature = "serde", serde(skip))]
46 pub properties: Option<Rc<DeviceProperties>>,
47}
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
51pub struct ValidationErrors {
52 errors: Rc<RefCell<Vec<String>>>,
53}
54
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
57#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
58pub struct DebugInfo {
59 pub enabled: bool,
60 pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
61 pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
62 pub source_loc: Option<SourceLoc>,
63 pub entry_loc: Option<SourceLoc>,
64}
65
66#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
69pub struct InstructionModes {
70 pub fp_math_mode: EnumSet<FastMath>,
71}
72
73impl core::hash::Hash for Scope {
74 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
75 self.depth.hash(ra_expand_state);
76 self.instructions.hash(ra_expand_state);
77 self.locals.hash(ra_expand_state);
78 self.matrices.hash(ra_expand_state);
79 self.pipelines.hash(ra_expand_state);
80 self.shared.hash(ra_expand_state);
81 self.const_arrays.hash(ra_expand_state);
82 self.local_arrays.hash(ra_expand_state);
83 self.index_offset_with_output_layout_position
84 .hash(ra_expand_state);
85 }
86}
87
88#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
90#[allow(missing_docs)]
91pub enum ReadingStrategy {
92 OutputLayout,
94 Plain,
96}
97
98impl Scope {
99 pub fn device_properties(&mut self, properties: &DeviceProperties) {
101 self.properties = Some(Rc::new(properties.clone()));
102 }
103 pub fn root(debug_enabled: bool) -> Self {
108 Self {
109 validation_errors: ValidationErrors {
110 errors: Rc::new(RefCell::new(Vec::new())),
111 },
112 depth: 0,
113 instructions: Vec::new(),
114 locals: Vec::new(),
115 matrices: Vec::new(),
116 pipelines: Vec::new(),
117 local_arrays: Vec::new(),
118 shared: Vec::new(),
119 const_arrays: Vec::new(),
120 index_offset_with_output_layout_position: Vec::new(),
121 allocator: Allocator::default(),
122 debug: DebugInfo {
123 enabled: debug_enabled,
124 sources: Default::default(),
125 variable_names: Default::default(),
126 source_loc: None,
127 entry_loc: None,
128 },
129 typemap: Default::default(),
130 runtime_properties: Rc::new(Default::default()),
131 modes: Default::default(),
132 properties: None,
133 }
134 }
135
136 pub fn with_allocator(mut self, allocator: Allocator) -> Self {
138 self.allocator = allocator;
139 self
140 }
141
142 pub fn with_types(mut self, typemap: TypeMap) -> Self {
143 self.typemap = typemap;
144 self
145 }
146
147 pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
149 let matrix = self.allocator.create_matrix(matrix);
150 self.add_matrix(*matrix);
151 matrix
152 }
153
154 pub fn add_matrix(&mut self, variable: Variable) {
155 self.matrices.push(variable);
156 }
157
158 pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
160 let pipeline = self.allocator.create_pipeline(num_stages);
161 self.add_pipeline(*pipeline);
162 pipeline
163 }
164
165 pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ExpandElement {
167 let token = Variable::new(
168 VariableKind::BarrierToken { id, level },
169 Type::semantic(SemanticType::BarrierToken),
170 );
171 ExpandElement::Plain(token)
172 }
173
174 pub fn add_pipeline(&mut self, variable: Variable) {
175 self.pipelines.push(variable);
176 }
177
178 pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
180 self.allocator.create_local_mut(item.into())
181 }
182
183 pub fn add_local_mut(&mut self, var: Variable) {
185 if !self.locals.contains(&var) {
186 self.locals.push(var);
187 }
188 }
189
190 pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
193 self.allocator.create_local_restricted(item)
194 }
195
196 pub fn create_local(&mut self, item: Type) -> ExpandElement {
198 self.allocator.create_local(item)
199 }
200
201 pub fn last_local_index(&self) -> Option<&Variable> {
203 self.locals.last()
204 }
205
206 pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
208 let mut inst = instruction.into();
209 inst.source_loc = self.debug.source_loc.clone();
210 inst.modes = *self.modes.borrow();
211 self.instructions.push(inst)
212 }
213
214 pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
216 let map = self.typemap.borrow();
217 let result = map.get(&TypeId::of::<T>());
218
219 result.cloned()
220 }
221
222 pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
224 let mut map = self.typemap.borrow_mut();
225
226 map.insert(TypeId::of::<T>(), elem);
227 }
228
229 pub fn child(&mut self) -> Self {
231 Self {
232 validation_errors: self.validation_errors.clone(),
233 depth: self.depth + 1,
234 instructions: Vec::new(),
235 locals: Vec::new(),
236 matrices: Vec::new(),
237 pipelines: Vec::new(),
238 shared: Vec::new(),
239 const_arrays: Vec::new(),
240 local_arrays: Vec::new(),
241 index_offset_with_output_layout_position: Vec::new(),
242 allocator: self.allocator.clone(),
243 debug: self.debug.clone(),
244 typemap: self.typemap.clone(),
245 runtime_properties: self.runtime_properties.clone(),
246 modes: self.modes.clone(),
247 properties: self.properties.clone(),
248 }
249 }
250
251 pub fn push_error(&mut self, msg: impl Into<String>) {
253 self.validation_errors.errors.borrow_mut().push(msg.into());
254 }
255
256 pub fn pop_errors(&mut self) -> Vec<String> {
258 self.validation_errors.errors.replace_with(|_| Vec::new())
259 }
260
261 pub fn process<'a>(
268 &mut self,
269 processors: impl IntoIterator<Item = &'a dyn Processor>,
270 ) -> ScopeProcessing {
271 let mut variables = core::mem::take(&mut self.locals);
272
273 for var in self.matrices.drain(..) {
274 variables.push(var);
275 }
276
277 let mut instructions = Vec::new();
278
279 for inst in self.instructions.drain(..) {
280 instructions.push(inst);
281 }
282
283 variables.extend(self.allocator.take_variables());
284
285 let mut processing = ScopeProcessing {
286 variables,
287 instructions,
288 typemap: self.typemap.clone(),
289 };
290
291 for p in processors {
292 processing = p.transform(processing, self.allocator.clone());
293 }
294
295 processing.variables.extend(self.allocator.take_variables());
297
298 processing
299 }
300
301 pub fn new_local_index(&self) -> u32 {
302 self.allocator.new_local_index()
303 }
304
305 pub fn create_shared_array<I: Into<Type>>(
307 &mut self,
308 item: I,
309 shared_memory_size: usize,
310 alignment: Option<usize>,
311 ) -> ExpandElement {
312 let item = item.into();
313 let index = self.new_local_index();
314 let shared_array = Variable::new(
315 VariableKind::SharedArray {
316 id: index,
317 length: shared_memory_size,
318 unroll_factor: 1,
319 alignment,
320 },
321 item,
322 );
323 self.shared.push(shared_array);
324 ExpandElement::Plain(shared_array)
325 }
326
327 pub fn create_shared<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
329 let item = item.into();
330 let index = self.new_local_index();
331 let shared = Variable::new(VariableKind::Shared { id: index }, item);
332 self.shared.push(shared);
333 ExpandElement::Plain(shared)
334 }
335
336 pub fn create_const_array<I: Into<Type>>(
338 &mut self,
339 item: I,
340 data: Vec<Variable>,
341 ) -> ExpandElement {
342 let item = item.into();
343 let index = self.new_local_index();
344 let const_array = Variable::new(
345 VariableKind::ConstantArray {
346 id: index,
347 length: data.len(),
348 unroll_factor: 1,
349 },
350 item,
351 );
352 self.const_arrays.push((const_array, data));
353 ExpandElement::Plain(const_array)
354 }
355
356 pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
358 ExpandElement::Plain(crate::Variable::new(
359 VariableKind::GlobalInputArray(id),
360 item,
361 ))
362 }
363
364 pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
366 let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
367 ExpandElement::Plain(var)
368 }
369
370 pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
372 ExpandElement::Plain(crate::Variable::new(
373 VariableKind::GlobalScalar(id),
374 Type::new(storage),
375 ))
376 }
377
378 pub fn create_local_array<I: Into<Type>>(
380 &mut self,
381 item: I,
382 array_size: usize,
383 ) -> ExpandElement {
384 let local_array = self.allocator.create_local_array(item.into(), array_size);
385 self.add_local_array(*local_array);
386 local_array
387 }
388
389 pub fn add_local_array(&mut self, var: Variable) {
390 self.local_arrays.push(var);
391 }
392
393 pub fn update_source(&mut self, source: CubeFnSource) {
394 if self.debug.enabled {
395 self.debug.sources.borrow_mut().insert(source.clone());
396 self.debug.source_loc = Some(SourceLoc {
397 line: source.line,
398 column: source.column,
399 source,
400 });
401 if self.debug.entry_loc.is_none() {
402 self.debug.entry_loc = self.debug.source_loc.clone();
403 }
404 }
405 }
406
407 pub fn update_span(&mut self, line: u32, col: u32) {
408 if let Some(loc) = self.debug.source_loc.as_mut() {
409 loc.line = line;
410 loc.column = col;
411 }
412 }
413
414 pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
415 if self.debug.enabled {
416 self.debug
417 .variable_names
418 .borrow_mut()
419 .insert(variable, name.into());
420 }
421 }
422}
423
424impl Display for Scope {
425 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
426 writeln!(f, "{{")?;
427 for instruction in self.instructions.iter() {
428 let instruction_str = instruction.to_string();
429 if !instruction_str.is_empty() {
430 writeln!(
431 f,
432 "{}{}",
433 " ".repeat(self.depth as usize + 1),
434 instruction_str,
435 )?;
436 }
437 }
438 write!(f, "{}}}", " ".repeat(self.depth as usize))?;
439 Ok(())
440 }
441}