cubecl_ir/
scope.rs

1use alloc::{borrow::Cow, boxed::Box, rc::Rc, string::ToString, vec::Vec};
2use core::{any::TypeId, cell::RefCell, fmt::Display};
3use hashbrown::{HashMap, HashSet};
4
5use crate::{BarrierLevel, CubeFnSource, ExpandElement, Matrix, Processor, SourceLoc, TypeHash};
6
7use super::{
8    Allocator, Elem, Id, Instruction, Item, Variable, VariableKind, processing::ScopeProcessing,
9};
10
11/// The scope is the main [operation](Operation) and [variable](Variable) container that simplify
12/// the process of reading inputs, creating local variables and adding new operations.
13///
14/// Notes:
15///
16/// This type isn't responsible for creating [shader bindings](super::Binding) and figuring out which
17/// variable can be written to.
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
20#[allow(missing_docs)]
21pub struct Scope {
22    pub depth: u8,
23    pub instructions: Vec<Instruction>,
24    pub locals: Vec<Variable>,
25    matrices: Vec<Variable>,
26    pipelines: Vec<Variable>,
27    barriers: Vec<Variable>,
28    shared_memories: Vec<Variable>,
29    pub const_arrays: Vec<(Variable, Vec<Variable>)>,
30    local_arrays: Vec<Variable>,
31    index_offset_with_output_layout_position: Vec<usize>,
32    pub allocator: Allocator,
33    pub debug: DebugInfo,
34    #[type_hash(skip)]
35    #[cfg_attr(feature = "serde", serde(skip))]
36    pub typemap: Rc<RefCell<HashMap<TypeId, Elem>>>,
37}
38
39/// Debug related fields, most of these are global
40#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
41#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
42pub struct DebugInfo {
43    pub enabled: bool,
44    pub sources: Rc<RefCell<HashSet<CubeFnSource>>>,
45    pub variable_names: Rc<RefCell<HashMap<Variable, Cow<'static, str>>>>,
46    pub source_loc: Option<SourceLoc>,
47    pub entry_loc: Option<SourceLoc>,
48}
49
50impl core::hash::Hash for Scope {
51    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
52        self.depth.hash(ra_expand_state);
53        self.instructions.hash(ra_expand_state);
54        self.locals.hash(ra_expand_state);
55        self.matrices.hash(ra_expand_state);
56        self.pipelines.hash(ra_expand_state);
57        self.barriers.hash(ra_expand_state);
58        self.shared_memories.hash(ra_expand_state);
59        self.const_arrays.hash(ra_expand_state);
60        self.local_arrays.hash(ra_expand_state);
61        self.index_offset_with_output_layout_position
62            .hash(ra_expand_state);
63    }
64}
65
66#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash)]
68#[allow(missing_docs)]
69pub enum ReadingStrategy {
70    /// Each element will be read in a way to be compatible with the output layout.
71    OutputLayout,
72    /// Keep the current layout.
73    Plain,
74}
75
76impl Scope {
77    /// Create a scope that is at the root of a
78    /// [kernel definition](crate::ir::KernelDefinition).
79    ///
80    /// A local scope can be created with the [child](Self::child) method.
81    pub fn root(debug_enabled: bool) -> Self {
82        Self {
83            depth: 0,
84            instructions: Vec::new(),
85            locals: Vec::new(),
86            matrices: Vec::new(),
87            pipelines: Vec::new(),
88            barriers: Vec::new(),
89            local_arrays: Vec::new(),
90            shared_memories: Vec::new(),
91            const_arrays: Vec::new(),
92            index_offset_with_output_layout_position: Vec::new(),
93            allocator: Allocator::default(),
94            debug: DebugInfo {
95                enabled: debug_enabled,
96                sources: Default::default(),
97                variable_names: Default::default(),
98                source_loc: None,
99                entry_loc: None,
100            },
101            typemap: Default::default(),
102        }
103    }
104
105    /// Shift variable ids.
106    pub fn with_allocator(mut self, allocator: Allocator) -> Self {
107        self.allocator = allocator;
108        self
109    }
110
111    /// Create a new matrix element.
112    pub fn create_matrix(&mut self, matrix: Matrix) -> ExpandElement {
113        let matrix = self.allocator.create_matrix(matrix);
114        self.add_matrix(*matrix);
115        matrix
116    }
117
118    pub fn add_matrix(&mut self, variable: Variable) {
119        self.matrices.push(variable);
120    }
121
122    /// Create a new pipeline element.
123    pub fn create_pipeline(&mut self, item: Item, num_stages: u8) -> ExpandElement {
124        let pipeline = self.allocator.create_pipeline(item, num_stages);
125        self.add_pipeline(*pipeline);
126        pipeline
127    }
128
129    /// Create a new barrier element.
130    pub fn create_barrier(&mut self, item: Item, level: BarrierLevel) -> ExpandElement {
131        let barrier = self.allocator.create_barrier(item, level);
132        self.add_barrier(*barrier);
133        barrier
134    }
135
136    pub fn add_pipeline(&mut self, variable: Variable) {
137        self.pipelines.push(variable);
138    }
139
140    pub fn add_barrier(&mut self, variable: Variable) {
141        self.barriers.push(variable);
142    }
143
144    /// Create a mutable variable of the given [item type](Item).
145    pub fn create_local_mut<I: Into<Item>>(&mut self, item: I) -> ExpandElement {
146        self.allocator.create_local_mut(item.into())
147    }
148
149    /// Create a mutable variable of the given [item type](Item).
150    pub fn add_local_mut(&mut self, var: Variable) {
151        if !self.locals.contains(&var) {
152            self.locals.push(var);
153        }
154    }
155
156    /// Create a new restricted variable. The variable is
157    /// Useful for _for loops_ and other algorithms that require the control over initialization.
158    pub fn create_local_restricted(&mut self, item: Item) -> ExpandElement {
159        self.allocator.create_local_restricted(item)
160    }
161
162    /// Create a new immutable variable.
163    pub fn create_local(&mut self, item: Item) -> ExpandElement {
164        self.allocator.create_local(item)
165    }
166
167    /// Retrieve the last local variable that was created.
168    pub fn last_local_index(&self) -> Option<&Variable> {
169        self.locals.last()
170    }
171
172    /// Register an [operation](Operation) into the scope.
173    pub fn register<T: Into<Instruction>>(&mut self, instruction: T) {
174        let mut inst = instruction.into();
175        inst.source_loc = self.debug.source_loc.clone();
176        self.instructions.push(inst)
177    }
178
179    /// Resolve the element type of the given generic type.
180    pub fn resolve_elem<T: 'static>(&self) -> Option<Elem> {
181        let map = self.typemap.borrow();
182        let result = map.get(&TypeId::of::<T>());
183
184        result.cloned()
185    }
186
187    /// Register the element type for the given generic type.
188    pub fn register_elem<T: 'static>(&mut self, elem: Elem) {
189        let mut map = self.typemap.borrow_mut();
190
191        map.insert(TypeId::of::<T>(), elem);
192    }
193
194    /// Create an empty child scope.
195    pub fn child(&mut self) -> Self {
196        Self {
197            depth: self.depth + 1,
198            instructions: Vec::new(),
199            locals: Vec::new(),
200            matrices: Vec::new(),
201            pipelines: Vec::new(),
202            barriers: Vec::new(),
203            shared_memories: Vec::new(),
204            const_arrays: Vec::new(),
205            local_arrays: Vec::new(),
206            index_offset_with_output_layout_position: Vec::new(),
207            allocator: self.allocator.clone(),
208            debug: self.debug.clone(),
209            typemap: self.typemap.clone(),
210        }
211    }
212
213    /// Returns the variables and operations to be declared and executed.
214    ///
215    /// Notes:
216    ///
217    /// New operations and variables can be created within the same scope without having name
218    /// conflicts.
219    pub fn process(
220        &mut self,
221        processors: impl IntoIterator<Item = Box<dyn Processor>>,
222    ) -> ScopeProcessing {
223        let mut variables = core::mem::take(&mut self.locals);
224
225        for var in self.matrices.drain(..) {
226            variables.push(var);
227        }
228
229        let mut instructions = Vec::new();
230
231        for inst in self.instructions.drain(..) {
232            instructions.push(inst);
233        }
234
235        variables.extend(self.allocator.take_variables());
236
237        let mut processing = ScopeProcessing {
238            variables,
239            instructions,
240        }
241        .optimize();
242
243        for p in processors {
244            processing = p.transform(processing, self.allocator.clone());
245        }
246
247        processing
248    }
249
250    pub fn new_local_index(&self) -> u32 {
251        self.allocator.new_local_index()
252    }
253
254    /// Create a shared variable of the given [item type](Item).
255    pub fn create_shared<I: Into<Item>>(
256        &mut self,
257        item: I,
258        shared_memory_size: u32,
259        alignment: Option<u32>,
260    ) -> ExpandElement {
261        let item = item.into();
262        let index = self.new_local_index();
263        let shared_memory = Variable::new(
264            VariableKind::SharedMemory {
265                id: index,
266                length: shared_memory_size,
267                alignment,
268            },
269            item,
270        );
271        self.shared_memories.push(shared_memory);
272        ExpandElement::Plain(shared_memory)
273    }
274
275    /// Create a shared variable of the given [item type](Item).
276    pub fn create_const_array<I: Into<Item>>(
277        &mut self,
278        item: I,
279        data: Vec<Variable>,
280    ) -> ExpandElement {
281        let item = item.into();
282        let index = self.new_local_index();
283        let const_array = Variable::new(
284            VariableKind::ConstantArray {
285                id: index,
286                length: data.len() as u32,
287            },
288            item,
289        );
290        self.const_arrays.push((const_array, data));
291        ExpandElement::Plain(const_array)
292    }
293
294    /// Obtain the index-th input
295    pub fn input(&mut self, id: Id, item: Item) -> ExpandElement {
296        ExpandElement::Plain(crate::Variable::new(
297            VariableKind::GlobalInputArray(id),
298            item,
299        ))
300    }
301
302    /// Obtain the index-th output
303    pub fn output(&mut self, id: Id, item: Item) -> ExpandElement {
304        let var = crate::Variable::new(VariableKind::GlobalOutputArray(id), item);
305        ExpandElement::Plain(var)
306    }
307
308    /// Obtain the index-th scalar
309    pub fn scalar(&self, id: Id, elem: Elem) -> ExpandElement {
310        ExpandElement::Plain(crate::Variable::new(
311            VariableKind::GlobalScalar(id),
312            Item::new(elem),
313        ))
314    }
315
316    /// Create a local array of the given [item type](Item).
317    pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> ExpandElement {
318        let local_array = self.allocator.create_local_array(item.into(), array_size);
319        self.add_local_array(*local_array);
320        local_array
321    }
322
323    pub fn add_local_array(&mut self, var: Variable) {
324        self.local_arrays.push(var);
325    }
326
327    pub fn update_source(&mut self, source: CubeFnSource) {
328        if self.debug.enabled {
329            self.debug.sources.borrow_mut().insert(source.clone());
330            self.debug.source_loc = Some(SourceLoc {
331                line: source.line,
332                column: source.column,
333                source,
334            });
335            if self.debug.entry_loc.is_none() {
336                self.debug.entry_loc = self.debug.source_loc.clone();
337            }
338        }
339    }
340
341    pub fn update_span(&mut self, line: u32, col: u32) {
342        if let Some(loc) = self.debug.source_loc.as_mut() {
343            loc.line = line;
344            loc.column = col;
345        }
346    }
347
348    pub fn update_variable_name(&self, variable: Variable, name: impl Into<Cow<'static, str>>) {
349        if self.debug.enabled {
350            self.debug
351                .variable_names
352                .borrow_mut()
353                .insert(variable, name.into());
354        }
355    }
356}
357
358impl Display for Scope {
359    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
360        writeln!(f, "{{")?;
361        for instruction in self.instructions.iter() {
362            let instruction_str = instruction.to_string();
363            if !instruction_str.is_empty() {
364                writeln!(
365                    f,
366                    "{}{}",
367                    "    ".repeat(self.depth as usize + 1),
368                    instruction_str,
369                )?;
370            }
371        }
372        write!(f, "{}}}", "    ".repeat(self.depth as usize))?;
373        Ok(())
374    }
375}