cubecl_opt/
instructions.rs

1use cubecl_core::ir::{
2    AtomicOp, BinaryOperator, CoopMma, Instruction, Metadata, Operation, Operator, Plane,
3    UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9    pub fn visit_out(
10        &mut self,
11        var: &mut Option<Variable>,
12        mut visit_write: impl FnMut(&mut Self, &mut Variable),
13    ) {
14        if let Some(out) = var {
15            visit_write(self, out);
16        }
17    }
18
19    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
20    /// each read or written to variable.
21    pub fn visit_instruction(
22        &mut self,
23        inst: &mut Instruction,
24        visit_read: impl FnMut(&mut Self, &mut Variable),
25        visit_write: impl FnMut(&mut Self, &mut Variable),
26    ) {
27        self.visit_out(&mut inst.out, visit_write);
28        self.visit_operation(&mut inst.operation, visit_read);
29    }
30
31    /// Visit an operation with a set of read and write visitors. Each visitor will be called with
32    /// each read or written to variable.
33    pub fn visit_operation(
34        &mut self,
35        op: &mut Operation,
36        mut visit_read: impl FnMut(&mut Self, &mut Variable),
37    ) {
38        match op {
39            Operation::Copy(variable) => visit_read(self, variable),
40            Operation::Operator(operator) => self.visit_operator(operator, visit_read),
41            Operation::Atomic(atomic) => self.visit_atomic(atomic, visit_read),
42            Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
43            // Sync has no outputs
44            Operation::Synchronization(_) | Operation::NonSemantic(_) => {}
45            Operation::Plane(plane) => self.visit_plane(plane, visit_read),
46            Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
47            Operation::Branch(_) => unreachable!(),
48        }
49    }
50
51    /// Visit an operator with a set of read and write visitors. Each visitor will be called with
52    /// each read or written to variable.
53    pub fn visit_operator(
54        &mut self,
55        op: &mut Operator,
56        mut visit_read: impl FnMut(&mut Self, &mut Variable),
57    ) {
58        match op {
59            Operator::Fma(fma_operator) => {
60                visit_read(self, &mut fma_operator.a);
61                visit_read(self, &mut fma_operator.b);
62                visit_read(self, &mut fma_operator.c);
63            }
64            Operator::Add(binary_operator)
65            | Operator::Sub(binary_operator)
66            | Operator::Mul(binary_operator)
67            | Operator::Div(binary_operator)
68            | Operator::Powf(binary_operator)
69            | Operator::Equal(binary_operator)
70            | Operator::NotEqual(binary_operator)
71            | Operator::LowerEqual(binary_operator)
72            | Operator::UncheckedIndex(binary_operator)
73            | Operator::UncheckedIndexAssign(binary_operator)
74            | Operator::Modulo(binary_operator)
75            | Operator::Index(binary_operator)
76            | Operator::IndexAssign(binary_operator)
77            | Operator::And(binary_operator)
78            | Operator::Greater(binary_operator)
79            | Operator::Lower(binary_operator)
80            | Operator::Or(binary_operator)
81            | Operator::Max(binary_operator)
82            | Operator::Min(binary_operator)
83            | Operator::BitwiseAnd(binary_operator)
84            | Operator::BitwiseOr(binary_operator)
85            | Operator::BitwiseXor(binary_operator)
86            | Operator::ShiftLeft(binary_operator)
87            | Operator::ShiftRight(binary_operator)
88            | Operator::Remainder(binary_operator)
89            | Operator::Dot(binary_operator)
90            | Operator::GreaterEqual(binary_operator) => {
91                self.visit_binop(binary_operator, visit_read)
92            }
93
94            Operator::Abs(unary_operator)
95            | Operator::Exp(unary_operator)
96            | Operator::Log(unary_operator)
97            | Operator::Log1p(unary_operator)
98            | Operator::Cos(unary_operator)
99            | Operator::Sin(unary_operator)
100            | Operator::Tanh(unary_operator)
101            | Operator::Sqrt(unary_operator)
102            | Operator::Round(unary_operator)
103            | Operator::Floor(unary_operator)
104            | Operator::Ceil(unary_operator)
105            | Operator::Erf(unary_operator)
106            | Operator::Recip(unary_operator)
107            | Operator::Not(unary_operator)
108            | Operator::Neg(unary_operator)
109            | Operator::Cast(unary_operator)
110            | Operator::Bitcast(unary_operator)
111            | Operator::Magnitude(unary_operator)
112            | Operator::Normalize(unary_operator)
113            | Operator::CountOnes(unary_operator)
114            | Operator::ReverseBits(unary_operator) => self.visit_unop(unary_operator, visit_read),
115
116            Operator::Clamp(clamp_operator) => {
117                visit_read(self, &mut clamp_operator.input);
118                visit_read(self, &mut clamp_operator.min_value);
119                visit_read(self, &mut clamp_operator.max_value);
120            }
121            Operator::Slice(slice_operator) => {
122                visit_read(self, &mut slice_operator.start);
123                visit_read(self, &mut slice_operator.end);
124                visit_read(self, &mut slice_operator.input);
125            }
126            Operator::InitLine(line_init_operator) => {
127                for input in &mut line_init_operator.inputs {
128                    visit_read(self, input)
129                }
130            }
131            Operator::CopyMemory(copy_operator) => {
132                visit_read(self, &mut copy_operator.input);
133                visit_read(self, &mut copy_operator.in_index);
134                visit_read(self, &mut copy_operator.out_index);
135            }
136            Operator::CopyMemoryBulk(copy_bulk_operator) => {
137                visit_read(self, &mut copy_bulk_operator.input);
138                visit_read(self, &mut copy_bulk_operator.in_index);
139                visit_read(self, &mut copy_bulk_operator.out_index);
140            }
141            Operator::Select(select) => {
142                visit_read(self, &mut select.cond);
143                visit_read(self, &mut select.then);
144                visit_read(self, &mut select.or_else);
145            }
146        }
147    }
148
149    fn visit_atomic(
150        &mut self,
151        atomic: &mut AtomicOp,
152        mut visit_read: impl FnMut(&mut Self, &mut Variable),
153    ) {
154        match atomic {
155            AtomicOp::Add(binary_operator)
156            | AtomicOp::Sub(binary_operator)
157            | AtomicOp::Max(binary_operator)
158            | AtomicOp::Min(binary_operator)
159            | AtomicOp::And(binary_operator)
160            | AtomicOp::Or(binary_operator)
161            | AtomicOp::Xor(binary_operator)
162            | AtomicOp::Swap(binary_operator) => {
163                self.visit_binop(binary_operator, visit_read);
164            }
165            AtomicOp::Load(unary_operator) | AtomicOp::Store(unary_operator) => {
166                self.visit_unop(unary_operator, visit_read);
167            }
168            AtomicOp::CompareAndSwap(op) => {
169                visit_read(self, &mut op.cmp);
170                visit_read(self, &mut op.cmp);
171                visit_read(self, &mut op.val);
172            }
173        }
174    }
175    fn visit_meta(
176        &mut self,
177        metadata: &mut Metadata,
178        mut visit_read: impl FnMut(&mut Self, &mut Variable),
179    ) {
180        match metadata {
181            Metadata::Rank { var } => {
182                visit_read(self, var);
183            }
184            Metadata::Stride { dim, var } => {
185                visit_read(self, dim);
186                visit_read(self, var);
187            }
188            Metadata::Shape { dim, var } => {
189                visit_read(self, dim);
190                visit_read(self, var);
191            }
192            Metadata::Length { var } => {
193                visit_read(self, var);
194            }
195            Metadata::BufferLength { var } => {
196                visit_read(self, var);
197            }
198        }
199    }
200
201    fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
202        match plane {
203            Plane::Elect => {}
204            Plane::Broadcast(binary_operator) => self.visit_binop(binary_operator, visit_read),
205            Plane::All(unary_operator)
206            | Plane::Any(unary_operator)
207            | Plane::Sum(unary_operator)
208            | Plane::Prod(unary_operator)
209            | Plane::Min(unary_operator)
210            | Plane::Max(unary_operator) => self.visit_unop(unary_operator, visit_read),
211        }
212    }
213
214    fn visit_cmma(
215        &mut self,
216        cmma: &mut CoopMma,
217        mut visit_read: impl FnMut(&mut Self, &mut Variable),
218    ) {
219        match cmma {
220            CoopMma::Fill { value } => {
221                visit_read(self, value);
222            }
223            CoopMma::Load { value, stride, .. } => {
224                visit_read(self, value);
225                visit_read(self, stride);
226            }
227            CoopMma::Execute {
228                mat_a,
229                mat_b,
230                mat_c,
231            } => {
232                visit_read(self, mat_a);
233                visit_read(self, mat_b);
234                visit_read(self, mat_c);
235            }
236            CoopMma::Store { mat, stride, .. } => {
237                visit_read(self, mat);
238                visit_read(self, stride);
239            }
240            CoopMma::Cast { input } => {
241                visit_read(self, input);
242            }
243        }
244    }
245
246    fn visit_unop(
247        &mut self,
248        unop: &mut UnaryOperator,
249        mut visit_read: impl FnMut(&mut Self, &mut Variable),
250    ) {
251        visit_read(self, &mut unop.input);
252    }
253
254    fn visit_binop(
255        &mut self,
256        binop: &mut BinaryOperator,
257        mut visit_read: impl FnMut(&mut Self, &mut Variable),
258    ) {
259        visit_read(self, &mut binop.lhs);
260        visit_read(self, &mut binop.rhs);
261    }
262}