Skip to main content

cubecl_ir/
scope.rs

1use alloc::{borrow::Cow, rc::Rc, string::String, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7    BarrierLevel, CubeFnSource, DeviceProperties, FastMath, ManagedVariable, Matrix, Processor,
8    SemanticType, SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12    Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15pub type TypeMap = Rc<RefCell<HashMap<TypeId, StorageType>>>;
16pub type SizeMap = Rc<RefCell<HashMap<TypeId, usize>>>;
17
18/// The scope is the main [`crate::Operation`] and [`crate::Variable`] container that simplify
19/// the process of reading inputs, creating local variables and adding new operations.
20///
21/// Notes:
22///
23/// This type isn't responsible for creating shader bindings and figuring out which
24/// variable can be written to.
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
27#[allow(missing_docs)]
28pub struct Scope {
29    validation_errors: ValidationErrors,
30    pub depth: u8,
31    pub instructions: Vec<Instruction>,
32    pub locals: Vec<Variable>,
33    matrices: Vec<Variable>,
34    pipelines: Vec<Variable>,
35    shared: Vec<Variable>,
36    pub const_arrays: Vec<(Variable, Vec<Variable>)>,
37    local_arrays: Vec<Variable>,
38    index_offset_with_output_layout_position: Vec<usize>,
39    pub allocator: Allocator,
40    pub debug: DebugInfo,
41    #[type_hash(skip)]
42    #[cfg_attr(feature = "serde", serde(skip))]
43    pub typemap: TypeMap,
44    #[type_hash(skip)]
45    #[cfg_attr(feature = "serde", serde(skip))]
46    pub sizemap: SizeMap,
47    pub runtime_properties: Rc<TargetProperties>,
48    pub modes: Rc<RefCell<InstructionModes>>,
49    #[cfg_attr(feature = "serde", serde(skip))]
50    pub properties: Option<Rc<DeviceProperties>>,
51}
52
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
55pub struct ValidationErrors {
56    errors: Rc<RefCell<Vec<String>>>,
57}
58
59/// Debug related fields, most of these are global
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
62pub struct DebugInfo {
63    pub enabled: bool,
64    pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
65    pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
66    pub source_loc: Option<SourceLoc>,
67    pub entry_loc: Option<SourceLoc>,
68}
69
70/// Modes set and reset during expansion
71#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
72#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
73pub struct InstructionModes {
74    pub fp_math_mode: EnumSet<FastMath>,
75}
76
77impl core::hash::Hash for Scope {
78    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
79        self.depth.hash(ra_expand_state);
80        self.instructions.hash(ra_expand_state);
81        self.locals.hash(ra_expand_state);
82        self.matrices.hash(ra_expand_state);
83        self.pipelines.hash(ra_expand_state);
84        self.shared.hash(ra_expand_state);
85        self.const_arrays.hash(ra_expand_state);
86        self.local_arrays.hash(ra_expand_state);
87        self.index_offset_with_output_layout_position
88            .hash(ra_expand_state);
89    }
90}
91
92#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
94#[allow(missing_docs)]
95pub enum ReadingStrategy {
96    /// Each element will be read in a way to be compatible with the output layout.
97    OutputLayout,
98    /// Keep the current layout.
99    Plain,
100}
101
102impl Scope {
103    /// Set the device properties.
104    pub fn device_properties(&mut self, properties: &DeviceProperties) {
105        self.properties = Some(Rc::new(properties.clone()));
106    }
107    /// Create a scope that is at the root of a kernel definition.
108    ///
109    /// A local scope can be created with the [child](Self::child) method.
110    pub fn root(debug_enabled: bool) -> Self {
111        Self {
112            validation_errors: ValidationErrors {
113                errors: Rc::new(RefCell::new(Vec::new())),
114            },
115            depth: 0,
116            instructions: Vec::new(),
117            locals: Vec::new(),
118            matrices: Vec::new(),
119            pipelines: Vec::new(),
120            local_arrays: Vec::new(),
121            shared: Vec::new(),
122            const_arrays: Vec::new(),
123            index_offset_with_output_layout_position: Vec::new(),
124            allocator: Allocator::default(),
125            debug: DebugInfo {
126                enabled: debug_enabled,
127                sources: Default::default(),
128                variable_names: Default::default(),
129                source_loc: None,
130                entry_loc: None,
131            },
132            typemap: Default::default(),
133            sizemap: Default::default(),
134            runtime_properties: Rc::new(Default::default()),
135            modes: Default::default(),
136            properties: None,
137        }
138    }
139
140    /// Shift variable ids.
141    pub fn with_allocator(mut self, allocator: Allocator) -> Self {
142        self.allocator = allocator;
143        self
144    }
145
146    pub fn with_types(mut self, typemap: TypeMap) -> Self {
147        self.typemap = typemap;
148        self
149    }
150
151    /// Create a new matrix element.
152    pub fn create_matrix(&mut self, matrix: Matrix) -> ManagedVariable {
153        let matrix = self.allocator.create_matrix(matrix);
154        self.add_matrix(*matrix);
155        matrix
156    }
157
158    pub fn add_matrix(&mut self, variable: Variable) {
159        self.matrices.push(variable);
160    }
161
162    /// Create a new pipeline element.
163    pub fn create_pipeline(&mut self, num_stages: u8) -> ManagedVariable {
164        let pipeline = self.allocator.create_pipeline(num_stages);
165        self.add_pipeline(*pipeline);
166        pipeline
167    }
168
169    /// Create a new barrier element.
170    pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ManagedVariable {
171        let token = Variable::new(
172            VariableKind::BarrierToken { id, level },
173            Type::semantic(SemanticType::BarrierToken),
174        );
175        ManagedVariable::Plain(token)
176    }
177
178    pub fn add_pipeline(&mut self, variable: Variable) {
179        self.pipelines.push(variable);
180    }
181
182    /// Create a mutable variable of the given item type.
183    pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ManagedVariable {
184        self.allocator.create_local_mut(item.into())
185    }
186
187    /// Create a mutable variable of the given item type.
188    pub fn add_local_mut(&mut self, var: Variable) {
189        if !self.locals.contains(&var) {
190            self.locals.push(var);
191        }
192    }
193
194    /// Create a new restricted variable. The variable is
195    /// Useful for _for loops_ and other algorithms that require the control over initialization.
196    pub fn create_local_restricted(&mut self, item: Type) -> ManagedVariable {
197        self.allocator.create_local_restricted(item)
198    }
199
200    /// Create a new immutable variable.
201    pub fn create_local(&mut self, item: Type) -> ManagedVariable {
202        self.allocator.create_local(item)
203    }
204
205    /// Retrieve the last local variable that was created.
206    pub fn last_local_index(&self) -> Option<&Variable> {
207        self.locals.last()
208    }
209
210    /// Register an [`Instruction`] into the scope.
211    pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
212        let mut inst = instruction.into();
213        inst.source_loc = self.debug.source_loc.clone();
214        inst.modes = *self.modes.borrow();
215        self.instructions.push(inst)
216    }
217
218    /// Resolve the element type of the given generic type.
219    pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
220        let map = self.typemap.borrow();
221        let result = map.get(&TypeId::of::<T>());
222
223        result.cloned()
224    }
225
226    /// Resolve the comptime size of the given generic size.
227    pub fn resolve_size<T: 'static>(&self) -> Option<usize> {
228        let map = self.sizemap.borrow();
229        let result = map.get(&TypeId::of::<T>());
230
231        result.cloned()
232    }
233
234    /// Register the element type for the given generic type.
235    pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
236        let mut map = self.typemap.borrow_mut();
237
238        map.insert(TypeId::of::<T>(), elem);
239    }
240
241    /// Register the comptime size for the given generic size.
242    pub fn register_size<T: 'static>(&mut self, size: usize) {
243        let mut map = self.sizemap.borrow_mut();
244
245        map.insert(TypeId::of::<T>(), size);
246    }
247
248    /// Create an empty child scope.
249    pub fn child(&mut self) -> Self {
250        Self {
251            validation_errors: self.validation_errors.clone(),
252            depth: self.depth + 1,
253            instructions: Vec::new(),
254            locals: Vec::new(),
255            matrices: Vec::new(),
256            pipelines: Vec::new(),
257            shared: Vec::new(),
258            const_arrays: Vec::new(),
259            local_arrays: Vec::new(),
260            index_offset_with_output_layout_position: Vec::new(),
261            allocator: self.allocator.clone(),
262            debug: self.debug.clone(),
263            typemap: self.typemap.clone(),
264            sizemap: self.sizemap.clone(),
265            runtime_properties: self.runtime_properties.clone(),
266            modes: self.modes.clone(),
267            properties: self.properties.clone(),
268        }
269    }
270
271    // Adds a validation error.
272    pub fn push_error(&mut self, msg: impl Into<String>) {
273        self.validation_errors.errors.borrow_mut().push(msg.into());
274    }
275
276    /// Returns all validation errors.
277    pub fn pop_errors(&mut self) -> Vec<String> {
278        self.validation_errors.errors.replace_with(|_| Vec::new())
279    }
280
281    /// Returns the variables and operations to be declared and executed.
282    ///
283    /// Notes:
284    ///
285    /// New operations and variables can be created within the same scope without having name
286    /// conflicts.
287    pub fn process<'a>(
288        &mut self,
289        processors: impl IntoIterator<Item = &'a dyn Processor>,
290    ) -> ScopeProcessing {
291        let mut variables = core::mem::take(&mut self.locals);
292
293        for var in self.matrices.drain(..) {
294            variables.push(var);
295        }
296
297        let mut instructions = Vec::new();
298
299        for inst in self.instructions.drain(..) {
300            instructions.push(inst);
301        }
302
303        variables.extend(self.allocator.take_variables());
304
305        let mut processing = ScopeProcessing {
306            variables,
307            instructions,
308            typemap: self.typemap.clone(),
309        };
310
311        for p in processors {
312            processing = p.transform(processing, self.allocator.clone());
313        }
314
315        // Add variables added from processors
316        processing.variables.extend(self.allocator.take_variables());
317
318        processing
319    }
320
321    pub fn new_local_index(&self) -> u32 {
322        self.allocator.new_local_index()
323    }
324
325    /// Create a shared array variable of the given item type.
326    pub fn create_shared_array<I: Into<Type>>(
327        &mut self,
328        item: I,
329        shared_memory_size: usize,
330        alignment: Option<usize>,
331    ) -> ManagedVariable {
332        let item = item.into();
333        let index = self.new_local_index();
334        let shared_array = Variable::new(
335            VariableKind::SharedArray {
336                id: index,
337                length: shared_memory_size,
338                unroll_factor: 1,
339                alignment,
340            },
341            item,
342        );
343        self.shared.push(shared_array);
344        ManagedVariable::Plain(shared_array)
345    }
346
347    /// Create a shared variable of the given item type.
348    pub fn create_shared<I: Into<Type>>(&mut self, item: I) -> ManagedVariable {
349        let item = item.into();
350        let index = self.new_local_index();
351        let shared = Variable::new(VariableKind::Shared { id: index }, item);
352        self.shared.push(shared);
353        ManagedVariable::Plain(shared)
354    }
355
356    /// Create a shared variable of the given item type.
357    pub fn create_const_array<I: Into<Type>>(
358        &mut self,
359        item: I,
360        data: Vec<Variable>,
361    ) -> ManagedVariable {
362        let item = item.into();
363        let index = self.new_local_index();
364        let const_array = Variable::new(
365            VariableKind::ConstantArray {
366                id: index,
367                length: data.len(),
368                unroll_factor: 1,
369            },
370            item,
371        );
372        self.const_arrays.push((const_array, data));
373        ManagedVariable::Plain(const_array)
374    }
375
376    /// Obtain the index-th input
377    pub fn input(&mut self, id: Id, item: Type) -> ManagedVariable {
378        ManagedVariable::Plain(crate::Variable::new(
379            VariableKind::GlobalInputArray(id),
380            item,
381        ))
382    }
383
384    /// Obtain the index-th output
385    pub fn output(&mut self, id: Id, item: Type) -> ManagedVariable {
386        let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
387        ManagedVariable::Plain(var)
388    }
389
390    /// Obtain the index-th scalar
391    pub fn scalar(&self, id: Id, storage: StorageType) -> ManagedVariable {
392        ManagedVariable::Plain(crate::Variable::new(
393            VariableKind::GlobalScalar(id),
394            Type::new(storage),
395        ))
396    }
397
398    /// Create a local array of the given item type.
399    pub fn create_local_array<I: Into<Type>>(
400        &mut self,
401        item: I,
402        array_size: usize,
403    ) -> ManagedVariable {
404        let local_array = self.allocator.create_local_array(item.into(), array_size);
405        self.add_local_array(*local_array);
406        local_array
407    }
408
409    pub fn add_local_array(&mut self, var: Variable) {
410        self.local_arrays.push(var);
411    }
412
413    pub fn update_source(&mut self, source: CubeFnSource) {
414        if self.debug.enabled {
415            self.debug.sources.borrow_mut().insert(source.clone());
416            self.debug.source_loc = Some(SourceLoc {
417                line: source.line,
418                column: source.column,
419                source,
420            });
421            if self.debug.entry_loc.is_none() {
422                self.debug.entry_loc = self.debug.source_loc.clone();
423            }
424        }
425    }
426
427    pub fn update_span(&mut self, line: u32, col: u32) {
428        if let Some(loc) = self.debug.source_loc.as_mut() {
429            loc.line = line;
430            loc.column = col;
431        }
432    }
433
434    pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
435        if self.debug.enabled {
436            self.debug
437                .variable_names
438                .borrow_mut()
439                .insert(variable, name.into());
440        }
441    }
442}
443
444impl Display for Scope {
445    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
446        writeln!(f, "{{")?;
447        for instruction in self.instructions.iter() {
448            let instruction_str = instruction.to_string();
449            if !instruction_str.is_empty() {
450                writeln!(
451                    f,
452                    "{}{}",
453                    "    ".repeat(self.depth as usize + 1),
454                    instruction_str,
455                )?;
456            }
457        }
458        write!(f, "{}}}", "    ".repeat(self.depth as usize))?;
459        Ok(())
460    }
461}