cubecl_core/ir/
scope.rs

1use serde::{Deserialize, Serialize};
2
3use crate::ir::ConstantScalarValue;
4
5use super::{
6    cpa, processing::ScopeProcessing, Allocator, Elem, Id, Instruction, Item, Operation, UIntKind,
7    Variable, VariableKind,
8};
9
10/// The scope is the main [operation](Operation) and [variable](Variable) container that simplify
11/// the process of reading inputs, creating local variables and adding new operations.
12///
13/// Notes:
14///
15/// This type isn't responsible for creating [shader bindings](super::Binding) and figuring out which
16/// variable can be written to.
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[allow(missing_docs)]
19pub struct Scope {
20    pub depth: u8,
21    pub operations: Vec<Instruction>,
22    pub locals: Vec<Variable>,
23    matrices: Vec<Variable>,
24    slices: Vec<Variable>,
25    shared_memories: Vec<Variable>,
26    pub const_arrays: Vec<(Variable, Vec<Variable>)>,
27    local_arrays: Vec<Variable>,
28    reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>,
29    index_offset_with_output_layout_position: Vec<usize>,
30    writes_global: Vec<(Variable, Variable, Variable)>,
31    reads_scalar: Vec<(Variable, Variable)>,
32    pub layout_ref: Option<Variable>,
33    #[serde(skip)]
34    pub allocator: Allocator,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Hash, Eq, Serialize, Deserialize)]
38#[allow(missing_docs)]
39pub enum ReadingStrategy {
40    /// Each element will be read in a way to be compatible with the output layout.
41    OutputLayout,
42    /// Keep the current layout.
43    Plain,
44}
45
46impl Scope {
47    /// Create a scope that is at the root of a
48    /// [kernel definition](crate::ir::KernelDefinition).
49    ///
50    /// A local scope can be created with the [child](Self::child) method.
51    pub fn root() -> Self {
52        Self {
53            depth: 0,
54            operations: Vec::new(),
55            locals: Vec::new(),
56            matrices: Vec::new(),
57            slices: Vec::new(),
58            local_arrays: Vec::new(),
59            shared_memories: Vec::new(),
60            const_arrays: Vec::new(),
61            reads_global: Vec::new(),
62            index_offset_with_output_layout_position: Vec::new(),
63            writes_global: Vec::new(),
64            reads_scalar: Vec::new(),
65            layout_ref: None,
66            allocator: Allocator::default(),
67        }
68    }
69
70    /// Create a variable initialized at zero.
71    pub fn zero<I: Into<Item>>(&mut self, item: I) -> Variable {
72        let local = self.create_local(item.into());
73        let zero: Variable = 0u32.into();
74        cpa!(self, local = zero);
75        local
76    }
77
78    /// Create a variable initialized at some value.
79    pub fn create_with_value<E, I>(&mut self, value: E, item: I) -> Variable
80    where
81        E: num_traits::ToPrimitive,
82        I: Into<Item> + Copy,
83    {
84        let item: Item = item.into();
85        let value = match item.elem() {
86            Elem::Float(kind) | Elem::AtomicFloat(kind) => {
87                ConstantScalarValue::Float(value.to_f64().unwrap(), kind)
88            }
89            Elem::Int(kind) | Elem::AtomicInt(kind) => {
90                ConstantScalarValue::Int(value.to_i64().unwrap(), kind)
91            }
92            Elem::UInt(kind) | Elem::AtomicUInt(kind) => {
93                ConstantScalarValue::UInt(value.to_u64().unwrap(), kind)
94            }
95            Elem::Bool => ConstantScalarValue::Bool(value.to_u32().unwrap() == 1),
96        };
97        let local = self.create_local(item);
98        let value = Variable::constant(value);
99        cpa!(self, local = value);
100        local
101    }
102
103    pub fn add_matrix(&mut self, variable: Variable) {
104        self.matrices.push(variable);
105    }
106
107    pub fn add_slice(&mut self, slice: Variable) {
108        self.slices.push(slice);
109    }
110
111    /// Create a mutable variable of the given [item type](Item).
112    pub fn create_local_mut<I: Into<Item>>(&mut self, item: I) -> Variable {
113        let id = self.new_local_index();
114        let local = Variable::new(VariableKind::LocalMut { id }, item.into());
115        self.add_local_mut(local);
116        local
117    }
118
119    /// Create a mutable variable of the given [item type](Item).
120    pub fn add_local_mut(&mut self, var: Variable) {
121        if !self.locals.contains(&var) {
122            self.locals.push(var);
123        }
124    }
125
126    /// Create a new restricted variable. The variable is
127    /// Useful for _for loops_ and other algorithms that require the control over initialization.
128    pub fn create_local_restricted(&mut self, item: Item) -> Variable {
129        *self.allocator.create_local_restricted(item)
130    }
131
132    /// Create a new immutable variable.
133    pub fn create_local(&mut self, item: Item) -> Variable {
134        *self.allocator.create_local(item)
135    }
136
137    /// Reads an input array to a local variable.
138    ///
139    /// The index refers to the argument position of the array in the compute shader.
140    pub fn read_array<I: Into<Item>>(
141        &mut self,
142        index: Id,
143        item: I,
144        position: Variable,
145    ) -> Variable {
146        self.read_input_strategy(index, item.into(), ReadingStrategy::OutputLayout, position)
147    }
148
149    /// Reads an input scalar to a local variable.
150    ///
151    /// The index refers to the scalar position for the same [element](Elem) type.
152    pub fn read_scalar(&mut self, index: Id, elem: Elem) -> Variable {
153        let id = self.new_local_index();
154        let local = Variable::new(VariableKind::LocalConst { id }, Item::new(elem));
155        let scalar = Variable::new(VariableKind::GlobalScalar(index), Item::new(elem));
156
157        self.reads_scalar.push((local, scalar));
158
159        local
160    }
161
162    /// Retrieve the last local variable that was created.
163    pub fn last_local_index(&self) -> Option<&Variable> {
164        self.locals.last()
165    }
166
167    /// Writes a variable to given output.
168    ///
169    /// Notes:
170    ///
171    /// This should only be used when doing compilation.
172    pub fn write_global(&mut self, input: Variable, output: Variable, position: Variable) {
173        // This assumes that all outputs have the same layout
174        if self.layout_ref.is_none() {
175            self.layout_ref = Some(output);
176        }
177        self.writes_global.push((input, output, position));
178    }
179
180    /// Writes a variable to given output.
181    ///
182    /// Notes:
183    ///
184    /// This should only be used when doing compilation.
185    pub fn write_global_custom(&mut self, output: Variable) {
186        // This assumes that all outputs have the same layout
187        if self.layout_ref.is_none() {
188            self.layout_ref = Some(output);
189        }
190    }
191
192    /// Update the [reading strategy](ReadingStrategy) for an input array.
193    ///
194    /// Notes:
195    ///
196    /// This should only be used when doing compilation.
197    pub(crate) fn update_read(&mut self, index: Id, strategy: ReadingStrategy) {
198        if let Some((_, strategy_old, _, _position)) = self
199            .reads_global
200            .iter_mut()
201            .find(|(var, _, _, _)| var.index() == Some(index))
202        {
203            *strategy_old = strategy;
204        }
205    }
206
207    #[allow(dead_code)]
208    pub fn read_globals(&self) -> Vec<(Id, ReadingStrategy)> {
209        self.reads_global
210            .iter()
211            .map(|(var, strategy, _, _)| match var.kind {
212                VariableKind::GlobalInputArray(id) => (id, *strategy),
213                _ => panic!("Can only read global input arrays."),
214            })
215            .collect()
216    }
217
218    /// Register an [operation](Operation) into the scope.
219    pub fn register<T: Into<Instruction>>(&mut self, operation: T) {
220        self.operations.push(operation.into())
221    }
222
223    /// Create an empty child scope.
224    pub fn child(&mut self) -> Self {
225        Self {
226            depth: self.depth + 1,
227            operations: Vec::new(),
228            locals: Vec::new(),
229            matrices: Vec::new(),
230            slices: Vec::new(),
231            shared_memories: Vec::new(),
232            const_arrays: Vec::new(),
233            local_arrays: Vec::new(),
234            reads_global: Vec::new(),
235            index_offset_with_output_layout_position: Vec::new(),
236            writes_global: Vec::new(),
237            reads_scalar: Vec::new(),
238            layout_ref: self.layout_ref,
239            allocator: self.allocator.clone(),
240        }
241    }
242
243    /// Returns the variables and operations to be declared and executed.
244    ///
245    /// Notes:
246    ///
247    /// New operations and variables can be created within the same scope without having name
248    /// conflicts.
249    pub fn process(&mut self) -> ScopeProcessing {
250        let mut variables = core::mem::take(&mut self.locals);
251
252        for var in self.matrices.drain(..) {
253            variables.push(var);
254        }
255        for var in self.slices.drain(..) {
256            variables.push(var);
257        }
258
259        let mut operations = Vec::new();
260
261        for (local, scalar) in self.reads_scalar.drain(..) {
262            operations.push(Instruction::new(Operation::Copy(scalar), local));
263            variables.push(local);
264        }
265
266        for op in self.operations.drain(..) {
267            operations.push(op);
268        }
269
270        ScopeProcessing {
271            variables,
272            operations,
273        }
274        .optimize()
275    }
276
277    pub fn new_local_index(&self) -> u32 {
278        self.allocator.new_local_index()
279    }
280
281    fn new_shared_index(&self) -> Id {
282        self.shared_memories.len() as Id
283    }
284
285    fn new_const_array_index(&self) -> Id {
286        self.const_arrays.len() as Id
287    }
288
289    fn read_input_strategy(
290        &mut self,
291        index: Id,
292        item: Item,
293        strategy: ReadingStrategy,
294        position: Variable,
295    ) -> Variable {
296        let item_global = match item.elem() {
297            Elem::Bool => Item {
298                elem: Elem::UInt(UIntKind::U32),
299                vectorization: item.vectorization,
300            },
301            _ => item,
302        };
303        let input = Variable::new(VariableKind::GlobalInputArray(index), item_global);
304        let id = self.new_local_index();
305        let local = Variable::new(VariableKind::LocalMut { id }, item);
306        self.reads_global.push((input, strategy, local, position));
307        self.locals.push(local);
308        local
309    }
310
311    /// Create a shared variable of the given [item type](Item).
312    pub fn create_shared<I: Into<Item>>(&mut self, item: I, shared_memory_size: u32) -> Variable {
313        let item = item.into();
314        let index = self.new_shared_index();
315        let shared_memory = Variable::new(
316            VariableKind::SharedMemory {
317                id: index,
318                length: shared_memory_size,
319            },
320            item,
321        );
322        self.shared_memories.push(shared_memory);
323        shared_memory
324    }
325
326    /// Create a shared variable of the given [item type](Item).
327    pub fn create_const_array<I: Into<Item>>(&mut self, item: I, data: Vec<Variable>) -> Variable {
328        let item = item.into();
329        let index = self.new_const_array_index();
330        let const_array = Variable::new(
331            VariableKind::ConstantArray {
332                id: index,
333                length: data.len() as u32,
334            },
335            item,
336        );
337        self.const_arrays.push((const_array, data));
338        const_array
339    }
340
341    /// Create a local array of the given [item type](Item).
342    pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> Variable {
343        let local_array = self.allocator.create_local_array(item.into(), array_size);
344        self.add_local_array(*local_array);
345        *local_array
346    }
347
348    pub fn add_local_array(&mut self, var: Variable) {
349        self.local_arrays.push(var);
350    }
351}