1use core::fmt::Display;
2
3use alloc::string::ToString;
4use alloc::vec::Vec;
5
6use crate::{Allocator, AtomicOp, Bitwise, Comparison, Operator};
7
8use super::{
9 Arithmetic, Branch, CoopMma, ElemType, Instruction, Metadata, Operation, UIntKind, Variable,
10 VariableKind,
11};
12
13pub trait Processor: core::fmt::Debug {
14 fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing;
15}
16
17pub struct ScopeProcessing {
19 pub variables: Vec<Variable>,
21 pub instructions: Vec<Instruction>,
23}
24
25impl Display for ScopeProcessing {
26 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27 writeln!(f, "{{")?;
28 for instruction in self.instructions.iter() {
29 let instruction_str = instruction.to_string();
30 if !instruction_str.is_empty() {
31 writeln!(f, " {instruction_str}")?;
32 }
33 }
34 write!(f, "}}")?;
35 Ok(())
36 }
37}
38
39impl ScopeProcessing {
40 pub fn optimize(self) -> Self {
48 self.sanitize_constant_scalars()
49 }
50
51 fn sanitize_constant_scalars(mut self) -> Self {
54 self.instructions
55 .iter_mut()
56 .for_each(|inst| match &mut inst.operation {
57 Operation::Copy(op) => {
58 sanitize_constant_scalar_ref_var(op, &inst.out.unwrap());
59 }
60 Operation::Arithmetic(op) => match op {
61 Arithmetic::Add(op) => {
62 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
63 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
64 }
65 Arithmetic::SaturatingAdd(op) => {
66 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
67 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
68 }
69 Arithmetic::Fma(op) => {
70 sanitize_constant_scalar_ref_var(&mut op.a, &inst.out.unwrap());
71 sanitize_constant_scalar_ref_var(&mut op.b, &inst.out.unwrap());
72 sanitize_constant_scalar_ref_var(&mut op.c, &inst.out.unwrap());
73 }
74 Arithmetic::Sub(op) => {
75 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
76 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
77 }
78 Arithmetic::SaturatingSub(op) => {
79 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
80 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
81 }
82 Arithmetic::Mul(op) => {
83 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
84 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
85 }
86 Arithmetic::Div(op) => {
87 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
88 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
89 }
90 Arithmetic::MulHi(op) => {
91 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
92 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
93 }
94 Arithmetic::Abs(op) => {
95 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
96 }
97 Arithmetic::Exp(op) => {
98 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
99 }
100 Arithmetic::Log(op) => {
101 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
102 }
103 Arithmetic::Log1p(op) => {
104 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
105 }
106 Arithmetic::Cos(op) => {
107 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
108 }
109 Arithmetic::Sin(op) => {
110 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
111 }
112 Arithmetic::Tan(op) => {
113 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
114 }
115 Arithmetic::Tanh(op) => {
116 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
117 }
118 Arithmetic::Sinh(op) => {
119 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
120 }
121 Arithmetic::Cosh(op) => {
122 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
123 }
124 Arithmetic::ArcCos(op) => {
125 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
126 }
127 Arithmetic::ArcSin(op) => {
128 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
129 }
130 Arithmetic::ArcTan(op) => {
131 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
132 }
133 Arithmetic::ArcSinh(op) => {
134 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
135 }
136 Arithmetic::ArcCosh(op) => {
137 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
138 }
139 Arithmetic::ArcTanh(op) => {
140 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
141 }
142 Arithmetic::Degrees(op) => {
143 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
144 }
145 Arithmetic::Radians(op) => {
146 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
147 }
148 Arithmetic::ArcTan2(op) => {
149 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
150 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
151 }
152 Arithmetic::Powf(op) => {
153 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
154 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
155 }
156 Arithmetic::Powi(op) => {
157 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
158 }
159 Arithmetic::Hypot(op) => {
160 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
161 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
162 }
163 Arithmetic::Rhypot(op) => {
164 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
165 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
166 }
167 Arithmetic::Sqrt(op) => {
168 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
169 }
170 Arithmetic::InverseSqrt(op) => {
171 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
172 }
173 Arithmetic::Round(op) => {
174 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
175 }
176 Arithmetic::Floor(op) => {
177 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
178 }
179 Arithmetic::Ceil(op) => {
180 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
181 }
182 Arithmetic::Trunc(op) => {
183 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
184 }
185 Arithmetic::Erf(op) => {
186 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
187 }
188 Arithmetic::Recip(op) => {
189 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
190 }
191 Arithmetic::Clamp(op) => {
192 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
193 sanitize_constant_scalar_ref_var(&mut op.min_value, &inst.out.unwrap());
194 sanitize_constant_scalar_ref_var(&mut op.max_value, &inst.out.unwrap());
195 }
196 Arithmetic::Modulo(op) => {
197 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
198 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
199 }
200 Arithmetic::Neg(op) => {
201 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap())
202 }
203 Arithmetic::Max(op) => {
204 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
205 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
206 }
207 Arithmetic::Min(op) => {
208 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
209 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
210 }
211 Arithmetic::Remainder(op) => {
212 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
213 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
214 }
215 Arithmetic::Magnitude(op) => {
216 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
217 }
218 Arithmetic::Normalize(op) => {
219 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
220 }
221 Arithmetic::Dot(op) => {
222 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
223 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
224 }
225 },
226 Operation::Comparison(op) => match op {
227 Comparison::Greater(op) => {
228 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
229 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
230 }
231 Comparison::LowerEqual(op) => {
232 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
233 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
234 }
235 Comparison::GreaterEqual(op) => {
236 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
237 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
238 }
239 Comparison::Equal(op) => {
240 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
241 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
242 }
243 Comparison::NotEqual(op) => {
244 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
245 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
246 }
247 Comparison::Lower(op) => {
248 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
249 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
250 }
251 Comparison::IsNan(_op) | Comparison::IsInf(_op) => {
252 }
254 },
255 Operation::Bitwise(op) => match op {
256 Bitwise::BitwiseAnd(op) => {
257 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
258 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
259 }
260 Bitwise::BitwiseOr(op) => {
261 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
262 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
263 }
264 Bitwise::BitwiseXor(op) => {
265 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
266 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
267 }
268 Bitwise::CountOnes(_) | Bitwise::LeadingZeros(_) | Bitwise::FindFirstSet(_) => {
269 }
271 Bitwise::ReverseBits(op) => {
272 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
273 }
274 Bitwise::ShiftLeft(op) => {
275 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
276 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
277 }
278 Bitwise::ShiftRight(op) => {
279 sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
280 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
281 }
282 Bitwise::BitwiseNot(op) => {
283 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
284 }
285 },
286 Operation::Operator(op) => match op {
287 Operator::Index(op) => {
288 sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
289 sanitize_constant_scalar_ref_elem(
290 &mut op.index,
291 ElemType::UInt(UIntKind::U32),
292 );
293 }
294 Operator::UncheckedIndex(op) => {
295 sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
296 sanitize_constant_scalar_ref_elem(
297 &mut op.index,
298 ElemType::UInt(UIntKind::U32),
299 );
300 }
301 Operator::IndexAssign(op) => {
302 sanitize_constant_scalar_ref_elem(
303 &mut op.index,
304 ElemType::UInt(UIntKind::U32),
305 );
306 sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
307 }
308 Operator::UncheckedIndexAssign(op) => {
309 sanitize_constant_scalar_ref_elem(
310 &mut op.index,
311 ElemType::UInt(UIntKind::U32),
312 );
313 sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
314 }
315 Operator::And(op) => {
316 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
317 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
318 }
319 Operator::Or(op) => {
320 sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
321 sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
322 }
323 Operator::Not(op) => {
324 sanitize_constant_scalar_ref_elem(&mut op.input, ElemType::Bool);
325 }
326 Operator::InitLine(_) => {
327 }
329 Operator::CopyMemory(op) => {
330 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
331 sanitize_constant_scalar_ref_elem(
332 &mut op.in_index,
333 ElemType::UInt(UIntKind::U32),
334 );
335 sanitize_constant_scalar_ref_elem(
336 &mut op.out_index,
337 ElemType::UInt(UIntKind::U32),
338 );
339 }
340 Operator::CopyMemoryBulk(op) => {
341 sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
342 sanitize_constant_scalar_ref_elem(
343 &mut op.in_index,
344 ElemType::UInt(UIntKind::U32),
345 );
346 sanitize_constant_scalar_ref_elem(
347 &mut op.out_index,
348 ElemType::UInt(UIntKind::U32),
349 );
350 }
351 Operator::Select(op) => {
352 sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
353 sanitize_constant_scalar_ref_var(&mut op.then, &inst.out.unwrap());
354 sanitize_constant_scalar_ref_var(&mut op.or_else, &inst.out.unwrap());
355 }
356 Operator::Cast(_) => {}
357 Operator::Reinterpret(_) => {}
358 },
359 Operation::Atomic(op) => match op {
360 AtomicOp::Load(_) => {}
361 AtomicOp::Store(_) => {}
362 AtomicOp::Swap(op) => {
363 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
364 }
365 AtomicOp::CompareAndSwap(op) => {
366 sanitize_constant_scalar_ref_var(&mut op.cmp, &inst.out.unwrap());
367 sanitize_constant_scalar_ref_var(&mut op.val, &inst.out.unwrap());
368 }
369 AtomicOp::Add(op) => {
370 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
371 }
372 AtomicOp::Sub(op) => {
373 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
374 }
375 AtomicOp::Max(op) => {
376 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
377 }
378 AtomicOp::Min(op) => {
379 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
380 }
381 AtomicOp::And(op) => {
382 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
383 }
384 AtomicOp::Or(op) => {
385 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
386 }
387 AtomicOp::Xor(op) => {
388 sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
389 }
390 },
391 Operation::Metadata(op) => match op {
392 Metadata::Stride { dim, .. } => {
393 sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
394 }
395 Metadata::Shape { dim, .. } => {
396 sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
397 }
398 Metadata::Length { .. }
399 | Metadata::BufferLength { .. }
400 | Metadata::Rank { .. } => {
401 }
403 },
404 Operation::Branch(op) => match op {
405 Branch::If(op) => {
406 sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
407 }
408 Branch::IfElse(op) => {
409 sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
410 }
411 Branch::RangeLoop(op) => {
412 sanitize_constant_scalar_ref_var(&mut op.end, &op.start);
413 sanitize_constant_scalar_ref_var(&mut op.i, &op.start);
414 if let Some(step) = &mut op.step {
415 sanitize_constant_scalar_ref_elem(step, ElemType::UInt(UIntKind::U32));
416 }
417 }
418 _ => {
419 }
421 },
422 Operation::Synchronization(_) => {
423 }
425 Operation::Plane(_) => {
426 }
428 Operation::CoopMma(op) => match op {
429 CoopMma::Fill { value } => {
430 sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
431 }
432 CoopMma::Load { value, stride, .. } => {
433 sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
434 sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
435 }
436 CoopMma::Execute { .. }
437 | CoopMma::ExecuteManual { .. }
438 | CoopMma::ExecuteScaled { .. } => {
439 }
441 CoopMma::Store { stride, .. } => {
442 sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
443 }
444 CoopMma::Cast { .. } => {
445 }
447 CoopMma::RowIndex { lane_id, i, .. } => {
448 sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
449 sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
450 }
451 CoopMma::ColIndex { lane_id, i, .. } => {
452 sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
453 sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
454 }
455 CoopMma::LoadMatrix { .. } | CoopMma::StoreMatrix { .. } => {
456 }
458 },
459 Operation::NonSemantic(_) => {
460 }
462 Operation::Barrier(_) => {
463 }
465 Operation::Tma(_) => {
466 }
468 Operation::Marker(_) => {
469 }
471 });
472 self
473 }
474}
475
476fn sanitize_constant_scalar_ref_var(var: &mut Variable, reference: &Variable) {
477 if !reference.ty.is_semantic() {
478 let elem = reference.ty.elem_type();
479 sanitize_constant_scalar_ref_elem(var, elem);
480 }
481}
482
483fn sanitize_constant_scalar_ref_elem(var: &mut Variable, elem: ElemType) {
484 if let VariableKind::ConstantScalar(scalar) = var.kind
485 && scalar.elem_type() != elem
486 {
487 *var = match scalar {
488 super::ConstantScalarValue::Int(val, _) => elem.constant_from_i64(val),
489 super::ConstantScalarValue::Float(val, _) => elem.constant_from_f64(val),
490 super::ConstantScalarValue::UInt(val, _) => elem.constant_from_u64(val),
491 super::ConstantScalarValue::Bool(val) => elem.constant_from_bool(val),
492 };
493 }
494}