cubecl_ir/
scope.rs

1use alloc::{borrow::Cow, rc::Rc, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use enumset::EnumSet;
4use hashbrown::{HashMap, HashSet};
5
6use crate::{
7    BarrierLevel, CubeFnSource, ExpandElement, FastMath, Matrix, Processor, SemanticType,
8    SourceLoc, StorageType, TargetProperties, TypeHash,
9};
10
11use super::{
12    Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
13};
14
15/// The scope is the main [operation](Operation) and [variable](Variable) container that simplify
16/// the process of reading inputs, creating local variables and adding new operations.
17///
18/// Notes:
19///
20/// This type isn't responsible for creating [shader bindings](super::Binding) and figuring out which
21/// variable can be written to.
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
24#[allow(missing_docs)]
25pub struct Scope {
26    pub depth: u8,
27    pub instructions: Vec<Instruction>,
28    pub locals: Vec<Variable>,
29    matrices: Vec<Variable>,
30    pipelines: Vec<Variable>,
31    barriers: Vec<Variable>,
32    shared_memories: Vec<Variable>,
33    pub const_arrays: Vec<(Variable, Vec<Variable>)>,
34    local_arrays: Vec<Variable>,
35    index_offset_with_output_layout_position: Vec<usize>,
36    pub allocator: Allocator,
37    pub debug: DebugInfo,
38    #[type_hash(skip)]
39    #[cfg_attr(feature = "serde", serde(skip))]
40    pub typemap: Rc<RefCell<HashMap<TypeId, StorageType>>>,
41    pub runtime_properties: Rc<TargetProperties>,
42    pub modes: Rc<RefCell<InstructionModes>>,
43}
44
45/// Debug related fields, most of these are global
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
48pub struct DebugInfo {
49    pub enabled: bool,
50    pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
51    pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
52    pub source_loc: Option<SourceLoc>,
53    pub entry_loc: Option<SourceLoc>,
54}
55
56/// Modes set and reset during expansion
57#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)]
59pub struct InstructionModes {
60    pub fp_math_mode: EnumSet<FastMath>,
61}
62
63impl core::hash::Hash for Scope {
64    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
65        self.depth.hash(ra_expand_state);
66        self.instructions.hash(ra_expand_state);
67        self.locals.hash(ra_expand_state);
68        self.matrices.hash(ra_expand_state);
69        self.pipelines.hash(ra_expand_state);
70        self.barriers.hash(ra_expand_state);
71        self.shared_memories.hash(ra_expand_state);
72        self.const_arrays.hash(ra_expand_state);
73        self.local_arrays.hash(ra_expand_state);
74        self.index_offset_with_output_layout_position
75            .hash(ra_expand_state);
76    }
77}
78
79#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
81#[allow(missing_docs)]
82pub enum ReadingStrategy {
83    /// Each element will be read in a way to be compatible with the output layout.
84    OutputLayout,
85    /// Keep the current layout.
86    Plain,
87}
88
89impl Scope {
90    /// Create a scope that is at the root of a
91    /// [kernel definition](crate::ir::KernelDefinition).
92    ///
93    /// A local scope can be created with the [child](Self::child) method.
94    pub fn root(debug_enabled: bool) -> Self {
95        Self {
96            depth: 0,
97            instructions: Vec::new(),
98            locals: Vec::new(),
99            matrices: Vec::new(),
100            pipelines: Vec::new(),
101            barriers: Vec::new(),
102            local_arrays: Vec::new(),
103            shared_memories: Vec::new(),
104            const_arrays: Vec::new(),
105            index_offset_with_output_layout_position: Vec::new(),
106            allocator: Allocator::default(),
107            debug: DebugInfo {
108                enabled: debug_enabled,
109                sources: Default::default(),
110                variable_names: Default::default(),
111                source_loc: None,
112                entry_loc: None,
113            },
114            typemap: Default::default(),
115            runtime_properties: Rc::new(Default::default()),
116            modes: Default::default(),
117        }
118    }
119
120    /// Shift variable ids.
121    pub fn with_allocator(mut self, allocator: Allocator) -> Self {
122        self.allocator = allocator;
123        self
124    }
125
126    /// Create a new matrix element.
127    pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
128        let matrix = self.allocator.create_matrix(matrix);
129        self.add_matrix(*matrix);
130        matrix
131    }
132
133    pub fn add_matrix(&mut self, variable: Variable) {
134        self.matrices.push(variable);
135    }
136
137    /// Create a new pipeline element.
138    pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
139        let pipeline = self.allocator.create_pipeline(num_stages);
140        self.add_pipeline(*pipeline);
141        pipeline
142    }
143
144    /// Create a new barrier element.
145    pub fn create_barrier(&mut self, level: BarrierLevel) -> ExpandElement {
146        let barrier = self.allocator.create_barrier(level);
147        self.add_barrier(*barrier);
148        barrier
149    }
150
151    /// Create a new barrier element.
152    pub fn create_barrier_token(&mut self, id: Id, level: BarrierLevel) -> ExpandElement {
153        let token = Variable::new(
154            VariableKind::BarrierToken { id, level },
155            Type::semantic(SemanticType::BarrierToken),
156        );
157        ExpandElement::Plain(token)
158    }
159
160    pub fn add_pipeline(&mut self, variable: Variable) {
161        self.pipelines.push(variable);
162    }
163
164    pub fn add_barrier(&mut self, variable: Variable) {
165        self.barriers.push(variable);
166    }
167
168    /// Create a mutable variable of the given [item type](Item).
169    pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
170        self.allocator.create_local_mut(item.into())
171    }
172
173    /// Create a mutable variable of the given [item type](Item).
174    pub fn add_local_mut(&mut self, var: Variable) {
175        if !self.locals.contains(&var) {
176            self.locals.push(var);
177        }
178    }
179
180    /// Create a new restricted variable. The variable is
181    /// Useful for _for loops_ and other algorithms that require the control over initialization.
182    pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
183        self.allocator.create_local_restricted(item)
184    }
185
186    /// Create a new immutable variable.
187    pub fn create_local(&mut self, item: Type) -> ExpandElement {
188        self.allocator.create_local(item)
189    }
190
191    /// Retrieve the last local variable that was created.
192    pub fn last_local_index(&self) -> Option<&Variable> {
193        self.locals.last()
194    }
195
196    /// Register an [operation](Operation) into the scope.
197    pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
198        let mut inst = instruction.into();
199        inst.source_loc = self.debug.source_loc.clone();
200        inst.modes = *self.modes.borrow();
201        self.instructions.push(inst)
202    }
203
204    /// Resolve the element type of the given generic type.
205    pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
206        let map = self.typemap.borrow();
207        let result = map.get(&TypeId::of::<T>());
208
209        result.cloned()
210    }
211
212    /// Register the element type for the given generic type.
213    pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
214        let mut map = self.typemap.borrow_mut();
215
216        map.insert(TypeId::of::<T>(), elem);
217    }
218
219    /// Create an empty child scope.
220    pub fn child(&mut self) -> Self {
221        Self {
222            depth: self.depth + 1,
223            instructions: Vec::new(),
224            locals: Vec::new(),
225            matrices: Vec::new(),
226            pipelines: Vec::new(),
227            barriers: Vec::new(),
228            shared_memories: Vec::new(),
229            const_arrays: Vec::new(),
230            local_arrays: Vec::new(),
231            index_offset_with_output_layout_position: Vec::new(),
232            allocator: self.allocator.clone(),
233            debug: self.debug.clone(),
234            typemap: self.typemap.clone(),
235            runtime_properties: self.runtime_properties.clone(),
236            modes: self.modes.clone(),
237        }
238    }
239
240    /// Returns the variables and operations to be declared and executed.
241    ///
242    /// Notes:
243    ///
244    /// New operations and variables can be created within the same scope without having name
245    /// conflicts.
246    pub fn process<'a>(
247        &mut self,
248        processors: impl IntoIterator<Item = &'a dyn Processor>,
249    ) -> ScopeProcessing {
250        let mut variables = core::mem::take(&mut self.locals);
251
252        for var in self.matrices.drain(..) {
253            variables.push(var);
254        }
255
256        let mut instructions = Vec::new();
257
258        for inst in self.instructions.drain(..) {
259            instructions.push(inst);
260        }
261
262        variables.extend(self.allocator.take_variables());
263
264        let mut processing = ScopeProcessing {
265            variables,
266            instructions,
267        }
268        .optimize();
269
270        for p in processors {
271            processing = p.transform(processing, self.allocator.clone());
272        }
273
274        // Add variables added from processors
275        processing.variables.extend(self.allocator.take_variables());
276
277        processing
278    }
279
280    pub fn new_local_index(&self) -> u32 {
281        self.allocator.new_local_index()
282    }
283
284    /// Create a shared variable of the given [item type](Item).
285    pub fn create_shared<I: Into<Type>>(
286        &mut self,
287        item: I,
288        shared_memory_size: u32,
289        alignment: Option<u32>,
290    ) -> ExpandElement {
291        let item = item.into();
292        let index = self.new_local_index();
293        let shared_memory = Variable::new(
294            VariableKind::SharedMemory {
295                id: index,
296                length: shared_memory_size,
297                unroll_factor: 1,
298                alignment,
299            },
300            item,
301        );
302        self.shared_memories.push(shared_memory);
303        ExpandElement::Plain(shared_memory)
304    }
305
306    /// Create a shared variable of the given [item type](Item).
307    pub fn create_const_array<I: Into<Type>>(
308        &mut self,
309        item: I,
310        data: Vec<Variable>,
311    ) -> ExpandElement {
312        let item = item.into();
313        let index = self.new_local_index();
314        let const_array = Variable::new(
315            VariableKind::ConstantArray {
316                id: index,
317                length: data.len() as u32,
318                unroll_factor: 1,
319            },
320            item,
321        );
322        self.const_arrays.push((const_array, data));
323        ExpandElement::Plain(const_array)
324    }
325
326    /// Obtain the index-th input
327    pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
328        ExpandElement::Plain(crate::Variable::new(
329            VariableKind::GlobalInputArray(id),
330            item,
331        ))
332    }
333
334    /// Obtain the index-th output
335    pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
336        let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
337        ExpandElement::Plain(var)
338    }
339
340    /// Obtain the index-th scalar
341    pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
342        ExpandElement::Plain(crate::Variable::new(
343            VariableKind::GlobalScalar(id),
344            Type::new(storage),
345        ))
346    }
347
348    /// Create a local array of the given [item type](Item).
349    pub fn create_local_array<I: Into<Type>>(&mut self, item: I, array_size: u32) -> ExpandElement {
350        let local_array = self.allocator.create_local_array(item.into(), array_size);
351        self.add_local_array(*local_array);
352        local_array
353    }
354
355    pub fn add_local_array(&mut self, var: Variable) {
356        self.local_arrays.push(var);
357    }
358
359    pub fn update_source(&mut self, source: CubeFnSource) {
360        if self.debug.enabled {
361            self.debug.sources.borrow_mut().insert(source.clone());
362            self.debug.source_loc = Some(SourceLoc {
363                line: source.line,
364                column: source.column,
365                source,
366            });
367            if self.debug.entry_loc.is_none() {
368                self.debug.entry_loc = self.debug.source_loc.clone();
369            }
370        }
371    }
372
373    pub fn update_span(&mut self, line: u32, col: u32) {
374        if let Some(loc) = self.debug.source_loc.as_mut() {
375            loc.line = line;
376            loc.column = col;
377        }
378    }
379
380    pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
381        if self.debug.enabled {
382            self.debug
383                .variable_names
384                .borrow_mut()
385                .insert(variable, name.into());
386        }
387    }
388}
389
390impl Display for Scope {
391    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
392        writeln!(f, "{{")?;
393        for instruction in self.instructions.iter() {
394            let instruction_str = instruction.to_string();
395            if !instruction_str.is_empty() {
396                writeln!(
397                    f,
398                    "{}{}",
399                    "    ".repeat(self.depth as usize + 1),
400                    instruction_str,
401                )?;
402            }
403        }
404        write!(f, "{}}}", "    ".repeat(self.depth as usize))?;
405        Ok(())
406    }
407}