cubecl_opt/
instructions.rs

1use cubecl_ir::{
2    Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3    Metadata, NonSemantic, Operation, Operator, PipelineOps, Plane, TmaOps, UnaryOperator,
4    Variable,
5};
6
7use super::Optimizer;
8
9impl Optimizer {
10    pub fn visit_out(
11        &mut self,
12        var: &mut Option<Variable>,
13        mut visit_write: impl FnMut(&mut Self, &mut Variable),
14    ) {
15        if let Some(out) = var {
16            visit_write(self, out);
17        }
18    }
19
20    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
21    /// each read or written to variable.
22    pub fn visit_instruction(
23        &mut self,
24        inst: &mut Instruction,
25        visit_read: impl FnMut(&mut Self, &mut Variable),
26        visit_write: impl FnMut(&mut Self, &mut Variable),
27    ) {
28        self.visit_out(&mut inst.out, visit_write);
29        self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
30    }
31
32    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
33    /// each read or written to variable.
34    pub fn visit_operation(
35        &mut self,
36        op: &mut Operation,
37        out: &mut Option<Variable>,
38        mut visit_read: impl FnMut(&mut Self, &mut Variable),
39    ) {
40        match op {
41            Operation::Copy(variable) => visit_read(self, variable),
42            Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
43            Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
44            Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
45            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
46            Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
47            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
48            // Sync has no outputs
49            Operation::Synchronization(_) => {}
50            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
51            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
52            Operation::Branch(_) => unreachable!(),
53            Operation::Pipeline(pipeline_ops) => self.visit_pipeline(pipeline_ops, visit_read),
54            Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
55            Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
56            Operation::NonSemantic(non_semantic) => {
57                self.visit_nonsemantic(non_semantic, visit_read)
58            }
59        }
60    }
61
62    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
63    /// each read or written to variable.
64    pub fn visit_arithmetic(
65        &mut self,
66        op: &mut Arithmetic,
67        mut visit_read: impl FnMut(&mut Self, &mut Variable),
68    ) {
69        match op {
70            Arithmetic::Fma(fma_operator) => {
71                visit_read(self, &mut fma_operator.a);
72                visit_read(self, &mut fma_operator.b);
73                visit_read(self, &mut fma_operator.c);
74            }
75            Arithmetic::Add(binary_operator)
76            | Arithmetic::Sub(binary_operator)
77            | Arithmetic::Mul(binary_operator)
78            | Arithmetic::Div(binary_operator)
79            | Arithmetic::Powf(binary_operator)
80            | Arithmetic::Modulo(binary_operator)
81            | Arithmetic::Max(binary_operator)
82            | Arithmetic::Min(binary_operator)
83            | Arithmetic::Remainder(binary_operator)
84            | Arithmetic::Dot(binary_operator)
85            | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read),
86
87            Arithmetic::Abs(unary_operator)
88            | Arithmetic::Exp(unary_operator)
89            | Arithmetic::Log(unary_operator)
90            | Arithmetic::Log1p(unary_operator)
91            | Arithmetic::Cos(unary_operator)
92            | Arithmetic::Sin(unary_operator)
93            | Arithmetic::Tanh(unary_operator)
94            | Arithmetic::Sqrt(unary_operator)
95            | Arithmetic::Round(unary_operator)
96            | Arithmetic::Floor(unary_operator)
97            | Arithmetic::Ceil(unary_operator)
98            | Arithmetic::Erf(unary_operator)
99            | Arithmetic::Recip(unary_operator)
100            | Arithmetic::Neg(unary_operator)
101            | Arithmetic::Magnitude(unary_operator)
102            | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
103
104            Arithmetic::Clamp(clamp_operator) => {
105                visit_read(self, &mut clamp_operator.input);
106                visit_read(self, &mut clamp_operator.min_value);
107                visit_read(self, &mut clamp_operator.max_value);
108            }
109        }
110    }
111
112    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
113    /// each read or written to variable.
114    pub fn visit_compare(
115        &mut self,
116        op: &mut Comparison,
117        visit_read: impl FnMut(&mut Self, &mut Variable),
118    ) {
119        match op {
120            Comparison::Equal(binary_operator)
121            | Comparison::NotEqual(binary_operator)
122            | Comparison::LowerEqual(binary_operator)
123            | Comparison::Greater(binary_operator)
124            | Comparison::Lower(binary_operator)
125            | Comparison::GreaterEqual(binary_operator) => {
126                self.visit_binop(binary_operator, visit_read)
127            }
128        }
129    }
130
131    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
132    /// each read or written to variable.
133    pub fn visit_bitwise(
134        &mut self,
135        op: &mut Bitwise,
136        visit_read: impl FnMut(&mut Self, &mut Variable),
137    ) {
138        match op {
139            Bitwise::BitwiseAnd(binary_operator)
140            | Bitwise::BitwiseOr(binary_operator)
141            | Bitwise::BitwiseXor(binary_operator)
142            | Bitwise::ShiftLeft(binary_operator)
143            | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
144
145            Bitwise::CountOnes(unary_operator)
146            | Bitwise::BitwiseNot(unary_operator)
147            | Bitwise::ReverseBits(unary_operator)
148            | Bitwise::LeadingZeros(unary_operator)
149            | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
150        }
151    }
152
153    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
154    /// each read or written to variable.
155    pub fn visit_operator(
156        &mut self,
157        op: &mut Operator,
158        mut visit_read: impl FnMut(&mut Self, &mut Variable),
159    ) {
160        match op {
161            Operator::UncheckedIndex(binary_operator)
162            | Operator::UncheckedIndexAssign(binary_operator)
163            | Operator::Index(binary_operator)
164            | Operator::IndexAssign(binary_operator)
165            | Operator::And(binary_operator)
166            | Operator::Or(binary_operator) => self.visit_binop(binary_operator, visit_read),
167            Operator::Not(unary_operator)
168            | Operator::Cast(unary_operator)
169            | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
170            Operator::Slice(slice_operator) => {
171                visit_read(self, &mut slice_operator.start);
172                visit_read(self, &mut slice_operator.end);
173                visit_read(self, &mut slice_operator.input);
174            }
175            Operator::ReinterpretSlice(_) => {
176                todo!()
177            }
178            Operator::InitLine(line_init_operator) => {
179                for input in &mut line_init_operator.inputs {
180                    visit_read(self, input)
181                }
182            }
183            Operator::CopyMemory(copy_operator) => {
184                visit_read(self, &mut copy_operator.input);
185                visit_read(self, &mut copy_operator.in_index);
186                visit_read(self, &mut copy_operator.out_index);
187            }
188            Operator::CopyMemoryBulk(copy_bulk_operator) => {
189                visit_read(self, &mut copy_bulk_operator.input);
190                visit_read(self, &mut copy_bulk_operator.in_index);
191                visit_read(self, &mut copy_bulk_operator.out_index);
192            }
193            Operator::Select(select) => {
194                visit_read(self, &mut select.cond);
195                visit_read(self, &mut select.then);
196                visit_read(self, &mut select.or_else);
197            }
198        }
199    }
200
201    fn visit_atomic(
202        &mut self,
203        atomic: &mut AtomicOp,
204        out: &mut Option<Variable>,
205        mut visit_read: impl FnMut(&mut Self, &mut Variable),
206    ) {
207        match atomic {
208            AtomicOp::Add(binary_operator)
209            | AtomicOp::Sub(binary_operator)
210            | AtomicOp::Max(binary_operator)
211            | AtomicOp::Min(binary_operator)
212            | AtomicOp::And(binary_operator)
213            | AtomicOp::Or(binary_operator)
214            | AtomicOp::Xor(binary_operator)
215            | AtomicOp::Swap(binary_operator) => {
216                self.visit_binop(binary_operator, visit_read);
217            }
218            AtomicOp::Load(unary_operator) => {
219                self.visit_unop(unary_operator, visit_read);
220            }
221            AtomicOp::Store(unary_operator) => {
222                visit_read(self, out.as_mut().unwrap());
223                self.visit_unop(unary_operator, visit_read);
224            }
225            AtomicOp::CompareAndSwap(op) => {
226                visit_read(self, &mut op.cmp);
227                visit_read(self, &mut op.cmp);
228                visit_read(self, &mut op.val);
229            }
230        }
231    }
232    fn visit_meta(
233        &mut self,
234        metadata: &mut Metadata,
235        mut visit_read: impl FnMut(&mut Self, &mut Variable),
236    ) {
237        match metadata {
238            Metadata::Rank { var } => {
239                visit_read(self, var);
240            }
241            Metadata::Stride { dim, var } => {
242                visit_read(self, dim);
243                visit_read(self, var);
244            }
245            Metadata::Shape { dim, var } => {
246                visit_read(self, dim);
247                visit_read(self, var);
248            }
249            Metadata::Length { var } => {
250                visit_read(self, var);
251            }
252            Metadata::BufferLength { var } => {
253                visit_read(self, var);
254            }
255        }
256    }
257
258    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
259        match plane {
260            Plane::Elect => {}
261            Plane::Broadcast(binary_operator) => self.visit_binop(binary_operator, visit_read),
262            Plane::All(unary_operator)
263            | Plane::Any(unary_operator)
264            | Plane::Sum(unary_operator)
265            | Plane::InclusiveSum(unary_operator)
266            | Plane::ExclusiveSum(unary_operator)
267            | Plane::Prod(unary_operator)
268            | Plane::InclusiveProd(unary_operator)
269            | Plane::ExclusiveProd(unary_operator)
270            | Plane::Min(unary_operator)
271            | Plane::Max(unary_operator)
272            | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
273        }
274    }
275
276    fn visit_cmma(
277        &mut self,
278        cmma: &mut CoopMma,
279        mut visit_read: impl FnMut(&mut Self, &mut Variable),
280    ) {
281        match cmma {
282            CoopMma::Fill { value } => {
283                visit_read(self, value);
284            }
285            CoopMma::Load { value, stride, .. } => {
286                visit_read(self, value);
287                visit_read(self, stride);
288            }
289            CoopMma::Execute {
290                mat_a,
291                mat_b,
292                mat_c,
293            } => {
294                visit_read(self, mat_a);
295                visit_read(self, mat_b);
296                visit_read(self, mat_c);
297            }
298            CoopMma::Store { mat, stride, .. } => {
299                visit_read(self, mat);
300                visit_read(self, stride);
301            }
302            CoopMma::Cast { input } => {
303                visit_read(self, input);
304            }
305        }
306    }
307
308    fn visit_pipeline(
309        &mut self,
310        pipeline_ops: &mut PipelineOps,
311        mut visit_read: impl FnMut(&mut Self, &mut Variable),
312    ) {
313        match pipeline_ops {
314            PipelineOps::MemCopyAsync {
315                pipeline,
316                source,
317                destination,
318            } => {
319                visit_read(self, pipeline);
320                visit_read(self, source);
321                visit_read(self, destination);
322            }
323            PipelineOps::ProducerAcquire { pipeline } => visit_read(self, pipeline),
324            PipelineOps::ProducerCommit { pipeline } => visit_read(self, pipeline),
325            PipelineOps::ConsumerWait { pipeline } => visit_read(self, pipeline),
326            PipelineOps::ConsumerRelease { pipeline } => visit_read(self, pipeline),
327        }
328    }
329
330    fn visit_barrier(
331        &mut self,
332        barrier_ops: &mut BarrierOps,
333        mut visit_read: impl FnMut(&mut Self, &mut Variable),
334    ) {
335        match barrier_ops {
336            BarrierOps::Init { barrier, .. } => {
337                visit_read(self, barrier);
338            }
339            BarrierOps::MemCopyAsync { barrier, source } => {
340                visit_read(self, barrier);
341                visit_read(self, source);
342            }
343            BarrierOps::TmaLoad {
344                barrier,
345                tensor_map,
346                indices,
347            } => {
348                visit_read(self, barrier);
349                visit_read(self, tensor_map);
350                for index in indices {
351                    visit_read(self, index);
352                }
353            }
354            BarrierOps::TmaLoadIm2col {
355                barrier,
356                tensor_map,
357                indices,
358                offsets,
359            } => {
360                visit_read(self, barrier);
361                visit_read(self, tensor_map);
362                for index in indices {
363                    visit_read(self, index);
364                }
365                for offset in offsets {
366                    visit_read(self, offset);
367                }
368            }
369            BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
370            BarrierOps::Arrive { barrier } => visit_read(self, barrier),
371            BarrierOps::ArriveTx {
372                barrier,
373                arrive_count_update,
374                transaction_count_update,
375            } => {
376                visit_read(self, barrier);
377                visit_read(self, arrive_count_update);
378                visit_read(self, transaction_count_update);
379            }
380            BarrierOps::ExpectTx {
381                barrier,
382                transaction_count_update,
383            } => {
384                visit_read(self, barrier);
385                visit_read(self, transaction_count_update);
386            }
387            BarrierOps::Wait { barrier } => {
388                visit_read(self, barrier);
389            }
390        }
391    }
392
393    fn visit_tma(
394        &mut self,
395        tma_ops: &mut TmaOps,
396        mut visit_read: impl FnMut(&mut Self, &mut Variable),
397    ) {
398        match tma_ops {
399            TmaOps::TmaStore {
400                source,
401                coordinates,
402            } => {
403                visit_read(self, source);
404                for coord in coordinates {
405                    visit_read(self, coord)
406                }
407            }
408            TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
409        }
410    }
411
412    fn visit_nonsemantic(
413        &mut self,
414        non_semantic: &mut NonSemantic,
415        mut visit_read: impl FnMut(&mut Self, &mut Variable),
416    ) {
417        match non_semantic {
418            NonSemantic::Comment { .. }
419            | NonSemantic::EnterDebugScope
420            | NonSemantic::ExitDebugScope => {}
421            NonSemantic::Print { args, .. } => {
422                for arg in args {
423                    visit_read(self, arg);
424                }
425            }
426        }
427    }
428
429    fn visit_unop(
430        &mut self,
431        unop: &mut UnaryOperator,
432        mut visit_read: impl FnMut(&mut Self, &mut Variable),
433    ) {
434        visit_read(self, &mut unop.input);
435    }
436
437    fn visit_binop(
438        &mut self,
439        binop: &mut BinaryOperator,
440        mut visit_read: impl FnMut(&mut Self, &mut Variable),
441    ) {
442        visit_read(self, &mut binop.lhs);
443        visit_read(self, &mut binop.rhs);
444    }
445}