cubecl_ir/
scope.rs

1use alloc::{borrow::Cow, rc::Rc, string::String, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7    BarrierLevel, CubeFnSource, DeviceProperties, ExpandElement, FastMath, Matrix, Processor,
8    SemanticType, SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12    Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15pub type TypeMap = Rc<RefCell<HashMap<TypeId, StorageType>>>;
16
17/// The scope is the main [operation](Operation) and [variable](Variable) container that simplify
18/// the process of reading inputs, creating local variables and adding new operations.
19///
20/// Notes:
21///
22/// This type isn't responsible for creating [shader bindings](super::Binding) and figuring out which
23/// variable can be written to.
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
26#[allow(missing_docs)]
27pub struct Scope {
28    validation_errors: ValidationErrors,
29    pub depth: u8,
30    pub instructions: Vec<Instruction>,
31    pub locals: Vec<Variable>,
32    matrices: Vec<Variable>,
33    pipelines: Vec<Variable>,
34    shared: Vec<Variable>,
35    pub const_arrays: Vec<(Variable, Vec<Variable>)>,
36    local_arrays: Vec<Variable>,
37    index_offset_with_output_layout_position: Vec<usize>,
38    pub allocator: Allocator,
39    pub debug: DebugInfo,
40    #[type_hash(skip)]
41    #[cfg_attr(feature = "serde", serde(skip))]
42    pub typemap: TypeMap,
43    pub runtime_properties: Rc<TargetProperties>,
44    pub modes: Rc<RefCell<InstructionModes>>,
45    #[cfg_attr(feature = "serde", serde(skip))]
46    pub properties: Option<Rc<DeviceProperties>>,
47}
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
51pub struct ValidationErrors {
52    errors: Rc<RefCell<Vec<String>>>,
53}
54
55/// Debug related fields, most of these are global
56#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
57#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
58pub struct DebugInfo {
59    pub enabled: bool,
60    pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
61    pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
62    pub source_loc: Option<SourceLoc>,
63    pub entry_loc: Option<SourceLoc>,
64}
65
66/// Modes set and reset during expansion
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
69pub struct InstructionModes {
70    pub fp_math_mode: EnumSet<FastMath>,
71}
72
73impl core::hash::Hash for Scope {
74    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
75        self.depth.hash(ra_expand_state);
76        self.instructions.hash(ra_expand_state);
77        self.locals.hash(ra_expand_state);
78        self.matrices.hash(ra_expand_state);
79        self.pipelines.hash(ra_expand_state);
80        self.shared.hash(ra_expand_state);
81        self.const_arrays.hash(ra_expand_state);
82        self.local_arrays.hash(ra_expand_state);
83        self.index_offset_with_output_layout_position
84            .hash(ra_expand_state);
85    }
86}
87
88#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
90#[allow(missing_docs)]
91pub enum ReadingStrategy {
92    /// Each element will be read in a way to be compatible with the output layout.
93    OutputLayout,
94    /// Keep the current layout.
95    Plain,
96}
97
98impl Scope {
99    /// Set the device properties.
100    pub fn device_properties(&mut self, properties: &DeviceProperties) {
101        self.properties = Some(Rc::new(properties.clone()));
102    }
103    /// Create a scope that is at the root of a
104    /// [kernel definition](crate::ir::KernelDefinition).
105    ///
106    /// A local scope can be created with the [child](Self::child) method.
107    pub fn root(debug_enabled: bool) -> Self {
108        Self {
109            validation_errors: ValidationErrors {
110                errors: Rc::new(RefCell::new(Vec::new())),
111            },
112            depth: 0,
113            instructions: Vec::new(),
114            locals: Vec::new(),
115            matrices: Vec::new(),
116            pipelines: Vec::new(),
117            local_arrays: Vec::new(),
118            shared: Vec::new(),
119            const_arrays: Vec::new(),
120            index_offset_with_output_layout_position: Vec::new(),
121            allocator: Allocator::default(),
122            debug: DebugInfo {
123                enabled: debug_enabled,
124                sources: Default::default(),
125                variable_names: Default::default(),
126                source_loc: None,
127                entry_loc: None,
128            },
129            typemap: Default::default(),
130            runtime_properties: Rc::new(Default::default()),
131            modes: Default::default(),
132            properties: None,
133        }
134    }
135
136    /// Shift variable ids.
137    pub fn with_allocator(mut self, allocator: Allocator) -> Self {
138        self.allocator = allocator;
139        self
140    }
141
142    pub fn with_types(mut self, typemap: TypeMap) -> Self {
143        self.typemap = typemap;
144        self
145    }
146
147    /// Create a new matrix element.
148    pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
149        let matrix = self.allocator.create_matrix(matrix);
150        self.add_matrix(*matrix);
151        matrix
152    }
153
154    pub fn add_matrix(&mut self, variable: Variable) {
155        self.matrices.push(variable);
156    }
157
158    /// Create a new pipeline element.
159    pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
160        let pipeline = self.allocator.create_pipeline(num_stages);
161        self.add_pipeline(*pipeline);
162        pipeline
163    }
164
165    /// Create a new barrier element.
166    pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ExpandElement {
167        let token = Variable::new(
168            VariableKind::BarrierToken { id, level },
169            Type::semantic(SemanticType::BarrierToken),
170        );
171        ExpandElement::Plain(token)
172    }
173
174    pub fn add_pipeline(&mut self, variable: Variable) {
175        self.pipelines.push(variable);
176    }
177
178    /// Create a mutable variable of the given [item type](Item).
179    pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
180        self.allocator.create_local_mut(item.into())
181    }
182
183    /// Create a mutable variable of the given [item type](Item).
184    pub fn add_local_mut(&mut self, var: Variable) {
185        if !self.locals.contains(&var) {
186            self.locals.push(var);
187        }
188    }
189
190    /// Create a new restricted variable. The variable is
191    /// Useful for _for loops_ and other algorithms that require the control over initialization.
192    pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
193        self.allocator.create_local_restricted(item)
194    }
195
196    /// Create a new immutable variable.
197    pub fn create_local(&mut self, item: Type) -> ExpandElement {
198        self.allocator.create_local(item)
199    }
200
201    /// Retrieve the last local variable that was created.
202    pub fn last_local_index(&self) -> Option<&Variable> {
203        self.locals.last()
204    }
205
206    /// Register an [operation](Operation) into the scope.
207    pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
208        let mut inst = instruction.into();
209        inst.source_loc = self.debug.source_loc.clone();
210        inst.modes = *self.modes.borrow();
211        self.instructions.push(inst)
212    }
213
214    /// Resolve the element type of the given generic type.
215    pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
216        let map = self.typemap.borrow();
217        let result = map.get(&TypeId::of::<T>());
218
219        result.cloned()
220    }
221
222    /// Register the element type for the given generic type.
223    pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
224        let mut map = self.typemap.borrow_mut();
225
226        map.insert(TypeId::of::<T>(), elem);
227    }
228
229    /// Create an empty child scope.
230    pub fn child(&mut self) -> Self {
231        Self {
232            validation_errors: self.validation_errors.clone(),
233            depth: self.depth + 1,
234            instructions: Vec::new(),
235            locals: Vec::new(),
236            matrices: Vec::new(),
237            pipelines: Vec::new(),
238            shared: Vec::new(),
239            const_arrays: Vec::new(),
240            local_arrays: Vec::new(),
241            index_offset_with_output_layout_position: Vec::new(),
242            allocator: self.allocator.clone(),
243            debug: self.debug.clone(),
244            typemap: self.typemap.clone(),
245            runtime_properties: self.runtime_properties.clone(),
246            modes: self.modes.clone(),
247            properties: self.properties.clone(),
248        }
249    }
250
251    // Adds a validation error.
252    pub fn push_error(&mut self, msg: impl Into<String>) {
253        self.validation_errors.errors.borrow_mut().push(msg.into());
254    }
255
256    /// Returns all validation errors.
257    pub fn pop_errors(&mut self) -> Vec<String> {
258        self.validation_errors.errors.replace_with(|_| Vec::new())
259    }
260
261    /// Returns the variables and operations to be declared and executed.
262    ///
263    /// Notes:
264    ///
265    /// New operations and variables can be created within the same scope without having name
266    /// conflicts.
267    pub fn process<'a>(
268        &mut self,
269        processors: impl IntoIterator<Item = &'a dyn Processor>,
270    ) -> ScopeProcessing {
271        let mut variables = core::mem::take(&mut self.locals);
272
273        for var in self.matrices.drain(..) {
274            variables.push(var);
275        }
276
277        let mut instructions = Vec::new();
278
279        for inst in self.instructions.drain(..) {
280            instructions.push(inst);
281        }
282
283        variables.extend(self.allocator.take_variables());
284
285        let mut processing = ScopeProcessing {
286            variables,
287            instructions,
288            typemap: self.typemap.clone(),
289        };
290
291        for p in processors {
292            processing = p.transform(processing, self.allocator.clone());
293        }
294
295        // Add variables added from processors
296        processing.variables.extend(self.allocator.take_variables());
297
298        processing
299    }
300
301    pub fn new_local_index(&self) -> u32 {
302        self.allocator.new_local_index()
303    }
304
305    /// Create a shared array variable of the given [item type](Item).
306    pub fn create_shared_array<I: Into<Type>>(
307        &mut self,
308        item: I,
309        shared_memory_size: usize,
310        alignment: Option<usize>,
311    ) -> ExpandElement {
312        let item = item.into();
313        let index = self.new_local_index();
314        let shared_array = Variable::new(
315            VariableKind::SharedArray {
316                id: index,
317                length: shared_memory_size,
318                unroll_factor: 1,
319                alignment,
320            },
321            item,
322        );
323        self.shared.push(shared_array);
324        ExpandElement::Plain(shared_array)
325    }
326
327    /// Create a shared variable of the given [item type](Item).
328    pub fn create_shared<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
329        let item = item.into();
330        let index = self.new_local_index();
331        let shared = Variable::new(VariableKind::Shared { id: index }, item);
332        self.shared.push(shared);
333        ExpandElement::Plain(shared)
334    }
335
336    /// Create a shared variable of the given [item type](Item).
337    pub fn create_const_array<I: Into<Type>>(
338        &mut self,
339        item: I,
340        data: Vec<Variable>,
341    ) -> ExpandElement {
342        let item = item.into();
343        let index = self.new_local_index();
344        let const_array = Variable::new(
345            VariableKind::ConstantArray {
346                id: index,
347                length: data.len(),
348                unroll_factor: 1,
349            },
350            item,
351        );
352        self.const_arrays.push((const_array, data));
353        ExpandElement::Plain(const_array)
354    }
355
356    /// Obtain the index-th input
357    pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
358        ExpandElement::Plain(crate::Variable::new(
359            VariableKind::GlobalInputArray(id),
360            item,
361        ))
362    }
363
364    /// Obtain the index-th output
365    pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
366        let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
367        ExpandElement::Plain(var)
368    }
369
370    /// Obtain the index-th scalar
371    pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
372        ExpandElement::Plain(crate::Variable::new(
373            VariableKind::GlobalScalar(id),
374            Type::new(storage),
375        ))
376    }
377
378    /// Create a local array of the given [item type](Item).
379    pub fn create_local_array<I: Into<Type>>(
380        &mut self,
381        item: I,
382        array_size: usize,
383    ) -> ExpandElement {
384        let local_array = self.allocator.create_local_array(item.into(), array_size);
385        self.add_local_array(*local_array);
386        local_array
387    }
388
389    pub fn add_local_array(&mut self, var: Variable) {
390        self.local_arrays.push(var);
391    }
392
393    pub fn update_source(&mut self, source: CubeFnSource) {
394        if self.debug.enabled {
395            self.debug.sources.borrow_mut().insert(source.clone());
396            self.debug.source_loc = Some(SourceLoc {
397                line: source.line,
398                column: source.column,
399                source,
400            });
401            if self.debug.entry_loc.is_none() {
402                self.debug.entry_loc = self.debug.source_loc.clone();
403            }
404        }
405    }
406
407    pub fn update_span(&mut self, line: u32, col: u32) {
408        if let Some(loc) = self.debug.source_loc.as_mut() {
409            loc.line = line;
410            loc.column = col;
411        }
412    }
413
414    pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
415        if self.debug.enabled {
416            self.debug
417                .variable_names
418                .borrow_mut()
419                .insert(variable, name.into());
420        }
421    }
422}
423
424impl Display for Scope {
425    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
426        writeln!(f, "{{")?;
427        for instruction in self.instructions.iter() {
428            let instruction_str = instruction.to_string();
429            if !instruction_str.is_empty() {
430                writeln!(
431                    f,
432                    "{}{}",
433                    "    ".repeat(self.depth as usize + 1),
434                    instruction_str,
435                )?;
436            }
437        }
438        write!(f, "{}}}", "    ".repeat(self.depth as usize))?;
439        Ok(())
440    }
441}