Skip to main content

cubecl_opt/
instructions.rs

1use cubecl_ir::{
2    Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3    Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use crate::ControlFlow;
7
8use super::Optimizer;
9
10impl Optimizer {
11    pub fn visit_out(
12        &mut self,
13        var: &mut Option<Variable>,
14        mut visit_write: impl FnMut(&mut Self, &mut Variable),
15    ) {
16        if let Some(out) = var {
17            visit_write(self, out);
18        }
19    }
20
21    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
22    /// each read or written to variable.
23    pub fn visit_instruction(
24        &mut self,
25        inst: &mut Instruction,
26        visit_read: impl FnMut(&mut Self, &mut Variable),
27        visit_write: impl FnMut(&mut Self, &mut Variable),
28    ) {
29        self.visit_out(&mut inst.out, visit_write);
30        self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
31    }
32
33    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
34    /// each read or written to variable.
35    pub fn visit_operation(
36        &mut self,
37        op: &mut Operation,
38        out: &mut Option<Variable>,
39        mut visit_read: impl FnMut(&mut Self, &mut Variable),
40    ) {
41        match op {
42            Operation::Copy(variable) => visit_read(self, variable),
43            Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
44            Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
45            Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
46            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
47            Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
48            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
49            // Sync has no outputs
50            Operation::Synchronization(_) => {}
51            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
52            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
53            Operation::Branch(_) => unreachable!(),
54            Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
55            Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
56            Operation::NonSemantic(non_semantic) => {
57                self.visit_nonsemantic(non_semantic, visit_read)
58            }
59            Operation::Marker(_) => {}
60        }
61    }
62
63    /// Visit a control flow finisher with a set of read and write visitors. Each visitor will be called with
64    /// each read or written to variable.
65    pub fn visit_control_flow(
66        &mut self,
67        op: &mut ControlFlow,
68        mut visit_read: impl FnMut(&mut Self, &mut Variable),
69    ) {
70        match op {
71            ControlFlow::IfElse { cond, .. } => visit_read(self, cond),
72            ControlFlow::Switch { value, .. } => visit_read(self, value),
73            ControlFlow::Loop { .. } => {}
74            ControlFlow::LoopBreak { break_cond, .. } => visit_read(self, break_cond),
75            ControlFlow::Return | ControlFlow::Unreachable | ControlFlow::None => {}
76        }
77    }
78
79    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
80    /// each read or written to variable.
81    pub fn visit_arithmetic(
82        &mut self,
83        op: &mut Arithmetic,
84        mut visit_read: impl FnMut(&mut Self, &mut Variable),
85    ) {
86        match op {
87            Arithmetic::Fma(fma_operator) => {
88                visit_read(self, &mut fma_operator.a);
89                visit_read(self, &mut fma_operator.b);
90                visit_read(self, &mut fma_operator.c);
91            }
92            Arithmetic::Add(binary_operator)
93            | Arithmetic::SaturatingAdd(binary_operator)
94            | Arithmetic::Sub(binary_operator)
95            | Arithmetic::SaturatingSub(binary_operator)
96            | Arithmetic::Mul(binary_operator)
97            | Arithmetic::Div(binary_operator)
98            | Arithmetic::Powf(binary_operator)
99            | Arithmetic::Powi(binary_operator)
100            | Arithmetic::Hypot(binary_operator)
101            | Arithmetic::Rhypot(binary_operator)
102            | Arithmetic::Modulo(binary_operator)
103            | Arithmetic::Max(binary_operator)
104            | Arithmetic::Min(binary_operator)
105            | Arithmetic::Remainder(binary_operator)
106            | Arithmetic::Dot(binary_operator)
107            | Arithmetic::MulHi(binary_operator)
108            | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
109
110            Arithmetic::Abs(unary_operator)
111            | Arithmetic::Exp(unary_operator)
112            | Arithmetic::Log(unary_operator)
113            | Arithmetic::Log1p(unary_operator)
114            | Arithmetic::Cos(unary_operator)
115            | Arithmetic::Sin(unary_operator)
116            | Arithmetic::Tan(unary_operator)
117            | Arithmetic::Tanh(unary_operator)
118            | Arithmetic::Sinh(unary_operator)
119            | Arithmetic::Cosh(unary_operator)
120            | Arithmetic::ArcCos(unary_operator)
121            | Arithmetic::ArcSin(unary_operator)
122            | Arithmetic::ArcTan(unary_operator)
123            | Arithmetic::ArcSinh(unary_operator)
124            | Arithmetic::ArcCosh(unary_operator)
125            | Arithmetic::ArcTanh(unary_operator)
126            | Arithmetic::Degrees(unary_operator)
127            | Arithmetic::Radians(unary_operator)
128            | Arithmetic::Sqrt(unary_operator)
129            | Arithmetic::InverseSqrt(unary_operator)
130            | Arithmetic::Round(unary_operator)
131            | Arithmetic::Floor(unary_operator)
132            | Arithmetic::Ceil(unary_operator)
133            | Arithmetic::Trunc(unary_operator)
134            | Arithmetic::Erf(unary_operator)
135            | Arithmetic::Recip(unary_operator)
136            | Arithmetic::Neg(unary_operator)
137            | Arithmetic::Magnitude(unary_operator)
138            | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
139
140            Arithmetic::Clamp(clamp_operator) => {
141                visit_read(self, &mut clamp_operator.input);
142                visit_read(self, &mut clamp_operator.min_value);
143                visit_read(self, &mut clamp_operator.max_value);
144            }
145        }
146    }
147
148    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
149    /// each read or written to variable.
150    pub fn visit_compare(
151        &mut self,
152        op: &mut Comparison,
153        visit_read: impl FnMut(&mut Self, &mut Variable),
154    ) {
155        match op {
156            Comparison::Equal(binary_operator)
157            | Comparison::NotEqual(binary_operator)
158            | Comparison::LowerEqual(binary_operator)
159            | Comparison::Greater(binary_operator)
160            | Comparison::Lower(binary_operator)
161            | Comparison::GreaterEqual(binary_operator) => {
162                self.visit_binop(binary_operator, visit_read)
163            }
164            Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
165                self.visit_unop(unary_operator, visit_read)
166            }
167        }
168    }
169
170    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
171    /// each read or written to variable.
172    pub fn visit_bitwise(
173        &mut self,
174        op: &mut Bitwise,
175        visit_read: impl FnMut(&mut Self, &mut Variable),
176    ) {
177        match op {
178            Bitwise::BitwiseAnd(binary_operator)
179            | Bitwise::BitwiseOr(binary_operator)
180            | Bitwise::BitwiseXor(binary_operator)
181            | Bitwise::ShiftLeft(binary_operator)
182            | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
183
184            Bitwise::CountOnes(unary_operator)
185            | Bitwise::BitwiseNot(unary_operator)
186            | Bitwise::ReverseBits(unary_operator)
187            | Bitwise::LeadingZeros(unary_operator)
188            | Bitwise::TrailingZeros(unary_operator)
189            | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
190        }
191    }
192
193    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
194    /// each read or written to variable.
195    pub fn visit_operator(
196        &mut self,
197        op: &mut Operator,
198        mut visit_read: impl FnMut(&mut Self, &mut Variable),
199    ) {
200        match op {
201            Operator::And(binary_operator) | Operator::Or(binary_operator) => {
202                self.visit_binop(binary_operator, visit_read)
203            }
204            Operator::Not(unary_operator)
205            | Operator::Cast(unary_operator)
206            | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
207            Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
208                visit_read(self, &mut index_operator.list);
209                visit_read(self, &mut index_operator.index);
210            }
211            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
212                visit_read(self, &mut op.index);
213                visit_read(self, &mut op.value);
214            }
215            Operator::InitVector(vector_init_operator) => {
216                for input in &mut vector_init_operator.inputs {
217                    visit_read(self, input)
218                }
219            }
220            Operator::CopyMemory(copy_operator) => {
221                visit_read(self, &mut copy_operator.input);
222                visit_read(self, &mut copy_operator.in_index);
223                visit_read(self, &mut copy_operator.out_index);
224            }
225            Operator::CopyMemoryBulk(copy_bulk_operator) => {
226                visit_read(self, &mut copy_bulk_operator.input);
227                visit_read(self, &mut copy_bulk_operator.in_index);
228                visit_read(self, &mut copy_bulk_operator.out_index);
229            }
230            Operator::Select(select) => {
231                visit_read(self, &mut select.cond);
232                visit_read(self, &mut select.then);
233                visit_read(self, &mut select.or_else);
234            }
235        }
236    }
237
238    fn visit_atomic(
239        &mut self,
240        atomic: &mut AtomicOp,
241        out: &mut Option<Variable>,
242        mut visit_read: impl FnMut(&mut Self, &mut Variable),
243    ) {
244        match atomic {
245            AtomicOp::Add(binary_operator)
246            | AtomicOp::Sub(binary_operator)
247            | AtomicOp::Max(binary_operator)
248            | AtomicOp::Min(binary_operator)
249            | AtomicOp::And(binary_operator)
250            | AtomicOp::Or(binary_operator)
251            | AtomicOp::Xor(binary_operator)
252            | AtomicOp::Swap(binary_operator) => {
253                self.visit_binop(binary_operator, visit_read);
254            }
255            AtomicOp::Load(unary_operator) => {
256                self.visit_unop(unary_operator, visit_read);
257            }
258            AtomicOp::Store(unary_operator) => {
259                visit_read(self, out.as_mut().unwrap());
260                self.visit_unop(unary_operator, visit_read);
261            }
262            AtomicOp::CompareAndSwap(op) => {
263                visit_read(self, &mut op.cmp);
264                visit_read(self, &mut op.cmp);
265                visit_read(self, &mut op.val);
266            }
267        }
268    }
269    fn visit_meta(
270        &mut self,
271        metadata: &mut Metadata,
272        mut visit_read: impl FnMut(&mut Self, &mut Variable),
273    ) {
274        match metadata {
275            Metadata::Rank { var } => {
276                visit_read(self, var);
277            }
278            Metadata::Stride { dim, var } => {
279                visit_read(self, dim);
280                visit_read(self, var);
281            }
282            Metadata::Shape { dim, var } => {
283                visit_read(self, dim);
284                visit_read(self, var);
285            }
286            Metadata::Length { var } => {
287                visit_read(self, var);
288            }
289            Metadata::BufferLength { var } => {
290                visit_read(self, var);
291            }
292        }
293    }
294
295    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
296        match plane {
297            Plane::Elect => {}
298            Plane::Broadcast(binary_operator)
299            | Plane::Shuffle(binary_operator)
300            | Plane::ShuffleXor(binary_operator)
301            | Plane::ShuffleUp(binary_operator)
302            | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
303            Plane::All(unary_operator)
304            | Plane::Any(unary_operator)
305            | Plane::Sum(unary_operator)
306            | Plane::InclusiveSum(unary_operator)
307            | Plane::ExclusiveSum(unary_operator)
308            | Plane::Prod(unary_operator)
309            | Plane::InclusiveProd(unary_operator)
310            | Plane::ExclusiveProd(unary_operator)
311            | Plane::Min(unary_operator)
312            | Plane::Max(unary_operator)
313            | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
314        }
315    }
316
317    fn visit_cmma(
318        &mut self,
319        cmma: &mut CoopMma,
320        mut visit_read: impl FnMut(&mut Self, &mut Variable),
321    ) {
322        match cmma {
323            CoopMma::Fill { value } => {
324                visit_read(self, value);
325            }
326            CoopMma::Load {
327                value,
328                stride,
329                offset,
330                layout: _,
331            } => {
332                visit_read(self, value);
333                visit_read(self, stride);
334                visit_read(self, offset);
335            }
336            CoopMma::Execute {
337                mat_a,
338                mat_b,
339                mat_c,
340            } => {
341                visit_read(self, mat_a);
342                visit_read(self, mat_b);
343                visit_read(self, mat_c);
344            }
345            CoopMma::Store {
346                mat,
347                stride,
348                offset,
349                layout: _,
350            } => {
351                visit_read(self, mat);
352                visit_read(self, stride);
353                visit_read(self, offset);
354            }
355            CoopMma::Cast { input } => {
356                visit_read(self, input);
357            }
358            CoopMma::RowIndex { lane_id, i, .. } => {
359                visit_read(self, lane_id);
360                visit_read(self, i);
361            }
362            CoopMma::ColIndex { lane_id, i, .. } => {
363                visit_read(self, lane_id);
364                visit_read(self, i);
365            }
366            CoopMma::LoadMatrix { buffer, offset, .. } => {
367                visit_read(self, buffer);
368                visit_read(self, offset);
369            }
370            CoopMma::StoreMatrix {
371                offset, registers, ..
372            } => {
373                visit_read(self, offset);
374                visit_read(self, registers);
375            }
376            CoopMma::ExecuteManual {
377                registers_a,
378                registers_b,
379                registers_c,
380                ..
381            } => {
382                visit_read(self, registers_a);
383                visit_read(self, registers_b);
384                visit_read(self, registers_c);
385            }
386            CoopMma::ExecuteScaled {
387                registers_a,
388                registers_b,
389                registers_c,
390                scales_a,
391                scales_b,
392                ..
393            } => {
394                visit_read(self, registers_a);
395                visit_read(self, registers_b);
396                visit_read(self, registers_c);
397                visit_read(self, scales_a);
398                visit_read(self, scales_b);
399            }
400        }
401    }
402
403    fn visit_barrier(
404        &mut self,
405        barrier_ops: &mut BarrierOps,
406        mut visit_read: impl FnMut(&mut Self, &mut Variable),
407    ) {
408        match barrier_ops {
409            BarrierOps::Declare { barrier } => visit_read(self, barrier),
410            BarrierOps::Init {
411                barrier,
412                is_elected,
413                arrival_count,
414                ..
415            } => {
416                visit_read(self, barrier);
417                visit_read(self, is_elected);
418                visit_read(self, arrival_count);
419            }
420            BarrierOps::InitManual {
421                barrier,
422                arrival_count,
423            } => {
424                visit_read(self, barrier);
425                visit_read(self, arrival_count);
426            }
427            BarrierOps::MemCopyAsync {
428                barrier,
429                source,
430                source_length,
431                offset_source,
432                offset_out,
433            } => {
434                visit_read(self, barrier);
435                visit_read(self, source_length);
436                visit_read(self, source);
437                visit_read(self, offset_source);
438                visit_read(self, offset_out);
439            }
440            BarrierOps::MemCopyAsyncCooperative {
441                barrier,
442                source,
443                source_length,
444                offset_source,
445                offset_out,
446            } => {
447                visit_read(self, barrier);
448                visit_read(self, source_length);
449                visit_read(self, source);
450                visit_read(self, offset_source);
451                visit_read(self, offset_out);
452            }
453            BarrierOps::CopyAsync {
454                source,
455                source_length,
456                offset_source,
457                offset_out,
458                ..
459            } => {
460                visit_read(self, source_length);
461                visit_read(self, source);
462                visit_read(self, offset_source);
463                visit_read(self, offset_out);
464            }
465            BarrierOps::MemCopyAsyncTx {
466                barrier,
467                source,
468                source_length,
469                offset_source,
470                offset_out,
471            } => {
472                visit_read(self, barrier);
473                visit_read(self, source_length);
474                visit_read(self, source);
475                visit_read(self, offset_source);
476                visit_read(self, offset_out);
477            }
478            BarrierOps::TmaLoad {
479                barrier,
480                offset_out,
481                tensor_map,
482                indices,
483            } => {
484                visit_read(self, offset_out);
485                visit_read(self, barrier);
486                visit_read(self, tensor_map);
487                for index in indices {
488                    visit_read(self, index);
489                }
490            }
491            BarrierOps::TmaLoadIm2col {
492                barrier,
493                tensor_map,
494                indices,
495                offset_out,
496                offsets,
497            } => {
498                visit_read(self, offset_out);
499                visit_read(self, barrier);
500                visit_read(self, tensor_map);
501                for index in indices {
502                    visit_read(self, index);
503                }
504                for offset in offsets {
505                    visit_read(self, offset);
506                }
507            }
508            BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
509            BarrierOps::Arrive { barrier } => visit_read(self, barrier),
510            BarrierOps::ArriveTx {
511                barrier,
512                arrive_count_update,
513                transaction_count_update,
514            } => {
515                visit_read(self, barrier);
516                visit_read(self, arrive_count_update);
517                visit_read(self, transaction_count_update);
518            }
519            BarrierOps::CommitCopyAsync { barrier } => visit_read(self, barrier),
520            BarrierOps::ExpectTx {
521                barrier,
522                transaction_count_update,
523            } => {
524                visit_read(self, barrier);
525                visit_read(self, transaction_count_update);
526            }
527            BarrierOps::Wait { barrier, token } => {
528                visit_read(self, barrier);
529                visit_read(self, token);
530            }
531            BarrierOps::WaitParity { barrier, phase } => {
532                visit_read(self, barrier);
533                visit_read(self, phase);
534            }
535        }
536    }
537
538    fn visit_tma(
539        &mut self,
540        tma_ops: &mut TmaOps,
541        mut visit_read: impl FnMut(&mut Self, &mut Variable),
542    ) {
543        match tma_ops {
544            TmaOps::TmaStore {
545                source,
546                coordinates,
547                offset_source,
548            } => {
549                visit_read(self, source);
550                visit_read(self, offset_source);
551                for coord in coordinates {
552                    visit_read(self, coord)
553                }
554            }
555            TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
556        }
557    }
558
559    fn visit_nonsemantic(
560        &mut self,
561        non_semantic: &mut NonSemantic,
562        mut visit_read: impl FnMut(&mut Self, &mut Variable),
563    ) {
564        match non_semantic {
565            NonSemantic::Comment { .. }
566            | NonSemantic::EnterDebugScope
567            | NonSemantic::ExitDebugScope => {}
568            NonSemantic::Print { args, .. } => {
569                for arg in args {
570                    visit_read(self, arg);
571                }
572            }
573        }
574    }
575
576    fn visit_unop(
577        &mut self,
578        unop: &mut UnaryOperator,
579        mut visit_read: impl FnMut(&mut Self, &mut Variable),
580    ) {
581        visit_read(self, &mut unop.input);
582    }
583
584    fn visit_binop(
585        &mut self,
586        binop: &mut BinaryOperator,
587        mut visit_read: impl FnMut(&mut Self, &mut Variable),
588    ) {
589        visit_read(self, &mut binop.lhs);
590        visit_read(self, &mut binop.rhs);
591    }
592}