1use cubecl_ir::{
2 Arithmetic, AtomicOp, BarrierOps, BinaryOperator, Bitwise, Comparison, CoopMma, Instruction,
3 Metadata, NonSemantic, Operation, Operator, PipelineOps, Plane, TmaOps, UnaryOperator,
4 Variable,
5};
6
7use super::Optimizer;
8
9impl Optimizer {
10 pub fn visit_out(
11 &mut self,
12 var: &mut Option<Variable>,
13 mut visit_write: impl FnMut(&mut Self, &mut Variable),
14 ) {
15 if let Some(out) = var {
16 visit_write(self, out);
17 }
18 }
19
20 pub fn visit_instruction(
23 &mut self,
24 inst: &mut Instruction,
25 visit_read: impl FnMut(&mut Self, &mut Variable),
26 visit_write: impl FnMut(&mut Self, &mut Variable),
27 ) {
28 self.visit_out(&mut inst.out, visit_write);
29 self.visit_operation(&mut inst.operation, &mut inst.out, visit_read);
30 }
31
32 pub fn visit_operation(
35 &mut self,
36 op: &mut Operation,
37 out: &mut Option<Variable>,
38 mut visit_read: impl FnMut(&mut Self, &mut Variable),
39 ) {
40 match op {
41 Operation::Copy(variable) => visit_read(self, variable),
42 Operation::Arithmetic(arithmetic) => self.visit_arithmetic(arithmetic, visit_read),
43 Operation::Comparison(comparison) => self.visit_compare(comparison, visit_read),
44 Operation::Bitwise(bitwise) => self.visit_bitwise(bitwise, visit_read),
45 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
46 Operation::Atomic(atomic) => self.visit_atomic(atomic, out, visit_read),
47 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
48 Operation::Synchronization(_) => {}
50 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
51 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
52 Operation::Branch(_) => unreachable!(),
53 Operation::Pipeline(pipeline_ops) => self.visit_pipeline(pipeline_ops, visit_read),
54 Operation::Barrier(barrier_ops) => self.visit_barrier(barrier_ops, visit_read),
55 Operation::Tma(tma_ops) => self.visit_tma(tma_ops, visit_read),
56 Operation::NonSemantic(non_semantic) => {
57 self.visit_nonsemantic(non_semantic, visit_read)
58 }
59 }
60 }
61
62 pub fn visit_arithmetic(
65 &mut self,
66 op: &mut Arithmetic,
67 mut visit_read: impl FnMut(&mut Self, &mut Variable),
68 ) {
69 match op {
70 Arithmetic::Fma(fma_operator) => {
71 visit_read(self, &mut fma_operator.a);
72 visit_read(self, &mut fma_operator.b);
73 visit_read(self, &mut fma_operator.c);
74 }
75 Arithmetic::Add(binary_operator)
76 | Arithmetic::Sub(binary_operator)
77 | Arithmetic::Mul(binary_operator)
78 | Arithmetic::Div(binary_operator)
79 | Arithmetic::Powf(binary_operator)
80 | Arithmetic::Modulo(binary_operator)
81 | Arithmetic::Max(binary_operator)
82 | Arithmetic::Min(binary_operator)
83 | Arithmetic::Remainder(binary_operator)
84 | Arithmetic::Dot(binary_operator)
85 | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read),
86
87 Arithmetic::Abs(unary_operator)
88 | Arithmetic::Exp(unary_operator)
89 | Arithmetic::Log(unary_operator)
90 | Arithmetic::Log1p(unary_operator)
91 | Arithmetic::Cos(unary_operator)
92 | Arithmetic::Sin(unary_operator)
93 | Arithmetic::Tanh(unary_operator)
94 | Arithmetic::Sqrt(unary_operator)
95 | Arithmetic::Round(unary_operator)
96 | Arithmetic::Floor(unary_operator)
97 | Arithmetic::Ceil(unary_operator)
98 | Arithmetic::Erf(unary_operator)
99 | Arithmetic::Recip(unary_operator)
100 | Arithmetic::Neg(unary_operator)
101 | Arithmetic::Magnitude(unary_operator)
102 | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read),
103
104 Arithmetic::Clamp(clamp_operator) => {
105 visit_read(self, &mut clamp_operator.input);
106 visit_read(self, &mut clamp_operator.min_value);
107 visit_read(self, &mut clamp_operator.max_value);
108 }
109 }
110 }
111
112 pub fn visit_compare(
115 &mut self,
116 op: &mut Comparison,
117 visit_read: impl FnMut(&mut Self, &mut Variable),
118 ) {
119 match op {
120 Comparison::Equal(binary_operator)
121 | Comparison::NotEqual(binary_operator)
122 | Comparison::LowerEqual(binary_operator)
123 | Comparison::Greater(binary_operator)
124 | Comparison::Lower(binary_operator)
125 | Comparison::GreaterEqual(binary_operator) => {
126 self.visit_binop(binary_operator, visit_read)
127 }
128 }
129 }
130
131 pub fn visit_bitwise(
134 &mut self,
135 op: &mut Bitwise,
136 visit_read: impl FnMut(&mut Self, &mut Variable),
137 ) {
138 match op {
139 Bitwise::BitwiseAnd(binary_operator)
140 | Bitwise::BitwiseOr(binary_operator)
141 | Bitwise::BitwiseXor(binary_operator)
142 | Bitwise::ShiftLeft(binary_operator)
143 | Bitwise::ShiftRight(binary_operator) => self.visit_binop(binary_operator, visit_read),
144
145 Bitwise::CountOnes(unary_operator)
146 | Bitwise::BitwiseNot(unary_operator)
147 | Bitwise::ReverseBits(unary_operator)
148 | Bitwise::LeadingZeros(unary_operator)
149 | Bitwise::FindFirstSet(unary_operator) => self.visit_unop(unary_operator, visit_read),
150 }
151 }
152
153 pub fn visit_operator(
156 &mut self,
157 op: &mut Operator,
158 mut visit_read: impl FnMut(&mut Self, &mut Variable),
159 ) {
160 match op {
161 Operator::UncheckedIndex(binary_operator)
162 | Operator::UncheckedIndexAssign(binary_operator)
163 | Operator::Index(binary_operator)
164 | Operator::IndexAssign(binary_operator)
165 | Operator::And(binary_operator)
166 | Operator::Or(binary_operator) => self.visit_binop(binary_operator, visit_read),
167 Operator::Not(unary_operator)
168 | Operator::Cast(unary_operator)
169 | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read),
170 Operator::Slice(slice_operator) => {
171 visit_read(self, &mut slice_operator.start);
172 visit_read(self, &mut slice_operator.end);
173 visit_read(self, &mut slice_operator.input);
174 }
175 Operator::ReinterpretSlice(_) => {
176 todo!()
177 }
178 Operator::InitLine(line_init_operator) => {
179 for input in &mut line_init_operator.inputs {
180 visit_read(self, input)
181 }
182 }
183 Operator::CopyMemory(copy_operator) => {
184 visit_read(self, &mut copy_operator.input);
185 visit_read(self, &mut copy_operator.in_index);
186 visit_read(self, &mut copy_operator.out_index);
187 }
188 Operator::CopyMemoryBulk(copy_bulk_operator) => {
189 visit_read(self, &mut copy_bulk_operator.input);
190 visit_read(self, &mut copy_bulk_operator.in_index);
191 visit_read(self, &mut copy_bulk_operator.out_index);
192 }
193 Operator::Select(select) => {
194 visit_read(self, &mut select.cond);
195 visit_read(self, &mut select.then);
196 visit_read(self, &mut select.or_else);
197 }
198 }
199 }
200
201 fn visit_atomic(
202 &mut self,
203 atomic: &mut AtomicOp,
204 out: &mut Option<Variable>,
205 mut visit_read: impl FnMut(&mut Self, &mut Variable),
206 ) {
207 match atomic {
208 AtomicOp::Add(binary_operator)
209 | AtomicOp::Sub(binary_operator)
210 | AtomicOp::Max(binary_operator)
211 | AtomicOp::Min(binary_operator)
212 | AtomicOp::And(binary_operator)
213 | AtomicOp::Or(binary_operator)
214 | AtomicOp::Xor(binary_operator)
215 | AtomicOp::Swap(binary_operator) => {
216 self.visit_binop(binary_operator, visit_read);
217 }
218 AtomicOp::Load(unary_operator) => {
219 self.visit_unop(unary_operator, visit_read);
220 }
221 AtomicOp::Store(unary_operator) => {
222 visit_read(self, out.as_mut().unwrap());
223 self.visit_unop(unary_operator, visit_read);
224 }
225 AtomicOp::CompareAndSwap(op) => {
226 visit_read(self, &mut op.cmp);
227 visit_read(self, &mut op.cmp);
228 visit_read(self, &mut op.val);
229 }
230 }
231 }
232 fn visit_meta(
233 &mut self,
234 metadata: &mut Metadata,
235 mut visit_read: impl FnMut(&mut Self, &mut Variable),
236 ) {
237 match metadata {
238 Metadata::Rank { var } => {
239 visit_read(self, var);
240 }
241 Metadata::Stride { dim, var } => {
242 visit_read(self, dim);
243 visit_read(self, var);
244 }
245 Metadata::Shape { dim, var } => {
246 visit_read(self, dim);
247 visit_read(self, var);
248 }
249 Metadata::Length { var } => {
250 visit_read(self, var);
251 }
252 Metadata::BufferLength { var } => {
253 visit_read(self, var);
254 }
255 }
256 }
257
258 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
259 match plane {
260 Plane::Elect => {}
261 Plane::Broadcast(binary_operator) => self.visit_binop(binary_operator, visit_read),
262 Plane::All(unary_operator)
263 | Plane::Any(unary_operator)
264 | Plane::Sum(unary_operator)
265 | Plane::InclusiveSum(unary_operator)
266 | Plane::ExclusiveSum(unary_operator)
267 | Plane::Prod(unary_operator)
268 | Plane::InclusiveProd(unary_operator)
269 | Plane::ExclusiveProd(unary_operator)
270 | Plane::Min(unary_operator)
271 | Plane::Max(unary_operator)
272 | Plane::Ballot(unary_operator) => self.visit_unop(unary_operator, visit_read),
273 }
274 }
275
276 fn visit_cmma(
277 &mut self,
278 cmma: &mut CoopMma,
279 mut visit_read: impl FnMut(&mut Self, &mut Variable),
280 ) {
281 match cmma {
282 CoopMma::Fill { value } => {
283 visit_read(self, value);
284 }
285 CoopMma::Load { value, stride, .. } => {
286 visit_read(self, value);
287 visit_read(self, stride);
288 }
289 CoopMma::Execute {
290 mat_a,
291 mat_b,
292 mat_c,
293 } => {
294 visit_read(self, mat_a);
295 visit_read(self, mat_b);
296 visit_read(self, mat_c);
297 }
298 CoopMma::Store { mat, stride, .. } => {
299 visit_read(self, mat);
300 visit_read(self, stride);
301 }
302 CoopMma::Cast { input } => {
303 visit_read(self, input);
304 }
305 }
306 }
307
308 fn visit_pipeline(
309 &mut self,
310 pipeline_ops: &mut PipelineOps,
311 mut visit_read: impl FnMut(&mut Self, &mut Variable),
312 ) {
313 match pipeline_ops {
314 PipelineOps::MemCopyAsync {
315 pipeline,
316 source,
317 destination,
318 } => {
319 visit_read(self, pipeline);
320 visit_read(self, source);
321 visit_read(self, destination);
322 }
323 PipelineOps::ProducerAcquire { pipeline } => visit_read(self, pipeline),
324 PipelineOps::ProducerCommit { pipeline } => visit_read(self, pipeline),
325 PipelineOps::ConsumerWait { pipeline } => visit_read(self, pipeline),
326 PipelineOps::ConsumerRelease { pipeline } => visit_read(self, pipeline),
327 }
328 }
329
330 fn visit_barrier(
331 &mut self,
332 barrier_ops: &mut BarrierOps,
333 mut visit_read: impl FnMut(&mut Self, &mut Variable),
334 ) {
335 match barrier_ops {
336 BarrierOps::Init { barrier, .. } => {
337 visit_read(self, barrier);
338 }
339 BarrierOps::MemCopyAsync { barrier, source } => {
340 visit_read(self, barrier);
341 visit_read(self, source);
342 }
343 BarrierOps::TmaLoad {
344 barrier,
345 tensor_map,
346 indices,
347 } => {
348 visit_read(self, barrier);
349 visit_read(self, tensor_map);
350 for index in indices {
351 visit_read(self, index);
352 }
353 }
354 BarrierOps::TmaLoadIm2col {
355 barrier,
356 tensor_map,
357 indices,
358 offsets,
359 } => {
360 visit_read(self, barrier);
361 visit_read(self, tensor_map);
362 for index in indices {
363 visit_read(self, index);
364 }
365 for offset in offsets {
366 visit_read(self, offset);
367 }
368 }
369 BarrierOps::ArriveAndWait { barrier } => visit_read(self, barrier),
370 BarrierOps::Arrive { barrier } => visit_read(self, barrier),
371 BarrierOps::ArriveTx {
372 barrier,
373 arrive_count_update,
374 transaction_count_update,
375 } => {
376 visit_read(self, barrier);
377 visit_read(self, arrive_count_update);
378 visit_read(self, transaction_count_update);
379 }
380 BarrierOps::ExpectTx {
381 barrier,
382 transaction_count_update,
383 } => {
384 visit_read(self, barrier);
385 visit_read(self, transaction_count_update);
386 }
387 BarrierOps::Wait { barrier } => {
388 visit_read(self, barrier);
389 }
390 }
391 }
392
393 fn visit_tma(
394 &mut self,
395 tma_ops: &mut TmaOps,
396 mut visit_read: impl FnMut(&mut Self, &mut Variable),
397 ) {
398 match tma_ops {
399 TmaOps::TmaStore {
400 source,
401 coordinates,
402 } => {
403 visit_read(self, source);
404 for coord in coordinates {
405 visit_read(self, coord)
406 }
407 }
408 TmaOps::CommitGroup | TmaOps::WaitGroup { .. } | TmaOps::WaitGroupRead { .. } => {}
409 }
410 }
411
412 fn visit_nonsemantic(
413 &mut self,
414 non_semantic: &mut NonSemantic,
415 mut visit_read: impl FnMut(&mut Self, &mut Variable),
416 ) {
417 match non_semantic {
418 NonSemantic::Comment { .. }
419 | NonSemantic::EnterDebugScope
420 | NonSemantic::ExitDebugScope => {}
421 NonSemantic::Print { args, .. } => {
422 for arg in args {
423 visit_read(self, arg);
424 }
425 }
426 }
427 }
428
429 fn visit_unop(
430 &mut self,
431 unop: &mut UnaryOperator,
432 mut visit_read: impl FnMut(&mut Self, &mut Variable),
433 ) {
434 visit_read(self, &mut unop.input);
435 }
436
437 fn visit_binop(
438 &mut self,
439 binop: &mut BinaryOperator,
440 mut visit_read: impl FnMut(&mut Self, &mut Variable),
441 ) {
442 visit_read(self, &mut binop.lhs);
443 visit_read(self, &mut binop.rhs);
444 }
445}