1use cubecl_ir::{
2 Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3 Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9 pub fn visit_out(
10 &mut self,
11 var: &mut Option<Variable>,
12 mut visit_write: impl FnMut(&mut Self, &mut Variable),
13 ) {
14 if let Some(out) = var {
15 visit_write(self, out);
16 }
17 }
18
19 pub fn visit_instruction(
22 &mut self,
23 inst: &mut Instruction,
24 visit_read: impl FnMut(&mut Self, &mut Variable),
25 visit_write: impl FnMut(&mut Self, &mut Variable),
26 ) {
27 self.visit_out(&mut inst.out, visit_write);
28 self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29 }
30
31 pub fn visit_operation(
34 &mut self,
35 op: &mut Operation,
36 out: &mut Option<Variable>,
37 mut visit_read: impl FnMut(&mut Self, &mut Variable),
38 ) {
39 match op {
40 Operation::Copy(variable) => visit_read(self, variable),
41 Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42 Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43 Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45 Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47 Operation::Synchronization(_) => {}
49 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51 Operation::Branch(_) => unreachable!(),
52 Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53 Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54 Operation::NonSemantic(non_semantic) => {
55 self.visit_nonsemantic(non_semantic, visit_read)
56 }
57 Operation::Marker(_) => {}
58 }
59 }
60
61 pub fn visit_arithmetic(
64 &mut self,
65 op: &mut Arithmetic,
66 mut visit_read: impl FnMut(&mut Self, &mut Variable),
67 ) {
68 match op {
69 Arithmetic::Fma(fma_operator) => {
70 visit_read(self, &mut fma_operator.a);
71 visit_read(self, &mut fma_operator.b);
72 visit_read(self, &mut fma_operator.c);
73 }
74 Arithmetic::Add(binary_operator)
75 | Arithmetic::SaturatingAdd(binary_operator)
76 | Arithmetic::Sub(binary_operator)
77 | Arithmetic::SaturatingSub(binary_operator)
78 | Arithmetic::Mul(binary_operator)
79 | Arithmetic::Div(binary_operator)
80 | Arithmetic::Powf(binary_operator)
81 | Arithmetic::Powi(binary_operator)
82 | Arithmetic::Modulo(binary_operator)
83 | Arithmetic::Max(binary_operator)
84 | Arithmetic::Min(binary_operator)
85 | Arithmetic::Remainder(binary_operator)
86 | Arithmetic::Dot(binary_operator)
87 | Arithmetic::MulHi(binary_operator)
88 | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read),
89
90 Arithmetic::Abs(unary_operator)
91 | Arithmetic::Exp(unary_operator)
92 | Arithmetic::Log(unary_operator)
93 | Arithmetic::Log1p(unary_operator)
94 | Arithmetic::Cos(unary_operator)
95 | Arithmetic::Sin(unary_operator)
96 | Arithmetic::Tan(unary_operator)
97 | Arithmetic::Tanh(unary_operator)
98 | Arithmetic::Sinh(unary_operator)
99 | Arithmetic::Cosh(unary_operator)
100 | Arithmetic::ArcCos(unary_operator)
101 | Arithmetic::ArcSin(unary_operator)
102 | Arithmetic::ArcTan(unary_operator)
103 | Arithmetic::ArcSinh(unary_operator)
104 | Arithmetic::ArcCosh(unary_operator)
105 | Arithmetic::ArcTanh(unary_operator)
106 | Arithmetic::Degrees(unary_operator)
107 | Arithmetic::Radians(unary_operator)
108 | Arithmetic::Sqrt(unary_operator)
109 | Arithmetic::InverseSqrt(unary_operator)
110 | Arithmetic::Round(unary_operator)
111 | Arithmetic::Floor(unary_operator)
112 | Arithmetic::Ceil(unary_operator)
113 | Arithmetic::Trunc(unary_operator)
114 | Arithmetic::Erf(unary_operator)
115 | Arithmetic::Recip(unary_operator)
116 | Arithmetic::Neg(unary_operator)
117 | Arithmetic::Magnitude(unary_operator)
118 | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
119
120 Arithmetic::Clamp(clamp_operator) => {
121 visit_read(self, &mut clamp_operator.input);
122 visit_read(self, &mut clamp_operator.min_value);
123 visit_read(self, &mut clamp_operator.max_value);
124 }
125 }
126 }
127
128 pub fn visit_compare(
131 &mut self,
132 op: &mut Comparison,
133 visit_read: impl FnMut(&mut Self, &mut Variable),
134 ) {
135 match op {
136 Comparison::Equal(binary_operator)
137 | Comparison::NotEqual(binary_operator)
138 | Comparison::LowerEqual(binary_operator)
139 | Comparison::Greater(binary_operator)
140 | Comparison::Lower(binary_operator)
141 | Comparison::GreaterEqual(binary_operator) => {
142 self.visit_binop(binary_operator, visit_read)
143 }
144 Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
145 self.visit_unop(unary_operator, visit_read)
146 }
147 }
148 }
149
150 pub fn visit_bitwise(
153 &mut self,
154 op: &mut Bitwise,
155 visit_read: impl FnMut(&mut Self, &mut Variable),
156 ) {
157 match op {
158 Bitwise::BitwiseAnd(binary_operator)
159 | Bitwise::BitwiseOr(binary_operator)
160 | Bitwise::BitwiseXor(binary_operator)
161 | Bitwise::ShiftLeft(binary_operator)
162 | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
163
164 Bitwise::CountOnes(unary_operator)
165 | Bitwise::BitwiseNot(unary_operator)
166 | Bitwise::ReverseBits(unary_operator)
167 | Bitwise::LeadingZeros(unary_operator)
168 | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
169 }
170 }
171
172 pub fn visit_operator(
175 &mut self,
176 op: &mut Operator,
177 mut visit_read: impl FnMut(&mut Self, &mut Variable),
178 ) {
179 match op {
180 Operator::And(binary_operator) | Operator::Or(binary_operator) => {
181 self.visit_binop(binary_operator, visit_read)
182 }
183 Operator::Not(unary_operator)
184 | Operator::Cast(unary_operator)
185 | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
186 Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
187 visit_read(self, &mut index_operator.list);
188 visit_read(self, &mut index_operator.index);
189 }
190 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
191 visit_read(self, &mut op.index);
192 visit_read(self, &mut op.value);
193 }
194 Operator::InitLine(line_init_operator) => {
195 for input in &mut line_init_operator.inputs {
196 visit_read(self, input)
197 }
198 }
199 Operator::CopyMemory(copy_operator) => {
200 visit_read(self, &mut copy_operator.input);
201 visit_read(self, &mut copy_operator.in_index);
202 visit_read(self, &mut copy_operator.out_index);
203 }
204 Operator::CopyMemoryBulk(copy_bulk_operator) => {
205 visit_read(self, &mut copy_bulk_operator.input);
206 visit_read(self, &mut copy_bulk_operator.in_index);
207 visit_read(self, &mut copy_bulk_operator.out_index);
208 }
209 Operator::Select(select) => {
210 visit_read(self, &mut select.cond);
211 visit_read(self, &mut select.then);
212 visit_read(self, &mut select.or_else);
213 }
214 }
215 }
216
217 fn visit_atomic(
218 &mut self,
219 atomic: &mut AtomicOp,
220 out: &mut Option<Variable>,
221 mut visit_read: impl FnMut(&mut Self, &mut Variable),
222 ) {
223 match atomic {
224 AtomicOp::Add(binary_operator)
225 | AtomicOp::Sub(binary_operator)
226 | AtomicOp::Max(binary_operator)
227 | AtomicOp::Min(binary_operator)
228 | AtomicOp::And(binary_operator)
229 | AtomicOp::Or(binary_operator)
230 | AtomicOp::Xor(binary_operator)
231 | AtomicOp::Swap(binary_operator) => {
232 self.visit_binop(binary_operator, visit_read);
233 }
234 AtomicOp::Load(unary_operator) => {
235 self.visit_unop(unary_operator, visit_read);
236 }
237 AtomicOp::Store(unary_operator) => {
238 visit_read(self, out.as_mut().unwrap());
239 self.visit_unop(unary_operator, visit_read);
240 }
241 AtomicOp::CompareAndSwap(op) => {
242 visit_read(self, &mut op.cmp);
243 visit_read(self, &mut op.cmp);
244 visit_read(self, &mut op.val);
245 }
246 }
247 }
248 fn visit_meta(
249 &mut self,
250 metadata: &mut Metadata,
251 mut visit_read: impl FnMut(&mut Self, &mut Variable),
252 ) {
253 match metadata {
254 Metadata::Rank { var } => {
255 visit_read(self, var);
256 }
257 Metadata::Stride { dim, var } => {
258 visit_read(self, dim);
259 visit_read(self, var);
260 }
261 Metadata::Shape { dim, var } => {
262 visit_read(self, dim);
263 visit_read(self, var);
264 }
265 Metadata::Length { var } => {
266 visit_read(self, var);
267 }
268 Metadata::BufferLength { var } => {
269 visit_read(self, var);
270 }
271 }
272 }
273
274 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
275 match plane {
276 Plane::Elect => {}
277 Plane::Broadcast(binary_operator)
278 | Plane::Shuffle(binary_operator)
279 | Plane::ShuffleXor(binary_operator)
280 | Plane::ShuffleUp(binary_operator)
281 | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
282 Plane::All(unary_operator)
283 | Plane::Any(unary_operator)
284 | Plane::Sum(unary_operator)
285 | Plane::InclusiveSum(unary_operator)
286 | Plane::ExclusiveSum(unary_operator)
287 | Plane::Prod(unary_operator)
288 | Plane::InclusiveProd(unary_operator)
289 | Plane::ExclusiveProd(unary_operator)
290 | Plane::Min(unary_operator)
291 | Plane::Max(unary_operator)
292 | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
293 }
294 }
295
296 fn visit_cmma(
297 &mut self,
298 cmma: &mut CoopMma,
299 mut visit_read: impl FnMut(&mut Self, &mut Variable),
300 ) {
301 match cmma {
302 CoopMma::Fill { value } => {
303 visit_read(self, value);
304 }
305 CoopMma::Load {
306 value,
307 stride,
308 offset,
309 layout: _,
310 } => {
311 visit_read(self, value);
312 visit_read(self, stride);
313 visit_read(self, offset);
314 }
315 CoopMma::Execute {
316 mat_a,
317 mat_b,
318 mat_c,
319 } => {
320 visit_read(self, mat_a);
321 visit_read(self, mat_b);
322 visit_read(self, mat_c);
323 }
324 CoopMma::Store {
325 mat,
326 stride,
327 offset,
328 layout: _,
329 } => {
330 visit_read(self, mat);
331 visit_read(self, stride);
332 visit_read(self, offset);
333 }
334 CoopMma::Cast { input } => {
335 visit_read(self, input);
336 }
337 CoopMma::RowIndex { lane_id, i, .. } => {
338 visit_read(self, lane_id);
339 visit_read(self, i);
340 }
341 CoopMma::ColIndex { lane_id, i, .. } => {
342 visit_read(self, lane_id);
343 visit_read(self, i);
344 }
345 CoopMma::LoadMatrix { buffer, offset, .. } => {
346 visit_read(self, buffer);
347 visit_read(self, offset);
348 }
349 CoopMma::StoreMatrix {
350 offset, registers, ..
351 } => {
352 visit_read(self, offset);
353 visit_read(self, registers);
354 }
355 CoopMma::ExecuteManual {
356 registers_a,
357 registers_b,
358 registers_c,
359 ..
360 } => {
361 visit_read(self, registers_a);
362 visit_read(self, registers_b);
363 visit_read(self, registers_c);
364 }
365 CoopMma::ExecuteScaled {
366 registers_a,
367 registers_b,
368 registers_c,
369 scales_a,
370 scales_b,
371 ..
372 } => {
373 visit_read(self, registers_a);
374 visit_read(self, registers_b);
375 visit_read(self, registers_c);
376 visit_read(self, scales_a);
377 visit_read(self, scales_b);
378 }
379 }
380 }
381
382 fn visit_barrier(
383 &mut self,
384 barrier_ops: &mut BarrierOps,
385 mut visit_read: impl FnMut(&mut Self, &mut Variable),
386 ) {
387 match barrier_ops {
388 BarrierOps::Declare { barrier } => visit_read(self, barrier),
389 BarrierOps::Init {
390 barrier,
391 is_elected,
392 arrival_count,
393 ..
394 } => {
395 visit_read(self, barrier);
396 visit_read(self, is_elected);
397 visit_read(self, arrival_count);
398 }
399 BarrierOps::InitManual {
400 barrier,
401 arrival_count,
402 } => {
403 visit_read(self, barrier);
404 visit_read(self, arrival_count);
405 }
406 BarrierOps::MemCopyAsync {
407 barrier,
408 source,
409 source_length,
410 offset_source,
411 offset_out,
412 } => {
413 visit_read(self, barrier);
414 visit_read(self, source_length);
415 visit_read(self, source);
416 visit_read(self, offset_source);
417 visit_read(self, offset_out);
418 }
419 BarrierOps::MemCopyAsyncCooperative {
420 barrier,
421 source,
422 source_length,
423 offset_source,
424 offset_out,
425 } => {
426 visit_read(self, barrier);
427 visit_read(self, source_length);
428 visit_read(self, source);
429 visit_read(self, offset_source);
430 visit_read(self, offset_out);
431 }
432 BarrierOps::MemCopyAsyncTx {
433 barrier,
434 source,
435 source_length,
436 offset_source,
437 offset_out,
438 } => {
439 visit_read(self, barrier);
440 visit_read(self, source_length);
441 visit_read(self, source);
442 visit_read(self, offset_source);
443 visit_read(self, offset_out);
444 }
445 BarrierOps::TmaLoad {
446 barrier,
447 offset_out,
448 tensor_map,
449 indices,
450 } => {
451 visit_read(self, offset_out);
452 visit_read(self, barrier);
453 visit_read(self, tensor_map);
454 for index in indices {
455 visit_read(self, index);
456 }
457 }
458 BarrierOps::TmaLoadIm2col {
459 barrier,
460 tensor_map,
461 indices,
462 offset_out,
463 offsets,
464 } => {
465 visit_read(self, offset_out);
466 visit_read(self, barrier);
467 visit_read(self, tensor_map);
468 for index in indices {
469 visit_read(self, index);
470 }
471 for offset in offsets {
472 visit_read(self, offset);
473 }
474 }
475 BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
476 BarrierOps::Arrive { barrier } => visit_read(self, barrier),
477 BarrierOps::ArriveTx {
478 barrier,
479 arrive_count_update,
480 transaction_count_update,
481 } => {
482 visit_read(self, barrier);
483 visit_read(self, arrive_count_update);
484 visit_read(self, transaction_count_update);
485 }
486 BarrierOps::ExpectTx {
487 barrier,
488 transaction_count_update,
489 } => {
490 visit_read(self, barrier);
491 visit_read(self, transaction_count_update);
492 }
493 BarrierOps::Wait { barrier, token } => {
494 visit_read(self, barrier);
495 visit_read(self, token);
496 }
497 BarrierOps::WaitParity { barrier, phase } => {
498 visit_read(self, barrier);
499 visit_read(self, phase);
500 }
501 }
502 }
503
504 fn visit_tma(
505 &mut self,
506 tma_ops: &mut TmaOps,
507 mut visit_read: impl FnMut(&mut Self, &mut Variable),
508 ) {
509 match tma_ops {
510 TmaOps::TmaStore {
511 source,
512 coordinates,
513 offset_source,
514 } => {
515 visit_read(self, source);
516 visit_read(self, offset_source);
517 for coord in coordinates {
518 visit_read(self, coord)
519 }
520 }
521 TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
522 }
523 }
524
525 fn visit_nonsemantic(
526 &mut self,
527 non_semantic: &mut NonSemantic,
528 mut visit_read: impl FnMut(&mut Self, &mut Variable),
529 ) {
530 match non_semantic {
531 NonSemantic::Comment { .. }
532 | NonSemantic::EnterDebugScope
533 | NonSemantic::ExitDebugScope => {}
534 NonSemantic::Print { args, .. } => {
535 for arg in args {
536 visit_read(self, arg);
537 }
538 }
539 }
540 }
541
542 fn visit_unop(
543 &mut self,
544 unop: &mut UnaryOperator,
545 mut visit_read: impl FnMut(&mut Self, &mut Variable),
546 ) {
547 visit_read(self, &mut unop.input);
548 }
549
550 fn visit_binop(
551 &mut self,
552 binop: &mut BinaryOperator,
553 mut visit_read: impl FnMut(&mut Self, &mut Variable),
554 ) {
555 visit_read(self, &mut binop.lhs);
556 visit_read(self, &mut binop.rhs);
557 }
558}