1use serde::{Deserialize, Serialize};
2
3use crate::ir::ConstantScalarValue;
4
5use super::{
6 cpa, processing::ScopeProcessing, Allocator, Elem, Id, Instruction, Item, Operation, UIntKind,
7 Variable, VariableKind,
8};
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[allow(missing_docs)]
19pub struct Scope {
20 pub depth: u8,
21 pub operations: Vec<Instruction>,
22 pub locals: Vec<Variable>,
23 matrices: Vec<Variable>,
24 slices: Vec<Variable>,
25 shared_memories: Vec<Variable>,
26 pub const_arrays: Vec<(Variable, Vec<Variable>)>,
27 local_arrays: Vec<Variable>,
28 reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>,
29 index_offset_with_output_layout_position: Vec<usize>,
30 writes_global: Vec<(Variable, Variable, Variable)>,
31 reads_scalar: Vec<(Variable, Variable)>,
32 pub layout_ref: Option<Variable>,
33 #[serde(skip)]
34 pub allocator: Allocator,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Hash, Eq, Serialize, Deserialize)]
38#[allow(missing_docs)]
39pub enum ReadingStrategy {
40 OutputLayout,
42 Plain,
44}
45
46impl Scope {
47 pub fn root() -> Self {
52 Self {
53 depth: 0,
54 operations: Vec::new(),
55 locals: Vec::new(),
56 matrices: Vec::new(),
57 slices: Vec::new(),
58 local_arrays: Vec::new(),
59 shared_memories: Vec::new(),
60 const_arrays: Vec::new(),
61 reads_global: Vec::new(),
62 index_offset_with_output_layout_position: Vec::new(),
63 writes_global: Vec::new(),
64 reads_scalar: Vec::new(),
65 layout_ref: None,
66 allocator: Allocator::default(),
67 }
68 }
69
70 pub fn zero<I: Into<Item>>(&mut self, item: I) -> Variable {
72 let local = self.create_local(item.into());
73 let zero: Variable = 0u32.into();
74 cpa!(self, local = zero);
75 local
76 }
77
78 pub fn create_with_value<E, I>(&mut self, value: E, item: I) -> Variable
80 where
81 E: num_traits::ToPrimitive,
82 I: Into<Item> + Copy,
83 {
84 let item: Item = item.into();
85 let value = match item.elem() {
86 Elem::Float(kind) | Elem::AtomicFloat(kind) => {
87 ConstantScalarValue::Float(value.to_f64().unwrap(), kind)
88 }
89 Elem::Int(kind) | Elem::AtomicInt(kind) => {
90 ConstantScalarValue::Int(value.to_i64().unwrap(), kind)
91 }
92 Elem::UInt(kind) | Elem::AtomicUInt(kind) => {
93 ConstantScalarValue::UInt(value.to_u64().unwrap(), kind)
94 }
95 Elem::Bool => ConstantScalarValue::Bool(value.to_u32().unwrap() == 1),
96 };
97 let local = self.create_local(item);
98 let value = Variable::constant(value);
99 cpa!(self, local = value);
100 local
101 }
102
103 pub fn add_matrix(&mut self, variable: Variable) {
104 self.matrices.push(variable);
105 }
106
107 pub fn add_slice(&mut self, slice: Variable) {
108 self.slices.push(slice);
109 }
110
111 pub fn create_local_mut<I: Into<Item>>(&mut self, item: I) -> Variable {
113 let id = self.new_local_index();
114 let local = Variable::new(VariableKind::LocalMut { id }, item.into());
115 self.add_local_mut(local);
116 local
117 }
118
119 pub fn add_local_mut(&mut self, var: Variable) {
121 if !self.locals.contains(&var) {
122 self.locals.push(var);
123 }
124 }
125
126 pub fn create_local_restricted(&mut self, item: Item) -> Variable {
129 *self.allocator.create_local_restricted(item)
130 }
131
132 pub fn create_local(&mut self, item: Item) -> Variable {
134 *self.allocator.create_local(item)
135 }
136
137 pub fn read_array<I: Into<Item>>(
141 &mut self,
142 index: Id,
143 item: I,
144 position: Variable,
145 ) -> Variable {
146 self.read_input_strategy(index, item.into(), ReadingStrategy::OutputLayout, position)
147 }
148
149 pub fn read_scalar(&mut self, index: Id, elem: Elem) -> Variable {
153 let id = self.new_local_index();
154 let local = Variable::new(VariableKind::LocalConst { id }, Item::new(elem));
155 let scalar = Variable::new(VariableKind::GlobalScalar(index), Item::new(elem));
156
157 self.reads_scalar.push((local, scalar));
158
159 local
160 }
161
162 pub fn last_local_index(&self) -> Option<&Variable> {
164 self.locals.last()
165 }
166
167 pub fn write_global(&mut self, input: Variable, output: Variable, position: Variable) {
173 if self.layout_ref.is_none() {
175 self.layout_ref = Some(output);
176 }
177 self.writes_global.push((input, output, position));
178 }
179
180 pub fn write_global_custom(&mut self, output: Variable) {
186 if self.layout_ref.is_none() {
188 self.layout_ref = Some(output);
189 }
190 }
191
192 pub(crate) fn update_read(&mut self, index: Id, strategy: ReadingStrategy) {
198 if let Some((_, strategy_old, _, _position)) = self
199 .reads_global
200 .iter_mut()
201 .find(|(var, _, _, _)| var.index() == Some(index))
202 {
203 *strategy_old = strategy;
204 }
205 }
206
207 #[allow(dead_code)]
208 pub fn read_globals(&self) -> Vec<(Id, ReadingStrategy)> {
209 self.reads_global
210 .iter()
211 .map(|(var, strategy, _, _)| match var.kind {
212 VariableKind::GlobalInputArray(id) => (id, *strategy),
213 _ => panic!("Can only read global input arrays."),
214 })
215 .collect()
216 }
217
218 pub fn register<T: Into<Instruction>>(&mut self, operation: T) {
220 self.operations.push(operation.into())
221 }
222
223 pub fn child(&mut self) -> Self {
225 Self {
226 depth: self.depth + 1,
227 operations: Vec::new(),
228 locals: Vec::new(),
229 matrices: Vec::new(),
230 slices: Vec::new(),
231 shared_memories: Vec::new(),
232 const_arrays: Vec::new(),
233 local_arrays: Vec::new(),
234 reads_global: Vec::new(),
235 index_offset_with_output_layout_position: Vec::new(),
236 writes_global: Vec::new(),
237 reads_scalar: Vec::new(),
238 layout_ref: self.layout_ref,
239 allocator: self.allocator.clone(),
240 }
241 }
242
243 pub fn process(&mut self) -> ScopeProcessing {
250 let mut variables = core::mem::take(&mut self.locals);
251
252 for var in self.matrices.drain(..) {
253 variables.push(var);
254 }
255 for var in self.slices.drain(..) {
256 variables.push(var);
257 }
258
259 let mut operations = Vec::new();
260
261 for (local, scalar) in self.reads_scalar.drain(..) {
262 operations.push(Instruction::new(Operation::Copy(scalar), local));
263 variables.push(local);
264 }
265
266 for op in self.operations.drain(..) {
267 operations.push(op);
268 }
269
270 ScopeProcessing {
271 variables,
272 operations,
273 }
274 .optimize()
275 }
276
277 pub fn new_local_index(&self) -> u32 {
278 self.allocator.new_local_index()
279 }
280
281 fn new_shared_index(&self) -> Id {
282 self.shared_memories.len() as Id
283 }
284
285 fn new_const_array_index(&self) -> Id {
286 self.const_arrays.len() as Id
287 }
288
289 fn read_input_strategy(
290 &mut self,
291 index: Id,
292 item: Item,
293 strategy: ReadingStrategy,
294 position: Variable,
295 ) -> Variable {
296 let item_global = match item.elem() {
297 Elem::Bool => Item {
298 elem: Elem::UInt(UIntKind::U32),
299 vectorization: item.vectorization,
300 },
301 _ => item,
302 };
303 let input = Variable::new(VariableKind::GlobalInputArray(index), item_global);
304 let id = self.new_local_index();
305 let local = Variable::new(VariableKind::LocalMut { id }, item);
306 self.reads_global.push((input, strategy, local, position));
307 self.locals.push(local);
308 local
309 }
310
311 pub fn create_shared<I: Into<Item>>(&mut self, item: I, shared_memory_size: u32) -> Variable {
313 let item = item.into();
314 let index = self.new_shared_index();
315 let shared_memory = Variable::new(
316 VariableKind::SharedMemory {
317 id: index,
318 length: shared_memory_size,
319 },
320 item,
321 );
322 self.shared_memories.push(shared_memory);
323 shared_memory
324 }
325
326 pub fn create_const_array<I: Into<Item>>(&mut self, item: I, data: Vec<Variable>) -> Variable {
328 let item = item.into();
329 let index = self.new_const_array_index();
330 let const_array = Variable::new(
331 VariableKind::ConstantArray {
332 id: index,
333 length: data.len() as u32,
334 },
335 item,
336 );
337 self.const_arrays.push((const_array, data));
338 const_array
339 }
340
341 pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> Variable {
343 let local_array = self.allocator.create_local_array(item.into(), array_size);
344 self.add_local_array(*local_array);
345 *local_array
346 }
347
348 pub fn add_local_array(&mut self, var: Variable) {
349 self.local_arrays.push(var);
350 }
351}