1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Formatter, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone)]
17pub struct BinaryInstruction<D: Dialect> {
18 pub lhs: Variable<D>,
19 pub rhs: Variable<D>,
20 pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25 pub list: Variable<D>,
26 pub index: Variable<D>,
27 pub line_size: u32,
28 pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33 pub index: Variable<D>,
34 pub value: Variable<D>,
35 pub line_size: u32,
36 pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone)]
40pub struct UnaryInstruction<D: Dialect> {
41 pub input: Variable<D>,
42 pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47 Metadata {
48 info_offset: Variable<D>,
49 split_meta: bool,
50 out: Variable<D>,
51 },
52 ExtendedMetadata {
53 info_offset: Variable<D>,
54 dim: Variable<D>,
55 split_meta: bool,
56 static_offset: u32,
57 out: Variable<D>,
58 },
59 ConstLength {
60 length: u32,
61 out: Variable<D>,
62 },
63 SliceLength {
64 input: Variable<D>,
65 out: Variable<D>,
66 },
67 DeclareVariable {
68 var: Variable<D>,
69 },
70 Modulo(BinaryInstruction<D>),
71 Remainder(BinaryInstruction<D>),
72 Add(BinaryInstruction<D>),
73 SaturatingAdd(BinaryInstruction<D>),
74 Fma {
75 a: Variable<D>,
76 b: Variable<D>,
77 c: Variable<D>,
78 out: Variable<D>,
79 },
80 Div(BinaryInstruction<D>),
81 Mul(BinaryInstruction<D>),
82 Sub(BinaryInstruction<D>),
83 SaturatingSub(BinaryInstruction<D>),
84 HiMul(BinaryInstruction<D>),
85 Index(IndexInstruction<D>),
86 IndexAssign(IndexAssignInstruction<D>),
87 Assign(UnaryInstruction<D>),
88 SpecialCast(UnaryInstruction<D>),
89 RangeLoop {
90 i: Variable<D>,
91 start: Variable<D>,
92 end: Variable<D>,
93 step: Option<Variable<D>>,
94 inclusive: bool,
95 instructions: Vec<Self>,
96 },
97 VecInit {
98 inputs: Vec<Variable<D>>,
99 out: Variable<D>,
100 },
101 Loop {
102 instructions: Vec<Self>,
103 },
104 If {
105 cond: Variable<D>,
106 instructions: Vec<Self>,
107 },
108 IfElse {
109 cond: Variable<D>,
110 instructions_if: Vec<Self>,
111 instructions_else: Vec<Self>,
112 },
113 Select {
114 cond: Variable<D>,
115 then: Variable<D>,
116 or_else: Variable<D>,
117 out: Variable<D>,
118 },
119 Switch {
120 value: Variable<D>,
121 instructions_default: Vec<Self>,
122 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
123 },
124 Slice {
125 input: Variable<D>,
126 start: Variable<D>,
127 end: Variable<D>,
128 out: Variable<D>,
129 },
130 CheckedSlice {
131 input: Variable<D>,
132 start: Variable<D>,
133 end: Variable<D>,
134 out: Variable<D>,
135 len: Variable<D>,
136 },
137 ReinterpretSlice {
138 input: Variable<D>,
139 line_size: u32,
140 out: Variable<D>,
141 },
142 Return,
143 Break,
144 Equal(BinaryInstruction<D>),
145 NotEqual(BinaryInstruction<D>),
146 Lower(BinaryInstruction<D>),
147 Greater(BinaryInstruction<D>),
148 LowerEqual(BinaryInstruction<D>),
149 GreaterEqual(BinaryInstruction<D>),
150 Erf(UnaryInstruction<D>),
151 BitwiseOr(BinaryInstruction<D>),
152 BitwiseAnd(BinaryInstruction<D>),
153 BitwiseXor(BinaryInstruction<D>),
154 CountBits(UnaryInstruction<D>),
155 ReverseBits(UnaryInstruction<D>),
156 ShiftLeft(BinaryInstruction<D>),
157 ShiftRight(BinaryInstruction<D>),
158 BitwiseNot(UnaryInstruction<D>),
159 LeadingZeros(UnaryInstruction<D>),
160 FindFirstSet(UnaryInstruction<D>),
161 Abs(UnaryInstruction<D>),
162 Exp(UnaryInstruction<D>),
163 Log(UnaryInstruction<D>),
164 Log1p(UnaryInstruction<D>),
165 Cos(UnaryInstruction<D>),
166 Sin(UnaryInstruction<D>),
167 Tanh(UnaryInstruction<D>),
168 Powf(BinaryInstruction<D>),
169 Powi(BinaryInstruction<D>),
170 Sqrt(UnaryInstruction<D>),
171 Min(BinaryInstruction<D>),
172 Max(BinaryInstruction<D>),
173 Not(UnaryInstruction<D>),
174 Or(BinaryInstruction<D>),
175 And(BinaryInstruction<D>),
176 Clamp {
177 input: Variable<D>,
178 min_value: Variable<D>,
179 max_value: Variable<D>,
180 out: Variable<D>,
181 },
182 IsNan(UnaryInstruction<D>),
183 IsInf(UnaryInstruction<D>),
184 SyncThreads,
185 SyncWarp,
186 ThreadFence,
187 ProxySharedFence,
188 BulkCommitGroup,
189 BulkWaitGroup {
190 max_pending: u32,
191 },
192 BulkWaitGroupRead {
193 max_pending: u32,
194 },
195 TmaReplacePointer {
196 buffer: Variable<D>,
197 offset: Variable<D>,
198 tensor_map: Variable<D>,
199 out: Variable<D>,
200 },
201 Round(UnaryInstruction<D>),
202 Ceil(UnaryInstruction<D>),
203 Trunc(UnaryInstruction<D>),
204 Floor(UnaryInstruction<D>),
205 Warp(WarpInstruction<D>),
206 Wmma(WmmaInstruction<D>),
207 Bitcast(UnaryInstruction<D>),
208 AtomicLoad(UnaryInstruction<D>),
209 AtomicStore(UnaryInstruction<D>),
210 AtomicSwap(BinaryInstruction<D>),
211 AtomicAdd(BinaryInstruction<D>),
212 AtomicSub(BinaryInstruction<D>),
213 AtomicMax(BinaryInstruction<D>),
214 AtomicMin(BinaryInstruction<D>),
215 AtomicAnd(BinaryInstruction<D>),
216 AtomicOr(BinaryInstruction<D>),
217 AtomicXor(BinaryInstruction<D>),
218 AtomicCAS {
219 input: Variable<D>,
220 cmp: Variable<D>,
221 val: Variable<D>,
222 out: Variable<D>,
223 },
224 Neg(UnaryInstruction<D>),
225 Magnitude(UnaryInstruction<D>),
226 Normalize(UnaryInstruction<D>),
227 Dot(BinaryInstruction<D>),
228 Copy {
229 input: Variable<D>,
230 in_index: Variable<D>,
231 out: Variable<D>,
232 out_index: Variable<D>,
233 },
234 CopyBulk {
235 input: Variable<D>,
236 in_index: Variable<D>,
237 out: Variable<D>,
238 out_index: Variable<D>,
239 len: u32,
240 },
241 Printf {
242 format_string: String,
243 args: Vec<Variable<D>>,
244 },
245 Comment {
246 content: String,
247 },
248 Pipeline(PipelineOps<D>),
249 Barrier(BarrierOps<D>),
250 MemCopyAsyncTensorSharedToGlobal {
251 smem_buffer: Variable<D>,
252 smem_offset: Variable<D>,
253 tensor_map: Variable<D>,
254 indices: Vec<Variable<D>>,
255 },
256 Line {
257 file: Cow<'static, str>,
258 line: u32,
259 },
260}
261
262impl<D: Dialect> Display for Instruction<D> {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 match self {
265 Instruction::Return => f.write_str("return;"),
266 Instruction::Break => f.write_str("break;"),
267 Instruction::DeclareVariable { var } => match var {
268 Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
269 _ => {
270 let item = var.item();
271 writeln!(f, "{item} {var};")
272 }
273 },
274 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
275 Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
276 Instruction::Slice {
277 input,
278 start,
279 end,
280 out,
281 } => {
282 let item = out.item();
283 let addr_space = D::address_space_for_variable(input);
284 writeln!(f, "const uint {out}_length = {end} - {start};")?;
285 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
286 }
287 Instruction::CheckedSlice {
288 input,
289 start,
290 end,
291 out,
292 len,
293 } => {
294 let item = out.item();
295 let addr_space = D::address_space_for_variable(input);
296 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
297 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
298 }
299 Instruction::ReinterpretSlice {
300 input,
301 line_size,
302 out,
303 } => {
304 let mut item = out.item();
305 item.vectorization = *line_size as usize;
306 let addr_space = D::address_space_for_variable(input);
307
308 writeln!(
309 f,
310 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
311 )
312 }
313 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
314 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
315 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
316 Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
317 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
318 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
319 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
320 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
321 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
322 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
323 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
324 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
325 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
326 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
327 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
328 Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
329 Instruction::IndexAssign(it) => {
330 IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
331 }
332 Instruction::Copy {
333 input,
334 in_index,
335 out,
336 out_index,
337 } => {
338 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
339 }
340 Instruction::CopyBulk {
341 input,
342 in_index,
343 out,
344 out_index,
345 len,
346 } => {
347 for i in 0..*len {
348 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
349 }
350 Ok(())
351 }
352 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
353 Instruction::RangeLoop {
354 i,
355 start,
356 end,
357 step,
358 inclusive,
359 instructions,
360 } => {
361 let increment = step
362 .map(|step| format!("{i} += {step}"))
363 .unwrap_or_else(|| format!("++{i}"));
364 let cmp = if *inclusive { "<=" } else { "<" };
365 let i_ty = i.item();
366
367 write!(
368 f,
369 "
370for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
371"
372 )?;
373 for instruction in instructions {
374 write!(f, "{instruction}")?;
375 }
376
377 f.write_str("}\n")
378 }
379 Instruction::Loop { instructions } => {
380 writeln!(f, "while (true) {{")?;
381 for i in instructions {
382 write!(f, "{i}")?;
383 }
384 f.write_str("}\n")
385 }
386 Instruction::If { cond, instructions } => {
387 writeln!(f, "if ({cond}) {{")?;
388 for i in instructions {
389 write!(f, "{i}")?;
390 }
391 f.write_str("}\n")
392 }
393 Instruction::IfElse {
394 cond,
395 instructions_if,
396 instructions_else,
397 } => {
398 writeln!(f, "if ({cond}) {{")?;
399 for i in instructions_if {
400 write!(f, "{i}")?;
401 }
402 f.write_str("} else {\n")?;
403 for i in instructions_else {
404 write!(f, "{i}")?;
405 }
406 f.write_str("}\n")
407 }
408 Instruction::Select {
409 cond,
410 then,
411 or_else,
412 out,
413 } => {
414 let item_or_else = or_else.item();
415 let item_then = then.item();
416 let item_out = out.item();
417
418 let vf_then = item_then.vectorization;
419 let vf_or_else = item_or_else.vectorization;
420 let vf_out = item_out.vectorization;
421 let vf_cond = cond.item().vectorization;
422
423 let item_out = out.item();
424 let cond_elem = cond.item().elem;
425 let out = out.fmt_left();
426
427 let should_broadcast =
428 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
429
430 if should_broadcast {
431 let vf = usize::max(vf_cond, vf_out);
432 let vf = usize::max(vf, vf_then);
433 let vf = usize::max(vf, vf_or_else);
434
435 writeln!(f, "{out} = {item_out} {{")?;
436 for i in 0..vf {
437 let theni = then.index(i);
438 let or_elsei = or_else.index(i);
439 let condi = cond.index(i);
440 let condi = EnsureBoolArg {
441 var: &condi,
442 elem: &cond_elem,
443 };
444
445 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
446 }
447
448 writeln!(f, "}};")
449 } else {
450 let cond = EnsureBoolArg {
451 var: &cond,
452 elem: &cond_elem,
453 };
454 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
455 }
456 }
457 Instruction::Switch {
458 value,
459 instructions_default,
460 instructions_cases,
461 } => {
462 writeln!(f, "switch({value}) {{")?;
463 for (value, block) in instructions_cases {
464 write!(f, "case {value}:\n{{\n")?;
465 for i in block {
466 i.fmt(f)?;
467 }
468 f.write_str("break;\n}\n")?;
469 }
470 f.write_str("default:\n{")?;
471 for i in instructions_default {
472 i.fmt(f)?;
473 }
474 f.write_str("}\n}\n")
475 }
476 Instruction::Metadata {
477 info_offset,
478 split_meta,
479 out,
480 } => {
481 let out = out.fmt_left();
482 match *split_meta {
483 true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
484 false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
485 }
486 }
487 Instruction::ExtendedMetadata {
488 info_offset,
489 dim,
490 split_meta,
491 static_offset,
492 out,
493 } => {
494 let out = out.fmt_left();
495 match *split_meta {
496 true => writeln!(
497 f,
498 "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
499 ),
500 false => writeln!(
501 f,
502 "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
503 ),
504 }
505 }
506 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
507 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
508 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
509 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
510 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
511 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
512 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
513 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
514 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
515 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
516 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
517 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
518 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
519 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
520 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
521 Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
522 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
523 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
524 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
525 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
526 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
527 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
528 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
529 Instruction::Clamp {
530 input,
531 min_value,
532 max_value,
533 out,
534 } => Clamp::format(f, input, min_value, max_value, out),
535 Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
536 Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
537 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
538 Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
539 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
540 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
541 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
542 Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
543 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
544 Instruction::SliceLength { input, out } => {
545 let out = out.fmt_left();
546 writeln!(f, "{out} = {input}_length;")
547 }
548 Instruction::ConstLength { length, out } => {
549 let out = out.fmt_left();
550 writeln!(f, "{out} = {length};")
551 }
552 Instruction::Warp(it) => write!(f, "{it}"),
553 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
554 Instruction::Wmma(it) => write!(f, "{it}"),
555 Instruction::Bitcast(UnaryInstruction { input, out }) => {
556 let qualifier = out.const_qualifier();
557 let input_item = input.item();
558 let out_item = out.item();
559
560 if out_item.elem.size() * out_item.vectorization
561 != input.item().elem.size() * input.item().vectorization
562 {
563 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
564 } else {
565 let out = out.fmt_left();
566 let addr_space = D::address_space_for_variable(input);
567 writeln!(
568 f,
569 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
570 )
571 }
572 }
573 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
574 D::compile_atomic_add(f, lhs, rhs, out)
575 }
576 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
577 D::compile_atomic_and(f, lhs, rhs, out)
578 }
579 Instruction::AtomicCAS {
580 input,
581 cmp,
582 val,
583 out,
584 } => D::compile_atomic_cas(f, input, cmp, val, out),
585 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
586 D::compile_atomic_load(f, input, out)
587 }
588 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
589 D::compile_atomic_max(f, lhs, rhs, out)
590 }
591 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
592 D::compile_atomic_min(f, lhs, rhs, out)
593 }
594 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
595 D::compile_atomic_or(f, lhs, rhs, out)
596 }
597 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
598 D::compile_atomic_store(f, input, out)
599 }
600 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
601 D::compile_atomic_sub(f, lhs, rhs, out)
602 }
603 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
604 D::compile_atomic_swap(f, lhs, rhs, out)
605 }
606 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
607 D::compile_atomic_xor(f, lhs, rhs, out)
608 }
609 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
610 Instruction::Neg(UnaryInstruction { input, out }) => {
611 let out = out.fmt_left();
612 writeln!(f, "{out} = -{input};")
613 }
614 Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out),
615 Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out),
616 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
617 Instruction::VecInit { inputs, out } => {
618 let item = out.item();
619 let inputs = inputs
620 .iter()
621 .map(|input| format!("{input}"))
622 .collect::<Vec<_>>();
623 let out = out.fmt_left();
624 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
625 }
626 Instruction::Printf {
627 format_string,
628 args,
629 } => D::compile_instruction_printf(f, format_string, args),
630 Instruction::Comment { content } => {
631 if content.contains('\n') {
632 writeln!(f, "/* {content} */")
633 } else {
634 writeln!(f, "// {content}")
635 }
636 }
637 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
638 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
639 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
640 Instruction::ProxySharedFence => {
641 writeln!(
642 f,
643 "cuda::device::experimental::fence_proxy_async_shared_cta();"
644 )
645 }
646 Instruction::BulkCommitGroup => writeln!(
647 f,
648 "cuda::device::experimental::cp_async_bulk_commit_group();"
649 ),
650 Instruction::BulkWaitGroup { max_pending } => writeln!(
651 f,
652 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
653 ),
654 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
655 f,
656 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
657 ),
658 Instruction::TmaReplacePointer {
659 buffer,
660 offset,
661 tensor_map,
662 out,
663 } => {
664 let pos = Variable::<D>::UnitPos;
665 writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
666 writeln!(
667 f,
668 "
669if({pos} == 0) {{
670 {out} = {tensor_map};
671 tensormap_replace_global_address({out}, &{buffer}[{offset}]);
672}}"
673 )?;
674 writeln!(f, "__syncthreads();")
675 }
676 Instruction::MemCopyAsyncTensorSharedToGlobal {
677 smem_buffer,
678 smem_offset,
679 tensor_map,
680 indices,
681 } => {
682 let rank = indices.len();
683 let smem_ptr = smem_buffer.fmt_ptr();
684 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
685 let _ = write!(s, "{it}, ");
686 s
687 });
688 writeln!(
689 f,
690 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
691 )
692 }
693 Instruction::SpecialCast(UnaryInstruction { input, out }) => {
694 #[cfg(not(feature = "cuda"))]
696 {
697 let _ = (input, out);
698 unimplemented!("FP8/FP6/FP4 casting isn't supported outside of CUDA");
699 }
700 #[cfg(feature = "cuda")]
701 crate::cuda::convert::special_cast::<D>(f, input, out)
702 }
703 }
704 }
705}
706
707struct Fma<D: Dialect> {
708 _dialect: PhantomData<D>,
709}
710
711impl<D: Dialect> Fma<D> {
712 fn format(
713 f: &mut core::fmt::Formatter<'_>,
714 a: &Variable<D>,
715 b: &Variable<D>,
716 c: &Variable<D>,
717 out: &Variable<D>,
718 ) -> core::fmt::Result {
719 let out_item = out.item();
720 let num = out_item.vectorization;
721
722 let out = out.fmt_left();
723 if num == 1 {
724 writeln!(f, "{out} = fma({a}, {b}, {c});")
725 } else {
726 writeln!(f, "{out} = {out_item}{{")?;
727
728 for i in 0..num {
729 let ai = a.index(i);
730 let bi = b.index(i);
731 let ci = c.index(i);
732
733 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
734 }
735 f.write_str("};\n")
736 }
737 }
738}
739
740struct Clamp<D: Dialect> {
741 _dialect: PhantomData<D>,
742}
743
744impl<D: Dialect> Clamp<D> {
745 fn format(
746 f: &mut core::fmt::Formatter<'_>,
747 input: &Variable<D>,
748 min_value: &Variable<D>,
749 max_value: &Variable<D>,
750 out: &Variable<D>,
751 ) -> core::fmt::Result {
752 let out_item = out.item();
753 if out.item().vectorization == 1 {
754 let out = out.fmt_left();
755 write!(f, "{out} = ")?;
756 Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
757 f.write_str(";\n")
758 } else {
759 Self::unroll_vec(f, input, min_value, max_value, out)
760 }
761 }
762
763 fn format_scalar(
764 f: &mut Formatter<'_>,
765 input: impl Component<D>,
766 min_value: impl Component<D>,
767 max_value: impl Component<D>,
768 item: Item<D>,
769 ) -> std::fmt::Result {
770 D::compile_instruction_max_function_name(f, item)?;
771 write!(f, "({min_value}, ")?;
772 D::compile_instruction_min_function_name(f, item)?;
773 write!(f, "({max_value}, {input}))")
774 }
775
776 fn unroll_vec(
777 f: &mut core::fmt::Formatter<'_>,
778 input: &Variable<D>,
779 min_value: &Variable<D>,
780 max_value: &Variable<D>,
781 out: &Variable<D>,
782 ) -> std::fmt::Result {
783 let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
784 let [input, min_value, max_value, out_optimized] = optimized.args;
785
786 let item_out_original = out.item();
787 let item_out_optimized = out_optimized.item();
788
789 let index = match optimized.optimization_factor {
790 Some(factor) => item_out_original.vectorization / factor,
791 None => item_out_optimized.vectorization,
792 };
793
794 let mut write_op = |input: &Variable<D>,
795 min_value: &Variable<D>,
796 max_value: &Variable<D>,
797 out: &Variable<D>,
798 item_out: Item<D>| {
799 let out = out.fmt_left();
800 writeln!(f, "{out} = {item_out}{{")?;
801 for i in 0..index {
802 let inputi = input.index(i);
803 let min_valuei = min_value.index(i);
804 let max_valuei = max_value.index(i);
805
806 Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
807 f.write_str(", ")?;
808 }
809
810 f.write_str("};\n")
811 };
812
813 if item_out_original == item_out_optimized {
814 write_op(&input, &min_value, &max_value, out, item_out_optimized)
815 } else {
816 let out_tmp = Variable::tmp(item_out_optimized);
817 write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
818 let addr_space = D::address_space_for_variable(out);
819 let out = out.fmt_left();
820
821 writeln!(
822 f,
823 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
824 )?;
825
826 Ok(())
827 }
828 }
829}
830
831struct Remainder<D: Dialect> {
832 _dialect: PhantomData<D>,
833}
834
835impl<D: Dialect> Remainder<D> {
836 fn format(
837 f: &mut core::fmt::Formatter<'_>,
838 lhs: &Variable<D>,
839 rhs: &Variable<D>,
840 out: &Variable<D>,
841 ) -> core::fmt::Result {
842 let floor = |elem| {
843 let prefix = match elem {
844 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
845 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
846 _ => "",
847 };
848 format!("{prefix}floor")
849 };
850
851 if out.item().vectorization == 1 {
852 let floor = floor(out.elem());
853
854 let out = out.fmt_left();
855 return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
856 }
857
858 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
859 let [lhs, rhs, out_optimized] = optimized.args;
860
861 let item_out_original = out.item();
862 let item_out_optimized = out_optimized.item();
863
864 let index = match optimized.optimization_factor {
865 Some(factor) => item_out_original.vectorization / factor,
866 None => item_out_optimized.vectorization,
867 };
868
869 let floor = floor(*item_out_optimized.elem());
870
871 let mut write_op =
872 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
873 let out = out.fmt_left();
874 writeln!(f, "{out} = {item_out}{{")?;
875 for i in 0..index {
876 let lhsi = lhs.index(i);
877 let rhsi = rhs.index(i);
878
879 writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
880 f.write_str(", ")?;
881 }
882
883 f.write_str("};\n")
884 };
885
886 if item_out_original == item_out_optimized {
887 write_op(&lhs, &rhs, out, item_out_optimized)
888 } else {
889 let out_tmp = Variable::tmp(item_out_optimized);
890
891 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
892
893 let addr_space = D::address_space_for_variable(&out_tmp);
894 let qualifier = out.const_qualifier();
895 let out = out.fmt_left();
896
897 writeln!(
898 f,
899 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
900 )?;
901
902 Ok(())
903 }
904 }
905}
906
907struct Magnitude<D: Dialect> {
908 _dialect: PhantomData<D>,
909}
910
911impl<D: Dialect> Magnitude<D> {
912 fn format(
913 f: &mut core::fmt::Formatter<'_>,
914 input: &Variable<D>,
915 out: &Variable<D>,
916 ) -> core::fmt::Result {
917 let num = input.item().vectorization;
918 let elem = input.elem();
919
920 let mag = format!("{out}_mag");
921
922 writeln!(f, "{} {mag} = 0.0;", out.item())?;
923
924 for i in 0..num {
925 let input_i = input.index(i);
926 writeln!(f, "{mag} += {input_i} * {input_i};")?;
927 }
928
929 let out = out.fmt_left();
930 write!(f, "{out} = ")?;
931 Sqrt::format_unary(f, &mag, elem)?;
932 f.write_str(";\n")
933 }
934}
935
936struct Normalize<D: Dialect> {
937 _dialect: PhantomData<D>,
938}
939
940impl<D: Dialect> Normalize<D> {
941 fn format(
942 f: &mut core::fmt::Formatter<'_>,
943 input: &Variable<D>,
944 out: &Variable<D>,
945 ) -> core::fmt::Result {
946 let num = input.item().vectorization;
947 let elem = input.elem();
948 let norm = format!("{out}_norm");
949
950 let out_item = out.item();
951 let out = out.fmt_left();
952 writeln!(f, "{elem} {norm} = 0.0;")?;
953
954 for i in 0..num {
955 let input_i = input.index(i);
956 writeln!(f, "{norm} += {input_i} * {input_i};")?;
957 }
958
959 write!(f, "{norm} = ")?;
960 Sqrt::format_unary(f, &norm, elem)?;
961 f.write_str(";\n")?;
962
963 if num == 1 {
964 writeln!(f, "{out} = {input} / {norm};")
965 } else {
966 write!(f, "{out} = {out_item}{{")?;
967 for i in 0..num {
968 let input_i = input.index(i);
969
970 writeln!(f, "{input_i} / {norm},")?;
971 }
972
973 f.write_str("};\n")
974 }
975 }
976}
977
978struct Dot<D: Dialect> {
979 _dialect: PhantomData<D>,
980}
981
982impl<D: Dialect> Dot<D> {
983 fn format(
984 f: &mut core::fmt::Formatter<'_>,
985 lhs: &Variable<D>,
986 rhs: &Variable<D>,
987 out: &Variable<D>,
988 ) -> core::fmt::Result {
989 let num = lhs.item().vectorization;
990
991 let muls = (0..num)
992 .map(|i| {
993 let lhs_i = lhs.index(i);
994 let rhs_i = rhs.index(i);
995 format!("{lhs_i} * {rhs_i}")
996 })
997 .collect::<Vec<_>>();
998
999 let out = out.fmt_left();
1000 writeln!(f, "{out} = {};", muls.join(" + "))
1001 }
1002}
1003
1004struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1005 var: &'a V,
1006 elem: &'a Elem<D>,
1007}
1008
1009impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1010 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1011 if self.elem != &Elem::Bool {
1012 write!(f, "bool({})", self.var)
1013 } else {
1014 write!(f, "{}", self.var)
1015 }
1016 }
1017}