Skip to main content

cubecl_ir/
scope.rs

1use alloc::{borrow::Cow, rc::Rc, string::String, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7    BarrierLevel, CubeFnSource, DeviceProperties, ExpandElement, FastMath, Matrix, Processor,
8    SemanticType, SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12    Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15pub type TypeMap = Rc<RefCell<HashMap<TypeId, StorageType>>>;
16
17/// The scope is the main [`crate::Operation`] and [`crate::Variable`] container that simplify
18/// the process of reading inputs, creating local variables and adding new operations.
19///
20/// Notes:
21///
22/// This type isn't responsible for creating shader bindings and figuring out which
23/// variable can be written to.
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
26#[allow(missing_docs)]
27pub struct Scope {
28    validation_errors: ValidationErrors,
29    pub depth: u8,
30    pub instructions: Vec<Instruction>,
31    pub locals: Vec<Variable>,
32    matrices: Vec<Variable>,
33    pipelines: Vec<Variable>,
34    shared: Vec<Variable>,
35    pub const_arrays: Vec<(Variable, Vec<Variable>)>,
36    local_arrays: Vec<Variable>,
37    index_offset_with_output_layout_position: Vec<usize>,
38    pub allocator: Allocator,
39    pub debug: DebugInfo,
40    #[type_hash(skip)]
41    #[cfg_attr(feature = "serde", serde(skip))]
42    pub typemap: TypeMap,
43    pub runtime_properties: Rc<TargetProperties>,
44    pub modes: Rc<RefCell<InstructionModes>>,
45    #[cfg_attr(feature = "serde", serde(skip))]
46    pub properties: Option<Rc<DeviceProperties>>,
47}
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
51pub struct ValidationErrors {
52    errors: Rc<RefCell<Vec<String>>>,
53}
54
55/// Debug related fields, most of these are global
56#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
57#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
58pub struct DebugInfo {
59    pub enabled: bool,
60    pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
61    pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
62    pub source_loc: Option<SourceLoc>,
63    pub entry_loc: Option<SourceLoc>,
64}
65
66/// Modes set and reset during expansion
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
69pub struct InstructionModes {
70    pub fp_math_mode: EnumSet<FastMath>,
71}
72
73impl core::hash::Hash for Scope {
74    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
75        self.depth.hash(ra_expand_state);
76        self.instructions.hash(ra_expand_state);
77        self.locals.hash(ra_expand_state);
78        self.matrices.hash(ra_expand_state);
79        self.pipelines.hash(ra_expand_state);
80        self.shared.hash(ra_expand_state);
81        self.const_arrays.hash(ra_expand_state);
82        self.local_arrays.hash(ra_expand_state);
83        self.index_offset_with_output_layout_position
84            .hash(ra_expand_state);
85    }
86}
87
88#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
90#[allow(missing_docs)]
91pub enum ReadingStrategy {
92    /// Each element will be read in a way to be compatible with the output layout.
93    OutputLayout,
94    /// Keep the current layout.
95    Plain,
96}
97
98impl Scope {
99    /// Set the device properties.
100    pub fn device_properties(&mut self, properties: &DeviceProperties) {
101        self.properties = Some(Rc::new(properties.clone()));
102    }
103    /// Create a scope that is at the root of a kernel definition.
104    ///
105    /// A local scope can be created with the [child](Self::child) method.
106    pub fn root(debug_enabled: bool) -> Self {
107        Self {
108            validation_errors: ValidationErrors {
109                errors: Rc::new(RefCell::new(Vec::new())),
110            },
111            depth: 0,
112            instructions: Vec::new(),
113            locals: Vec::new(),
114            matrices: Vec::new(),
115            pipelines: Vec::new(),
116            local_arrays: Vec::new(),
117            shared: Vec::new(),
118            const_arrays: Vec::new(),
119            index_offset_with_output_layout_position: Vec::new(),
120            allocator: Allocator::default(),
121            debug: DebugInfo {
122                enabled: debug_enabled,
123                sources: Default::default(),
124                variable_names: Default::default(),
125                source_loc: None,
126                entry_loc: None,
127            },
128            typemap: Default::default(),
129            runtime_properties: Rc::new(Default::default()),
130            modes: Default::default(),
131            properties: None,
132        }
133    }
134
135    /// Shift variable ids.
136    pub fn with_allocator(mut self, allocator: Allocator) -> Self {
137        self.allocator = allocator;
138        self
139    }
140
141    pub fn with_types(mut self, typemap: TypeMap) -> Self {
142        self.typemap = typemap;
143        self
144    }
145
146    /// Create a new matrix element.
147    pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
148        let matrix = self.allocator.create_matrix(matrix);
149        self.add_matrix(*matrix);
150        matrix
151    }
152
153    pub fn add_matrix(&mut self, variable: Variable) {
154        self.matrices.push(variable);
155    }
156
157    /// Create a new pipeline element.
158    pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
159        let pipeline = self.allocator.create_pipeline(num_stages);
160        self.add_pipeline(*pipeline);
161        pipeline
162    }
163
164    /// Create a new barrier element.
165    pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ExpandElement {
166        let token = Variable::new(
167            VariableKind::BarrierToken { id, level },
168            Type::semantic(SemanticType::BarrierToken),
169        );
170        ExpandElement::Plain(token)
171    }
172
173    pub fn add_pipeline(&mut self, variable: Variable) {
174        self.pipelines.push(variable);
175    }
176
177    /// Create a mutable variable of the given item type.
178    pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
179        self.allocator.create_local_mut(item.into())
180    }
181
182    /// Create a mutable variable of the given item type.
183    pub fn add_local_mut(&mut self, var: Variable) {
184        if !self.locals.contains(&var) {
185            self.locals.push(var);
186        }
187    }
188
189    /// Create a new restricted variable. The variable is
190    /// Useful for _for loops_ and other algorithms that require the control over initialization.
191    pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
192        self.allocator.create_local_restricted(item)
193    }
194
195    /// Create a new immutable variable.
196    pub fn create_local(&mut self, item: Type) -> ExpandElement {
197        self.allocator.create_local(item)
198    }
199
200    /// Retrieve the last local variable that was created.
201    pub fn last_local_index(&self) -> Option<&Variable> {
202        self.locals.last()
203    }
204
205    /// Register an [`Instruction`] into the scope.
206    pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
207        let mut inst = instruction.into();
208        inst.source_loc = self.debug.source_loc.clone();
209        inst.modes = *self.modes.borrow();
210        self.instructions.push(inst)
211    }
212
213    /// Resolve the element type of the given generic type.
214    pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
215        let map = self.typemap.borrow();
216        let result = map.get(&TypeId::of::<T>());
217
218        result.cloned()
219    }
220
221    /// Register the element type for the given generic type.
222    pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
223        let mut map = self.typemap.borrow_mut();
224
225        map.insert(TypeId::of::<T>(), elem);
226    }
227
228    /// Create an empty child scope.
229    pub fn child(&mut self) -> Self {
230        Self {
231            validation_errors: self.validation_errors.clone(),
232            depth: self.depth + 1,
233            instructions: Vec::new(),
234            locals: Vec::new(),
235            matrices: Vec::new(),
236            pipelines: Vec::new(),
237            shared: Vec::new(),
238            const_arrays: Vec::new(),
239            local_arrays: Vec::new(),
240            index_offset_with_output_layout_position: Vec::new(),
241            allocator: self.allocator.clone(),
242            debug: self.debug.clone(),
243            typemap: self.typemap.clone(),
244            runtime_properties: self.runtime_properties.clone(),
245            modes: self.modes.clone(),
246            properties: self.properties.clone(),
247        }
248    }
249
250    // Adds a validation error.
251    pub fn push_error(&mut self, msg: impl Into<String>) {
252        self.validation_errors.errors.borrow_mut().push(msg.into());
253    }
254
255    /// Returns all validation errors.
256    pub fn pop_errors(&mut self) -> Vec<String> {
257        self.validation_errors.errors.replace_with(|_| Vec::new())
258    }
259
260    /// Returns the variables and operations to be declared and executed.
261    ///
262    /// Notes:
263    ///
264    /// New operations and variables can be created within the same scope without having name
265    /// conflicts.
266    pub fn process<'a>(
267        &mut self,
268        processors: impl IntoIterator<Item = &'a dyn Processor>,
269    ) -> ScopeProcessing {
270        let mut variables = core::mem::take(&mut self.locals);
271
272        for var in self.matrices.drain(..) {
273            variables.push(var);
274        }
275
276        let mut instructions = Vec::new();
277
278        for inst in self.instructions.drain(..) {
279            instructions.push(inst);
280        }
281
282        variables.extend(self.allocator.take_variables());
283
284        let mut processing = ScopeProcessing {
285            variables,
286            instructions,
287            typemap: self.typemap.clone(),
288        };
289
290        for p in processors {
291            processing = p.transform(processing, self.allocator.clone());
292        }
293
294        // Add variables added from processors
295        processing.variables.extend(self.allocator.take_variables());
296
297        processing
298    }
299
300    pub fn new_local_index(&self) -> u32 {
301        self.allocator.new_local_index()
302    }
303
304    /// Create a shared array variable of the given item type.
305    pub fn create_shared_array<I: Into<Type>>(
306        &mut self,
307        item: I,
308        shared_memory_size: usize,
309        alignment: Option<usize>,
310    ) -> ExpandElement {
311        let item = item.into();
312        let index = self.new_local_index();
313        let shared_array = Variable::new(
314            VariableKind::SharedArray {
315                id: index,
316                length: shared_memory_size,
317                unroll_factor: 1,
318                alignment,
319            },
320            item,
321        );
322        self.shared.push(shared_array);
323        ExpandElement::Plain(shared_array)
324    }
325
326    /// Create a shared variable of the given item type.
327    pub fn create_shared<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
328        let item = item.into();
329        let index = self.new_local_index();
330        let shared = Variable::new(VariableKind::Shared { id: index }, item);
331        self.shared.push(shared);
332        ExpandElement::Plain(shared)
333    }
334
335    /// Create a shared variable of the given item type.
336    pub fn create_const_array<I: Into<Type>>(
337        &mut self,
338        item: I,
339        data: Vec<Variable>,
340    ) -> ExpandElement {
341        let item = item.into();
342        let index = self.new_local_index();
343        let const_array = Variable::new(
344            VariableKind::ConstantArray {
345                id: index,
346                length: data.len(),
347                unroll_factor: 1,
348            },
349            item,
350        );
351        self.const_arrays.push((const_array, data));
352        ExpandElement::Plain(const_array)
353    }
354
355    /// Obtain the index-th input
356    pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
357        ExpandElement::Plain(crate::Variable::new(
358            VariableKind::GlobalInputArray(id),
359            item,
360        ))
361    }
362
363    /// Obtain the index-th output
364    pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
365        let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
366        ExpandElement::Plain(var)
367    }
368
369    /// Obtain the index-th scalar
370    pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
371        ExpandElement::Plain(crate::Variable::new(
372            VariableKind::GlobalScalar(id),
373            Type::new(storage),
374        ))
375    }
376
377    /// Create a local array of the given item type.
378    pub fn create_local_array<I: Into<Type>>(
379        &mut self,
380        item: I,
381        array_size: usize,
382    ) -> ExpandElement {
383        let local_array = self.allocator.create_local_array(item.into(), array_size);
384        self.add_local_array(*local_array);
385        local_array
386    }
387
388    pub fn add_local_array(&mut self, var: Variable) {
389        self.local_arrays.push(var);
390    }
391
392    pub fn update_source(&mut self, source: CubeFnSource) {
393        if self.debug.enabled {
394            self.debug.sources.borrow_mut().insert(source.clone());
395            self.debug.source_loc = Some(SourceLoc {
396                line: source.line,
397                column: source.column,
398                source,
399            });
400            if self.debug.entry_loc.is_none() {
401                self.debug.entry_loc = self.debug.source_loc.clone();
402            }
403        }
404    }
405
406    pub fn update_span(&mut self, line: u32, col: u32) {
407        if let Some(loc) = self.debug.source_loc.as_mut() {
408            loc.line = line;
409            loc.column = col;
410        }
411    }
412
413    pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
414        if self.debug.enabled {
415            self.debug
416                .variable_names
417                .borrow_mut()
418                .insert(variable, name.into());
419        }
420    }
421}
422
423impl Display for Scope {
424    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
425        writeln!(f, "{{")?;
426        for instruction in self.instructions.iter() {
427            let instruction_str = instruction.to_string();
428            if !instruction_str.is_empty() {
429                writeln!(
430                    f,
431                    "{}{}",
432                    "    ".repeat(self.depth as usize + 1),
433                    instruction_str,
434                )?;
435            }
436        }
437        write!(f, "{}}}", "    ".repeat(self.depth as usize))?;
438        Ok(())
439    }
440}