1use cubecl_ir::{
2 Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3 Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9 pub fn visit_out(
10 &mut self,
11 var: &mut Option<Variable>,
12 mut visit_write: impl FnMut(&mut Self, &mut Variable),
13 ) {
14 if let Some(out) = var {
15 visit_write(self, out);
16 }
17 }
18
19 pub fn visit_instruction(
22 &mut self,
23 inst: &mut Instruction,
24 visit_read: impl FnMut(&mut Self, &mut Variable),
25 visit_write: impl FnMut(&mut Self, &mut Variable),
26 ) {
27 self.visit_out(&mut inst.out, visit_write);
28 self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29 }
30
31 pub fn visit_operation(
34 &mut self,
35 op: &mut Operation,
36 out: &mut Option<Variable>,
37 mut visit_read: impl FnMut(&mut Self, &mut Variable),
38 ) {
39 match op {
40 Operation::Copy(variable) => visit_read(self, variable),
41 Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42 Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43 Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45 Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47 Operation::Synchronization(_) => {}
49 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51 Operation::Branch(_) => unreachable!(),
52 Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53 Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54 Operation::NonSemantic(non_semantic) => {
55 self.visit_nonsemantic(non_semantic, visit_read)
56 }
57 Operation::Marker(_) => {}
58 }
59 }
60
61 pub fn visit_arithmetic(
64 &mut self,
65 op: &mut Arithmetic,
66 mut visit_read: impl FnMut(&mut Self, &mut Variable),
67 ) {
68 match op {
69 Arithmetic::Fma(fma_operator) => {
70 visit_read(self, &mut fma_operator.a);
71 visit_read(self, &mut fma_operator.b);
72 visit_read(self, &mut fma_operator.c);
73 }
74 Arithmetic::Add(binary_operator)
75 | Arithmetic::SaturatingAdd(binary_operator)
76 | Arithmetic::Sub(binary_operator)
77 | Arithmetic::SaturatingSub(binary_operator)
78 | Arithmetic::Mul(binary_operator)
79 | Arithmetic::Div(binary_operator)
80 | Arithmetic::Powf(binary_operator)
81 | Arithmetic::Powi(binary_operator)
82 | Arithmetic::Hypot(binary_operator)
83 | Arithmetic::Rhypot(binary_operator)
84 | Arithmetic::Modulo(binary_operator)
85 | Arithmetic::Max(binary_operator)
86 | Arithmetic::Min(binary_operator)
87 | Arithmetic::Remainder(binary_operator)
88 | Arithmetic::Dot(binary_operator)
89 | Arithmetic::MulHi(binary_operator)
90 | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
91
92 Arithmetic::Abs(unary_operator)
93 | Arithmetic::Exp(unary_operator)
94 | Arithmetic::Log(unary_operator)
95 | Arithmetic::Log1p(unary_operator)
96 | Arithmetic::Cos(unary_operator)
97 | Arithmetic::Sin(unary_operator)
98 | Arithmetic::Tan(unary_operator)
99 | Arithmetic::Tanh(unary_operator)
100 | Arithmetic::Sinh(unary_operator)
101 | Arithmetic::Cosh(unary_operator)
102 | Arithmetic::ArcCos(unary_operator)
103 | Arithmetic::ArcSin(unary_operator)
104 | Arithmetic::ArcTan(unary_operator)
105 | Arithmetic::ArcSinh(unary_operator)
106 | Arithmetic::ArcCosh(unary_operator)
107 | Arithmetic::ArcTanh(unary_operator)
108 | Arithmetic::Degrees(unary_operator)
109 | Arithmetic::Radians(unary_operator)
110 | Arithmetic::Sqrt(unary_operator)
111 | Arithmetic::InverseSqrt(unary_operator)
112 | Arithmetic::Round(unary_operator)
113 | Arithmetic::Floor(unary_operator)
114 | Arithmetic::Ceil(unary_operator)
115 | Arithmetic::Trunc(unary_operator)
116 | Arithmetic::Erf(unary_operator)
117 | Arithmetic::Recip(unary_operator)
118 | Arithmetic::Neg(unary_operator)
119 | Arithmetic::Magnitude(unary_operator)
120 | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
121
122 Arithmetic::Clamp(clamp_operator) => {
123 visit_read(self, &mut clamp_operator.input);
124 visit_read(self, &mut clamp_operator.min_value);
125 visit_read(self, &mut clamp_operator.max_value);
126 }
127 }
128 }
129
130 pub fn visit_compare(
133 &mut self,
134 op: &mut Comparison,
135 visit_read: impl FnMut(&mut Self, &mut Variable),
136 ) {
137 match op {
138 Comparison::Equal(binary_operator)
139 | Comparison::NotEqual(binary_operator)
140 | Comparison::LowerEqual(binary_operator)
141 | Comparison::Greater(binary_operator)
142 | Comparison::Lower(binary_operator)
143 | Comparison::GreaterEqual(binary_operator) => {
144 self.visit_binop(binary_operator, visit_read)
145 }
146 Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
147 self.visit_unop(unary_operator, visit_read)
148 }
149 }
150 }
151
152 pub fn visit_bitwise(
155 &mut self,
156 op: &mut Bitwise,
157 visit_read: impl FnMut(&mut Self, &mut Variable),
158 ) {
159 match op {
160 Bitwise::BitwiseAnd(binary_operator)
161 | Bitwise::BitwiseOr(binary_operator)
162 | Bitwise::BitwiseXor(binary_operator)
163 | Bitwise::ShiftLeft(binary_operator)
164 | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
165
166 Bitwise::CountOnes(unary_operator)
167 | Bitwise::BitwiseNot(unary_operator)
168 | Bitwise::ReverseBits(unary_operator)
169 | Bitwise::LeadingZeros(unary_operator)
170 | Bitwise::TrailingZeros(unary_operator)
171 | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
172 }
173 }
174
175 pub fn visit_operator(
178 &mut self,
179 op: &mut Operator,
180 mut visit_read: impl FnMut(&mut Self, &mut Variable),
181 ) {
182 match op {
183 Operator::And(binary_operator) | Operator::Or(binary_operator) => {
184 self.visit_binop(binary_operator, visit_read)
185 }
186 Operator::Not(unary_operator)
187 | Operator::Cast(unary_operator)
188 | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
189 Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
190 visit_read(self, &mut index_operator.list);
191 visit_read(self, &mut index_operator.index);
192 }
193 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
194 visit_read(self, &mut op.index);
195 visit_read(self, &mut op.value);
196 }
197 Operator::InitLine(line_init_operator) => {
198 for input in &mut line_init_operator.inputs {
199 visit_read(self, input)
200 }
201 }
202 Operator::CopyMemory(copy_operator) => {
203 visit_read(self, &mut copy_operator.input);
204 visit_read(self, &mut copy_operator.in_index);
205 visit_read(self, &mut copy_operator.out_index);
206 }
207 Operator::CopyMemoryBulk(copy_bulk_operator) => {
208 visit_read(self, &mut copy_bulk_operator.input);
209 visit_read(self, &mut copy_bulk_operator.in_index);
210 visit_read(self, &mut copy_bulk_operator.out_index);
211 }
212 Operator::Select(select) => {
213 visit_read(self, &mut select.cond);
214 visit_read(self, &mut select.then);
215 visit_read(self, &mut select.or_else);
216 }
217 }
218 }
219
220 fn visit_atomic(
221 &mut self,
222 atomic: &mut AtomicOp,
223 out: &mut Option<Variable>,
224 mut visit_read: impl FnMut(&mut Self, &mut Variable),
225 ) {
226 match atomic {
227 AtomicOp::Add(binary_operator)
228 | AtomicOp::Sub(binary_operator)
229 | AtomicOp::Max(binary_operator)
230 | AtomicOp::Min(binary_operator)
231 | AtomicOp::And(binary_operator)
232 | AtomicOp::Or(binary_operator)
233 | AtomicOp::Xor(binary_operator)
234 | AtomicOp::Swap(binary_operator) => {
235 self.visit_binop(binary_operator, visit_read);
236 }
237 AtomicOp::Load(unary_operator) => {
238 self.visit_unop(unary_operator, visit_read);
239 }
240 AtomicOp::Store(unary_operator) => {
241 visit_read(self, out.as_mut().unwrap());
242 self.visit_unop(unary_operator, visit_read);
243 }
244 AtomicOp::CompareAndSwap(op) => {
245 visit_read(self, &mut op.cmp);
246 visit_read(self, &mut op.cmp);
247 visit_read(self, &mut op.val);
248 }
249 }
250 }
251 fn visit_meta(
252 &mut self,
253 metadata: &mut Metadata,
254 mut visit_read: impl FnMut(&mut Self, &mut Variable),
255 ) {
256 match metadata {
257 Metadata::Rank { var } => {
258 visit_read(self, var);
259 }
260 Metadata::Stride { dim, var } => {
261 visit_read(self, dim);
262 visit_read(self, var);
263 }
264 Metadata::Shape { dim, var } => {
265 visit_read(self, dim);
266 visit_read(self, var);
267 }
268 Metadata::Length { var } => {
269 visit_read(self, var);
270 }
271 Metadata::BufferLength { var } => {
272 visit_read(self, var);
273 }
274 }
275 }
276
277 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
278 match plane {
279 Plane::Elect => {}
280 Plane::Broadcast(binary_operator)
281 | Plane::Shuffle(binary_operator)
282 | Plane::ShuffleXor(binary_operator)
283 | Plane::ShuffleUp(binary_operator)
284 | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
285 Plane::All(unary_operator)
286 | Plane::Any(unary_operator)
287 | Plane::Sum(unary_operator)
288 | Plane::InclusiveSum(unary_operator)
289 | Plane::ExclusiveSum(unary_operator)
290 | Plane::Prod(unary_operator)
291 | Plane::InclusiveProd(unary_operator)
292 | Plane::ExclusiveProd(unary_operator)
293 | Plane::Min(unary_operator)
294 | Plane::Max(unary_operator)
295 | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
296 }
297 }
298
299 fn visit_cmma(
300 &mut self,
301 cmma: &mut CoopMma,
302 mut visit_read: impl FnMut(&mut Self, &mut Variable),
303 ) {
304 match cmma {
305 CoopMma::Fill { value } => {
306 visit_read(self, value);
307 }
308 CoopMma::Load {
309 value,
310 stride,
311 offset,
312 layout: _,
313 } => {
314 visit_read(self, value);
315 visit_read(self, stride);
316 visit_read(self, offset);
317 }
318 CoopMma::Execute {
319 mat_a,
320 mat_b,
321 mat_c,
322 } => {
323 visit_read(self, mat_a);
324 visit_read(self, mat_b);
325 visit_read(self, mat_c);
326 }
327 CoopMma::Store {
328 mat,
329 stride,
330 offset,
331 layout: _,
332 } => {
333 visit_read(self, mat);
334 visit_read(self, stride);
335 visit_read(self, offset);
336 }
337 CoopMma::Cast { input } => {
338 visit_read(self, input);
339 }
340 CoopMma::RowIndex { lane_id, i, .. } => {
341 visit_read(self, lane_id);
342 visit_read(self, i);
343 }
344 CoopMma::ColIndex { lane_id, i, .. } => {
345 visit_read(self, lane_id);
346 visit_read(self, i);
347 }
348 CoopMma::LoadMatrix { buffer, offset, .. } => {
349 visit_read(self, buffer);
350 visit_read(self, offset);
351 }
352 CoopMma::StoreMatrix {
353 offset, registers, ..
354 } => {
355 visit_read(self, offset);
356 visit_read(self, registers);
357 }
358 CoopMma::ExecuteManual {
359 registers_a,
360 registers_b,
361 registers_c,
362 ..
363 } => {
364 visit_read(self, registers_a);
365 visit_read(self, registers_b);
366 visit_read(self, registers_c);
367 }
368 CoopMma::ExecuteScaled {
369 registers_a,
370 registers_b,
371 registers_c,
372 scales_a,
373 scales_b,
374 ..
375 } => {
376 visit_read(self, registers_a);
377 visit_read(self, registers_b);
378 visit_read(self, registers_c);
379 visit_read(self, scales_a);
380 visit_read(self, scales_b);
381 }
382 }
383 }
384
385 fn visit_barrier(
386 &mut self,
387 barrier_ops: &mut BarrierOps,
388 mut visit_read: impl FnMut(&mut Self, &mut Variable),
389 ) {
390 match barrier_ops {
391 BarrierOps::Declare { barrier } => visit_read(self, barrier),
392 BarrierOps::Init {
393 barrier,
394 is_elected,
395 arrival_count,
396 ..
397 } => {
398 visit_read(self, barrier);
399 visit_read(self, is_elected);
400 visit_read(self, arrival_count);
401 }
402 BarrierOps::InitManual {
403 barrier,
404 arrival_count,
405 } => {
406 visit_read(self, barrier);
407 visit_read(self, arrival_count);
408 }
409 BarrierOps::MemCopyAsync {
410 barrier,
411 source,
412 source_length,
413 offset_source,
414 offset_out,
415 } => {
416 visit_read(self, barrier);
417 visit_read(self, source_length);
418 visit_read(self, source);
419 visit_read(self, offset_source);
420 visit_read(self, offset_out);
421 }
422 BarrierOps::MemCopyAsyncCooperative {
423 barrier,
424 source,
425 source_length,
426 offset_source,
427 offset_out,
428 } => {
429 visit_read(self, barrier);
430 visit_read(self, source_length);
431 visit_read(self, source);
432 visit_read(self, offset_source);
433 visit_read(self, offset_out);
434 }
435 BarrierOps::CopyAsync {
436 source,
437 source_length,
438 offset_source,
439 offset_out,
440 ..
441 } => {
442 visit_read(self, source_length);
443 visit_read(self, source);
444 visit_read(self, offset_source);
445 visit_read(self, offset_out);
446 }
447 BarrierOps::MemCopyAsyncTx {
448 barrier,
449 source,
450 source_length,
451 offset_source,
452 offset_out,
453 } => {
454 visit_read(self, barrier);
455 visit_read(self, source_length);
456 visit_read(self, source);
457 visit_read(self, offset_source);
458 visit_read(self, offset_out);
459 }
460 BarrierOps::TmaLoad {
461 barrier,
462 offset_out,
463 tensor_map,
464 indices,
465 } => {
466 visit_read(self, offset_out);
467 visit_read(self, barrier);
468 visit_read(self, tensor_map);
469 for index in indices {
470 visit_read(self, index);
471 }
472 }
473 BarrierOps::TmaLoadIm2col {
474 barrier,
475 tensor_map,
476 indices,
477 offset_out,
478 offsets,
479 } => {
480 visit_read(self, offset_out);
481 visit_read(self, barrier);
482 visit_read(self, tensor_map);
483 for index in indices {
484 visit_read(self, index);
485 }
486 for offset in offsets {
487 visit_read(self, offset);
488 }
489 }
490 BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
491 BarrierOps::Arrive { barrier } => visit_read(self, barrier),
492 BarrierOps::ArriveTx {
493 barrier,
494 arrive_count_update,
495 transaction_count_update,
496 } => {
497 visit_read(self, barrier);
498 visit_read(self, arrive_count_update);
499 visit_read(self, transaction_count_update);
500 }
501 BarrierOps::CommitCopyAsync { barrier } => visit_read(self, barrier),
502 BarrierOps::ExpectTx {
503 barrier,
504 transaction_count_update,
505 } => {
506 visit_read(self, barrier);
507 visit_read(self, transaction_count_update);
508 }
509 BarrierOps::Wait { barrier, token } => {
510 visit_read(self, barrier);
511 visit_read(self, token);
512 }
513 BarrierOps::WaitParity { barrier, phase } => {
514 visit_read(self, barrier);
515 visit_read(self, phase);
516 }
517 }
518 }
519
520 fn visit_tma(
521 &mut self,
522 tma_ops: &mut TmaOps,
523 mut visit_read: impl FnMut(&mut Self, &mut Variable),
524 ) {
525 match tma_ops {
526 TmaOps::TmaStore {
527 source,
528 coordinates,
529 offset_source,
530 } => {
531 visit_read(self, source);
532 visit_read(self, offset_source);
533 for coord in coordinates {
534 visit_read(self, coord)
535 }
536 }
537 TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
538 }
539 }
540
541 fn visit_nonsemantic(
542 &mut self,
543 non_semantic: &mut NonSemantic,
544 mut visit_read: impl FnMut(&mut Self, &mut Variable),
545 ) {
546 match non_semantic {
547 NonSemantic::Comment { .. }
548 | NonSemantic::EnterDebugScope
549 | NonSemantic::ExitDebugScope => {}
550 NonSemantic::Print { args, .. } => {
551 for arg in args {
552 visit_read(self, arg);
553 }
554 }
555 }
556 }
557
558 fn visit_unop(
559 &mut self,
560 unop: &mut UnaryOperator,
561 mut visit_read: impl FnMut(&mut Self, &mut Variable),
562 ) {
563 visit_read(self, &mut unop.input);
564 }
565
566 fn visit_binop(
567 &mut self,
568 binop: &mut BinaryOperator,
569 mut visit_read: impl FnMut(&mut Self, &mut Variable),
570 ) {
571 visit_read(self, &mut binop.lhs);
572 visit_read(self, &mut binop.rhs);
573 }
574}