1use cubecl_ir::{
2 Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3 Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use crate::ControlFlow;
7
8use super::Optimizer;
9
10impl Optimizer {
11 pub fn visit_out(
12 &mut self,
13 var: &mut Option<Variable>,
14 mut visit_write: impl FnMut(&mut Self, &mut Variable),
15 ) {
16 if let Some(out) = var {
17 visit_write(self, out);
18 }
19 }
20
21 pub fn visit_instruction(
24 &mut self,
25 inst: &mut Instruction,
26 visit_read: impl FnMut(&mut Self, &mut Variable),
27 visit_write: impl FnMut(&mut Self, &mut Variable),
28 ) {
29 self.visit_out(&mut inst.out, visit_write);
30 self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
31 }
32
33 pub fn visit_operation(
36 &mut self,
37 op: &mut Operation,
38 out: &mut Option<Variable>,
39 mut visit_read: impl FnMut(&mut Self, &mut Variable),
40 ) {
41 match op {
42 Operation::Copy(variable) => visit_read(self, variable),
43 Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
44 Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
45 Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
46 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
47 Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
48 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
49 Operation::Synchronization(_) => {}
51 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
52 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
53 Operation::Branch(_) => unreachable!(),
54 Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
55 Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
56 Operation::NonSemantic(non_semantic) => {
57 self.visit_nonsemantic(non_semantic, visit_read)
58 }
59 Operation::Marker(_) => {}
60 }
61 }
62
63 pub fn visit_control_flow(
66 &mut self,
67 op: &mut ControlFlow,
68 mut visit_read: impl FnMut(&mut Self, &mut Variable),
69 ) {
70 match op {
71 ControlFlow::IfElse { cond, .. } => visit_read(self, cond),
72 ControlFlow::Switch { value, .. } => visit_read(self, value),
73 ControlFlow::Loop { .. } => {}
74 ControlFlow::LoopBreak { break_cond, .. } => visit_read(self, break_cond),
75 ControlFlow::Return | ControlFlow::Unreachable | ControlFlow::None => {}
76 }
77 }
78
79 pub fn visit_arithmetic(
82 &mut self,
83 op: &mut Arithmetic,
84 mut visit_read: impl FnMut(&mut Self, &mut Variable),
85 ) {
86 match op {
87 Arithmetic::Fma(fma_operator) => {
88 visit_read(self, &mut fma_operator.a);
89 visit_read(self, &mut fma_operator.b);
90 visit_read(self, &mut fma_operator.c);
91 }
92 Arithmetic::Add(binary_operator)
93 | Arithmetic::SaturatingAdd(binary_operator)
94 | Arithmetic::Sub(binary_operator)
95 | Arithmetic::SaturatingSub(binary_operator)
96 | Arithmetic::Mul(binary_operator)
97 | Arithmetic::Div(binary_operator)
98 | Arithmetic::Powf(binary_operator)
99 | Arithmetic::Powi(binary_operator)
100 | Arithmetic::Hypot(binary_operator)
101 | Arithmetic::Rhypot(binary_operator)
102 | Arithmetic::Modulo(binary_operator)
103 | Arithmetic::Max(binary_operator)
104 | Arithmetic::Min(binary_operator)
105 | Arithmetic::Remainder(binary_operator)
106 | Arithmetic::Dot(binary_operator)
107 | Arithmetic::MulHi(binary_operator)
108 | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
109
110 Arithmetic::Abs(unary_operator)
111 | Arithmetic::Exp(unary_operator)
112 | Arithmetic::Log(unary_operator)
113 | Arithmetic::Log1p(unary_operator)
114 | Arithmetic::Cos(unary_operator)
115 | Arithmetic::Sin(unary_operator)
116 | Arithmetic::Tan(unary_operator)
117 | Arithmetic::Tanh(unary_operator)
118 | Arithmetic::Sinh(unary_operator)
119 | Arithmetic::Cosh(unary_operator)
120 | Arithmetic::ArcCos(unary_operator)
121 | Arithmetic::ArcSin(unary_operator)
122 | Arithmetic::ArcTan(unary_operator)
123 | Arithmetic::ArcSinh(unary_operator)
124 | Arithmetic::ArcCosh(unary_operator)
125 | Arithmetic::ArcTanh(unary_operator)
126 | Arithmetic::Degrees(unary_operator)
127 | Arithmetic::Radians(unary_operator)
128 | Arithmetic::Sqrt(unary_operator)
129 | Arithmetic::InverseSqrt(unary_operator)
130 | Arithmetic::Round(unary_operator)
131 | Arithmetic::Floor(unary_operator)
132 | Arithmetic::Ceil(unary_operator)
133 | Arithmetic::Trunc(unary_operator)
134 | Arithmetic::Erf(unary_operator)
135 | Arithmetic::Recip(unary_operator)
136 | Arithmetic::Neg(unary_operator)
137 | Arithmetic::Magnitude(unary_operator)
138 | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
139
140 Arithmetic::Clamp(clamp_operator) => {
141 visit_read(self, &mut clamp_operator.input);
142 visit_read(self, &mut clamp_operator.min_value);
143 visit_read(self, &mut clamp_operator.max_value);
144 }
145 }
146 }
147
148 pub fn visit_compare(
151 &mut self,
152 op: &mut Comparison,
153 visit_read: impl FnMut(&mut Self, &mut Variable),
154 ) {
155 match op {
156 Comparison::Equal(binary_operator)
157 | Comparison::NotEqual(binary_operator)
158 | Comparison::LowerEqual(binary_operator)
159 | Comparison::Greater(binary_operator)
160 | Comparison::Lower(binary_operator)
161 | Comparison::GreaterEqual(binary_operator) => {
162 self.visit_binop(binary_operator, visit_read)
163 }
164 Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
165 self.visit_unop(unary_operator, visit_read)
166 }
167 }
168 }
169
170 pub fn visit_bitwise(
173 &mut self,
174 op: &mut Bitwise,
175 visit_read: impl FnMut(&mut Self, &mut Variable),
176 ) {
177 match op {
178 Bitwise::BitwiseAnd(binary_operator)
179 | Bitwise::BitwiseOr(binary_operator)
180 | Bitwise::BitwiseXor(binary_operator)
181 | Bitwise::ShiftLeft(binary_operator)
182 | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
183
184 Bitwise::CountOnes(unary_operator)
185 | Bitwise::BitwiseNot(unary_operator)
186 | Bitwise::ReverseBits(unary_operator)
187 | Bitwise::LeadingZeros(unary_operator)
188 | Bitwise::TrailingZeros(unary_operator)
189 | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
190 }
191 }
192
193 pub fn visit_operator(
196 &mut self,
197 op: &mut Operator,
198 mut visit_read: impl FnMut(&mut Self, &mut Variable),
199 ) {
200 match op {
201 Operator::And(binary_operator) | Operator::Or(binary_operator) => {
202 self.visit_binop(binary_operator, visit_read)
203 }
204 Operator::Not(unary_operator)
205 | Operator::Cast(unary_operator)
206 | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
207 Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
208 visit_read(self, &mut index_operator.list);
209 visit_read(self, &mut index_operator.index);
210 }
211 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
212 visit_read(self, &mut op.index);
213 visit_read(self, &mut op.value);
214 }
215 Operator::InitVector(vector_init_operator) => {
216 for input in &mut vector_init_operator.inputs {
217 visit_read(self, input)
218 }
219 }
220 Operator::CopyMemory(copy_operator) => {
221 visit_read(self, &mut copy_operator.input);
222 visit_read(self, &mut copy_operator.in_index);
223 visit_read(self, &mut copy_operator.out_index);
224 }
225 Operator::CopyMemoryBulk(copy_bulk_operator) => {
226 visit_read(self, &mut copy_bulk_operator.input);
227 visit_read(self, &mut copy_bulk_operator.in_index);
228 visit_read(self, &mut copy_bulk_operator.out_index);
229 }
230 Operator::Select(select) => {
231 visit_read(self, &mut select.cond);
232 visit_read(self, &mut select.then);
233 visit_read(self, &mut select.or_else);
234 }
235 }
236 }
237
238 fn visit_atomic(
239 &mut self,
240 atomic: &mut AtomicOp,
241 out: &mut Option<Variable>,
242 mut visit_read: impl FnMut(&mut Self, &mut Variable),
243 ) {
244 match atomic {
245 AtomicOp::Add(binary_operator)
246 | AtomicOp::Sub(binary_operator)
247 | AtomicOp::Max(binary_operator)
248 | AtomicOp::Min(binary_operator)
249 | AtomicOp::And(binary_operator)
250 | AtomicOp::Or(binary_operator)
251 | AtomicOp::Xor(binary_operator)
252 | AtomicOp::Swap(binary_operator) => {
253 self.visit_binop(binary_operator, visit_read);
254 }
255 AtomicOp::Load(unary_operator) => {
256 self.visit_unop(unary_operator, visit_read);
257 }
258 AtomicOp::Store(unary_operator) => {
259 visit_read(self, out.as_mut().unwrap());
260 self.visit_unop(unary_operator, visit_read);
261 }
262 AtomicOp::CompareAndSwap(op) => {
263 visit_read(self, &mut op.cmp);
264 visit_read(self, &mut op.cmp);
265 visit_read(self, &mut op.val);
266 }
267 }
268 }
269 fn visit_meta(
270 &mut self,
271 metadata: &mut Metadata,
272 mut visit_read: impl FnMut(&mut Self, &mut Variable),
273 ) {
274 match metadata {
275 Metadata::Rank { var } => {
276 visit_read(self, var);
277 }
278 Metadata::Stride { dim, var } => {
279 visit_read(self, dim);
280 visit_read(self, var);
281 }
282 Metadata::Shape { dim, var } => {
283 visit_read(self, dim);
284 visit_read(self, var);
285 }
286 Metadata::Length { var } => {
287 visit_read(self, var);
288 }
289 Metadata::BufferLength { var } => {
290 visit_read(self, var);
291 }
292 }
293 }
294
295 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
296 match plane {
297 Plane::Elect => {}
298 Plane::Broadcast(binary_operator)
299 | Plane::Shuffle(binary_operator)
300 | Plane::ShuffleXor(binary_operator)
301 | Plane::ShuffleUp(binary_operator)
302 | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
303 Plane::All(unary_operator)
304 | Plane::Any(unary_operator)
305 | Plane::Sum(unary_operator)
306 | Plane::InclusiveSum(unary_operator)
307 | Plane::ExclusiveSum(unary_operator)
308 | Plane::Prod(unary_operator)
309 | Plane::InclusiveProd(unary_operator)
310 | Plane::ExclusiveProd(unary_operator)
311 | Plane::Min(unary_operator)
312 | Plane::Max(unary_operator)
313 | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
314 }
315 }
316
317 fn visit_cmma(
318 &mut self,
319 cmma: &mut CoopMma,
320 mut visit_read: impl FnMut(&mut Self, &mut Variable),
321 ) {
322 match cmma {
323 CoopMma::Fill { value } => {
324 visit_read(self, value);
325 }
326 CoopMma::Load {
327 value,
328 stride,
329 offset,
330 layout: _,
331 } => {
332 visit_read(self, value);
333 visit_read(self, stride);
334 visit_read(self, offset);
335 }
336 CoopMma::Execute {
337 mat_a,
338 mat_b,
339 mat_c,
340 } => {
341 visit_read(self, mat_a);
342 visit_read(self, mat_b);
343 visit_read(self, mat_c);
344 }
345 CoopMma::Store {
346 mat,
347 stride,
348 offset,
349 layout: _,
350 } => {
351 visit_read(self, mat);
352 visit_read(self, stride);
353 visit_read(self, offset);
354 }
355 CoopMma::Cast { input } => {
356 visit_read(self, input);
357 }
358 CoopMma::RowIndex { lane_id, i, .. } => {
359 visit_read(self, lane_id);
360 visit_read(self, i);
361 }
362 CoopMma::ColIndex { lane_id, i, .. } => {
363 visit_read(self, lane_id);
364 visit_read(self, i);
365 }
366 CoopMma::LoadMatrix { buffer, offset, .. } => {
367 visit_read(self, buffer);
368 visit_read(self, offset);
369 }
370 CoopMma::StoreMatrix {
371 offset, registers, ..
372 } => {
373 visit_read(self, offset);
374 visit_read(self, registers);
375 }
376 CoopMma::ExecuteManual {
377 registers_a,
378 registers_b,
379 registers_c,
380 ..
381 } => {
382 visit_read(self, registers_a);
383 visit_read(self, registers_b);
384 visit_read(self, registers_c);
385 }
386 CoopMma::ExecuteScaled {
387 registers_a,
388 registers_b,
389 registers_c,
390 scales_a,
391 scales_b,
392 ..
393 } => {
394 visit_read(self, registers_a);
395 visit_read(self, registers_b);
396 visit_read(self, registers_c);
397 visit_read(self, scales_a);
398 visit_read(self, scales_b);
399 }
400 }
401 }
402
403 fn visit_barrier(
404 &mut self,
405 barrier_ops: &mut BarrierOps,
406 mut visit_read: impl FnMut(&mut Self, &mut Variable),
407 ) {
408 match barrier_ops {
409 BarrierOps::Declare { barrier } => visit_read(self, barrier),
410 BarrierOps::Init {
411 barrier,
412 is_elected,
413 arrival_count,
414 ..
415 } => {
416 visit_read(self, barrier);
417 visit_read(self, is_elected);
418 visit_read(self, arrival_count);
419 }
420 BarrierOps::InitManual {
421 barrier,
422 arrival_count,
423 } => {
424 visit_read(self, barrier);
425 visit_read(self, arrival_count);
426 }
427 BarrierOps::MemCopyAsync {
428 barrier,
429 source,
430 source_length,
431 offset_source,
432 offset_out,
433 } => {
434 visit_read(self, barrier);
435 visit_read(self, source_length);
436 visit_read(self, source);
437 visit_read(self, offset_source);
438 visit_read(self, offset_out);
439 }
440 BarrierOps::MemCopyAsyncCooperative {
441 barrier,
442 source,
443 source_length,
444 offset_source,
445 offset_out,
446 } => {
447 visit_read(self, barrier);
448 visit_read(self, source_length);
449 visit_read(self, source);
450 visit_read(self, offset_source);
451 visit_read(self, offset_out);
452 }
453 BarrierOps::CopyAsync {
454 source,
455 source_length,
456 offset_source,
457 offset_out,
458 ..
459 } => {
460 visit_read(self, source_length);
461 visit_read(self, source);
462 visit_read(self, offset_source);
463 visit_read(self, offset_out);
464 }
465 BarrierOps::MemCopyAsyncTx {
466 barrier,
467 source,
468 source_length,
469 offset_source,
470 offset_out,
471 } => {
472 visit_read(self, barrier);
473 visit_read(self, source_length);
474 visit_read(self, source);
475 visit_read(self, offset_source);
476 visit_read(self, offset_out);
477 }
478 BarrierOps::TmaLoad {
479 barrier,
480 offset_out,
481 tensor_map,
482 indices,
483 } => {
484 visit_read(self, offset_out);
485 visit_read(self, barrier);
486 visit_read(self, tensor_map);
487 for index in indices {
488 visit_read(self, index);
489 }
490 }
491 BarrierOps::TmaLoadIm2col {
492 barrier,
493 tensor_map,
494 indices,
495 offset_out,
496 offsets,
497 } => {
498 visit_read(self, offset_out);
499 visit_read(self, barrier);
500 visit_read(self, tensor_map);
501 for index in indices {
502 visit_read(self, index);
503 }
504 for offset in offsets {
505 visit_read(self, offset);
506 }
507 }
508 BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
509 BarrierOps::Arrive { barrier } => visit_read(self, barrier),
510 BarrierOps::ArriveTx {
511 barrier,
512 arrive_count_update,
513 transaction_count_update,
514 } => {
515 visit_read(self, barrier);
516 visit_read(self, arrive_count_update);
517 visit_read(self, transaction_count_update);
518 }
519 BarrierOps::CommitCopyAsync { barrier } => visit_read(self, barrier),
520 BarrierOps::ExpectTx {
521 barrier,
522 transaction_count_update,
523 } => {
524 visit_read(self, barrier);
525 visit_read(self, transaction_count_update);
526 }
527 BarrierOps::Wait { barrier, token } => {
528 visit_read(self, barrier);
529 visit_read(self, token);
530 }
531 BarrierOps::WaitParity { barrier, phase } => {
532 visit_read(self, barrier);
533 visit_read(self, phase);
534 }
535 }
536 }
537
538 fn visit_tma(
539 &mut self,
540 tma_ops: &mut TmaOps,
541 mut visit_read: impl FnMut(&mut Self, &mut Variable),
542 ) {
543 match tma_ops {
544 TmaOps::TmaStore {
545 source,
546 coordinates,
547 offset_source,
548 } => {
549 visit_read(self, source);
550 visit_read(self, offset_source);
551 for coord in coordinates {
552 visit_read(self, coord)
553 }
554 }
555 TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
556 }
557 }
558
559 fn visit_nonsemantic(
560 &mut self,
561 non_semantic: &mut NonSemantic,
562 mut visit_read: impl FnMut(&mut Self, &mut Variable),
563 ) {
564 match non_semantic {
565 NonSemantic::Comment { .. }
566 | NonSemantic::EnterDebugScope
567 | NonSemantic::ExitDebugScope => {}
568 NonSemantic::Print { args, .. } => {
569 for arg in args {
570 visit_read(self, arg);
571 }
572 }
573 }
574 }
575
576 fn visit_unop(
577 &mut self,
578 unop: &mut UnaryOperator,
579 mut visit_read: impl FnMut(&mut Self, &mut Variable),
580 ) {
581 visit_read(self, &mut unop.input);
582 }
583
584 fn visit_binop(
585 &mut self,
586 binop: &mut BinaryOperator,
587 mut visit_read: impl FnMut(&mut Self, &mut Variable),
588 ) {
589 visit_read(self, &mut binop.lhs);
590 visit_read(self, &mut binop.rhs);
591 }
592}