1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Formatter, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone, Copy)]
17pub struct BinaryInstruction<D: Dialect> {
18 pub lhs: Variable<D>,
19 pub rhs: Variable<D>,
20 pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25 pub list: Variable<D>,
26 pub index: Variable<D>,
27 pub line_size: u32,
28 pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33 pub index: Variable<D>,
34 pub value: Variable<D>,
35 pub line_size: u32,
36 pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct UnaryInstruction<D: Dialect> {
41 pub input: Variable<D>,
42 pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47 Metadata {
48 info_offset: Variable<D>,
49 split_meta: bool,
50 out: Variable<D>,
51 },
52 ExtendedMetadata {
53 info_offset: Variable<D>,
54 dim: Variable<D>,
55 split_meta: bool,
56 static_offset: u32,
57 out: Variable<D>,
58 },
59 ConstLength {
60 length: u32,
61 out: Variable<D>,
62 },
63 SliceLength {
64 input: Variable<D>,
65 out: Variable<D>,
66 },
67 DeclareVariable {
68 var: Variable<D>,
69 },
70 Modulo(BinaryInstruction<D>),
71 Remainder(BinaryInstruction<D>),
72 Add(BinaryInstruction<D>),
73 SaturatingAdd(BinaryInstruction<D>),
74 Fma {
75 a: Variable<D>,
76 b: Variable<D>,
77 c: Variable<D>,
78 out: Variable<D>,
79 },
80 Div(BinaryInstruction<D>),
81 FastDiv(BinaryInstruction<D>),
82 FastRecip(UnaryInstruction<D>),
83 Mul(BinaryInstruction<D>),
84 Sub(BinaryInstruction<D>),
85 SaturatingSub(BinaryInstruction<D>),
86 HiMul(BinaryInstruction<D>),
87 Index(IndexInstruction<D>),
88 IndexAssign(IndexAssignInstruction<D>),
89 Assign(UnaryInstruction<D>),
90 SpecialCast(UnaryInstruction<D>),
91 RangeLoop {
92 i: Variable<D>,
93 start: Variable<D>,
94 end: Variable<D>,
95 step: Option<Variable<D>>,
96 inclusive: bool,
97 instructions: Vec<Self>,
98 },
99 VecInit {
100 inputs: Vec<Variable<D>>,
101 out: Variable<D>,
102 },
103 Loop {
104 instructions: Vec<Self>,
105 },
106 If {
107 cond: Variable<D>,
108 instructions: Vec<Self>,
109 },
110 IfElse {
111 cond: Variable<D>,
112 instructions_if: Vec<Self>,
113 instructions_else: Vec<Self>,
114 },
115 Select {
116 cond: Variable<D>,
117 then: Variable<D>,
118 or_else: Variable<D>,
119 out: Variable<D>,
120 },
121 Switch {
122 value: Variable<D>,
123 instructions_default: Vec<Self>,
124 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
125 },
126 Slice {
127 input: Variable<D>,
128 start: Variable<D>,
129 end: Variable<D>,
130 out: Variable<D>,
131 },
132 CheckedSlice {
133 input: Variable<D>,
134 start: Variable<D>,
135 end: Variable<D>,
136 out: Variable<D>,
137 len: Variable<D>,
138 },
139 ReinterpretSlice {
140 input: Variable<D>,
141 line_size: u32,
142 out: Variable<D>,
143 },
144 Return,
145 Break,
146 Equal(BinaryInstruction<D>),
147 NotEqual(BinaryInstruction<D>),
148 Lower(BinaryInstruction<D>),
149 Greater(BinaryInstruction<D>),
150 LowerEqual(BinaryInstruction<D>),
151 GreaterEqual(BinaryInstruction<D>),
152 Erf(UnaryInstruction<D>),
153 BitwiseOr(BinaryInstruction<D>),
154 BitwiseAnd(BinaryInstruction<D>),
155 BitwiseXor(BinaryInstruction<D>),
156 CountBits(UnaryInstruction<D>),
157 ReverseBits(UnaryInstruction<D>),
158 ShiftLeft(BinaryInstruction<D>),
159 ShiftRight(BinaryInstruction<D>),
160 BitwiseNot(UnaryInstruction<D>),
161 LeadingZeros(UnaryInstruction<D>),
162 FindFirstSet(UnaryInstruction<D>),
163 Abs(UnaryInstruction<D>),
164 Exp(UnaryInstruction<D>),
165 FastExp(UnaryInstruction<D>),
166 Log(UnaryInstruction<D>),
167 FastLog(UnaryInstruction<D>),
168 Log1p(UnaryInstruction<D>),
169 Cos(UnaryInstruction<D>),
170 Sin(UnaryInstruction<D>),
171 Tan(UnaryInstruction<D>),
172 Tanh(UnaryInstruction<D>),
173 Sinh(UnaryInstruction<D>),
174 Cosh(UnaryInstruction<D>),
175 ArcCos(UnaryInstruction<D>),
176 ArcSin(UnaryInstruction<D>),
177 ArcTan(UnaryInstruction<D>),
178 ArcSinh(UnaryInstruction<D>),
179 ArcCosh(UnaryInstruction<D>),
180 ArcTanh(UnaryInstruction<D>),
181 Degrees(UnaryInstruction<D>),
182 Radians(UnaryInstruction<D>),
183 ArcTan2(BinaryInstruction<D>),
184 FastSin(UnaryInstruction<D>),
185 FastCos(UnaryInstruction<D>),
186 FastTanh(UnaryInstruction<D>),
187 Powf(BinaryInstruction<D>),
188 FastPowf(BinaryInstruction<D>),
189 Powi(BinaryInstruction<D>),
190 Sqrt(UnaryInstruction<D>),
191 FastSqrt(UnaryInstruction<D>),
192 InverseSqrt(UnaryInstruction<D>),
193 FastInverseSqrt(UnaryInstruction<D>),
194 Min(BinaryInstruction<D>),
195 Max(BinaryInstruction<D>),
196 Not(UnaryInstruction<D>),
197 Or(BinaryInstruction<D>),
198 And(BinaryInstruction<D>),
199 Clamp {
200 input: Variable<D>,
201 min_value: Variable<D>,
202 max_value: Variable<D>,
203 out: Variable<D>,
204 },
205 IsNan(UnaryInstruction<D>),
206 IsInf(UnaryInstruction<D>),
207 SyncThreads,
208 SyncWarp,
209 ThreadFence,
210 ProxyAsyncToSharedFence,
211 BulkCommitGroup,
212 BulkWaitGroup {
213 max_pending: u32,
214 },
215 BulkWaitGroupRead {
216 max_pending: u32,
217 },
218 TmaReplacePointer {
219 buffer: Variable<D>,
220 offset: Variable<D>,
221 tensor_map: Variable<D>,
222 out: Variable<D>,
223 },
224 Round(UnaryInstruction<D>),
225 Ceil(UnaryInstruction<D>),
226 Trunc(UnaryInstruction<D>),
227 Floor(UnaryInstruction<D>),
228 Warp(WarpInstruction<D>),
229 Wmma(WmmaInstruction<D>),
230 Bitcast(UnaryInstruction<D>),
231 AtomicLoad(UnaryInstruction<D>),
232 AtomicStore(UnaryInstruction<D>),
233 AtomicSwap(BinaryInstruction<D>),
234 AtomicAdd(BinaryInstruction<D>),
235 AtomicSub(BinaryInstruction<D>),
236 AtomicMax(BinaryInstruction<D>),
237 AtomicMin(BinaryInstruction<D>),
238 AtomicAnd(BinaryInstruction<D>),
239 AtomicOr(BinaryInstruction<D>),
240 AtomicXor(BinaryInstruction<D>),
241 AtomicCAS {
242 input: Variable<D>,
243 cmp: Variable<D>,
244 val: Variable<D>,
245 out: Variable<D>,
246 },
247 Neg(UnaryInstruction<D>),
248 Magnitude(UnaryInstruction<D>),
249 FastMagnitude(UnaryInstruction<D>),
250 Normalize(UnaryInstruction<D>),
251 FastNormalize(UnaryInstruction<D>),
252 Dot(BinaryInstruction<D>),
253 Copy {
254 input: Variable<D>,
255 in_index: Variable<D>,
256 out: Variable<D>,
257 out_index: Variable<D>,
258 },
259 CopyBulk {
260 input: Variable<D>,
261 in_index: Variable<D>,
262 out: Variable<D>,
263 out_index: Variable<D>,
264 len: u32,
265 },
266 Printf {
267 format_string: String,
268 args: Vec<Variable<D>>,
269 },
270 Comment {
271 content: String,
272 },
273 Pipeline(PipelineOps<D>),
274 Barrier(BarrierOps<D>),
275 MemCopyAsyncTensorSharedToGlobal {
276 smem_buffer: Variable<D>,
277 smem_offset: Variable<D>,
278 tensor_map: Variable<D>,
279 indices: Vec<Variable<D>>,
280 },
281 Line {
282 file: Cow<'static, str>,
283 line: u32,
284 },
285}
286
287impl<D: Dialect> Display for Instruction<D> {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 match self {
290 Instruction::Return => f.write_str("return;"),
291 Instruction::Break => f.write_str("break;"),
292 Instruction::DeclareVariable { var } => match var {
293 Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
294 _ => {
295 let item = var.item();
296 writeln!(f, "{item} {var};")
297 }
298 },
299 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
300 Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
301 Instruction::Slice {
302 input,
303 start,
304 end,
305 out,
306 } => {
307 let item = out.item();
308 let addr_space = D::address_space_for_variable(input);
309 writeln!(f, "const uint {out}_length = {end} - {start};")?;
310 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
311 }
312 Instruction::CheckedSlice {
313 input,
314 start,
315 end,
316 out,
317 len,
318 } => {
319 let item = out.item();
320 let addr_space = D::address_space_for_variable(input);
321 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
322 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
323 }
324 Instruction::ReinterpretSlice {
325 input,
326 line_size,
327 out,
328 } => {
329 let mut item = out.item();
330 item.vectorization = *line_size as usize;
331 let addr_space = D::address_space_for_variable(input);
332
333 writeln!(
334 f,
335 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
336 )
337 }
338 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
339 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
340 Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
341 Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
342 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
343 Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
344 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
345 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
346 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
347 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
348 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
349 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
350 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
351 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
352 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
353 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
354 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
355 Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
356 Instruction::IndexAssign(it) => {
357 IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
358 }
359 Instruction::Copy {
360 input,
361 in_index,
362 out,
363 out_index,
364 } => {
365 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
366 }
367 Instruction::CopyBulk {
368 input,
369 in_index,
370 out,
371 out_index,
372 len,
373 } => {
374 for i in 0..*len {
375 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
376 }
377 Ok(())
378 }
379 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
380 Instruction::RangeLoop {
381 i,
382 start,
383 end,
384 step,
385 inclusive,
386 instructions,
387 } => {
388 let increment = step
389 .map(|step| format!("{i} += {step}"))
390 .unwrap_or_else(|| format!("++{i}"));
391 let cmp = if *inclusive { "<=" } else { "<" };
392 let i_ty = i.item();
393
394 write!(
395 f,
396 "
397for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
398"
399 )?;
400 for instruction in instructions {
401 write!(f, "{instruction}")?;
402 }
403
404 f.write_str("}\n")
405 }
406 Instruction::Loop { instructions } => {
407 writeln!(f, "while (true) {{")?;
408 for i in instructions {
409 write!(f, "{i}")?;
410 }
411 f.write_str("}\n")
412 }
413 Instruction::If { cond, instructions } => {
414 writeln!(f, "if ({cond}) {{")?;
415 for i in instructions {
416 write!(f, "{i}")?;
417 }
418 f.write_str("}\n")
419 }
420 Instruction::IfElse {
421 cond,
422 instructions_if,
423 instructions_else,
424 } => {
425 writeln!(f, "if ({cond}) {{")?;
426 for i in instructions_if {
427 write!(f, "{i}")?;
428 }
429 f.write_str("} else {\n")?;
430 for i in instructions_else {
431 write!(f, "{i}")?;
432 }
433 f.write_str("}\n")
434 }
435 Instruction::Select {
436 cond,
437 then,
438 or_else,
439 out,
440 } => {
441 let item_or_else = or_else.item();
442 let item_then = then.item();
443 let item_out = out.item();
444
445 let vf_then = item_then.vectorization;
446 let vf_or_else = item_or_else.vectorization;
447 let vf_out = item_out.vectorization;
448 let vf_cond = cond.item().vectorization;
449
450 let item_out = out.item();
451 let cond_elem = cond.item().elem;
452 let out = out.fmt_left();
453
454 let should_broadcast =
455 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
456
457 if should_broadcast {
458 let vf = usize::max(vf_cond, vf_out);
459 let vf = usize::max(vf, vf_then);
460 let vf = usize::max(vf, vf_or_else);
461
462 writeln!(f, "{out} = {item_out} {{")?;
463 for i in 0..vf {
464 let theni = then.index(i);
465 let or_elsei = or_else.index(i);
466 let condi = cond.index(i);
467 let condi = EnsureBoolArg {
468 var: &condi,
469 elem: &cond_elem,
470 };
471
472 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
473 }
474
475 writeln!(f, "}};")
476 } else {
477 let cond = EnsureBoolArg {
478 var: &cond,
479 elem: &cond_elem,
480 };
481 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
482 }
483 }
484 Instruction::Switch {
485 value,
486 instructions_default,
487 instructions_cases,
488 } => {
489 writeln!(f, "switch({value}) {{")?;
490 for (value, block) in instructions_cases {
491 write!(f, "case {value}:\n{{\n")?;
492 for i in block {
493 i.fmt(f)?;
494 }
495 f.write_str("break;\n}\n")?;
496 }
497 f.write_str("default:\n{")?;
498 for i in instructions_default {
499 i.fmt(f)?;
500 }
501 f.write_str("}\n}\n")
502 }
503 Instruction::Metadata {
504 info_offset,
505 split_meta,
506 out,
507 } => {
508 let out = out.fmt_left();
509 match *split_meta {
510 true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
511 false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
512 }
513 }
514 Instruction::ExtendedMetadata {
515 info_offset,
516 dim,
517 split_meta,
518 static_offset,
519 out,
520 } => {
521 let out = out.fmt_left();
522 match *split_meta {
523 true => writeln!(
524 f,
525 "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
526 ),
527 false => writeln!(
528 f,
529 "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
530 ),
531 }
532 }
533 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
534 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
535 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
536 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
537 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
538 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
539 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
540 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
541 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
542 Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
543 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
544 Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
545 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
546 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
547 Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
548 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
549 Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
550 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
551 Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
552 Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
553 Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
554 Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
555 Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
556 Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
557 Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
558 Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
559 Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
560 Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
561 Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
562 Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
563 Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
564 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
565 Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
566 Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
567 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
568 Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
569 Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
570 Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
571 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
572 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
573 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
574 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
575 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
576 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
577 Instruction::Clamp {
578 input,
579 min_value,
580 max_value,
581 out,
582 } => Clamp::format(f, input, min_value, max_value, out),
583 Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
584 Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
585 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
586 Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
587 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
588 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
589 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
590 Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
591 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
592 Instruction::SliceLength { input, out } => {
593 let out = out.fmt_left();
594 writeln!(f, "{out} = {input}_length;")
595 }
596 Instruction::ConstLength { length, out } => {
597 let out = out.fmt_left();
598 writeln!(f, "{out} = {length};")
599 }
600 Instruction::Warp(it) => write!(f, "{it}"),
601 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
602 Instruction::Wmma(it) => write!(f, "{it}"),
603 Instruction::Bitcast(UnaryInstruction { input, out }) => {
604 let qualifier = out.const_qualifier();
605 let input_item = input.item();
606 let out_item = out.item();
607
608 if out_item.elem.size() * out_item.vectorization
609 != input.item().elem.size() * input.item().vectorization
610 {
611 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
612 } else {
613 let out = out.fmt_left();
614 let addr_space = D::address_space_for_variable(input);
615 writeln!(
616 f,
617 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
618 )
619 }
620 }
621 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
622 D::compile_atomic_add(f, lhs, rhs, out)
623 }
624 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
625 D::compile_atomic_and(f, lhs, rhs, out)
626 }
627 Instruction::AtomicCAS {
628 input,
629 cmp,
630 val,
631 out,
632 } => D::compile_atomic_cas(f, input, cmp, val, out),
633 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
634 D::compile_atomic_load(f, input, out)
635 }
636 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
637 D::compile_atomic_max(f, lhs, rhs, out)
638 }
639 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
640 D::compile_atomic_min(f, lhs, rhs, out)
641 }
642 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
643 D::compile_atomic_or(f, lhs, rhs, out)
644 }
645 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
646 D::compile_atomic_store(f, input, out)
647 }
648 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
649 D::compile_atomic_sub(f, lhs, rhs, out)
650 }
651 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
652 D::compile_atomic_swap(f, lhs, rhs, out)
653 }
654 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
655 D::compile_atomic_xor(f, lhs, rhs, out)
656 }
657 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
658 Instruction::Neg(UnaryInstruction { input, out }) => {
659 let out = out.fmt_left();
660 writeln!(f, "{out} = -{input};")
661 }
662 Instruction::Normalize(inst) => {
663 Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
664 }
665 Instruction::FastNormalize(inst) => {
666 Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
667 }
668 Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
669 Instruction::FastMagnitude(inst) => {
670 Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
671 }
672 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
673 Instruction::VecInit { inputs, out } => {
674 let item = out.item();
675 let inputs = inputs
676 .iter()
677 .map(|input| format!("{input}"))
678 .collect::<Vec<_>>();
679 let out = out.fmt_left();
680 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
681 }
682 Instruction::Printf {
683 format_string,
684 args,
685 } => D::compile_instruction_printf(f, format_string, args),
686 Instruction::Comment { content } => {
687 if content.contains('\n') {
688 writeln!(f, "/* {content} */")
689 } else {
690 writeln!(f, "// {content}")
691 }
692 }
693 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
694 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
695 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
696 Instruction::ProxyAsyncToSharedFence => {
697 writeln!(
698 f,
699 "cuda::device::experimental::fence_proxy_async_shared_cta();"
700 )
701 }
702 Instruction::BulkCommitGroup => writeln!(
703 f,
704 "cuda::device::experimental::cp_async_bulk_commit_group();"
705 ),
706 Instruction::BulkWaitGroup { max_pending } => writeln!(
707 f,
708 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
709 ),
710 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
711 f,
712 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
713 ),
714 Instruction::TmaReplacePointer {
715 buffer,
716 offset,
717 tensor_map,
718 out,
719 } => {
720 let pos = Variable::<D>::UnitPos;
721 writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
722 writeln!(
723 f,
724 "
725if({pos} == 0) {{
726 {out} = {tensor_map};
727 tensormap_replace_global_address({out}, &{buffer}[{offset}]);
728}}"
729 )?;
730 writeln!(f, "__syncthreads();")
731 }
732 Instruction::MemCopyAsyncTensorSharedToGlobal {
733 smem_buffer,
734 smem_offset,
735 tensor_map,
736 indices,
737 } => {
738 let rank = indices.len();
739 let smem_ptr = smem_buffer.fmt_ptr();
740 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
741 let _ = write!(s, "{it}, ");
742 s
743 });
744 writeln!(
745 f,
746 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
747 )
748 }
749 Instruction::SpecialCast(UnaryInstruction { input, out }) => {
750 #[cfg(not(feature = "cuda"))]
752 {
753 let _ = (input, out);
754 unimplemented!("FP8/FP6/FP4 casting isn't supported outside of CUDA");
755 }
756 #[cfg(feature = "cuda")]
757 crate::cuda::convert::special_cast::<D>(f, input, out)
758 }
759 }
760 }
761}
762
763struct Fma<D: Dialect> {
764 _dialect: PhantomData<D>,
765}
766
767impl<D: Dialect> Fma<D> {
768 fn format(
769 f: &mut core::fmt::Formatter<'_>,
770 a: &Variable<D>,
771 b: &Variable<D>,
772 c: &Variable<D>,
773 out: &Variable<D>,
774 ) -> core::fmt::Result {
775 let out_item = out.item();
776 let num = out_item.vectorization;
777
778 let out = out.fmt_left();
779 if num == 1 {
780 writeln!(f, "{out} = fma({a}, {b}, {c});")
781 } else {
782 writeln!(f, "{out} = {out_item}{{")?;
783
784 for i in 0..num {
785 let ai = a.index(i);
786 let bi = b.index(i);
787 let ci = c.index(i);
788
789 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
790 }
791 f.write_str("};\n")
792 }
793 }
794}
795
796struct Clamp<D: Dialect> {
797 _dialect: PhantomData<D>,
798}
799
800impl<D: Dialect> Clamp<D> {
801 fn format(
802 f: &mut core::fmt::Formatter<'_>,
803 input: &Variable<D>,
804 min_value: &Variable<D>,
805 max_value: &Variable<D>,
806 out: &Variable<D>,
807 ) -> core::fmt::Result {
808 let out_item = out.item();
809 if out.item().vectorization == 1 {
810 let out = out.fmt_left();
811 write!(f, "{out} = ")?;
812 Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
813 f.write_str(";\n")
814 } else {
815 Self::unroll_vec(f, input, min_value, max_value, out)
816 }
817 }
818
819 fn format_scalar(
820 f: &mut Formatter<'_>,
821 input: impl Component<D>,
822 min_value: impl Component<D>,
823 max_value: impl Component<D>,
824 item: Item<D>,
825 ) -> std::fmt::Result {
826 D::compile_instruction_max_function_name(f, item)?;
827 write!(f, "({min_value}, ")?;
828 D::compile_instruction_min_function_name(f, item)?;
829 write!(f, "({max_value}, {input}))")
830 }
831
832 fn unroll_vec(
833 f: &mut core::fmt::Formatter<'_>,
834 input: &Variable<D>,
835 min_value: &Variable<D>,
836 max_value: &Variable<D>,
837 out: &Variable<D>,
838 ) -> std::fmt::Result {
839 let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
840 let [input, min_value, max_value, out_optimized] = optimized.args;
841
842 let item_out_original = out.item();
843 let item_out_optimized = out_optimized.item();
844
845 let index = match optimized.optimization_factor {
846 Some(factor) => item_out_original.vectorization / factor,
847 None => item_out_optimized.vectorization,
848 };
849
850 let mut write_op = |input: &Variable<D>,
851 min_value: &Variable<D>,
852 max_value: &Variable<D>,
853 out: &Variable<D>,
854 item_out: Item<D>| {
855 let out = out.fmt_left();
856 writeln!(f, "{out} = {item_out}{{")?;
857 for i in 0..index {
858 let inputi = input.index(i);
859 let min_valuei = min_value.index(i);
860 let max_valuei = max_value.index(i);
861
862 Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
863 f.write_str(", ")?;
864 }
865
866 f.write_str("};\n")
867 };
868
869 if item_out_original == item_out_optimized {
870 write_op(&input, &min_value, &max_value, out, item_out_optimized)
871 } else {
872 let out_tmp = Variable::tmp(item_out_optimized);
873 write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
874 let addr_space = D::address_space_for_variable(out);
875 let out = out.fmt_left();
876
877 writeln!(
878 f,
879 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
880 )?;
881
882 Ok(())
883 }
884 }
885}
886
887struct Remainder<D: Dialect> {
888 _dialect: PhantomData<D>,
889}
890
891impl<D: Dialect> Remainder<D> {
892 fn format(
893 f: &mut core::fmt::Formatter<'_>,
894 lhs: &Variable<D>,
895 rhs: &Variable<D>,
896 out: &Variable<D>,
897 ) -> core::fmt::Result {
898 let floor = |elem| {
899 let prefix = match elem {
900 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
901 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
902 _ => "",
903 };
904 format!("{prefix}floor")
905 };
906
907 let is_int = matches!(
908 out.elem(),
909 Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
910 );
911 let rem_expr = |lhs, rhs, floor| {
912 if is_int {
913 format!("{lhs} - {rhs} * {floor}((float){lhs} / (float){rhs})")
914 } else {
915 format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
916 }
917 };
918
919 if out.item().vectorization == 1 {
920 let floor = floor(out.elem());
921
922 let out = out.fmt_left();
923 let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
924 return writeln!(f, "{out} = {rem};");
925 }
926
927 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
928 let [lhs, rhs, out_optimized] = optimized.args;
929
930 let item_out_original = out.item();
931 let item_out_optimized = out_optimized.item();
932
933 let index = match optimized.optimization_factor {
934 Some(factor) => item_out_original.vectorization / factor,
935 None => item_out_optimized.vectorization,
936 };
937
938 let floor = floor(*item_out_optimized.elem());
939
940 let mut write_op =
941 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
942 let out = out.fmt_left();
943 writeln!(f, "{out} = {item_out}{{")?;
944 for i in 0..index {
945 let lhsi = lhs.index(i);
946 let rhsi = rhs.index(i);
947
948 let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
949 writeln!(f, "{rem}")?;
950 f.write_str(", ")?;
951 }
952
953 f.write_str("};\n")
954 };
955
956 if item_out_original == item_out_optimized {
957 write_op(&lhs, &rhs, out, item_out_optimized)
958 } else {
959 let out_tmp = Variable::tmp(item_out_optimized);
960
961 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
962
963 let addr_space = D::address_space_for_variable(&out_tmp);
964 let qualifier = out.const_qualifier();
965 let out = out.fmt_left();
966
967 writeln!(
968 f,
969 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
970 )?;
971
972 Ok(())
973 }
974 }
975}
976
977struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
978 _dialect: PhantomData<D>,
979 _sqrt: PhantomData<S>,
980}
981
982impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
983 fn format(
984 f: &mut core::fmt::Formatter<'_>,
985 input: &Variable<D>,
986 out: &Variable<D>,
987 ) -> core::fmt::Result {
988 let num = input.item().vectorization;
989 let elem = input.elem();
990
991 let mag = format!("{out}_mag");
992
993 writeln!(f, "{} {mag} = 0.0;", out.item())?;
994
995 for i in 0..num {
996 let input_i = input.index(i);
997 writeln!(f, "{mag} += {input_i} * {input_i};")?;
998 }
999
1000 let out = out.fmt_left();
1001 write!(f, "{out} = ")?;
1002 S::format_unary(f, &mag, elem)?;
1003 f.write_str(";\n")
1004 }
1005}
1006
1007struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1008 _dialect: PhantomData<D>,
1009 _rsqrt: PhantomData<InvS>,
1010}
1011
1012impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1013 fn format(
1014 f: &mut core::fmt::Formatter<'_>,
1015 input: &Variable<D>,
1016 out: &Variable<D>,
1017 ) -> core::fmt::Result {
1018 let num = input.item().vectorization;
1019 let elem = input.elem();
1020 let norm = format!("{out}_norm");
1021
1022 let out_item = out.item();
1023 let out = out.fmt_left();
1024 writeln!(f, "{elem} {norm} = 0.0;")?;
1025
1026 for i in 0..num {
1027 let input_i = input.index(i);
1028 writeln!(f, "{norm} += {input_i} * {input_i};")?;
1029 }
1030
1031 write!(f, "{norm} = ")?;
1032 InvS::format_unary(f, &norm, elem)?;
1033 f.write_str(";\n")?;
1034
1035 if num == 1 {
1036 writeln!(f, "{out} = {input} * {norm};")
1037 } else {
1038 write!(f, "{out} = {out_item}{{")?;
1039 for i in 0..num {
1040 let input_i = input.index(i);
1041
1042 writeln!(f, "{input_i} * {norm},")?;
1043 }
1044
1045 f.write_str("};\n")
1046 }
1047 }
1048}
1049
1050struct Dot<D: Dialect> {
1051 _dialect: PhantomData<D>,
1052}
1053
1054impl<D: Dialect> Dot<D> {
1055 fn format(
1056 f: &mut core::fmt::Formatter<'_>,
1057 lhs: &Variable<D>,
1058 rhs: &Variable<D>,
1059 out: &Variable<D>,
1060 ) -> core::fmt::Result {
1061 let num = lhs.item().vectorization;
1062
1063 let muls = (0..num)
1064 .map(|i| {
1065 let lhs_i = lhs.index(i);
1066 let rhs_i = rhs.index(i);
1067 format!("{lhs_i} * {rhs_i}")
1068 })
1069 .collect::<Vec<_>>();
1070
1071 let out = out.fmt_left();
1072 writeln!(f, "{out} = {};", muls.join(" + "))
1073 }
1074}
1075
1076struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1077 var: &'a V,
1078 elem: &'a Elem<D>,
1079}
1080
1081impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1082 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1083 if self.elem != &Elem::Bool {
1084 write!(f, "bool({})", self.var)
1085 } else {
1086 write!(f, "{}", self.var)
1087 }
1088 }
1089}