1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Formatter, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone, Copy)]
17pub struct BinaryInstruction<D: Dialect> {
18 pub lhs: Variable<D>,
19 pub rhs: Variable<D>,
20 pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25 pub list: Variable<D>,
26 pub index: Variable<D>,
27 pub line_size: u32,
28 pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33 pub index: Variable<D>,
34 pub value: Variable<D>,
35 pub line_size: u32,
36 pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct UnaryInstruction<D: Dialect> {
41 pub input: Variable<D>,
42 pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47 Metadata {
48 info_offset: Variable<D>,
49 split_meta: bool,
50 out: Variable<D>,
51 },
52 ExtendedMetadata {
53 info_offset: Variable<D>,
54 dim: Variable<D>,
55 split_meta: bool,
56 static_offset: u32,
57 out: Variable<D>,
58 },
59 ConstLength {
60 length: u32,
61 out: Variable<D>,
62 },
63 SliceLength {
64 input: Variable<D>,
65 out: Variable<D>,
66 },
67 DeclareVariable {
68 var: Variable<D>,
69 },
70 Modulo(BinaryInstruction<D>),
71 Remainder(BinaryInstruction<D>),
72 Add(BinaryInstruction<D>),
73 SaturatingAdd(BinaryInstruction<D>),
74 Fma {
75 a: Variable<D>,
76 b: Variable<D>,
77 c: Variable<D>,
78 out: Variable<D>,
79 },
80 Div(BinaryInstruction<D>),
81 FastDiv(BinaryInstruction<D>),
82 FastRecip(UnaryInstruction<D>),
83 Mul(BinaryInstruction<D>),
84 Sub(BinaryInstruction<D>),
85 SaturatingSub(BinaryInstruction<D>),
86 HiMul(BinaryInstruction<D>),
87 Index(IndexInstruction<D>),
88 IndexAssign(IndexAssignInstruction<D>),
89 Assign(UnaryInstruction<D>),
90 SpecialCast(UnaryInstruction<D>),
91 RangeLoop {
92 i: Variable<D>,
93 start: Variable<D>,
94 end: Variable<D>,
95 step: Option<Variable<D>>,
96 inclusive: bool,
97 instructions: Vec<Self>,
98 },
99 VecInit {
100 inputs: Vec<Variable<D>>,
101 out: Variable<D>,
102 },
103 Loop {
104 instructions: Vec<Self>,
105 },
106 If {
107 cond: Variable<D>,
108 instructions: Vec<Self>,
109 },
110 IfElse {
111 cond: Variable<D>,
112 instructions_if: Vec<Self>,
113 instructions_else: Vec<Self>,
114 },
115 Select {
116 cond: Variable<D>,
117 then: Variable<D>,
118 or_else: Variable<D>,
119 out: Variable<D>,
120 },
121 Switch {
122 value: Variable<D>,
123 instructions_default: Vec<Self>,
124 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
125 },
126 Slice {
127 input: Variable<D>,
128 start: Variable<D>,
129 end: Variable<D>,
130 out: Variable<D>,
131 },
132 CheckedSlice {
133 input: Variable<D>,
134 start: Variable<D>,
135 end: Variable<D>,
136 out: Variable<D>,
137 len: Variable<D>,
138 },
139 ReinterpretSlice {
140 input: Variable<D>,
141 line_size: u32,
142 out: Variable<D>,
143 },
144 Return,
145 Break,
146 Equal(BinaryInstruction<D>),
147 NotEqual(BinaryInstruction<D>),
148 Lower(BinaryInstruction<D>),
149 Greater(BinaryInstruction<D>),
150 LowerEqual(BinaryInstruction<D>),
151 GreaterEqual(BinaryInstruction<D>),
152 Erf(UnaryInstruction<D>),
153 BitwiseOr(BinaryInstruction<D>),
154 BitwiseAnd(BinaryInstruction<D>),
155 BitwiseXor(BinaryInstruction<D>),
156 CountBits(UnaryInstruction<D>),
157 ReverseBits(UnaryInstruction<D>),
158 ShiftLeft(BinaryInstruction<D>),
159 ShiftRight(BinaryInstruction<D>),
160 BitwiseNot(UnaryInstruction<D>),
161 LeadingZeros(UnaryInstruction<D>),
162 FindFirstSet(UnaryInstruction<D>),
163 Abs(UnaryInstruction<D>),
164 Exp(UnaryInstruction<D>),
165 FastExp(UnaryInstruction<D>),
166 Log(UnaryInstruction<D>),
167 FastLog(UnaryInstruction<D>),
168 Log1p(UnaryInstruction<D>),
169 Cos(UnaryInstruction<D>),
170 FastCos(UnaryInstruction<D>),
171 Sin(UnaryInstruction<D>),
172 FastSin(UnaryInstruction<D>),
173 Tanh(UnaryInstruction<D>),
174 FastTanh(UnaryInstruction<D>),
175 Powf(BinaryInstruction<D>),
176 FastPowf(BinaryInstruction<D>),
177 Powi(BinaryInstruction<D>),
178 Sqrt(UnaryInstruction<D>),
179 FastSqrt(UnaryInstruction<D>),
180 InverseSqrt(UnaryInstruction<D>),
181 FastInverseSqrt(UnaryInstruction<D>),
182 Min(BinaryInstruction<D>),
183 Max(BinaryInstruction<D>),
184 Not(UnaryInstruction<D>),
185 Or(BinaryInstruction<D>),
186 And(BinaryInstruction<D>),
187 Clamp {
188 input: Variable<D>,
189 min_value: Variable<D>,
190 max_value: Variable<D>,
191 out: Variable<D>,
192 },
193 IsNan(UnaryInstruction<D>),
194 IsInf(UnaryInstruction<D>),
195 SyncThreads,
196 SyncWarp,
197 ThreadFence,
198 ProxyAsyncToSharedFence,
199 BulkCommitGroup,
200 BulkWaitGroup {
201 max_pending: u32,
202 },
203 BulkWaitGroupRead {
204 max_pending: u32,
205 },
206 TmaReplacePointer {
207 buffer: Variable<D>,
208 offset: Variable<D>,
209 tensor_map: Variable<D>,
210 out: Variable<D>,
211 },
212 Round(UnaryInstruction<D>),
213 Ceil(UnaryInstruction<D>),
214 Trunc(UnaryInstruction<D>),
215 Floor(UnaryInstruction<D>),
216 Warp(WarpInstruction<D>),
217 Wmma(WmmaInstruction<D>),
218 Bitcast(UnaryInstruction<D>),
219 AtomicLoad(UnaryInstruction<D>),
220 AtomicStore(UnaryInstruction<D>),
221 AtomicSwap(BinaryInstruction<D>),
222 AtomicAdd(BinaryInstruction<D>),
223 AtomicSub(BinaryInstruction<D>),
224 AtomicMax(BinaryInstruction<D>),
225 AtomicMin(BinaryInstruction<D>),
226 AtomicAnd(BinaryInstruction<D>),
227 AtomicOr(BinaryInstruction<D>),
228 AtomicXor(BinaryInstruction<D>),
229 AtomicCAS {
230 input: Variable<D>,
231 cmp: Variable<D>,
232 val: Variable<D>,
233 out: Variable<D>,
234 },
235 Neg(UnaryInstruction<D>),
236 Magnitude(UnaryInstruction<D>),
237 FastMagnitude(UnaryInstruction<D>),
238 Normalize(UnaryInstruction<D>),
239 FastNormalize(UnaryInstruction<D>),
240 Dot(BinaryInstruction<D>),
241 Copy {
242 input: Variable<D>,
243 in_index: Variable<D>,
244 out: Variable<D>,
245 out_index: Variable<D>,
246 },
247 CopyBulk {
248 input: Variable<D>,
249 in_index: Variable<D>,
250 out: Variable<D>,
251 out_index: Variable<D>,
252 len: u32,
253 },
254 Printf {
255 format_string: String,
256 args: Vec<Variable<D>>,
257 },
258 Comment {
259 content: String,
260 },
261 Pipeline(PipelineOps<D>),
262 Barrier(BarrierOps<D>),
263 MemCopyAsyncTensorSharedToGlobal {
264 smem_buffer: Variable<D>,
265 smem_offset: Variable<D>,
266 tensor_map: Variable<D>,
267 indices: Vec<Variable<D>>,
268 },
269 Line {
270 file: Cow<'static, str>,
271 line: u32,
272 },
273}
274
275impl<D: Dialect> Display for Instruction<D> {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 match self {
278 Instruction::Return => f.write_str("return;"),
279 Instruction::Break => f.write_str("break;"),
280 Instruction::DeclareVariable { var } => match var {
281 Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
282 _ => {
283 let item = var.item();
284 writeln!(f, "{item} {var};")
285 }
286 },
287 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
288 Instruction::SaturatingAdd(it) => SaturatingAdd::format(f, &it.lhs, &it.rhs, &it.out),
289 Instruction::Slice {
290 input,
291 start,
292 end,
293 out,
294 } => {
295 let item = out.item();
296 let addr_space = D::address_space_for_variable(input);
297 writeln!(f, "const uint {out}_length = {end} - {start};")?;
298 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
299 }
300 Instruction::CheckedSlice {
301 input,
302 start,
303 end,
304 out,
305 len,
306 } => {
307 let item = out.item();
308 let addr_space = D::address_space_for_variable(input);
309 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
310 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
311 }
312 Instruction::ReinterpretSlice {
313 input,
314 line_size,
315 out,
316 } => {
317 let mut item = out.item();
318 item.vectorization = *line_size as usize;
319 let addr_space = D::address_space_for_variable(input);
320
321 writeln!(
322 f,
323 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
324 )
325 }
326 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
327 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
328 Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out),
329 Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out),
330 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
331 Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out),
332 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
333 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
334 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
335 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
336 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
337 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
338 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
339 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
340 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
341 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
342 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
343 Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
344 Instruction::IndexAssign(it) => {
345 IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
346 }
347 Instruction::Copy {
348 input,
349 in_index,
350 out,
351 out_index,
352 } => {
353 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
354 }
355 Instruction::CopyBulk {
356 input,
357 in_index,
358 out,
359 out_index,
360 len,
361 } => {
362 for i in 0..*len {
363 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
364 }
365 Ok(())
366 }
367 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
368 Instruction::RangeLoop {
369 i,
370 start,
371 end,
372 step,
373 inclusive,
374 instructions,
375 } => {
376 let increment = step
377 .map(|step| format!("{i} += {step}"))
378 .unwrap_or_else(|| format!("++{i}"));
379 let cmp = if *inclusive { "<=" } else { "<" };
380 let i_ty = i.item();
381
382 write!(
383 f,
384 "
385for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
386"
387 )?;
388 for instruction in instructions {
389 write!(f, "{instruction}")?;
390 }
391
392 f.write_str("}\n")
393 }
394 Instruction::Loop { instructions } => {
395 writeln!(f, "while (true) {{")?;
396 for i in instructions {
397 write!(f, "{i}")?;
398 }
399 f.write_str("}\n")
400 }
401 Instruction::If { cond, instructions } => {
402 writeln!(f, "if ({cond}) {{")?;
403 for i in instructions {
404 write!(f, "{i}")?;
405 }
406 f.write_str("}\n")
407 }
408 Instruction::IfElse {
409 cond,
410 instructions_if,
411 instructions_else,
412 } => {
413 writeln!(f, "if ({cond}) {{")?;
414 for i in instructions_if {
415 write!(f, "{i}")?;
416 }
417 f.write_str("} else {\n")?;
418 for i in instructions_else {
419 write!(f, "{i}")?;
420 }
421 f.write_str("}\n")
422 }
423 Instruction::Select {
424 cond,
425 then,
426 or_else,
427 out,
428 } => {
429 let item_or_else = or_else.item();
430 let item_then = then.item();
431 let item_out = out.item();
432
433 let vf_then = item_then.vectorization;
434 let vf_or_else = item_or_else.vectorization;
435 let vf_out = item_out.vectorization;
436 let vf_cond = cond.item().vectorization;
437
438 let item_out = out.item();
439 let cond_elem = cond.item().elem;
440 let out = out.fmt_left();
441
442 let should_broadcast =
443 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
444
445 if should_broadcast {
446 let vf = usize::max(vf_cond, vf_out);
447 let vf = usize::max(vf, vf_then);
448 let vf = usize::max(vf, vf_or_else);
449
450 writeln!(f, "{out} = {item_out} {{")?;
451 for i in 0..vf {
452 let theni = then.index(i);
453 let or_elsei = or_else.index(i);
454 let condi = cond.index(i);
455 let condi = EnsureBoolArg {
456 var: &condi,
457 elem: &cond_elem,
458 };
459
460 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
461 }
462
463 writeln!(f, "}};")
464 } else {
465 let cond = EnsureBoolArg {
466 var: &cond,
467 elem: &cond_elem,
468 };
469 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
470 }
471 }
472 Instruction::Switch {
473 value,
474 instructions_default,
475 instructions_cases,
476 } => {
477 writeln!(f, "switch({value}) {{")?;
478 for (value, block) in instructions_cases {
479 write!(f, "case {value}:\n{{\n")?;
480 for i in block {
481 i.fmt(f)?;
482 }
483 f.write_str("break;\n}\n")?;
484 }
485 f.write_str("default:\n{")?;
486 for i in instructions_default {
487 i.fmt(f)?;
488 }
489 f.write_str("}\n}\n")
490 }
491 Instruction::Metadata {
492 info_offset,
493 split_meta,
494 out,
495 } => {
496 let out = out.fmt_left();
497 match *split_meta {
498 true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
499 false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
500 }
501 }
502 Instruction::ExtendedMetadata {
503 info_offset,
504 dim,
505 split_meta,
506 static_offset,
507 out,
508 } => {
509 let out = out.fmt_left();
510 match *split_meta {
511 true => writeln!(
512 f,
513 "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
514 ),
515 false => writeln!(
516 f,
517 "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
518 ),
519 }
520 }
521 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
522 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
523 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
524 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
525 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
526 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
527 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
528 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
529 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
530 Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out),
531 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
532 Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out),
533 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
534 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
535 Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out),
536 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
537 Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out),
538 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
539 Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out),
540 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
541 Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out),
542 Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out),
543 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
544 Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out),
545 Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out),
546 Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out),
547 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
548 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
549 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
550 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
551 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
552 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
553 Instruction::Clamp {
554 input,
555 min_value,
556 max_value,
557 out,
558 } => Clamp::format(f, input, min_value, max_value, out),
559 Instruction::IsNan(it) => IsNan::format(f, &it.input, &it.out),
560 Instruction::IsInf(it) => IsInf::format(f, &it.input, &it.out),
561 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
562 Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
563 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
564 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
565 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
566 Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
567 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
568 Instruction::SliceLength { input, out } => {
569 let out = out.fmt_left();
570 writeln!(f, "{out} = {input}_length;")
571 }
572 Instruction::ConstLength { length, out } => {
573 let out = out.fmt_left();
574 writeln!(f, "{out} = {length};")
575 }
576 Instruction::Warp(it) => write!(f, "{it}"),
577 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
578 Instruction::Wmma(it) => write!(f, "{it}"),
579 Instruction::Bitcast(UnaryInstruction { input, out }) => {
580 let qualifier = out.const_qualifier();
581 let input_item = input.item();
582 let out_item = out.item();
583
584 if out_item.elem.size() * out_item.vectorization
585 != input.item().elem.size() * input.item().vectorization
586 {
587 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
588 } else {
589 let out = out.fmt_left();
590 let addr_space = D::address_space_for_variable(input);
591 writeln!(
592 f,
593 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
594 )
595 }
596 }
597 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
598 D::compile_atomic_add(f, lhs, rhs, out)
599 }
600 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
601 D::compile_atomic_and(f, lhs, rhs, out)
602 }
603 Instruction::AtomicCAS {
604 input,
605 cmp,
606 val,
607 out,
608 } => D::compile_atomic_cas(f, input, cmp, val, out),
609 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
610 D::compile_atomic_load(f, input, out)
611 }
612 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
613 D::compile_atomic_max(f, lhs, rhs, out)
614 }
615 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
616 D::compile_atomic_min(f, lhs, rhs, out)
617 }
618 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
619 D::compile_atomic_or(f, lhs, rhs, out)
620 }
621 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
622 D::compile_atomic_store(f, input, out)
623 }
624 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
625 D::compile_atomic_sub(f, lhs, rhs, out)
626 }
627 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
628 D::compile_atomic_swap(f, lhs, rhs, out)
629 }
630 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
631 D::compile_atomic_xor(f, lhs, rhs, out)
632 }
633 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
634 Instruction::Neg(UnaryInstruction { input, out }) => {
635 let out = out.fmt_left();
636 writeln!(f, "{out} = -{input};")
637 }
638 Instruction::Normalize(inst) => {
639 Normalize::<D, InverseSqrt>::format(f, &inst.input, &inst.out)
640 }
641 Instruction::FastNormalize(inst) => {
642 Normalize::<D, FastInverseSqrt>::format(f, &inst.input, &inst.out)
643 }
644 Instruction::Magnitude(inst) => Magnitude::<D, Sqrt>::format(f, &inst.input, &inst.out),
645 Instruction::FastMagnitude(inst) => {
646 Magnitude::<D, FastSqrt>::format(f, &inst.input, &inst.out)
647 }
648 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
649 Instruction::VecInit { inputs, out } => {
650 let item = out.item();
651 let inputs = inputs
652 .iter()
653 .map(|input| format!("{input}"))
654 .collect::<Vec<_>>();
655 let out = out.fmt_left();
656 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
657 }
658 Instruction::Printf {
659 format_string,
660 args,
661 } => D::compile_instruction_printf(f, format_string, args),
662 Instruction::Comment { content } => {
663 if content.contains('\n') {
664 writeln!(f, "/* {content} */")
665 } else {
666 writeln!(f, "// {content}")
667 }
668 }
669 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
670 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
671 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
672 Instruction::ProxyAsyncToSharedFence => {
673 writeln!(
674 f,
675 "cuda::device::experimental::fence_proxy_async_shared_cta();"
676 )
677 }
678 Instruction::BulkCommitGroup => writeln!(
679 f,
680 "cuda::device::experimental::cp_async_bulk_commit_group();"
681 ),
682 Instruction::BulkWaitGroup { max_pending } => writeln!(
683 f,
684 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
685 ),
686 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
687 f,
688 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
689 ),
690 Instruction::TmaReplacePointer {
691 buffer,
692 offset,
693 tensor_map,
694 out,
695 } => {
696 let pos = Variable::<D>::UnitPos;
697 writeln!(f, "__shared__ alignas(128) CUtensorMap {out};")?;
698 writeln!(
699 f,
700 "
701if({pos} == 0) {{
702 {out} = {tensor_map};
703 tensormap_replace_global_address({out}, &{buffer}[{offset}]);
704}}"
705 )?;
706 writeln!(f, "__syncthreads();")
707 }
708 Instruction::MemCopyAsyncTensorSharedToGlobal {
709 smem_buffer,
710 smem_offset,
711 tensor_map,
712 indices,
713 } => {
714 let rank = indices.len();
715 let smem_ptr = smem_buffer.fmt_ptr();
716 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
717 let _ = write!(s, "{it}, ");
718 s
719 });
720 writeln!(
721 f,
722 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
723 )
724 }
725 Instruction::SpecialCast(UnaryInstruction { input, out }) => {
726 #[cfg(not(feature = "cuda"))]
728 {
729 let _ = (input, out);
730 unimplemented!("FP8/FP6/FP4 casting isn't supported outside of CUDA");
731 }
732 #[cfg(feature = "cuda")]
733 crate::cuda::convert::special_cast::<D>(f, input, out)
734 }
735 }
736 }
737}
738
739struct Fma<D: Dialect> {
740 _dialect: PhantomData<D>,
741}
742
743impl<D: Dialect> Fma<D> {
744 fn format(
745 f: &mut core::fmt::Formatter<'_>,
746 a: &Variable<D>,
747 b: &Variable<D>,
748 c: &Variable<D>,
749 out: &Variable<D>,
750 ) -> core::fmt::Result {
751 let out_item = out.item();
752 let num = out_item.vectorization;
753
754 let out = out.fmt_left();
755 if num == 1 {
756 writeln!(f, "{out} = fma({a}, {b}, {c});")
757 } else {
758 writeln!(f, "{out} = {out_item}{{")?;
759
760 for i in 0..num {
761 let ai = a.index(i);
762 let bi = b.index(i);
763 let ci = c.index(i);
764
765 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
766 }
767 f.write_str("};\n")
768 }
769 }
770}
771
772struct Clamp<D: Dialect> {
773 _dialect: PhantomData<D>,
774}
775
776impl<D: Dialect> Clamp<D> {
777 fn format(
778 f: &mut core::fmt::Formatter<'_>,
779 input: &Variable<D>,
780 min_value: &Variable<D>,
781 max_value: &Variable<D>,
782 out: &Variable<D>,
783 ) -> core::fmt::Result {
784 let out_item = out.item();
785 if out.item().vectorization == 1 {
786 let out = out.fmt_left();
787 write!(f, "{out} = ")?;
788 Self::format_scalar(f, *input, *min_value, *max_value, out_item)?;
789 f.write_str(";\n")
790 } else {
791 Self::unroll_vec(f, input, min_value, max_value, out)
792 }
793 }
794
795 fn format_scalar(
796 f: &mut Formatter<'_>,
797 input: impl Component<D>,
798 min_value: impl Component<D>,
799 max_value: impl Component<D>,
800 item: Item<D>,
801 ) -> std::fmt::Result {
802 D::compile_instruction_max_function_name(f, item)?;
803 write!(f, "({min_value}, ")?;
804 D::compile_instruction_min_function_name(f, item)?;
805 write!(f, "({max_value}, {input}))")
806 }
807
808 fn unroll_vec(
809 f: &mut core::fmt::Formatter<'_>,
810 input: &Variable<D>,
811 min_value: &Variable<D>,
812 max_value: &Variable<D>,
813 out: &Variable<D>,
814 ) -> std::fmt::Result {
815 let optimized = Variable::optimized_args([*input, *min_value, *max_value, *out]);
816 let [input, min_value, max_value, out_optimized] = optimized.args;
817
818 let item_out_original = out.item();
819 let item_out_optimized = out_optimized.item();
820
821 let index = match optimized.optimization_factor {
822 Some(factor) => item_out_original.vectorization / factor,
823 None => item_out_optimized.vectorization,
824 };
825
826 let mut write_op = |input: &Variable<D>,
827 min_value: &Variable<D>,
828 max_value: &Variable<D>,
829 out: &Variable<D>,
830 item_out: Item<D>| {
831 let out = out.fmt_left();
832 writeln!(f, "{out} = {item_out}{{")?;
833 for i in 0..index {
834 let inputi = input.index(i);
835 let min_valuei = min_value.index(i);
836 let max_valuei = max_value.index(i);
837
838 Self::format_scalar(f, inputi, min_valuei, max_valuei, item_out)?;
839 f.write_str(", ")?;
840 }
841
842 f.write_str("};\n")
843 };
844
845 if item_out_original == item_out_optimized {
846 write_op(&input, &min_value, &max_value, out, item_out_optimized)
847 } else {
848 let out_tmp = Variable::tmp(item_out_optimized);
849 write_op(&input, &min_value, &max_value, &out_tmp, item_out_optimized)?;
850 let addr_space = D::address_space_for_variable(out);
851 let out = out.fmt_left();
852
853 writeln!(
854 f,
855 "{out} = reinterpret_cast<{addr_space}{item_out_original}&>({out_tmp});\n"
856 )?;
857
858 Ok(())
859 }
860 }
861}
862
863struct Remainder<D: Dialect> {
864 _dialect: PhantomData<D>,
865}
866
867impl<D: Dialect> Remainder<D> {
868 fn format(
869 f: &mut core::fmt::Formatter<'_>,
870 lhs: &Variable<D>,
871 rhs: &Variable<D>,
872 out: &Variable<D>,
873 ) -> core::fmt::Result {
874 let floor = |elem| {
875 let prefix = match elem {
876 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
877 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
878 _ => "",
879 };
880 format!("{prefix}floor")
881 };
882
883 if out.item().vectorization == 1 {
884 let floor = floor(out.elem());
885
886 let out = out.fmt_left();
887 return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
888 }
889
890 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
891 let [lhs, rhs, out_optimized] = optimized.args;
892
893 let item_out_original = out.item();
894 let item_out_optimized = out_optimized.item();
895
896 let index = match optimized.optimization_factor {
897 Some(factor) => item_out_original.vectorization / factor,
898 None => item_out_optimized.vectorization,
899 };
900
901 let floor = floor(*item_out_optimized.elem());
902
903 let mut write_op =
904 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
905 let out = out.fmt_left();
906 writeln!(f, "{out} = {item_out}{{")?;
907 for i in 0..index {
908 let lhsi = lhs.index(i);
909 let rhsi = rhs.index(i);
910
911 writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
912 f.write_str(", ")?;
913 }
914
915 f.write_str("};\n")
916 };
917
918 if item_out_original == item_out_optimized {
919 write_op(&lhs, &rhs, out, item_out_optimized)
920 } else {
921 let out_tmp = Variable::tmp(item_out_optimized);
922
923 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
924
925 let addr_space = D::address_space_for_variable(&out_tmp);
926 let qualifier = out.const_qualifier();
927 let out = out.fmt_left();
928
929 writeln!(
930 f,
931 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
932 )?;
933
934 Ok(())
935 }
936 }
937}
938
939struct Magnitude<D: Dialect, S: FunctionFmt<D>> {
940 _dialect: PhantomData<D>,
941 _sqrt: PhantomData<S>,
942}
943
944impl<D: Dialect, S: FunctionFmt<D>> Magnitude<D, S> {
945 fn format(
946 f: &mut core::fmt::Formatter<'_>,
947 input: &Variable<D>,
948 out: &Variable<D>,
949 ) -> core::fmt::Result {
950 let num = input.item().vectorization;
951 let elem = input.elem();
952
953 let mag = format!("{out}_mag");
954
955 writeln!(f, "{} {mag} = 0.0;", out.item())?;
956
957 for i in 0..num {
958 let input_i = input.index(i);
959 writeln!(f, "{mag} += {input_i} * {input_i};")?;
960 }
961
962 let out = out.fmt_left();
963 write!(f, "{out} = ")?;
964 S::format_unary(f, &mag, elem)?;
965 f.write_str(";\n")
966 }
967}
968
969struct Normalize<D: Dialect, InvS: FunctionFmt<D>> {
970 _dialect: PhantomData<D>,
971 _rsqrt: PhantomData<InvS>,
972}
973
974impl<D: Dialect, InvS: FunctionFmt<D>> Normalize<D, InvS> {
975 fn format(
976 f: &mut core::fmt::Formatter<'_>,
977 input: &Variable<D>,
978 out: &Variable<D>,
979 ) -> core::fmt::Result {
980 let num = input.item().vectorization;
981 let elem = input.elem();
982 let norm = format!("{out}_norm");
983
984 let out_item = out.item();
985 let out = out.fmt_left();
986 writeln!(f, "{elem} {norm} = 0.0;")?;
987
988 for i in 0..num {
989 let input_i = input.index(i);
990 writeln!(f, "{norm} += {input_i} * {input_i};")?;
991 }
992
993 write!(f, "{norm} = ")?;
994 InvS::format_unary(f, &norm, elem)?;
995 f.write_str(";\n")?;
996
997 if num == 1 {
998 writeln!(f, "{out} = {input} * {norm};")
999 } else {
1000 write!(f, "{out} = {out_item}{{")?;
1001 for i in 0..num {
1002 let input_i = input.index(i);
1003
1004 writeln!(f, "{input_i} * {norm},")?;
1005 }
1006
1007 f.write_str("};\n")
1008 }
1009 }
1010}
1011
1012struct Dot<D: Dialect> {
1013 _dialect: PhantomData<D>,
1014}
1015
1016impl<D: Dialect> Dot<D> {
1017 fn format(
1018 f: &mut core::fmt::Formatter<'_>,
1019 lhs: &Variable<D>,
1020 rhs: &Variable<D>,
1021 out: &Variable<D>,
1022 ) -> core::fmt::Result {
1023 let num = lhs.item().vectorization;
1024
1025 let muls = (0..num)
1026 .map(|i| {
1027 let lhs_i = lhs.index(i);
1028 let rhs_i = rhs.index(i);
1029 format!("{lhs_i} * {rhs_i}")
1030 })
1031 .collect::<Vec<_>>();
1032
1033 let out = out.fmt_left();
1034 writeln!(f, "{out} = {};", muls.join(" + "))
1035 }
1036}
1037
1038struct EnsureBoolArg<'a, V: Display, D: Dialect> {
1039 var: &'a V,
1040 elem: &'a Elem<D>,
1041}
1042
1043impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
1044 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1045 if self.elem != &Elem::Bool {
1046 write!(f, "bool({})", self.var)
1047 } else {
1048 write!(f, "{}", self.var)
1049 }
1050 }
1051}