cubecl_ir/
scope.rs

1use alloc::{borrow::Cow, rc::Rc, string::String, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7    BarrierLevel, CubeFnSource, ExpandElement, FastMath, Matrix, Processor, SemanticType,
8    SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12    Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15/// The scope is the main [operation](Operation) and [variable](Variable) container that simplify
16/// the process of reading inputs, creating local variables and adding new operations.
17///
18/// Notes:
19///
20/// This type isn't responsible for creating [shader bindings](super::Binding) and figuring out which
21/// variable can be written to.
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
24#[allow(missing_docs)]
25pub struct Scope {
26    validation_errors: ValidationErrors,
27    pub depth: u8,
28    pub instructions: Vec<Instruction>,
29    pub locals: Vec<Variable>,
30    matrices: Vec<Variable>,
31    pipelines: Vec<Variable>,
32    shared: Vec<Variable>,
33    pub const_arrays: Vec<(Variable, Vec<Variable>)>,
34    local_arrays: Vec<Variable>,
35    index_offset_with_output_layout_position: Vec<usize>,
36    pub allocator: Allocator,
37    pub debug: DebugInfo,
38    #[type_hash(skip)]
39    #[cfg_attr(feature = "serde", serde(skip))]
40    pub typemap: Rc<RefCell<HashMap<TypeId, StorageType>>>,
41    pub runtime_properties: Rc<TargetProperties>,
42    pub modes: Rc<RefCell<InstructionModes>>,
43}
44
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
47pub struct ValidationErrors {
48    errors: Rc<RefCell<Vec<String>>>,
49}
50
51/// Debug related fields, most of these are global
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
54pub struct DebugInfo {
55    pub enabled: bool,
56    pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
57    pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
58    pub source_loc: Option<SourceLoc>,
59    pub entry_loc: Option<SourceLoc>,
60}
61
62/// Modes set and reset during expansion
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
64#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
65pub struct InstructionModes {
66    pub fp_math_mode: EnumSet<FastMath>,
67}
68
69impl core::hash::Hash for Scope {
70    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
71        self.depth.hash(ra_expand_state);
72        self.instructions.hash(ra_expand_state);
73        self.locals.hash(ra_expand_state);
74        self.matrices.hash(ra_expand_state);
75        self.pipelines.hash(ra_expand_state);
76        self.shared.hash(ra_expand_state);
77        self.const_arrays.hash(ra_expand_state);
78        self.local_arrays.hash(ra_expand_state);
79        self.index_offset_with_output_layout_position
80            .hash(ra_expand_state);
81    }
82}
83
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
86#[allow(missing_docs)]
87pub enum ReadingStrategy {
88    /// Each element will be read in a way to be compatible with the output layout.
89    OutputLayout,
90    /// Keep the current layout.
91    Plain,
92}
93
94impl Scope {
95    /// Create a scope that is at the root of a
96    /// [kernel definition](crate::ir::KernelDefinition).
97    ///
98    /// A local scope can be created with the [child](Self::child) method.
99    pub fn root(debug_enabled: bool) -> Self {
100        Self {
101            validation_errors: ValidationErrors {
102                errors: Rc::new(RefCell::new(Vec::new())),
103            },
104            depth: 0,
105            instructions: Vec::new(),
106            locals: Vec::new(),
107            matrices: Vec::new(),
108            pipelines: Vec::new(),
109            local_arrays: Vec::new(),
110            shared: Vec::new(),
111            const_arrays: Vec::new(),
112            index_offset_with_output_layout_position: Vec::new(),
113            allocator: Allocator::default(),
114            debug: DebugInfo {
115                enabled: debug_enabled,
116                sources: Default::default(),
117                variable_names: Default::default(),
118                source_loc: None,
119                entry_loc: None,
120            },
121            typemap: Default::default(),
122            runtime_properties: Rc::new(Default::default()),
123            modes: Default::default(),
124        }
125    }
126
127    /// Shift variable ids.
128    pub fn with_allocator(mut self, allocator: Allocator) -> Self {
129        self.allocator = allocator;
130        self
131    }
132
133    /// Create a new matrix element.
134    pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
135        let matrix = self.allocator.create_matrix(matrix);
136        self.add_matrix(*matrix);
137        matrix
138    }
139
140    pub fn add_matrix(&mut self, variable: Variable) {
141        self.matrices.push(variable);
142    }
143
144    /// Create a new pipeline element.
145    pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
146        let pipeline = self.allocator.create_pipeline(num_stages);
147        self.add_pipeline(*pipeline);
148        pipeline
149    }
150
151    /// Create a new barrier element.
152    pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ExpandElement {
153        let token = Variable::new(
154            VariableKind::BarrierToken { id, level },
155            Type::semantic(SemanticType::BarrierToken),
156        );
157        ExpandElement::Plain(token)
158    }
159
160    pub fn add_pipeline(&mut self, variable: Variable) {
161        self.pipelines.push(variable);
162    }
163
164    /// Create a mutable variable of the given [item type](Item).
165    pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
166        self.allocator.create_local_mut(item.into())
167    }
168
169    /// Create a mutable variable of the given [item type](Item).
170    pub fn add_local_mut(&mut self, var: Variable) {
171        if !self.locals.contains(&var) {
172            self.locals.push(var);
173        }
174    }
175
176    /// Create a new restricted variable. The variable is
177    /// Useful for _for loops_ and other algorithms that require the control over initialization.
178    pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
179        self.allocator.create_local_restricted(item)
180    }
181
182    /// Create a new immutable variable.
183    pub fn create_local(&mut self, item: Type) -> ExpandElement {
184        self.allocator.create_local(item)
185    }
186
187    /// Retrieve the last local variable that was created.
188    pub fn last_local_index(&self) -> Option<&Variable> {
189        self.locals.last()
190    }
191
192    /// Register an [operation](Operation) into the scope.
193    pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
194        let mut inst = instruction.into();
195        inst.source_loc = self.debug.source_loc.clone();
196        inst.modes = *self.modes.borrow();
197        self.instructions.push(inst)
198    }
199
200    /// Resolve the element type of the given generic type.
201    pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
202        let map = self.typemap.borrow();
203        let result = map.get(&TypeId::of::<T>());
204
205        result.cloned()
206    }
207
208    /// Register the element type for the given generic type.
209    pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
210        let mut map = self.typemap.borrow_mut();
211
212        map.insert(TypeId::of::<T>(), elem);
213    }
214
215    /// Create an empty child scope.
216    pub fn child(&mut self) -> Self {
217        Self {
218            validation_errors: self.validation_errors.clone(),
219            depth: self.depth + 1,
220            instructions: Vec::new(),
221            locals: Vec::new(),
222            matrices: Vec::new(),
223            pipelines: Vec::new(),
224            shared: Vec::new(),
225            const_arrays: Vec::new(),
226            local_arrays: Vec::new(),
227            index_offset_with_output_layout_position: Vec::new(),
228            allocator: self.allocator.clone(),
229            debug: self.debug.clone(),
230            typemap: self.typemap.clone(),
231            runtime_properties: self.runtime_properties.clone(),
232            modes: self.modes.clone(),
233        }
234    }
235
236    // Adds a validation error.
237    pub fn push_error(&mut self, msg: impl Into<String>) {
238        self.validation_errors.errors.borrow_mut().push(msg.into());
239    }
240
241    /// Returns all validation errors.
242    pub fn pop_errors(&mut self) -> Vec<String> {
243        self.validation_errors.errors.replace_with(|_| Vec::new())
244    }
245
246    /// Returns the variables and operations to be declared and executed.
247    ///
248    /// Notes:
249    ///
250    /// New operations and variables can be created within the same scope without having name
251    /// conflicts.
252    pub fn process<'a>(
253        &mut self,
254        processors: impl IntoIterator<Item = &'a dyn Processor>,
255    ) -> ScopeProcessing {
256        let mut variables = core::mem::take(&mut self.locals);
257
258        for var in self.matrices.drain(..) {
259            variables.push(var);
260        }
261
262        let mut instructions = Vec::new();
263
264        for inst in self.instructions.drain(..) {
265            instructions.push(inst);
266        }
267
268        variables.extend(self.allocator.take_variables());
269
270        let mut processing = ScopeProcessing {
271            variables,
272            instructions,
273        }
274        .optimize();
275
276        for p in processors {
277            processing = p.transform(processing, self.allocator.clone());
278        }
279
280        // Add variables added from processors
281        processing.variables.extend(self.allocator.take_variables());
282
283        processing
284    }
285
286    pub fn new_local_index(&self) -> u32 {
287        self.allocator.new_local_index()
288    }
289
290    /// Create a shared array variable of the given [item type](Item).
291    pub fn create_shared_array<I: Into<Type>>(
292        &mut self,
293        item: I,
294        shared_memory_size: u32,
295        alignment: Option<u32>,
296    ) -> ExpandElement {
297        let item = item.into();
298        let index = self.new_local_index();
299        let shared_array = Variable::new(
300            VariableKind::SharedArray {
301                id: index,
302                length: shared_memory_size,
303                unroll_factor: 1,
304                alignment,
305            },
306            item,
307        );
308        self.shared.push(shared_array);
309        ExpandElement::Plain(shared_array)
310    }
311
312    /// Create a shared variable of the given [item type](Item).
313    pub fn create_shared<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
314        let item = item.into();
315        let index = self.new_local_index();
316        let shared = Variable::new(VariableKind::Shared { id: index }, item);
317        self.shared.push(shared);
318        ExpandElement::Plain(shared)
319    }
320
321    /// Create a shared variable of the given [item type](Item).
322    pub fn create_const_array<I: Into<Type>>(
323        &mut self,
324        item: I,
325        data: Vec<Variable>,
326    ) -> ExpandElement {
327        let item = item.into();
328        let index = self.new_local_index();
329        let const_array = Variable::new(
330            VariableKind::ConstantArray {
331                id: index,
332                length: data.len() as u32,
333                unroll_factor: 1,
334            },
335            item,
336        );
337        self.const_arrays.push((const_array, data));
338        ExpandElement::Plain(const_array)
339    }
340
341    /// Obtain the index-th input
342    pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
343        ExpandElement::Plain(crate::Variable::new(
344            VariableKind::GlobalInputArray(id),
345            item,
346        ))
347    }
348
349    /// Obtain the index-th output
350    pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
351        let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
352        ExpandElement::Plain(var)
353    }
354
355    /// Obtain the index-th scalar
356    pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
357        ExpandElement::Plain(crate::Variable::new(
358            VariableKind::GlobalScalar(id),
359            Type::new(storage),
360        ))
361    }
362
363    /// Create a local array of the given [item type](Item).
364    pub fn create_local_array<I: Into<Type>>(&mut self, item: I, array_size: u32) -> ExpandElement {
365        let local_array = self.allocator.create_local_array(item.into(), array_size);
366        self.add_local_array(*local_array);
367        local_array
368    }
369
370    pub fn add_local_array(&mut self, var: Variable) {
371        self.local_arrays.push(var);
372    }
373
374    pub fn update_source(&mut self, source: CubeFnSource) {
375        if self.debug.enabled {
376            self.debug.sources.borrow_mut().insert(source.clone());
377            self.debug.source_loc = Some(SourceLoc {
378                line: source.line,
379                column: source.column,
380                source,
381            });
382            if self.debug.entry_loc.is_none() {
383                self.debug.entry_loc = self.debug.source_loc.clone();
384            }
385        }
386    }
387
388    pub fn update_span(&mut self, line: u32, col: u32) {
389        if let Some(loc) = self.debug.source_loc.as_mut() {
390            loc.line = line;
391            loc.column = col;
392        }
393    }
394
395    pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
396        if self.debug.enabled {
397            self.debug
398                .variable_names
399                .borrow_mut()
400                .insert(variable, name.into());
401        }
402    }
403}
404
405impl Display for Scope {
406    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
407        writeln!(f, "{{")?;
408        for instruction in self.instructions.iter() {
409            let instruction_str = instruction.to_string();
410            if !instruction_str.is_empty() {
411                writeln!(
412                    f,
413                    "{}{}",
414                    "    ".repeat(self.depth as usize + 1),
415                    instruction_str,
416                )?;
417            }
418        }
419        write!(f, "{}}}", "    ".repeat(self.depth as usize))?;
420        Ok(())
421    }
422}