1use cubecl_core::ir::{
2 AtomicOp, BinaryOperator, CoopMma, Instruction, Metadata, Operation, Operator, Plane,
3 UnaryOperator, Variable,
4};
5
6use super::Optimizer;
7
8impl Optimizer {
9 pub fn visit_out(
10 &mut self,
11 var: &mut Option<Variable>,
12 mut visit_write: impl FnMut(&mut Self, &mut Variable),
13 ) {
14 if let Some(out) = var {
15 visit_write(self, out);
16 }
17 }
18
19 pub fn visit_instruction(
22 &mut self,
23 inst: &mut Instruction,
24 visit_read: impl FnMut(&mut Self, &mut Variable),
25 visit_write: impl FnMut(&mut Self, &mut Variable),
26 ) {
27 self.visit_out(&mut inst.out, visit_write);
28 self.visit_operation(&mut inst.operation, visit_read);
29 }
30
31 pub fn visit_operation(
34 &mut self,
35 op: &mut Operation,
36 mut visit_read: impl FnMut(&mut Self, &mut Variable),
37 ) {
38 match op {
39 Operation::Copy(variable) => visit_read(self, variable),
40 Operation::Operator(operator) => self.visit_operator(operator, visit_read),
41 Operation::Atomic(atomic) => self.visit_atomic(atomic, visit_read),
42 Operation::Metadata(meta) => self.visit_meta(meta, visit_read),
43 Operation::Synchronization(_) | Operation::NonSemantic(_) => {}
45 Operation::Plane(plane) => self.visit_plane(plane, visit_read),
46 Operation::CoopMma(coop_mma) => self.visit_cmma(coop_mma, visit_read),
47 Operation::Branch(_) => unreachable!(),
48 }
49 }
50
51 pub fn visit_operator(
54 &mut self,
55 op: &mut Operator,
56 mut visit_read: impl FnMut(&mut Self, &mut Variable),
57 ) {
58 match op {
59 Operator::Fma(fma_operator) => {
60 visit_read(self, &mut fma_operator.a);
61 visit_read(self, &mut fma_operator.b);
62 visit_read(self, &mut fma_operator.c);
63 }
64 Operator::Add(binary_operator)
65 | Operator::Sub(binary_operator)
66 | Operator::Mul(binary_operator)
67 | Operator::Div(binary_operator)
68 | Operator::Powf(binary_operator)
69 | Operator::Equal(binary_operator)
70 | Operator::NotEqual(binary_operator)
71 | Operator::LowerEqual(binary_operator)
72 | Operator::UncheckedIndex(binary_operator)
73 | Operator::UncheckedIndexAssign(binary_operator)
74 | Operator::Modulo(binary_operator)
75 | Operator::Index(binary_operator)
76 | Operator::IndexAssign(binary_operator)
77 | Operator::And(binary_operator)
78 | Operator::Greater(binary_operator)
79 | Operator::Lower(binary_operator)
80 | Operator::Or(binary_operator)
81 | Operator::Max(binary_operator)
82 | Operator::Min(binary_operator)
83 | Operator::BitwiseAnd(binary_operator)
84 | Operator::BitwiseOr(binary_operator)
85 | Operator::BitwiseXor(binary_operator)
86 | Operator::ShiftLeft(binary_operator)
87 | Operator::ShiftRight(binary_operator)
88 | Operator::Remainder(binary_operator)
89 | Operator::Dot(binary_operator)
90 | Operator::GreaterEqual(binary_operator) => {
91 self.visit_binop(binary_operator, visit_read)
92 }
93
94 Operator::Abs(unary_operator)
95 | Operator::Exp(unary_operator)
96 | Operator::Log(unary_operator)
97 | Operator::Log1p(unary_operator)
98 | Operator::Cos(unary_operator)
99 | Operator::Sin(unary_operator)
100 | Operator::Tanh(unary_operator)
101 | Operator::Sqrt(unary_operator)
102 | Operator::Round(unary_operator)
103 | Operator::Floor(unary_operator)
104 | Operator::Ceil(unary_operator)
105 | Operator::Erf(unary_operator)
106 | Operator::Recip(unary_operator)
107 | Operator::Not(unary_operator)
108 | Operator::Neg(unary_operator)
109 | Operator::Cast(unary_operator)
110 | Operator::Bitcast(unary_operator)
111 | Operator::Magnitude(unary_operator)
112 | Operator::Normalize(unary_operator)
113 | Operator::CountOnes(unary_operator)
114 | Operator::ReverseBits(unary_operator) => self.visit_unop(unary_operator, visit_read),
115
116 Operator::Clamp(clamp_operator) => {
117 visit_read(self, &mut clamp_operator.input);
118 visit_read(self, &mut clamp_operator.min_value);
119 visit_read(self, &mut clamp_operator.max_value);
120 }
121 Operator::Slice(slice_operator) => {
122 visit_read(self, &mut slice_operator.start);
123 visit_read(self, &mut slice_operator.end);
124 visit_read(self, &mut slice_operator.input);
125 }
126 Operator::InitLine(line_init_operator) => {
127 for input in &mut line_init_operator.inputs {
128 visit_read(self, input)
129 }
130 }
131 Operator::CopyMemory(copy_operator) => {
132 visit_read(self, &mut copy_operator.input);
133 visit_read(self, &mut copy_operator.in_index);
134 visit_read(self, &mut copy_operator.out_index);
135 }
136 Operator::CopyMemoryBulk(copy_bulk_operator) => {
137 visit_read(self, &mut copy_bulk_operator.input);
138 visit_read(self, &mut copy_bulk_operator.in_index);
139 visit_read(self, &mut copy_bulk_operator.out_index);
140 }
141 Operator::Select(select) => {
142 visit_read(self, &mut select.cond);
143 visit_read(self, &mut select.then);
144 visit_read(self, &mut select.or_else);
145 }
146 }
147 }
148
149 fn visit_atomic(
150 &mut self,
151 atomic: &mut AtomicOp,
152 mut visit_read: impl FnMut(&mut Self, &mut Variable),
153 ) {
154 match atomic {
155 AtomicOp::Add(binary_operator)
156 | AtomicOp::Sub(binary_operator)
157 | AtomicOp::Max(binary_operator)
158 | AtomicOp::Min(binary_operator)
159 | AtomicOp::And(binary_operator)
160 | AtomicOp::Or(binary_operator)
161 | AtomicOp::Xor(binary_operator)
162 | AtomicOp::Swap(binary_operator) => {
163 self.visit_binop(binary_operator, visit_read);
164 }
165 AtomicOp::Load(unary_operator) | AtomicOp::Store(unary_operator) => {
166 self.visit_unop(unary_operator, visit_read);
167 }
168 AtomicOp::CompareAndSwap(op) => {
169 visit_read(self, &mut op.cmp);
170 visit_read(self, &mut op.cmp);
171 visit_read(self, &mut op.val);
172 }
173 }
174 }
175 fn visit_meta(
176 &mut self,
177 metadata: &mut Metadata,
178 mut visit_read: impl FnMut(&mut Self, &mut Variable),
179 ) {
180 match metadata {
181 Metadata::Rank { var } => {
182 visit_read(self, var);
183 }
184 Metadata::Stride { dim, var } => {
185 visit_read(self, dim);
186 visit_read(self, var);
187 }
188 Metadata::Shape { dim, var } => {
189 visit_read(self, dim);
190 visit_read(self, var);
191 }
192 Metadata::Length { var } => {
193 visit_read(self, var);
194 }
195 Metadata::BufferLength { var } => {
196 visit_read(self, var);
197 }
198 }
199 }
200
201 fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) {
202 match plane {
203 Plane::Elect => {}
204 Plane::Broadcast(binary_operator) => self.visit_binop(binary_operator, visit_read),
205 Plane::All(unary_operator)
206 | Plane::Any(unary_operator)
207 | Plane::Sum(unary_operator)
208 | Plane::Prod(unary_operator)
209 | Plane::Min(unary_operator)
210 | Plane::Max(unary_operator) => self.visit_unop(unary_operator, visit_read),
211 }
212 }
213
214 fn visit_cmma(
215 &mut self,
216 cmma: &mut CoopMma,
217 mut visit_read: impl FnMut(&mut Self, &mut Variable),
218 ) {
219 match cmma {
220 CoopMma::Fill { value } => {
221 visit_read(self, value);
222 }
223 CoopMma::Load { value, stride, .. } => {
224 visit_read(self, value);
225 visit_read(self, stride);
226 }
227 CoopMma::Execute {
228 mat_a,
229 mat_b,
230 mat_c,
231 } => {
232 visit_read(self, mat_a);
233 visit_read(self, mat_b);
234 visit_read(self, mat_c);
235 }
236 CoopMma::Store { mat, stride, .. } => {
237 visit_read(self, mat);
238 visit_read(self, stride);
239 }
240 CoopMma::Cast { input } => {
241 visit_read(self, input);
242 }
243 }
244 }
245
246 fn visit_unop(
247 &mut self,
248 unop: &mut UnaryOperator,
249 mut visit_read: impl FnMut(&mut Self, &mut Variable),
250 ) {
251 visit_read(self, &mut unop.input);
252 }
253
254 fn visit_binop(
255 &mut self,
256 binop: &mut BinaryOperator,
257 mut visit_read: impl FnMut(&mut Self, &mut Variable),
258 ) {
259 visit_read(self, &mut binop.lhs);
260 visit_read(self, &mut binop.rhs);
261 }
262}