1use cubecl_ir::{
2 Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3 Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9 pub fn visit_out(
10 &mut self,
11 var: &mut Option<Variable>,
12 mut visit_write: impl FnMut(&mut Self, &mut Variable),
13 ) {
14 if let Some(out) = var {
15 visit_write(self, out);
16 }
17 }
18
19 pub fn visit_instruction(
22 &mut self,
23 inst: &mut Instruction,
24 visit_read: impl FnMut(&mut Self, &mut Variable),
25 visit_write: impl FnMut(&mut Self, &mut Variable),
26 ) {
27 self.visit_out(&mut inst.out, visit_write);
28 self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29 }
30
31 pub fn visit_operation(
34 &mut self,
35 op: &mut Operation,
36 out: &mut Option<Variable>,
37 mut visit_read: impl FnMut(&mut Self, &mut Variable),
38 ) {
39 match op {
40 Operation::Copy(variable) => visit_read(self, variable),
41 Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42 Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43 Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45 Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47 Operation::Synchronization(_) => {}
49 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51 Operation::Branch(_) => unreachable!(),
52 Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53 Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54 Operation::NonSemantic(non_semantic) => {
55 self.visit_nonsemantic(non_semantic, visit_read)
56 }
57 }
58 }
59
60 pub fn visit_arithmetic(
63 &mut self,
64 op: &mut Arithmetic,
65 mut visit_read: impl FnMut(&mut Self, &mut Variable),
66 ) {
67 match op {
68 Arithmetic::Fma(fma_operator) => {
69 visit_read(self, &mut fma_operator.a);
70 visit_read(self, &mut fma_operator.b);
71 visit_read(self, &mut fma_operator.c);
72 }
73 Arithmetic::Add(binary_operator)
74 | Arithmetic::Sub(binary_operator)
75 | Arithmetic::Mul(binary_operator)
76 | Arithmetic::Div(binary_operator)
77 | Arithmetic::Powf(binary_operator)
78 | Arithmetic::Modulo(binary_operator)
79 | Arithmetic::Max(binary_operator)
80 | Arithmetic::Min(binary_operator)
81 | Arithmetic::Remainder(binary_operator)
82 | Arithmetic::Dot(binary_operator)
83 | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read),
84
85 Arithmetic::Abs(unary_operator)
86 | Arithmetic::Exp(unary_operator)
87 | Arithmetic::Log(unary_operator)
88 | Arithmetic::Log1p(unary_operator)
89 | Arithmetic::Cos(unary_operator)
90 | Arithmetic::Sin(unary_operator)
91 | Arithmetic::Tanh(unary_operator)
92 | Arithmetic::Sqrt(unary_operator)
93 | Arithmetic::Round(unary_operator)
94 | Arithmetic::Floor(unary_operator)
95 | Arithmetic::Ceil(unary_operator)
96 | Arithmetic::Erf(unary_operator)
97 | Arithmetic::Recip(unary_operator)
98 | Arithmetic::Neg(unary_operator)
99 | Arithmetic::Magnitude(unary_operator)
100 | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
101
102 Arithmetic::Clamp(clamp_operator) => {
103 visit_read(self, &mut clamp_operator.input);
104 visit_read(self, &mut clamp_operator.min_value);
105 visit_read(self, &mut clamp_operator.max_value);
106 }
107 }
108 }
109
110 pub fn visit_compare(
113 &mut self,
114 op: &mut Comparison,
115 visit_read: impl FnMut(&mut Self, &mut Variable),
116 ) {
117 match op {
118 Comparison::Equal(binary_operator)
119 | Comparison::NotEqual(binary_operator)
120 | Comparison::LowerEqual(binary_operator)
121 | Comparison::Greater(binary_operator)
122 | Comparison::Lower(binary_operator)
123 | Comparison::GreaterEqual(binary_operator) => {
124 self.visit_binop(binary_operator, visit_read)
125 }
126 }
127 }
128
129 pub fn visit_bitwise(
132 &mut self,
133 op: &mut Bitwise,
134 visit_read: impl FnMut(&mut Self, &mut Variable),
135 ) {
136 match op {
137 Bitwise::BitwiseAnd(binary_operator)
138 | Bitwise::BitwiseOr(binary_operator)
139 | Bitwise::BitwiseXor(binary_operator)
140 | Bitwise::ShiftLeft(binary_operator)
141 | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
142
143 Bitwise::CountOnes(unary_operator)
144 | Bitwise::BitwiseNot(unary_operator)
145 | Bitwise::ReverseBits(unary_operator)
146 | Bitwise::LeadingZeros(unary_operator)
147 | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
148 }
149 }
150
151 pub fn visit_operator(
154 &mut self,
155 op: &mut Operator,
156 mut visit_read: impl FnMut(&mut Self, &mut Variable),
157 ) {
158 match op {
159 Operator::And(binary_operator) | Operator::Or(binary_operator) => {
160 self.visit_binop(binary_operator, visit_read)
161 }
162 Operator::Not(unary_operator)
163 | Operator::Cast(unary_operator)
164 | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
165 Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
166 visit_read(self, &mut index_operator.list);
167 visit_read(self, &mut index_operator.index);
168 }
169 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
170 visit_read(self, &mut op.index);
171 visit_read(self, &mut op.value);
172 }
173 Operator::InitLine(line_init_operator) => {
174 for input in &mut line_init_operator.inputs {
175 visit_read(self, input)
176 }
177 }
178 Operator::CopyMemory(copy_operator) => {
179 visit_read(self, &mut copy_operator.input);
180 visit_read(self, &mut copy_operator.in_index);
181 visit_read(self, &mut copy_operator.out_index);
182 }
183 Operator::CopyMemoryBulk(copy_bulk_operator) => {
184 visit_read(self, &mut copy_bulk_operator.input);
185 visit_read(self, &mut copy_bulk_operator.in_index);
186 visit_read(self, &mut copy_bulk_operator.out_index);
187 }
188 Operator::Select(select) => {
189 visit_read(self, &mut select.cond);
190 visit_read(self, &mut select.then);
191 visit_read(self, &mut select.or_else);
192 }
193 }
194 }
195
196 fn visit_atomic(
197 &mut self,
198 atomic: &mut AtomicOp,
199 out: &mut Option<Variable>,
200 mut visit_read: impl FnMut(&mut Self, &mut Variable),
201 ) {
202 match atomic {
203 AtomicOp::Add(binary_operator)
204 | AtomicOp::Sub(binary_operator)
205 | AtomicOp::Max(binary_operator)
206 | AtomicOp::Min(binary_operator)
207 | AtomicOp::And(binary_operator)
208 | AtomicOp::Or(binary_operator)
209 | AtomicOp::Xor(binary_operator)
210 | AtomicOp::Swap(binary_operator) => {
211 self.visit_binop(binary_operator, visit_read);
212 }
213 AtomicOp::Load(unary_operator) => {
214 self.visit_unop(unary_operator, visit_read);
215 }
216 AtomicOp::Store(unary_operator) => {
217 visit_read(self, out.as_mut().unwrap());
218 self.visit_unop(unary_operator, visit_read);
219 }
220 AtomicOp::CompareAndSwap(op) => {
221 visit_read(self, &mut op.cmp);
222 visit_read(self, &mut op.cmp);
223 visit_read(self, &mut op.val);
224 }
225 }
226 }
227 fn visit_meta(
228 &mut self,
229 metadata: &mut Metadata,
230 mut visit_read: impl FnMut(&mut Self, &mut Variable),
231 ) {
232 match metadata {
233 Metadata::Rank { var } => {
234 visit_read(self, var);
235 }
236 Metadata::Stride { dim, var } => {
237 visit_read(self, dim);
238 visit_read(self, var);
239 }
240 Metadata::Shape { dim, var } => {
241 visit_read(self, dim);
242 visit_read(self, var);
243 }
244 Metadata::Length { var } => {
245 visit_read(self, var);
246 }
247 Metadata::BufferLength { var } => {
248 visit_read(self, var);
249 }
250 }
251 }
252
253 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
254 match plane {
255 Plane::Elect => {}
256 Plane::Broadcast(binary_operator) => self.visit_binop(binary_operator, visit_read),
257 Plane::All(unary_operator)
258 | Plane::Any(unary_operator)
259 | Plane::Sum(unary_operator)
260 | Plane::InclusiveSum(unary_operator)
261 | Plane::ExclusiveSum(unary_operator)
262 | Plane::Prod(unary_operator)
263 | Plane::InclusiveProd(unary_operator)
264 | Plane::ExclusiveProd(unary_operator)
265 | Plane::Min(unary_operator)
266 | Plane::Max(unary_operator)
267 | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
268 }
269 }
270
271 fn visit_cmma(
272 &mut self,
273 cmma: &mut CoopMma,
274 mut visit_read: impl FnMut(&mut Self, &mut Variable),
275 ) {
276 match cmma {
277 CoopMma::Fill { value } => {
278 visit_read(self, value);
279 }
280 CoopMma::Load {
281 value,
282 stride,
283 offset,
284 layout: _,
285 } => {
286 visit_read(self, value);
287 visit_read(self, stride);
288 visit_read(self, offset);
289 }
290 CoopMma::Execute {
291 mat_a,
292 mat_b,
293 mat_c,
294 } => {
295 visit_read(self, mat_a);
296 visit_read(self, mat_b);
297 visit_read(self, mat_c);
298 }
299 CoopMma::Store {
300 mat,
301 stride,
302 offset,
303 layout: _,
304 } => {
305 visit_read(self, mat);
306 visit_read(self, stride);
307 visit_read(self, offset);
308 }
309 CoopMma::Cast { input } => {
310 visit_read(self, input);
311 }
312 }
313 }
314
315 fn visit_barrier(
316 &mut self,
317 barrier_ops: &mut BarrierOps,
318 mut visit_read: impl FnMut(&mut Self, &mut Variable),
319 ) {
320 match barrier_ops {
321 BarrierOps::Init { barrier, .. } => {
322 visit_read(self, barrier);
323 }
324 BarrierOps::MemCopyAsync {
325 barrier,
326 source,
327 source_length,
328 offset_source,
329 offset_out,
330 } => {
331 visit_read(self, barrier);
332 visit_read(self, source_length);
333 visit_read(self, source);
334 visit_read(self, offset_source);
335 visit_read(self, offset_out);
336 }
337 BarrierOps::TmaLoad {
338 barrier,
339 offset_out,
340 tensor_map,
341 indices,
342 } => {
343 visit_read(self, offset_out);
344 visit_read(self, barrier);
345 visit_read(self, tensor_map);
346 for index in indices {
347 visit_read(self, index);
348 }
349 }
350 BarrierOps::TmaLoadIm2col {
351 barrier,
352 tensor_map,
353 indices,
354 offset_out,
355 offsets,
356 } => {
357 visit_read(self, offset_out);
358 visit_read(self, barrier);
359 visit_read(self, tensor_map);
360 for index in indices {
361 visit_read(self, index);
362 }
363 for offset in offsets {
364 visit_read(self, offset);
365 }
366 }
367 BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
368 BarrierOps::Arrive { barrier } => visit_read(self, barrier),
369 BarrierOps::ArriveTx {
370 barrier,
371 arrive_count_update,
372 transaction_count_update,
373 } => {
374 visit_read(self, barrier);
375 visit_read(self, arrive_count_update);
376 visit_read(self, transaction_count_update);
377 }
378 BarrierOps::ExpectTx {
379 barrier,
380 transaction_count_update,
381 } => {
382 visit_read(self, barrier);
383 visit_read(self, transaction_count_update);
384 }
385 BarrierOps::Wait { barrier } => {
386 visit_read(self, barrier);
387 }
388 }
389 }
390
391 fn visit_tma(
392 &mut self,
393 tma_ops: &mut TmaOps,
394 mut visit_read: impl FnMut(&mut Self, &mut Variable),
395 ) {
396 match tma_ops {
397 TmaOps::TmaStore {
398 source,
399 coordinates,
400 offset_source,
401 } => {
402 visit_read(self, source);
403 visit_read(self, offset_source);
404 for coord in coordinates {
405 visit_read(self, coord)
406 }
407 }
408 TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
409 }
410 }
411
412 fn visit_nonsemantic(
413 &mut self,
414 non_semantic: &mut NonSemantic,
415 mut visit_read: impl FnMut(&mut Self, &mut Variable),
416 ) {
417 match non_semantic {
418 NonSemantic::Comment { .. }
419 | NonSemantic::EnterDebugScope
420 | NonSemantic::ExitDebugScope => {}
421 NonSemantic::Print { args, .. } => {
422 for arg in args {
423 visit_read(self, arg);
424 }
425 }
426 }
427 }
428
429 fn visit_unop(
430 &mut self,
431 unop: &mut UnaryOperator,
432 mut visit_read: impl FnMut(&mut Self, &mut Variable),
433 ) {
434 visit_read(self, &mut unop.input);
435 }
436
437 fn visit_binop(
438 &mut self,
439 binop: &mut BinaryOperator,
440 mut visit_read: impl FnMut(&mut Self, &mut Variable),
441 ) {
442 visit_read(self, &mut binop.lhs);
443 visit_read(self, &mut binop.rhs);
444 }
445}