Skip to main content

cubecl_opt/
instructions.rs

1use cubecl_ir::{
2    Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3    Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9    pub fn visit_out(
10        &mut self,
11        var: &mut Option<Variable>,
12        mut visit_write: impl FnMut(&mut Self, &mut Variable),
13    ) {
14        if let Some(out) = var {
15            visit_write(self, out);
16        }
17    }
18
19    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
20    /// each read or written to variable.
21    pub fn visit_instruction(
22        &mut self,
23        inst: &mut Instruction,
24        visit_read: impl FnMut(&mut Self, &mut Variable),
25        visit_write: impl FnMut(&mut Self, &mut Variable),
26    ) {
27        self.visit_out(&mut inst.out, visit_write);
28        self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29    }
30
31    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
32    /// each read or written to variable.
33    pub fn visit_operation(
34        &mut self,
35        op: &mut Operation,
36        out: &mut Option<Variable>,
37        mut visit_read: impl FnMut(&mut Self, &mut Variable),
38    ) {
39        match op {
40            Operation::Copy(variable) => visit_read(self, variable),
41            Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42            Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43            Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45            Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47            // Sync has no outputs
48            Operation::Synchronization(_) => {}
49            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51            Operation::Branch(_) => unreachable!(),
52            Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53            Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54            Operation::NonSemantic(non_semantic) => {
55                self.visit_nonsemantic(non_semantic, visit_read)
56            }
57            Operation::Marker(_) => {}
58        }
59    }
60
61    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
62    /// each read or written to variable.
63    pub fn visit_arithmetic(
64        &mut self,
65        op: &mut Arithmetic,
66        mut visit_read: impl FnMut(&mut Self, &mut Variable),
67    ) {
68        match op {
69            Arithmetic::Fma(fma_operator) => {
70                visit_read(self, &mut fma_operator.a);
71                visit_read(self, &mut fma_operator.b);
72                visit_read(self, &mut fma_operator.c);
73            }
74            Arithmetic::Add(binary_operator)
75            | Arithmetic::SaturatingAdd(binary_operator)
76            | Arithmetic::Sub(binary_operator)
77            | Arithmetic::SaturatingSub(binary_operator)
78            | Arithmetic::Mul(binary_operator)
79            | Arithmetic::Div(binary_operator)
80            | Arithmetic::Powf(binary_operator)
81            | Arithmetic::Powi(binary_operator)
82            | Arithmetic::Hypot(binary_operator)
83            | Arithmetic::Rhypot(binary_operator)
84            | Arithmetic::Modulo(binary_operator)
85            | Arithmetic::Max(binary_operator)
86            | Arithmetic::Min(binary_operator)
87            | Arithmetic::Remainder(binary_operator)
88            | Arithmetic::Dot(binary_operator)
89            | Arithmetic::MulHi(binary_operator)
90            | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
91
92            Arithmetic::Abs(unary_operator)
93            | Arithmetic::Exp(unary_operator)
94            | Arithmetic::Log(unary_operator)
95            | Arithmetic::Log1p(unary_operator)
96            | Arithmetic::Cos(unary_operator)
97            | Arithmetic::Sin(unary_operator)
98            | Arithmetic::Tan(unary_operator)
99            | Arithmetic::Tanh(unary_operator)
100            | Arithmetic::Sinh(unary_operator)
101            | Arithmetic::Cosh(unary_operator)
102            | Arithmetic::ArcCos(unary_operator)
103            | Arithmetic::ArcSin(unary_operator)
104            | Arithmetic::ArcTan(unary_operator)
105            | Arithmetic::ArcSinh(unary_operator)
106            | Arithmetic::ArcCosh(unary_operator)
107            | Arithmetic::ArcTanh(unary_operator)
108            | Arithmetic::Degrees(unary_operator)
109            | Arithmetic::Radians(unary_operator)
110            | Arithmetic::Sqrt(unary_operator)
111            | Arithmetic::InverseSqrt(unary_operator)
112            | Arithmetic::Round(unary_operator)
113            | Arithmetic::Floor(unary_operator)
114            | Arithmetic::Ceil(unary_operator)
115            | Arithmetic::Trunc(unary_operator)
116            | Arithmetic::Erf(unary_operator)
117            | Arithmetic::Recip(unary_operator)
118            | Arithmetic::Neg(unary_operator)
119            | Arithmetic::Magnitude(unary_operator)
120            | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
121
122            Arithmetic::Clamp(clamp_operator) => {
123                visit_read(self, &mut clamp_operator.input);
124                visit_read(self, &mut clamp_operator.min_value);
125                visit_read(self, &mut clamp_operator.max_value);
126            }
127        }
128    }
129
130    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
131    /// each read or written to variable.
132    pub fn visit_compare(
133        &mut self,
134        op: &mut Comparison,
135        visit_read: impl FnMut(&mut Self, &mut Variable),
136    ) {
137        match op {
138            Comparison::Equal(binary_operator)
139            | Comparison::NotEqual(binary_operator)
140            | Comparison::LowerEqual(binary_operator)
141            | Comparison::Greater(binary_operator)
142            | Comparison::Lower(binary_operator)
143            | Comparison::GreaterEqual(binary_operator) => {
144                self.visit_binop(binary_operator, visit_read)
145            }
146            Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
147                self.visit_unop(unary_operator, visit_read)
148            }
149        }
150    }
151
152    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
153    /// each read or written to variable.
154    pub fn visit_bitwise(
155        &mut self,
156        op: &mut Bitwise,
157        visit_read: impl FnMut(&mut Self, &mut Variable),
158    ) {
159        match op {
160            Bitwise::BitwiseAnd(binary_operator)
161            | Bitwise::BitwiseOr(binary_operator)
162            | Bitwise::BitwiseXor(binary_operator)
163            | Bitwise::ShiftLeft(binary_operator)
164            | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
165
166            Bitwise::CountOnes(unary_operator)
167            | Bitwise::BitwiseNot(unary_operator)
168            | Bitwise::ReverseBits(unary_operator)
169            | Bitwise::LeadingZeros(unary_operator)
170            | Bitwise::TrailingZeros(unary_operator)
171            | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
172        }
173    }
174
175    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
176    /// each read or written to variable.
177    pub fn visit_operator(
178        &mut self,
179        op: &mut Operator,
180        mut visit_read: impl FnMut(&mut Self, &mut Variable),
181    ) {
182        match op {
183            Operator::And(binary_operator) | Operator::Or(binary_operator) => {
184                self.visit_binop(binary_operator, visit_read)
185            }
186            Operator::Not(unary_operator)
187            | Operator::Cast(unary_operator)
188            | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
189            Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
190                visit_read(self, &mut index_operator.list);
191                visit_read(self, &mut index_operator.index);
192            }
193            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
194                visit_read(self, &mut op.index);
195                visit_read(self, &mut op.value);
196            }
197            Operator::InitLine(line_init_operator) => {
198                for input in &mut line_init_operator.inputs {
199                    visit_read(self, input)
200                }
201            }
202            Operator::CopyMemory(copy_operator) => {
203                visit_read(self, &mut copy_operator.input);
204                visit_read(self, &mut copy_operator.in_index);
205                visit_read(self, &mut copy_operator.out_index);
206            }
207            Operator::CopyMemoryBulk(copy_bulk_operator) => {
208                visit_read(self, &mut copy_bulk_operator.input);
209                visit_read(self, &mut copy_bulk_operator.in_index);
210                visit_read(self, &mut copy_bulk_operator.out_index);
211            }
212            Operator::Select(select) => {
213                visit_read(self, &mut select.cond);
214                visit_read(self, &mut select.then);
215                visit_read(self, &mut select.or_else);
216            }
217        }
218    }
219
220    fn visit_atomic(
221        &mut self,
222        atomic: &mut AtomicOp,
223        out: &mut Option<Variable>,
224        mut visit_read: impl FnMut(&mut Self, &mut Variable),
225    ) {
226        match atomic {
227            AtomicOp::Add(binary_operator)
228            | AtomicOp::Sub(binary_operator)
229            | AtomicOp::Max(binary_operator)
230            | AtomicOp::Min(binary_operator)
231            | AtomicOp::And(binary_operator)
232            | AtomicOp::Or(binary_operator)
233            | AtomicOp::Xor(binary_operator)
234            | AtomicOp::Swap(binary_operator) => {
235                self.visit_binop(binary_operator, visit_read);
236            }
237            AtomicOp::Load(unary_operator) => {
238                self.visit_unop(unary_operator, visit_read);
239            }
240            AtomicOp::Store(unary_operator) => {
241                visit_read(self, out.as_mut().unwrap());
242                self.visit_unop(unary_operator, visit_read);
243            }
244            AtomicOp::CompareAndSwap(op) => {
245                visit_read(self, &mut op.cmp);
246                visit_read(self, &mut op.cmp);
247                visit_read(self, &mut op.val);
248            }
249        }
250    }
251    fn visit_meta(
252        &mut self,
253        metadata: &mut Metadata,
254        mut visit_read: impl FnMut(&mut Self, &mut Variable),
255    ) {
256        match metadata {
257            Metadata::Rank { var } => {
258                visit_read(self, var);
259            }
260            Metadata::Stride { dim, var } => {
261                visit_read(self, dim);
262                visit_read(self, var);
263            }
264            Metadata::Shape { dim, var } => {
265                visit_read(self, dim);
266                visit_read(self, var);
267            }
268            Metadata::Length { var } => {
269                visit_read(self, var);
270            }
271            Metadata::BufferLength { var } => {
272                visit_read(self, var);
273            }
274        }
275    }
276
277    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
278        match plane {
279            Plane::Elect => {}
280            Plane::Broadcast(binary_operator)
281            | Plane::Shuffle(binary_operator)
282            | Plane::ShuffleXor(binary_operator)
283            | Plane::ShuffleUp(binary_operator)
284            | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
285            Plane::All(unary_operator)
286            | Plane::Any(unary_operator)
287            | Plane::Sum(unary_operator)
288            | Plane::InclusiveSum(unary_operator)
289            | Plane::ExclusiveSum(unary_operator)
290            | Plane::Prod(unary_operator)
291            | Plane::InclusiveProd(unary_operator)
292            | Plane::ExclusiveProd(unary_operator)
293            | Plane::Min(unary_operator)
294            | Plane::Max(unary_operator)
295            | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
296        }
297    }
298
299    fn visit_cmma(
300        &mut self,
301        cmma: &mut CoopMma,
302        mut visit_read: impl FnMut(&mut Self, &mut Variable),
303    ) {
304        match cmma {
305            CoopMma::Fill { value } => {
306                visit_read(self, value);
307            }
308            CoopMma::Load {
309                value,
310                stride,
311                offset,
312                layout: _,
313            } => {
314                visit_read(self, value);
315                visit_read(self, stride);
316                visit_read(self, offset);
317            }
318            CoopMma::Execute {
319                mat_a,
320                mat_b,
321                mat_c,
322            } => {
323                visit_read(self, mat_a);
324                visit_read(self, mat_b);
325                visit_read(self, mat_c);
326            }
327            CoopMma::Store {
328                mat,
329                stride,
330                offset,
331                layout: _,
332            } => {
333                visit_read(self, mat);
334                visit_read(self, stride);
335                visit_read(self, offset);
336            }
337            CoopMma::Cast { input } => {
338                visit_read(self, input);
339            }
340            CoopMma::RowIndex { lane_id, i, .. } => {
341                visit_read(self, lane_id);
342                visit_read(self, i);
343            }
344            CoopMma::ColIndex { lane_id, i, .. } => {
345                visit_read(self, lane_id);
346                visit_read(self, i);
347            }
348            CoopMma::LoadMatrix { buffer, offset, .. } => {
349                visit_read(self, buffer);
350                visit_read(self, offset);
351            }
352            CoopMma::StoreMatrix {
353                offset, registers, ..
354            } => {
355                visit_read(self, offset);
356                visit_read(self, registers);
357            }
358            CoopMma::ExecuteManual {
359                registers_a,
360                registers_b,
361                registers_c,
362                ..
363            } => {
364                visit_read(self, registers_a);
365                visit_read(self, registers_b);
366                visit_read(self, registers_c);
367            }
368            CoopMma::ExecuteScaled {
369                registers_a,
370                registers_b,
371                registers_c,
372                scales_a,
373                scales_b,
374                ..
375            } => {
376                visit_read(self, registers_a);
377                visit_read(self, registers_b);
378                visit_read(self, registers_c);
379                visit_read(self, scales_a);
380                visit_read(self, scales_b);
381            }
382        }
383    }
384
385    fn visit_barrier(
386        &mut self,
387        barrier_ops: &mut BarrierOps,
388        mut visit_read: impl FnMut(&mut Self, &mut Variable),
389    ) {
390        match barrier_ops {
391            BarrierOps::Declare { barrier } => visit_read(self, barrier),
392            BarrierOps::Init {
393                barrier,
394                is_elected,
395                arrival_count,
396                ..
397            } => {
398                visit_read(self, barrier);
399                visit_read(self, is_elected);
400                visit_read(self, arrival_count);
401            }
402            BarrierOps::InitManual {
403                barrier,
404                arrival_count,
405            } => {
406                visit_read(self, barrier);
407                visit_read(self, arrival_count);
408            }
409            BarrierOps::MemCopyAsync {
410                barrier,
411                source,
412                source_length,
413                offset_source,
414                offset_out,
415            } => {
416                visit_read(self, barrier);
417                visit_read(self, source_length);
418                visit_read(self, source);
419                visit_read(self, offset_source);
420                visit_read(self, offset_out);
421            }
422            BarrierOps::MemCopyAsyncCooperative {
423                barrier,
424                source,
425                source_length,
426                offset_source,
427                offset_out,
428            } => {
429                visit_read(self, barrier);
430                visit_read(self, source_length);
431                visit_read(self, source);
432                visit_read(self, offset_source);
433                visit_read(self, offset_out);
434            }
435            BarrierOps::CopyAsync {
436                source,
437                source_length,
438                offset_source,
439                offset_out,
440                ..
441            } => {
442                visit_read(self, source_length);
443                visit_read(self, source);
444                visit_read(self, offset_source);
445                visit_read(self, offset_out);
446            }
447            BarrierOps::MemCopyAsyncTx {
448                barrier,
449                source,
450                source_length,
451                offset_source,
452                offset_out,
453            } => {
454                visit_read(self, barrier);
455                visit_read(self, source_length);
456                visit_read(self, source);
457                visit_read(self, offset_source);
458                visit_read(self, offset_out);
459            }
460            BarrierOps::TmaLoad {
461                barrier,
462                offset_out,
463                tensor_map,
464                indices,
465            } => {
466                visit_read(self, offset_out);
467                visit_read(self, barrier);
468                visit_read(self, tensor_map);
469                for index in indices {
470                    visit_read(self, index);
471                }
472            }
473            BarrierOps::TmaLoadIm2col {
474                barrier,
475                tensor_map,
476                indices,
477                offset_out,
478                offsets,
479            } => {
480                visit_read(self, offset_out);
481                visit_read(self, barrier);
482                visit_read(self, tensor_map);
483                for index in indices {
484                    visit_read(self, index);
485                }
486                for offset in offsets {
487                    visit_read(self, offset);
488                }
489            }
490            BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
491            BarrierOps::Arrive { barrier } => visit_read(self, barrier),
492            BarrierOps::ArriveTx {
493                barrier,
494                arrive_count_update,
495                transaction_count_update,
496            } => {
497                visit_read(self, barrier);
498                visit_read(self, arrive_count_update);
499                visit_read(self, transaction_count_update);
500            }
501            BarrierOps::CommitCopyAsync { barrier } => visit_read(self, barrier),
502            BarrierOps::ExpectTx {
503                barrier,
504                transaction_count_update,
505            } => {
506                visit_read(self, barrier);
507                visit_read(self, transaction_count_update);
508            }
509            BarrierOps::Wait { barrier, token } => {
510                visit_read(self, barrier);
511                visit_read(self, token);
512            }
513            BarrierOps::WaitParity { barrier, phase } => {
514                visit_read(self, barrier);
515                visit_read(self, phase);
516            }
517        }
518    }
519
520    fn visit_tma(
521        &mut self,
522        tma_ops: &mut TmaOps,
523        mut visit_read: impl FnMut(&mut Self, &mut Variable),
524    ) {
525        match tma_ops {
526            TmaOps::TmaStore {
527                source,
528                coordinates,
529                offset_source,
530            } => {
531                visit_read(self, source);
532                visit_read(self, offset_source);
533                for coord in coordinates {
534                    visit_read(self, coord)
535                }
536            }
537            TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
538        }
539    }
540
541    fn visit_nonsemantic(
542        &mut self,
543        non_semantic: &mut NonSemantic,
544        mut visit_read: impl FnMut(&mut Self, &mut Variable),
545    ) {
546        match non_semantic {
547            NonSemantic::Comment { .. }
548            | NonSemantic::EnterDebugScope
549            | NonSemantic::ExitDebugScope => {}
550            NonSemantic::Print { args, .. } => {
551                for arg in args {
552                    visit_read(self, arg);
553                }
554            }
555        }
556    }
557
558    fn visit_unop(
559        &mut self,
560        unop: &mut UnaryOperator,
561        mut visit_read: impl FnMut(&mut Self, &mut Variable),
562    ) {
563        visit_read(self, &mut unop.input);
564    }
565
566    fn visit_binop(
567        &mut self,
568        binop: &mut BinaryOperator,
569        mut visit_read: impl FnMut(&mut Self, &mut Variable),
570    ) {
571        visit_read(self, &mut binop.lhs);
572        visit_read(self, &mut binop.rhs);
573    }
574}