Skip to main content

cubecl_opt/
instructions.rs

1use cubecl_ir::{
2    Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3    Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use crate::ControlFlow;
7
8use super::Optimizer;
9
10impl Optimizer {
11    pub fn visit_out(
12        &mut self,
13        var: &mut Option<Variable>,
14        mut visit_write: impl FnMut(&mut Self, &mut Variable),
15    ) {
16        if let Some(out) = var {
17            visit_write(self, out);
18        }
19    }
20
21    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
22    /// each read or written to variable.
23    pub fn visit_instruction(
24        &mut self,
25        inst: &mut Instruction,
26        visit_read: impl FnMut(&mut Self, &mut Variable),
27        visit_write: impl FnMut(&mut Self, &mut Variable),
28    ) {
29        self.visit_out(&mut inst.out, visit_write);
30        self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
31    }
32
33    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
34    /// each read or written to variable.
35    pub fn visit_operation(
36        &mut self,
37        op: &mut Operation,
38        out: &mut Option<Variable>,
39        mut visit_read: impl FnMut(&mut Self, &mut Variable),
40    ) {
41        match op {
42            Operation::Copy(variable) => visit_read(self, variable),
43            Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
44            Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
45            Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
46            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
47            Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
48            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
49            // Sync has no outputs
50            Operation::Synchronization(_) => {}
51            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
52            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
53            Operation::Branch(_) => unreachable!(),
54            Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
55            Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
56            Operation::NonSemantic(non_semantic) => {
57                self.visit_nonsemantic(non_semantic, visit_read)
58            }
59            Operation::Marker(_) => {}
60        }
61    }
62
63    /// Visit a control flow finisher with a set of read and write visitors. Each visitor will be called with
64    /// each read or written to variable.
65    pub fn visit_control_flow(
66        &mut self,
67        op: &mut ControlFlow,
68        mut visit_read: impl FnMut(&mut Self, &mut Variable),
69    ) {
70        match op {
71            ControlFlow::IfElse { cond, .. } => visit_read(self, cond),
72            ControlFlow::Switch { value, .. } => visit_read(self, value),
73            ControlFlow::Loop { .. } => {}
74            ControlFlow::LoopBreak { break_cond, .. } => visit_read(self, break_cond),
75            ControlFlow::Return | ControlFlow::Unreachable | ControlFlow::None => {}
76        }
77    }
78
79    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
80    /// each read or written to variable.
81    pub fn visit_arithmetic(
82        &mut self,
83        op: &mut Arithmetic,
84        mut visit_read: impl FnMut(&mut Self, &mut Variable),
85    ) {
86        match op {
87            Arithmetic::Fma(fma_operator) => {
88                visit_read(self, &mut fma_operator.a);
89                visit_read(self, &mut fma_operator.b);
90                visit_read(self, &mut fma_operator.c);
91            }
92            Arithmetic::Add(binary_operator)
93            | Arithmetic::SaturatingAdd(binary_operator)
94            | Arithmetic::Sub(binary_operator)
95            | Arithmetic::SaturatingSub(binary_operator)
96            | Arithmetic::Mul(binary_operator)
97            | Arithmetic::Div(binary_operator)
98            | Arithmetic::Powf(binary_operator)
99            | Arithmetic::Powi(binary_operator)
100            | Arithmetic::Hypot(binary_operator)
101            | Arithmetic::Rhypot(binary_operator)
102            | Arithmetic::Modulo(binary_operator)
103            | Arithmetic::Max(binary_operator)
104            | Arithmetic::Min(binary_operator)
105            | Arithmetic::Remainder(binary_operator)
106            | Arithmetic::Dot(binary_operator)
107            | Arithmetic::MulHi(binary_operator)
108            | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
109
110            Arithmetic::Abs(unary_operator)
111            | Arithmetic::Exp(unary_operator)
112            | Arithmetic::Log(unary_operator)
113            | Arithmetic::Log1p(unary_operator)
114            | Arithmetic::Cos(unary_operator)
115            | Arithmetic::Sin(unary_operator)
116            | Arithmetic::Tan(unary_operator)
117            | Arithmetic::Tanh(unary_operator)
118            | Arithmetic::Sinh(unary_operator)
119            | Arithmetic::Cosh(unary_operator)
120            | Arithmetic::ArcCos(unary_operator)
121            | Arithmetic::ArcSin(unary_operator)
122            | Arithmetic::ArcTan(unary_operator)
123            | Arithmetic::ArcSinh(unary_operator)
124            | Arithmetic::ArcCosh(unary_operator)
125            | Arithmetic::ArcTanh(unary_operator)
126            | Arithmetic::Degrees(unary_operator)
127            | Arithmetic::Radians(unary_operator)
128            | Arithmetic::Sqrt(unary_operator)
129            | Arithmetic::InverseSqrt(unary_operator)
130            | Arithmetic::Round(unary_operator)
131            | Arithmetic::Floor(unary_operator)
132            | Arithmetic::Ceil(unary_operator)
133            | Arithmetic::Trunc(unary_operator)
134            | Arithmetic::Erf(unary_operator)
135            | Arithmetic::Recip(unary_operator)
136            | Arithmetic::Neg(unary_operator)
137            | Arithmetic::Magnitude(unary_operator)
138            | Arithmetic::Normalize(unary_operator)
139            | Arithmetic::VectorSum(unary_operator) => self.visit_unop(unary_operator, visit_read),
140
141            Arithmetic::Clamp(clamp_operator) => {
142                visit_read(self, &mut clamp_operator.input);
143                visit_read(self, &mut clamp_operator.min_value);
144                visit_read(self, &mut clamp_operator.max_value);
145            }
146        }
147    }
148
149    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
150    /// each read or written to variable.
151    pub fn visit_compare(
152        &mut self,
153        op: &mut Comparison,
154        visit_read: impl FnMut(&mut Self, &mut Variable),
155    ) {
156        match op {
157            Comparison::Equal(binary_operator)
158            | Comparison::NotEqual(binary_operator)
159            | Comparison::LowerEqual(binary_operator)
160            | Comparison::Greater(binary_operator)
161            | Comparison::Lower(binary_operator)
162            | Comparison::GreaterEqual(binary_operator) => {
163                self.visit_binop(binary_operator, visit_read)
164            }
165            Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
166                self.visit_unop(unary_operator, visit_read)
167            }
168        }
169    }
170
171    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
172    /// each read or written to variable.
173    pub fn visit_bitwise(
174        &mut self,
175        op: &mut Bitwise,
176        visit_read: impl FnMut(&mut Self, &mut Variable),
177    ) {
178        match op {
179            Bitwise::BitwiseAnd(binary_operator)
180            | Bitwise::BitwiseOr(binary_operator)
181            | Bitwise::BitwiseXor(binary_operator)
182            | Bitwise::ShiftLeft(binary_operator)
183            | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
184
185            Bitwise::CountOnes(unary_operator)
186            | Bitwise::BitwiseNot(unary_operator)
187            | Bitwise::ReverseBits(unary_operator)
188            | Bitwise::LeadingZeros(unary_operator)
189            | Bitwise::TrailingZeros(unary_operator)
190            | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
191        }
192    }
193
194    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
195    /// each read or written to variable.
196    pub fn visit_operator(
197        &mut self,
198        op: &mut Operator,
199        mut visit_read: impl FnMut(&mut Self, &mut Variable),
200    ) {
201        match op {
202            Operator::And(binary_operator) | Operator::Or(binary_operator) => {
203                self.visit_binop(binary_operator, visit_read)
204            }
205            Operator::Not(unary_operator)
206            | Operator::Cast(unary_operator)
207            | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
208            Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
209                visit_read(self, &mut index_operator.list);
210                visit_read(self, &mut index_operator.index);
211            }
212            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
213                visit_read(self, &mut op.index);
214                visit_read(self, &mut op.value);
215            }
216            Operator::InitVector(vector_init_operator) => {
217                for input in &mut vector_init_operator.inputs {
218                    visit_read(self, input)
219                }
220            }
221            Operator::CopyMemory(copy_operator) => {
222                visit_read(self, &mut copy_operator.input);
223                visit_read(self, &mut copy_operator.in_index);
224                visit_read(self, &mut copy_operator.out_index);
225            }
226            Operator::CopyMemoryBulk(copy_bulk_operator) => {
227                visit_read(self, &mut copy_bulk_operator.input);
228                visit_read(self, &mut copy_bulk_operator.in_index);
229                visit_read(self, &mut copy_bulk_operator.out_index);
230            }
231            Operator::Select(select) => {
232                visit_read(self, &mut select.cond);
233                visit_read(self, &mut select.then);
234                visit_read(self, &mut select.or_else);
235            }
236        }
237    }
238
239    fn visit_atomic(
240        &mut self,
241        atomic: &mut AtomicOp,
242        out: &mut Option<Variable>,
243        mut visit_read: impl FnMut(&mut Self, &mut Variable),
244    ) {
245        match atomic {
246            AtomicOp::Add(binary_operator)
247            | AtomicOp::Sub(binary_operator)
248            | AtomicOp::Max(binary_operator)
249            | AtomicOp::Min(binary_operator)
250            | AtomicOp::And(binary_operator)
251            | AtomicOp::Or(binary_operator)
252            | AtomicOp::Xor(binary_operator)
253            | AtomicOp::Swap(binary_operator) => {
254                self.visit_binop(binary_operator, visit_read);
255            }
256            AtomicOp::Load(unary_operator) => {
257                self.visit_unop(unary_operator, visit_read);
258            }
259            AtomicOp::Store(unary_operator) => {
260                visit_read(self, out.as_mut().unwrap());
261                self.visit_unop(unary_operator, visit_read);
262            }
263            AtomicOp::CompareAndSwap(op) => {
264                visit_read(self, &mut op.cmp);
265                visit_read(self, &mut op.cmp);
266                visit_read(self, &mut op.val);
267            }
268        }
269    }
270    fn visit_meta(
271        &mut self,
272        metadata: &mut Metadata,
273        mut visit_read: impl FnMut(&mut Self, &mut Variable),
274    ) {
275        match metadata {
276            Metadata::Rank { var } => {
277                visit_read(self, var);
278            }
279            Metadata::Stride { dim, var } => {
280                visit_read(self, dim);
281                visit_read(self, var);
282            }
283            Metadata::Shape { dim, var } => {
284                visit_read(self, dim);
285                visit_read(self, var);
286            }
287            Metadata::Length { var } => {
288                visit_read(self, var);
289            }
290            Metadata::BufferLength { var } => {
291                visit_read(self, var);
292            }
293        }
294    }
295
296    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
297        match plane {
298            Plane::Elect => {}
299            Plane::Broadcast(binary_operator)
300            | Plane::Shuffle(binary_operator)
301            | Plane::ShuffleXor(binary_operator)
302            | Plane::ShuffleUp(binary_operator)
303            | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
304            Plane::All(unary_operator)
305            | Plane::Any(unary_operator)
306            | Plane::Sum(unary_operator)
307            | Plane::InclusiveSum(unary_operator)
308            | Plane::ExclusiveSum(unary_operator)
309            | Plane::Prod(unary_operator)
310            | Plane::InclusiveProd(unary_operator)
311            | Plane::ExclusiveProd(unary_operator)
312            | Plane::Min(unary_operator)
313            | Plane::Max(unary_operator)
314            | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
315        }
316    }
317
318    fn visit_cmma(
319        &mut self,
320        cmma: &mut CoopMma,
321        mut visit_read: impl FnMut(&mut Self, &mut Variable),
322    ) {
323        match cmma {
324            CoopMma::Fill { value } => {
325                visit_read(self, value);
326            }
327            CoopMma::Load {
328                value,
329                stride,
330                offset,
331                layout: _,
332            } => {
333                visit_read(self, value);
334                visit_read(self, stride);
335                visit_read(self, offset);
336            }
337            CoopMma::Execute {
338                mat_a,
339                mat_b,
340                mat_c,
341            } => {
342                visit_read(self, mat_a);
343                visit_read(self, mat_b);
344                visit_read(self, mat_c);
345            }
346            CoopMma::Store {
347                mat,
348                stride,
349                offset,
350                layout: _,
351            } => {
352                visit_read(self, mat);
353                visit_read(self, stride);
354                visit_read(self, offset);
355            }
356            CoopMma::Cast { input } => {
357                visit_read(self, input);
358            }
359            CoopMma::RowIndex { lane_id, i, .. } => {
360                visit_read(self, lane_id);
361                visit_read(self, i);
362            }
363            CoopMma::ColIndex { lane_id, i, .. } => {
364                visit_read(self, lane_id);
365                visit_read(self, i);
366            }
367            CoopMma::LoadMatrix { buffer, offset, .. } => {
368                visit_read(self, buffer);
369                visit_read(self, offset);
370            }
371            CoopMma::StoreMatrix {
372                offset, registers, ..
373            } => {
374                visit_read(self, offset);
375                visit_read(self, registers);
376            }
377            CoopMma::ExecuteManual {
378                registers_a,
379                registers_b,
380                registers_c,
381                ..
382            } => {
383                visit_read(self, registers_a);
384                visit_read(self, registers_b);
385                visit_read(self, registers_c);
386            }
387            CoopMma::ExecuteScaled {
388                registers_a,
389                registers_b,
390                registers_c,
391                scales_a,
392                scales_b,
393                ..
394            } => {
395                visit_read(self, registers_a);
396                visit_read(self, registers_b);
397                visit_read(self, registers_c);
398                visit_read(self, scales_a);
399                visit_read(self, scales_b);
400            }
401        }
402    }
403
404    fn visit_barrier(
405        &mut self,
406        barrier_ops: &mut BarrierOps,
407        mut visit_read: impl FnMut(&mut Self, &mut Variable),
408    ) {
409        match barrier_ops {
410            BarrierOps::Declare { barrier } => visit_read(self, barrier),
411            BarrierOps::Init {
412                barrier,
413                is_elected,
414                arrival_count,
415                ..
416            } => {
417                visit_read(self, barrier);
418                visit_read(self, is_elected);
419                visit_read(self, arrival_count);
420            }
421            BarrierOps::InitManual {
422                barrier,
423                arrival_count,
424            } => {
425                visit_read(self, barrier);
426                visit_read(self, arrival_count);
427            }
428            BarrierOps::MemCopyAsync {
429                barrier,
430                source,
431                source_length,
432                offset_source,
433                offset_out,
434            } => {
435                visit_read(self, barrier);
436                visit_read(self, source_length);
437                visit_read(self, source);
438                visit_read(self, offset_source);
439                visit_read(self, offset_out);
440            }
441            BarrierOps::MemCopyAsyncCooperative {
442                barrier,
443                source,
444                source_length,
445                offset_source,
446                offset_out,
447            } => {
448                visit_read(self, barrier);
449                visit_read(self, source_length);
450                visit_read(self, source);
451                visit_read(self, offset_source);
452                visit_read(self, offset_out);
453            }
454            BarrierOps::CopyAsync {
455                source,
456                source_length,
457                offset_source,
458                offset_out,
459                ..
460            } => {
461                visit_read(self, source_length);
462                visit_read(self, source);
463                visit_read(self, offset_source);
464                visit_read(self, offset_out);
465            }
466            BarrierOps::MemCopyAsyncTx {
467                barrier,
468                source,
469                source_length,
470                offset_source,
471                offset_out,
472            } => {
473                visit_read(self, barrier);
474                visit_read(self, source_length);
475                visit_read(self, source);
476                visit_read(self, offset_source);
477                visit_read(self, offset_out);
478            }
479            BarrierOps::TmaLoad {
480                barrier,
481                offset_out,
482                tensor_map,
483                indices,
484            } => {
485                visit_read(self, offset_out);
486                visit_read(self, barrier);
487                visit_read(self, tensor_map);
488                for index in indices {
489                    visit_read(self, index);
490                }
491            }
492            BarrierOps::TmaLoadIm2col {
493                barrier,
494                tensor_map,
495                indices,
496                offset_out,
497                offsets,
498            } => {
499                visit_read(self, offset_out);
500                visit_read(self, barrier);
501                visit_read(self, tensor_map);
502                for index in indices {
503                    visit_read(self, index);
504                }
505                for offset in offsets {
506                    visit_read(self, offset);
507                }
508            }
509            BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
510            BarrierOps::Arrive { barrier } => visit_read(self, barrier),
511            BarrierOps::ArriveTx {
512                barrier,
513                arrive_count_update,
514                transaction_count_update,
515            } => {
516                visit_read(self, barrier);
517                visit_read(self, arrive_count_update);
518                visit_read(self, transaction_count_update);
519            }
520            BarrierOps::CommitCopyAsync { barrier } => visit_read(self, barrier),
521            BarrierOps::ExpectTx {
522                barrier,
523                transaction_count_update,
524            } => {
525                visit_read(self, barrier);
526                visit_read(self, transaction_count_update);
527            }
528            BarrierOps::Wait { barrier, token } => {
529                visit_read(self, barrier);
530                visit_read(self, token);
531            }
532            BarrierOps::WaitParity { barrier, phase } => {
533                visit_read(self, barrier);
534                visit_read(self, phase);
535            }
536        }
537    }
538
539    fn visit_tma(
540        &mut self,
541        tma_ops: &mut TmaOps,
542        mut visit_read: impl FnMut(&mut Self, &mut Variable),
543    ) {
544        match tma_ops {
545            TmaOps::TmaStore {
546                source,
547                coordinates,
548                offset_source,
549            } => {
550                visit_read(self, source);
551                visit_read(self, offset_source);
552                for coord in coordinates {
553                    visit_read(self, coord)
554                }
555            }
556            TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
557        }
558    }
559
560    fn visit_nonsemantic(
561        &mut self,
562        non_semantic: &mut NonSemantic,
563        mut visit_read: impl FnMut(&mut Self, &mut Variable),
564    ) {
565        match non_semantic {
566            NonSemantic::Comment { .. }
567            | NonSemantic::EnterDebugScope
568            | NonSemantic::ExitDebugScope => {}
569            NonSemantic::Print { args, .. } => {
570                for arg in args {
571                    visit_read(self, arg);
572                }
573            }
574        }
575    }
576
577    fn visit_unop(
578        &mut self,
579        unop: &mut UnaryOperator,
580        mut visit_read: impl FnMut(&mut Self, &mut Variable),
581    ) {
582        visit_read(self, &mut unop.input);
583    }
584
585    fn visit_binop(
586        &mut self,
587        binop: &mut BinaryOperator,
588        mut visit_read: impl FnMut(&mut Self, &mut Variable),
589    ) {
590        visit_read(self, &mut binop.lhs);
591        visit_read(self, &mut binop.rhs);
592    }
593}