cubecl_opt/
instructions.rs

1use cubecl_ir::{
2    Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3    Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9    pub fn visit_out(
10        &mut self,
11        var: &mut Option<Variable>,
12        mut visit_write: impl FnMut(&mut Self, &mut Variable),
13    ) {
14        if let Some(out) = var {
15            visit_write(self, out);
16        }
17    }
18
19    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
20    /// each read or written to variable.
21    pub fn visit_instruction(
22        &mut self,
23        inst: &mut Instruction,
24        visit_read: impl FnMut(&mut Self, &mut Variable),
25        visit_write: impl FnMut(&mut Self, &mut Variable),
26    ) {
27        self.visit_out(&mut inst.out, visit_write);
28        self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29    }
30
31    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
32    /// each read or written to variable.
33    pub fn visit_operation(
34        &mut self,
35        op: &mut Operation,
36        out: &mut Option<Variable>,
37        mut visit_read: impl FnMut(&mut Self, &mut Variable),
38    ) {
39        match op {
40            Operation::Copy(variable) => visit_read(self, variable),
41            Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42            Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43            Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45            Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47            // Sync has no outputs
48            Operation::Synchronization(_) => {}
49            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51            Operation::Branch(_) => unreachable!(),
52            Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53            Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54            Operation::NonSemantic(non_semantic) => {
55                self.visit_nonsemantic(non_semantic, visit_read)
56            }
57            Operation::Marker(_) => {}
58        }
59    }
60
61    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
62    /// each read or written to variable.
63    pub fn visit_arithmetic(
64        &mut self,
65        op: &mut Arithmetic,
66        mut visit_read: impl FnMut(&mut Self, &mut Variable),
67    ) {
68        match op {
69            Arithmetic::Fma(fma_operator) => {
70                visit_read(self, &mut fma_operator.a);
71                visit_read(self, &mut fma_operator.b);
72                visit_read(self, &mut fma_operator.c);
73            }
74            Arithmetic::Add(binary_operator)
75            | Arithmetic::SaturatingAdd(binary_operator)
76            | Arithmetic::Sub(binary_operator)
77            | Arithmetic::SaturatingSub(binary_operator)
78            | Arithmetic::Mul(binary_operator)
79            | Arithmetic::Div(binary_operator)
80            | Arithmetic::Powf(binary_operator)
81            | Arithmetic::Powi(binary_operator)
82            | Arithmetic::Modulo(binary_operator)
83            | Arithmetic::Max(binary_operator)
84            | Arithmetic::Min(binary_operator)
85            | Arithmetic::Remainder(binary_operator)
86            | Arithmetic::Dot(binary_operator)
87            | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read),
88
89            Arithmetic::Abs(unary_operator)
90            | Arithmetic::Exp(unary_operator)
91            | Arithmetic::Log(unary_operator)
92            | Arithmetic::Log1p(unary_operator)
93            | Arithmetic::Cos(unary_operator)
94            | Arithmetic::Sin(unary_operator)
95            | Arithmetic::Tanh(unary_operator)
96            | Arithmetic::Sqrt(unary_operator)
97            | Arithmetic::InverseSqrt(unary_operator)
98            | Arithmetic::Round(unary_operator)
99            | Arithmetic::Floor(unary_operator)
100            | Arithmetic::Ceil(unary_operator)
101            | Arithmetic::Trunc(unary_operator)
102            | Arithmetic::Erf(unary_operator)
103            | Arithmetic::Recip(unary_operator)
104            | Arithmetic::Neg(unary_operator)
105            | Arithmetic::Magnitude(unary_operator)
106            | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
107
108            Arithmetic::Clamp(clamp_operator) => {
109                visit_read(self, &mut clamp_operator.input);
110                visit_read(self, &mut clamp_operator.min_value);
111                visit_read(self, &mut clamp_operator.max_value);
112            }
113        }
114    }
115
116    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
117    /// each read or written to variable.
118    pub fn visit_compare(
119        &mut self,
120        op: &mut Comparison,
121        visit_read: impl FnMut(&mut Self, &mut Variable),
122    ) {
123        match op {
124            Comparison::Equal(binary_operator)
125            | Comparison::NotEqual(binary_operator)
126            | Comparison::LowerEqual(binary_operator)
127            | Comparison::Greater(binary_operator)
128            | Comparison::Lower(binary_operator)
129            | Comparison::GreaterEqual(binary_operator) => {
130                self.visit_binop(binary_operator, visit_read)
131            }
132            Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
133                self.visit_unop(unary_operator, visit_read)
134            }
135        }
136    }
137
138    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
139    /// each read or written to variable.
140    pub fn visit_bitwise(
141        &mut self,
142        op: &mut Bitwise,
143        visit_read: impl FnMut(&mut Self, &mut Variable),
144    ) {
145        match op {
146            Bitwise::BitwiseAnd(binary_operator)
147            | Bitwise::BitwiseOr(binary_operator)
148            | Bitwise::BitwiseXor(binary_operator)
149            | Bitwise::ShiftLeft(binary_operator)
150            | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
151
152            Bitwise::CountOnes(unary_operator)
153            | Bitwise::BitwiseNot(unary_operator)
154            | Bitwise::ReverseBits(unary_operator)
155            | Bitwise::LeadingZeros(unary_operator)
156            | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
157        }
158    }
159
160    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
161    /// each read or written to variable.
162    pub fn visit_operator(
163        &mut self,
164        op: &mut Operator,
165        mut visit_read: impl FnMut(&mut Self, &mut Variable),
166    ) {
167        match op {
168            Operator::And(binary_operator) | Operator::Or(binary_operator) => {
169                self.visit_binop(binary_operator, visit_read)
170            }
171            Operator::Not(unary_operator)
172            | Operator::Cast(unary_operator)
173            | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
174            Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
175                visit_read(self, &mut index_operator.list);
176                visit_read(self, &mut index_operator.index);
177            }
178            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
179                visit_read(self, &mut op.index);
180                visit_read(self, &mut op.value);
181            }
182            Operator::InitLine(line_init_operator) => {
183                for input in &mut line_init_operator.inputs {
184                    visit_read(self, input)
185                }
186            }
187            Operator::CopyMemory(copy_operator) => {
188                visit_read(self, &mut copy_operator.input);
189                visit_read(self, &mut copy_operator.in_index);
190                visit_read(self, &mut copy_operator.out_index);
191            }
192            Operator::CopyMemoryBulk(copy_bulk_operator) => {
193                visit_read(self, &mut copy_bulk_operator.input);
194                visit_read(self, &mut copy_bulk_operator.in_index);
195                visit_read(self, &mut copy_bulk_operator.out_index);
196            }
197            Operator::Select(select) => {
198                visit_read(self, &mut select.cond);
199                visit_read(self, &mut select.then);
200                visit_read(self, &mut select.or_else);
201            }
202        }
203    }
204
205    fn visit_atomic(
206        &mut self,
207        atomic: &mut AtomicOp,
208        out: &mut Option<Variable>,
209        mut visit_read: impl FnMut(&mut Self, &mut Variable),
210    ) {
211        match atomic {
212            AtomicOp::Add(binary_operator)
213            | AtomicOp::Sub(binary_operator)
214            | AtomicOp::Max(binary_operator)
215            | AtomicOp::Min(binary_operator)
216            | AtomicOp::And(binary_operator)
217            | AtomicOp::Or(binary_operator)
218            | AtomicOp::Xor(binary_operator)
219            | AtomicOp::Swap(binary_operator) => {
220                self.visit_binop(binary_operator, visit_read);
221            }
222            AtomicOp::Load(unary_operator) => {
223                self.visit_unop(unary_operator, visit_read);
224            }
225            AtomicOp::Store(unary_operator) => {
226                visit_read(self, out.as_mut().unwrap());
227                self.visit_unop(unary_operator, visit_read);
228            }
229            AtomicOp::CompareAndSwap(op) => {
230                visit_read(self, &mut op.cmp);
231                visit_read(self, &mut op.cmp);
232                visit_read(self, &mut op.val);
233            }
234        }
235    }
236    fn visit_meta(
237        &mut self,
238        metadata: &mut Metadata,
239        mut visit_read: impl FnMut(&mut Self, &mut Variable),
240    ) {
241        match metadata {
242            Metadata::Rank { var } => {
243                visit_read(self, var);
244            }
245            Metadata::Stride { dim, var } => {
246                visit_read(self, dim);
247                visit_read(self, var);
248            }
249            Metadata::Shape { dim, var } => {
250                visit_read(self, dim);
251                visit_read(self, var);
252            }
253            Metadata::Length { var } => {
254                visit_read(self, var);
255            }
256            Metadata::BufferLength { var } => {
257                visit_read(self, var);
258            }
259        }
260    }
261
262    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
263        match plane {
264            Plane::Elect => {}
265            Plane::Broadcast(binary_operator)
266            | Plane::Shuffle(binary_operator)
267            | Plane::ShuffleXor(binary_operator)
268            | Plane::ShuffleUp(binary_operator)
269            | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
270            Plane::All(unary_operator)
271            | Plane::Any(unary_operator)
272            | Plane::Sum(unary_operator)
273            | Plane::InclusiveSum(unary_operator)
274            | Plane::ExclusiveSum(unary_operator)
275            | Plane::Prod(unary_operator)
276            | Plane::InclusiveProd(unary_operator)
277            | Plane::ExclusiveProd(unary_operator)
278            | Plane::Min(unary_operator)
279            | Plane::Max(unary_operator)
280            | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
281        }
282    }
283
284    fn visit_cmma(
285        &mut self,
286        cmma: &mut CoopMma,
287        mut visit_read: impl FnMut(&mut Self, &mut Variable),
288    ) {
289        match cmma {
290            CoopMma::Fill { value } => {
291                visit_read(self, value);
292            }
293            CoopMma::Load {
294                value,
295                stride,
296                offset,
297                layout: _,
298            } => {
299                visit_read(self, value);
300                visit_read(self, stride);
301                visit_read(self, offset);
302            }
303            CoopMma::Execute {
304                mat_a,
305                mat_b,
306                mat_c,
307            } => {
308                visit_read(self, mat_a);
309                visit_read(self, mat_b);
310                visit_read(self, mat_c);
311            }
312            CoopMma::Store {
313                mat,
314                stride,
315                offset,
316                layout: _,
317            } => {
318                visit_read(self, mat);
319                visit_read(self, stride);
320                visit_read(self, offset);
321            }
322            CoopMma::Cast { input } => {
323                visit_read(self, input);
324            }
325            CoopMma::RowIndex { lane_id, i, .. } => {
326                visit_read(self, lane_id);
327                visit_read(self, i);
328            }
329            CoopMma::ColIndex { lane_id, i, .. } => {
330                visit_read(self, lane_id);
331                visit_read(self, i);
332            }
333            CoopMma::ExecuteManual {
334                registers_a,
335                registers_b,
336                registers_c,
337                ..
338            } => {
339                for reg in registers_a {
340                    visit_read(self, reg);
341                }
342                for reg in registers_b {
343                    visit_read(self, reg);
344                }
345                for reg in registers_c {
346                    visit_read(self, reg);
347                }
348            }
349            CoopMma::ExecuteScaled {
350                registers_a,
351                registers_b,
352                registers_c,
353                scales_a,
354                scales_b,
355                ..
356            } => {
357                for reg in registers_a {
358                    visit_read(self, reg);
359                }
360                for reg in registers_b {
361                    visit_read(self, reg);
362                }
363                for reg in registers_c {
364                    visit_read(self, reg);
365                }
366                visit_read(self, scales_a);
367                visit_read(self, scales_b);
368            }
369        }
370    }
371
372    fn visit_barrier(
373        &mut self,
374        barrier_ops: &mut BarrierOps,
375        mut visit_read: impl FnMut(&mut Self, &mut Variable),
376    ) {
377        match barrier_ops {
378            BarrierOps::Declare { barrier } => visit_read(self, barrier),
379            BarrierOps::Init {
380                barrier,
381                is_elected,
382                arrival_count,
383                ..
384            } => {
385                visit_read(self, barrier);
386                visit_read(self, is_elected);
387                visit_read(self, arrival_count);
388            }
389            BarrierOps::InitManual {
390                barrier,
391                arrival_count,
392            } => {
393                visit_read(self, barrier);
394                visit_read(self, arrival_count);
395            }
396            BarrierOps::MemCopyAsync {
397                barrier,
398                source,
399                source_length,
400                offset_source,
401                offset_out,
402            } => {
403                visit_read(self, barrier);
404                visit_read(self, source_length);
405                visit_read(self, source);
406                visit_read(self, offset_source);
407                visit_read(self, offset_out);
408            }
409            BarrierOps::MemCopyAsyncCooperative {
410                barrier,
411                source,
412                source_length,
413                offset_source,
414                offset_out,
415            } => {
416                visit_read(self, barrier);
417                visit_read(self, source_length);
418                visit_read(self, source);
419                visit_read(self, offset_source);
420                visit_read(self, offset_out);
421            }
422            BarrierOps::MemCopyAsyncTx {
423                barrier,
424                source,
425                source_length,
426                offset_source,
427                offset_out,
428            } => {
429                visit_read(self, barrier);
430                visit_read(self, source_length);
431                visit_read(self, source);
432                visit_read(self, offset_source);
433                visit_read(self, offset_out);
434            }
435            BarrierOps::TmaLoad {
436                barrier,
437                offset_out,
438                tensor_map,
439                indices,
440            } => {
441                visit_read(self, offset_out);
442                visit_read(self, barrier);
443                visit_read(self, tensor_map);
444                for index in indices {
445                    visit_read(self, index);
446                }
447            }
448            BarrierOps::TmaLoadIm2col {
449                barrier,
450                tensor_map,
451                indices,
452                offset_out,
453                offsets,
454            } => {
455                visit_read(self, offset_out);
456                visit_read(self, barrier);
457                visit_read(self, tensor_map);
458                for index in indices {
459                    visit_read(self, index);
460                }
461                for offset in offsets {
462                    visit_read(self, offset);
463                }
464            }
465            BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
466            BarrierOps::Arrive { barrier } => visit_read(self, barrier),
467            BarrierOps::ArriveTx {
468                barrier,
469                arrive_count_update,
470                transaction_count_update,
471            } => {
472                visit_read(self, barrier);
473                visit_read(self, arrive_count_update);
474                visit_read(self, transaction_count_update);
475            }
476            BarrierOps::ExpectTx {
477                barrier,
478                transaction_count_update,
479            } => {
480                visit_read(self, barrier);
481                visit_read(self, transaction_count_update);
482            }
483            BarrierOps::Wait { barrier, token } => {
484                visit_read(self, barrier);
485                visit_read(self, token);
486            }
487            BarrierOps::WaitParity { barrier, phase } => {
488                visit_read(self, barrier);
489                visit_read(self, phase);
490            }
491        }
492    }
493
494    fn visit_tma(
495        &mut self,
496        tma_ops: &mut TmaOps,
497        mut visit_read: impl FnMut(&mut Self, &mut Variable),
498    ) {
499        match tma_ops {
500            TmaOps::TmaStore {
501                source,
502                coordinates,
503                offset_source,
504            } => {
505                visit_read(self, source);
506                visit_read(self, offset_source);
507                for coord in coordinates {
508                    visit_read(self, coord)
509                }
510            }
511            TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
512        }
513    }
514
515    fn visit_nonsemantic(
516        &mut self,
517        non_semantic: &mut NonSemantic,
518        mut visit_read: impl FnMut(&mut Self, &mut Variable),
519    ) {
520        match non_semantic {
521            NonSemantic::Comment { .. }
522            | NonSemantic::EnterDebugScope
523            | NonSemantic::ExitDebugScope => {}
524            NonSemantic::Print { args, .. } => {
525                for arg in args {
526                    visit_read(self, arg);
527                }
528            }
529        }
530    }
531
532    fn visit_unop(
533        &mut self,
534        unop: &mut UnaryOperator,
535        mut visit_read: impl FnMut(&mut Self, &mut Variable),
536    ) {
537        visit_read(self, &mut unop.input);
538    }
539
540    fn visit_binop(
541        &mut self,
542        binop: &mut BinaryOperator,
543        mut visit_read: impl FnMut(&mut Self, &mut Variable),
544    ) {
545        visit_read(self, &mut binop.lhs);
546        visit_read(self, &mut binop.rhs);
547    }
548}