cubecl_opt/
instructions.rs

1use cubecl_ir::{
2    Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3    Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9    pub fn visit_out(
10        &mut self,
11        var: &mut Option<Variable>,
12        mut visit_write: impl FnMut(&mut Self, &mut Variable),
13    ) {
14        if let Some(out) = var {
15            visit_write(self, out);
16        }
17    }
18
19    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
20    /// each read or written to variable.
21    pub fn visit_instruction(
22        &mut self,
23        inst: &mut Instruction,
24        visit_read: impl FnMut(&mut Self, &mut Variable),
25        visit_write: impl FnMut(&mut Self, &mut Variable),
26    ) {
27        self.visit_out(&mut inst.out, visit_write);
28        self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29    }
30
31    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
32    /// each read or written to variable.
33    pub fn visit_operation(
34        &mut self,
35        op: &mut Operation,
36        out: &mut Option<Variable>,
37        mut visit_read: impl FnMut(&mut Self, &mut Variable),
38    ) {
39        match op {
40            Operation::Copy(variable) => visit_read(self, variable),
41            Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42            Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43            Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45            Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47            // Sync has no outputs
48            Operation::Synchronization(_) => {}
49            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51            Operation::Branch(_) => unreachable!(),
52            Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53            Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54            Operation::NonSemantic(non_semantic) => {
55                self.visit_nonsemantic(non_semantic, visit_read)
56            }
57            Operation::Marker(_) => {}
58        }
59    }
60
61    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
62    /// each read or written to variable.
63    pub fn visit_arithmetic(
64        &mut self,
65        op: &mut Arithmetic,
66        mut visit_read: impl FnMut(&mut Self, &mut Variable),
67    ) {
68        match op {
69            Arithmetic::Fma(fma_operator) => {
70                visit_read(self, &mut fma_operator.a);
71                visit_read(self, &mut fma_operator.b);
72                visit_read(self, &mut fma_operator.c);
73            }
74            Arithmetic::Add(binary_operator)
75            | Arithmetic::SaturatingAdd(binary_operator)
76            | Arithmetic::Sub(binary_operator)
77            | Arithmetic::SaturatingSub(binary_operator)
78            | Arithmetic::Mul(binary_operator)
79            | Arithmetic::Div(binary_operator)
80            | Arithmetic::Powf(binary_operator)
81            | Arithmetic::Powi(binary_operator)
82            | Arithmetic::Hypot(binary_operator)
83            | Arithmetic::Rhypot(binary_operator)
84            | Arithmetic::Modulo(binary_operator)
85            | Arithmetic::Max(binary_operator)
86            | Arithmetic::Min(binary_operator)
87            | Arithmetic::Remainder(binary_operator)
88            | Arithmetic::Dot(binary_operator)
89            | Arithmetic::MulHi(binary_operator)
90            | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
91
92            Arithmetic::Abs(unary_operator)
93            | Arithmetic::Exp(unary_operator)
94            | Arithmetic::Log(unary_operator)
95            | Arithmetic::Log1p(unary_operator)
96            | Arithmetic::Cos(unary_operator)
97            | Arithmetic::Sin(unary_operator)
98            | Arithmetic::Tan(unary_operator)
99            | Arithmetic::Tanh(unary_operator)
100            | Arithmetic::Sinh(unary_operator)
101            | Arithmetic::Cosh(unary_operator)
102            | Arithmetic::ArcCos(unary_operator)
103            | Arithmetic::ArcSin(unary_operator)
104            | Arithmetic::ArcTan(unary_operator)
105            | Arithmetic::ArcSinh(unary_operator)
106            | Arithmetic::ArcCosh(unary_operator)
107            | Arithmetic::ArcTanh(unary_operator)
108            | Arithmetic::Degrees(unary_operator)
109            | Arithmetic::Radians(unary_operator)
110            | Arithmetic::Sqrt(unary_operator)
111            | Arithmetic::InverseSqrt(unary_operator)
112            | Arithmetic::Round(unary_operator)
113            | Arithmetic::Floor(unary_operator)
114            | Arithmetic::Ceil(unary_operator)
115            | Arithmetic::Trunc(unary_operator)
116            | Arithmetic::Erf(unary_operator)
117            | Arithmetic::Recip(unary_operator)
118            | Arithmetic::Neg(unary_operator)
119            | Arithmetic::Magnitude(unary_operator)
120            | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
121
122            Arithmetic::Clamp(clamp_operator) => {
123                visit_read(self, &mut clamp_operator.input);
124                visit_read(self, &mut clamp_operator.min_value);
125                visit_read(self, &mut clamp_operator.max_value);
126            }
127        }
128    }
129
130    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
131    /// each read or written to variable.
132    pub fn visit_compare(
133        &mut self,
134        op: &mut Comparison,
135        visit_read: impl FnMut(&mut Self, &mut Variable),
136    ) {
137        match op {
138            Comparison::Equal(binary_operator)
139            | Comparison::NotEqual(binary_operator)
140            | Comparison::LowerEqual(binary_operator)
141            | Comparison::Greater(binary_operator)
142            | Comparison::Lower(binary_operator)
143            | Comparison::GreaterEqual(binary_operator) => {
144                self.visit_binop(binary_operator, visit_read)
145            }
146            Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
147                self.visit_unop(unary_operator, visit_read)
148            }
149        }
150    }
151
152    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
153    /// each read or written to variable.
154    pub fn visit_bitwise(
155        &mut self,
156        op: &mut Bitwise,
157        visit_read: impl FnMut(&mut Self, &mut Variable),
158    ) {
159        match op {
160            Bitwise::BitwiseAnd(binary_operator)
161            | Bitwise::BitwiseOr(binary_operator)
162            | Bitwise::BitwiseXor(binary_operator)
163            | Bitwise::ShiftLeft(binary_operator)
164            | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
165
166            Bitwise::CountOnes(unary_operator)
167            | Bitwise::BitwiseNot(unary_operator)
168            | Bitwise::ReverseBits(unary_operator)
169            | Bitwise::LeadingZeros(unary_operator)
170            | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
171        }
172    }
173
174    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
175    /// each read or written to variable.
176    pub fn visit_operator(
177        &mut self,
178        op: &mut Operator,
179        mut visit_read: impl FnMut(&mut Self, &mut Variable),
180    ) {
181        match op {
182            Operator::And(binary_operator) | Operator::Or(binary_operator) => {
183                self.visit_binop(binary_operator, visit_read)
184            }
185            Operator::Not(unary_operator)
186            | Operator::Cast(unary_operator)
187            | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
188            Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
189                visit_read(self, &mut index_operator.list);
190                visit_read(self, &mut index_operator.index);
191            }
192            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
193                visit_read(self, &mut op.index);
194                visit_read(self, &mut op.value);
195            }
196            Operator::InitLine(line_init_operator) => {
197                for input in &mut line_init_operator.inputs {
198                    visit_read(self, input)
199                }
200            }
201            Operator::CopyMemory(copy_operator) => {
202                visit_read(self, &mut copy_operator.input);
203                visit_read(self, &mut copy_operator.in_index);
204                visit_read(self, &mut copy_operator.out_index);
205            }
206            Operator::CopyMemoryBulk(copy_bulk_operator) => {
207                visit_read(self, &mut copy_bulk_operator.input);
208                visit_read(self, &mut copy_bulk_operator.in_index);
209                visit_read(self, &mut copy_bulk_operator.out_index);
210            }
211            Operator::Select(select) => {
212                visit_read(self, &mut select.cond);
213                visit_read(self, &mut select.then);
214                visit_read(self, &mut select.or_else);
215            }
216        }
217    }
218
219    fn visit_atomic(
220        &mut self,
221        atomic: &mut AtomicOp,
222        out: &mut Option<Variable>,
223        mut visit_read: impl FnMut(&mut Self, &mut Variable),
224    ) {
225        match atomic {
226            AtomicOp::Add(binary_operator)
227            | AtomicOp::Sub(binary_operator)
228            | AtomicOp::Max(binary_operator)
229            | AtomicOp::Min(binary_operator)
230            | AtomicOp::And(binary_operator)
231            | AtomicOp::Or(binary_operator)
232            | AtomicOp::Xor(binary_operator)
233            | AtomicOp::Swap(binary_operator) => {
234                self.visit_binop(binary_operator, visit_read);
235            }
236            AtomicOp::Load(unary_operator) => {
237                self.visit_unop(unary_operator, visit_read);
238            }
239            AtomicOp::Store(unary_operator) => {
240                visit_read(self, out.as_mut().unwrap());
241                self.visit_unop(unary_operator, visit_read);
242            }
243            AtomicOp::CompareAndSwap(op) => {
244                visit_read(self, &mut op.cmp);
245                visit_read(self, &mut op.cmp);
246                visit_read(self, &mut op.val);
247            }
248        }
249    }
250    fn visit_meta(
251        &mut self,
252        metadata: &mut Metadata,
253        mut visit_read: impl FnMut(&mut Self, &mut Variable),
254    ) {
255        match metadata {
256            Metadata::Rank { var } => {
257                visit_read(self, var);
258            }
259            Metadata::Stride { dim, var } => {
260                visit_read(self, dim);
261                visit_read(self, var);
262            }
263            Metadata::Shape { dim, var } => {
264                visit_read(self, dim);
265                visit_read(self, var);
266            }
267            Metadata::Length { var } => {
268                visit_read(self, var);
269            }
270            Metadata::BufferLength { var } => {
271                visit_read(self, var);
272            }
273        }
274    }
275
276    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
277        match plane {
278            Plane::Elect => {}
279            Plane::Broadcast(binary_operator)
280            | Plane::Shuffle(binary_operator)
281            | Plane::ShuffleXor(binary_operator)
282            | Plane::ShuffleUp(binary_operator)
283            | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
284            Plane::All(unary_operator)
285            | Plane::Any(unary_operator)
286            | Plane::Sum(unary_operator)
287            | Plane::InclusiveSum(unary_operator)
288            | Plane::ExclusiveSum(unary_operator)
289            | Plane::Prod(unary_operator)
290            | Plane::InclusiveProd(unary_operator)
291            | Plane::ExclusiveProd(unary_operator)
292            | Plane::Min(unary_operator)
293            | Plane::Max(unary_operator)
294            | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
295        }
296    }
297
298    fn visit_cmma(
299        &mut self,
300        cmma: &mut CoopMma,
301        mut visit_read: impl FnMut(&mut Self, &mut Variable),
302    ) {
303        match cmma {
304            CoopMma::Fill { value } => {
305                visit_read(self, value);
306            }
307            CoopMma::Load {
308                value,
309                stride,
310                offset,
311                layout: _,
312            } => {
313                visit_read(self, value);
314                visit_read(self, stride);
315                visit_read(self, offset);
316            }
317            CoopMma::Execute {
318                mat_a,
319                mat_b,
320                mat_c,
321            } => {
322                visit_read(self, mat_a);
323                visit_read(self, mat_b);
324                visit_read(self, mat_c);
325            }
326            CoopMma::Store {
327                mat,
328                stride,
329                offset,
330                layout: _,
331            } => {
332                visit_read(self, mat);
333                visit_read(self, stride);
334                visit_read(self, offset);
335            }
336            CoopMma::Cast { input } => {
337                visit_read(self, input);
338            }
339            CoopMma::RowIndex { lane_id, i, .. } => {
340                visit_read(self, lane_id);
341                visit_read(self, i);
342            }
343            CoopMma::ColIndex { lane_id, i, .. } => {
344                visit_read(self, lane_id);
345                visit_read(self, i);
346            }
347            CoopMma::LoadMatrix { buffer, offset, .. } => {
348                visit_read(self, buffer);
349                visit_read(self, offset);
350            }
351            CoopMma::StoreMatrix {
352                offset, registers, ..
353            } => {
354                visit_read(self, offset);
355                visit_read(self, registers);
356            }
357            CoopMma::ExecuteManual {
358                registers_a,
359                registers_b,
360                registers_c,
361                ..
362            } => {
363                visit_read(self, registers_a);
364                visit_read(self, registers_b);
365                visit_read(self, registers_c);
366            }
367            CoopMma::ExecuteScaled {
368                registers_a,
369                registers_b,
370                registers_c,
371                scales_a,
372                scales_b,
373                ..
374            } => {
375                visit_read(self, registers_a);
376                visit_read(self, registers_b);
377                visit_read(self, registers_c);
378                visit_read(self, scales_a);
379                visit_read(self, scales_b);
380            }
381        }
382    }
383
384    fn visit_barrier(
385        &mut self,
386        barrier_ops: &mut BarrierOps,
387        mut visit_read: impl FnMut(&mut Self, &mut Variable),
388    ) {
389        match barrier_ops {
390            BarrierOps::Declare { barrier } => visit_read(self, barrier),
391            BarrierOps::Init {
392                barrier,
393                is_elected,
394                arrival_count,
395                ..
396            } => {
397                visit_read(self, barrier);
398                visit_read(self, is_elected);
399                visit_read(self, arrival_count);
400            }
401            BarrierOps::InitManual {
402                barrier,
403                arrival_count,
404            } => {
405                visit_read(self, barrier);
406                visit_read(self, arrival_count);
407            }
408            BarrierOps::MemCopyAsync {
409                barrier,
410                source,
411                source_length,
412                offset_source,
413                offset_out,
414            } => {
415                visit_read(self, barrier);
416                visit_read(self, source_length);
417                visit_read(self, source);
418                visit_read(self, offset_source);
419                visit_read(self, offset_out);
420            }
421            BarrierOps::MemCopyAsyncCooperative {
422                barrier,
423                source,
424                source_length,
425                offset_source,
426                offset_out,
427            } => {
428                visit_read(self, barrier);
429                visit_read(self, source_length);
430                visit_read(self, source);
431                visit_read(self, offset_source);
432                visit_read(self, offset_out);
433            }
434            BarrierOps::CopyAsync {
435                source,
436                source_length,
437                offset_source,
438                offset_out,
439                ..
440            } => {
441                visit_read(self, source_length);
442                visit_read(self, source);
443                visit_read(self, offset_source);
444                visit_read(self, offset_out);
445            }
446            BarrierOps::MemCopyAsyncTx {
447                barrier,
448                source,
449                source_length,
450                offset_source,
451                offset_out,
452            } => {
453                visit_read(self, barrier);
454                visit_read(self, source_length);
455                visit_read(self, source);
456                visit_read(self, offset_source);
457                visit_read(self, offset_out);
458            }
459            BarrierOps::TmaLoad {
460                barrier,
461                offset_out,
462                tensor_map,
463                indices,
464            } => {
465                visit_read(self, offset_out);
466                visit_read(self, barrier);
467                visit_read(self, tensor_map);
468                for index in indices {
469                    visit_read(self, index);
470                }
471            }
472            BarrierOps::TmaLoadIm2col {
473                barrier,
474                tensor_map,
475                indices,
476                offset_out,
477                offsets,
478            } => {
479                visit_read(self, offset_out);
480                visit_read(self, barrier);
481                visit_read(self, tensor_map);
482                for index in indices {
483                    visit_read(self, index);
484                }
485                for offset in offsets {
486                    visit_read(self, offset);
487                }
488            }
489            BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
490            BarrierOps::Arrive { barrier } => visit_read(self, barrier),
491            BarrierOps::ArriveTx {
492                barrier,
493                arrive_count_update,
494                transaction_count_update,
495            } => {
496                visit_read(self, barrier);
497                visit_read(self, arrive_count_update);
498                visit_read(self, transaction_count_update);
499            }
500            BarrierOps::CommitCopyAsync { barrier } => visit_read(self, barrier),
501            BarrierOps::ExpectTx {
502                barrier,
503                transaction_count_update,
504            } => {
505                visit_read(self, barrier);
506                visit_read(self, transaction_count_update);
507            }
508            BarrierOps::Wait { barrier, token } => {
509                visit_read(self, barrier);
510                visit_read(self, token);
511            }
512            BarrierOps::WaitParity { barrier, phase } => {
513                visit_read(self, barrier);
514                visit_read(self, phase);
515            }
516        }
517    }
518
519    fn visit_tma(
520        &mut self,
521        tma_ops: &mut TmaOps,
522        mut visit_read: impl FnMut(&mut Self, &mut Variable),
523    ) {
524        match tma_ops {
525            TmaOps::TmaStore {
526                source,
527                coordinates,
528                offset_source,
529            } => {
530                visit_read(self, source);
531                visit_read(self, offset_source);
532                for coord in coordinates {
533                    visit_read(self, coord)
534                }
535            }
536            TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
537        }
538    }
539
540    fn visit_nonsemantic(
541        &mut self,
542        non_semantic: &mut NonSemantic,
543        mut visit_read: impl FnMut(&mut Self, &mut Variable),
544    ) {
545        match non_semantic {
546            NonSemantic::Comment { .. }
547            | NonSemantic::EnterDebugScope
548            | NonSemantic::ExitDebugScope => {}
549            NonSemantic::Print { args, .. } => {
550                for arg in args {
551                    visit_read(self, arg);
552                }
553            }
554        }
555    }
556
557    fn visit_unop(
558        &mut self,
559        unop: &mut UnaryOperator,
560        mut visit_read: impl FnMut(&mut Self, &mut Variable),
561    ) {
562        visit_read(self, &mut unop.input);
563    }
564
565    fn visit_binop(
566        &mut self,
567        binop: &mut BinaryOperator,
568        mut visit_read: impl FnMut(&mut Self, &mut Variable),
569    ) {
570        visit_read(self, &mut binop.lhs);
571        visit_read(self, &mut binop.rhs);
572    }
573}