1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Formatter, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const DYNAMIC_META_NAME: &str = "dynamic_meta";
15pub(crate) const STATIC_META_NAME: &str = "info.static_meta";
16
17#[derive(Debug, Clone, Copy)]
18pub struct BinaryInstruction<D: Dialect> {
19 pub lhs: Variable<D>,
20 pub rhs: Variable<D>,
21 pub out: Variable<D>,
22}
23
24#[derive(Debug, Clone)]
25pub struct IndexInstruction<D: Dialect> {
26 pub list: Variable<D>,
27 pub index: Variable<D>,
28 pub vector_size: u32,
29 pub out: Variable<D>,
30}
31
32#[derive(Debug, Clone)]
33pub struct IndexAssignInstruction<D: Dialect> {
34 pub index: Variable<D>,
35 pub value: Variable<D>,
36 pub vector_size: u32,
37 pub out: Variable<D>,
38}
39
40#[derive(Debug, Clone, Copy)]
41pub struct UnaryInstruction<D: Dialect> {
42 pub input: Variable<D>,
43 pub out: Variable<D>,
44}
45
46#[derive(Debug, Clone)]
47pub enum Instruction<D: Dialect> {
48 Metadata {
49 info_offset: Variable<D>,
50 out: Variable<D>,
51 },
52 ExtendedMetadata {
53 info_offset: Variable<D>,
54 dim: Variable<D>,
55 out: Variable<D>,
56 },
57 ConstLength {
58 length: usize,
59 out: Variable<D>,
60 },
61 SliceLength {
62 input: Variable<D>,
63 out: Variable<D>,
64 },
65 DeclareVariable {
66 var: Variable<D>,
67 },
68 Modulo(BinaryInstruction<D>),
69 Remainder(BinaryInstruction<D>),
70 Add(BinaryInstruction<D>),
71 SaturatingAdd(BinaryInstruction<D>),
72 Fma {
73 a: Variable<D>,
74 b: Variable<D>,
75 c: Variable<D>,
76 out: Variable<D>,
77 },
78 Div(BinaryInstruction<D>),
79 FastDiv(BinaryInstruction<D>),
80 FastRecip(UnaryInstruction<D>),
81 Mul(BinaryInstruction<D>),
82 Sub(BinaryInstruction<D>),
83 SaturatingSub(BinaryInstruction<D>),
84 HiMul(BinaryInstruction<D>),
85 Index(IndexInstruction<D>),
86 IndexAssign(IndexAssignInstruction<D>),
87 Assign(UnaryInstruction<D>),
88 SpecialCast(UnaryInstruction<D>),
89 RangeLoop {
90 i: Variable<D>,
91 start: Variable<D>,
92 end: Variable<D>,
93 step: Option<Variable<D>>,
94 inclusive: bool,
95 instructions: Vec<Self>,
96 },
97 VecInit {
98 inputs: Vec<Variable<D>>,
99 out: Variable<D>,
100 },
101 Loop {
102 instructions: Vec<Self>,
103 },
104 If {
105 cond: Variable<D>,
106 instructions: Vec<Self>,
107 },
108 IfElse {
109 cond: Variable<D>,
110 instructions_if: Vec<Self>,
111 instructions_else: Vec<Self>,
112 },
113 Select {
114 cond: Variable<D>,
115 then: Variable<D>,
116 or_else: Variable<D>,
117 out: Variable<D>,
118 },
119 Switch {
120 value: Variable<D>,
121 instructions_default: Vec<Self>,
122 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
123 },
124 Slice {
125 input: Variable<D>,
126 start: Variable<D>,
127 end: Variable<D>,
128 out: Variable<D>,
129 },
130 CheckedSlice {
131 input: Variable<D>,
132 start: Variable<D>,
133 end: Variable<D>,
134 out: Variable<D>,
135 len: Variable<D>,
136 },
137 ReinterpretSlice {
138 input: Variable<D>,
139 vector_size: u32,
140 out: Variable<D>,
141 },
142 Return,
143 Break,
144 Unreachable,
145 Equal(BinaryInstruction<D>),
146 NotEqual(BinaryInstruction<D>),
147 Lower(BinaryInstruction<D>),
148 Greater(BinaryInstruction<D>),
149 LowerEqual(BinaryInstruction<D>),
150 GreaterEqual(BinaryInstruction<D>),
151 Erf(UnaryInstruction<D>),
152 BitwiseOr(BinaryInstruction<D>),
153 BitwiseAnd(BinaryInstruction<D>),
154 BitwiseXor(BinaryInstruction<D>),
155 CountBits(UnaryInstruction<D>),
156 ReverseBits(UnaryInstruction<D>),
157 ShiftLeft(BinaryInstruction<D>),
158 ShiftRight(BinaryInstruction<D>),
159 BitwiseNot(UnaryInstruction<D>),
160 LeadingZeros(UnaryInstruction<D>),
161 TrailingZeros(UnaryInstruction<D>),
162 FindFirstSet(UnaryInstruction<D>),
163 Abs(UnaryInstruction<D>),
164 Exp(UnaryInstruction<D>),
165 FastExp(UnaryInstruction<D>),
166 Log(UnaryInstruction<D>),
167 FastLog(UnaryInstruction<D>),
168 Log1p(UnaryInstruction<D>),
169 Cos(UnaryInstruction<D>),
170 Sin(UnaryInstruction<D>),
171 Tan(UnaryInstruction<D>),
172 Tanh(UnaryInstruction<D>),
173 Sinh(UnaryInstruction<D>),
174 Cosh(UnaryInstruction<D>),
175 ArcCos(UnaryInstruction<D>),
176 ArcSin(UnaryInstruction<D>),
177 ArcTan(UnaryInstruction<D>),
178 ArcSinh(UnaryInstruction<D>),
179 ArcCosh(UnaryInstruction<D>),
180 ArcTanh(UnaryInstruction<D>),
181 Degrees(UnaryInstruction<D>),
182 Radians(UnaryInstruction<D>),
183 ArcTan2(BinaryInstruction<D>),
184 FastSin(UnaryInstruction<D>),
185 FastCos(UnaryInstruction<D>),
186 FastTanh(UnaryInstruction<D>),
187 Powf(BinaryInstruction<D>),
188 FastPowf(BinaryInstruction<D>),
189 Powi(BinaryInstruction<D>),
190 Hypot(BinaryInstruction<D>),
191 Rhypot(BinaryInstruction<D>),
192 Sqrt(UnaryInstruction<D>),
193 FastSqrt(UnaryInstruction<D>),
194 InverseSqrt(UnaryInstruction<D>),
195 FastInverseSqrt(UnaryInstruction<D>),
196 Min(BinaryInstruction<D>),
197 Max(BinaryInstruction<D>),
198 Not(UnaryInstruction<D>),
199 Or(BinaryInstruction<D>),
200 And(BinaryInstruction<D>),
201 Clamp {
202 input: Variable<D>,
203 min_value: Variable<D>,
204 max_value: Variable<D>,
205 out: Variable<D>,
206 },
207 IsNan(UnaryInstruction<D>),
208 IsInf(UnaryInstruction<D>),
209 SyncThreads,
210 SyncWarp,
211 ThreadFence,
212 ProxyAsyncToSharedFence,
213 BulkCommitGroup,
214 BulkWaitGroup {
215 max_pending: u32,
216 },
217 BulkWaitGroupRead {
218 max_pending: u32,
219 },
220 TmaReplacePointer {
221 buffer: Variable<D>,
222 offset: Variable<D>,
223 tensor_map: Variable<D>,
224 out: Variable<D>,
225 },
226 Round(UnaryInstruction<D>),
227 Ceil(UnaryInstruction<D>),
228 Trunc(UnaryInstruction<D>),
229 Floor(UnaryInstruction<D>),
230 Warp(WarpInstruction<D>),
231 Wmma(WmmaInstruction<D>),
232 Bitcast(UnaryInstruction<D>),
233 AtomicLoad(UnaryInstruction<D>),
234 AtomicStore(UnaryInstruction<D>),
235 AtomicSwap(BinaryInstruction<D>),
236 AtomicAdd(BinaryInstruction<D>),
237 AtomicSub(BinaryInstruction<D>),
238 AtomicMax(BinaryInstruction<D>),
239 AtomicMin(BinaryInstruction<D>),
240 AtomicAnd(BinaryInstruction<D>),
241 AtomicOr(BinaryInstruction<D>),
242 AtomicXor(BinaryInstruction<D>),
243 AtomicCAS {
244 input: Variable<D>,
245 cmp: Variable<D>,
246 val: Variable<D>,
247 out: Variable<D>,
248 },
249 Neg(UnaryInstruction<D>),
250 Magnitude(UnaryInstruction<D>),
251 FastMagnitude(UnaryInstruction<D>),
252 Normalize(UnaryInstruction<D>),
253 FastNormalize(UnaryInstruction<D>),
254 Dot(BinaryInstruction<D>),
255 Copy {
256 input: Variable<D>,
257 in_index: Variable<D>,
258 out: Variable<D>,
259 out_index: Variable<D>,
260 },
261 CopyBulk {
262 input: Variable<D>,
263 in_index: Variable<D>,
264 out: Variable<D>,
265 out_index: Variable<D>,
266 len: u32,
267 },
268 Printf {
269 format_string: String,
270 args: Vec<Variable<D>>,
271 },
272 Comment {
273 content: String,
274 },
275 Pipeline(PipelineOps<D>),
276 Barrier(BarrierOps<D>),
277 MemCopyAsyncTensorSharedToGlobal {
278 smem_buffer: Variable<D>,
279 smem_offset: Variable<D>,
280 tensor_map: Variable<D>,
281 indices: Vec<Variable<D>>,
282 },
283 Line {
284 file: Cow<'static, str>,
285 line: u32,
286 },
287}
288
289impl<D: Dialect> Display for Instruction<D> {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 match self {
292 Instruction::Return => f.write_str("return;"),
293 Instruction::Break => f.write_str("break;"),
294 Instruction::Unreachable => D::compile_unreachable(f),
295 Instruction::DeclareVariable { var } => match var {
296 Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
297 _ => {
298 let item = var.item();
299 writeln!(f, "{item} {var};")
300 }
301 },
302 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
303 Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
304 Instruction::Slice {
305 input,
306 start,
307 end,
308 out,
309 } => {
310 let item = out.item();
311 let addr_space = D::address_space_for_variable(input);
312 writeln!(f, "const uint {out}_length = {end} - {start};")?;
313 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
314 }
315 Instruction::CheckedSlice {
316 input,
317 start,
318 end,
319 out,
320 len,
321 } => {
322 let item = out.item();
323 let addr_space = D::address_space_for_variable(input);
324 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
325 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
326 }
327 Instruction::ReinterpretSlice {
328 input,
329 vector_size,
330 out,
331 } => {
332 let mut item = out.item();
333 item.vectorization = *vector_size as usize;
334 let addr_space = D::address_space_for_variable(input);
335
336 writeln!(
337 f,
338 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
339 )
340 }
341 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
342 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
343 Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
344 Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
345 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
346 Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
347 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
348 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
349 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
350 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
351 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
352 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
353 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
354 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
355 Instruction::TrailingZeros(it) => TrailingZeros::format(f, &it.input, &it.out),
356 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
357 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
358 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
359 Instruction::Index(it) => {
360 Index::format(f, &it.list, &it.index, &it.out, it.vector_size)
361 }
362 Instruction::IndexAssign(it) => {
363 IndexAssign::format(f, &it.index, &it.value, &it.out, it.vector_size)
364 }
365 Instruction::Copy {
366 input,
367 in_index,
368 out,
369 out_index,
370 } => {
371 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
372 }
373 Instruction::CopyBulk {
374 input,
375 in_index,
376 out,
377 out_index,
378 len,
379 } => {
380 for i in 0..*len {
381 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
382 }
383 Ok(())
384 }
385 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
386 Instruction::RangeLoop {
387 i,
388 start,
389 end,
390 step,
391 inclusive,
392 instructions,
393 } => {
394 let increment = step
395 .map(|step| format!("{i} += {step}"))
396 .unwrap_or_else(|| format!("++{i}"));
397 let cmp = if *inclusive { "<=" } else { "<" };
398 let i_ty = i.item();
399
400 write!(
401 f,
402 "
403for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
404"
405 )?;
406 for instruction in instructions {
407 write!(f, "{instruction}")?;
408 }
409
410 f.write_str("}\n")
411 }
412 Instruction::Loop { instructions } => {
413 writeln!(f, "while (true) {{")?;
414 for i in instructions {
415 write!(f, "{i}")?;
416 }
417 f.write_str("}\n")
418 }
419 Instruction::If { cond, instructions } => {
420 writeln!(f, "if ({cond}) {{")?;
421 for i in instructions {
422 write!(f, "{i}")?;
423 }
424 f.write_str("}\n")
425 }
426 Instruction::IfElse {
427 cond,
428 instructions_if,
429 instructions_else,
430 } => {
431 writeln!(f, "if ({cond}) {{")?;
432 for i in instructions_if {
433 write!(f, "{i}")?;
434 }
435 f.write_str("} else {\n")?;
436 for i in instructions_else {
437 write!(f, "{i}")?;
438 }
439 f.write_str("}\n")
440 }
441 Instruction::Select {
442 cond,
443 then,
444 or_else,
445 out,
446 } => {
447 let item_or_else = or_else.item();
448 let item_then = then.item();
449 let item_out = out.item();
450
451 let vf_then = item_then.vectorization;
452 let vf_or_else = item_or_else.vectorization;
453 let vf_out = item_out.vectorization;
454 let vf_cond = cond.item().vectorization;
455
456 let item_out = out.item();
457 let cond_elem = cond.item().elem;
458 let out = out.fmt_left();
459
460 let should_broadcast =
461 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
462
463 if should_broadcast {
464 let vf = usize::max(vf_cond, vf_out);
465 let vf = usize::max(vf, vf_then);
466 let vf = usize::max(vf, vf_or_else);
467
468 writeln!(f, "{out} = {item_out} {{")?;
469 for i in 0..vf {
470 let theni = then.index(i);
471 let or_elsei = or_else.index(i);
472 let condi = cond.index(i);
473 let condi = EnsureBoolArg {
474 var: &condi,
475 elem: &cond_elem,
476 };
477
478 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
479 }
480
481 writeln!(f, "}};")
482 } else {
483 let cond = EnsureBoolArg {
484 var: &cond,
485 elem: &cond_elem,
486 };
487 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
488 }
489 }
490 Instruction::Switch {
491 value,
492 instructions_default,
493 instructions_cases,
494 } => {
495 writeln!(f, "switch({value}) {{")?;
496 for (value, block) in instructions_cases {
497 write!(f, "case {value}:\n{{\n")?;
498 for i in block {
499 i.fmt(f)?;
500 }
501 f.write_str("break;\n}\n")?;
502 }
503 f.write_str("default:\n{")?;
504 for i in instructions_default {
505 i.fmt(f)?;
506 }
507 f.write_str("break;\n}\n}\n")
508 }
509 Instruction::Metadata { info_offset, out } => {
510 let out = out.fmt_left();
511 writeln!(f, "{out} = {STATIC_META_NAME}[{info_offset}];")
512 }
513 Instruction::ExtendedMetadata {
514 info_offset,
515 dim,
516 out,
517 } => {
518 let out = out.fmt_left();
519 writeln!(
520 f,
521 "{out} = {DYNAMIC_META_NAME}[{STATIC_META_NAME}[{info_offset}] + {dim}];"
522 )
523 }
524 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
525 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
526 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
527 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
528 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
529 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
530 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
531 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
532 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
533 Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
534 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
535 Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
536 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
537 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
538 Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
539 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
540 Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
541 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
542 Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
543 Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
544 Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
545 Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
546 Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
547 Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
548 Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
549 Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
550 Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
551 Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
552 Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
553 Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
554 Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
555 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
556 Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
557 Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
558 Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out),
559 Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out),
560 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
561 Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
562 Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
563 Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
564 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
565 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
566 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
567 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
568 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
569 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
570 Instruction::Clamp {
571 input,
572 min_value,
573 max_value,
574 out,
575 } => Clamp::format(f, input, min_value, max_value, out),
576 Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
577 Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
578 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
579 Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
580 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
581 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
582 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
583 Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
584 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
585 Instruction::SliceLength { input, out } => {
586 let out = out.fmt_left();
587 writeln!(f, "{out} = {input}_length;")
588 }
589 Instruction::ConstLength { length, out } => {
590 let out = out.fmt_left();
591 writeln!(f, "{out} = {length};")
592 }
593 Instruction::Warp(it) => write!(f, "{it}"),
594 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
595 Instruction::Wmma(it) => write!(f, "{it}"),
596 Instruction::Bitcast(UnaryInstruction { input, out }) => {
597 let qualifier = out.const_qualifier();
598 let input_item = input.item();
599 let out_item = out.item();
600
601 if out_item.elem.size() * out_item.vectorization
602 != input.item().elem.size() * input.item().vectorization
603 {
604 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
605 } else {
606 let out = out.fmt_left();
607 let addr_space = D::address_space_for_variable(input);
608 writeln!(
609 f,
610 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
611 )
612 }
613 }
614 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
615 D::compile_atomic_add(f, lhs, rhs, out)
616 }
617 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
618 D::compile_atomic_and(f, lhs, rhs, out)
619 }
620 Instruction::AtomicCAS {
621 input,
622 cmp,
623 val,
624 out,
625 } => D::compile_atomic_cas(f, input, cmp, val, out),
626 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
627 D::compile_atomic_load(f, input, out)
628 }
629 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
630 D::compile_atomic_max(f, lhs, rhs, out)
631 }
632 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
633 D::compile_atomic_min(f, lhs, rhs, out)
634 }
635 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
636 D::compile_atomic_or(f, lhs, rhs, out)
637 }
638 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
639 D::compile_atomic_store(f, input, out)
640 }
641 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
642 D::compile_atomic_sub(f, lhs, rhs, out)
643 }
644 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
645 D::compile_atomic_swap(f, lhs, rhs, out)
646 }
647 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
648 D::compile_atomic_xor(f, lhs, rhs, out)
649 }
650 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
651 Instruction::Neg(UnaryInstruction { input, out }) => {
652 let out = out.fmt_left();
653 writeln!(f, "{out} = -{input};")
654 }
655 Instruction::Normalize(inst) => {
656 Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
657 }
658 Instruction::FastNormalize(inst) => {
659 Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
660 }
661 Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
662 Instruction::FastMagnitude(inst) => {
663 Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
664 }
665 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
666 Instruction::VecInit { inputs, out } => {
667 let item = out.item();
668 let inputs = inputs
669 .iter()
670 .map(|input| format!("{input}"))
671 .collect::<Vec<_>>();
672 let out = out.fmt_left();
673 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
674 }
675 Instruction::Printf {
676 format_string,
677 args,
678 } => D::compile_instruction_printf(f, format_string, args),
679 Instruction::Comment { content } => {
680 if content.contains('\n') {
681 writeln!(f, "/* {content} */")
682 } else {
683 writeln!(f, "// {content}")
684 }
685 }
686 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
687 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
688 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
689 Instruction::ProxyAsyncToSharedFence => {
690 writeln!(
691 f,
692 "cuda::device::experimental::fence_proxy_async_shared_cta();"
693 )
694 }
695 Instruction::BulkCommitGroup => writeln!(
696 f,
697 "cuda::device::experimental::cp_async_bulk_commit_group();"
698 ),
699 Instruction::BulkWaitGroup { max_pending } => writeln!(
700 f,
701 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
702 ),
703 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
704 f,
705 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
706 ),
707 Instruction::TmaReplacePointer {
708 buffer,
709 offset,
710 tensor_map,
711 out,
712 } => {
713 let pos = Variable::<D>::UnitPos;
714 writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
715 writeln!(
716 f,
717 "
718if({pos} == 0) {{
719 {out} = {tensor_map};
720 tensormap_replace_global_address({out}, &{buffer}[{offset}]);
721}}"
722 )?;
723 writeln!(f, "__syncthreads();")
724 }
725 Instruction::MemCopyAsyncTensorSharedToGlobal {
726 smem_buffer,
727 smem_offset,
728 tensor_map,
729 indices,
730 } => {
731 let rank = indices.len();
732 let smem_ptr = smem_buffer.fmt_ptr();
733 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
734 let _ = write!(s, "{it}, ");
735 s
736 });
737 writeln!(
738 f,
739 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
740 )
741 }
742 Instruction::SpecialCast(UnaryInstruction { input, out }) => {
743 #[cfg(not(feature = "cuda"))]
745 {
746 let _ = (input, out);
747 writeln!(
748 f,
749 "#error FP8/FP6/FP4 casting isn't supported outside of CUDA"
750 )
751 }
752 #[cfg(feature = "cuda")]
753 crate::cuda::convert::special_cast::<D>(f, input, out)
754 }
755 }
756 }
757}
758
759struct Fma<D: Dialect> {
760 _dialect: PhantomData<D>,
761}
762
763impl<D: Dialect> Fma<D> {
764 fn format(
765 f: &mut core::fmt::Formatter<'_>,
766 a: &Variable<D>,
767 b: &Variable<D>,
768 c: &Variable<D>,
769 out: &Variable<D>,
770 ) -> core::fmt::Result {
771 let out_item = out.item();
772 let num = out_item.vectorization;
773
774 let out = out.fmt_left();
775 if num == 1 {
776 writeln!(f, "{out} = fma({a}, {b}, {c});")
777 } else {
778 writeln!(f, "{out} = {out_item}{{")?;
779
780 for i in 0..num {
781 let ai = a.index(i);
782 let bi = b.index(i);
783 let ci = c.index(i);
784
785 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
786 }
787 f.write_str("};\n")
788 }
789 }
790}
791
792struct Clamp<D: Dialect> {
793 _dialect: PhantomData<D>,
794}
795
796impl<D: Dialect> Clamp<D> {
797 fn format(
798 f: &mut core::fmt::Formatter<'_>,
799 input: &Variable<D>,
800 min_value: &Variable<D>,
801 max_value: &Variable<D>,
802 out: &Variable<D>,
803 ) -> core::fmt::Result {
804 let out_item = out.item();
805 if out.item().vectorization == 1 {
806 let out = out.fmt_left();
807 write!(f, "{out} = ")?;
808 Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
809 f.write_str(";\n")
810 } else {
811 Self::unroll_vec(f, input, min_value, max_value, out)
812 }
813 }
814
815 fn format_scalar(
816 f: &mut Formatter<'_>,
817 input: impl Component<D>,
818 min_value: impl Component<D>,
819 max_value: impl Component<D>,
820 item: Item<D>,
821 ) -> std::fmt::Result {
822 D::compile_instruction_max_function_name(f, item)?;
823 write!(f, "({min_value}, ")?;
824 D::compile_instruction_min_function_name(f, item)?;
825 write!(f, "({max_value}, {input}))")
826 }
827
828 fn unroll_vec(
829 f: &mut core::fmt::Formatter<'_>,
830 input: &Variable<D>,
831 min_value: &Variable<D>,
832 max_value: &Variable<D>,
833 out: &Variable<D>,
834 ) -> std::fmt::Result {
835 let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
836 let [input, min_value, max_value, out_optimized] = optimized.args;
837
838 let item_out_original = out.item();
839 let item_out_optimized = out_optimized.item();
840
841 let index = match optimized.optimization_factor {
842 Some(factor) => item_out_original.vectorization / factor,
843 None => item_out_optimized.vectorization,
844 };
845
846 let mut write_op = |input: &Variable<D>,
847 min_value: &Variable<D>,
848 max_value: &Variable<D>,
849 out: &Variable<D>,
850 item_out: Item<D>| {
851 let out = out.fmt_left();
852 writeln!(f, "{out} = {item_out}{{")?;
853 for i in 0..index {
854 let inputi = input.index(i);
855 let min_valuei = min_value.index(i);
856 let max_valuei = max_value.index(i);
857
858 Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
859 f.write_str(", ")?;
860 }
861
862 f.write_str("};\n")
863 };
864
865 if item_out_original == item_out_optimized {
866 write_op(&input, &min_value, &max_value, out, item_out_optimized)
867 } else {
868 let out_tmp = Variable::tmp(item_out_optimized);
869 write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
870 let addr_space = D::address_space_for_variable(out);
871 let out = out.fmt_left();
872
873 writeln!(
874 f,
875 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
876 )?;
877
878 Ok(())
879 }
880 }
881}
882
883struct Remainder<D: Dialect> {
884 _dialect: PhantomData<D>,
885}
886
887impl<D: Dialect> Remainder<D> {
888 fn format(
889 f: &mut core::fmt::Formatter<'_>,
890 lhs: &Variable<D>,
891 rhs: &Variable<D>,
892 out: &Variable<D>,
893 ) -> core::fmt::Result {
894 let floor = |elem| {
895 let prefix = match elem {
896 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
897 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
898 _ => "",
899 };
900 format!("{prefix}floor")
901 };
902
903 let is_int = matches!(
904 out.elem(),
905 Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
906 );
907 let out_elem = out.elem();
908 let rem_expr = |lhs, rhs, floor: &str| {
909 if is_int {
910 format!("{lhs} - {rhs} * ({out_elem}){floor}((float){lhs} / (float){rhs})")
911 } else {
912 format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
913 }
914 };
915
916 if out.item().vectorization == 1 {
917 let floor = floor(out.elem());
918
919 let out = out.fmt_left();
920 let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
921 return writeln!(f, "{out} = {rem};");
922 }
923
924 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
925 let [lhs, rhs, out_optimized] = optimized.args;
926
927 let item_out_original = out.item();
928 let item_out_optimized = out_optimized.item();
929
930 let index = match optimized.optimization_factor {
931 Some(factor) => item_out_original.vectorization / factor,
932 None => item_out_optimized.vectorization,
933 };
934
935 let floor = floor(*item_out_optimized.elem());
936
937 let mut write_op =
938 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
939 let out = out.fmt_left();
940 writeln!(f, "{out} = {item_out}{{")?;
941 for i in 0..index {
942 let lhsi = lhs.index(i);
943 let rhsi = rhs.index(i);
944
945 let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
946 writeln!(f, "{rem}")?;
947 f.write_str(", ")?;
948 }
949
950 f.write_str("};\n")
951 };
952
953 if item_out_original == item_out_optimized {
954 write_op(&lhs, &rhs, out, item_out_optimized)
955 } else {
956 let out_tmp = Variable::tmp(item_out_optimized);
957
958 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
959
960 let addr_space = D::address_space_for_variable(&out_tmp);
961 let qualifier = out.const_qualifier();
962 let out = out.fmt_left();
963
964 writeln!(
965 f,
966 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
967 )?;
968
969 Ok(())
970 }
971 }
972}
973
974struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
975 _dialect: PhantomData<D>,
976 _sqrt: PhantomData<S>,
977}
978
979impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
980 fn format(
981 f: &mut core::fmt::Formatter<'_>,
982 input: &Variable<D>,
983 out: &Variable<D>,
984 ) -> core::fmt::Result {
985 let num = input.item().vectorization;
986 let elem = input.elem();
987
988 let mag = format!("{out}_mag");
989
990 writeln!(f, "{} {mag} = 0.0;", out.item())?;
991
992 for i in 0..num {
993 let input_i = input.index(i);
994 writeln!(f, "{mag} += {input_i} * {input_i};")?;
995 }
996
997 let out = out.fmt_left();
998 write!(f, "{out} = ")?;
999 S::format_unary(f, &mag, elem)?;
1000 f.write_str(";\n")
1001 }
1002}
1003
1004struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1005 _dialect: PhantomData<D>,
1006 _rsqrt: PhantomData<InvS>,
1007}
1008
1009impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1010 fn format(
1011 f: &mut core::fmt::Formatter<'_>,
1012 input: &Variable<D>,
1013 out: &Variable<D>,
1014 ) -> core::fmt::Result {
1015 let num = input.item().vectorization;
1016 let elem = input.elem();
1017 let norm = format!("{out}_norm");
1018
1019 let out_item = out.item();
1020 let out = out.fmt_left();
1021 writeln!(f, "{elem} {norm} = 0.0;")?;
1022
1023 for i in 0..num {
1024 let input_i = input.index(i);
1025 writeln!(f, "{norm} += {input_i} * {input_i};")?;
1026 }
1027
1028 write!(f, "{norm} = ")?;
1029 InvS::format_unary(f, &norm, elem)?;
1030 f.write_str(";\n")?;
1031
1032 if num == 1 {
1033 writeln!(f, "{out} = {input} * {norm};")
1034 } else {
1035 write!(f, "{out} = {out_item}{{")?;
1036 for i in 0..num {
1037 let input_i = input.index(i);
1038
1039 writeln!(f, "{input_i} * {norm},")?;
1040 }
1041
1042 f.write_str("};\n")
1043 }
1044 }
1045}
1046
1047struct Dot<D: Dialect> {
1048 _dialect: PhantomData<D>,
1049}
1050
1051impl<D: Dialect> Dot<D> {
1052 fn format(
1053 f: &mut core::fmt::Formatter<'_>,
1054 lhs: &Variable<D>,
1055 rhs: &Variable<D>,
1056 out: &Variable<D>,
1057 ) -> core::fmt::Result {
1058 let num = lhs.item().vectorization;
1059
1060 let muls = (0..num)
1061 .map(|i| {
1062 let lhs_i = lhs.index(i);
1063 let rhs_i = rhs.index(i);
1064 format!("{lhs_i} * {rhs_i}")
1065 })
1066 .collect::<Vec<_>>();
1067
1068 let out = out.fmt_left();
1069 writeln!(f, "{out} = {};", muls.join(" + "))
1070 }
1071}
1072
1073struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1074 var: &'a V,
1075 elem: &'a Elem<D>,
1076}
1077
1078impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1079 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1080 if self.elem != &Elem::Bool {
1081 write!(f, "bool({})", self.var)
1082 } else {
1083 write!(f, "{}", self.var)
1084 }
1085 }
1086}