1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Formatter, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone, Copy)]
17pub struct BinaryInstruction<D: Dialect> {
18 pub lhs: Variable<D>,
19 pub rhs: Variable<D>,
20 pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25 pub list: Variable<D>,
26 pub index: Variable<D>,
27 pub line_size: u32,
28 pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33 pub index: Variable<D>,
34 pub value: Variable<D>,
35 pub line_size: u32,
36 pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct UnaryInstruction<D: Dialect> {
41 pub input: Variable<D>,
42 pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47 Metadata {
48 info_offset: Variable<D>,
49 split_meta: bool,
50 out: Variable<D>,
51 },
52 ExtendedMetadata {
53 info_offset: Variable<D>,
54 dim: Variable<D>,
55 split_meta: bool,
56 static_offset: u32,
57 out: Variable<D>,
58 },
59 ConstLength {
60 length: u32,
61 out: Variable<D>,
62 },
63 SliceLength {
64 input: Variable<D>,
65 out: Variable<D>,
66 },
67 DeclareVariable {
68 var: Variable<D>,
69 },
70 Modulo(BinaryInstruction<D>),
71 Remainder(BinaryInstruction<D>),
72 Add(BinaryInstruction<D>),
73 SaturatingAdd(BinaryInstruction<D>),
74 Fma {
75 a: Variable<D>,
76 b: Variable<D>,
77 c: Variable<D>,
78 out: Variable<D>,
79 },
80 Div(BinaryInstruction<D>),
81 FastDiv(BinaryInstruction<D>),
82 FastRecip(UnaryInstruction<D>),
83 Mul(BinaryInstruction<D>),
84 Sub(BinaryInstruction<D>),
85 SaturatingSub(BinaryInstruction<D>),
86 HiMul(BinaryInstruction<D>),
87 Index(IndexInstruction<D>),
88 IndexAssign(IndexAssignInstruction<D>),
89 Assign(UnaryInstruction<D>),
90 SpecialCast(UnaryInstruction<D>),
91 RangeLoop {
92 i: Variable<D>,
93 start: Variable<D>,
94 end: Variable<D>,
95 step: Option<Variable<D>>,
96 inclusive: bool,
97 instructions: Vec<Self>,
98 },
99 VecInit {
100 inputs: Vec<Variable<D>>,
101 out: Variable<D>,
102 },
103 Loop {
104 instructions: Vec<Self>,
105 },
106 If {
107 cond: Variable<D>,
108 instructions: Vec<Self>,
109 },
110 IfElse {
111 cond: Variable<D>,
112 instructions_if: Vec<Self>,
113 instructions_else: Vec<Self>,
114 },
115 Select {
116 cond: Variable<D>,
117 then: Variable<D>,
118 or_else: Variable<D>,
119 out: Variable<D>,
120 },
121 Switch {
122 value: Variable<D>,
123 instructions_default: Vec<Self>,
124 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
125 },
126 Slice {
127 input: Variable<D>,
128 start: Variable<D>,
129 end: Variable<D>,
130 out: Variable<D>,
131 },
132 CheckedSlice {
133 input: Variable<D>,
134 start: Variable<D>,
135 end: Variable<D>,
136 out: Variable<D>,
137 len: Variable<D>,
138 },
139 ReinterpretSlice {
140 input: Variable<D>,
141 line_size: u32,
142 out: Variable<D>,
143 },
144 Return,
145 Break,
146 Equal(BinaryInstruction<D>),
147 NotEqual(BinaryInstruction<D>),
148 Lower(BinaryInstruction<D>),
149 Greater(BinaryInstruction<D>),
150 LowerEqual(BinaryInstruction<D>),
151 GreaterEqual(BinaryInstruction<D>),
152 Erf(UnaryInstruction<D>),
153 BitwiseOr(BinaryInstruction<D>),
154 BitwiseAnd(BinaryInstruction<D>),
155 BitwiseXor(BinaryInstruction<D>),
156 CountBits(UnaryInstruction<D>),
157 ReverseBits(UnaryInstruction<D>),
158 ShiftLeft(BinaryInstruction<D>),
159 ShiftRight(BinaryInstruction<D>),
160 BitwiseNot(UnaryInstruction<D>),
161 LeadingZeros(UnaryInstruction<D>),
162 FindFirstSet(UnaryInstruction<D>),
163 Abs(UnaryInstruction<D>),
164 Exp(UnaryInstruction<D>),
165 FastExp(UnaryInstruction<D>),
166 Log(UnaryInstruction<D>),
167 FastLog(UnaryInstruction<D>),
168 Log1p(UnaryInstruction<D>),
169 Cos(UnaryInstruction<D>),
170 Sin(UnaryInstruction<D>),
171 Tan(UnaryInstruction<D>),
172 Tanh(UnaryInstruction<D>),
173 Sinh(UnaryInstruction<D>),
174 Cosh(UnaryInstruction<D>),
175 ArcCos(UnaryInstruction<D>),
176 ArcSin(UnaryInstruction<D>),
177 ArcTan(UnaryInstruction<D>),
178 ArcSinh(UnaryInstruction<D>),
179 ArcCosh(UnaryInstruction<D>),
180 ArcTanh(UnaryInstruction<D>),
181 Degrees(UnaryInstruction<D>),
182 Radians(UnaryInstruction<D>),
183 ArcTan2(BinaryInstruction<D>),
184 FastSin(UnaryInstruction<D>),
185 FastCos(UnaryInstruction<D>),
186 FastTanh(UnaryInstruction<D>),
187 Powf(BinaryInstruction<D>),
188 FastPowf(BinaryInstruction<D>),
189 Powi(BinaryInstruction<D>),
190 Hypot(BinaryInstruction<D>),
191 Rhypot(BinaryInstruction<D>),
192 Sqrt(UnaryInstruction<D>),
193 FastSqrt(UnaryInstruction<D>),
194 InverseSqrt(UnaryInstruction<D>),
195 FastInverseSqrt(UnaryInstruction<D>),
196 Min(BinaryInstruction<D>),
197 Max(BinaryInstruction<D>),
198 Not(UnaryInstruction<D>),
199 Or(BinaryInstruction<D>),
200 And(BinaryInstruction<D>),
201 Clamp {
202 input: Variable<D>,
203 min_value: Variable<D>,
204 max_value: Variable<D>,
205 out: Variable<D>,
206 },
207 IsNan(UnaryInstruction<D>),
208 IsInf(UnaryInstruction<D>),
209 SyncThreads,
210 SyncWarp,
211 ThreadFence,
212 ProxyAsyncToSharedFence,
213 BulkCommitGroup,
214 BulkWaitGroup {
215 max_pending: u32,
216 },
217 BulkWaitGroupRead {
218 max_pending: u32,
219 },
220 TmaReplacePointer {
221 buffer: Variable<D>,
222 offset: Variable<D>,
223 tensor_map: Variable<D>,
224 out: Variable<D>,
225 },
226 Round(UnaryInstruction<D>),
227 Ceil(UnaryInstruction<D>),
228 Trunc(UnaryInstruction<D>),
229 Floor(UnaryInstruction<D>),
230 Warp(WarpInstruction<D>),
231 Wmma(WmmaInstruction<D>),
232 Bitcast(UnaryInstruction<D>),
233 AtomicLoad(UnaryInstruction<D>),
234 AtomicStore(UnaryInstruction<D>),
235 AtomicSwap(BinaryInstruction<D>),
236 AtomicAdd(BinaryInstruction<D>),
237 AtomicSub(BinaryInstruction<D>),
238 AtomicMax(BinaryInstruction<D>),
239 AtomicMin(BinaryInstruction<D>),
240 AtomicAnd(BinaryInstruction<D>),
241 AtomicOr(BinaryInstruction<D>),
242 AtomicXor(BinaryInstruction<D>),
243 AtomicCAS {
244 input: Variable<D>,
245 cmp: Variable<D>,
246 val: Variable<D>,
247 out: Variable<D>,
248 },
249 Neg(UnaryInstruction<D>),
250 Magnitude(UnaryInstruction<D>),
251 FastMagnitude(UnaryInstruction<D>),
252 Normalize(UnaryInstruction<D>),
253 FastNormalize(UnaryInstruction<D>),
254 Dot(BinaryInstruction<D>),
255 Copy {
256 input: Variable<D>,
257 in_index: Variable<D>,
258 out: Variable<D>,
259 out_index: Variable<D>,
260 },
261 CopyBulk {
262 input: Variable<D>,
263 in_index: Variable<D>,
264 out: Variable<D>,
265 out_index: Variable<D>,
266 len: u32,
267 },
268 Printf {
269 format_string: String,
270 args: Vec<Variable<D>>,
271 },
272 Comment {
273 content: String,
274 },
275 Pipeline(PipelineOps<D>),
276 Barrier(BarrierOps<D>),
277 MemCopyAsyncTensorSharedToGlobal {
278 smem_buffer: Variable<D>,
279 smem_offset: Variable<D>,
280 tensor_map: Variable<D>,
281 indices: Vec<Variable<D>>,
282 },
283 Line {
284 file: Cow<'static, str>,
285 line: u32,
286 },
287}
288
289impl<D: Dialect> Display for Instruction<D> {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 match self {
292 Instruction::Return => f.write_str("return;"),
293 Instruction::Break => f.write_str("break;"),
294 Instruction::DeclareVariable { var } => match var {
295 Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
296 _ => {
297 let item = var.item();
298 writeln!(f, "{item} {var};")
299 }
300 },
301 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
302 Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
303 Instruction::Slice {
304 input,
305 start,
306 end,
307 out,
308 } => {
309 let item = out.item();
310 let addr_space = D::address_space_for_variable(input);
311 writeln!(f, "const uint {out}_length = {end} - {start};")?;
312 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
313 }
314 Instruction::CheckedSlice {
315 input,
316 start,
317 end,
318 out,
319 len,
320 } => {
321 let item = out.item();
322 let addr_space = D::address_space_for_variable(input);
323 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
324 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
325 }
326 Instruction::ReinterpretSlice {
327 input,
328 line_size,
329 out,
330 } => {
331 let mut item = out.item();
332 item.vectorization = *line_size as usize;
333 let addr_space = D::address_space_for_variable(input);
334
335 writeln!(
336 f,
337 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
338 )
339 }
340 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
341 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
342 Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
343 Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
344 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
345 Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
346 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
347 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
348 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
349 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
350 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
351 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
352 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
353 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
354 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
355 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
356 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
357 Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
358 Instruction::IndexAssign(it) => {
359 IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
360 }
361 Instruction::Copy {
362 input,
363 in_index,
364 out,
365 out_index,
366 } => {
367 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
368 }
369 Instruction::CopyBulk {
370 input,
371 in_index,
372 out,
373 out_index,
374 len,
375 } => {
376 for i in 0..*len {
377 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
378 }
379 Ok(())
380 }
381 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
382 Instruction::RangeLoop {
383 i,
384 start,
385 end,
386 step,
387 inclusive,
388 instructions,
389 } => {
390 let increment = step
391 .map(|step| format!("{i} += {step}"))
392 .unwrap_or_else(|| format!("++{i}"));
393 let cmp = if *inclusive { "<=" } else { "<" };
394 let i_ty = i.item();
395
396 write!(
397 f,
398 "
399for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
400"
401 )?;
402 for instruction in instructions {
403 write!(f, "{instruction}")?;
404 }
405
406 f.write_str("}\n")
407 }
408 Instruction::Loop { instructions } => {
409 writeln!(f, "while (true) {{")?;
410 for i in instructions {
411 write!(f, "{i}")?;
412 }
413 f.write_str("}\n")
414 }
415 Instruction::If { cond, instructions } => {
416 writeln!(f, "if ({cond}) {{")?;
417 for i in instructions {
418 write!(f, "{i}")?;
419 }
420 f.write_str("}\n")
421 }
422 Instruction::IfElse {
423 cond,
424 instructions_if,
425 instructions_else,
426 } => {
427 writeln!(f, "if ({cond}) {{")?;
428 for i in instructions_if {
429 write!(f, "{i}")?;
430 }
431 f.write_str("} else {\n")?;
432 for i in instructions_else {
433 write!(f, "{i}")?;
434 }
435 f.write_str("}\n")
436 }
437 Instruction::Select {
438 cond,
439 then,
440 or_else,
441 out,
442 } => {
443 let item_or_else = or_else.item();
444 let item_then = then.item();
445 let item_out = out.item();
446
447 let vf_then = item_then.vectorization;
448 let vf_or_else = item_or_else.vectorization;
449 let vf_out = item_out.vectorization;
450 let vf_cond = cond.item().vectorization;
451
452 let item_out = out.item();
453 let cond_elem = cond.item().elem;
454 let out = out.fmt_left();
455
456 let should_broadcast =
457 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
458
459 if should_broadcast {
460 let vf = usize::max(vf_cond, vf_out);
461 let vf = usize::max(vf, vf_then);
462 let vf = usize::max(vf, vf_or_else);
463
464 writeln!(f, "{out} = {item_out} {{")?;
465 for i in 0..vf {
466 let theni = then.index(i);
467 let or_elsei = or_else.index(i);
468 let condi = cond.index(i);
469 let condi = EnsureBoolArg {
470 var: &condi,
471 elem: &cond_elem,
472 };
473
474 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
475 }
476
477 writeln!(f, "}};")
478 } else {
479 let cond = EnsureBoolArg {
480 var: &cond,
481 elem: &cond_elem,
482 };
483 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
484 }
485 }
486 Instruction::Switch {
487 value,
488 instructions_default,
489 instructions_cases,
490 } => {
491 writeln!(f, "switch({value}) {{")?;
492 for (value, block) in instructions_cases {
493 write!(f, "case {value}:\n{{\n")?;
494 for i in block {
495 i.fmt(f)?;
496 }
497 f.write_str("break;\n}\n")?;
498 }
499 f.write_str("default:\n{")?;
500 for i in instructions_default {
501 i.fmt(f)?;
502 }
503 f.write_str("}\n}\n")
504 }
505 Instruction::Metadata {
506 info_offset,
507 split_meta,
508 out,
509 } => {
510 let out = out.fmt_left();
511 match *split_meta {
512 true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
513 false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
514 }
515 }
516 Instruction::ExtendedMetadata {
517 info_offset,
518 dim,
519 split_meta,
520 static_offset,
521 out,
522 } => {
523 let out = out.fmt_left();
524 match *split_meta {
525 true => writeln!(
526 f,
527 "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
528 ),
529 false => writeln!(
530 f,
531 "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
532 ),
533 }
534 }
535 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
536 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
537 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
538 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
539 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
540 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
541 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
542 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
543 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
544 Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
545 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
546 Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
547 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
548 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
549 Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
550 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
551 Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
552 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
553 Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
554 Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
555 Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
556 Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
557 Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
558 Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
559 Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
560 Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
561 Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
562 Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
563 Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
564 Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
565 Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
566 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
567 Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
568 Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
569 Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out),
570 Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out),
571 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
572 Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
573 Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
574 Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
575 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
576 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
577 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
578 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
579 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
580 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
581 Instruction::Clamp {
582 input,
583 min_value,
584 max_value,
585 out,
586 } => Clamp::format(f, input, min_value, max_value, out),
587 Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
588 Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
589 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
590 Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
591 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
592 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
593 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
594 Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
595 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
596 Instruction::SliceLength { input, out } => {
597 let out = out.fmt_left();
598 writeln!(f, "{out} = {input}_length;")
599 }
600 Instruction::ConstLength { length, out } => {
601 let out = out.fmt_left();
602 writeln!(f, "{out} = {length};")
603 }
604 Instruction::Warp(it) => write!(f, "{it}"),
605 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
606 Instruction::Wmma(it) => write!(f, "{it}"),
607 Instruction::Bitcast(UnaryInstruction { input, out }) => {
608 let qualifier = out.const_qualifier();
609 let input_item = input.item();
610 let out_item = out.item();
611
612 if out_item.elem.size() * out_item.vectorization
613 != input.item().elem.size() * input.item().vectorization
614 {
615 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
616 } else {
617 let out = out.fmt_left();
618 let addr_space = D::address_space_for_variable(input);
619 writeln!(
620 f,
621 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
622 )
623 }
624 }
625 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
626 D::compile_atomic_add(f, lhs, rhs, out)
627 }
628 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
629 D::compile_atomic_and(f, lhs, rhs, out)
630 }
631 Instruction::AtomicCAS {
632 input,
633 cmp,
634 val,
635 out,
636 } => D::compile_atomic_cas(f, input, cmp, val, out),
637 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
638 D::compile_atomic_load(f, input, out)
639 }
640 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
641 D::compile_atomic_max(f, lhs, rhs, out)
642 }
643 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
644 D::compile_atomic_min(f, lhs, rhs, out)
645 }
646 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
647 D::compile_atomic_or(f, lhs, rhs, out)
648 }
649 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
650 D::compile_atomic_store(f, input, out)
651 }
652 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
653 D::compile_atomic_sub(f, lhs, rhs, out)
654 }
655 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
656 D::compile_atomic_swap(f, lhs, rhs, out)
657 }
658 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
659 D::compile_atomic_xor(f, lhs, rhs, out)
660 }
661 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
662 Instruction::Neg(UnaryInstruction { input, out }) => {
663 let out = out.fmt_left();
664 writeln!(f, "{out} = -{input};")
665 }
666 Instruction::Normalize(inst) => {
667 Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
668 }
669 Instruction::FastNormalize(inst) => {
670 Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
671 }
672 Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
673 Instruction::FastMagnitude(inst) => {
674 Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
675 }
676 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
677 Instruction::VecInit { inputs, out } => {
678 let item = out.item();
679 let inputs = inputs
680 .iter()
681 .map(|input| format!("{input}"))
682 .collect::<Vec<_>>();
683 let out = out.fmt_left();
684 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
685 }
686 Instruction::Printf {
687 format_string,
688 args,
689 } => D::compile_instruction_printf(f, format_string, args),
690 Instruction::Comment { content } => {
691 if content.contains('\n') {
692 writeln!(f, "/* {content} */")
693 } else {
694 writeln!(f, "// {content}")
695 }
696 }
697 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
698 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
699 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
700 Instruction::ProxyAsyncToSharedFence => {
701 writeln!(
702 f,
703 "cuda::device::experimental::fence_proxy_async_shared_cta();"
704 )
705 }
706 Instruction::BulkCommitGroup => writeln!(
707 f,
708 "cuda::device::experimental::cp_async_bulk_commit_group();"
709 ),
710 Instruction::BulkWaitGroup { max_pending } => writeln!(
711 f,
712 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
713 ),
714 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
715 f,
716 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
717 ),
718 Instruction::TmaReplacePointer {
719 buffer,
720 offset,
721 tensor_map,
722 out,
723 } => {
724 let pos = Variable::<D>::UnitPos;
725 writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
726 writeln!(
727 f,
728 "
729if({pos} == 0) {{
730 {out} = {tensor_map};
731 tensormap_replace_global_address({out}, &{buffer}[{offset}]);
732}}"
733 )?;
734 writeln!(f, "__syncthreads();")
735 }
736 Instruction::MemCopyAsyncTensorSharedToGlobal {
737 smem_buffer,
738 smem_offset,
739 tensor_map,
740 indices,
741 } => {
742 let rank = indices.len();
743 let smem_ptr = smem_buffer.fmt_ptr();
744 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
745 let _ = write!(s, "{it}, ");
746 s
747 });
748 writeln!(
749 f,
750 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
751 )
752 }
753 Instruction::SpecialCast(UnaryInstruction { input, out }) => {
754 #[cfg(not(feature = "cuda"))]
756 {
757 let _ = (input, out);
758 writeln!(
759 f,
760 "#error FP8/FP6/FP4 casting isn't supported outside of CUDA"
761 )
762 }
763 #[cfg(feature = "cuda")]
764 crate::cuda::convert::special_cast::<D>(f, input, out)
765 }
766 }
767 }
768}
769
770struct Fma<D: Dialect> {
771 _dialect: PhantomData<D>,
772}
773
774impl<D: Dialect> Fma<D> {
775 fn format(
776 f: &mut core::fmt::Formatter<'_>,
777 a: &Variable<D>,
778 b: &Variable<D>,
779 c: &Variable<D>,
780 out: &Variable<D>,
781 ) -> core::fmt::Result {
782 let out_item = out.item();
783 let num = out_item.vectorization;
784
785 let out = out.fmt_left();
786 if num == 1 {
787 writeln!(f, "{out} = fma({a}, {b}, {c});")
788 } else {
789 writeln!(f, "{out} = {out_item}{{")?;
790
791 for i in 0..num {
792 let ai = a.index(i);
793 let bi = b.index(i);
794 let ci = c.index(i);
795
796 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
797 }
798 f.write_str("};\n")
799 }
800 }
801}
802
803struct Clamp<D: Dialect> {
804 _dialect: PhantomData<D>,
805}
806
807impl<D: Dialect> Clamp<D> {
808 fn format(
809 f: &mut core::fmt::Formatter<'_>,
810 input: &Variable<D>,
811 min_value: &Variable<D>,
812 max_value: &Variable<D>,
813 out: &Variable<D>,
814 ) -> core::fmt::Result {
815 let out_item = out.item();
816 if out.item().vectorization == 1 {
817 let out = out.fmt_left();
818 write!(f, "{out} = ")?;
819 Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
820 f.write_str(";\n")
821 } else {
822 Self::unroll_vec(f, input, min_value, max_value, out)
823 }
824 }
825
826 fn format_scalar(
827 f: &mut Formatter<'_>,
828 input: impl Component<D>,
829 min_value: impl Component<D>,
830 max_value: impl Component<D>,
831 item: Item<D>,
832 ) -> std::fmt::Result {
833 D::compile_instruction_max_function_name(f, item)?;
834 write!(f, "({min_value}, ")?;
835 D::compile_instruction_min_function_name(f, item)?;
836 write!(f, "({max_value}, {input}))")
837 }
838
839 fn unroll_vec(
840 f: &mut core::fmt::Formatter<'_>,
841 input: &Variable<D>,
842 min_value: &Variable<D>,
843 max_value: &Variable<D>,
844 out: &Variable<D>,
845 ) -> std::fmt::Result {
846 let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
847 let [input, min_value, max_value, out_optimized] = optimized.args;
848
849 let item_out_original = out.item();
850 let item_out_optimized = out_optimized.item();
851
852 let index = match optimized.optimization_factor {
853 Some(factor) => item_out_original.vectorization / factor,
854 None => item_out_optimized.vectorization,
855 };
856
857 let mut write_op = |input: &Variable<D>,
858 min_value: &Variable<D>,
859 max_value: &Variable<D>,
860 out: &Variable<D>,
861 item_out: Item<D>| {
862 let out = out.fmt_left();
863 writeln!(f, "{out} = {item_out}{{")?;
864 for i in 0..index {
865 let inputi = input.index(i);
866 let min_valuei = min_value.index(i);
867 let max_valuei = max_value.index(i);
868
869 Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
870 f.write_str(", ")?;
871 }
872
873 f.write_str("};\n")
874 };
875
876 if item_out_original == item_out_optimized {
877 write_op(&input, &min_value, &max_value, out, item_out_optimized)
878 } else {
879 let out_tmp = Variable::tmp(item_out_optimized);
880 write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
881 let addr_space = D::address_space_for_variable(out);
882 let out = out.fmt_left();
883
884 writeln!(
885 f,
886 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
887 )?;
888
889 Ok(())
890 }
891 }
892}
893
894struct Remainder<D: Dialect> {
895 _dialect: PhantomData<D>,
896}
897
898impl<D: Dialect> Remainder<D> {
899 fn format(
900 f: &mut core::fmt::Formatter<'_>,
901 lhs: &Variable<D>,
902 rhs: &Variable<D>,
903 out: &Variable<D>,
904 ) -> core::fmt::Result {
905 let floor = |elem| {
906 let prefix = match elem {
907 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
908 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
909 _ => "",
910 };
911 format!("{prefix}floor")
912 };
913
914 let is_int = matches!(
915 out.elem(),
916 Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
917 );
918 let rem_expr = |lhs, rhs, floor| {
919 if is_int {
920 format!("{lhs} - {rhs} * {floor}((float){lhs} / (float){rhs})")
921 } else {
922 format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
923 }
924 };
925
926 if out.item().vectorization == 1 {
927 let floor = floor(out.elem());
928
929 let out = out.fmt_left();
930 let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
931 return writeln!(f, "{out} = {rem};");
932 }
933
934 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
935 let [lhs, rhs, out_optimized] = optimized.args;
936
937 let item_out_original = out.item();
938 let item_out_optimized = out_optimized.item();
939
940 let index = match optimized.optimization_factor {
941 Some(factor) => item_out_original.vectorization / factor,
942 None => item_out_optimized.vectorization,
943 };
944
945 let floor = floor(*item_out_optimized.elem());
946
947 let mut write_op =
948 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
949 let out = out.fmt_left();
950 writeln!(f, "{out} = {item_out}{{")?;
951 for i in 0..index {
952 let lhsi = lhs.index(i);
953 let rhsi = rhs.index(i);
954
955 let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
956 writeln!(f, "{rem}")?;
957 f.write_str(", ")?;
958 }
959
960 f.write_str("};\n")
961 };
962
963 if item_out_original == item_out_optimized {
964 write_op(&lhs, &rhs, out, item_out_optimized)
965 } else {
966 let out_tmp = Variable::tmp(item_out_optimized);
967
968 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
969
970 let addr_space = D::address_space_for_variable(&out_tmp);
971 let qualifier = out.const_qualifier();
972 let out = out.fmt_left();
973
974 writeln!(
975 f,
976 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
977 )?;
978
979 Ok(())
980 }
981 }
982}
983
984struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
985 _dialect: PhantomData<D>,
986 _sqrt: PhantomData<S>,
987}
988
989impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
990 fn format(
991 f: &mut core::fmt::Formatter<'_>,
992 input: &Variable<D>,
993 out: &Variable<D>,
994 ) -> core::fmt::Result {
995 let num = input.item().vectorization;
996 let elem = input.elem();
997
998 let mag = format!("{out}_mag");
999
1000 writeln!(f, "{} {mag} = 0.0;", out.item())?;
1001
1002 for i in 0..num {
1003 let input_i = input.index(i);
1004 writeln!(f, "{mag} += {input_i} * {input_i};")?;
1005 }
1006
1007 let out = out.fmt_left();
1008 write!(f, "{out} = ")?;
1009 S::format_unary(f, &mag, elem)?;
1010 f.write_str(";\n")
1011 }
1012}
1013
1014struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1015 _dialect: PhantomData<D>,
1016 _rsqrt: PhantomData<InvS>,
1017}
1018
1019impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1020 fn format(
1021 f: &mut core::fmt::Formatter<'_>,
1022 input: &Variable<D>,
1023 out: &Variable<D>,
1024 ) -> core::fmt::Result {
1025 let num = input.item().vectorization;
1026 let elem = input.elem();
1027 let norm = format!("{out}_norm");
1028
1029 let out_item = out.item();
1030 let out = out.fmt_left();
1031 writeln!(f, "{elem} {norm} = 0.0;")?;
1032
1033 for i in 0..num {
1034 let input_i = input.index(i);
1035 writeln!(f, "{norm} += {input_i} * {input_i};")?;
1036 }
1037
1038 write!(f, "{norm} = ")?;
1039 InvS::format_unary(f, &norm, elem)?;
1040 f.write_str(";\n")?;
1041
1042 if num == 1 {
1043 writeln!(f, "{out} = {input} * {norm};")
1044 } else {
1045 write!(f, "{out} = {out_item}{{")?;
1046 for i in 0..num {
1047 let input_i = input.index(i);
1048
1049 writeln!(f, "{input_i} * {norm},")?;
1050 }
1051
1052 f.write_str("};\n")
1053 }
1054 }
1055}
1056
1057struct Dot<D: Dialect> {
1058 _dialect: PhantomData<D>,
1059}
1060
1061impl<D: Dialect> Dot<D> {
1062 fn format(
1063 f: &mut core::fmt::Formatter<'_>,
1064 lhs: &Variable<D>,
1065 rhs: &Variable<D>,
1066 out: &Variable<D>,
1067 ) -> core::fmt::Result {
1068 let num = lhs.item().vectorization;
1069
1070 let muls = (0..num)
1071 .map(|i| {
1072 let lhs_i = lhs.index(i);
1073 let rhs_i = rhs.index(i);
1074 format!("{lhs_i} * {rhs_i}")
1075 })
1076 .collect::<Vec<_>>();
1077
1078 let out = out.fmt_left();
1079 writeln!(f, "{out} = {};", muls.join(" + "))
1080 }
1081}
1082
1083struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1084 var: &'a V,
1085 elem: &'a Elem<D>,
1086}
1087
1088impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1089 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1090 if self.elem != &Elem::Bool {
1091 write!(f, "bool({})", self.var)
1092 } else {
1093 write!(f, "{}", self.var)
1094 }
1095 }
1096}