cubecl_opt/
instructions.rs

1use cubecl_ir::{
2    Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3    Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9    pub fn visit_out(
10        &mut self,
11        var: &mut Option<Variable>,
12        mut visit_write: impl FnMut(&mut Self, &mut Variable),
13    ) {
14        if let Some(out) = var {
15            visit_write(self, out);
16        }
17    }
18
19    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
20    /// each read or written to variable.
21    pub fn visit_instruction(
22        &mut self,
23        inst: &mut Instruction,
24        visit_read: impl FnMut(&mut Self, &mut Variable),
25        visit_write: impl FnMut(&mut Self, &mut Variable),
26    ) {
27        self.visit_out(&mut inst.out, visit_write);
28        self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29    }
30
31    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
32    /// each read or written to variable.
33    pub fn visit_operation(
34        &mut self,
35        op: &mut Operation,
36        out: &mut Option<Variable>,
37        mut visit_read: impl FnMut(&mut Self, &mut Variable),
38    ) {
39        match op {
40            Operation::Copy(variable) => visit_read(self, variable),
41            Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42            Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43            Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45            Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47            // Sync has no outputs
48            Operation::Synchronization(_) => {}
49            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51            Operation::Branch(_) => unreachable!(),
52            Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53            Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54            Operation::NonSemantic(non_semantic) => {
55                self.visit_nonsemantic(non_semantic, visit_read)
56            }
57            Operation::Free(_) => {}
58        }
59    }
60
61    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
62    /// each read or written to variable.
63    pub fn visit_arithmetic(
64        &mut self,
65        op: &mut Arithmetic,
66        mut visit_read: impl FnMut(&mut Self, &mut Variable),
67    ) {
68        match op {
69            Arithmetic::Fma(fma_operator) => {
70                visit_read(self, &mut fma_operator.a);
71                visit_read(self, &mut fma_operator.b);
72                visit_read(self, &mut fma_operator.c);
73            }
74            Arithmetic::Add(binary_operator)
75            | Arithmetic::SaturatingAdd(binary_operator)
76            | Arithmetic::Sub(binary_operator)
77            | Arithmetic::SaturatingSub(binary_operator)
78            | Arithmetic::Mul(binary_operator)
79            | Arithmetic::Div(binary_operator)
80            | Arithmetic::Powf(binary_operator)
81            | Arithmetic::Powi(binary_operator)
82            | Arithmetic::Modulo(binary_operator)
83            | Arithmetic::Max(binary_operator)
84            | Arithmetic::Min(binary_operator)
85            | Arithmetic::Remainder(binary_operator)
86            | Arithmetic::Dot(binary_operator)
87            | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read),
88
89            Arithmetic::Abs(unary_operator)
90            | Arithmetic::Exp(unary_operator)
91            | Arithmetic::Log(unary_operator)
92            | Arithmetic::Log1p(unary_operator)
93            | Arithmetic::Cos(unary_operator)
94            | Arithmetic::Sin(unary_operator)
95            | Arithmetic::Tanh(unary_operator)
96            | Arithmetic::Sqrt(unary_operator)
97            | Arithmetic::Round(unary_operator)
98            | Arithmetic::Floor(unary_operator)
99            | Arithmetic::Ceil(unary_operator)
100            | Arithmetic::Trunc(unary_operator)
101            | Arithmetic::Erf(unary_operator)
102            | Arithmetic::Recip(unary_operator)
103            | Arithmetic::Neg(unary_operator)
104            | Arithmetic::Magnitude(unary_operator)
105            | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
106
107            Arithmetic::Clamp(clamp_operator) => {
108                visit_read(self, &mut clamp_operator.input);
109                visit_read(self, &mut clamp_operator.min_value);
110                visit_read(self, &mut clamp_operator.max_value);
111            }
112        }
113    }
114
115    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
116    /// each read or written to variable.
117    pub fn visit_compare(
118        &mut self,
119        op: &mut Comparison,
120        visit_read: impl FnMut(&mut Self, &mut Variable),
121    ) {
122        match op {
123            Comparison::Equal(binary_operator)
124            | Comparison::NotEqual(binary_operator)
125            | Comparison::LowerEqual(binary_operator)
126            | Comparison::Greater(binary_operator)
127            | Comparison::Lower(binary_operator)
128            | Comparison::GreaterEqual(binary_operator) => {
129                self.visit_binop(binary_operator, visit_read)
130            }
131            Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
132                self.visit_unop(unary_operator, visit_read)
133            }
134        }
135    }
136
137    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
138    /// each read or written to variable.
139    pub fn visit_bitwise(
140        &mut self,
141        op: &mut Bitwise,
142        visit_read: impl FnMut(&mut Self, &mut Variable),
143    ) {
144        match op {
145            Bitwise::BitwiseAnd(binary_operator)
146            | Bitwise::BitwiseOr(binary_operator)
147            | Bitwise::BitwiseXor(binary_operator)
148            | Bitwise::ShiftLeft(binary_operator)
149            | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
150
151            Bitwise::CountOnes(unary_operator)
152            | Bitwise::BitwiseNot(unary_operator)
153            | Bitwise::ReverseBits(unary_operator)
154            | Bitwise::LeadingZeros(unary_operator)
155            | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
156        }
157    }
158
159    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
160    /// each read or written to variable.
161    pub fn visit_operator(
162        &mut self,
163        op: &mut Operator,
164        mut visit_read: impl FnMut(&mut Self, &mut Variable),
165    ) {
166        match op {
167            Operator::And(binary_operator) | Operator::Or(binary_operator) => {
168                self.visit_binop(binary_operator, visit_read)
169            }
170            Operator::Not(unary_operator)
171            | Operator::Cast(unary_operator)
172            | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
173            Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
174                visit_read(self, &mut index_operator.list);
175                visit_read(self, &mut index_operator.index);
176            }
177            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
178                visit_read(self, &mut op.index);
179                visit_read(self, &mut op.value);
180            }
181            Operator::InitLine(line_init_operator) => {
182                for input in &mut line_init_operator.inputs {
183                    visit_read(self, input)
184                }
185            }
186            Operator::CopyMemory(copy_operator) => {
187                visit_read(self, &mut copy_operator.input);
188                visit_read(self, &mut copy_operator.in_index);
189                visit_read(self, &mut copy_operator.out_index);
190            }
191            Operator::CopyMemoryBulk(copy_bulk_operator) => {
192                visit_read(self, &mut copy_bulk_operator.input);
193                visit_read(self, &mut copy_bulk_operator.in_index);
194                visit_read(self, &mut copy_bulk_operator.out_index);
195            }
196            Operator::Select(select) => {
197                visit_read(self, &mut select.cond);
198                visit_read(self, &mut select.then);
199                visit_read(self, &mut select.or_else);
200            }
201        }
202    }
203
204    fn visit_atomic(
205        &mut self,
206        atomic: &mut AtomicOp,
207        out: &mut Option<Variable>,
208        mut visit_read: impl FnMut(&mut Self, &mut Variable),
209    ) {
210        match atomic {
211            AtomicOp::Add(binary_operator)
212            | AtomicOp::Sub(binary_operator)
213            | AtomicOp::Max(binary_operator)
214            | AtomicOp::Min(binary_operator)
215            | AtomicOp::And(binary_operator)
216            | AtomicOp::Or(binary_operator)
217            | AtomicOp::Xor(binary_operator)
218            | AtomicOp::Swap(binary_operator) => {
219                self.visit_binop(binary_operator, visit_read);
220            }
221            AtomicOp::Load(unary_operator) => {
222                self.visit_unop(unary_operator, visit_read);
223            }
224            AtomicOp::Store(unary_operator) => {
225                visit_read(self, out.as_mut().unwrap());
226                self.visit_unop(unary_operator, visit_read);
227            }
228            AtomicOp::CompareAndSwap(op) => {
229                visit_read(self, &mut op.cmp);
230                visit_read(self, &mut op.cmp);
231                visit_read(self, &mut op.val);
232            }
233        }
234    }
235    fn visit_meta(
236        &mut self,
237        metadata: &mut Metadata,
238        mut visit_read: impl FnMut(&mut Self, &mut Variable),
239    ) {
240        match metadata {
241            Metadata::Rank { var } => {
242                visit_read(self, var);
243            }
244            Metadata::Stride { dim, var } => {
245                visit_read(self, dim);
246                visit_read(self, var);
247            }
248            Metadata::Shape { dim, var } => {
249                visit_read(self, dim);
250                visit_read(self, var);
251            }
252            Metadata::Length { var } => {
253                visit_read(self, var);
254            }
255            Metadata::BufferLength { var } => {
256                visit_read(self, var);
257            }
258        }
259    }
260
261    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
262        match plane {
263            Plane::Elect => {}
264            Plane::Broadcast(binary_operator)
265            | Plane::Shuffle(binary_operator)
266            | Plane::ShuffleXor(binary_operator)
267            | Plane::ShuffleUp(binary_operator)
268            | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
269            Plane::All(unary_operator)
270            | Plane::Any(unary_operator)
271            | Plane::Sum(unary_operator)
272            | Plane::InclusiveSum(unary_operator)
273            | Plane::ExclusiveSum(unary_operator)
274            | Plane::Prod(unary_operator)
275            | Plane::InclusiveProd(unary_operator)
276            | Plane::ExclusiveProd(unary_operator)
277            | Plane::Min(unary_operator)
278            | Plane::Max(unary_operator)
279            | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
280        }
281    }
282
283    fn visit_cmma(
284        &mut self,
285        cmma: &mut CoopMma,
286        mut visit_read: impl FnMut(&mut Self, &mut Variable),
287    ) {
288        match cmma {
289            CoopMma::Fill { value } => {
290                visit_read(self, value);
291            }
292            CoopMma::Load {
293                value,
294                stride,
295                offset,
296                layout: _,
297            } => {
298                visit_read(self, value);
299                visit_read(self, stride);
300                visit_read(self, offset);
301            }
302            CoopMma::Execute {
303                mat_a,
304                mat_b,
305                mat_c,
306            } => {
307                visit_read(self, mat_a);
308                visit_read(self, mat_b);
309                visit_read(self, mat_c);
310            }
311            CoopMma::Store {
312                mat,
313                stride,
314                offset,
315                layout: _,
316            } => {
317                visit_read(self, mat);
318                visit_read(self, stride);
319                visit_read(self, offset);
320            }
321            CoopMma::Cast { input } => {
322                visit_read(self, input);
323            }
324            CoopMma::RowIndex { lane_id, i, .. } => {
325                visit_read(self, lane_id);
326                visit_read(self, i);
327            }
328            CoopMma::ColIndex { lane_id, i, .. } => {
329                visit_read(self, lane_id);
330                visit_read(self, i);
331            }
332            CoopMma::ExecuteManual {
333                registers_a,
334                registers_b,
335                registers_c,
336                ..
337            } => {
338                for reg in registers_a {
339                    visit_read(self, reg);
340                }
341                for reg in registers_b {
342                    visit_read(self, reg);
343                }
344                for reg in registers_c {
345                    visit_read(self, reg);
346                }
347            }
348            CoopMma::ExecuteScaled {
349                registers_a,
350                registers_b,
351                registers_c,
352                scales_a,
353                scales_b,
354                ..
355            } => {
356                for reg in registers_a {
357                    visit_read(self, reg);
358                }
359                for reg in registers_b {
360                    visit_read(self, reg);
361                }
362                for reg in registers_c {
363                    visit_read(self, reg);
364                }
365                visit_read(self, scales_a);
366                visit_read(self, scales_b);
367            }
368        }
369    }
370
371    fn visit_barrier(
372        &mut self,
373        barrier_ops: &mut BarrierOps,
374        mut visit_read: impl FnMut(&mut Self, &mut Variable),
375    ) {
376        match barrier_ops {
377            BarrierOps::Init { barrier, .. } => {
378                visit_read(self, barrier);
379            }
380            BarrierOps::MemCopyAsync {
381                barrier,
382                source,
383                source_length,
384                offset_source,
385                offset_out,
386            } => {
387                visit_read(self, barrier);
388                visit_read(self, source_length);
389                visit_read(self, source);
390                visit_read(self, offset_source);
391                visit_read(self, offset_out);
392            }
393            BarrierOps::TmaLoad {
394                barrier,
395                offset_out,
396                tensor_map,
397                indices,
398            } => {
399                visit_read(self, offset_out);
400                visit_read(self, barrier);
401                visit_read(self, tensor_map);
402                for index in indices {
403                    visit_read(self, index);
404                }
405            }
406            BarrierOps::TmaLoadIm2col {
407                barrier,
408                tensor_map,
409                indices,
410                offset_out,
411                offsets,
412            } => {
413                visit_read(self, offset_out);
414                visit_read(self, barrier);
415                visit_read(self, tensor_map);
416                for index in indices {
417                    visit_read(self, index);
418                }
419                for offset in offsets {
420                    visit_read(self, offset);
421                }
422            }
423            BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
424            BarrierOps::Arrive { barrier } => visit_read(self, barrier),
425            BarrierOps::ArriveTx {
426                barrier,
427                arrive_count_update,
428                transaction_count_update,
429            } => {
430                visit_read(self, barrier);
431                visit_read(self, arrive_count_update);
432                visit_read(self, transaction_count_update);
433            }
434            BarrierOps::ExpectTx {
435                barrier,
436                transaction_count_update,
437            } => {
438                visit_read(self, barrier);
439                visit_read(self, transaction_count_update);
440            }
441            BarrierOps::Wait { barrier } => {
442                visit_read(self, barrier);
443            }
444        }
445    }
446
447    fn visit_tma(
448        &mut self,
449        tma_ops: &mut TmaOps,
450        mut visit_read: impl FnMut(&mut Self, &mut Variable),
451    ) {
452        match tma_ops {
453            TmaOps::TmaStore {
454                source,
455                coordinates,
456                offset_source,
457            } => {
458                visit_read(self, source);
459                visit_read(self, offset_source);
460                for coord in coordinates {
461                    visit_read(self, coord)
462                }
463            }
464            TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
465        }
466    }
467
468    fn visit_nonsemantic(
469        &mut self,
470        non_semantic: &mut NonSemantic,
471        mut visit_read: impl FnMut(&mut Self, &mut Variable),
472    ) {
473        match non_semantic {
474            NonSemantic::Comment { .. }
475            | NonSemantic::EnterDebugScope
476            | NonSemantic::ExitDebugScope => {}
477            NonSemantic::Print { args, .. } => {
478                for arg in args {
479                    visit_read(self, arg);
480                }
481            }
482        }
483    }
484
485    fn visit_unop(
486        &mut self,
487        unop: &mut UnaryOperator,
488        mut visit_read: impl FnMut(&mut Self, &mut Variable),
489    ) {
490        visit_read(self, &mut unop.input);
491    }
492
493    fn visit_binop(
494        &mut self,
495        binop: &mut BinaryOperator,
496        mut visit_read: impl FnMut(&mut Self, &mut Variable),
497    ) {
498        visit_read(self, &mut binop.lhs);
499        visit_read(self, &mut binop.rhs);
500    }
501}