cubecl_opt/
instructions.rs

1use cubecl_ir::{
2    Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3    Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9    pub fn visit_out(
10        &mut self,
11        var: &mut Option<Variable>,
12        mut visit_write: impl FnMut(&mut Self, &mut Variable),
13    ) {
14        if let Some(out) = var {
15            visit_write(self, out);
16        }
17    }
18
19    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
20    /// each read or written to variable.
21    pub fn visit_instruction(
22        &mut self,
23        inst: &mut Instruction,
24        visit_read: impl FnMut(&mut Self, &mut Variable),
25        visit_write: impl FnMut(&mut Self, &mut Variable),
26    ) {
27        self.visit_out(&mut inst.out, visit_write);
28        self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29    }
30
31    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
32    /// each read or written to variable.
33    pub fn visit_operation(
34        &mut self,
35        op: &mut Operation,
36        out: &mut Option<Variable>,
37        mut visit_read: impl FnMut(&mut Self, &mut Variable),
38    ) {
39        match op {
40            Operation::Copy(variable) => visit_read(self, variable),
41            Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42            Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43            Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45            Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47            // Sync has no outputs
48            Operation::Synchronization(_) => {}
49            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51            Operation::Branch(_) => unreachable!(),
52            Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53            Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54            Operation::NonSemantic(non_semantic) => {
55                self.visit_nonsemantic(non_semantic, visit_read)
56            }
57            Operation::Marker(_) => {}
58        }
59    }
60
61    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
62    /// each read or written to variable.
63    pub fn visit_arithmetic(
64        &mut self,
65        op: &mut Arithmetic,
66        mut visit_read: impl FnMut(&mut Self, &mut Variable),
67    ) {
68        match op {
69            Arithmetic::Fma(fma_operator) => {
70                visit_read(self, &mut fma_operator.a);
71                visit_read(self, &mut fma_operator.b);
72                visit_read(self, &mut fma_operator.c);
73            }
74            Arithmetic::Add(binary_operator)
75            | Arithmetic::SaturatingAdd(binary_operator)
76            | Arithmetic::Sub(binary_operator)
77            | Arithmetic::SaturatingSub(binary_operator)
78            | Arithmetic::Mul(binary_operator)
79            | Arithmetic::Div(binary_operator)
80            | Arithmetic::Powf(binary_operator)
81            | Arithmetic::Powi(binary_operator)
82            | Arithmetic::Modulo(binary_operator)
83            | Arithmetic::Max(binary_operator)
84            | Arithmetic::Min(binary_operator)
85            | Arithmetic::Remainder(binary_operator)
86            | Arithmetic::Dot(binary_operator)
87            | Arithmetic::MulHi(binary_operator)
88            | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
89
90            Arithmetic::Abs(unary_operator)
91            | Arithmetic::Exp(unary_operator)
92            | Arithmetic::Log(unary_operator)
93            | Arithmetic::Log1p(unary_operator)
94            | Arithmetic::Cos(unary_operator)
95            | Arithmetic::Sin(unary_operator)
96            | Arithmetic::Tan(unary_operator)
97            | Arithmetic::Tanh(unary_operator)
98            | Arithmetic::Sinh(unary_operator)
99            | Arithmetic::Cosh(unary_operator)
100            | Arithmetic::ArcCos(unary_operator)
101            | Arithmetic::ArcSin(unary_operator)
102            | Arithmetic::ArcTan(unary_operator)
103            | Arithmetic::ArcSinh(unary_operator)
104            | Arithmetic::ArcCosh(unary_operator)
105            | Arithmetic::ArcTanh(unary_operator)
106            | Arithmetic::Degrees(unary_operator)
107            | Arithmetic::Radians(unary_operator)
108            | Arithmetic::Sqrt(unary_operator)
109            | Arithmetic::InverseSqrt(unary_operator)
110            | Arithmetic::Round(unary_operator)
111            | Arithmetic::Floor(unary_operator)
112            | Arithmetic::Ceil(unary_operator)
113            | Arithmetic::Trunc(unary_operator)
114            | Arithmetic::Erf(unary_operator)
115            | Arithmetic::Recip(unary_operator)
116            | Arithmetic::Neg(unary_operator)
117            | Arithmetic::Magnitude(unary_operator)
118            | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
119
120            Arithmetic::Clamp(clamp_operator) => {
121                visit_read(self, &mut clamp_operator.input);
122                visit_read(self, &mut clamp_operator.min_value);
123                visit_read(self, &mut clamp_operator.max_value);
124            }
125        }
126    }
127
128    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
129    /// each read or written to variable.
130    pub fn visit_compare(
131        &mut self,
132        op: &mut Comparison,
133        visit_read: impl FnMut(&mut Self, &mut Variable),
134    ) {
135        match op {
136            Comparison::Equal(binary_operator)
137            | Comparison::NotEqual(binary_operator)
138            | Comparison::LowerEqual(binary_operator)
139            | Comparison::Greater(binary_operator)
140            | Comparison::Lower(binary_operator)
141            | Comparison::GreaterEqual(binary_operator) => {
142                self.visit_binop(binary_operator, visit_read)
143            }
144            Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
145                self.visit_unop(unary_operator, visit_read)
146            }
147        }
148    }
149
150    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
151    /// each read or written to variable.
152    pub fn visit_bitwise(
153        &mut self,
154        op: &mut Bitwise,
155        visit_read: impl FnMut(&mut Self, &mut Variable),
156    ) {
157        match op {
158            Bitwise::BitwiseAnd(binary_operator)
159            | Bitwise::BitwiseOr(binary_operator)
160            | Bitwise::BitwiseXor(binary_operator)
161            | Bitwise::ShiftLeft(binary_operator)
162            | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
163
164            Bitwise::CountOnes(unary_operator)
165            | Bitwise::BitwiseNot(unary_operator)
166            | Bitwise::ReverseBits(unary_operator)
167            | Bitwise::LeadingZeros(unary_operator)
168            | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
169        }
170    }
171
172    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
173    /// each read or written to variable.
174    pub fn visit_operator(
175        &mut self,
176        op: &mut Operator,
177        mut visit_read: impl FnMut(&mut Self, &mut Variable),
178    ) {
179        match op {
180            Operator::And(binary_operator) | Operator::Or(binary_operator) => {
181                self.visit_binop(binary_operator, visit_read)
182            }
183            Operator::Not(unary_operator)
184            | Operator::Cast(unary_operator)
185            | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
186            Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
187                visit_read(self, &mut index_operator.list);
188                visit_read(self, &mut index_operator.index);
189            }
190            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
191                visit_read(self, &mut op.index);
192                visit_read(self, &mut op.value);
193            }
194            Operator::InitLine(line_init_operator) => {
195                for input in &mut line_init_operator.inputs {
196                    visit_read(self, input)
197                }
198            }
199            Operator::CopyMemory(copy_operator) => {
200                visit_read(self, &mut copy_operator.input);
201                visit_read(self, &mut copy_operator.in_index);
202                visit_read(self, &mut copy_operator.out_index);
203            }
204            Operator::CopyMemoryBulk(copy_bulk_operator) => {
205                visit_read(self, &mut copy_bulk_operator.input);
206                visit_read(self, &mut copy_bulk_operator.in_index);
207                visit_read(self, &mut copy_bulk_operator.out_index);
208            }
209            Operator::Select(select) => {
210                visit_read(self, &mut select.cond);
211                visit_read(self, &mut select.then);
212                visit_read(self, &mut select.or_else);
213            }
214        }
215    }
216
217    fn visit_atomic(
218        &mut self,
219        atomic: &mut AtomicOp,
220        out: &mut Option<Variable>,
221        mut visit_read: impl FnMut(&mut Self, &mut Variable),
222    ) {
223        match atomic {
224            AtomicOp::Add(binary_operator)
225            | AtomicOp::Sub(binary_operator)
226            | AtomicOp::Max(binary_operator)
227            | AtomicOp::Min(binary_operator)
228            | AtomicOp::And(binary_operator)
229            | AtomicOp::Or(binary_operator)
230            | AtomicOp::Xor(binary_operator)
231            | AtomicOp::Swap(binary_operator) => {
232                self.visit_binop(binary_operator, visit_read);
233            }
234            AtomicOp::Load(unary_operator) => {
235                self.visit_unop(unary_operator, visit_read);
236            }
237            AtomicOp::Store(unary_operator) => {
238                visit_read(self, out.as_mut().unwrap());
239                self.visit_unop(unary_operator, visit_read);
240            }
241            AtomicOp::CompareAndSwap(op) => {
242                visit_read(self, &mut op.cmp);
243                visit_read(self, &mut op.cmp);
244                visit_read(self, &mut op.val);
245            }
246        }
247    }
248    fn visit_meta(
249        &mut self,
250        metadata: &mut Metadata,
251        mut visit_read: impl FnMut(&mut Self, &mut Variable),
252    ) {
253        match metadata {
254            Metadata::Rank { var } => {
255                visit_read(self, var);
256            }
257            Metadata::Stride { dim, var } => {
258                visit_read(self, dim);
259                visit_read(self, var);
260            }
261            Metadata::Shape { dim, var } => {
262                visit_read(self, dim);
263                visit_read(self, var);
264            }
265            Metadata::Length { var } => {
266                visit_read(self, var);
267            }
268            Metadata::BufferLength { var } => {
269                visit_read(self, var);
270            }
271        }
272    }
273
274    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
275        match plane {
276            Plane::Elect => {}
277            Plane::Broadcast(binary_operator)
278            | Plane::Shuffle(binary_operator)
279            | Plane::ShuffleXor(binary_operator)
280            | Plane::ShuffleUp(binary_operator)
281            | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
282            Plane::All(unary_operator)
283            | Plane::Any(unary_operator)
284            | Plane::Sum(unary_operator)
285            | Plane::InclusiveSum(unary_operator)
286            | Plane::ExclusiveSum(unary_operator)
287            | Plane::Prod(unary_operator)
288            | Plane::InclusiveProd(unary_operator)
289            | Plane::ExclusiveProd(unary_operator)
290            | Plane::Min(unary_operator)
291            | Plane::Max(unary_operator)
292            | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
293        }
294    }
295
296    fn visit_cmma(
297        &mut self,
298        cmma: &mut CoopMma,
299        mut visit_read: impl FnMut(&mut Self, &mut Variable),
300    ) {
301        match cmma {
302            CoopMma::Fill { value } => {
303                visit_read(self, value);
304            }
305            CoopMma::Load {
306                value,
307                stride,
308                offset,
309                layout: _,
310            } => {
311                visit_read(self, value);
312                visit_read(self, stride);
313                visit_read(self, offset);
314            }
315            CoopMma::Execute {
316                mat_a,
317                mat_b,
318                mat_c,
319            } => {
320                visit_read(self, mat_a);
321                visit_read(self, mat_b);
322                visit_read(self, mat_c);
323            }
324            CoopMma::Store {
325                mat,
326                stride,
327                offset,
328                layout: _,
329            } => {
330                visit_read(self, mat);
331                visit_read(self, stride);
332                visit_read(self, offset);
333            }
334            CoopMma::Cast { input } => {
335                visit_read(self, input);
336            }
337            CoopMma::RowIndex { lane_id, i, .. } => {
338                visit_read(self, lane_id);
339                visit_read(self, i);
340            }
341            CoopMma::ColIndex { lane_id, i, .. } => {
342                visit_read(self, lane_id);
343                visit_read(self, i);
344            }
345            CoopMma::LoadMatrix { buffer, offset, .. } => {
346                visit_read(self, buffer);
347                visit_read(self, offset);
348            }
349            CoopMma::StoreMatrix {
350                offset, registers, ..
351            } => {
352                visit_read(self, offset);
353                visit_read(self, registers);
354            }
355            CoopMma::ExecuteManual {
356                registers_a,
357                registers_b,
358                registers_c,
359                ..
360            } => {
361                visit_read(self, registers_a);
362                visit_read(self, registers_b);
363                visit_read(self, registers_c);
364            }
365            CoopMma::ExecuteScaled {
366                registers_a,
367                registers_b,
368                registers_c,
369                scales_a,
370                scales_b,
371                ..
372            } => {
373                visit_read(self, registers_a);
374                visit_read(self, registers_b);
375                visit_read(self, registers_c);
376                visit_read(self, scales_a);
377                visit_read(self, scales_b);
378            }
379        }
380    }
381
382    fn visit_barrier(
383        &mut self,
384        barrier_ops: &mut BarrierOps,
385        mut visit_read: impl FnMut(&mut Self, &mut Variable),
386    ) {
387        match barrier_ops {
388            BarrierOps::Declare { barrier } => visit_read(self, barrier),
389            BarrierOps::Init {
390                barrier,
391                is_elected,
392                arrival_count,
393                ..
394            } => {
395                visit_read(self, barrier);
396                visit_read(self, is_elected);
397                visit_read(self, arrival_count);
398            }
399            BarrierOps::InitManual {
400                barrier,
401                arrival_count,
402            } => {
403                visit_read(self, barrier);
404                visit_read(self, arrival_count);
405            }
406            BarrierOps::MemCopyAsync {
407                barrier,
408                source,
409                source_length,
410                offset_source,
411                offset_out,
412            } => {
413                visit_read(self, barrier);
414                visit_read(self, source_length);
415                visit_read(self, source);
416                visit_read(self, offset_source);
417                visit_read(self, offset_out);
418            }
419            BarrierOps::MemCopyAsyncCooperative {
420                barrier,
421                source,
422                source_length,
423                offset_source,
424                offset_out,
425            } => {
426                visit_read(self, barrier);
427                visit_read(self, source_length);
428                visit_read(self, source);
429                visit_read(self, offset_source);
430                visit_read(self, offset_out);
431            }
432            BarrierOps::MemCopyAsyncTx {
433                barrier,
434                source,
435                source_length,
436                offset_source,
437                offset_out,
438            } => {
439                visit_read(self, barrier);
440                visit_read(self, source_length);
441                visit_read(self, source);
442                visit_read(self, offset_source);
443                visit_read(self, offset_out);
444            }
445            BarrierOps::TmaLoad {
446                barrier,
447                offset_out,
448                tensor_map,
449                indices,
450            } => {
451                visit_read(self, offset_out);
452                visit_read(self, barrier);
453                visit_read(self, tensor_map);
454                for index in indices {
455                    visit_read(self, index);
456                }
457            }
458            BarrierOps::TmaLoadIm2col {
459                barrier,
460                tensor_map,
461                indices,
462                offset_out,
463                offsets,
464            } => {
465                visit_read(self, offset_out);
466                visit_read(self, barrier);
467                visit_read(self, tensor_map);
468                for index in indices {
469                    visit_read(self, index);
470                }
471                for offset in offsets {
472                    visit_read(self, offset);
473                }
474            }
475            BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
476            BarrierOps::Arrive { barrier } => visit_read(self, barrier),
477            BarrierOps::ArriveTx {
478                barrier,
479                arrive_count_update,
480                transaction_count_update,
481            } => {
482                visit_read(self, barrier);
483                visit_read(self, arrive_count_update);
484                visit_read(self, transaction_count_update);
485            }
486            BarrierOps::ExpectTx {
487                barrier,
488                transaction_count_update,
489            } => {
490                visit_read(self, barrier);
491                visit_read(self, transaction_count_update);
492            }
493            BarrierOps::Wait { barrier, token } => {
494                visit_read(self, barrier);
495                visit_read(self, token);
496            }
497            BarrierOps::WaitParity { barrier, phase } => {
498                visit_read(self, barrier);
499                visit_read(self, phase);
500            }
501        }
502    }
503
504    fn visit_tma(
505        &mut self,
506        tma_ops: &mut TmaOps,
507        mut visit_read: impl FnMut(&mut Self, &mut Variable),
508    ) {
509        match tma_ops {
510            TmaOps::TmaStore {
511                source,
512                coordinates,
513                offset_source,
514            } => {
515                visit_read(self, source);
516                visit_read(self, offset_source);
517                for coord in coordinates {
518                    visit_read(self, coord)
519                }
520            }
521            TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
522        }
523    }
524
525    fn visit_nonsemantic(
526        &mut self,
527        non_semantic: &mut NonSemantic,
528        mut visit_read: impl FnMut(&mut Self, &mut Variable),
529    ) {
530        match non_semantic {
531            NonSemantic::Comment { .. }
532            | NonSemantic::EnterDebugScope
533            | NonSemantic::ExitDebugScope => {}
534            NonSemantic::Print { args, .. } => {
535                for arg in args {
536                    visit_read(self, arg);
537                }
538            }
539        }
540    }
541
542    fn visit_unop(
543        &mut self,
544        unop: &mut UnaryOperator,
545        mut visit_read: impl FnMut(&mut Self, &mut Variable),
546    ) {
547        visit_read(self, &mut unop.input);
548    }
549
550    fn visit_binop(
551        &mut self,
552        binop: &mut BinaryOperator,
553        mut visit_read: impl FnMut(&mut Self, &mut Variable),
554    ) {
555        visit_read(self, &mut binop.lhs);
556        visit_read(self, &mut binop.rhs);
557    }
558}