cubecl_opt/
instructions.rs

1use cubecl_ir::{
2    Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3    Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9    pub fn visit_out(
10        &mut self,
11        var: &mut Option<Variable>,
12        mut visit_write: impl FnMut(&mut Self, &mut Variable),
13    ) {
14        if let Some(out) = var {
15            visit_write(self, out);
16        }
17    }
18
19    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
20    /// each read or written to variable.
21    pub fn visit_instruction(
22        &mut self,
23        inst: &mut Instruction,
24        visit_read: impl FnMut(&mut Self, &mut Variable),
25        visit_write: impl FnMut(&mut Self, &mut Variable),
26    ) {
27        self.visit_out(&mut inst.out, visit_write);
28        self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29    }
30
31    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
32    /// each read or written to variable.
33    pub fn visit_operation(
34        &mut self,
35        op: &mut Operation,
36        out: &mut Option<Variable>,
37        mut visit_read: impl FnMut(&mut Self, &mut Variable),
38    ) {
39        match op {
40            Operation::Copy(variable) => visit_read(self, variable),
41            Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42            Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43            Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45            Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47            // Sync has no outputs
48            Operation::Synchronization(_) => {}
49            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51            Operation::Branch(_) => unreachable!(),
52            Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53            Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54            Operation::NonSemantic(non_semantic) => {
55                self.visit_nonsemantic(non_semantic, visit_read)
56            }
57        }
58    }
59
60    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
61    /// each read or written to variable.
62    pub fn visit_arithmetic(
63        &mut self,
64        op: &mut Arithmetic,
65        mut visit_read: impl FnMut(&mut Self, &mut Variable),
66    ) {
67        match op {
68            Arithmetic::Fma(fma_operator) => {
69                visit_read(self, &mut fma_operator.a);
70                visit_read(self, &mut fma_operator.b);
71                visit_read(self, &mut fma_operator.c);
72            }
73            Arithmetic::Add(binary_operator)
74            | Arithmetic::Sub(binary_operator)
75            | Arithmetic::Mul(binary_operator)
76            | Arithmetic::Div(binary_operator)
77            | Arithmetic::Powf(binary_operator)
78            | Arithmetic::Modulo(binary_operator)
79            | Arithmetic::Max(binary_operator)
80            | Arithmetic::Min(binary_operator)
81            | Arithmetic::Remainder(binary_operator)
82            | Arithmetic::Dot(binary_operator)
83            | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read),
84
85            Arithmetic::Abs(unary_operator)
86            | Arithmetic::Exp(unary_operator)
87            | Arithmetic::Log(unary_operator)
88            | Arithmetic::Log1p(unary_operator)
89            | Arithmetic::Cos(unary_operator)
90            | Arithmetic::Sin(unary_operator)
91            | Arithmetic::Tanh(unary_operator)
92            | Arithmetic::Sqrt(unary_operator)
93            | Arithmetic::Round(unary_operator)
94            | Arithmetic::Floor(unary_operator)
95            | Arithmetic::Ceil(unary_operator)
96            | Arithmetic::Erf(unary_operator)
97            | Arithmetic::Recip(unary_operator)
98            | Arithmetic::Neg(unary_operator)
99            | Arithmetic::Magnitude(unary_operator)
100            | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
101
102            Arithmetic::Clamp(clamp_operator) => {
103                visit_read(self, &mut clamp_operator.input);
104                visit_read(self, &mut clamp_operator.min_value);
105                visit_read(self, &mut clamp_operator.max_value);
106            }
107        }
108    }
109
110    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
111    /// each read or written to variable.
112    pub fn visit_compare(
113        &mut self,
114        op: &mut Comparison,
115        visit_read: impl FnMut(&mut Self, &mut Variable),
116    ) {
117        match op {
118            Comparison::Equal(binary_operator)
119            | Comparison::NotEqual(binary_operator)
120            | Comparison::LowerEqual(binary_operator)
121            | Comparison::Greater(binary_operator)
122            | Comparison::Lower(binary_operator)
123            | Comparison::GreaterEqual(binary_operator) => {
124                self.visit_binop(binary_operator, visit_read)
125            }
126        }
127    }
128
129    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
130    /// each read or written to variable.
131    pub fn visit_bitwise(
132        &mut self,
133        op: &mut Bitwise,
134        visit_read: impl FnMut(&mut Self, &mut Variable),
135    ) {
136        match op {
137            Bitwise::BitwiseAnd(binary_operator)
138            | Bitwise::BitwiseOr(binary_operator)
139            | Bitwise::BitwiseXor(binary_operator)
140            | Bitwise::ShiftLeft(binary_operator)
141            | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
142
143            Bitwise::CountOnes(unary_operator)
144            | Bitwise::BitwiseNot(unary_operator)
145            | Bitwise::ReverseBits(unary_operator)
146            | Bitwise::LeadingZeros(unary_operator)
147            | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
148        }
149    }
150
151    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
152    /// each read or written to variable.
153    pub fn visit_operator(
154        &mut self,
155        op: &mut Operator,
156        mut visit_read: impl FnMut(&mut Self, &mut Variable),
157    ) {
158        match op {
159            Operator::And(binary_operator) | Operator::Or(binary_operator) => {
160                self.visit_binop(binary_operator, visit_read)
161            }
162            Operator::Not(unary_operator)
163            | Operator::Cast(unary_operator)
164            | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
165            Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
166                visit_read(self, &mut index_operator.list);
167                visit_read(self, &mut index_operator.index);
168            }
169            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
170                visit_read(self, &mut op.index);
171                visit_read(self, &mut op.value);
172            }
173            Operator::InitLine(line_init_operator) => {
174                for input in &mut line_init_operator.inputs {
175                    visit_read(self, input)
176                }
177            }
178            Operator::CopyMemory(copy_operator) => {
179                visit_read(self, &mut copy_operator.input);
180                visit_read(self, &mut copy_operator.in_index);
181                visit_read(self, &mut copy_operator.out_index);
182            }
183            Operator::CopyMemoryBulk(copy_bulk_operator) => {
184                visit_read(self, &mut copy_bulk_operator.input);
185                visit_read(self, &mut copy_bulk_operator.in_index);
186                visit_read(self, &mut copy_bulk_operator.out_index);
187            }
188            Operator::Select(select) => {
189                visit_read(self, &mut select.cond);
190                visit_read(self, &mut select.then);
191                visit_read(self, &mut select.or_else);
192            }
193        }
194    }
195
196    fn visit_atomic(
197        &mut self,
198        atomic: &mut AtomicOp,
199        out: &mut Option<Variable>,
200        mut visit_read: impl FnMut(&mut Self, &mut Variable),
201    ) {
202        match atomic {
203            AtomicOp::Add(binary_operator)
204            | AtomicOp::Sub(binary_operator)
205            | AtomicOp::Max(binary_operator)
206            | AtomicOp::Min(binary_operator)
207            | AtomicOp::And(binary_operator)
208            | AtomicOp::Or(binary_operator)
209            | AtomicOp::Xor(binary_operator)
210            | AtomicOp::Swap(binary_operator) => {
211                self.visit_binop(binary_operator, visit_read);
212            }
213            AtomicOp::Load(unary_operator) => {
214                self.visit_unop(unary_operator, visit_read);
215            }
216            AtomicOp::Store(unary_operator) => {
217                visit_read(self, out.as_mut().unwrap());
218                self.visit_unop(unary_operator, visit_read);
219            }
220            AtomicOp::CompareAndSwap(op) => {
221                visit_read(self, &mut op.cmp);
222                visit_read(self, &mut op.cmp);
223                visit_read(self, &mut op.val);
224            }
225        }
226    }
227    fn visit_meta(
228        &mut self,
229        metadata: &mut Metadata,
230        mut visit_read: impl FnMut(&mut Self, &mut Variable),
231    ) {
232        match metadata {
233            Metadata::Rank { var } => {
234                visit_read(self, var);
235            }
236            Metadata::Stride { dim, var } => {
237                visit_read(self, dim);
238                visit_read(self, var);
239            }
240            Metadata::Shape { dim, var } => {
241                visit_read(self, dim);
242                visit_read(self, var);
243            }
244            Metadata::Length { var } => {
245                visit_read(self, var);
246            }
247            Metadata::BufferLength { var } => {
248                visit_read(self, var);
249            }
250        }
251    }
252
253    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
254        match plane {
255            Plane::Elect => {}
256            Plane::Broadcast(binary_operator) => self.visit_binop(binary_operator, visit_read),
257            Plane::All(unary_operator)
258            | Plane::Any(unary_operator)
259            | Plane::Sum(unary_operator)
260            | Plane::InclusiveSum(unary_operator)
261            | Plane::ExclusiveSum(unary_operator)
262            | Plane::Prod(unary_operator)
263            | Plane::InclusiveProd(unary_operator)
264            | Plane::ExclusiveProd(unary_operator)
265            | Plane::Min(unary_operator)
266            | Plane::Max(unary_operator)
267            | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
268        }
269    }
270
271    fn visit_cmma(
272        &mut self,
273        cmma: &mut CoopMma,
274        mut visit_read: impl FnMut(&mut Self, &mut Variable),
275    ) {
276        match cmma {
277            CoopMma::Fill { value } => {
278                visit_read(self, value);
279            }
280            CoopMma::Load {
281                value,
282                stride,
283                offset,
284                layout: _,
285            } => {
286                visit_read(self, value);
287                visit_read(self, stride);
288                visit_read(self, offset);
289            }
290            CoopMma::Execute {
291                mat_a,
292                mat_b,
293                mat_c,
294            } => {
295                visit_read(self, mat_a);
296                visit_read(self, mat_b);
297                visit_read(self, mat_c);
298            }
299            CoopMma::Store {
300                mat,
301                stride,
302                offset,
303                layout: _,
304            } => {
305                visit_read(self, mat);
306                visit_read(self, stride);
307                visit_read(self, offset);
308            }
309            CoopMma::Cast { input } => {
310                visit_read(self, input);
311            }
312        }
313    }
314
315    fn visit_barrier(
316        &mut self,
317        barrier_ops: &mut BarrierOps,
318        mut visit_read: impl FnMut(&mut Self, &mut Variable),
319    ) {
320        match barrier_ops {
321            BarrierOps::Init { barrier, .. } => {
322                visit_read(self, barrier);
323            }
324            BarrierOps::MemCopyAsync {
325                barrier,
326                source,
327                source_length,
328                offset_source,
329                offset_out,
330            } => {
331                visit_read(self, barrier);
332                visit_read(self, source_length);
333                visit_read(self, source);
334                visit_read(self, offset_source);
335                visit_read(self, offset_out);
336            }
337            BarrierOps::TmaLoad {
338                barrier,
339                offset_out,
340                tensor_map,
341                indices,
342            } => {
343                visit_read(self, offset_out);
344                visit_read(self, barrier);
345                visit_read(self, tensor_map);
346                for index in indices {
347                    visit_read(self, index);
348                }
349            }
350            BarrierOps::TmaLoadIm2col {
351                barrier,
352                tensor_map,
353                indices,
354                offset_out,
355                offsets,
356            } => {
357                visit_read(self, offset_out);
358                visit_read(self, barrier);
359                visit_read(self, tensor_map);
360                for index in indices {
361                    visit_read(self, index);
362                }
363                for offset in offsets {
364                    visit_read(self, offset);
365                }
366            }
367            BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
368            BarrierOps::Arrive { barrier } => visit_read(self, barrier),
369            BarrierOps::ArriveTx {
370                barrier,
371                arrive_count_update,
372                transaction_count_update,
373            } => {
374                visit_read(self, barrier);
375                visit_read(self, arrive_count_update);
376                visit_read(self, transaction_count_update);
377            }
378            BarrierOps::ExpectTx {
379                barrier,
380                transaction_count_update,
381            } => {
382                visit_read(self, barrier);
383                visit_read(self, transaction_count_update);
384            }
385            BarrierOps::Wait { barrier } => {
386                visit_read(self, barrier);
387            }
388        }
389    }
390
391    fn visit_tma(
392        &mut self,
393        tma_ops: &mut TmaOps,
394        mut visit_read: impl FnMut(&mut Self, &mut Variable),
395    ) {
396        match tma_ops {
397            TmaOps::TmaStore {
398                source,
399                coordinates,
400                offset_source,
401            } => {
402                visit_read(self, source);
403                visit_read(self, offset_source);
404                for coord in coordinates {
405                    visit_read(self, coord)
406                }
407            }
408            TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
409        }
410    }
411
412    fn visit_nonsemantic(
413        &mut self,
414        non_semantic: &mut NonSemantic,
415        mut visit_read: impl FnMut(&mut Self, &mut Variable),
416    ) {
417        match non_semantic {
418            NonSemantic::Comment { .. }
419            | NonSemantic::EnterDebugScope
420            | NonSemantic::ExitDebugScope => {}
421            NonSemantic::Print { args, .. } => {
422                for arg in args {
423                    visit_read(self, arg);
424                }
425            }
426        }
427    }
428
429    fn visit_unop(
430        &mut self,
431        unop: &mut UnaryOperator,
432        mut visit_read: impl FnMut(&mut Self, &mut Variable),
433    ) {
434        visit_read(self, &mut unop.input);
435    }
436
437    fn visit_binop(
438        &mut self,
439        binop: &mut BinaryOperator,
440        mut visit_read: impl FnMut(&mut Self, &mut Variable),
441    ) {
442        visit_read(self, &mut binop.lhs);
443        visit_read(self, &mut binop.rhs);
444    }
445}