1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Formatter, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone, Copy)]
17pub struct BinaryInstruction<D: Dialect> {
18 pub lhs: Variable<D>,
19 pub rhs: Variable<D>,
20 pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25 pub list: Variable<D>,
26 pub index: Variable<D>,
27 pub line_size: u32,
28 pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33 pub index: Variable<D>,
34 pub value: Variable<D>,
35 pub line_size: u32,
36 pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct UnaryInstruction<D: Dialect> {
41 pub input: Variable<D>,
42 pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47 Metadata {
48 info_offset: Variable<D>,
49 split_meta: bool,
50 out: Variable<D>,
51 },
52 ExtendedMetadata {
53 info_offset: Variable<D>,
54 dim: Variable<D>,
55 split_meta: bool,
56 static_offset: u32,
57 out: Variable<D>,
58 },
59 ConstLength {
60 length: u32,
61 out: Variable<D>,
62 },
63 SliceLength {
64 input: Variable<D>,
65 out: Variable<D>,
66 },
67 DeclareVariable {
68 var: Variable<D>,
69 },
70 Modulo(BinaryInstruction<D>),
71 Remainder(BinaryInstruction<D>),
72 Add(BinaryInstruction<D>),
73 SaturatingAdd(BinaryInstruction<D>),
74 Fma {
75 a: Variable<D>,
76 b: Variable<D>,
77 c: Variable<D>,
78 out: Variable<D>,
79 },
80 Div(BinaryInstruction<D>),
81 FastDiv(BinaryInstruction<D>),
82 FastRecip(UnaryInstruction<D>),
83 Mul(BinaryInstruction<D>),
84 Sub(BinaryInstruction<D>),
85 SaturatingSub(BinaryInstruction<D>),
86 HiMul(BinaryInstruction<D>),
87 Index(IndexInstruction<D>),
88 IndexAssign(IndexAssignInstruction<D>),
89 Assign(UnaryInstruction<D>),
90 SpecialCast(UnaryInstruction<D>),
91 RangeLoop {
92 i: Variable<D>,
93 start: Variable<D>,
94 end: Variable<D>,
95 step: Option<Variable<D>>,
96 inclusive: bool,
97 instructions: Vec<Self>,
98 },
99 VecInit {
100 inputs: Vec<Variable<D>>,
101 out: Variable<D>,
102 },
103 Loop {
104 instructions: Vec<Self>,
105 },
106 If {
107 cond: Variable<D>,
108 instructions: Vec<Self>,
109 },
110 IfElse {
111 cond: Variable<D>,
112 instructions_if: Vec<Self>,
113 instructions_else: Vec<Self>,
114 },
115 Select {
116 cond: Variable<D>,
117 then: Variable<D>,
118 or_else: Variable<D>,
119 out: Variable<D>,
120 },
121 Switch {
122 value: Variable<D>,
123 instructions_default: Vec<Self>,
124 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
125 },
126 Slice {
127 input: Variable<D>,
128 start: Variable<D>,
129 end: Variable<D>,
130 out: Variable<D>,
131 },
132 CheckedSlice {
133 input: Variable<D>,
134 start: Variable<D>,
135 end: Variable<D>,
136 out: Variable<D>,
137 len: Variable<D>,
138 },
139 ReinterpretSlice {
140 input: Variable<D>,
141 line_size: u32,
142 out: Variable<D>,
143 },
144 Return,
145 Break,
146 Equal(BinaryInstruction<D>),
147 NotEqual(BinaryInstruction<D>),
148 Lower(BinaryInstruction<D>),
149 Greater(BinaryInstruction<D>),
150 LowerEqual(BinaryInstruction<D>),
151 GreaterEqual(BinaryInstruction<D>),
152 Erf(UnaryInstruction<D>),
153 BitwiseOr(BinaryInstruction<D>),
154 BitwiseAnd(BinaryInstruction<D>),
155 BitwiseXor(BinaryInstruction<D>),
156 CountBits(UnaryInstruction<D>),
157 ReverseBits(UnaryInstruction<D>),
158 ShiftLeft(BinaryInstruction<D>),
159 ShiftRight(BinaryInstruction<D>),
160 BitwiseNot(UnaryInstruction<D>),
161 LeadingZeros(UnaryInstruction<D>),
162 FindFirstSet(UnaryInstruction<D>),
163 Abs(UnaryInstruction<D>),
164 Exp(UnaryInstruction<D>),
165 FastExp(UnaryInstruction<D>),
166 Log(UnaryInstruction<D>),
167 FastLog(UnaryInstruction<D>),
168 Log1p(UnaryInstruction<D>),
169 Cos(UnaryInstruction<D>),
170 Sin(UnaryInstruction<D>),
171 Tan(UnaryInstruction<D>),
172 Tanh(UnaryInstruction<D>),
173 Sinh(UnaryInstruction<D>),
174 Cosh(UnaryInstruction<D>),
175 ArcCos(UnaryInstruction<D>),
176 ArcSin(UnaryInstruction<D>),
177 ArcTan(UnaryInstruction<D>),
178 ArcSinh(UnaryInstruction<D>),
179 ArcCosh(UnaryInstruction<D>),
180 ArcTanh(UnaryInstruction<D>),
181 Degrees(UnaryInstruction<D>),
182 Radians(UnaryInstruction<D>),
183 ArcTan2(BinaryInstruction<D>),
184 FastSin(UnaryInstruction<D>),
185 FastCos(UnaryInstruction<D>),
186 FastTanh(UnaryInstruction<D>),
187 Powf(BinaryInstruction<D>),
188 FastPowf(BinaryInstruction<D>),
189 Powi(BinaryInstruction<D>),
190 Sqrt(UnaryInstruction<D>),
191 FastSqrt(UnaryInstruction<D>),
192 InverseSqrt(UnaryInstruction<D>),
193 FastInverseSqrt(UnaryInstruction<D>),
194 Min(BinaryInstruction<D>),
195 Max(BinaryInstruction<D>),
196 Not(UnaryInstruction<D>),
197 Or(BinaryInstruction<D>),
198 And(BinaryInstruction<D>),
199 Clamp {
200 input: Variable<D>,
201 min_value: Variable<D>,
202 max_value: Variable<D>,
203 out: Variable<D>,
204 },
205 IsNan(UnaryInstruction<D>),
206 IsInf(UnaryInstruction<D>),
207 SyncThreads,
208 SyncWarp,
209 ThreadFence,
210 ProxyAsyncToSharedFence,
211 BulkCommitGroup,
212 BulkWaitGroup {
213 max_pending: u32,
214 },
215 BulkWaitGroupRead {
216 max_pending: u32,
217 },
218 TmaReplacePointer {
219 buffer: Variable<D>,
220 offset: Variable<D>,
221 tensor_map: Variable<D>,
222 out: Variable<D>,
223 },
224 Round(UnaryInstruction<D>),
225 Ceil(UnaryInstruction<D>),
226 Trunc(UnaryInstruction<D>),
227 Floor(UnaryInstruction<D>),
228 Warp(WarpInstruction<D>),
229 Wmma(WmmaInstruction<D>),
230 Bitcast(UnaryInstruction<D>),
231 AtomicLoad(UnaryInstruction<D>),
232 AtomicStore(UnaryInstruction<D>),
233 AtomicSwap(BinaryInstruction<D>),
234 AtomicAdd(BinaryInstruction<D>),
235 AtomicSub(BinaryInstruction<D>),
236 AtomicMax(BinaryInstruction<D>),
237 AtomicMin(BinaryInstruction<D>),
238 AtomicAnd(BinaryInstruction<D>),
239 AtomicOr(BinaryInstruction<D>),
240 AtomicXor(BinaryInstruction<D>),
241 AtomicCAS {
242 input: Variable<D>,
243 cmp: Variable<D>,
244 val: Variable<D>,
245 out: Variable<D>,
246 },
247 Neg(UnaryInstruction<D>),
248 Magnitude(UnaryInstruction<D>),
249 FastMagnitude(UnaryInstruction<D>),
250 Normalize(UnaryInstruction<D>),
251 FastNormalize(UnaryInstruction<D>),
252 Dot(BinaryInstruction<D>),
253 Copy {
254 input: Variable<D>,
255 in_index: Variable<D>,
256 out: Variable<D>,
257 out_index: Variable<D>,
258 },
259 CopyBulk {
260 input: Variable<D>,
261 in_index: Variable<D>,
262 out: Variable<D>,
263 out_index: Variable<D>,
264 len: u32,
265 },
266 Printf {
267 format_string: String,
268 args: Vec<Variable<D>>,
269 },
270 Comment {
271 content: String,
272 },
273 Pipeline(PipelineOps<D>),
274 Barrier(BarrierOps<D>),
275 MemCopyAsyncTensorSharedToGlobal {
276 smem_buffer: Variable<D>,
277 smem_offset: Variable<D>,
278 tensor_map: Variable<D>,
279 indices: Vec<Variable<D>>,
280 },
281 Line {
282 file: Cow<'static, str>,
283 line: u32,
284 },
285}
286
287impl<D: Dialect> Display for Instruction<D> {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 match self {
290 Instruction::Return => f.write_str("return;"),
291 Instruction::Break => f.write_str("break;"),
292 Instruction::DeclareVariable { var } => match var {
293 Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
294 _ => {
295 let item = var.item();
296 writeln!(f, "{item} {var};")
297 }
298 },
299 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
300 Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
301 Instruction::Slice {
302 input,
303 start,
304 end,
305 out,
306 } => {
307 let item = out.item();
308 let addr_space = D::address_space_for_variable(input);
309 writeln!(f, "const uint {out}_length = {end} - {start};")?;
310 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
311 }
312 Instruction::CheckedSlice {
313 input,
314 start,
315 end,
316 out,
317 len,
318 } => {
319 let item = out.item();
320 let addr_space = D::address_space_for_variable(input);
321 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
322 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
323 }
324 Instruction::ReinterpretSlice {
325 input,
326 line_size,
327 out,
328 } => {
329 let mut item = out.item();
330 item.vectorization = *line_size as usize;
331 let addr_space = D::address_space_for_variable(input);
332
333 writeln!(
334 f,
335 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
336 )
337 }
338 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
339 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
340 Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
341 Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
342 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
343 Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
344 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
345 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
346 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
347 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
348 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
349 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
350 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
351 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
352 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
353 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
354 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
355 Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
356 Instruction::IndexAssign(it) => {
357 IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
358 }
359 Instruction::Copy {
360 input,
361 in_index,
362 out,
363 out_index,
364 } => {
365 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
366 }
367 Instruction::CopyBulk {
368 input,
369 in_index,
370 out,
371 out_index,
372 len,
373 } => {
374 for i in 0..*len {
375 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
376 }
377 Ok(())
378 }
379 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
380 Instruction::RangeLoop {
381 i,
382 start,
383 end,
384 step,
385 inclusive,
386 instructions,
387 } => {
388 let increment = step
389 .map(|step| format!("{i} += {step}"))
390 .unwrap_or_else(|| format!("++{i}"));
391 let cmp = if *inclusive { "<=" } else { "<" };
392 let i_ty = i.item();
393
394 write!(
395 f,
396 "
397for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
398"
399 )?;
400 for instruction in instructions {
401 write!(f, "{instruction}")?;
402 }
403
404 f.write_str("}\n")
405 }
406 Instruction::Loop { instructions } => {
407 writeln!(f, "while (true) {{")?;
408 for i in instructions {
409 write!(f, "{i}")?;
410 }
411 f.write_str("}\n")
412 }
413 Instruction::If { cond, instructions } => {
414 writeln!(f, "if ({cond}) {{")?;
415 for i in instructions {
416 write!(f, "{i}")?;
417 }
418 f.write_str("}\n")
419 }
420 Instruction::IfElse {
421 cond,
422 instructions_if,
423 instructions_else,
424 } => {
425 writeln!(f, "if ({cond}) {{")?;
426 for i in instructions_if {
427 write!(f, "{i}")?;
428 }
429 f.write_str("} else {\n")?;
430 for i in instructions_else {
431 write!(f, "{i}")?;
432 }
433 f.write_str("}\n")
434 }
435 Instruction::Select {
436 cond,
437 then,
438 or_else,
439 out,
440 } => {
441 let item_or_else = or_else.item();
442 let item_then = then.item();
443 let item_out = out.item();
444
445 let vf_then = item_then.vectorization;
446 let vf_or_else = item_or_else.vectorization;
447 let vf_out = item_out.vectorization;
448 let vf_cond = cond.item().vectorization;
449
450 let item_out = out.item();
451 let cond_elem = cond.item().elem;
452 let out = out.fmt_left();
453
454 let should_broadcast =
455 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
456
457 if should_broadcast {
458 let vf = usize::max(vf_cond, vf_out);
459 let vf = usize::max(vf, vf_then);
460 let vf = usize::max(vf, vf_or_else);
461
462 writeln!(f, "{out} = {item_out} {{")?;
463 for i in 0..vf {
464 let theni = then.index(i);
465 let or_elsei = or_else.index(i);
466 let condi = cond.index(i);
467 let condi = EnsureBoolArg {
468 var: &condi,
469 elem: &cond_elem,
470 };
471
472 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
473 }
474
475 writeln!(f, "}};")
476 } else {
477 let cond = EnsureBoolArg {
478 var: &cond,
479 elem: &cond_elem,
480 };
481 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
482 }
483 }
484 Instruction::Switch {
485 value,
486 instructions_default,
487 instructions_cases,
488 } => {
489 writeln!(f, "switch({value}) {{")?;
490 for (value, block) in instructions_cases {
491 write!(f, "case {value}:\n{{\n")?;
492 for i in block {
493 i.fmt(f)?;
494 }
495 f.write_str("break;\n}\n")?;
496 }
497 f.write_str("default:\n{")?;
498 for i in instructions_default {
499 i.fmt(f)?;
500 }
501 f.write_str("}\n}\n")
502 }
503 Instruction::Metadata {
504 info_offset,
505 split_meta,
506 out,
507 } => {
508 let out = out.fmt_left();
509 match *split_meta {
510 true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
511 false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
512 }
513 }
514 Instruction::ExtendedMetadata {
515 info_offset,
516 dim,
517 split_meta,
518 static_offset,
519 out,
520 } => {
521 let out = out.fmt_left();
522 match *split_meta {
523 true => writeln!(
524 f,
525 "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
526 ),
527 false => writeln!(
528 f,
529 "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
530 ),
531 }
532 }
533 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
534 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
535 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
536 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
537 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
538 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
539 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
540 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
541 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
542 Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
543 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
544 Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
545 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
546 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
547 Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
548 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
549 Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
550 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
551 Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
552 Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
553 Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
554 Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
555 Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
556 Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
557 Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
558 Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
559 Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
560 Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
561 Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
562 Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
563 Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
564 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
565 Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
566 Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
567 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
568 Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
569 Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
570 Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
571 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
572 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
573 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
574 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
575 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
576 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
577 Instruction::Clamp {
578 input,
579 min_value,
580 max_value,
581 out,
582 } => Clamp::format(f, input, min_value, max_value, out),
583 Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
584 Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
585 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
586 Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
587 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
588 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
589 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
590 Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
591 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
592 Instruction::SliceLength { input, out } => {
593 let out = out.fmt_left();
594 writeln!(f, "{out} = {input}_length;")
595 }
596 Instruction::ConstLength { length, out } => {
597 let out = out.fmt_left();
598 writeln!(f, "{out} = {length};")
599 }
600 Instruction::Warp(it) => write!(f, "{it}"),
601 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
602 Instruction::Wmma(it) => write!(f, "{it}"),
603 Instruction::Bitcast(UnaryInstruction { input, out }) => {
604 let qualifier = out.const_qualifier();
605 let input_item = input.item();
606 let out_item = out.item();
607
608 if out_item.elem.size() * out_item.vectorization
609 != input.item().elem.size() * input.item().vectorization
610 {
611 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
612 } else {
613 let out = out.fmt_left();
614 let addr_space = D::address_space_for_variable(input);
615 writeln!(
616 f,
617 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
618 )
619 }
620 }
621 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
622 D::compile_atomic_add(f, lhs, rhs, out)
623 }
624 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
625 D::compile_atomic_and(f, lhs, rhs, out)
626 }
627 Instruction::AtomicCAS {
628 input,
629 cmp,
630 val,
631 out,
632 } => D::compile_atomic_cas(f, input, cmp, val, out),
633 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
634 D::compile_atomic_load(f, input, out)
635 }
636 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
637 D::compile_atomic_max(f, lhs, rhs, out)
638 }
639 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
640 D::compile_atomic_min(f, lhs, rhs, out)
641 }
642 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
643 D::compile_atomic_or(f, lhs, rhs, out)
644 }
645 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
646 D::compile_atomic_store(f, input, out)
647 }
648 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
649 D::compile_atomic_sub(f, lhs, rhs, out)
650 }
651 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
652 D::compile_atomic_swap(f, lhs, rhs, out)
653 }
654 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
655 D::compile_atomic_xor(f, lhs, rhs, out)
656 }
657 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
658 Instruction::Neg(UnaryInstruction { input, out }) => {
659 let out = out.fmt_left();
660 writeln!(f, "{out} = -{input};")
661 }
662 Instruction::Normalize(inst) => {
663 Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
664 }
665 Instruction::FastNormalize(inst) => {
666 Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
667 }
668 Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
669 Instruction::FastMagnitude(inst) => {
670 Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
671 }
672 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
673 Instruction::VecInit { inputs, out } => {
674 let item = out.item();
675 let inputs = inputs
676 .iter()
677 .map(|input| format!("{input}"))
678 .collect::<Vec<_>>();
679 let out = out.fmt_left();
680 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
681 }
682 Instruction::Printf {
683 format_string,
684 args,
685 } => D::compile_instruction_printf(f, format_string, args),
686 Instruction::Comment { content } => {
687 if content.contains('\n') {
688 writeln!(f, "/* {content} */")
689 } else {
690 writeln!(f, "// {content}")
691 }
692 }
693 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
694 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
695 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
696 Instruction::ProxyAsyncToSharedFence => {
697 writeln!(
698 f,
699 "cuda::device::experimental::fence_proxy_async_shared_cta();"
700 )
701 }
702 Instruction::BulkCommitGroup => writeln!(
703 f,
704 "cuda::device::experimental::cp_async_bulk_commit_group();"
705 ),
706 Instruction::BulkWaitGroup { max_pending } => writeln!(
707 f,
708 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
709 ),
710 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
711 f,
712 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
713 ),
714 Instruction::TmaReplacePointer {
715 buffer,
716 offset,
717 tensor_map,
718 out,
719 } => {
720 let pos = Variable::<D>::UnitPos;
721 writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
722 writeln!(
723 f,
724 "
725if({pos} == 0) {{
726 {out} = {tensor_map};
727 tensormap_replace_global_address({out}, &{buffer}[{offset}]);
728}}"
729 )?;
730 writeln!(f, "__syncthreads();")
731 }
732 Instruction::MemCopyAsyncTensorSharedToGlobal {
733 smem_buffer,
734 smem_offset,
735 tensor_map,
736 indices,
737 } => {
738 let rank = indices.len();
739 let smem_ptr = smem_buffer.fmt_ptr();
740 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
741 let _ = write!(s, "{it}, ");
742 s
743 });
744 writeln!(
745 f,
746 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
747 )
748 }
749 Instruction::SpecialCast(UnaryInstruction { input, out }) => {
750 #[cfg(not(feature = "cuda"))]
752 {
753 let _ = (input, out);
754 writeln!(
755 f,
756 "#error FP8/FP6/FP4 casting isn't supported outside of CUDA"
757 )
758 }
759 #[cfg(feature = "cuda")]
760 crate::cuda::convert::special_cast::<D>(f, input, out)
761 }
762 }
763 }
764}
765
766struct Fma<D: Dialect> {
767 _dialect: PhantomData<D>,
768}
769
770impl<D: Dialect> Fma<D> {
771 fn format(
772 f: &mut core::fmt::Formatter<'_>,
773 a: &Variable<D>,
774 b: &Variable<D>,
775 c: &Variable<D>,
776 out: &Variable<D>,
777 ) -> core::fmt::Result {
778 let out_item = out.item();
779 let num = out_item.vectorization;
780
781 let out = out.fmt_left();
782 if num == 1 {
783 writeln!(f, "{out} = fma({a}, {b}, {c});")
784 } else {
785 writeln!(f, "{out} = {out_item}{{")?;
786
787 for i in 0..num {
788 let ai = a.index(i);
789 let bi = b.index(i);
790 let ci = c.index(i);
791
792 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
793 }
794 f.write_str("};\n")
795 }
796 }
797}
798
799struct Clamp<D: Dialect> {
800 _dialect: PhantomData<D>,
801}
802
803impl<D: Dialect> Clamp<D> {
804 fn format(
805 f: &mut core::fmt::Formatter<'_>,
806 input: &Variable<D>,
807 min_value: &Variable<D>,
808 max_value: &Variable<D>,
809 out: &Variable<D>,
810 ) -> core::fmt::Result {
811 let out_item = out.item();
812 if out.item().vectorization == 1 {
813 let out = out.fmt_left();
814 write!(f, "{out} = ")?;
815 Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
816 f.write_str(";\n")
817 } else {
818 Self::unroll_vec(f, input, min_value, max_value, out)
819 }
820 }
821
822 fn format_scalar(
823 f: &mut Formatter<'_>,
824 input: impl Component<D>,
825 min_value: impl Component<D>,
826 max_value: impl Component<D>,
827 item: Item<D>,
828 ) -> std::fmt::Result {
829 D::compile_instruction_max_function_name(f, item)?;
830 write!(f, "({min_value}, ")?;
831 D::compile_instruction_min_function_name(f, item)?;
832 write!(f, "({max_value}, {input}))")
833 }
834
835 fn unroll_vec(
836 f: &mut core::fmt::Formatter<'_>,
837 input: &Variable<D>,
838 min_value: &Variable<D>,
839 max_value: &Variable<D>,
840 out: &Variable<D>,
841 ) -> std::fmt::Result {
842 let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
843 let [input, min_value, max_value, out_optimized] = optimized.args;
844
845 let item_out_original = out.item();
846 let item_out_optimized = out_optimized.item();
847
848 let index = match optimized.optimization_factor {
849 Some(factor) => item_out_original.vectorization / factor,
850 None => item_out_optimized.vectorization,
851 };
852
853 let mut write_op = |input: &Variable<D>,
854 min_value: &Variable<D>,
855 max_value: &Variable<D>,
856 out: &Variable<D>,
857 item_out: Item<D>| {
858 let out = out.fmt_left();
859 writeln!(f, "{out} = {item_out}{{")?;
860 for i in 0..index {
861 let inputi = input.index(i);
862 let min_valuei = min_value.index(i);
863 let max_valuei = max_value.index(i);
864
865 Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
866 f.write_str(", ")?;
867 }
868
869 f.write_str("};\n")
870 };
871
872 if item_out_original == item_out_optimized {
873 write_op(&input, &min_value, &max_value, out, item_out_optimized)
874 } else {
875 let out_tmp = Variable::tmp(item_out_optimized);
876 write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
877 let addr_space = D::address_space_for_variable(out);
878 let out = out.fmt_left();
879
880 writeln!(
881 f,
882 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
883 )?;
884
885 Ok(())
886 }
887 }
888}
889
890struct Remainder<D: Dialect> {
891 _dialect: PhantomData<D>,
892}
893
894impl<D: Dialect> Remainder<D> {
895 fn format(
896 f: &mut core::fmt::Formatter<'_>,
897 lhs: &Variable<D>,
898 rhs: &Variable<D>,
899 out: &Variable<D>,
900 ) -> core::fmt::Result {
901 let floor = |elem| {
902 let prefix = match elem {
903 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
904 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
905 _ => "",
906 };
907 format!("{prefix}floor")
908 };
909
910 let is_int = matches!(
911 out.elem(),
912 Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
913 );
914 let rem_expr = |lhs, rhs, floor| {
915 if is_int {
916 format!("{lhs} - {rhs} * {floor}((float){lhs} / (float){rhs})")
917 } else {
918 format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
919 }
920 };
921
922 if out.item().vectorization == 1 {
923 let floor = floor(out.elem());
924
925 let out = out.fmt_left();
926 let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
927 return writeln!(f, "{out} = {rem};");
928 }
929
930 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
931 let [lhs, rhs, out_optimized] = optimized.args;
932
933 let item_out_original = out.item();
934 let item_out_optimized = out_optimized.item();
935
936 let index = match optimized.optimization_factor {
937 Some(factor) => item_out_original.vectorization / factor,
938 None => item_out_optimized.vectorization,
939 };
940
941 let floor = floor(*item_out_optimized.elem());
942
943 let mut write_op =
944 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
945 let out = out.fmt_left();
946 writeln!(f, "{out} = {item_out}{{")?;
947 for i in 0..index {
948 let lhsi = lhs.index(i);
949 let rhsi = rhs.index(i);
950
951 let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
952 writeln!(f, "{rem}")?;
953 f.write_str(", ")?;
954 }
955
956 f.write_str("};\n")
957 };
958
959 if item_out_original == item_out_optimized {
960 write_op(&lhs, &rhs, out, item_out_optimized)
961 } else {
962 let out_tmp = Variable::tmp(item_out_optimized);
963
964 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
965
966 let addr_space = D::address_space_for_variable(&out_tmp);
967 let qualifier = out.const_qualifier();
968 let out = out.fmt_left();
969
970 writeln!(
971 f,
972 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
973 )?;
974
975 Ok(())
976 }
977 }
978}
979
980struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
981 _dialect: PhantomData<D>,
982 _sqrt: PhantomData<S>,
983}
984
985impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
986 fn format(
987 f: &mut core::fmt::Formatter<'_>,
988 input: &Variable<D>,
989 out: &Variable<D>,
990 ) -> core::fmt::Result {
991 let num = input.item().vectorization;
992 let elem = input.elem();
993
994 let mag = format!("{out}_mag");
995
996 writeln!(f, "{} {mag} = 0.0;", out.item())?;
997
998 for i in 0..num {
999 let input_i = input.index(i);
1000 writeln!(f, "{mag} += {input_i} * {input_i};")?;
1001 }
1002
1003 let out = out.fmt_left();
1004 write!(f, "{out} = ")?;
1005 S::format_unary(f, &mag, elem)?;
1006 f.write_str(";\n")
1007 }
1008}
1009
1010struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1011 _dialect: PhantomData<D>,
1012 _rsqrt: PhantomData<InvS>,
1013}
1014
1015impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1016 fn format(
1017 f: &mut core::fmt::Formatter<'_>,
1018 input: &Variable<D>,
1019 out: &Variable<D>,
1020 ) -> core::fmt::Result {
1021 let num = input.item().vectorization;
1022 let elem = input.elem();
1023 let norm = format!("{out}_norm");
1024
1025 let out_item = out.item();
1026 let out = out.fmt_left();
1027 writeln!(f, "{elem} {norm} = 0.0;")?;
1028
1029 for i in 0..num {
1030 let input_i = input.index(i);
1031 writeln!(f, "{norm} += {input_i} * {input_i};")?;
1032 }
1033
1034 write!(f, "{norm} = ")?;
1035 InvS::format_unary(f, &norm, elem)?;
1036 f.write_str(";\n")?;
1037
1038 if num == 1 {
1039 writeln!(f, "{out} = {input} * {norm};")
1040 } else {
1041 write!(f, "{out} = {out_item}{{")?;
1042 for i in 0..num {
1043 let input_i = input.index(i);
1044
1045 writeln!(f, "{input_i} * {norm},")?;
1046 }
1047
1048 f.write_str("};\n")
1049 }
1050 }
1051}
1052
1053struct Dot<D: Dialect> {
1054 _dialect: PhantomData<D>,
1055}
1056
1057impl<D: Dialect> Dot<D> {
1058 fn format(
1059 f: &mut core::fmt::Formatter<'_>,
1060 lhs: &Variable<D>,
1061 rhs: &Variable<D>,
1062 out: &Variable<D>,
1063 ) -> core::fmt::Result {
1064 let num = lhs.item().vectorization;
1065
1066 let muls = (0..num)
1067 .map(|i| {
1068 let lhs_i = lhs.index(i);
1069 let rhs_i = rhs.index(i);
1070 format!("{lhs_i} * {rhs_i}")
1071 })
1072 .collect::<Vec<_>>();
1073
1074 let out = out.fmt_left();
1075 writeln!(f, "{out} = {};", muls.join(" + "))
1076 }
1077}
1078
1079struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1080 var: &'a V,
1081 elem: &'a Elem<D>,
1082}
1083
1084impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1085 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1086 if self.elem != &Elem::Bool {
1087 write!(f, "bool({})", self.var)
1088 } else {
1089 write!(f, "{}", self.var)
1090 }
1091 }
1092}