1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone)]
17pub struct BinaryInstruction<D: Dialect> {
18 pub lhs: Variable<D>,
19 pub rhs: Variable<D>,
20 pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct IndexInstruction<D: Dialect> {
25 pub list: Variable<D>,
26 pub index: Variable<D>,
27 pub line_size: u32,
28 pub out: Variable<D>,
29}
30
31#[derive(Debug, Clone)]
32pub struct IndexAssignInstruction<D: Dialect> {
33 pub index: Variable<D>,
34 pub value: Variable<D>,
35 pub line_size: u32,
36 pub out: Variable<D>,
37}
38
39#[derive(Debug, Clone)]
40pub struct UnaryInstruction<D: Dialect> {
41 pub input: Variable<D>,
42 pub out: Variable<D>,
43}
44
45#[derive(Debug, Clone)]
46pub enum Instruction<D: Dialect> {
47 Metadata {
48 info_offset: Variable<D>,
49 split_meta: bool,
50 out: Variable<D>,
51 },
52 ExtendedMetadata {
53 info_offset: Variable<D>,
54 dim: Variable<D>,
55 split_meta: bool,
56 static_offset: u32,
57 out: Variable<D>,
58 },
59 ConstLength {
60 length: u32,
61 out: Variable<D>,
62 },
63 SliceLength {
64 input: Variable<D>,
65 out: Variable<D>,
66 },
67 DeclareVariable {
68 var: Variable<D>,
69 },
70 Modulo(BinaryInstruction<D>),
71 Remainder(BinaryInstruction<D>),
72 Add(BinaryInstruction<D>),
73 Fma {
74 a: Variable<D>,
75 b: Variable<D>,
76 c: Variable<D>,
77 out: Variable<D>,
78 },
79 Div(BinaryInstruction<D>),
80 Mul(BinaryInstruction<D>),
81 Sub(BinaryInstruction<D>),
82 HiMul(BinaryInstruction<D>),
83 Index(IndexInstruction<D>),
84 IndexAssign(IndexAssignInstruction<D>),
85 Assign(UnaryInstruction<D>),
86 SpecialCast(UnaryInstruction<D>),
87 RangeLoop {
88 i: Variable<D>,
89 start: Variable<D>,
90 end: Variable<D>,
91 step: Option<Variable<D>>,
92 inclusive: bool,
93 instructions: Vec<Self>,
94 },
95 VecInit {
96 inputs: Vec<Variable<D>>,
97 out: Variable<D>,
98 },
99 Loop {
100 instructions: Vec<Self>,
101 },
102 If {
103 cond: Variable<D>,
104 instructions: Vec<Self>,
105 },
106 IfElse {
107 cond: Variable<D>,
108 instructions_if: Vec<Self>,
109 instructions_else: Vec<Self>,
110 },
111 Select {
112 cond: Variable<D>,
113 then: Variable<D>,
114 or_else: Variable<D>,
115 out: Variable<D>,
116 },
117 Switch {
118 value: Variable<D>,
119 instructions_default: Vec<Self>,
120 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
121 },
122 Slice {
123 input: Variable<D>,
124 start: Variable<D>,
125 end: Variable<D>,
126 out: Variable<D>,
127 },
128 CheckedSlice {
129 input: Variable<D>,
130 start: Variable<D>,
131 end: Variable<D>,
132 out: Variable<D>,
133 len: Variable<D>,
134 },
135 ReinterpretSlice {
136 input: Variable<D>,
137 line_size: u32,
138 out: Variable<D>,
139 },
140 Return,
141 Break,
142 Equal(BinaryInstruction<D>),
143 NotEqual(BinaryInstruction<D>),
144 Lower(BinaryInstruction<D>),
145 Greater(BinaryInstruction<D>),
146 LowerEqual(BinaryInstruction<D>),
147 GreaterEqual(BinaryInstruction<D>),
148 Erf(UnaryInstruction<D>),
149 BitwiseOr(BinaryInstruction<D>),
150 BitwiseAnd(BinaryInstruction<D>),
151 BitwiseXor(BinaryInstruction<D>),
152 CountBits(UnaryInstruction<D>),
153 ReverseBits(UnaryInstruction<D>),
154 ShiftLeft(BinaryInstruction<D>),
155 ShiftRight(BinaryInstruction<D>),
156 BitwiseNot(UnaryInstruction<D>),
157 LeadingZeros(UnaryInstruction<D>),
158 FindFirstSet(UnaryInstruction<D>),
159 Abs(UnaryInstruction<D>),
160 Exp(UnaryInstruction<D>),
161 Log(UnaryInstruction<D>),
162 Log1p(UnaryInstruction<D>),
163 Cos(UnaryInstruction<D>),
164 Sin(UnaryInstruction<D>),
165 Tanh(UnaryInstruction<D>),
166 Powf(BinaryInstruction<D>),
167 Sqrt(UnaryInstruction<D>),
168 Min(BinaryInstruction<D>),
169 Max(BinaryInstruction<D>),
170 Not(UnaryInstruction<D>),
171 Or(BinaryInstruction<D>),
172 And(BinaryInstruction<D>),
173 Clamp {
174 input: Variable<D>,
175 min_value: Variable<D>,
176 max_value: Variable<D>,
177 out: Variable<D>,
178 },
179 SyncThreads,
180 SyncWarp,
181 ThreadFence,
182 ProxySharedFence,
183 BulkCommitGroup,
184 BulkWaitGroup {
185 max_pending: u32,
186 },
187 BulkWaitGroupRead {
188 max_pending: u32,
189 },
190 Round(UnaryInstruction<D>),
191 Ceil(UnaryInstruction<D>),
192 Floor(UnaryInstruction<D>),
193 Warp(WarpInstruction<D>),
194 Wmma(WmmaInstruction<D>),
195 Bitcast(UnaryInstruction<D>),
196 AtomicLoad(UnaryInstruction<D>),
197 AtomicStore(UnaryInstruction<D>),
198 AtomicSwap(BinaryInstruction<D>),
199 AtomicAdd(BinaryInstruction<D>),
200 AtomicSub(BinaryInstruction<D>),
201 AtomicMax(BinaryInstruction<D>),
202 AtomicMin(BinaryInstruction<D>),
203 AtomicAnd(BinaryInstruction<D>),
204 AtomicOr(BinaryInstruction<D>),
205 AtomicXor(BinaryInstruction<D>),
206 AtomicCAS {
207 input: Variable<D>,
208 cmp: Variable<D>,
209 val: Variable<D>,
210 out: Variable<D>,
211 },
212 Neg(UnaryInstruction<D>),
213 Magnitude(UnaryInstruction<D>),
214 Normalize(UnaryInstruction<D>),
215 Dot(BinaryInstruction<D>),
216 Copy {
217 input: Variable<D>,
218 in_index: Variable<D>,
219 out: Variable<D>,
220 out_index: Variable<D>,
221 },
222 CopyBulk {
223 input: Variable<D>,
224 in_index: Variable<D>,
225 out: Variable<D>,
226 out_index: Variable<D>,
227 len: u32,
228 },
229 Printf {
230 format_string: String,
231 args: Vec<Variable<D>>,
232 },
233 Comment {
234 content: String,
235 },
236 Pipeline(PipelineOps<D>),
237 Barrier(BarrierOps<D>),
238 MemCopyAsyncTensorSharedToGlobal {
239 smem_buffer: Variable<D>,
240 smem_offset: Variable<D>,
241 tensor_map: Variable<D>,
242 indices: Vec<Variable<D>>,
243 },
244 Line {
245 file: Cow<'static, str>,
246 line: u32,
247 },
248}
249
250impl<D: Dialect> Display for Instruction<D> {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 match self {
253 Instruction::Return => f.write_str("return;"),
254 Instruction::Break => f.write_str("break;"),
255 Instruction::DeclareVariable { var } => match var {
256 Variable::WmmaFragment { .. } => D::compile_wmma_fragment_declaration(f, var),
257 _ => {
258 let item = var.item();
259 writeln!(f, "{item} {var};")
260 }
261 },
262 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
263 Instruction::Slice {
264 input,
265 start,
266 end,
267 out,
268 } => {
269 let item = out.item();
270 let addr_space = D::address_space_for_variable(input);
271 writeln!(f, "const uint {out}_length = {end} - {start};")?;
272 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
273 }
274 Instruction::CheckedSlice {
275 input,
276 start,
277 end,
278 out,
279 len,
280 } => {
281 let item = out.item();
282 let addr_space = D::address_space_for_variable(input);
283 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
284 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
285 }
286 Instruction::ReinterpretSlice {
287 input,
288 line_size,
289 out,
290 } => {
291 let mut item = out.item();
292 item.vectorization = *line_size as usize;
293 let addr_space = D::address_space_for_variable(input);
294
295 writeln!(
296 f,
297 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
298 )
299 }
300 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
301 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
302 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
303 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
304 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
305 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
306 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
307 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
308 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
309 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
310 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
311 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
312 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
313 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
314 Instruction::Index(it) => Index::format(f, &it.list, &it.index, &it.out, it.line_size),
315 Instruction::IndexAssign(it) => {
316 IndexAssign::format(f, &it.index, &it.value, &it.out, it.line_size)
317 }
318 Instruction::Copy {
319 input,
320 in_index,
321 out,
322 out_index,
323 } => {
324 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
325 }
326 Instruction::CopyBulk {
327 input,
328 in_index,
329 out,
330 out_index,
331 len,
332 } => {
333 for i in 0..*len {
334 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
335 }
336 Ok(())
337 }
338 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
339 Instruction::RangeLoop {
340 i,
341 start,
342 end,
343 step,
344 inclusive,
345 instructions,
346 } => {
347 let increment = step
348 .map(|step| format!("{i} += {step}"))
349 .unwrap_or_else(|| format!("++{i}"));
350 let cmp = if *inclusive { "<=" } else { "<" };
351 let i_ty = i.item();
352
353 write!(
354 f,
355 "
356for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
357"
358 )?;
359 for instruction in instructions {
360 write!(f, "{instruction}")?;
361 }
362
363 f.write_str("}\n")
364 }
365 Instruction::Loop { instructions } => {
366 writeln!(f, "while (true) {{")?;
367 for i in instructions {
368 write!(f, "{i}")?;
369 }
370 f.write_str("}\n")
371 }
372 Instruction::If { cond, instructions } => {
373 writeln!(f, "if ({cond}) {{")?;
374 for i in instructions {
375 write!(f, "{i}")?;
376 }
377 f.write_str("}\n")
378 }
379 Instruction::IfElse {
380 cond,
381 instructions_if,
382 instructions_else,
383 } => {
384 writeln!(f, "if ({cond}) {{")?;
385 for i in instructions_if {
386 write!(f, "{i}")?;
387 }
388 f.write_str("} else {\n")?;
389 for i in instructions_else {
390 write!(f, "{i}")?;
391 }
392 f.write_str("}\n")
393 }
394 Instruction::Select {
395 cond,
396 then,
397 or_else,
398 out,
399 } => {
400 let item_or_else = or_else.item();
401 let item_then = then.item();
402 let item_out = out.item();
403
404 let vf_then = item_then.vectorization;
405 let vf_or_else = item_or_else.vectorization;
406 let vf_out = item_out.vectorization;
407 let vf_cond = cond.item().vectorization;
408
409 let item_out = out.item();
410 let cond_elem = cond.item().elem;
411 let out = out.fmt_left();
412
413 let should_broadcast =
414 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
415
416 if should_broadcast {
417 let vf = usize::max(vf_cond, vf_out);
418 let vf = usize::max(vf, vf_then);
419 let vf = usize::max(vf, vf_or_else);
420
421 writeln!(f, "{out} = {item_out} {{")?;
422 for i in 0..vf {
423 let theni = then.index(i);
424 let or_elsei = or_else.index(i);
425 let condi = cond.index(i);
426 let condi = EnsureBoolArg {
427 var: &condi,
428 elem: &cond_elem,
429 };
430
431 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
432 }
433
434 writeln!(f, "}};")
435 } else {
436 let cond = EnsureBoolArg {
437 var: &cond,
438 elem: &cond_elem,
439 };
440 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
441 }
442 }
443 Instruction::Switch {
444 value,
445 instructions_default,
446 instructions_cases,
447 } => {
448 writeln!(f, "switch({value}) {{")?;
449 for (value, block) in instructions_cases {
450 write!(f, "case {value}:\n{{\n")?;
451 for i in block {
452 i.fmt(f)?;
453 }
454 f.write_str("break;\n}\n")?;
455 }
456 f.write_str("default:\n{")?;
457 for i in instructions_default {
458 i.fmt(f)?;
459 }
460 f.write_str("}\n}\n")
461 }
462 Instruction::Metadata {
463 info_offset,
464 split_meta,
465 out,
466 } => {
467 let out = out.fmt_left();
468 match *split_meta {
469 true => writeln!(f, "{out} = static_info.x[{info_offset}];"),
470 false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
471 }
472 }
473 Instruction::ExtendedMetadata {
474 info_offset,
475 dim,
476 split_meta,
477 static_offset,
478 out,
479 } => {
480 let out = out.fmt_left();
481 match *split_meta {
482 true => writeln!(
483 f,
484 "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
485 ),
486 false => writeln!(
487 f,
488 "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
489 ),
490 }
491 }
492 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
493 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
494 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
495 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
496 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
497 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
498 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
499 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
500 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
501 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
502 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
503 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
504 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
505 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
506 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
507 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
508 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
509 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
510 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
511 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
512 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
513 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
514 Instruction::Clamp {
515 input,
516 min_value,
517 max_value,
518 out,
519 } => Clamp::format(f, input, min_value, max_value, out),
520 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
521 Instruction::SyncWarp => D::compile_instruction_sync_warp(f),
522 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
523 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
524 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
525 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
526 Instruction::SliceLength { input, out } => {
527 let out = out.fmt_left();
528 writeln!(f, "{out} = {input}_length;")
529 }
530 Instruction::ConstLength { length, out } => {
531 let out = out.fmt_left();
532 writeln!(f, "{out} = {length};")
533 }
534 Instruction::Warp(it) => write!(f, "{it}"),
535 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
536 Instruction::Wmma(it) => write!(f, "{it}"),
537 Instruction::Bitcast(UnaryInstruction { input, out }) => {
538 let qualifier = out.const_qualifier();
539 let input_item = input.item();
540 let out_item = out.item();
541
542 if out_item.elem.size() * out_item.vectorization
543 != input.item().elem.size() * input.item().vectorization
544 {
545 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
546 } else {
547 let out = out.fmt_left();
548 let addr_space = D::address_space_for_variable(input);
549 writeln!(
550 f,
551 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
552 )
553 }
554 }
555 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
556 D::compile_atomic_add(f, lhs, rhs, out)
557 }
558 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
559 D::compile_atomic_and(f, lhs, rhs, out)
560 }
561 Instruction::AtomicCAS {
562 input,
563 cmp,
564 val,
565 out,
566 } => D::compile_atomic_cas(f, input, cmp, val, out),
567 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
568 D::compile_atomic_load(f, input, out)
569 }
570 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
571 D::compile_atomic_max(f, lhs, rhs, out)
572 }
573 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
574 D::compile_atomic_min(f, lhs, rhs, out)
575 }
576 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
577 D::compile_atomic_or(f, lhs, rhs, out)
578 }
579 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
580 D::compile_atomic_store(f, input, out)
581 }
582 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
583 D::compile_atomic_sub(f, lhs, rhs, out)
584 }
585 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
586 D::compile_atomic_swap(f, lhs, rhs, out)
587 }
588 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
589 D::compile_atomic_xor(f, lhs, rhs, out)
590 }
591 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
592 Instruction::Neg(UnaryInstruction { input, out }) => {
593 let out = out.fmt_left();
594 writeln!(f, "{out} = -{input};")
595 }
596 Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out),
597 Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out),
598 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
599 Instruction::VecInit { inputs, out } => {
600 let item = out.item();
601 let inputs = inputs
602 .iter()
603 .map(|input| format!("{input}"))
604 .collect::<Vec<_>>();
605 let out = out.fmt_left();
606 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
607 }
608 Instruction::Printf {
609 format_string,
610 args,
611 } => D::compile_instruction_printf(f, format_string, args),
612 Instruction::Comment { content } => {
613 if content.contains('\n') {
614 writeln!(f, "/* {content} */")
615 } else {
616 writeln!(f, "// {content}")
617 }
618 }
619 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
620 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
621 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
622 Instruction::ProxySharedFence => {
623 writeln!(
624 f,
625 "cuda::device::experimental::fence_proxy_async_shared_cta();"
626 )
627 }
628 Instruction::BulkCommitGroup => writeln!(
629 f,
630 "cuda::device::experimental::cp_async_bulk_commit_group();"
631 ),
632 Instruction::BulkWaitGroup { max_pending } => writeln!(
633 f,
634 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
635 ),
636 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
637 f,
638 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
639 ),
640 Instruction::MemCopyAsyncTensorSharedToGlobal {
641 smem_buffer,
642 smem_offset,
643 tensor_map,
644 indices,
645 } => {
646 let rank = indices.len();
647 let smem_ptr = smem_buffer.fmt_ptr();
648 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
649 let _ = write!(s, "{it}, ");
650 s
651 });
652 writeln!(
653 f,
654 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr} + {smem_offset});"
655 )
656 }
657 Instruction::SpecialCast(UnaryInstruction { input, out }) => {
658 #[cfg(not(feature = "cuda"))]
660 {
661 let _ = (input, out);
662 unimplemented!("FP8/FP6/FP4 casting isn't supported outside of CUDA");
663 }
664 #[cfg(feature = "cuda")]
665 crate::cuda::convert::special_cast::<D>(f, input, out)
666 }
667 }
668 }
669}
670
671struct Fma<D: Dialect> {
672 _dialect: PhantomData<D>,
673}
674
675impl<D: Dialect> Fma<D> {
676 fn format(
677 f: &mut core::fmt::Formatter<'_>,
678 a: &Variable<D>,
679 b: &Variable<D>,
680 c: &Variable<D>,
681 out: &Variable<D>,
682 ) -> core::fmt::Result {
683 let out_item = out.item();
684 let num = out_item.vectorization;
685
686 let out = out.fmt_left();
687 if num == 1 {
688 writeln!(f, "{out} = fma({a}, {b}, {c});")
689 } else {
690 writeln!(f, "{out} = {out_item}{{")?;
691
692 for i in 0..num {
693 let ai = a.index(i);
694 let bi = b.index(i);
695 let ci = c.index(i);
696
697 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
698 }
699 f.write_str("};\n")
700 }
701 }
702}
703
704struct Clamp<D: Dialect> {
705 _dialect: PhantomData<D>,
706}
707
708impl<D: Dialect> Clamp<D> {
709 fn format(
710 f: &mut core::fmt::Formatter<'_>,
711 input: &Variable<D>,
712 min_value: &Variable<D>,
713 max_value: &Variable<D>,
714 out: &Variable<D>,
715 ) -> core::fmt::Result {
716 let input = input.optimized();
717 let min_value = min_value.optimized();
718 let max_value = max_value.optimized();
719 let out = out.optimized();
720 let out_item = out.item();
721 let num = out_item.vectorization;
722
723 let out_fmt = out.fmt_left();
724 if num == 1 {
725 writeln!(f, "{out_fmt} = ")?;
726 D::compile_instruction_max_function_name(f, out.item())?;
727 writeln!(f, "({min_value}, ")?;
728 D::compile_instruction_min_function_name(f, out.item())?;
729 writeln!(f, "({max_value}, {input}));")
730 } else {
731 writeln!(f, "{out_fmt} = {out_item}{{")?;
732 let mut item = out.item();
733 item.vectorization = 1;
734
735 for i in 0..num {
736 let inputi = input.index(i);
737 let mini = min_value.index(i);
738 let maxi = max_value.index(i);
739
740 D::compile_instruction_max_function_name(f, item)?;
741 writeln!(f, "({mini}, ")?;
742 D::compile_instruction_min_function_name(f, item)?;
743 writeln!(f, "({maxi}, {inputi})),")?;
744 }
745
746 f.write_str("};\n")
747 }
748 }
749}
750
751struct Remainder<D: Dialect> {
752 _dialect: PhantomData<D>,
753}
754
755impl<D: Dialect> Remainder<D> {
756 fn format(
757 f: &mut core::fmt::Formatter<'_>,
758 lhs: &Variable<D>,
759 rhs: &Variable<D>,
760 out: &Variable<D>,
761 ) -> core::fmt::Result {
762 let floor = |elem| {
763 let prefix = match elem {
764 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
765 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
766 _ => "",
767 };
768 format!("{prefix}floor")
769 };
770
771 if out.item().vectorization == 1 {
772 let floor = floor(out.elem());
773
774 let out = out.fmt_left();
775 return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
776 }
777
778 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
779 let [lhs, rhs, out_optimized] = optimized.args;
780
781 let item_out_original = out.item();
782 let item_out_optimized = out_optimized.item();
783
784 let index = match optimized.optimization_factor {
785 Some(factor) => item_out_original.vectorization / factor,
786 None => item_out_optimized.vectorization,
787 };
788
789 let floor = floor(*item_out_optimized.elem());
790
791 let mut write_op =
792 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
793 let out = out.fmt_left();
794 writeln!(f, "{out} = {item_out}{{")?;
795 for i in 0..index {
796 let lhsi = lhs.index(i);
797 let rhsi = rhs.index(i);
798
799 writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
800 f.write_str(", ")?;
801 }
802
803 f.write_str("};\n")
804 };
805
806 if item_out_original == item_out_optimized {
807 write_op(&lhs, &rhs, out, item_out_optimized)
808 } else {
809 let out_tmp = Variable::tmp(item_out_optimized);
810
811 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
812
813 let addr_space = D::address_space_for_variable(&out_tmp);
814 let qualifier = out.const_qualifier();
815 let out = out.fmt_left();
816
817 writeln!(
818 f,
819 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
820 )?;
821
822 Ok(())
823 }
824 }
825}
826
827struct Magnitude<D: Dialect> {
828 _dialect: PhantomData<D>,
829}
830
831impl<D: Dialect> Magnitude<D> {
832 fn format(
833 f: &mut core::fmt::Formatter<'_>,
834 input: &Variable<D>,
835 out: &Variable<D>,
836 ) -> core::fmt::Result {
837 let num = input.item().vectorization;
838 let elem = input.elem();
839
840 let mag = format!("{out}_mag");
841
842 writeln!(f, "{} {mag} = 0.0;", out.item())?;
843
844 for i in 0..num {
845 let input_i = input.index(i);
846 writeln!(f, "{mag} += {input_i} * {input_i};")?;
847 }
848
849 let out = out.fmt_left();
850 write!(f, "{out} = ")?;
851 Sqrt::format_unary(f, &mag, elem)?;
852 f.write_str(";\n")
853 }
854}
855
856struct Normalize<D: Dialect> {
857 _dialect: PhantomData<D>,
858}
859
860impl<D: Dialect> Normalize<D> {
861 fn format(
862 f: &mut core::fmt::Formatter<'_>,
863 input: &Variable<D>,
864 out: &Variable<D>,
865 ) -> core::fmt::Result {
866 let num = input.item().vectorization;
867 let elem = input.elem();
868 let norm = format!("{out}_norm");
869
870 let out_item = out.item();
871 let out = out.fmt_left();
872 writeln!(f, "{elem} {norm} = 0.0;")?;
873
874 for i in 0..num {
875 let input_i = input.index(i);
876 writeln!(f, "{norm} += {input_i} * {input_i};")?;
877 }
878
879 write!(f, "{norm} = ")?;
880 Sqrt::format_unary(f, &norm, elem)?;
881 f.write_str(";\n")?;
882
883 if num == 1 {
884 writeln!(f, "{out} = {input} / {norm};")
885 } else {
886 write!(f, "{out} = {out_item}{{")?;
887 for i in 0..num {
888 let input_i = input.index(i);
889
890 writeln!(f, "{input_i} / {norm},")?;
891 }
892
893 f.write_str("};\n")
894 }
895 }
896}
897
898struct Dot<D: Dialect> {
899 _dialect: PhantomData<D>,
900}
901
902impl<D: Dialect> Dot<D> {
903 fn format(
904 f: &mut core::fmt::Formatter<'_>,
905 lhs: &Variable<D>,
906 rhs: &Variable<D>,
907 out: &Variable<D>,
908 ) -> core::fmt::Result {
909 let num = lhs.item().vectorization;
910
911 let muls = (0..num)
912 .map(|i| {
913 let lhs_i = lhs.index(i);
914 let rhs_i = rhs.index(i);
915 format!("{lhs_i} * {rhs_i}")
916 })
917 .collect::<Vec<_>>();
918
919 let out = out.fmt_left();
920 writeln!(f, "{out} = {};", muls.join(" + "))
921 }
922}
923
924struct EnsureBoolArg<'a, V: Display, D: Dialect> {
925 var: &'a V,
926 elem: &'a Elem<D>,
927}
928
929impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
930 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
931 if self.elem != &Elem::Bool {
932 write!(f, "bool({})", self.var)
933 } else {
934 write!(f, "{}", self.var)
935 }
936 }
937}