1use core::fmt::Display;
2
3use alloc::vec::Vec;
4
5use crate::{Allocator, AtomicOp, Bitwise, Comparison, Operator};
6
7use super::{
8 Arithmetic, Branch, CoopMma, Elem, Instruction, Metadata, Operation, UIntKind, Variable,
9 VariableKind,
10};
11
12pub trait Processor {
13 fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing;
14}
15
16pub struct ScopeProcessing {
18 pub variables: Vec<Variable>,
20 pub instructions: Vec<Instruction>,
22}
23
24impl Display for ScopeProcessing {
25 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26 for inst in self.instructions.iter() {
27 f.write_fmt(format_args!("{inst}\n"))?;
28 }
29
30 Ok(())
31 }
32}
33
34impl ScopeProcessing {
35 pub fn optimize(self) -> Self {
43 self.sanitize_constant_scalars()
44 }
45
46 fn sanitize_constant_scalars(mut self) -> Self {
49 self.instructions
50 .iter_mut()
51 .for_each(|inst| match &mut inst.operation {
52 Operation::Copy(op) => {
53 sanitize_constant_scalar_ref_var(op, &inst.out.unwrap());
54 }
55 Operation::Arithmetic(op) => match op {
56 Arithmetic::Add(op) => {
57 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
58 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
59 }
60 Arithmetic::Fma(op) => {
61 sanitize_constant_scalar_ref_var(&mut op.a, &inst.out.unwrap());
62 sanitize_constant_scalar_ref_var(&mut op.b, &inst.out.unwrap());
63 sanitize_constant_scalar_ref_var(&mut op.c, &inst.out.unwrap());
64 }
65 Arithmetic::Sub(op) => {
66 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
67 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
68 }
69 Arithmetic::Mul(op) => {
70 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
71 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
72 }
73 Arithmetic::Div(op) => {
74 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
75 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
76 }
77 Arithmetic::MulHi(op) => {
78 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
79 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
80 }
81 Arithmetic::Abs(op) => {
82 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
83 }
84 Arithmetic::Exp(op) => {
85 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
86 }
87 Arithmetic::Log(op) => {
88 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
89 }
90 Arithmetic::Log1p(op) => {
91 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
92 }
93 Arithmetic::Cos(op) => {
94 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
95 }
96 Arithmetic::Sin(op) => {
97 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
98 }
99 Arithmetic::Tanh(op) => {
100 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
101 }
102 Arithmetic::Powf(op) => {
103 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
104 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
105 }
106 Arithmetic::Sqrt(op) => {
107 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
108 }
109 Arithmetic::Round(op) => {
110 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
111 }
112 Arithmetic::Floor(op) => {
113 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
114 }
115 Arithmetic::Ceil(op) => {
116 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
117 }
118 Arithmetic::Erf(op) => {
119 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
120 }
121 Arithmetic::Recip(op) => {
122 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
123 }
124 Arithmetic::Clamp(op) => {
125 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
126 sanitize_constant_scalar_ref_var(&mut op.min_value, &inst.out.unwrap());
127 sanitize_constant_scalar_ref_var(&mut op.max_value, &inst.out.unwrap());
128 }
129 Arithmetic::Modulo(op) => {
130 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
131 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
132 }
133 Arithmetic::Neg(op) => {
134 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap())
135 }
136 Arithmetic::Max(op) => {
137 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
138 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
139 }
140 Arithmetic::Min(op) => {
141 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
142 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
143 }
144 Arithmetic::Remainder(op) => {
145 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
146 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
147 }
148 Arithmetic::Magnitude(op) => {
149 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
150 }
151 Arithmetic::Normalize(op) => {
152 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
153 }
154 Arithmetic::Dot(op) => {
155 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
156 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
157 }
158 },
159 Operation::Comparison(op) => match op {
160 Comparison::Greater(op) => {
161 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
162 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
163 }
164 Comparison::LowerEqual(op) => {
165 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
166 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
167 }
168 Comparison::GreaterEqual(op) => {
169 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
170 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
171 }
172 Comparison::Equal(op) => {
173 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
174 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
175 }
176 Comparison::NotEqual(op) => {
177 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
178 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
179 }
180 Comparison::Lower(op) => {
181 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
182 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
183 }
184 },
185 Operation::Bitwise(op) => match op {
186 Bitwise::BitwiseAnd(op) => {
187 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
188 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
189 }
190 Bitwise::BitwiseOr(op) => {
191 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
192 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
193 }
194 Bitwise::BitwiseXor(op) => {
195 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
196 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
197 }
198 Bitwise::CountOnes(_) | Bitwise::LeadingZeros(_) | Bitwise::FindFirstSet(_) => {
199 }
201 Bitwise::ReverseBits(op) => {
202 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
203 }
204 Bitwise::ShiftLeft(op) => {
205 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
206 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
207 }
208 Bitwise::ShiftRight(op) => {
209 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
210 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
211 }
212 Bitwise::BitwiseNot(op) => {
213 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
214 }
215 },
216 Operation::Operator(op) => match op {
217 Operator::Index(op) => {
218 sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
219 sanitize_constant_scalar_ref_elem(&mut op.index, Elem::UInt(UIntKind::U32));
220 }
221 Operator::UncheckedIndex(op) => {
222 sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
223 sanitize_constant_scalar_ref_elem(&mut op.index, Elem::UInt(UIntKind::U32));
224 }
225 Operator::IndexAssign(op) => {
226 sanitize_constant_scalar_ref_elem(&mut op.index, Elem::UInt(UIntKind::U32));
227 sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
228 }
229 Operator::UncheckedIndexAssign(op) => {
230 sanitize_constant_scalar_ref_elem(&mut op.index, Elem::UInt(UIntKind::U32));
231 sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
232 }
233 Operator::And(op) => {
234 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
235 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
236 }
237 Operator::Or(op) => {
238 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
239 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
240 }
241 Operator::Not(op) => {
242 sanitize_constant_scalar_ref_elem(&mut op.input, Elem::Bool);
243 }
244 Operator::InitLine(_) => {
245 }
247 Operator::CopyMemory(op) => {
248 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
249 sanitize_constant_scalar_ref_elem(
250 &mut op.in_index,
251 Elem::UInt(UIntKind::U32),
252 );
253 sanitize_constant_scalar_ref_elem(
254 &mut op.out_index,
255 Elem::UInt(UIntKind::U32),
256 );
257 }
258 Operator::CopyMemoryBulk(op) => {
259 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
260 sanitize_constant_scalar_ref_elem(
261 &mut op.in_index,
262 Elem::UInt(UIntKind::U32),
263 );
264 sanitize_constant_scalar_ref_elem(
265 &mut op.out_index,
266 Elem::UInt(UIntKind::U32),
267 );
268 }
269 Operator::Select(op) => {
270 sanitize_constant_scalar_ref_elem(&mut op.cond, Elem::Bool);
271 sanitize_constant_scalar_ref_var(&mut op.then, &inst.out.unwrap());
272 sanitize_constant_scalar_ref_var(&mut op.or_else, &inst.out.unwrap());
273 }
274 Operator::Cast(_) => {}
275 Operator::Reinterpret(_) => {}
276 },
277 Operation::Atomic(op) => match op {
278 AtomicOp::Load(_) => {}
279 AtomicOp::Store(_) => {}
280 AtomicOp::Swap(op) => {
281 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
282 }
283 AtomicOp::CompareAndSwap(op) => {
284 sanitize_constant_scalar_ref_var(&mut op.cmp, &inst.out.unwrap());
285 sanitize_constant_scalar_ref_var(&mut op.val, &inst.out.unwrap());
286 }
287 AtomicOp::Add(op) => {
288 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
289 }
290 AtomicOp::Sub(op) => {
291 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
292 }
293 AtomicOp::Max(op) => {
294 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
295 }
296 AtomicOp::Min(op) => {
297 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
298 }
299 AtomicOp::And(op) => {
300 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
301 }
302 AtomicOp::Or(op) => {
303 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
304 }
305 AtomicOp::Xor(op) => {
306 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
307 }
308 },
309 Operation::Metadata(op) => match op {
310 Metadata::Stride { dim, .. } => {
311 sanitize_constant_scalar_ref_elem(dim, Elem::UInt(UIntKind::U32));
312 }
313 Metadata::Shape { dim, .. } => {
314 sanitize_constant_scalar_ref_elem(dim, Elem::UInt(UIntKind::U32));
315 }
316 Metadata::Length { .. }
317 | Metadata::BufferLength { .. }
318 | Metadata::Rank { .. } => {
319 }
321 },
322 Operation::Branch(op) => match op {
323 Branch::If(op) => {
324 sanitize_constant_scalar_ref_elem(&mut op.cond, Elem::Bool);
325 }
326 Branch::IfElse(op) => {
327 sanitize_constant_scalar_ref_elem(&mut op.cond, Elem::Bool);
328 }
329 Branch::RangeLoop(op) => {
330 sanitize_constant_scalar_ref_var(&mut op.end, &op.start);
331 sanitize_constant_scalar_ref_var(&mut op.i, &op.start);
332 if let Some(step) = &mut op.step {
333 sanitize_constant_scalar_ref_elem(step, Elem::UInt(UIntKind::U32));
334 }
335 }
336 _ => {
337 }
339 },
340 Operation::Synchronization(_) => {
341 }
343 Operation::Plane(_) => {
344 }
346 Operation::CoopMma(op) => match op {
347 CoopMma::Fill { value } => {
348 sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
349 }
350 CoopMma::Load { value, stride, .. } => {
351 sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
352 sanitize_constant_scalar_ref_elem(stride, Elem::UInt(UIntKind::U32));
353 }
354 CoopMma::Execute { .. } => {
355 }
357 CoopMma::Store { stride, .. } => {
358 sanitize_constant_scalar_ref_elem(stride, Elem::UInt(UIntKind::U32));
359 }
360 CoopMma::Cast { .. } => {
361 }
363 },
364 Operation::NonSemantic(_) => {
365 }
367 Operation::Barrier(_) => {
368 }
370 Operation::Tma(_) => {
371 }
373 });
374 self
375 }
376}
377
378fn sanitize_constant_scalar_ref_var(var: &mut Variable, reference: &Variable) {
379 let elem = reference.item.elem();
380 sanitize_constant_scalar_ref_elem(var, elem);
381}
382
383fn sanitize_constant_scalar_ref_elem(var: &mut Variable, elem: Elem) {
384 if let VariableKind::ConstantScalar(scalar) = var.kind {
385 if scalar.elem() != elem {
386 *var = match scalar {
387 super::ConstantScalarValue::Int(val, _) => elem.constant_from_i64(val),
388 super::ConstantScalarValue::Float(val, _) => elem.constant_from_f64(val),
389 super::ConstantScalarValue::UInt(val, _) => elem.constant_from_u64(val),
390 super::ConstantScalarValue::Bool(val) => elem.constant_from_bool(val),
391 };
392 }
393 }
394}