1use cubecl_ir::{
2 Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3 Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9 pub fn visit_out(
10 &mut self,
11 var: &mut Option<Variable>,
12 mut visit_write: impl FnMut(&mut Self, &mut Variable),
13 ) {
14 if let Some(out) = var {
15 visit_write(self, out);
16 }
17 }
18
19 pub fn visit_instruction(
22 &mut self,
23 inst: &mut Instruction,
24 visit_read: impl FnMut(&mut Self, &mut Variable),
25 visit_write: impl FnMut(&mut Self, &mut Variable),
26 ) {
27 self.visit_out(&mut inst.out, visit_write);
28 self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29 }
30
31 pub fn visit_operation(
34 &mut self,
35 op: &mut Operation,
36 out: &mut Option<Variable>,
37 mut visit_read: impl FnMut(&mut Self, &mut Variable),
38 ) {
39 match op {
40 Operation::Copy(variable) => visit_read(self, variable),
41 Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42 Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43 Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45 Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47 Operation::Synchronization(_) => {}
49 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51 Operation::Branch(_) => unreachable!(),
52 Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53 Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54 Operation::NonSemantic(non_semantic) => {
55 self.visit_nonsemantic(non_semantic, visit_read)
56 }
57 Operation::Marker(_) => {}
58 }
59 }
60
61 pub fn visit_arithmetic(
64 &mut self,
65 op: &mut Arithmetic,
66 mut visit_read: impl FnMut(&mut Self, &mut Variable),
67 ) {
68 match op {
69 Arithmetic::Fma(fma_operator) => {
70 visit_read(self, &mut fma_operator.a);
71 visit_read(self, &mut fma_operator.b);
72 visit_read(self, &mut fma_operator.c);
73 }
74 Arithmetic::Add(binary_operator)
75 | Arithmetic::SaturatingAdd(binary_operator)
76 | Arithmetic::Sub(binary_operator)
77 | Arithmetic::SaturatingSub(binary_operator)
78 | Arithmetic::Mul(binary_operator)
79 | Arithmetic::Div(binary_operator)
80 | Arithmetic::Powf(binary_operator)
81 | Arithmetic::Powi(binary_operator)
82 | Arithmetic::Modulo(binary_operator)
83 | Arithmetic::Max(binary_operator)
84 | Arithmetic::Min(binary_operator)
85 | Arithmetic::Remainder(binary_operator)
86 | Arithmetic::Dot(binary_operator)
87 | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read),
88
89 Arithmetic::Abs(unary_operator)
90 | Arithmetic::Exp(unary_operator)
91 | Arithmetic::Log(unary_operator)
92 | Arithmetic::Log1p(unary_operator)
93 | Arithmetic::Cos(unary_operator)
94 | Arithmetic::Sin(unary_operator)
95 | Arithmetic::Tanh(unary_operator)
96 | Arithmetic::Sqrt(unary_operator)
97 | Arithmetic::InverseSqrt(unary_operator)
98 | Arithmetic::Round(unary_operator)
99 | Arithmetic::Floor(unary_operator)
100 | Arithmetic::Ceil(unary_operator)
101 | Arithmetic::Trunc(unary_operator)
102 | Arithmetic::Erf(unary_operator)
103 | Arithmetic::Recip(unary_operator)
104 | Arithmetic::Neg(unary_operator)
105 | Arithmetic::Magnitude(unary_operator)
106 | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
107
108 Arithmetic::Clamp(clamp_operator) => {
109 visit_read(self, &mut clamp_operator.input);
110 visit_read(self, &mut clamp_operator.min_value);
111 visit_read(self, &mut clamp_operator.max_value);
112 }
113 }
114 }
115
116 pub fn visit_compare(
119 &mut self,
120 op: &mut Comparison,
121 visit_read: impl FnMut(&mut Self, &mut Variable),
122 ) {
123 match op {
124 Comparison::Equal(binary_operator)
125 | Comparison::NotEqual(binary_operator)
126 | Comparison::LowerEqual(binary_operator)
127 | Comparison::Greater(binary_operator)
128 | Comparison::Lower(binary_operator)
129 | Comparison::GreaterEqual(binary_operator) => {
130 self.visit_binop(binary_operator, visit_read)
131 }
132 Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
133 self.visit_unop(unary_operator, visit_read)
134 }
135 }
136 }
137
138 pub fn visit_bitwise(
141 &mut self,
142 op: &mut Bitwise,
143 visit_read: impl FnMut(&mut Self, &mut Variable),
144 ) {
145 match op {
146 Bitwise::BitwiseAnd(binary_operator)
147 | Bitwise::BitwiseOr(binary_operator)
148 | Bitwise::BitwiseXor(binary_operator)
149 | Bitwise::ShiftLeft(binary_operator)
150 | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
151
152 Bitwise::CountOnes(unary_operator)
153 | Bitwise::BitwiseNot(unary_operator)
154 | Bitwise::ReverseBits(unary_operator)
155 | Bitwise::LeadingZeros(unary_operator)
156 | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
157 }
158 }
159
160 pub fn visit_operator(
163 &mut self,
164 op: &mut Operator,
165 mut visit_read: impl FnMut(&mut Self, &mut Variable),
166 ) {
167 match op {
168 Operator::And(binary_operator) | Operator::Or(binary_operator) => {
169 self.visit_binop(binary_operator, visit_read)
170 }
171 Operator::Not(unary_operator)
172 | Operator::Cast(unary_operator)
173 | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
174 Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
175 visit_read(self, &mut index_operator.list);
176 visit_read(self, &mut index_operator.index);
177 }
178 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
179 visit_read(self, &mut op.index);
180 visit_read(self, &mut op.value);
181 }
182 Operator::InitLine(line_init_operator) => {
183 for input in &mut line_init_operator.inputs {
184 visit_read(self, input)
185 }
186 }
187 Operator::CopyMemory(copy_operator) => {
188 visit_read(self, &mut copy_operator.input);
189 visit_read(self, &mut copy_operator.in_index);
190 visit_read(self, &mut copy_operator.out_index);
191 }
192 Operator::CopyMemoryBulk(copy_bulk_operator) => {
193 visit_read(self, &mut copy_bulk_operator.input);
194 visit_read(self, &mut copy_bulk_operator.in_index);
195 visit_read(self, &mut copy_bulk_operator.out_index);
196 }
197 Operator::Select(select) => {
198 visit_read(self, &mut select.cond);
199 visit_read(self, &mut select.then);
200 visit_read(self, &mut select.or_else);
201 }
202 }
203 }
204
205 fn visit_atomic(
206 &mut self,
207 atomic: &mut AtomicOp,
208 out: &mut Option<Variable>,
209 mut visit_read: impl FnMut(&mut Self, &mut Variable),
210 ) {
211 match atomic {
212 AtomicOp::Add(binary_operator)
213 | AtomicOp::Sub(binary_operator)
214 | AtomicOp::Max(binary_operator)
215 | AtomicOp::Min(binary_operator)
216 | AtomicOp::And(binary_operator)
217 | AtomicOp::Or(binary_operator)
218 | AtomicOp::Xor(binary_operator)
219 | AtomicOp::Swap(binary_operator) => {
220 self.visit_binop(binary_operator, visit_read);
221 }
222 AtomicOp::Load(unary_operator) => {
223 self.visit_unop(unary_operator, visit_read);
224 }
225 AtomicOp::Store(unary_operator) => {
226 visit_read(self, out.as_mut().unwrap());
227 self.visit_unop(unary_operator, visit_read);
228 }
229 AtomicOp::CompareAndSwap(op) => {
230 visit_read(self, &mut op.cmp);
231 visit_read(self, &mut op.cmp);
232 visit_read(self, &mut op.val);
233 }
234 }
235 }
236 fn visit_meta(
237 &mut self,
238 metadata: &mut Metadata,
239 mut visit_read: impl FnMut(&mut Self, &mut Variable),
240 ) {
241 match metadata {
242 Metadata::Rank { var } => {
243 visit_read(self, var);
244 }
245 Metadata::Stride { dim, var } => {
246 visit_read(self, dim);
247 visit_read(self, var);
248 }
249 Metadata::Shape { dim, var } => {
250 visit_read(self, dim);
251 visit_read(self, var);
252 }
253 Metadata::Length { var } => {
254 visit_read(self, var);
255 }
256 Metadata::BufferLength { var } => {
257 visit_read(self, var);
258 }
259 }
260 }
261
262 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
263 match plane {
264 Plane::Elect => {}
265 Plane::Broadcast(binary_operator)
266 | Plane::Shuffle(binary_operator)
267 | Plane::ShuffleXor(binary_operator)
268 | Plane::ShuffleUp(binary_operator)
269 | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
270 Plane::All(unary_operator)
271 | Plane::Any(unary_operator)
272 | Plane::Sum(unary_operator)
273 | Plane::InclusiveSum(unary_operator)
274 | Plane::ExclusiveSum(unary_operator)
275 | Plane::Prod(unary_operator)
276 | Plane::InclusiveProd(unary_operator)
277 | Plane::ExclusiveProd(unary_operator)
278 | Plane::Min(unary_operator)
279 | Plane::Max(unary_operator)
280 | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
281 }
282 }
283
284 fn visit_cmma(
285 &mut self,
286 cmma: &mut CoopMma,
287 mut visit_read: impl FnMut(&mut Self, &mut Variable),
288 ) {
289 match cmma {
290 CoopMma::Fill { value } => {
291 visit_read(self, value);
292 }
293 CoopMma::Load {
294 value,
295 stride,
296 offset,
297 layout: _,
298 } => {
299 visit_read(self, value);
300 visit_read(self, stride);
301 visit_read(self, offset);
302 }
303 CoopMma::Execute {
304 mat_a,
305 mat_b,
306 mat_c,
307 } => {
308 visit_read(self, mat_a);
309 visit_read(self, mat_b);
310 visit_read(self, mat_c);
311 }
312 CoopMma::Store {
313 mat,
314 stride,
315 offset,
316 layout: _,
317 } => {
318 visit_read(self, mat);
319 visit_read(self, stride);
320 visit_read(self, offset);
321 }
322 CoopMma::Cast { input } => {
323 visit_read(self, input);
324 }
325 CoopMma::RowIndex { lane_id, i, .. } => {
326 visit_read(self, lane_id);
327 visit_read(self, i);
328 }
329 CoopMma::ColIndex { lane_id, i, .. } => {
330 visit_read(self, lane_id);
331 visit_read(self, i);
332 }
333 CoopMma::ExecuteManual {
334 registers_a,
335 registers_b,
336 registers_c,
337 ..
338 } => {
339 for reg in registers_a {
340 visit_read(self, reg);
341 }
342 for reg in registers_b {
343 visit_read(self, reg);
344 }
345 for reg in registers_c {
346 visit_read(self, reg);
347 }
348 }
349 CoopMma::ExecuteScaled {
350 registers_a,
351 registers_b,
352 registers_c,
353 scales_a,
354 scales_b,
355 ..
356 } => {
357 for reg in registers_a {
358 visit_read(self, reg);
359 }
360 for reg in registers_b {
361 visit_read(self, reg);
362 }
363 for reg in registers_c {
364 visit_read(self, reg);
365 }
366 visit_read(self, scales_a);
367 visit_read(self, scales_b);
368 }
369 }
370 }
371
372 fn visit_barrier(
373 &mut self,
374 barrier_ops: &mut BarrierOps,
375 mut visit_read: impl FnMut(&mut Self, &mut Variable),
376 ) {
377 match barrier_ops {
378 BarrierOps::Declare { barrier } => visit_read(self, barrier),
379 BarrierOps::Init {
380 barrier,
381 is_elected,
382 arrival_count,
383 ..
384 } => {
385 visit_read(self, barrier);
386 visit_read(self, is_elected);
387 visit_read(self, arrival_count);
388 }
389 BarrierOps::InitManual {
390 barrier,
391 arrival_count,
392 } => {
393 visit_read(self, barrier);
394 visit_read(self, arrival_count);
395 }
396 BarrierOps::MemCopyAsync {
397 barrier,
398 source,
399 source_length,
400 offset_source,
401 offset_out,
402 } => {
403 visit_read(self, barrier);
404 visit_read(self, source_length);
405 visit_read(self, source);
406 visit_read(self, offset_source);
407 visit_read(self, offset_out);
408 }
409 BarrierOps::MemCopyAsyncCooperative {
410 barrier,
411 source,
412 source_length,
413 offset_source,
414 offset_out,
415 } => {
416 visit_read(self, barrier);
417 visit_read(self, source_length);
418 visit_read(self, source);
419 visit_read(self, offset_source);
420 visit_read(self, offset_out);
421 }
422 BarrierOps::MemCopyAsyncTx {
423 barrier,
424 source,
425 source_length,
426 offset_source,
427 offset_out,
428 } => {
429 visit_read(self, barrier);
430 visit_read(self, source_length);
431 visit_read(self, source);
432 visit_read(self, offset_source);
433 visit_read(self, offset_out);
434 }
435 BarrierOps::TmaLoad {
436 barrier,
437 offset_out,
438 tensor_map,
439 indices,
440 } => {
441 visit_read(self, offset_out);
442 visit_read(self, barrier);
443 visit_read(self, tensor_map);
444 for index in indices {
445 visit_read(self, index);
446 }
447 }
448 BarrierOps::TmaLoadIm2col {
449 barrier,
450 tensor_map,
451 indices,
452 offset_out,
453 offsets,
454 } => {
455 visit_read(self, offset_out);
456 visit_read(self, barrier);
457 visit_read(self, tensor_map);
458 for index in indices {
459 visit_read(self, index);
460 }
461 for offset in offsets {
462 visit_read(self, offset);
463 }
464 }
465 BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
466 BarrierOps::Arrive { barrier } => visit_read(self, barrier),
467 BarrierOps::ArriveTx {
468 barrier,
469 arrive_count_update,
470 transaction_count_update,
471 } => {
472 visit_read(self, barrier);
473 visit_read(self, arrive_count_update);
474 visit_read(self, transaction_count_update);
475 }
476 BarrierOps::ExpectTx {
477 barrier,
478 transaction_count_update,
479 } => {
480 visit_read(self, barrier);
481 visit_read(self, transaction_count_update);
482 }
483 BarrierOps::Wait { barrier, token } => {
484 visit_read(self, barrier);
485 visit_read(self, token);
486 }
487 BarrierOps::WaitParity { barrier, phase } => {
488 visit_read(self, barrier);
489 visit_read(self, phase);
490 }
491 }
492 }
493
494 fn visit_tma(
495 &mut self,
496 tma_ops: &mut TmaOps,
497 mut visit_read: impl FnMut(&mut Self, &mut Variable),
498 ) {
499 match tma_ops {
500 TmaOps::TmaStore {
501 source,
502 coordinates,
503 offset_source,
504 } => {
505 visit_read(self, source);
506 visit_read(self, offset_source);
507 for coord in coordinates {
508 visit_read(self, coord)
509 }
510 }
511 TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
512 }
513 }
514
515 fn visit_nonsemantic(
516 &mut self,
517 non_semantic: &mut NonSemantic,
518 mut visit_read: impl FnMut(&mut Self, &mut Variable),
519 ) {
520 match non_semantic {
521 NonSemantic::Comment { .. }
522 | NonSemantic::EnterDebugScope
523 | NonSemantic::ExitDebugScope => {}
524 NonSemantic::Print { args, .. } => {
525 for arg in args {
526 visit_read(self, arg);
527 }
528 }
529 }
530 }
531
532 fn visit_unop(
533 &mut self,
534 unop: &mut UnaryOperator,
535 mut visit_read: impl FnMut(&mut Self, &mut Variable),
536 ) {
537 visit_read(self, &mut unop.input);
538 }
539
540 fn visit_binop(
541 &mut self,
542 binop: &mut BinaryOperator,
543 mut visit_read: impl FnMut(&mut Self, &mut Variable),
544 ) {
545 visit_read(self, &mut binop.lhs);
546 visit_read(self, &mut binop.rhs);
547 }
548}