1use cubecl_ir::{
2 Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3 Metadata, NonSemantic, Operation, Operator, Plane, TmaOps, UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9 pub fn visit_out(
10 &mut self,
11 var: &mut Option<Variable>,
12 mut visit_write: impl FnMut(&mut Self, &mut Variable),
13 ) {
14 if let Some(out) = var {
15 visit_write(self, out);
16 }
17 }
18
19 pub fn visit_instruction(
22 &mut self,
23 inst: &mut Instruction,
24 visit_read: impl FnMut(&mut Self, &mut Variable),
25 visit_write: impl FnMut(&mut Self, &mut Variable),
26 ) {
27 self.visit_out(&mut inst.out, visit_write);
28 self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
29 }
30
31 pub fn visit_operation(
34 &mut self,
35 op: &mut Operation,
36 out: &mut Option<Variable>,
37 mut visit_read: impl FnMut(&mut Self, &mut Variable),
38 ) {
39 match op {
40 Operation::Copy(variable) => visit_read(self, variable),
41 Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
42 Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
43 Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
44 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
45 Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
46 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
47 Operation::Synchronization(_) => {}
49 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
50 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
51 Operation::Branch(_) => unreachable!(),
52 Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
53 Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
54 Operation::NonSemantic(non_semantic) => {
55 self.visit_nonsemantic(non_semantic, visit_read)
56 }
57 Operation::Free(_) => {}
58 }
59 }
60
61 pub fn visit_arithmetic(
64 &mut self,
65 op: &mut Arithmetic,
66 mut visit_read: impl FnMut(&mut Self, &mut Variable),
67 ) {
68 match op {
69 Arithmetic::Fma(fma_operator) => {
70 visit_read(self, &mut fma_operator.a);
71 visit_read(self, &mut fma_operator.b);
72 visit_read(self, &mut fma_operator.c);
73 }
74 Arithmetic::Add(binary_operator)
75 | Arithmetic::SaturatingAdd(binary_operator)
76 | Arithmetic::Sub(binary_operator)
77 | Arithmetic::SaturatingSub(binary_operator)
78 | Arithmetic::Mul(binary_operator)
79 | Arithmetic::Div(binary_operator)
80 | Arithmetic::Powf(binary_operator)
81 | Arithmetic::Powi(binary_operator)
82 | Arithmetic::Modulo(binary_operator)
83 | Arithmetic::Max(binary_operator)
84 | Arithmetic::Min(binary_operator)
85 | Arithmetic::Remainder(binary_operator)
86 | Arithmetic::Dot(binary_operator)
87 | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read),
88
89 Arithmetic::Abs(unary_operator)
90 | Arithmetic::Exp(unary_operator)
91 | Arithmetic::Log(unary_operator)
92 | Arithmetic::Log1p(unary_operator)
93 | Arithmetic::Cos(unary_operator)
94 | Arithmetic::Sin(unary_operator)
95 | Arithmetic::Tanh(unary_operator)
96 | Arithmetic::Sqrt(unary_operator)
97 | Arithmetic::Round(unary_operator)
98 | Arithmetic::Floor(unary_operator)
99 | Arithmetic::Ceil(unary_operator)
100 | Arithmetic::Trunc(unary_operator)
101 | Arithmetic::Erf(unary_operator)
102 | Arithmetic::Recip(unary_operator)
103 | Arithmetic::Neg(unary_operator)
104 | Arithmetic::Magnitude(unary_operator)
105 | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
106
107 Arithmetic::Clamp(clamp_operator) => {
108 visit_read(self, &mut clamp_operator.input);
109 visit_read(self, &mut clamp_operator.min_value);
110 visit_read(self, &mut clamp_operator.max_value);
111 }
112 }
113 }
114
115 pub fn visit_compare(
118 &mut self,
119 op: &mut Comparison,
120 visit_read: impl FnMut(&mut Self, &mut Variable),
121 ) {
122 match op {
123 Comparison::Equal(binary_operator)
124 | Comparison::NotEqual(binary_operator)
125 | Comparison::LowerEqual(binary_operator)
126 | Comparison::Greater(binary_operator)
127 | Comparison::Lower(binary_operator)
128 | Comparison::GreaterEqual(binary_operator) => {
129 self.visit_binop(binary_operator, visit_read)
130 }
131 Comparison::IsNan(unary_operator) | Comparison::IsInf(unary_operator) => {
132 self.visit_unop(unary_operator, visit_read)
133 }
134 }
135 }
136
137 pub fn visit_bitwise(
140 &mut self,
141 op: &mut Bitwise,
142 visit_read: impl FnMut(&mut Self, &mut Variable),
143 ) {
144 match op {
145 Bitwise::BitwiseAnd(binary_operator)
146 | Bitwise::BitwiseOr(binary_operator)
147 | Bitwise::BitwiseXor(binary_operator)
148 | Bitwise::ShiftLeft(binary_operator)
149 | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
150
151 Bitwise::CountOnes(unary_operator)
152 | Bitwise::BitwiseNot(unary_operator)
153 | Bitwise::ReverseBits(unary_operator)
154 | Bitwise::LeadingZeros(unary_operator)
155 | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
156 }
157 }
158
159 pub fn visit_operator(
162 &mut self,
163 op: &mut Operator,
164 mut visit_read: impl FnMut(&mut Self, &mut Variable),
165 ) {
166 match op {
167 Operator::And(binary_operator) | Operator::Or(binary_operator) => {
168 self.visit_binop(binary_operator, visit_read)
169 }
170 Operator::Not(unary_operator)
171 | Operator::Cast(unary_operator)
172 | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
173 Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => {
174 visit_read(self, &mut index_operator.list);
175 visit_read(self, &mut index_operator.index);
176 }
177 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
178 visit_read(self, &mut op.index);
179 visit_read(self, &mut op.value);
180 }
181 Operator::InitLine(line_init_operator) => {
182 for input in &mut line_init_operator.inputs {
183 visit_read(self, input)
184 }
185 }
186 Operator::CopyMemory(copy_operator) => {
187 visit_read(self, &mut copy_operator.input);
188 visit_read(self, &mut copy_operator.in_index);
189 visit_read(self, &mut copy_operator.out_index);
190 }
191 Operator::CopyMemoryBulk(copy_bulk_operator) => {
192 visit_read(self, &mut copy_bulk_operator.input);
193 visit_read(self, &mut copy_bulk_operator.in_index);
194 visit_read(self, &mut copy_bulk_operator.out_index);
195 }
196 Operator::Select(select) => {
197 visit_read(self, &mut select.cond);
198 visit_read(self, &mut select.then);
199 visit_read(self, &mut select.or_else);
200 }
201 }
202 }
203
204 fn visit_atomic(
205 &mut self,
206 atomic: &mut AtomicOp,
207 out: &mut Option<Variable>,
208 mut visit_read: impl FnMut(&mut Self, &mut Variable),
209 ) {
210 match atomic {
211 AtomicOp::Add(binary_operator)
212 | AtomicOp::Sub(binary_operator)
213 | AtomicOp::Max(binary_operator)
214 | AtomicOp::Min(binary_operator)
215 | AtomicOp::And(binary_operator)
216 | AtomicOp::Or(binary_operator)
217 | AtomicOp::Xor(binary_operator)
218 | AtomicOp::Swap(binary_operator) => {
219 self.visit_binop(binary_operator, visit_read);
220 }
221 AtomicOp::Load(unary_operator) => {
222 self.visit_unop(unary_operator, visit_read);
223 }
224 AtomicOp::Store(unary_operator) => {
225 visit_read(self, out.as_mut().unwrap());
226 self.visit_unop(unary_operator, visit_read);
227 }
228 AtomicOp::CompareAndSwap(op) => {
229 visit_read(self, &mut op.cmp);
230 visit_read(self, &mut op.cmp);
231 visit_read(self, &mut op.val);
232 }
233 }
234 }
235 fn visit_meta(
236 &mut self,
237 metadata: &mut Metadata,
238 mut visit_read: impl FnMut(&mut Self, &mut Variable),
239 ) {
240 match metadata {
241 Metadata::Rank { var } => {
242 visit_read(self, var);
243 }
244 Metadata::Stride { dim, var } => {
245 visit_read(self, dim);
246 visit_read(self, var);
247 }
248 Metadata::Shape { dim, var } => {
249 visit_read(self, dim);
250 visit_read(self, var);
251 }
252 Metadata::Length { var } => {
253 visit_read(self, var);
254 }
255 Metadata::BufferLength { var } => {
256 visit_read(self, var);
257 }
258 }
259 }
260
261 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
262 match plane {
263 Plane::Elect => {}
264 Plane::Broadcast(binary_operator)
265 | Plane::Shuffle(binary_operator)
266 | Plane::ShuffleXor(binary_operator)
267 | Plane::ShuffleUp(binary_operator)
268 | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read),
269 Plane::All(unary_operator)
270 | Plane::Any(unary_operator)
271 | Plane::Sum(unary_operator)
272 | Plane::InclusiveSum(unary_operator)
273 | Plane::ExclusiveSum(unary_operator)
274 | Plane::Prod(unary_operator)
275 | Plane::InclusiveProd(unary_operator)
276 | Plane::ExclusiveProd(unary_operator)
277 | Plane::Min(unary_operator)
278 | Plane::Max(unary_operator)
279 | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
280 }
281 }
282
283 fn visit_cmma(
284 &mut self,
285 cmma: &mut CoopMma,
286 mut visit_read: impl FnMut(&mut Self, &mut Variable),
287 ) {
288 match cmma {
289 CoopMma::Fill { value } => {
290 visit_read(self, value);
291 }
292 CoopMma::Load {
293 value,
294 stride,
295 offset,
296 layout: _,
297 } => {
298 visit_read(self, value);
299 visit_read(self, stride);
300 visit_read(self, offset);
301 }
302 CoopMma::Execute {
303 mat_a,
304 mat_b,
305 mat_c,
306 } => {
307 visit_read(self, mat_a);
308 visit_read(self, mat_b);
309 visit_read(self, mat_c);
310 }
311 CoopMma::Store {
312 mat,
313 stride,
314 offset,
315 layout: _,
316 } => {
317 visit_read(self, mat);
318 visit_read(self, stride);
319 visit_read(self, offset);
320 }
321 CoopMma::Cast { input } => {
322 visit_read(self, input);
323 }
324 CoopMma::RowIndex { lane_id, i, .. } => {
325 visit_read(self, lane_id);
326 visit_read(self, i);
327 }
328 CoopMma::ColIndex { lane_id, i, .. } => {
329 visit_read(self, lane_id);
330 visit_read(self, i);
331 }
332 CoopMma::ExecuteManual {
333 registers_a,
334 registers_b,
335 registers_c,
336 ..
337 } => {
338 for reg in registers_a {
339 visit_read(self, reg);
340 }
341 for reg in registers_b {
342 visit_read(self, reg);
343 }
344 for reg in registers_c {
345 visit_read(self, reg);
346 }
347 }
348 CoopMma::ExecuteScaled {
349 registers_a,
350 registers_b,
351 registers_c,
352 scales_a,
353 scales_b,
354 ..
355 } => {
356 for reg in registers_a {
357 visit_read(self, reg);
358 }
359 for reg in registers_b {
360 visit_read(self, reg);
361 }
362 for reg in registers_c {
363 visit_read(self, reg);
364 }
365 visit_read(self, scales_a);
366 visit_read(self, scales_b);
367 }
368 }
369 }
370
371 fn visit_barrier(
372 &mut self,
373 barrier_ops: &mut BarrierOps,
374 mut visit_read: impl FnMut(&mut Self, &mut Variable),
375 ) {
376 match barrier_ops {
377 BarrierOps::Init { barrier, .. } => {
378 visit_read(self, barrier);
379 }
380 BarrierOps::MemCopyAsync {
381 barrier,
382 source,
383 source_length,
384 offset_source,
385 offset_out,
386 } => {
387 visit_read(self, barrier);
388 visit_read(self, source_length);
389 visit_read(self, source);
390 visit_read(self, offset_source);
391 visit_read(self, offset_out);
392 }
393 BarrierOps::TmaLoad {
394 barrier,
395 offset_out,
396 tensor_map,
397 indices,
398 } => {
399 visit_read(self, offset_out);
400 visit_read(self, barrier);
401 visit_read(self, tensor_map);
402 for index in indices {
403 visit_read(self, index);
404 }
405 }
406 BarrierOps::TmaLoadIm2col {
407 barrier,
408 tensor_map,
409 indices,
410 offset_out,
411 offsets,
412 } => {
413 visit_read(self, offset_out);
414 visit_read(self, barrier);
415 visit_read(self, tensor_map);
416 for index in indices {
417 visit_read(self, index);
418 }
419 for offset in offsets {
420 visit_read(self, offset);
421 }
422 }
423 BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
424 BarrierOps::Arrive { barrier } => visit_read(self, barrier),
425 BarrierOps::ArriveTx {
426 barrier,
427 arrive_count_update,
428 transaction_count_update,
429 } => {
430 visit_read(self, barrier);
431 visit_read(self, arrive_count_update);
432 visit_read(self, transaction_count_update);
433 }
434 BarrierOps::ExpectTx {
435 barrier,
436 transaction_count_update,
437 } => {
438 visit_read(self, barrier);
439 visit_read(self, transaction_count_update);
440 }
441 BarrierOps::Wait { barrier } => {
442 visit_read(self, barrier);
443 }
444 }
445 }
446
447 fn visit_tma(
448 &mut self,
449 tma_ops: &mut TmaOps,
450 mut visit_read: impl FnMut(&mut Self, &mut Variable),
451 ) {
452 match tma_ops {
453 TmaOps::TmaStore {
454 source,
455 coordinates,
456 offset_source,
457 } => {
458 visit_read(self, source);
459 visit_read(self, offset_source);
460 for coord in coordinates {
461 visit_read(self, coord)
462 }
463 }
464 TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
465 }
466 }
467
468 fn visit_nonsemantic(
469 &mut self,
470 non_semantic: &mut NonSemantic,
471 mut visit_read: impl FnMut(&mut Self, &mut Variable),
472 ) {
473 match non_semantic {
474 NonSemantic::Comment { .. }
475 | NonSemantic::EnterDebugScope
476 | NonSemantic::ExitDebugScope => {}
477 NonSemantic::Print { args, .. } => {
478 for arg in args {
479 visit_read(self, arg);
480 }
481 }
482 }
483 }
484
485 fn visit_unop(
486 &mut self,
487 unop: &mut UnaryOperator,
488 mut visit_read: impl FnMut(&mut Self, &mut Variable),
489 ) {
490 visit_read(self, &mut unop.input);
491 }
492
493 fn visit_binop(
494 &mut self,
495 binop: &mut BinaryOperator,
496 mut visit_read: impl FnMut(&mut Self, &mut Variable),
497 ) {
498 visit_read(self, &mut binop.lhs);
499 visit_read(self, &mut binop.rhs);
500 }
501}