1use core::fmt::Display;
2
3use alloc::string::ToString;
4use alloc::vec::Vec;
5
6use crate::{Allocator, AtomicOp, Bitwise, Comparison, Operator};
7
8use super::{
9 Arithmetic, Branch, CoopMma, ElemType, Instruction, Metadata, Operation, UIntKind, Variable,
10 VariableKind,
11};
12
13pub trait Processor: core::fmt::Debug {
14 fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing;
15}
16
17pub struct ScopeProcessing {
19 pub variables: Vec<Variable>,
21 pub instructions: Vec<Instruction>,
23}
24
25impl Display for ScopeProcessing {
26 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27 writeln!(f, "{{")?;
28 for instruction in self.instructions.iter() {
29 let instruction_str = instruction.to_string();
30 if !instruction_str.is_empty() {
31 writeln!(f, " {instruction_str}")?;
32 }
33 }
34 write!(f, "}}")?;
35 Ok(())
36 }
37}
38
39impl ScopeProcessing {
40 pub fn optimize(self) -> Self {
48 self.sanitize_constant_scalars()
49 }
50
51 fn sanitize_constant_scalars(mut self) -> Self {
54 self.instructions
55 .iter_mut()
56 .for_each(|inst| match &mut inst.operation {
57 Operation::Copy(op) => {
58 sanitize_constant_scalar_ref_var(op, &inst.out.unwrap());
59 }
60 Operation::Arithmetic(op) => match op {
61 Arithmetic::Add(op) => {
62 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
63 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
64 }
65 Arithmetic::SaturatingAdd(op) => {
66 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
67 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
68 }
69 Arithmetic::Fma(op) => {
70 sanitize_constant_scalar_ref_var(&mut op.a, &inst.out.unwrap());
71 sanitize_constant_scalar_ref_var(&mut op.b, &inst.out.unwrap());
72 sanitize_constant_scalar_ref_var(&mut op.c, &inst.out.unwrap());
73 }
74 Arithmetic::Sub(op) => {
75 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
76 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
77 }
78 Arithmetic::SaturatingSub(op) => {
79 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
80 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
81 }
82 Arithmetic::Mul(op) => {
83 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
84 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
85 }
86 Arithmetic::Div(op) => {
87 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
88 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
89 }
90 Arithmetic::MulHi(op) => {
91 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
92 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
93 }
94 Arithmetic::Abs(op) => {
95 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
96 }
97 Arithmetic::Exp(op) => {
98 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
99 }
100 Arithmetic::Log(op) => {
101 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
102 }
103 Arithmetic::Log1p(op) => {
104 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
105 }
106 Arithmetic::Cos(op) => {
107 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
108 }
109 Arithmetic::Sin(op) => {
110 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
111 }
112 Arithmetic::Tanh(op) => {
113 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
114 }
115 Arithmetic::Powf(op) => {
116 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
117 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
118 }
119 Arithmetic::Powi(op) => {
120 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
121 }
122 Arithmetic::Sqrt(op) => {
123 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
124 }
125 Arithmetic::InverseSqrt(op) => {
126 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
127 }
128 Arithmetic::Round(op) => {
129 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
130 }
131 Arithmetic::Floor(op) => {
132 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
133 }
134 Arithmetic::Ceil(op) => {
135 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
136 }
137 Arithmetic::Trunc(op) => {
138 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
139 }
140 Arithmetic::Erf(op) => {
141 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
142 }
143 Arithmetic::Recip(op) => {
144 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
145 }
146 Arithmetic::Clamp(op) => {
147 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
148 sanitize_constant_scalar_ref_var(&mut op.min_value, &inst.out.unwrap());
149 sanitize_constant_scalar_ref_var(&mut op.max_value, &inst.out.unwrap());
150 }
151 Arithmetic::Modulo(op) => {
152 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
153 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
154 }
155 Arithmetic::Neg(op) => {
156 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap())
157 }
158 Arithmetic::Max(op) => {
159 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
160 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
161 }
162 Arithmetic::Min(op) => {
163 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
164 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
165 }
166 Arithmetic::Remainder(op) => {
167 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
168 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
169 }
170 Arithmetic::Magnitude(op) => {
171 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
172 }
173 Arithmetic::Normalize(op) => {
174 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
175 }
176 Arithmetic::Dot(op) => {
177 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
178 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
179 }
180 },
181 Operation::Comparison(op) => match op {
182 Comparison::Greater(op) => {
183 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
184 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
185 }
186 Comparison::LowerEqual(op) => {
187 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
188 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
189 }
190 Comparison::GreaterEqual(op) => {
191 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
192 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
193 }
194 Comparison::Equal(op) => {
195 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
196 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
197 }
198 Comparison::NotEqual(op) => {
199 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
200 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
201 }
202 Comparison::Lower(op) => {
203 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
204 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
205 }
206 Comparison::IsNan(_op) | Comparison::IsInf(_op) => {
207 }
209 },
210 Operation::Bitwise(op) => match op {
211 Bitwise::BitwiseAnd(op) => {
212 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
213 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
214 }
215 Bitwise::BitwiseOr(op) => {
216 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
217 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
218 }
219 Bitwise::BitwiseXor(op) => {
220 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
221 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
222 }
223 Bitwise::CountOnes(_) | Bitwise::LeadingZeros(_) | Bitwise::FindFirstSet(_) => {
224 }
226 Bitwise::ReverseBits(op) => {
227 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
228 }
229 Bitwise::ShiftLeft(op) => {
230 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
231 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
232 }
233 Bitwise::ShiftRight(op) => {
234 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
235 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
236 }
237 Bitwise::BitwiseNot(op) => {
238 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
239 }
240 },
241 Operation::Operator(op) => match op {
242 Operator::Index(op) => {
243 sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
244 sanitize_constant_scalar_ref_elem(
245 &mut op.index,
246 ElemType::UInt(UIntKind::U32),
247 );
248 }
249 Operator::UncheckedIndex(op) => {
250 sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
251 sanitize_constant_scalar_ref_elem(
252 &mut op.index,
253 ElemType::UInt(UIntKind::U32),
254 );
255 }
256 Operator::IndexAssign(op) => {
257 sanitize_constant_scalar_ref_elem(
258 &mut op.index,
259 ElemType::UInt(UIntKind::U32),
260 );
261 sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
262 }
263 Operator::UncheckedIndexAssign(op) => {
264 sanitize_constant_scalar_ref_elem(
265 &mut op.index,
266 ElemType::UInt(UIntKind::U32),
267 );
268 sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
269 }
270 Operator::And(op) => {
271 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
272 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
273 }
274 Operator::Or(op) => {
275 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
276 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
277 }
278 Operator::Not(op) => {
279 sanitize_constant_scalar_ref_elem(&mut op.input, ElemType::Bool);
280 }
281 Operator::InitLine(_) => {
282 }
284 Operator::CopyMemory(op) => {
285 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
286 sanitize_constant_scalar_ref_elem(
287 &mut op.in_index,
288 ElemType::UInt(UIntKind::U32),
289 );
290 sanitize_constant_scalar_ref_elem(
291 &mut op.out_index,
292 ElemType::UInt(UIntKind::U32),
293 );
294 }
295 Operator::CopyMemoryBulk(op) => {
296 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
297 sanitize_constant_scalar_ref_elem(
298 &mut op.in_index,
299 ElemType::UInt(UIntKind::U32),
300 );
301 sanitize_constant_scalar_ref_elem(
302 &mut op.out_index,
303 ElemType::UInt(UIntKind::U32),
304 );
305 }
306 Operator::Select(op) => {
307 sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
308 sanitize_constant_scalar_ref_var(&mut op.then, &inst.out.unwrap());
309 sanitize_constant_scalar_ref_var(&mut op.or_else, &inst.out.unwrap());
310 }
311 Operator::Cast(_) => {}
312 Operator::Reinterpret(_) => {}
313 },
314 Operation::Atomic(op) => match op {
315 AtomicOp::Load(_) => {}
316 AtomicOp::Store(_) => {}
317 AtomicOp::Swap(op) => {
318 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
319 }
320 AtomicOp::CompareAndSwap(op) => {
321 sanitize_constant_scalar_ref_var(&mut op.cmp, &inst.out.unwrap());
322 sanitize_constant_scalar_ref_var(&mut op.val, &inst.out.unwrap());
323 }
324 AtomicOp::Add(op) => {
325 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
326 }
327 AtomicOp::Sub(op) => {
328 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
329 }
330 AtomicOp::Max(op) => {
331 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
332 }
333 AtomicOp::Min(op) => {
334 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
335 }
336 AtomicOp::And(op) => {
337 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
338 }
339 AtomicOp::Or(op) => {
340 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
341 }
342 AtomicOp::Xor(op) => {
343 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
344 }
345 },
346 Operation::Metadata(op) => match op {
347 Metadata::Stride { dim, .. } => {
348 sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
349 }
350 Metadata::Shape { dim, .. } => {
351 sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
352 }
353 Metadata::Length { .. }
354 | Metadata::BufferLength { .. }
355 | Metadata::Rank { .. } => {
356 }
358 },
359 Operation::Branch(op) => match op {
360 Branch::If(op) => {
361 sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
362 }
363 Branch::IfElse(op) => {
364 sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
365 }
366 Branch::RangeLoop(op) => {
367 sanitize_constant_scalar_ref_var(&mut op.end, &op.start);
368 sanitize_constant_scalar_ref_var(&mut op.i, &op.start);
369 if let Some(step) = &mut op.step {
370 sanitize_constant_scalar_ref_elem(step, ElemType::UInt(UIntKind::U32));
371 }
372 }
373 _ => {
374 }
376 },
377 Operation::Synchronization(_) => {
378 }
380 Operation::Plane(_) => {
381 }
383 Operation::CoopMma(op) => match op {
384 CoopMma::Fill { value } => {
385 sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
386 }
387 CoopMma::Load { value, stride, .. } => {
388 sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
389 sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
390 }
391 CoopMma::Execute { .. }
392 | CoopMma::ExecuteManual { .. }
393 | CoopMma::ExecuteScaled { .. } => {
394 }
396 CoopMma::Store { stride, .. } => {
397 sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
398 }
399 CoopMma::Cast { .. } => {
400 }
402 CoopMma::RowIndex { lane_id, i, .. } => {
403 sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
404 sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
405 }
406 CoopMma::ColIndex { lane_id, i, .. } => {
407 sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
408 sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
409 }
410 },
411 Operation::NonSemantic(_) => {
412 }
414 Operation::Barrier(_) => {
415 }
417 Operation::Tma(_) => {
418 }
420 Operation::Marker(_) => {
421 }
423 });
424 self
425 }
426}
427
428fn sanitize_constant_scalar_ref_var(var: &mut Variable, reference: &Variable) {
429 if !reference.ty.is_semantic() {
430 let elem = reference.ty.elem_type();
431 sanitize_constant_scalar_ref_elem(var, elem);
432 }
433}
434
435fn sanitize_constant_scalar_ref_elem(var: &mut Variable, elem: ElemType) {
436 if let VariableKind::ConstantScalar(scalar) = var.kind
437 && scalar.elem_type() != elem
438 {
439 *var = match scalar {
440 super::ConstantScalarValue::Int(val, _) => elem.constant_from_i64(val),
441 super::ConstantScalarValue::Float(val, _) => elem.constant_from_f64(val),
442 super::ConstantScalarValue::UInt(val, _) => elem.constant_from_u64(val),
443 super::ConstantScalarValue::Bool(val) => elem.constant_from_bool(val),
444 };
445 }
446}