1use cubecl_ir::{
2 Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3 Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9 pub fn visit_out(
10 &mut self,
11 var: &mut Option<Variable>,
12 mut visit_write: impl FnMut(&mut Self, &mut Variable),
13 ) {
14 if let Some(out) = var {
15 visit_write(self, out);
16 }
17 }
18
19 pub fn visit_instruction(
22 &mut self,
23 inst: &mut Instruction,
24 visit_read: impl FnMut(&mut Self, &mut Variable),
25 visit_write: impl FnMut(&mut Self, &mut Variable),
26 ) {
27 self.visit_out(&mut inst.out, visit_write);
28 self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29 }
30
31 pub fn visit_operation(
34 &mut self,
35 op: &mut Operation,
36 out: &mut Option<Variable>,
37 mut visit_read: impl FnMut(&mut Self, &mut Variable),
38 ) {
39 match op {
40 Operation::Copy(variable) => visit_read(self, variable),
41 Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42 Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43 Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45 Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47 Operation::Synchronization(_) => {}
49 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51 Operation::Branch(_) => unreachable!(),
52 Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53 Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54 Operation::NonSemantic(non_semantic) => {
55 self.visit_nonsemantic(non_semantic, visit_read)
56 }
57 Operation::Marker(_) => {}
58 }
59 }
60
61 pub fn visit_arithmetic(
64 &mut self,
65 op: &mut Arithmetic,
66 mut visit_read: impl FnMut(&mut Self, &mut Variable),
67 ) {
68 match op {
69 Arithmetic::Fma(fma_operator) => {
70 visit_read(self, &mut fma_operator.a);
71 visit_read(self, &mut fma_operator.b);
72 visit_read(self, &mut fma_operator.c);
73 }
74 Arithmetic::Add(binary_operator)
75 | Arithmetic::SaturatingAdd(binary_operator)
76 | Arithmetic::Sub(binary_operator)
77 | Arithmetic::SaturatingSub(binary_operator)
78 | Arithmetic::Mul(binary_operator)
79 | Arithmetic::Div(binary_operator)
80 | Arithmetic::Powf(binary_operator)
81 | Arithmetic::Powi(binary_operator)
82 | Arithmetic::Hypot(binary_operator)
83 | Arithmetic::Rhypot(binary_operator)
84 | Arithmetic::Modulo(binary_operator)
85 | Arithmetic::Max(binary_operator)
86 | Arithmetic::Min(binary_operator)
87 | Arithmetic::Remainder(binary_operator)
88 | Arithmetic::Dot(binary_operator)
89 | Arithmetic::MulHi(binary_operator)
90 | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
91
92 Arithmetic::Abs(unary_operator)
93 | Arithmetic::Exp(unary_operator)
94 | Arithmetic::Log(unary_operator)
95 | Arithmetic::Log1p(unary_operator)
96 | Arithmetic::Cos(unary_operator)
97 | Arithmetic::Sin(unary_operator)
98 | Arithmetic::Tan(unary_operator)
99 | Arithmetic::Tanh(unary_operator)
100 | Arithmetic::Sinh(unary_operator)
101 | Arithmetic::Cosh(unary_operator)
102 | Arithmetic::ArcCos(unary_operator)
103 | Arithmetic::ArcSin(unary_operator)
104 | Arithmetic::ArcTan(unary_operator)
105 | Arithmetic::ArcSinh(unary_operator)
106 | Arithmetic::ArcCosh(unary_operator)
107 | Arithmetic::ArcTanh(unary_operator)
108 | Arithmetic::Degrees(unary_operator)
109 | Arithmetic::Radians(unary_operator)
110 | Arithmetic::Sqrt(unary_operator)
111 | Arithmetic::InverseSqrt(unary_operator)
112 | Arithmetic::Round(unary_operator)
113 | Arithmetic::Floor(unary_operator)
114 | Arithmetic::Ceil(unary_operator)
115 | Arithmetic::Trunc(unary_operator)
116 | Arithmetic::Erf(unary_operator)
117 | Arithmetic::Recip(unary_operator)
118 | Arithmetic::Neg(unary_operator)
119 | Arithmetic::Magnitude(unary_operator)
120 | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
121
122 Arithmetic::Clamp(clamp_operator) => {
123 visit_read(self, &mut clamp_operator.input);
124 visit_read(self, &mut clamp_operator.min_value);
125 visit_read(self, &mut clamp_operator.max_value);
126 }
127 }
128 }
129
130 pub fn visit_compare(
133 &mut self,
134 op: &mut Comparison,
135 visit_read: impl FnMut(&mut Self, &mut Variable),
136 ) {
137 match op {
138 Comparison::Equal(binary_operator)
139 | Comparison::NotEqual(binary_operator)
140 | Comparison::LowerEqual(binary_operator)
141 | Comparison::Greater(binary_operator)
142 | Comparison::Lower(binary_operator)
143 | Comparison::GreaterEqual(binary_operator) => {
144 self.visit_binop(binary_operator, visit_read)
145 }
146 Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
147 self.visit_unop(unary_operator, visit_read)
148 }
149 }
150 }
151
152 pub fn visit_bitwise(
155 &mut self,
156 op: &mut Bitwise,
157 visit_read: impl FnMut(&mut Self, &mut Variable),
158 ) {
159 match op {
160 Bitwise::BitwiseAnd(binary_operator)
161 | Bitwise::BitwiseOr(binary_operator)
162 | Bitwise::BitwiseXor(binary_operator)
163 | Bitwise::ShiftLeft(binary_operator)
164 | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
165
166 Bitwise::CountOnes(unary_operator)
167 | Bitwise::BitwiseNot(unary_operator)
168 | Bitwise::ReverseBits(unary_operator)
169 | Bitwise::LeadingZeros(unary_operator)
170 | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
171 }
172 }
173
174 pub fn visit_operator(
177 &mut self,
178 op: &mut Operator,
179 mut visit_read: impl FnMut(&mut Self, &mut Variable),
180 ) {
181 match op {
182 Operator::And(binary_operator) | Operator::Or(binary_operator) => {
183 self.visit_binop(binary_operator, visit_read)
184 }
185 Operator::Not(unary_operator)
186 | Operator::Cast(unary_operator)
187 | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
188 Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
189 visit_read(self, &mut index_operator.list);
190 visit_read(self, &mut index_operator.index);
191 }
192 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
193 visit_read(self, &mut op.index);
194 visit_read(self, &mut op.value);
195 }
196 Operator::InitLine(line_init_operator) => {
197 for input in &mut line_init_operator.inputs {
198 visit_read(self, input)
199 }
200 }
201 Operator::CopyMemory(copy_operator) => {
202 visit_read(self, &mut copy_operator.input);
203 visit_read(self, &mut copy_operator.in_index);
204 visit_read(self, &mut copy_operator.out_index);
205 }
206 Operator::CopyMemoryBulk(copy_bulk_operator) => {
207 visit_read(self, &mut copy_bulk_operator.input);
208 visit_read(self, &mut copy_bulk_operator.in_index);
209 visit_read(self, &mut copy_bulk_operator.out_index);
210 }
211 Operator::Select(select) => {
212 visit_read(self, &mut select.cond);
213 visit_read(self, &mut select.then);
214 visit_read(self, &mut select.or_else);
215 }
216 }
217 }
218
219 fn visit_atomic(
220 &mut self,
221 atomic: &mut AtomicOp,
222 out: &mut Option<Variable>,
223 mut visit_read: impl FnMut(&mut Self, &mut Variable),
224 ) {
225 match atomic {
226 AtomicOp::Add(binary_operator)
227 | AtomicOp::Sub(binary_operator)
228 | AtomicOp::Max(binary_operator)
229 | AtomicOp::Min(binary_operator)
230 | AtomicOp::And(binary_operator)
231 | AtomicOp::Or(binary_operator)
232 | AtomicOp::Xor(binary_operator)
233 | AtomicOp::Swap(binary_operator) => {
234 self.visit_binop(binary_operator, visit_read);
235 }
236 AtomicOp::Load(unary_operator) => {
237 self.visit_unop(unary_operator, visit_read);
238 }
239 AtomicOp::Store(unary_operator) => {
240 visit_read(self, out.as_mut().unwrap());
241 self.visit_unop(unary_operator, visit_read);
242 }
243 AtomicOp::CompareAndSwap(op) => {
244 visit_read(self, &mut op.cmp);
245 visit_read(self, &mut op.cmp);
246 visit_read(self, &mut op.val);
247 }
248 }
249 }
250 fn visit_meta(
251 &mut self,
252 metadata: &mut Metadata,
253 mut visit_read: impl FnMut(&mut Self, &mut Variable),
254 ) {
255 match metadata {
256 Metadata::Rank { var } => {
257 visit_read(self, var);
258 }
259 Metadata::Stride { dim, var } => {
260 visit_read(self, dim);
261 visit_read(self, var);
262 }
263 Metadata::Shape { dim, var } => {
264 visit_read(self, dim);
265 visit_read(self, var);
266 }
267 Metadata::Length { var } => {
268 visit_read(self, var);
269 }
270 Metadata::BufferLength { var } => {
271 visit_read(self, var);
272 }
273 }
274 }
275
276 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
277 match plane {
278 Plane::Elect => {}
279 Plane::Broadcast(binary_operator)
280 | Plane::Shuffle(binary_operator)
281 | Plane::ShuffleXor(binary_operator)
282 | Plane::ShuffleUp(binary_operator)
283 | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
284 Plane::All(unary_operator)
285 | Plane::Any(unary_operator)
286 | Plane::Sum(unary_operator)
287 | Plane::InclusiveSum(unary_operator)
288 | Plane::ExclusiveSum(unary_operator)
289 | Plane::Prod(unary_operator)
290 | Plane::InclusiveProd(unary_operator)
291 | Plane::ExclusiveProd(unary_operator)
292 | Plane::Min(unary_operator)
293 | Plane::Max(unary_operator)
294 | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
295 }
296 }
297
298 fn visit_cmma(
299 &mut self,
300 cmma: &mut CoopMma,
301 mut visit_read: impl FnMut(&mut Self, &mut Variable),
302 ) {
303 match cmma {
304 CoopMma::Fill { value } => {
305 visit_read(self, value);
306 }
307 CoopMma::Load {
308 value,
309 stride,
310 offset,
311 layout: _,
312 } => {
313 visit_read(self, value);
314 visit_read(self, stride);
315 visit_read(self, offset);
316 }
317 CoopMma::Execute {
318 mat_a,
319 mat_b,
320 mat_c,
321 } => {
322 visit_read(self, mat_a);
323 visit_read(self, mat_b);
324 visit_read(self, mat_c);
325 }
326 CoopMma::Store {
327 mat,
328 stride,
329 offset,
330 layout: _,
331 } => {
332 visit_read(self, mat);
333 visit_read(self, stride);
334 visit_read(self, offset);
335 }
336 CoopMma::Cast { input } => {
337 visit_read(self, input);
338 }
339 CoopMma::RowIndex { lane_id, i, .. } => {
340 visit_read(self, lane_id);
341 visit_read(self, i);
342 }
343 CoopMma::ColIndex { lane_id, i, .. } => {
344 visit_read(self, lane_id);
345 visit_read(self, i);
346 }
347 CoopMma::LoadMatrix { buffer, offset, .. } => {
348 visit_read(self, buffer);
349 visit_read(self, offset);
350 }
351 CoopMma::StoreMatrix {
352 offset, registers, ..
353 } => {
354 visit_read(self, offset);
355 visit_read(self, registers);
356 }
357 CoopMma::ExecuteManual {
358 registers_a,
359 registers_b,
360 registers_c,
361 ..
362 } => {
363 visit_read(self, registers_a);
364 visit_read(self, registers_b);
365 visit_read(self, registers_c);
366 }
367 CoopMma::ExecuteScaled {
368 registers_a,
369 registers_b,
370 registers_c,
371 scales_a,
372 scales_b,
373 ..
374 } => {
375 visit_read(self, registers_a);
376 visit_read(self, registers_b);
377 visit_read(self, registers_c);
378 visit_read(self, scales_a);
379 visit_read(self, scales_b);
380 }
381 }
382 }
383
384 fn visit_barrier(
385 &mut self,
386 barrier_ops: &mut BarrierOps,
387 mut visit_read: impl FnMut(&mut Self, &mut Variable),
388 ) {
389 match barrier_ops {
390 BarrierOps::Declare { barrier } => visit_read(self, barrier),
391 BarrierOps::Init {
392 barrier,
393 is_elected,
394 arrival_count,
395 ..
396 } => {
397 visit_read(self, barrier);
398 visit_read(self, is_elected);
399 visit_read(self, arrival_count);
400 }
401 BarrierOps::InitManual {
402 barrier,
403 arrival_count,
404 } => {
405 visit_read(self, barrier);
406 visit_read(self, arrival_count);
407 }
408 BarrierOps::MemCopyAsync {
409 barrier,
410 source,
411 source_length,
412 offset_source,
413 offset_out,
414 } => {
415 visit_read(self, barrier);
416 visit_read(self, source_length);
417 visit_read(self, source);
418 visit_read(self, offset_source);
419 visit_read(self, offset_out);
420 }
421 BarrierOps::MemCopyAsyncCooperative {
422 barrier,
423 source,
424 source_length,
425 offset_source,
426 offset_out,
427 } => {
428 visit_read(self, barrier);
429 visit_read(self, source_length);
430 visit_read(self, source);
431 visit_read(self, offset_source);
432 visit_read(self, offset_out);
433 }
434 BarrierOps::CopyAsync {
435 source,
436 source_length,
437 offset_source,
438 offset_out,
439 ..
440 } => {
441 visit_read(self, source_length);
442 visit_read(self, source);
443 visit_read(self, offset_source);
444 visit_read(self, offset_out);
445 }
446 BarrierOps::MemCopyAsyncTx {
447 barrier,
448 source,
449 source_length,
450 offset_source,
451 offset_out,
452 } => {
453 visit_read(self, barrier);
454 visit_read(self, source_length);
455 visit_read(self, source);
456 visit_read(self, offset_source);
457 visit_read(self, offset_out);
458 }
459 BarrierOps::TmaLoad {
460 barrier,
461 offset_out,
462 tensor_map,
463 indices,
464 } => {
465 visit_read(self, offset_out);
466 visit_read(self, barrier);
467 visit_read(self, tensor_map);
468 for index in indices {
469 visit_read(self, index);
470 }
471 }
472 BarrierOps::TmaLoadIm2col {
473 barrier,
474 tensor_map,
475 indices,
476 offset_out,
477 offsets,
478 } => {
479 visit_read(self, offset_out);
480 visit_read(self, barrier);
481 visit_read(self, tensor_map);
482 for index in indices {
483 visit_read(self, index);
484 }
485 for offset in offsets {
486 visit_read(self, offset);
487 }
488 }
489 BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
490 BarrierOps::Arrive { barrier } => visit_read(self, barrier),
491 BarrierOps::ArriveTx {
492 barrier,
493 arrive_count_update,
494 transaction_count_update,
495 } => {
496 visit_read(self, barrier);
497 visit_read(self, arrive_count_update);
498 visit_read(self, transaction_count_update);
499 }
500 BarrierOps::CommitCopyAsync { barrier } => visit_read(self, barrier),
501 BarrierOps::ExpectTx {
502 barrier,
503 transaction_count_update,
504 } => {
505 visit_read(self, barrier);
506 visit_read(self, transaction_count_update);
507 }
508 BarrierOps::Wait { barrier, token } => {
509 visit_read(self, barrier);
510 visit_read(self, token);
511 }
512 BarrierOps::WaitParity { barrier, phase } => {
513 visit_read(self, barrier);
514 visit_read(self, phase);
515 }
516 }
517 }
518
519 fn visit_tma(
520 &mut self,
521 tma_ops: &mut TmaOps,
522 mut visit_read: impl FnMut(&mut Self, &mut Variable),
523 ) {
524 match tma_ops {
525 TmaOps::TmaStore {
526 source,
527 coordinates,
528 offset_source,
529 } => {
530 visit_read(self, source);
531 visit_read(self, offset_source);
532 for coord in coordinates {
533 visit_read(self, coord)
534 }
535 }
536 TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
537 }
538 }
539
540 fn visit_nonsemantic(
541 &mut self,
542 non_semantic: &mut NonSemantic,
543 mut visit_read: impl FnMut(&mut Self, &mut Variable),
544 ) {
545 match non_semantic {
546 NonSemantic::Comment { .. }
547 | NonSemantic::EnterDebugScope
548 | NonSemantic::ExitDebugScope => {}
549 NonSemantic::Print { args, .. } => {
550 for arg in args {
551 visit_read(self, arg);
552 }
553 }
554 }
555 }
556
557 fn visit_unop(
558 &mut self,
559 unop: &mut UnaryOperator,
560 mut visit_read: impl FnMut(&mut Self, &mut Variable),
561 ) {
562 visit_read(self, &mut unop.input);
563 }
564
565 fn visit_binop(
566 &mut self,
567 binop: &mut BinaryOperator,
568 mut visit_read: impl FnMut(&mut Self, &mut Variable),
569 ) {
570 visit_read(self, &mut binop.lhs);
571 visit_read(self, &mut binop.rhs);
572 }
573}