1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Formatter, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const DYNAMIC_META_NAME: &str = "dynamic_meta";
15pub(crate) const STATIC_META_NAME: &str = "info.static_meta";
16
17#[derive(Debug, Clone, Copy)]
18pub struct BinaryInstruction<D: Dialect> {
19 pub lhs: Variable<D>,
20 pub rhs: Variable<D>,
21 pub out: Variable<D>,
22}
23
24#[derive(Debug, Clone)]
25pub struct IndexInstruction<D: Dialect> {
26 pub list: Variable<D>,
27 pub index: Variable<D>,
28 pub vector_size: u32,
29 pub out: Variable<D>,
30}
31
32#[derive(Debug, Clone)]
33pub struct IndexAssignInstruction<D: Dialect> {
34 pub index: Variable<D>,
35 pub value: Variable<D>,
36 pub vector_size: u32,
37 pub out: Variable<D>,
38}
39
40#[derive(Debug, Clone, Copy)]
41pub struct UnaryInstruction<D: Dialect> {
42 pub input: Variable<D>,
43 pub out: Variable<D>,
44}
45
46#[derive(Debug, Clone)]
47pub enum Instruction<D: Dialect> {
48 Metadata {
49 info_offset: Variable<D>,
50 out: Variable<D>,
51 },
52 ExtendedMetadata {
53 info_offset: Variable<D>,
54 dim: Variable<D>,
55 out: Variable<D>,
56 },
57 ConstLength {
58 length: usize,
59 out: Variable<D>,
60 },
61 SliceLength {
62 input: Variable<D>,
63 out: Variable<D>,
64 },
65 DeclareVariable {
66 var: Variable<D>,
67 },
68 Modulo(BinaryInstruction<D>),
69 Remainder(BinaryInstruction<D>),
70 Add(BinaryInstruction<D>),
71 SaturatingAdd(BinaryInstruction<D>),
72 Fma {
73 a: Variable<D>,
74 b: Variable<D>,
75 c: Variable<D>,
76 out: Variable<D>,
77 },
78 Div(BinaryInstruction<D>),
79 FastDiv(BinaryInstruction<D>),
80 FastRecip(UnaryInstruction<D>),
81 Mul(BinaryInstruction<D>),
82 Sub(BinaryInstruction<D>),
83 SaturatingSub(BinaryInstruction<D>),
84 HiMul(BinaryInstruction<D>),
85 Index(IndexInstruction<D>),
86 IndexAssign(IndexAssignInstruction<D>),
87 Assign(UnaryInstruction<D>),
88 SpecialCast(UnaryInstruction<D>),
89 RangeLoop {
90 i: Variable<D>,
91 start: Variable<D>,
92 end: Variable<D>,
93 step: Option<Variable<D>>,
94 inclusive: bool,
95 instructions: Vec<Self>,
96 },
97 VecInit {
98 inputs: Vec<Variable<D>>,
99 out: Variable<D>,
100 },
101 Loop {
102 instructions: Vec<Self>,
103 },
104 If {
105 cond: Variable<D>,
106 instructions: Vec<Self>,
107 },
108 IfElse {
109 cond: Variable<D>,
110 instructions_if: Vec<Self>,
111 instructions_else: Vec<Self>,
112 },
113 Select {
114 cond: Variable<D>,
115 then: Variable<D>,
116 or_else: Variable<D>,
117 out: Variable<D>,
118 },
119 Switch {
120 value: Variable<D>,
121 instructions_default: Vec<Self>,
122 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
123 },
124 Slice {
125 input: Variable<D>,
126 start: Variable<D>,
127 end: Variable<D>,
128 out: Variable<D>,
129 },
130 CheckedSlice {
131 input: Variable<D>,
132 start: Variable<D>,
133 end: Variable<D>,
134 out: Variable<D>,
135 len: Variable<D>,
136 },
137 ReinterpretSlice {
138 input: Variable<D>,
139 vector_size: u32,
140 out: Variable<D>,
141 },
142 Return,
143 Break,
144 Unreachable,
145 Equal(BinaryInstruction<D>),
146 NotEqual(BinaryInstruction<D>),
147 Lower(BinaryInstruction<D>),
148 Greater(BinaryInstruction<D>),
149 LowerEqual(BinaryInstruction<D>),
150 GreaterEqual(BinaryInstruction<D>),
151 Erf(UnaryInstruction<D>),
152 BitwiseOr(BinaryInstruction<D>),
153 BitwiseAnd(BinaryInstruction<D>),
154 BitwiseXor(BinaryInstruction<D>),
155 CountBits(UnaryInstruction<D>),
156 ReverseBits(UnaryInstruction<D>),
157 ShiftLeft(BinaryInstruction<D>),
158 ShiftRight(BinaryInstruction<D>),
159 BitwiseNot(UnaryInstruction<D>),
160 LeadingZeros(UnaryInstruction<D>),
161 TrailingZeros(UnaryInstruction<D>),
162 FindFirstSet(UnaryInstruction<D>),
163 Abs(UnaryInstruction<D>),
164 Exp(UnaryInstruction<D>),
165 FastExp(UnaryInstruction<D>),
166 Log(UnaryInstruction<D>),
167 FastLog(UnaryInstruction<D>),
168 Log1p(UnaryInstruction<D>),
169 Cos(UnaryInstruction<D>),
170 Sin(UnaryInstruction<D>),
171 Tan(UnaryInstruction<D>),
172 Tanh(UnaryInstruction<D>),
173 Sinh(UnaryInstruction<D>),
174 Cosh(UnaryInstruction<D>),
175 ArcCos(UnaryInstruction<D>),
176 ArcSin(UnaryInstruction<D>),
177 ArcTan(UnaryInstruction<D>),
178 ArcSinh(UnaryInstruction<D>),
179 ArcCosh(UnaryInstruction<D>),
180 ArcTanh(UnaryInstruction<D>),
181 Degrees(UnaryInstruction<D>),
182 Radians(UnaryInstruction<D>),
183 ArcTan2(BinaryInstruction<D>),
184 FastSin(UnaryInstruction<D>),
185 FastCos(UnaryInstruction<D>),
186 FastTanh(UnaryInstruction<D>),
187 Powf(BinaryInstruction<D>),
188 FastPowf(BinaryInstruction<D>),
189 Powi(BinaryInstruction<D>),
190 Hypot(BinaryInstruction<D>),
191 Rhypot(BinaryInstruction<D>),
192 Sqrt(UnaryInstruction<D>),
193 FastSqrt(UnaryInstruction<D>),
194 InverseSqrt(UnaryInstruction<D>),
195 FastInverseSqrt(UnaryInstruction<D>),
196 Min(BinaryInstruction<D>),
197 Max(BinaryInstruction<D>),
198 Not(UnaryInstruction<D>),
199 Or(BinaryInstruction<D>),
200 And(BinaryInstruction<D>),
201 Clamp {
202 input: Variable<D>,
203 min_value: Variable<D>,
204 max_value: Variable<D>,
205 out: Variable<D>,
206 },
207 IsNan(UnaryInstruction<D>),
208 IsInf(UnaryInstruction<D>),
209 SyncThreads,
210 SyncWarp,
211 ThreadFence,
212 ProxyAsyncToSharedFence,
213 BulkCommitGroup,
214 BulkWaitGroup {
215 max_pending: u32,
216 },
217 BulkWaitGroupRead {
218 max_pending: u32,
219 },
220 TmaReplacePointer {
221 buffer: Variable<D>,
222 offset: Variable<D>,
223 tensor_map: Variable<D>,
224 out: Variable<D>,
225 },
226 Round(UnaryInstruction<D>),
227 Ceil(UnaryInstruction<D>),
228 Trunc(UnaryInstruction<D>),
229 Floor(UnaryInstruction<D>),
230 Warp(WarpInstruction<D>),
231 Wmma(WmmaInstruction<D>),
232 Bitcast(UnaryInstruction<D>),
233 AtomicLoad(UnaryInstruction<D>),
234 AtomicStore(UnaryInstruction<D>),
235 AtomicSwap(BinaryInstruction<D>),
236 AtomicAdd(BinaryInstruction<D>),
237 AtomicSub(BinaryInstruction<D>),
238 AtomicMax(BinaryInstruction<D>),
239 AtomicMin(BinaryInstruction<D>),
240 AtomicAnd(BinaryInstruction<D>),
241 AtomicOr(BinaryInstruction<D>),
242 AtomicXor(BinaryInstruction<D>),
243 AtomicCAS {
244 input: Variable<D>,
245 cmp: Variable<D>,
246 val: Variable<D>,
247 out: Variable<D>,
248 },
249 Neg(UnaryInstruction<D>),
250 Magnitude(UnaryInstruction<D>),
251 FastMagnitude(UnaryInstruction<D>),
252 Normalize(UnaryInstruction<D>),
253 FastNormalize(UnaryInstruction<D>),
254 Dot(BinaryInstruction<D>),
255 VectorSum(UnaryInstruction<D>),
256 Copy {
257 input: Variable<D>,
258 in_index: Variable<D>,
259 out: Variable<D>,
260 out_index: Variable<D>,
261 },
262 CopyBulk {
263 input: Variable<D>,
264 in_index: Variable<D>,
265 out: Variable<D>,
266 out_index: Variable<D>,
267 len: u32,
268 },
269 Printf {
270 format_string: String,
271 args: Vec<Variable<D>>,
272 },
273 Comment {
274 content: String,
275 },
276 Pipeline(PipelineOps<D>),
277 Barrier(BarrierOps<D>),
278 MemCopyAsyncTensorSharedToGlobal {
279 smem_buffer: Variable<D>,
280 smem_offset: Variable<D>,
281 tensor_map: Variable<D>,
282 indices: Vec<Variable<D>>,
283 },
284 Line {
285 file: Cow<'static, str>,
286 line: u32,
287 },
288}
289
290impl<D: Dialect> Display for Instruction<D> {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 match self {
293 Instruction::Return => f.write_str("return;"),
294 Instruction::Break => f.write_str("break;"),
295 Instruction::Unreachable => D::compile_unreachable(f),
296 Instruction::DeclareVariable { var } => match var {
297 Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
298 _ => {
299 let item = var.item();
300 writeln!(f, "{item} {var};")
301 }
302 },
303 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
304 Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
305 Instruction::Slice {
306 input,
307 start,
308 end,
309 out,
310 } => {
311 let item = out.item();
312 let addr_space = D::address_space_for_variable(input);
313 writeln!(f, "const uint {out}_length = {end} - {start};")?;
314 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
315 }
316 Instruction::CheckedSlice {
317 input,
318 start,
319 end,
320 out,
321 len,
322 } => {
323 let item = out.item();
324 let addr_space = D::address_space_for_variable(input);
325 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
326 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
327 }
328 Instruction::ReinterpretSlice {
329 input,
330 vector_size,
331 out,
332 } => {
333 let mut item = out.item();
334 item.vectorization = *vector_size as usize;
335 let addr_space = D::address_space_for_variable(input);
336
337 writeln!(
338 f,
339 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
340 )
341 }
342 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
343 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
344 Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
345 Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
346 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
347 Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
348 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
349 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
350 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
351 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
352 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
353 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
354 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
355 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
356 Instruction::TrailingZeros(it) => TrailingZeros::format(f, &it.input, &it.out),
357 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
358 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
359 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
360 Instruction::Index(it) => {
361 Index::format(f, &it.list, &it.index, &it.out, it.vector_size)
362 }
363 Instruction::IndexAssign(it) => {
364 IndexAssign::format(f, &it.index, &it.value, &it.out, it.vector_size)
365 }
366 Instruction::Copy {
367 input,
368 in_index,
369 out,
370 out_index,
371 } => {
372 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
373 }
374 Instruction::CopyBulk {
375 input,
376 in_index,
377 out,
378 out_index,
379 len,
380 } => {
381 for i in 0..*len {
382 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
383 }
384 Ok(())
385 }
386 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
387 Instruction::RangeLoop {
388 i,
389 start,
390 end,
391 step,
392 inclusive,
393 instructions,
394 } => {
395 let increment = step
396 .map(|step| format!("{i} += {step}"))
397 .unwrap_or_else(|| format!("++{i}"));
398 let cmp = if *inclusive { "<=" } else { "<" };
399 let i_ty = i.item();
400
401 write!(
402 f,
403 "
404for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
405"
406 )?;
407 for instruction in instructions {
408 write!(f, "{instruction}")?;
409 }
410
411 f.write_str("}\n")
412 }
413 Instruction::Loop { instructions } => {
414 writeln!(f, "while (true) {{")?;
415 for i in instructions {
416 write!(f, "{i}")?;
417 }
418 f.write_str("}\n")
419 }
420 Instruction::If { cond, instructions } => {
421 writeln!(f, "if ({cond}) {{")?;
422 for i in instructions {
423 write!(f, "{i}")?;
424 }
425 f.write_str("}\n")
426 }
427 Instruction::IfElse {
428 cond,
429 instructions_if,
430 instructions_else,
431 } => {
432 writeln!(f, "if ({cond}) {{")?;
433 for i in instructions_if {
434 write!(f, "{i}")?;
435 }
436 f.write_str("} else {\n")?;
437 for i in instructions_else {
438 write!(f, "{i}")?;
439 }
440 f.write_str("}\n")
441 }
442 Instruction::Select {
443 cond,
444 then,
445 or_else,
446 out,
447 } => {
448 let item_or_else = or_else.item();
449 let item_then = then.item();
450 let item_out = out.item();
451
452 let vf_then = item_then.vectorization;
453 let vf_or_else = item_or_else.vectorization;
454 let vf_out = item_out.vectorization;
455 let vf_cond = cond.item().vectorization;
456
457 let item_out = out.item();
458 let cond_elem = cond.item().elem;
459 let out = out.fmt_left();
460
461 let vf = usize::max(vf_cond, vf_out);
466 let vf = usize::max(vf, vf_then);
467 let vf = usize::max(vf, vf_or_else);
468 let should_broadcast = vf > 1;
469
470 if should_broadcast {
476 writeln!(f, "{out} = {item_out} {{")?;
477 for i in 0..vf {
478 let theni = then.index(i);
479 let or_elsei = or_else.index(i);
480 let condi = cond.index(i);
481 let condi = EnsureBoolArg {
482 var: &condi,
483 elem: &cond_elem,
484 };
485
486 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
487 }
488
489 writeln!(f, "}};")
490 } else {
491 let cond = EnsureBoolArg {
492 var: &cond,
493 elem: &cond_elem,
494 };
495 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
496 }
497 }
498 Instruction::Switch {
499 value,
500 instructions_default,
501 instructions_cases,
502 } => {
503 writeln!(f, "switch({value}) {{")?;
504 for (value, block) in instructions_cases {
505 write!(f, "case {value}:\n{{\n")?;
506 for i in block {
507 i.fmt(f)?;
508 }
509 f.write_str("break;\n}\n")?;
510 }
511 f.write_str("default:\n{")?;
512 for i in instructions_default {
513 i.fmt(f)?;
514 }
515 f.write_str("break;\n}\n}\n")
516 }
517 Instruction::Metadata { info_offset, out } => {
518 let out = out.fmt_left();
519 writeln!(f, "{out} = {STATIC_META_NAME}[{info_offset}];")
520 }
521 Instruction::ExtendedMetadata {
522 info_offset,
523 dim,
524 out,
525 } => {
526 let out = out.fmt_left();
527 writeln!(
528 f,
529 "{out} = {DYNAMIC_META_NAME}[{STATIC_META_NAME}[{info_offset}] + {dim}];"
530 )
531 }
532 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
533 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
534 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
535 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
536 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
537 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
538 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
539 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
540 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
541 Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
542 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
543 Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
544 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
545 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
546 Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
547 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
548 Instruction::Tan(it) => Tan::format(f, &it.input, &it.out),
549 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
550 Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out),
551 Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out),
552 Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out),
553 Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out),
554 Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out),
555 Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out),
556 Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out),
557 Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out),
558 Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out),
559 Instruction::Radians(it) => Radians::format(f, &it.input, &it.out),
560 Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out),
561 Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
562 Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
563 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
564 Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
565 Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
566 Instruction::Hypot(it) => Hypot::format(f, &it.lhs, &it.rhs, &it.out),
567 Instruction::Rhypot(it) => Rhypot::format(f, &it.lhs, &it.rhs, &it.out),
568 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
569 Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
570 Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
571 Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
572 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
573 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
574 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
575 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
576 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
577 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
578 Instruction::Clamp {
579 input,
580 min_value,
581 max_value,
582 out,
583 } => Clamp::format(f, input, min_value, max_value, out),
584 Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
585 Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
586 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
587 Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
588 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
589 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
590 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
591 Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
592 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
593 Instruction::SliceLength { input, out } => {
594 let out = out.fmt_left();
595 writeln!(f, "{out} = {input}_length;")
596 }
597 Instruction::ConstLength { length, out } => {
598 let out = out.fmt_left();
599 writeln!(f, "{out} = {length};")
600 }
601 Instruction::Warp(it) => write!(f, "{it}"),
602 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
603 Instruction::Wmma(it) => write!(f, "{it}"),
604 Instruction::Bitcast(UnaryInstruction { input, out }) => {
605 let qualifier = out.const_qualifier();
606 let input_item = input.item();
607 let out_item = out.item();
608
609 if out_item.elem.size() * out_item.vectorization
610 != input.item().elem.size() * input.item().vectorization
611 {
612 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
613 } else {
614 let out = out.fmt_left();
615 let addr_space = D::address_space_for_variable(input);
616 writeln!(
617 f,
618 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
619 )
620 }
621 }
622 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
623 D::compile_atomic_add(f, lhs, rhs, out)
624 }
625 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
626 D::compile_atomic_and(f, lhs, rhs, out)
627 }
628 Instruction::AtomicCAS {
629 input,
630 cmp,
631 val,
632 out,
633 } => D::compile_atomic_cas(f, input, cmp, val, out),
634 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
635 D::compile_atomic_load(f, input, out)
636 }
637 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
638 D::compile_atomic_max(f, lhs, rhs, out)
639 }
640 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
641 D::compile_atomic_min(f, lhs, rhs, out)
642 }
643 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
644 D::compile_atomic_or(f, lhs, rhs, out)
645 }
646 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
647 D::compile_atomic_store(f, input, out)
648 }
649 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
650 D::compile_atomic_sub(f, lhs, rhs, out)
651 }
652 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
653 D::compile_atomic_swap(f, lhs, rhs, out)
654 }
655 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
656 D::compile_atomic_xor(f, lhs, rhs, out)
657 }
658 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
659 Instruction::Neg(UnaryInstruction { input, out }) => {
660 let out = out.fmt_left();
661 writeln!(f, "{out} = -{input};")
662 }
663 Instruction::Normalize(inst) => {
664 Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
665 }
666 Instruction::FastNormalize(inst) => {
667 Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
668 }
669 Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
670 Instruction::FastMagnitude(inst) => {
671 Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
672 }
673 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
674 Instruction::VectorSum(inst) => VectorSumFmt::<D>::format(f, &inst.input, &inst.out),
675 Instruction::VecInit { inputs, out } => {
676 let item = out.item();
677 let inputs = inputs
678 .iter()
679 .map(|input| format!("{input}"))
680 .collect::<Vec<_>>();
681 let out = out.fmt_left();
682 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
683 }
684 Instruction::Printf {
685 format_string,
686 args,
687 } => D::compile_instruction_printf(f, format_string, args),
688 Instruction::Comment { content } => {
689 if content.contains('\n') {
690 writeln!(f, "/* {content} */")
691 } else {
692 writeln!(f, "// {content}")
693 }
694 }
695 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
696 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
697 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
698 Instruction::ProxyAsyncToSharedFence => {
699 writeln!(
700 f,
701 "cuda::device::experimental::fence_proxy_async_shared_cta();"
702 )
703 }
704 Instruction::BulkCommitGroup => writeln!(
705 f,
706 "cuda::device::experimental::cp_async_bulk_commit_group();"
707 ),
708 Instruction::BulkWaitGroup { max_pending } => writeln!(
709 f,
710 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
711 ),
712 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
713 f,
714 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
715 ),
716 Instruction::TmaReplacePointer {
717 buffer,
718 offset,
719 tensor_map,
720 out,
721 } => {
722 let pos = Variable::<D>::UnitPos;
723 writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
724 writeln!(
725 f,
726 "
727if({pos} == 0) {{
728 {out} = {tensor_map};
729 tensormap_replace_global_address({out}, &{buffer}[{offset}]);
730}}"
731 )?;
732 writeln!(f, "__syncthreads();")
733 }
734 Instruction::MemCopyAsyncTensorSharedToGlobal {
735 smem_buffer,
736 smem_offset,
737 tensor_map,
738 indices,
739 } => {
740 let rank = indices.len();
741 let smem_ptr = smem_buffer.fmt_ptr();
742 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
743 let _ = write!(s, "{it}, ");
744 s
745 });
746 writeln!(
747 f,
748 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
749 )
750 }
751 Instruction::SpecialCast(UnaryInstruction { input, out }) => {
752 #[cfg(not(feature = "cuda"))]
754 {
755 let _ = (input, out);
756 writeln!(
757 f,
758 "#error FP8/FP6/FP4 casting isn't supported outside of CUDA"
759 )
760 }
761 #[cfg(feature = "cuda")]
762 crate::cuda::convert::special_cast::<D>(f, input, out)
763 }
764 }
765 }
766}
767
768struct Fma<D: Dialect> {
769 _dialect: PhantomData<D>,
770}
771
772impl<D: Dialect> Fma<D> {
773 fn format(
774 f: &mut core::fmt::Formatter<'_>,
775 a: &Variable<D>,
776 b: &Variable<D>,
777 c: &Variable<D>,
778 out: &Variable<D>,
779 ) -> core::fmt::Result {
780 let out_item = out.item();
781 let num = out_item.vectorization;
782
783 let out = out.fmt_left();
784 if num == 1 {
785 writeln!(f, "{out} = fma({a}, {b}, {c});")
786 } else {
787 writeln!(f, "{out} = {out_item}{{")?;
788
789 for i in 0..num {
790 let ai = a.index(i);
791 let bi = b.index(i);
792 let ci = c.index(i);
793
794 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
795 }
796 f.write_str("};\n")
797 }
798 }
799}
800
801struct Clamp<D: Dialect> {
802 _dialect: PhantomData<D>,
803}
804
805impl<D: Dialect> Clamp<D> {
806 fn format(
807 f: &mut core::fmt::Formatter<'_>,
808 input: &Variable<D>,
809 min_value: &Variable<D>,
810 max_value: &Variable<D>,
811 out: &Variable<D>,
812 ) -> core::fmt::Result {
813 let out_item = out.item();
814 if out.item().vectorization == 1 {
815 let out = out.fmt_left();
816 write!(f, "{out} = ")?;
817 Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
818 f.write_str(";\n")
819 } else {
820 Self::unroll_vec(f, input, min_value, max_value, out)
821 }
822 }
823
824 fn format_scalar(
825 f: &mut Formatter<'_>,
826 input: impl Component<D>,
827 min_value: impl Component<D>,
828 max_value: impl Component<D>,
829 item: Item<D>,
830 ) -> std::fmt::Result {
831 D::compile_instruction_max_function_name(f, item)?;
832 write!(f, "({min_value}, ")?;
833 D::compile_instruction_min_function_name(f, item)?;
834 write!(f, "({max_value}, {input}))")
835 }
836
837 fn unroll_vec(
838 f: &mut core::fmt::Formatter<'_>,
839 input: &Variable<D>,
840 min_value: &Variable<D>,
841 max_value: &Variable<D>,
842 out: &Variable<D>,
843 ) -> std::fmt::Result {
844 let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
845 let [input, min_value, max_value, out_optimized] = optimized.args;
846
847 let item_out_original = out.item();
848 let item_out_optimized = out_optimized.item();
849
850 let index = match optimized.optimization_factor {
851 Some(factor) => item_out_original.vectorization / factor,
852 None => item_out_optimized.vectorization,
853 };
854
855 let mut write_op = |input: &Variable<D>,
856 min_value: &Variable<D>,
857 max_value: &Variable<D>,
858 out: &Variable<D>,
859 item_out: Item<D>| {
860 let out = out.fmt_left();
861 writeln!(f, "{out} = {item_out}{{")?;
862 for i in 0..index {
863 let inputi = input.index(i);
864 let min_valuei = min_value.index(i);
865 let max_valuei = max_value.index(i);
866
867 Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
868 f.write_str(", ")?;
869 }
870
871 f.write_str("};\n")
872 };
873
874 if item_out_original == item_out_optimized {
875 write_op(&input, &min_value, &max_value, out, item_out_optimized)
876 } else {
877 let out_tmp = Variable::tmp(item_out_optimized);
878 write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
879 let addr_space = D::address_space_for_variable(out);
880 let out = out.fmt_left();
881
882 writeln!(
883 f,
884 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
885 )?;
886
887 Ok(())
888 }
889 }
890}
891
892struct Remainder<D: Dialect> {
893 _dialect: PhantomData<D>,
894}
895
896impl<D: Dialect> Remainder<D> {
897 fn format(
898 f: &mut core::fmt::Formatter<'_>,
899 lhs: &Variable<D>,
900 rhs: &Variable<D>,
901 out: &Variable<D>,
902 ) -> core::fmt::Result {
903 let floor = |elem| {
904 let prefix = match elem {
905 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
906 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
907 _ => "",
908 };
909 format!("{prefix}floor")
910 };
911
912 let is_int = matches!(
913 out.elem(),
914 Elem::I8 | Elem::I16 | Elem::I32 | Elem::U8 | Elem::U16 | Elem::U32 | Elem::U64
915 );
916 let out_elem = out.elem();
917 let rem_expr = |lhs, rhs, floor: &str| {
918 if is_int {
919 format!("{lhs} - {rhs} * ({out_elem}){floor}((float){lhs} / (float){rhs})")
920 } else {
921 format!("{lhs} - {rhs} * {floor}({lhs} / {rhs})")
922 }
923 };
924
925 if out.item().vectorization == 1 {
926 let floor = floor(out.elem());
927
928 let out = out.fmt_left();
929 let rem = rem_expr(lhs.to_string(), rhs.to_string(), &floor);
930 return writeln!(f, "{out} = {rem};");
931 }
932
933 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
934 let [lhs, rhs, out_optimized] = optimized.args;
935
936 let item_out_original = out.item();
937 let item_out_optimized = out_optimized.item();
938
939 let index = match optimized.optimization_factor {
940 Some(factor) => item_out_original.vectorization / factor,
941 None => item_out_optimized.vectorization,
942 };
943
944 let floor = floor(*item_out_optimized.elem());
945
946 let mut write_op =
947 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
948 let out = out.fmt_left();
949 writeln!(f, "{out} = {item_out}{{")?;
950 for i in 0..index {
951 let lhsi = lhs.index(i);
952 let rhsi = rhs.index(i);
953
954 let rem = rem_expr(lhsi.to_string(), rhsi.to_string(), &floor);
955 writeln!(f, "{rem}")?;
956 f.write_str(", ")?;
957 }
958
959 f.write_str("};\n")
960 };
961
962 if item_out_original == item_out_optimized {
963 write_op(&lhs, &rhs, out, item_out_optimized)
964 } else {
965 let out_tmp = Variable::tmp(item_out_optimized);
966
967 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
968
969 let addr_space = D::address_space_for_variable(&out_tmp);
970 let qualifier = out.const_qualifier();
971 let out = out.fmt_left();
972
973 writeln!(
974 f,
975 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
976 )?;
977
978 Ok(())
979 }
980 }
981}
982
983struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
984 _dialect: PhantomData<D>,
985 _sqrt: PhantomData<S>,
986}
987
988impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
989 fn format(
990 f: &mut core::fmt::Formatter<'_>,
991 input: &Variable<D>,
992 out: &Variable<D>,
993 ) -> core::fmt::Result {
994 let num = input.item().vectorization;
995 let elem = input.elem();
996
997 let mag = format!("{out}_mag");
998
999 writeln!(f, "{} {mag} = 0.0;", out.item())?;
1000
1001 for i in 0..num {
1002 let input_i = input.index(i);
1003 writeln!(f, "{mag} += {input_i} * {input_i};")?;
1004 }
1005
1006 let out = out.fmt_left();
1007 write!(f, "{out} = ")?;
1008 S::format_unary(f, &mag, elem)?;
1009 f.write_str(";\n")
1010 }
1011}
1012
1013struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
1014 _dialect: PhantomData<D>,
1015 _rsqrt: PhantomData<InvS>,
1016}
1017
1018impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
1019 fn format(
1020 f: &mut core::fmt::Formatter<'_>,
1021 input: &Variable<D>,
1022 out: &Variable<D>,
1023 ) -> core::fmt::Result {
1024 let num = input.item().vectorization;
1025 let elem = input.elem();
1026 let norm = format!("{out}_norm");
1027
1028 let out_item = out.item();
1029 let out = out.fmt_left();
1030 writeln!(f, "{elem} {norm} = 0.0;")?;
1031
1032 for i in 0..num {
1033 let input_i = input.index(i);
1034 writeln!(f, "{norm} += {input_i} * {input_i};")?;
1035 }
1036
1037 write!(f, "{norm} = ")?;
1038 InvS::format_unary(f, &norm, elem)?;
1039 f.write_str(";\n")?;
1040
1041 if num == 1 {
1042 writeln!(f, "{out} = {input} * {norm};")
1043 } else {
1044 write!(f, "{out} = {out_item}{{")?;
1045 for i in 0..num {
1046 let input_i = input.index(i);
1047
1048 writeln!(f, "{input_i} * {norm},")?;
1049 }
1050
1051 f.write_str("};\n")
1052 }
1053 }
1054}
1055
1056struct Dot<D: Dialect> {
1057 _dialect: PhantomData<D>,
1058}
1059
1060impl<D: Dialect> Dot<D> {
1061 fn format(
1062 f: &mut core::fmt::Formatter<'_>,
1063 lhs: &Variable<D>,
1064 rhs: &Variable<D>,
1065 out: &Variable<D>,
1066 ) -> core::fmt::Result {
1067 let num = lhs.item().vectorization;
1068
1069 let muls = (0..num)
1070 .map(|i| {
1071 let lhs_i = lhs.index(i);
1072 let rhs_i = rhs.index(i);
1073 format!("{lhs_i} * {rhs_i}")
1074 })
1075 .collect::<Vec<_>>();
1076
1077 let out = out.fmt_left();
1078 writeln!(f, "{out} = {};", muls.join(" + "))
1079 }
1080}
1081
1082struct VectorSumFmt<D: Dialect> {
1083 _dialect: PhantomData<D>,
1084}
1085
1086impl<D: Dialect> VectorSumFmt<D> {
1087 fn format(
1088 f: &mut core::fmt::Formatter<'_>,
1089 input: &Variable<D>,
1090 out: &Variable<D>,
1091 ) -> core::fmt::Result {
1092 let num = input.item().vectorization;
1093
1094 let elems = (0..num)
1095 .map(|i| format!("{}", input.index(i)))
1096 .collect::<Vec<_>>();
1097
1098 let out = out.fmt_left();
1099 writeln!(f, "{out} = {};", elems.join(" + "))
1100 }
1101}
1102
1103struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1104 var: &'a V,
1105 elem: &'a Elem<D>,
1106}
1107
1108impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1110 if self.elem != &Elem::Bool {
1111 write!(f, "bool({})", self.var)
1112 } else {
1113 write!(f, "{}", self.var)
1114 }
1115 }
1116}