cubecl_ir/
scope.rs

1use alloc::{borrow::Cow, rc::Rc, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use hashbrown::{HashMap, HashSet};
4
5use crate::{
6    BarrierLevel, CubeFnSource, ExpandElement, Matrix, Processor, SourceLoc, StorageType,
7    TargetProperties, TypeHash,
8};
9
10use super::{
11    Allocator, Id, Instruction, Type, Variable, VariableKind, processing::ScopeProcessing,
12};
13
14/// The scope is the main [operation](Operation) and [variable](Variable) container that simplify
15/// the process of reading inputs, creating local variables and adding new operations.
16///
17/// Notes:
18///
19/// This type isn't responsible for creating [shader bindings](super::Binding) and figuring out which
20/// variable can be written to.
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
23#[allow(missing_docs)]
24pub struct Scope {
25    pub depth: u8,
26    pub instructions: Vec<Instruction>,
27    pub locals: Vec<Variable>,
28    matrices: Vec<Variable>,
29    pipelines: Vec<Variable>,
30    barriers: Vec<Variable>,
31    shared_memories: Vec<Variable>,
32    pub const_arrays: Vec<(Variable, Vec<Variable>)>,
33    local_arrays: Vec<Variable>,
34    index_offset_with_output_layout_position: Vec<usize>,
35    pub allocator: Allocator,
36    pub debug: DebugInfo,
37    #[type_hash(skip)]
38    #[cfg_attr(feature = "serde", serde(skip))]
39    pub typemap: Rc<RefCell<HashMap<TypeId, StorageType>>>,
40    pub runtime_properties: Rc<TargetProperties>,
41}
42
43/// Debug related fields, most of these are global
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
46pub struct DebugInfo {
47    pub enabled: bool,
48    pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
49    pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
50    pub source_loc: Option<SourceLoc>,
51    pub entry_loc: Option<SourceLoc>,
52}
53
54impl core::hash::Hash for Scope {
55    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
56        self.depth.hash(ra_expand_state);
57        self.instructions.hash(ra_expand_state);
58        self.locals.hash(ra_expand_state);
59        self.matrices.hash(ra_expand_state);
60        self.pipelines.hash(ra_expand_state);
61        self.barriers.hash(ra_expand_state);
62        self.shared_memories.hash(ra_expand_state);
63        self.const_arrays.hash(ra_expand_state);
64        self.local_arrays.hash(ra_expand_state);
65        self.index_offset_with_output_layout_position
66            .hash(ra_expand_state);
67    }
68}
69
70#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
72#[allow(missing_docs)]
73pub enum ReadingStrategy {
74    /// Each element will be read in a way to be compatible with the output layout.
75    OutputLayout,
76    /// Keep the current layout.
77    Plain,
78}
79
80impl Scope {
81    /// Create a scope that is at the root of a
82    /// [kernel definition](crate::ir::KernelDefinition).
83    ///
84    /// A local scope can be created with the [child](Self::child) method.
85    pub fn root(debug_enabled: bool) -> Self {
86        Self {
87            depth: 0,
88            instructions: Vec::new(),
89            locals: Vec::new(),
90            matrices: Vec::new(),
91            pipelines: Vec::new(),
92            barriers: Vec::new(),
93            local_arrays: Vec::new(),
94            shared_memories: Vec::new(),
95            const_arrays: Vec::new(),
96            index_offset_with_output_layout_position: Vec::new(),
97            allocator: Allocator::default(),
98            debug: DebugInfo {
99                enabled: debug_enabled,
100                sources: Default::default(),
101                variable_names: Default::default(),
102                source_loc: None,
103                entry_loc: None,
104            },
105            typemap: Default::default(),
106            runtime_properties: Rc::new(Default::default()),
107        }
108    }
109
110    /// Shift variable ids.
111    pub fn with_allocator(mut self, allocator: Allocator) -> Self {
112        self.allocator = allocator;
113        self
114    }
115
116    /// Create a new matrix element.
117    pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
118        let matrix = self.allocator.create_matrix(matrix);
119        self.add_matrix(*matrix);
120        matrix
121    }
122
123    pub fn add_matrix(&mut self, variable: Variable) {
124        self.matrices.push(variable);
125    }
126
127    /// Create a new pipeline element.
128    pub fn create_pipeline(&mut self, num_stages: u8) -> ExpandElement {
129        let pipeline = self.allocator.create_pipeline(num_stages);
130        self.add_pipeline(*pipeline);
131        pipeline
132    }
133
134    /// Create a new barrier element.
135    pub fn create_barrier(&mut self, level: BarrierLevel) -> ExpandElement {
136        let barrier = self.allocator.create_barrier(level);
137        self.add_barrier(*barrier);
138        barrier
139    }
140
141    pub fn add_pipeline(&mut self, variable: Variable) {
142        self.pipelines.push(variable);
143    }
144
145    pub fn add_barrier(&mut self, variable: Variable) {
146        self.barriers.push(variable);
147    }
148
149    /// Create a mutable variable of the given [item type](Item).
150    pub fn create_local_mut<I: Into<Type>>(&mut self, item: I) -> ExpandElement {
151        self.allocator.create_local_mut(item.into())
152    }
153
154    /// Create a mutable variable of the given [item type](Item).
155    pub fn add_local_mut(&mut self, var: Variable) {
156        if !self.locals.contains(&var) {
157            self.locals.push(var);
158        }
159    }
160
161    /// Create a new restricted variable. The variable is
162    /// Useful for _for loops_ and other algorithms that require the control over initialization.
163    pub fn create_local_restricted(&mut self, item: Type) -> ExpandElement {
164        self.allocator.create_local_restricted(item)
165    }
166
167    /// Create a new immutable variable.
168    pub fn create_local(&mut self, item: Type) -> ExpandElement {
169        self.allocator.create_local(item)
170    }
171
172    /// Retrieve the last local variable that was created.
173    pub fn last_local_index(&self) -> Option<&Variable> {
174        self.locals.last()
175    }
176
177    /// Register an [operation](Operation) into the scope.
178    pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
179        let mut inst = instruction.into();
180        inst.source_loc = self.debug.source_loc.clone();
181        self.instructions.push(inst)
182    }
183
184    /// Resolve the element type of the given generic type.
185    pub fn resolve_type<T: 'static>(&self) -> Option<StorageType> {
186        let map = self.typemap.borrow();
187        let result = map.get(&TypeId::of::<T>());
188
189        result.cloned()
190    }
191
192    /// Register the element type for the given generic type.
193    pub fn register_type<T: 'static>(&mut self, elem: StorageType) {
194        let mut map = self.typemap.borrow_mut();
195
196        map.insert(TypeId::of::<T>(), elem);
197    }
198
199    /// Create an empty child scope.
200    pub fn child(&mut self) -> Self {
201        Self {
202            depth: self.depth + 1,
203            instructions: Vec::new(),
204            locals: Vec::new(),
205            matrices: Vec::new(),
206            pipelines: Vec::new(),
207            barriers: Vec::new(),
208            shared_memories: Vec::new(),
209            const_arrays: Vec::new(),
210            local_arrays: Vec::new(),
211            index_offset_with_output_layout_position: Vec::new(),
212            allocator: self.allocator.clone(),
213            debug: self.debug.clone(),
214            typemap: self.typemap.clone(),
215            runtime_properties: self.runtime_properties.clone(),
216        }
217    }
218
219    /// Returns the variables and operations to be declared and executed.
220    ///
221    /// Notes:
222    ///
223    /// New operations and variables can be created within the same scope without having name
224    /// conflicts.
225    pub fn process<'a>(
226        &mut self,
227        processors: impl IntoIterator<Item = &'a dyn Processor>,
228    ) -> ScopeProcessing {
229        let mut variables = core::mem::take(&mut self.locals);
230
231        for var in self.matrices.drain(..) {
232            variables.push(var);
233        }
234
235        let mut instructions = Vec::new();
236
237        for inst in self.instructions.drain(..) {
238            instructions.push(inst);
239        }
240
241        variables.extend(self.allocator.take_variables());
242
243        let mut processing = ScopeProcessing {
244            variables,
245            instructions,
246        }
247        .optimize();
248
249        for p in processors {
250            processing = p.transform(processing, self.allocator.clone());
251        }
252
253        // Add variables added from processors
254        processing.variables.extend(self.allocator.take_variables());
255
256        processing
257    }
258
259    pub fn new_local_index(&self) -> u32 {
260        self.allocator.new_local_index()
261    }
262
263    /// Create a shared variable of the given [item type](Item).
264    pub fn create_shared<I: Into<Type>>(
265        &mut self,
266        item: I,
267        shared_memory_size: u32,
268        alignment: Option<u32>,
269    ) -> ExpandElement {
270        let item = item.into();
271        let index = self.new_local_index();
272        let shared_memory = Variable::new(
273            VariableKind::SharedMemory {
274                id: index,
275                length: shared_memory_size,
276                unroll_factor: 1,
277                alignment,
278            },
279            item,
280        );
281        self.shared_memories.push(shared_memory);
282        ExpandElement::Plain(shared_memory)
283    }
284
285    /// Create a shared variable of the given [item type](Item).
286    pub fn create_const_array<I: Into<Type>>(
287        &mut self,
288        item: I,
289        data: Vec<Variable>,
290    ) -> ExpandElement {
291        let item = item.into();
292        let index = self.new_local_index();
293        let const_array = Variable::new(
294            VariableKind::ConstantArray {
295                id: index,
296                length: data.len() as u32,
297                unroll_factor: 1,
298            },
299            item,
300        );
301        self.const_arrays.push((const_array, data));
302        ExpandElement::Plain(const_array)
303    }
304
305    /// Obtain the index-th input
306    pub fn input(&mut self, id: Id, item: Type) -> ExpandElement {
307        ExpandElement::Plain(crate::Variable::new(
308            VariableKind::GlobalInputArray(id),
309            item,
310        ))
311    }
312
313    /// Obtain the index-th output
314    pub fn output(&mut self, id: Id, item: Type) -> ExpandElement {
315        let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
316        ExpandElement::Plain(var)
317    }
318
319    /// Obtain the index-th scalar
320    pub fn scalar(&self, id: Id, storage: StorageType) -> ExpandElement {
321        ExpandElement::Plain(crate::Variable::new(
322            VariableKind::GlobalScalar(id),
323            Type::new(storage),
324        ))
325    }
326
327    /// Create a local array of the given [item type](Item).
328    pub fn create_local_array<I: Into<Type>>(&mut self, item: I, array_size: u32) -> ExpandElement {
329        let local_array = self.allocator.create_local_array(item.into(), array_size);
330        self.add_local_array(*local_array);
331        local_array
332    }
333
334    pub fn add_local_array(&mut self, var: Variable) {
335        self.local_arrays.push(var);
336    }
337
338    pub fn update_source(&mut self, source: CubeFnSource) {
339        if self.debug.enabled {
340            self.debug.sources.borrow_mut().insert(source.clone());
341            self.debug.source_loc = Some(SourceLoc {
342                line: source.line,
343                column: source.column,
344                source,
345            });
346            if self.debug.entry_loc.is_none() {
347                self.debug.entry_loc = self.debug.source_loc.clone();
348            }
349        }
350    }
351
352    pub fn update_span(&mut self, line: u32, col: u32) {
353        if let Some(loc) = self.debug.source_loc.as_mut() {
354            loc.line = line;
355            loc.column = col;
356        }
357    }
358
359    pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
360        if self.debug.enabled {
361            self.debug
362                .variable_names
363                .borrow_mut()
364                .insert(variable, name.into());
365        }
366    }
367}
368
369impl Display for Scope {
370    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
371        writeln!(f, "{{")?;
372        for instruction in self.instructions.iter() {
373            let instruction_str = instruction.to_string();
374            if !instruction_str.is_empty() {
375                writeln!(
376                    f,
377                    "{}{}",
378                    "    ".repeat(self.depth as usize + 1),
379                    instruction_str,
380                )?;
381            }
382        }
383        write!(f, "{}}}", "    ".repeat(self.depth as usize))?;
384        Ok(())
385    }
386}