1use cubecl_ir::{
2 Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3 Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use crate::ControlFlow;
7
8use super::Optimizer;
9
10impl Optimizer {
11 pub fn visit_out(
12 &mut self,
13 var: &mut Option<Variable>,
14 mut visit_write: impl FnMut(&mut Self, &mut Variable),
15 ) {
16 if let Some(out) = var {
17 visit_write(self, out);
18 }
19 }
20
21 pub fn visit_instruction(
24 &mut self,
25 inst: &mut Instruction,
26 visit_read: impl FnMut(&mut Self, &mut Variable),
27 visit_write: impl FnMut(&mut Self, &mut Variable),
28 ) {
29 self.visit_out(&mut inst.out, visit_write);
30 self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
31 }
32
33 pub fn visit_operation(
36 &mut self,
37 op: &mut Operation,
38 out: &mut Option<Variable>,
39 mut visit_read: impl FnMut(&mut Self, &mut Variable),
40 ) {
41 match op {
42 Operation::Copy(variable) => visit_read(self, variable),
43 Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
44 Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
45 Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
46 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
47 Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
48 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
49 Operation::Synchronization(_) => {}
51 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
52 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
53 Operation::Branch(_) => unreachable!(),
54 Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
55 Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
56 Operation::NonSemantic(non_semantic) => {
57 self.visit_nonsemantic(non_semantic, visit_read)
58 }
59 Operation::Marker(_) => {}
60 }
61 }
62
63 pub fn visit_control_flow(
66 &mut self,
67 op: &mut ControlFlow,
68 mut visit_read: impl FnMut(&mut Self, &mut Variable),
69 ) {
70 match op {
71 ControlFlow::IfElse { cond, .. } => visit_read(self, cond),
72 ControlFlow::Switch { value, .. } => visit_read(self, value),
73 ControlFlow::Loop { .. } => {}
74 ControlFlow::LoopBreak { break_cond, .. } => visit_read(self, break_cond),
75 ControlFlow::Return | ControlFlow::Unreachable | ControlFlow::None => {}
76 }
77 }
78
79 pub fn visit_arithmetic(
82 &mut self,
83 op: &mut Arithmetic,
84 mut visit_read: impl FnMut(&mut Self, &mut Variable),
85 ) {
86 match op {
87 Arithmetic::Fma(fma_operator) => {
88 visit_read(self, &mut fma_operator.a);
89 visit_read(self, &mut fma_operator.b);
90 visit_read(self, &mut fma_operator.c);
91 }
92 Arithmetic::Add(binary_operator)
93 | Arithmetic::SaturatingAdd(binary_operator)
94 | Arithmetic::Sub(binary_operator)
95 | Arithmetic::SaturatingSub(binary_operator)
96 | Arithmetic::Mul(binary_operator)
97 | Arithmetic::Div(binary_operator)
98 | Arithmetic::Powf(binary_operator)
99 | Arithmetic::Powi(binary_operator)
100 | Arithmetic::Hypot(binary_operator)
101 | Arithmetic::Rhypot(binary_operator)
102 | Arithmetic::Modulo(binary_operator)
103 | Arithmetic::Max(binary_operator)
104 | Arithmetic::Min(binary_operator)
105 | Arithmetic::Remainder(binary_operator)
106 | Arithmetic::Dot(binary_operator)
107 | Arithmetic::MulHi(binary_operator)
108 | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
109
110 Arithmetic::Abs(unary_operator)
111 | Arithmetic::Exp(unary_operator)
112 | Arithmetic::Log(unary_operator)
113 | Arithmetic::Log1p(unary_operator)
114 | Arithmetic::Cos(unary_operator)
115 | Arithmetic::Sin(unary_operator)
116 | Arithmetic::Tan(unary_operator)
117 | Arithmetic::Tanh(unary_operator)
118 | Arithmetic::Sinh(unary_operator)
119 | Arithmetic::Cosh(unary_operator)
120 | Arithmetic::ArcCos(unary_operator)
121 | Arithmetic::ArcSin(unary_operator)
122 | Arithmetic::ArcTan(unary_operator)
123 | Arithmetic::ArcSinh(unary_operator)
124 | Arithmetic::ArcCosh(unary_operator)
125 | Arithmetic::ArcTanh(unary_operator)
126 | Arithmetic::Degrees(unary_operator)
127 | Arithmetic::Radians(unary_operator)
128 | Arithmetic::Sqrt(unary_operator)
129 | Arithmetic::InverseSqrt(unary_operator)
130 | Arithmetic::Round(unary_operator)
131 | Arithmetic::Floor(unary_operator)
132 | Arithmetic::Ceil(unary_operator)
133 | Arithmetic::Trunc(unary_operator)
134 | Arithmetic::Erf(unary_operator)
135 | Arithmetic::Recip(unary_operator)
136 | Arithmetic::Neg(unary_operator)
137 | Arithmetic::Magnitude(unary_operator)
138 | Arithmetic::Normalize(unary_operator)
139 | Arithmetic::VectorSum(unary_operator) => self.visit_unop(unary_operator, visit_read),
140
141 Arithmetic::Clamp(clamp_operator) => {
142 visit_read(self, &mut clamp_operator.input);
143 visit_read(self, &mut clamp_operator.min_value);
144 visit_read(self, &mut clamp_operator.max_value);
145 }
146 }
147 }
148
149 pub fn visit_compare(
152 &mut self,
153 op: &mut Comparison,
154 visit_read: impl FnMut(&mut Self, &mut Variable),
155 ) {
156 match op {
157 Comparison::Equal(binary_operator)
158 | Comparison::NotEqual(binary_operator)
159 | Comparison::LowerEqual(binary_operator)
160 | Comparison::Greater(binary_operator)
161 | Comparison::Lower(binary_operator)
162 | Comparison::GreaterEqual(binary_operator) => {
163 self.visit_binop(binary_operator, visit_read)
164 }
165 Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
166 self.visit_unop(unary_operator, visit_read)
167 }
168 }
169 }
170
171 pub fn visit_bitwise(
174 &mut self,
175 op: &mut Bitwise,
176 visit_read: impl FnMut(&mut Self, &mut Variable),
177 ) {
178 match op {
179 Bitwise::BitwiseAnd(binary_operator)
180 | Bitwise::BitwiseOr(binary_operator)
181 | Bitwise::BitwiseXor(binary_operator)
182 | Bitwise::ShiftLeft(binary_operator)
183 | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
184
185 Bitwise::CountOnes(unary_operator)
186 | Bitwise::BitwiseNot(unary_operator)
187 | Bitwise::ReverseBits(unary_operator)
188 | Bitwise::LeadingZeros(unary_operator)
189 | Bitwise::TrailingZeros(unary_operator)
190 | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
191 }
192 }
193
194 pub fn visit_operator(
197 &mut self,
198 op: &mut Operator,
199 mut visit_read: impl FnMut(&mut Self, &mut Variable),
200 ) {
201 match op {
202 Operator::And(binary_operator) | Operator::Or(binary_operator) => {
203 self.visit_binop(binary_operator, visit_read)
204 }
205 Operator::Not(unary_operator)
206 | Operator::Cast(unary_operator)
207 | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
208 Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
209 visit_read(self, &mut index_operator.list);
210 visit_read(self, &mut index_operator.index);
211 }
212 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
213 visit_read(self, &mut op.index);
214 visit_read(self, &mut op.value);
215 }
216 Operator::InitVector(vector_init_operator) => {
217 for input in &mut vector_init_operator.inputs {
218 visit_read(self, input)
219 }
220 }
221 Operator::CopyMemory(copy_operator) => {
222 visit_read(self, &mut copy_operator.input);
223 visit_read(self, &mut copy_operator.in_index);
224 visit_read(self, &mut copy_operator.out_index);
225 }
226 Operator::CopyMemoryBulk(copy_bulk_operator) => {
227 visit_read(self, &mut copy_bulk_operator.input);
228 visit_read(self, &mut copy_bulk_operator.in_index);
229 visit_read(self, &mut copy_bulk_operator.out_index);
230 }
231 Operator::Select(select) => {
232 visit_read(self, &mut select.cond);
233 visit_read(self, &mut select.then);
234 visit_read(self, &mut select.or_else);
235 }
236 }
237 }
238
239 fn visit_atomic(
240 &mut self,
241 atomic: &mut AtomicOp,
242 out: &mut Option<Variable>,
243 mut visit_read: impl FnMut(&mut Self, &mut Variable),
244 ) {
245 match atomic {
246 AtomicOp::Add(binary_operator)
247 | AtomicOp::Sub(binary_operator)
248 | AtomicOp::Max(binary_operator)
249 | AtomicOp::Min(binary_operator)
250 | AtomicOp::And(binary_operator)
251 | AtomicOp::Or(binary_operator)
252 | AtomicOp::Xor(binary_operator)
253 | AtomicOp::Swap(binary_operator) => {
254 self.visit_binop(binary_operator, visit_read);
255 }
256 AtomicOp::Load(unary_operator) => {
257 self.visit_unop(unary_operator, visit_read);
258 }
259 AtomicOp::Store(unary_operator) => {
260 visit_read(self, out.as_mut().unwrap());
261 self.visit_unop(unary_operator, visit_read);
262 }
263 AtomicOp::CompareAndSwap(op) => {
264 visit_read(self, &mut op.cmp);
265 visit_read(self, &mut op.cmp);
266 visit_read(self, &mut op.val);
267 }
268 }
269 }
270 fn visit_meta(
271 &mut self,
272 metadata: &mut Metadata,
273 mut visit_read: impl FnMut(&mut Self, &mut Variable),
274 ) {
275 match metadata {
276 Metadata::Rank { var } => {
277 visit_read(self, var);
278 }
279 Metadata::Stride { dim, var } => {
280 visit_read(self, dim);
281 visit_read(self, var);
282 }
283 Metadata::Shape { dim, var } => {
284 visit_read(self, dim);
285 visit_read(self, var);
286 }
287 Metadata::Length { var } => {
288 visit_read(self, var);
289 }
290 Metadata::BufferLength { var } => {
291 visit_read(self, var);
292 }
293 }
294 }
295
296 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
297 match plane {
298 Plane::Elect => {}
299 Plane::Broadcast(binary_operator)
300 | Plane::Shuffle(binary_operator)
301 | Plane::ShuffleXor(binary_operator)
302 | Plane::ShuffleUp(binary_operator)
303 | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
304 Plane::All(unary_operator)
305 | Plane::Any(unary_operator)
306 | Plane::Sum(unary_operator)
307 | Plane::InclusiveSum(unary_operator)
308 | Plane::ExclusiveSum(unary_operator)
309 | Plane::Prod(unary_operator)
310 | Plane::InclusiveProd(unary_operator)
311 | Plane::ExclusiveProd(unary_operator)
312 | Plane::Min(unary_operator)
313 | Plane::Max(unary_operator)
314 | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
315 }
316 }
317
318 fn visit_cmma(
319 &mut self,
320 cmma: &mut CoopMma,
321 mut visit_read: impl FnMut(&mut Self, &mut Variable),
322 ) {
323 match cmma {
324 CoopMma::Fill { value } => {
325 visit_read(self, value);
326 }
327 CoopMma::Load {
328 value,
329 stride,
330 offset,
331 layout: _,
332 } => {
333 visit_read(self, value);
334 visit_read(self, stride);
335 visit_read(self, offset);
336 }
337 CoopMma::Execute {
338 mat_a,
339 mat_b,
340 mat_c,
341 } => {
342 visit_read(self, mat_a);
343 visit_read(self, mat_b);
344 visit_read(self, mat_c);
345 }
346 CoopMma::Store {
347 mat,
348 stride,
349 offset,
350 layout: _,
351 } => {
352 visit_read(self, mat);
353 visit_read(self, stride);
354 visit_read(self, offset);
355 }
356 CoopMma::Cast { input } => {
357 visit_read(self, input);
358 }
359 CoopMma::RowIndex { lane_id, i, .. } => {
360 visit_read(self, lane_id);
361 visit_read(self, i);
362 }
363 CoopMma::ColIndex { lane_id, i, .. } => {
364 visit_read(self, lane_id);
365 visit_read(self, i);
366 }
367 CoopMma::LoadMatrix { buffer, offset, .. } => {
368 visit_read(self, buffer);
369 visit_read(self, offset);
370 }
371 CoopMma::StoreMatrix {
372 offset, registers, ..
373 } => {
374 visit_read(self, offset);
375 visit_read(self, registers);
376 }
377 CoopMma::ExecuteManual {
378 registers_a,
379 registers_b,
380 registers_c,
381 ..
382 } => {
383 visit_read(self, registers_a);
384 visit_read(self, registers_b);
385 visit_read(self, registers_c);
386 }
387 CoopMma::ExecuteScaled {
388 registers_a,
389 registers_b,
390 registers_c,
391 scales_a,
392 scales_b,
393 ..
394 } => {
395 visit_read(self, registers_a);
396 visit_read(self, registers_b);
397 visit_read(self, registers_c);
398 visit_read(self, scales_a);
399 visit_read(self, scales_b);
400 }
401 }
402 }
403
404 fn visit_barrier(
405 &mut self,
406 barrier_ops: &mut BarrierOps,
407 mut visit_read: impl FnMut(&mut Self, &mut Variable),
408 ) {
409 match barrier_ops {
410 BarrierOps::Declare { barrier } => visit_read(self, barrier),
411 BarrierOps::Init {
412 barrier,
413 is_elected,
414 arrival_count,
415 ..
416 } => {
417 visit_read(self, barrier);
418 visit_read(self, is_elected);
419 visit_read(self, arrival_count);
420 }
421 BarrierOps::InitManual {
422 barrier,
423 arrival_count,
424 } => {
425 visit_read(self, barrier);
426 visit_read(self, arrival_count);
427 }
428 BarrierOps::MemCopyAsync {
429 barrier,
430 source,
431 source_length,
432 offset_source,
433 offset_out,
434 } => {
435 visit_read(self, barrier);
436 visit_read(self, source_length);
437 visit_read(self, source);
438 visit_read(self, offset_source);
439 visit_read(self, offset_out);
440 }
441 BarrierOps::MemCopyAsyncCooperative {
442 barrier,
443 source,
444 source_length,
445 offset_source,
446 offset_out,
447 } => {
448 visit_read(self, barrier);
449 visit_read(self, source_length);
450 visit_read(self, source);
451 visit_read(self, offset_source);
452 visit_read(self, offset_out);
453 }
454 BarrierOps::CopyAsync {
455 source,
456 source_length,
457 offset_source,
458 offset_out,
459 ..
460 } => {
461 visit_read(self, source_length);
462 visit_read(self, source);
463 visit_read(self, offset_source);
464 visit_read(self, offset_out);
465 }
466 BarrierOps::MemCopyAsyncTx {
467 barrier,
468 source,
469 source_length,
470 offset_source,
471 offset_out,
472 } => {
473 visit_read(self, barrier);
474 visit_read(self, source_length);
475 visit_read(self, source);
476 visit_read(self, offset_source);
477 visit_read(self, offset_out);
478 }
479 BarrierOps::TmaLoad {
480 barrier,
481 offset_out,
482 tensor_map,
483 indices,
484 } => {
485 visit_read(self, offset_out);
486 visit_read(self, barrier);
487 visit_read(self, tensor_map);
488 for index in indices {
489 visit_read(self, index);
490 }
491 }
492 BarrierOps::TmaLoadIm2col {
493 barrier,
494 tensor_map,
495 indices,
496 offset_out,
497 offsets,
498 } => {
499 visit_read(self, offset_out);
500 visit_read(self, barrier);
501 visit_read(self, tensor_map);
502 for index in indices {
503 visit_read(self, index);
504 }
505 for offset in offsets {
506 visit_read(self, offset);
507 }
508 }
509 BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
510 BarrierOps::Arrive { barrier } => visit_read(self, barrier),
511 BarrierOps::ArriveTx {
512 barrier,
513 arrive_count_update,
514 transaction_count_update,
515 } => {
516 visit_read(self, barrier);
517 visit_read(self, arrive_count_update);
518 visit_read(self, transaction_count_update);
519 }
520 BarrierOps::CommitCopyAsync { barrier } => visit_read(self, barrier),
521 BarrierOps::ExpectTx {
522 barrier,
523 transaction_count_update,
524 } => {
525 visit_read(self, barrier);
526 visit_read(self, transaction_count_update);
527 }
528 BarrierOps::Wait { barrier, token } => {
529 visit_read(self, barrier);
530 visit_read(self, token);
531 }
532 BarrierOps::WaitParity { barrier, phase } => {
533 visit_read(self, barrier);
534 visit_read(self, phase);
535 }
536 }
537 }
538
539 fn visit_tma(
540 &mut self,
541 tma_ops: &mut TmaOps,
542 mut visit_read: impl FnMut(&mut Self, &mut Variable),
543 ) {
544 match tma_ops {
545 TmaOps::TmaStore {
546 source,
547 coordinates,
548 offset_source,
549 } => {
550 visit_read(self, source);
551 visit_read(self, offset_source);
552 for coord in coordinates {
553 visit_read(self, coord)
554 }
555 }
556 TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
557 }
558 }
559
560 fn visit_nonsemantic(
561 &mut self,
562 non_semantic: &mut NonSemantic,
563 mut visit_read: impl FnMut(&mut Self, &mut Variable),
564 ) {
565 match non_semantic {
566 NonSemantic::Comment { .. }
567 | NonSemantic::EnterDebugScope
568 | NonSemantic::ExitDebugScope => {}
569 NonSemantic::Print { args, .. } => {
570 for arg in args {
571 visit_read(self, arg);
572 }
573 }
574 }
575 }
576
577 fn visit_unop(
578 &mut self,
579 unop: &mut UnaryOperator,
580 mut visit_read: impl FnMut(&mut Self, &mut Variable),
581 ) {
582 visit_read(self, &mut unop.input);
583 }
584
585 fn visit_binop(
586 &mut self,
587 binop: &mut BinaryOperator,
588 mut visit_read: impl FnMut(&mut Self, &mut Variable),
589 ) {
590 visit_read(self, &mut binop.lhs);
591 visit_read(self, &mut binop.rhs);
592 }
593}