1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Formatter, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone, Copy)]
17pub struct BinaryInstruction<D: Dialect> {
18 pub lhs: Variable<D>,
19 pub rhs: Variable<D>,
20 pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25 pub list: Variable<D>,
26 pub index: Variable<D>,
27 pub line_size: u32,
28 pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33 pub index: Variable<D>,
34 pub value: Variable<D>,
35 pub line_size: u32,
36 pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct UnaryInstruction<D: Dialect> {
41 pub input: Variable<D>,
42 pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47 Metadata {
48 info_offset: Variable<D>,
49 split_meta: bool,
50 out: Variable<D>,
51 },
52 ExtendedMetadata {
53 info_offset: Variable<D>,
54 dim: Variable<D>,
55 split_meta: bool,
56 static_offset: u32,
57 out: Variable<D>,
58 },
59 ConstLength {
60 length: usize,
61 out: Variable<D>,
62 },
63 SliceLength {
64 input: Variable<D>,
65 out: Variable<D>,
66 },
67 DeclareVariable {
68 var: Variable<D>,
69 },
70 Modulo(BinaryInstruction<D>),
71 Remainder(BinaryInstruction<D>),
72 Add(BinaryInstruction<D>),
73 SaturatingAdd(BinaryInstruction<D>),
74 Fma {
75 a: Variable<D>,
76 b: Variable<D>,
77 c: Variable<D>,
78 out: Variable<D>,
79 },
80 Div(BinaryInstruction<D>),
81 FastDiv(BinaryInstruction<D>),
82 FastRecip(UnaryInstruction<D>),
83 Mul(BinaryInstruction<D>),
84 Sub(BinaryInstruction<D>),
85 SaturatingSub(BinaryInstruction<D>),
86 HiMul(BinaryInstruction<D>),
87 Index(IndexInstruction<D>),
88 IndexAssign(IndexAssignInstruction<D>),
89 Assign(UnaryInstruction<D>),
90 SpecialCast(UnaryInstruction<D>),
91 RangeLoop {
92 i: Variable<D>,
93 start: Variable<D>,
94 end: Variable<D>,
95 step: Option<Variable<D>>,
96 inclusive: bool,
97 instructions: Vec<Self>,
98 },
99 VecInit {
100 inputs: Vec<Variable<D>>,
101 out: Variable<D>,
102 },
103 Loop {
104 instructions: Vec<Self>,
105 },
106 If {
107 cond: Variable<D>,
108 instructions: Vec<Self>,
109 },
110 IfElse {
111 cond: Variable<D>,
112 instructions_if: Vec<Self>,
113 instructions_else: Vec<Self>,
114 },
115 Select {
116 cond: Variable<D>,
117 then: Variable<D>,
118 or_else: Variable<D>,
119 out: Variable<D>,
120 },
121 Switch {
122 value: Variable<D>,
123 instructions_default: Vec<Self>,
124 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
125 },
126 Slice {
127 input: Variable<D>,
128 start: Variable<D>,
129 end: Variable<D>,
130 out: Variable<D>,
131 },
132 CheckedSlice {
133 input: Variable<D>,
134 start: Variable<D>,
135 end: Variable<D>,
136 out: Variable<D>,
137 len: Variable<D>,
138 },
139 ReinterpretSlice {
140 input: Variable<D>,
141 line_size: u32,
142 out: Variable<D>,
143 },
144 Return,
145 Break,
146 Equal(BinaryInstruction<D>),
147 NotEqual(BinaryInstruction<D>),
148 Lower(BinaryInstruction<D>),
149 Greater(BinaryInstruction<D>),
150 LowerEqual(BinaryInstruction<D>),
151 GreaterEqual(BinaryInstruction<D>),
152 Erf(UnaryInstruction<D>),
153 BitwiseOr(BinaryInstruction<D>),
154 BitwiseAnd(BinaryInstruction<D>),
155 BitwiseXor(BinaryInstruction<D>),
156 CountBits(UnaryInstruction<D>),
157 ReverseBits(UnaryInstruction<D>),
158 ShiftLeft(BinaryInstruction<D>),
159 ShiftRight(BinaryInstruction<D>),
160 BitwiseNot(UnaryInstruction<D>),
161 LeadingZeros(UnaryInstruction<D>),
162 TrailingZeros(UnaryInstruction<D>),
163 FindFirstSet(UnaryInstruction<D>),
164 Abs(UnaryInstruction<D>),
165 Exp(UnaryInstruction<D>),
166 FastExp(UnaryInstruction<D>),
167 Log(UnaryInstruction<D>),
168 FastLog(UnaryInstruction<D>),
169 Log1p(UnaryInstruction<D>),
170 Cos(UnaryInstruction<D>),
171 Sin(UnaryInstruction<D>),
172 Tan(UnaryInstruction<D>),
173 Tanh(UnaryInstruction<D>),
174 Sinh(UnaryInstruction<D>),
175 Cosh(UnaryInstruction<D>),
176 ArcCos(UnaryInstruction<D>),
177 ArcSin(UnaryInstruction<D>),
178 ArcTan(UnaryInstruction<D>),
179 ArcSinh(UnaryInstruction<D>),
180 ArcCosh(UnaryInstruction<D>),
181 ArcTanh(UnaryInstruction<D>),
182 Degrees(UnaryInstruction<D>),
183 Radians(UnaryInstruction<D>),
184 ArcTan2(BinaryInstruction<D>),
185 FastSin(UnaryInstruction<D>),
186 FastCos(UnaryInstruction<D>),
187 FastTanh(UnaryInstruction<D>),
188 Powf(BinaryInstruction<D>),
189 FastPowf(BinaryInstruction<D>),
190 Powi(BinaryInstruction<D>),
191 Hypot(BinaryInstruction<D>),
192 Rhypot(BinaryInstruction<D>),
193 Sqrt(UnaryInstruction<D>),
194 FastSqrt(UnaryInstruction<D>),
195 InverseSqrt(UnaryInstruction<D>),
196 FastInverseSqrt(UnaryInstruction<D>),
197 Min(BinaryInstruction<D>),
198 Max(BinaryInstruction<D>),
199 Not(UnaryInstruction<D>),
200 Or(BinaryInstruction<D>),
201 And(BinaryInstruction<D>),
202 Clamp {
203 input: Variable<D>,
204 min_value: Variable<D>,
205 max_value: Variable<D>,
206 out: Variable<D>,
207 },
208 IsNan(UnaryInstruction<D>),
209 IsInf(UnaryInstruction<D>),
210 SyncThreads,
211 SyncWarp,
212 ThreadFence,
213 ProxyAsyncToSharedFence,
214 BulkCommitGroup,
215 BulkWaitGroup {
216 max_pending: u32,
217 },
218 BulkWaitGroupRead {
219 max_pending: u32,
220 },
221 TmaReplacePointer {
222 buffer: Variable<D>,
223 offset: Variable<D>,
224 tensor_map: Variable<D>,
225 out: Variable<D>,
226 },
227 Round(UnaryInstruction<D>),
228 Ceil(UnaryInstruction<D>),
229 Trunc(UnaryInstruction<D>),
230 Floor(UnaryInstruction<D>),
231 Warp(WarpInstruction<D>),
232 Wmma(WmmaInstruction<D>),
233 Bitcast(UnaryInstruction<D>),
234 AtomicLoad(UnaryInstruction<D>),
235 AtomicStore(UnaryInstruction<D>),
236 AtomicSwap(BinaryInstruction<D>),
237 AtomicAdd(BinaryInstruction<D>),
238 AtomicSub(BinaryInstruction<D>),
239 AtomicMax(BinaryInstruction<D>),
240 AtomicMin(BinaryInstruction<D>),
241 AtomicAnd(BinaryInstruction<D>),
242 AtomicOr(BinaryInstruction<D>),
243 AtomicXor(BinaryInstruction<D>),
244 AtomicCAS {
245 input: Variable<D>,
246 cmp: Variable<D>,
247 val: Variable<D>,
248 out: Variable<D>,
249 },
250 Neg(UnaryInstruction<D>),
251 Magnitude(UnaryInstruction<D>),
252 FastMagnitude(UnaryInstruction<D>),
253 Normalize(UnaryInstruction<D>),
254 FastNormalize(UnaryInstruction<D>),
255 Dot(BinaryInstruction<D>),
256 Copy {
257 input: Variable<D>,
258 in_index: Variable<D>,
259 out: Variable<D>,
260 out_index: Variable<D>,
261 },
262 CopyBulk {
263 input: Variable<D>,
264 in_index: Variable<D>,
265 out: Variable<D>,
266 out_index: Variable<D>,
267 len: u32,
268 },
269 Printf {
270 format_string: String,
271 args: Vec<Variable<D>>,
272 },
273 Comment {
274 content: String,
275 },
276 Pipeline(PipelineOps<D>),
277 Barrier(BarrierOps<D>),
278 MemCopyAsyncTensorSharedToGlobal {
279 smem_buffer: Variable<D>,
280 smem_offset: Variable<D>,
281 tensor_map: Variable<D>,
282 indices: Vec<Variable<D>>,
283 },
284 Line {
285 file: Cow<'static, str>,
286 line: u32,
287 },
288}
289
290impl<D: Dialect> Display for Instruction<D> {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 match self {
293 Instruction::Return => f.write_str("return;"),
294 Instruction::Break => f.write_str("break;"),
295 Instruction::DeclareVariable { var } => match var {
296 Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
297 _ => {
298 let item = var.item();
299 writeln!(f, "{item} {var};")
300 }
301 },
302 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
303 Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
304 Instruction::Slice {
305 input,
306 start,
307 end,
308 out,
309 } => {
310 let item = out.item();
311 let addr_space = D::address_space_for_variable(input);
312 writeln!(f, "const uint {out}_length = {end} - {start};")?;
313 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
314 }
315 Instruction::CheckedSlice {
316 input,
317 start,
318 end,
319 out,
320 len,
321 } => {
322 let item = out.item();
323 let addr_space = D::address_space_for_variable(input);
324 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
325 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
326 }
327 Instruction::ReinterpretSlice {
328 input,
329 line_size,
330 out,
331 } => {
332 let mut item = out.item();
333 item.vectorization = *line_size as usize;
334 let addr_space = D::address_space_for_variable(input);
335
336 writeln!(
337 f,
338 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
339 )
340 }
341 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
342 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
343 Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
344 Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
345 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
346 Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
347 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
348 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
349 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
350 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
351 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
352 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
353 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
354 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
355 Instruction::TrailingZeros(it) => TrailingZeros::format(f, &it.input, &it.out),
356 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
357 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
358 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
359 Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
360 Instruction::IndexAssign(it) => {
361 IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
362 }
363 Instruction::Copy {
364 input,
365 in_index,
366 out,
367 out_index,
368 } => {
369 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
370 }
371 Instruction::CopyBulk {
372 input,
373 in_index,
374 out,
375 out_index,
376 len,
377 } => {
378 for i in 0..*len {
379 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
380 }
381 Ok(())
382 }
383 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
384 Instruction::RangeLoop {
385 i,
386 start,
387 end,
388 step,
389 inclusive,
390 instructions,
391 } => {
392 let increment = step
393 .map(|step| format!("{i} += {step}"))
394 .unwrap_or_else(|| format!("++{i}"));
395 let cmp = if *inclusive { "<=" } else { "<" };
396 let i_ty = i.item();
397
398 write!(
399 f,
400 "
401for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
402"
403 )?;
404 for instruction in instructions {
405 write!(f, "{instruction}")?;
406 }
407
408 f.write_str("}\n")
409 }
410 Instruction::Loop { instructions } => {
411 writeln!(f, "while (true) {{")?;
412 for i in instructions {
413 write!(f, "{i}")?;
414 }
415 f.write_str("}\n")
416 }
417 Instruction::If { cond, instructions } => {
418 writeln!(f, "if ({cond}) {{")?;
419 for i in instructions {
420 write!(f, "{i}")?;
421 }
422 f.write_str("}\n")
423 }
424 Instruction::IfElse {
425 cond,
426 instructions_if,
427 instructions_else,
428 } => {
429 writeln!(f, "if ({cond}) {{")?;
430 for i in instructions_if {
431 write!(f, "{i}")?;
432 }
433 f.write_str("} else {\n")?;
434 for i in instructions_else {
435 write!(f, "{i}")?;
436 }
437 f.write_str("}\n")
438 }
439 Instruction::Select {
440 cond,
441 then,
442 or_else,
443 out,
444 } => {
445 let item_or_else = or_else.item();
446 let item_then = then.item();
447 let item_out = out.item();
448
449 let vf_then = item_then.vectorization;
450 let vf_or_else = item_or_else.vectorization;
451 let vf_out = item_out.vectorization;
452 let vf_cond = cond.item().vectorization;
453
454 let item_out = out.item();
455 let cond_elem = cond.item().elem;
456 let out = out.fmt_left();
457
458 let should_broadcast =
459 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
460
461 if should_broadcast {
462 let vf = usize::max(vf_cond, vf_out);
463 let vf = usize::max(vf, vf_then);
464 let vf = usize::max(vf, vf_or_else);
465
466 writeln!(f, "{out} = {item_out} {{")?;
467 for i in 0..vf {
468 let theni = then.index(i);
469 let or_elsei = or_else.index(i);
470 let condi = cond.index(i);
471 let condi = EnsureBoolArg {
472 var: &condi,
473 elem: &cond_elem,
474 };
475
476 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
477 }
478
479 writeln!(f, "}};")
480 } else {
481 let cond = EnsureBoolArg {
482 var: &cond,
483 elem: &cond_elem,
484 };
485 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
486 }
487 }
488 Instruction::Switch {
489 value,
490 instructions_default,
491 instructions_cases,
492 } => {
493 writeln!(f, "switch({value}) {{")?;
494 for (value, block) in instructions_cases {
495 write!(f, "case {value}:\n{{\n")?;
496 for i in block {
497 i.fmt(f)?;
498 }
499 f.write_str("break;\n}\n")?;
500 }
501 f.write_str("default:\n{")?;
502 for i in instructions_default {
503 i.fmt(f)?;
504 }
505 f.write_str("}\n}\n")
506 }
507 Instruction::Metadata {
508 info_offset,
509 split_meta,
510 out,
511 } => {
512 let out = out.fmt_left();
513 match *split_meta {
514 true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
515 false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
516 }
517 }
518 Instruction::ExtendedMetadata {
519 info_offset,
520 dim,
521 split_meta,
522 static_offset,
523 out,
524 } => {
525 let out = out.fmt_left();
526 match *split_meta {
527 true => writeln!(
528 f,
529 "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
530 ),
531 false => writeln!(
532 f,
533 "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
534 ),
535 }
536 }
537 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
538 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
539 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
540 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
541 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
542 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
543 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
544 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
545 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
546 Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
547 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
548 Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
549 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
550 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
551 Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
552 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
553 Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
554 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
555 Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
556 Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
557 Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
558 Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
559 Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
560 Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
561 Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
562 Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
563 Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
564 Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
565 Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
566 Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
567 Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
568 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
569 Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
570 Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
571 Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out),
572 Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out),
573 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
574 Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
575 Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
576 Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
577 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
578 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
579 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
580 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
581 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
582 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
583 Instruction::Clamp {
584 input,
585 min_value,
586 max_value,
587 out,
588 } => Clamp::format(f, input, min_value, max_value, out),
589 Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
590 Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
591 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
592 Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
593 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
594 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
595 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
596 Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
597 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
598 Instruction::SliceLength { input, out } => {
599 let out = out.fmt_left();
600 writeln!(f, "{out} = {input}_length;")
601 }
602 Instruction::ConstLength { length, out } => {
603 let out = out.fmt_left();
604 writeln!(f, "{out} = {length};")
605 }
606 Instruction::Warp(it) => write!(f, "{it}"),
607 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
608 Instruction::Wmma(it) => write!(f, "{it}"),
609 Instruction::Bitcast(UnaryInstruction { input, out }) => {
610 let qualifier = out.const_qualifier();
611 let input_item = input.item();
612 let out_item = out.item();
613
614 if out_item.elem.size() * out_item.vectorization
615 != input.item().elem.size() * input.item().vectorization
616 {
617 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
618 } else {
619 let out = out.fmt_left();
620 let addr_space = D::address_space_for_variable(input);
621 writeln!(
622 f,
623 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
624 )
625 }
626 }
627 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
628 D::compile_atomic_add(f, lhs, rhs, out)
629 }
630 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
631 D::compile_atomic_and(f, lhs, rhs, out)
632 }
633 Instruction::AtomicCAS {
634 input,
635 cmp,
636 val,
637 out,
638 } => D::compile_atomic_cas(f, input, cmp, val, out),
639 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
640 D::compile_atomic_load(f, input, out)
641 }
642 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
643 D::compile_atomic_max(f, lhs, rhs, out)
644 }
645 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
646 D::compile_atomic_min(f, lhs, rhs, out)
647 }
648 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
649 D::compile_atomic_or(f, lhs, rhs, out)
650 }
651 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
652 D::compile_atomic_store(f, input, out)
653 }
654 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
655 D::compile_atomic_sub(f, lhs, rhs, out)
656 }
657 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
658 D::compile_atomic_swap(f, lhs, rhs, out)
659 }
660 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
661 D::compile_atomic_xor(f, lhs, rhs, out)
662 }
663 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
664 Instruction::Neg(UnaryInstruction { input, out }) => {
665 let out = out.fmt_left();
666 writeln!(f, "{out} = -{input};")
667 }
668 Instruction::Normalize(inst) => {
669 Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
670 }
671 Instruction::FastNormalize(inst) => {
672 Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
673 }
674 Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
675 Instruction::FastMagnitude(inst) => {
676 Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
677 }
678 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
679 Instruction::VecInit { inputs, out } => {
680 let item = out.item();
681 let inputs = inputs
682 .iter()
683 .map(|input| format!("{input}"))
684 .collect::<Vec<_>>();
685 let out = out.fmt_left();
686 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
687 }
688 Instruction::Printf {
689 format_string,
690 args,
691 } => D::compile_instruction_printf(f, format_string, args),
692 Instruction::Comment { content } => {
693 if content.contains('\n') {
694 writeln!(f, "/* {content} */")
695 } else {
696 writeln!(f, "// {content}")
697 }
698 }
699 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
700 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
701 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
702 Instruction::ProxyAsyncToSharedFence => {
703 writeln!(
704 f,
705 "cuda::device::experimental::fence_proxy_async_shared_cta();"
706 )
707 }
708 Instruction::BulkCommitGroup => writeln!(
709 f,
710 "cuda::device::experimental::cp_async_bulk_commit_group();"
711 ),
712 Instruction::BulkWaitGroup { max_pending } => writeln!(
713 f,
714 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
715 ),
716 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
717 f,
718 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
719 ),
720 Instruction::TmaReplacePointer {
721 buffer,
722 offset,
723 tensor_map,
724 out,
725 } => {
726 let pos = Variable::<D>::UnitPos;
727 writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
728 writeln!(
729 f,
730 "
731if({pos} == 0) {{
732 {out} = {tensor_map};
733 tensormap_replace_global_address({out}, &{buffer}[{offset}]);
734}}"
735 )?;
736 writeln!(f, "__syncthreads();")
737 }
738 Instruction::MemCopyAsyncTensorSharedToGlobal {
739 smem_buffer,
740 smem_offset,
741 tensor_map,
742 indices,
743 } => {
744 let rank = indices.len();
745 let smem_ptr = smem_buffer.fmt_ptr();
746 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
747 let _ = write!(s, "{it}, ");
748 s
749 });
750 writeln!(
751 f,
752 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
753 )
754 }
755 Instruction::SpecialCast(UnaryInstruction { input, out }) => {
756 #[cfg(not(feature = "cuda"))]
758 {
759 let _ = (input, out);
760 writeln!(
761 f,
762 "#error FP8/FP6/FP4 casting isn't supported outside of CUDA"
763 )
764 }
765 #[cfg(feature = "cuda")]
766 crate::cuda::convert::special_cast::<D>(f, input, out)
767 }
768 }
769 }
770}
771
772struct Fma<D: Dialect> {
773 _dialect: PhantomData<D>,
774}
775
776impl<D: Dialect> Fma<D> {
777 fn format(
778 f: &mut core::fmt::Formatter<'_>,
779 a: &Variable<D>,
780 b: &Variable<D>,
781 c: &Variable<D>,
782 out: &Variable<D>,
783 ) -> core::fmt::Result {
784 let out_item = out.item();
785 let num = out_item.vectorization;
786
787 let out = out.fmt_left();
788 if num == 1 {
789 writeln!(f, "{out} = fma({a}, {b}, {c});")
790 } else {
791 writeln!(f, "{out} = {out_item}{{")?;
792
793 for i in 0..num {
794 let ai = a.index(i);
795 let bi = b.index(i);
796 let ci = c.index(i);
797
798 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
799 }
800 f.write_str("};\n")
801 }
802 }
803}
804
805struct Clamp<D: Dialect> {
806 _dialect: PhantomData<D>,
807}
808
809impl<D: Dialect> Clamp<D> {
810 fn format(
811 f: &mut core::fmt::Formatter<'_>,
812 input: &Variable<D>,
813 min_value: &Variable<D>,
814 max_value: &Variable<D>,
815 out: &Variable<D>,
816 ) -> core::fmt::Result {
817 let out_item = out.item();
818 if out.item().vectorization == 1 {
819 let out = out.fmt_left();
820 write!(f, "{out} = ")?;
821 Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
822 f.write_str(";\n")
823 } else {
824 Self::unroll_vec(f, input, min_value, max_value, out)
825 }
826 }
827
828 fn format_scalar(
829 f: &mut Formatter<'_>,
830 input: impl Component<D>,
831 min_value: impl Component<D>,
832 max_value: impl Component<D>,
833 item: Item<D>,
834 ) -> std::fmt::Result {
835 D::compile_instruction_max_function_name(f, item)?;
836 write!(f, "({min_value}, ")?;
837 D::compile_instruction_min_function_name(f, item)?;
838 write!(f, "({max_value}, {input}))")
839 }
840
841 fn unroll_vec(
842 f: &mut core::fmt::Formatter<'_>,
843 input: &Variable<D>,
844 min_value: &Variable<D>,
845 max_value: &Variable<D>,
846 out: &Variable<D>,
847 ) -> std::fmt::Result {
848 let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
849 let [input, min_value, max_value, out_optimized] = optimized.args;
850
851 let item_out_original = out.item();
852 let item_out_optimized = out_optimized.item();
853
854 let index = match optimized.optimization_factor {
855 Some(factor) => item_out_original.vectorization / factor,
856 None => item_out_optimized.vectorization,
857 };
858
859 let mut write_op = |input: &Variable<D>,
860 min_value: &Variable<D>,
861 max_value: &Variable<D>,
862 out: &Variable<D>,
863 item_out: Item<D>| {
864 let out = out.fmt_left();
865 writeln!(f, "{out} = {item_out}{{")?;
866 for i in 0..index {
867 let inputi = input.index(i);
868 let min_valuei = min_value.index(i);
869 let max_valuei = max_value.index(i);
870
871 Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
872 f.write_str(", ")?;
873 }
874
875 f.write_str("};\n")
876 };
877
878 if item_out_original == item_out_optimized {
879 write_op(&input, &min_value, &max_value, out, item_out_optimized)
880 } else {
881 let out_tmp = Variable::tmp(item_out_optimized);
882 write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
883 let addr_space = D::address_space_for_variable(out);
884 let out = out.fmt_left();
885
886 writeln!(
887 f,
888 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
889 )?;
890
891 Ok(())
892 }
893 }
894}
895
896struct Remainder<D: Dialect> {
897 _dialect: PhantomData<D>,
898}
899
900impl<D: Dialect> Remainder<D> {
901 fn format(
902 f: &mut core::fmt::Formatter<'_>,
903 lhs: &Variable<D>,
904 rhs: &Variable<D>,
905 out: &Variable<D>,
906 ) -> core::fmt::Result {
907 let floor = |elem| {
908 let prefix = match elem {
909 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
910 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
911 _ => "",
912 };
913 format!("{prefix}floor")
914 };
915
916 let is_int = matches!(
917 out.elem(),
918 Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
919 );
920 let out_elem = out.elem();
921 let rem_expr = |lhs, rhs, floor: &str| {
922 if is_int {
923 format!("{lhs} - {rhs} * ({out_elem}){floor}((float){lhs} / (float){rhs})")
924 } else {
925 format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
926 }
927 };
928
929 if out.item().vectorization == 1 {
930 let floor = floor(out.elem());
931
932 let out = out.fmt_left();
933 let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
934 return writeln!(f, "{out} = {rem};");
935 }
936
937 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
938 let [lhs, rhs, out_optimized] = optimized.args;
939
940 let item_out_original = out.item();
941 let item_out_optimized = out_optimized.item();
942
943 let index = match optimized.optimization_factor {
944 Some(factor) => item_out_original.vectorization / factor,
945 None => item_out_optimized.vectorization,
946 };
947
948 let floor = floor(*item_out_optimized.elem());
949
950 let mut write_op =
951 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
952 let out = out.fmt_left();
953 writeln!(f, "{out} = {item_out}{{")?;
954 for i in 0..index {
955 let lhsi = lhs.index(i);
956 let rhsi = rhs.index(i);
957
958 let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
959 writeln!(f, "{rem}")?;
960 f.write_str(", ")?;
961 }
962
963 f.write_str("};\n")
964 };
965
966 if item_out_original == item_out_optimized {
967 write_op(&lhs, &rhs, out, item_out_optimized)
968 } else {
969 let out_tmp = Variable::tmp(item_out_optimized);
970
971 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
972
973 let addr_space = D::address_space_for_variable(&out_tmp);
974 let qualifier = out.const_qualifier();
975 let out = out.fmt_left();
976
977 writeln!(
978 f,
979 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
980 )?;
981
982 Ok(())
983 }
984 }
985}
986
987struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
988 _dialect: PhantomData<D>,
989 _sqrt: PhantomData<S>,
990}
991
992impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
993 fn format(
994 f: &mut core::fmt::Formatter<'_>,
995 input: &Variable<D>,
996 out: &Variable<D>,
997 ) -> core::fmt::Result {
998 let num = input.item().vectorization;
999 let elem = input.elem();
1000
1001 let mag = format!("{out}_mag");
1002
1003 writeln!(f, "{} {mag} = 0.0;", out.item())?;
1004
1005 for i in 0..num {
1006 let input_i = input.index(i);
1007 writeln!(f, "{mag} += {input_i} * {input_i};")?;
1008 }
1009
1010 let out = out.fmt_left();
1011 write!(f, "{out} = ")?;
1012 S::format_unary(f, &mag, elem)?;
1013 f.write_str(";\n")
1014 }
1015}
1016
1017struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1018 _dialect: PhantomData<D>,
1019 _rsqrt: PhantomData<InvS>,
1020}
1021
1022impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1023 fn format(
1024 f: &mut core::fmt::Formatter<'_>,
1025 input: &Variable<D>,
1026 out: &Variable<D>,
1027 ) -> core::fmt::Result {
1028 let num = input.item().vectorization;
1029 let elem = input.elem();
1030 let norm = format!("{out}_norm");
1031
1032 let out_item = out.item();
1033 let out = out.fmt_left();
1034 writeln!(f, "{elem} {norm} = 0.0;")?;
1035
1036 for i in 0..num {
1037 let input_i = input.index(i);
1038 writeln!(f, "{norm} += {input_i} * {input_i};")?;
1039 }
1040
1041 write!(f, "{norm} = ")?;
1042 InvS::format_unary(f, &norm, elem)?;
1043 f.write_str(";\n")?;
1044
1045 if num == 1 {
1046 writeln!(f, "{out} = {input} * {norm};")
1047 } else {
1048 write!(f, "{out} = {out_item}{{")?;
1049 for i in 0..num {
1050 let input_i = input.index(i);
1051
1052 writeln!(f, "{input_i} * {norm},")?;
1053 }
1054
1055 f.write_str("};\n")
1056 }
1057 }
1058}
1059
1060struct Dot<D: Dialect> {
1061 _dialect: PhantomData<D>,
1062}
1063
1064impl<D: Dialect> Dot<D> {
1065 fn format(
1066 f: &mut core::fmt::Formatter<'_>,
1067 lhs: &Variable<D>,
1068 rhs: &Variable<D>,
1069 out: &Variable<D>,
1070 ) -> core::fmt::Result {
1071 let num = lhs.item().vectorization;
1072
1073 let muls = (0..num)
1074 .map(|i| {
1075 let lhs_i = lhs.index(i);
1076 let rhs_i = rhs.index(i);
1077 format!("{lhs_i} * {rhs_i}")
1078 })
1079 .collect::<Vec<_>>();
1080
1081 let out = out.fmt_left();
1082 writeln!(f, "{out} = {};", muls.join(" + "))
1083 }
1084}
1085
1086struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1087 var: &'a V,
1088 elem: &'a Elem<D>,
1089}
1090
1091impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1092 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1093 if self.elem != &Elem::Bool {
1094 write!(f, "bool({})", self.var)
1095 } else {
1096 write!(f, "{}", self.var)
1097 }
1098 }
1099}