1use crate::shared::FmtLeft;
2
3use super::{
4 Component, Dialect, Elem, Item, Variable, WarpInstruction, WmmaInstruction,
5 barrier::BarrierOps, binary::*, pipeline::PipelineOps, unary::*,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Display, Write},
10 marker::PhantomData,
11};
12
13pub(crate) const INFO_NAME: &str = "info";
14pub(crate) const STATIC_INFO_NAME: &str = "static_info";
15
16#[derive(Debug, Clone)]
17pub struct BinaryInstruction<D: Dialect> {
18 pub lhs: Variable<D>,
19 pub rhs: Variable<D>,
20 pub out: Variable<D>,
21}
22
23#[derive(Debug, Clone)]
24pub struct UnaryInstruction<D: Dialect> {
25 pub input: Variable<D>,
26 pub out: Variable<D>,
27}
28
29#[derive(Debug, Clone)]
30pub enum Instruction<D: Dialect> {
31 Metadata {
32 info_offset: Variable<D>,
33 split_meta: bool,
34 out: Variable<D>,
35 },
36 ExtendedMetadata {
37 info_offset: Variable<D>,
38 dim: Variable<D>,
39 split_meta: bool,
40 static_offset: u32,
41 out: Variable<D>,
42 },
43 ConstLength {
44 length: u32,
45 out: Variable<D>,
46 },
47 SliceLength {
48 input: Variable<D>,
49 out: Variable<D>,
50 },
51 DeclareVariable {
52 var: Variable<D>,
53 },
54 Modulo(BinaryInstruction<D>),
55 Remainder(BinaryInstruction<D>),
56 Add(BinaryInstruction<D>),
57 Fma {
58 a: Variable<D>,
59 b: Variable<D>,
60 c: Variable<D>,
61 out: Variable<D>,
62 },
63 Div(BinaryInstruction<D>),
64 Mul(BinaryInstruction<D>),
65 Sub(BinaryInstruction<D>),
66 HiMul(BinaryInstruction<D>),
67 Index(BinaryInstruction<D>),
68 IndexAssign(BinaryInstruction<D>),
69 Assign(UnaryInstruction<D>),
70 RangeLoop {
71 i: Variable<D>,
72 start: Variable<D>,
73 end: Variable<D>,
74 step: Option<Variable<D>>,
75 inclusive: bool,
76 instructions: Vec<Self>,
77 },
78 VecInit {
79 inputs: Vec<Variable<D>>,
80 out: Variable<D>,
81 },
82 Loop {
83 instructions: Vec<Self>,
84 },
85 If {
86 cond: Variable<D>,
87 instructions: Vec<Self>,
88 },
89 IfElse {
90 cond: Variable<D>,
91 instructions_if: Vec<Self>,
92 instructions_else: Vec<Self>,
93 },
94 Select {
95 cond: Variable<D>,
96 then: Variable<D>,
97 or_else: Variable<D>,
98 out: Variable<D>,
99 },
100 Switch {
101 value: Variable<D>,
102 instructions_default: Vec<Self>,
103 instructions_cases: Vec<(Variable<D>, Vec<Self>)>,
104 },
105 Slice {
106 input: Variable<D>,
107 start: Variable<D>,
108 end: Variable<D>,
109 out: Variable<D>,
110 },
111 CheckedSlice {
112 input: Variable<D>,
113 start: Variable<D>,
114 end: Variable<D>,
115 out: Variable<D>,
116 len: Variable<D>,
117 },
118 ReinterpretSlice {
119 input: Variable<D>,
120 line_size: u32,
121 out: Variable<D>,
122 },
123 Return,
124 Break,
125 Equal(BinaryInstruction<D>),
126 NotEqual(BinaryInstruction<D>),
127 Lower(BinaryInstruction<D>),
128 Greater(BinaryInstruction<D>),
129 LowerEqual(BinaryInstruction<D>),
130 GreaterEqual(BinaryInstruction<D>),
131 Erf(UnaryInstruction<D>),
132 BitwiseOr(BinaryInstruction<D>),
133 BitwiseAnd(BinaryInstruction<D>),
134 BitwiseXor(BinaryInstruction<D>),
135 CountBits(UnaryInstruction<D>),
136 ReverseBits(UnaryInstruction<D>),
137 ShiftLeft(BinaryInstruction<D>),
138 ShiftRight(BinaryInstruction<D>),
139 BitwiseNot(UnaryInstruction<D>),
140 LeadingZeros(UnaryInstruction<D>),
141 FindFirstSet(UnaryInstruction<D>),
142 Abs(UnaryInstruction<D>),
143 Exp(UnaryInstruction<D>),
144 Log(UnaryInstruction<D>),
145 Log1p(UnaryInstruction<D>),
146 Cos(UnaryInstruction<D>),
147 Sin(UnaryInstruction<D>),
148 Tanh(UnaryInstruction<D>),
149 Powf(BinaryInstruction<D>),
150 Sqrt(UnaryInstruction<D>),
151 Min(BinaryInstruction<D>),
152 Max(BinaryInstruction<D>),
153 Not(UnaryInstruction<D>),
154 Or(BinaryInstruction<D>),
155 And(BinaryInstruction<D>),
156 Clamp {
157 input: Variable<D>,
158 min_value: Variable<D>,
159 max_value: Variable<D>,
160 out: Variable<D>,
161 },
162 SyncThreads,
163 ThreadFence,
164 ProxySharedFence,
165 BulkCommitGroup,
166 BulkWaitGroup {
167 max_pending: u32,
168 },
169 BulkWaitGroupRead {
170 max_pending: u32,
171 },
172 Round(UnaryInstruction<D>),
173 Ceil(UnaryInstruction<D>),
174 Floor(UnaryInstruction<D>),
175 Warp(WarpInstruction<D>),
176 Wmma(WmmaInstruction<D>),
177 Bitcast(UnaryInstruction<D>),
178 AtomicLoad(UnaryInstruction<D>),
179 AtomicStore(UnaryInstruction<D>),
180 AtomicSwap(BinaryInstruction<D>),
181 AtomicAdd(BinaryInstruction<D>),
182 AtomicSub(BinaryInstruction<D>),
183 AtomicMax(BinaryInstruction<D>),
184 AtomicMin(BinaryInstruction<D>),
185 AtomicAnd(BinaryInstruction<D>),
186 AtomicOr(BinaryInstruction<D>),
187 AtomicXor(BinaryInstruction<D>),
188 AtomicCAS {
189 input: Variable<D>,
190 cmp: Variable<D>,
191 val: Variable<D>,
192 out: Variable<D>,
193 },
194 Neg(UnaryInstruction<D>),
195 Magnitude(UnaryInstruction<D>),
196 Normalize(UnaryInstruction<D>),
197 Dot(BinaryInstruction<D>),
198 Copy {
199 input: Variable<D>,
200 in_index: Variable<D>,
201 out: Variable<D>,
202 out_index: Variable<D>,
203 },
204 CopyBulk {
205 input: Variable<D>,
206 in_index: Variable<D>,
207 out: Variable<D>,
208 out_index: Variable<D>,
209 len: u32,
210 },
211 Printf {
212 format_string: String,
213 args: Vec<Variable<D>>,
214 },
215 Comment {
216 content: String,
217 },
218 Pipeline(PipelineOps<D>),
219 Barrier(BarrierOps<D>),
220 MemCopyAsyncTensorSharedToGlobal {
221 smem_buffer: Variable<D>,
222 tensor_map: Variable<D>,
223 indices: Vec<Variable<D>>,
224 },
225 Line {
226 file: Cow<'static, str>,
227 line: u32,
228 },
229}
230
231impl<D: Dialect> Display for Instruction<D> {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 Instruction::Return => f.write_str("return;"),
235 Instruction::Break => f.write_str("break;"),
236 Instruction::DeclareVariable { var } => match var {
237 Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
238 _ => {
239 let item = var.item();
240 writeln!(f, "{item} {var};")
241 }
242 },
243 Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
244 Instruction::Slice {
245 input,
246 start,
247 end,
248 out,
249 } => {
250 let item = out.item();
251 let addr_space = D::address_space_for_variable(input);
252 writeln!(f, "const uint {out}_length = {end} - {start};")?;
253 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
254 }
255 Instruction::CheckedSlice {
256 input,
257 start,
258 end,
259 out,
260 len,
261 } => {
262 let item = out.item();
263 let addr_space = D::address_space_for_variable(input);
264 writeln!(f, "const uint {out}_length = min({len}, {end}) - {start};")?;
265 writeln!(f, "{addr_space}{item} *{out} = {input} + {start};")
266 }
267 Instruction::ReinterpretSlice {
268 input,
269 line_size,
270 out,
271 } => {
272 let mut item = out.item();
273 item.vectorization = *line_size as usize;
274 let addr_space = D::address_space_for_variable(input);
275
276 writeln!(
277 f,
278 "{addr_space}{item} *{out} = reinterpret_cast<{item}*>({input});"
279 )
280 }
281 Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
282 Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
283 Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
284 Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
285 Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
286 Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
287 Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
288 Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
289 Instruction::CountBits(it) => CountBits::format(f, &it.input, &it.out),
290 Instruction::ReverseBits(it) => ReverseBits::format(f, &it.input, &it.out),
291 Instruction::LeadingZeros(it) => LeadingZeros::format(f, &it.input, &it.out),
292 Instruction::FindFirstSet(it) => FindFirstSet::format(f, &it.input, &it.out),
293 Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
294 Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
295 Instruction::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out),
296 Instruction::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out),
297 Instruction::Copy {
298 input,
299 in_index,
300 out,
301 out_index,
302 } => {
303 writeln!(f, "{out}[{out_index}] = {input}[{in_index}];")
304 }
305 Instruction::CopyBulk {
306 input,
307 in_index,
308 out,
309 out_index,
310 len,
311 } => {
312 for i in 0..*len {
313 writeln!(f, "{out}[{out_index} + {i}] = {input}[{in_index} + {i}];")?;
314 }
315 Ok(())
316 }
317 Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
318 Instruction::RangeLoop {
319 i,
320 start,
321 end,
322 step,
323 inclusive,
324 instructions,
325 } => {
326 let increment = step
327 .map(|step| format!("{i} += {step}"))
328 .unwrap_or_else(|| format!("++{i}"));
329 let cmp = if *inclusive { "<=" } else { "<" };
330 let i_ty = i.item();
331
332 write!(
333 f,
334 "
335for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
336"
337 )?;
338 for instruction in instructions {
339 write!(f, "{instruction}")?;
340 }
341
342 f.write_str("}\n")
343 }
344 Instruction::Loop { instructions } => {
345 writeln!(f, "while (true) {{")?;
346 for i in instructions {
347 write!(f, "{i}")?;
348 }
349 f.write_str("}\n")
350 }
351 Instruction::If { cond, instructions } => {
352 writeln!(f, "if ({cond}) {{")?;
353 for i in instructions {
354 write!(f, "{i}")?;
355 }
356 f.write_str("}\n")
357 }
358 Instruction::IfElse {
359 cond,
360 instructions_if,
361 instructions_else,
362 } => {
363 writeln!(f, "if ({cond}) {{")?;
364 for i in instructions_if {
365 write!(f, "{i}")?;
366 }
367 f.write_str("} else {\n")?;
368 for i in instructions_else {
369 write!(f, "{i}")?;
370 }
371 f.write_str("}\n")
372 }
373 Instruction::Select {
374 cond,
375 then,
376 or_else,
377 out,
378 } => {
379 let item_or_else = or_else.item();
380 let item_then = then.item();
381 let item_out = out.item();
382
383 let vf_then = item_then.vectorization;
384 let vf_or_else = item_or_else.vectorization;
385 let vf_out = item_out.vectorization;
386 let vf_cond = cond.item().vectorization;
387
388 let item_out = out.item();
389 let cond_elem = cond.item().elem;
390 let out = out.fmt_left();
391
392 let should_broadcast =
393 vf_cond > 1 || item_out != item_or_else || item_out != item_then;
394
395 if should_broadcast {
396 let vf = usize::max(vf_cond, vf_out);
397 let vf = usize::max(vf, vf_then);
398 let vf = usize::max(vf, vf_or_else);
399
400 writeln!(f, "{out} = {item_out} {{")?;
401 for i in 0..vf {
402 let theni = then.index(i);
403 let or_elsei = or_else.index(i);
404 let condi = cond.index(i);
405 let condi = EnsureBoolArg {
406 var: &condi,
407 elem: &cond_elem,
408 };
409
410 writeln!(f, "({condi}) ? {theni} : {or_elsei},")?;
411 }
412
413 writeln!(f, "}};")
414 } else {
415 let cond = EnsureBoolArg {
416 var: &cond,
417 elem: &cond_elem,
418 };
419 writeln!(f, "{out} = ({cond}) ? {then} : {or_else};")
420 }
421 }
422 Instruction::Switch {
423 value,
424 instructions_default,
425 instructions_cases,
426 } => {
427 writeln!(f, "switch({value}) {{")?;
428 for (value, block) in instructions_cases {
429 write!(f, "case {value}:\n{{\n")?;
430 for i in block {
431 i.fmt(f)?;
432 }
433 f.write_str("break;\n}\n")?;
434 }
435 f.write_str("default:\n{")?;
436 for i in instructions_default {
437 i.fmt(f)?;
438 }
439 f.write_str("}\n}\n")
440 }
441 Instruction::Metadata {
442 info_offset,
443 split_meta,
444 out,
445 } => {
446 let out = out.fmt_left();
447 match *split_meta {
448 true => writeln!(f, "{out} = static_info.x[{info_offset}];"),
449 false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
450 }
451 }
452 Instruction::ExtendedMetadata {
453 info_offset,
454 dim,
455 split_meta,
456 static_offset,
457 out,
458 } => {
459 let out = out.fmt_left();
460 match *split_meta {
461 true => writeln!(
462 f,
463 "{out} = {INFO_NAME}[{STATIC_INFO_NAME}.x[{info_offset}] + {dim} - {static_offset}];"
464 ),
465 false => writeln!(
466 f,
467 "{out} = {INFO_NAME}[{INFO_NAME}[{info_offset}] + {dim}];"
468 ),
469 }
470 }
471 Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
472 Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
473 Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
474 Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
475 Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
476 Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
477 Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
478 Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
479 Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
480 Instruction::Log(it) => Log::format(f, &it.input, &it.out),
481 Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
482 Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
483 Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
484 Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
485 Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
486 Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
487 Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
488 Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
489 Instruction::Not(it) => Not::format(f, &it.input, &it.out),
490 Instruction::BitwiseNot(it) => BitwiseNot::format(f, &it.input, &it.out),
491 Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
492 Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
493 Instruction::Clamp {
494 input,
495 min_value,
496 max_value,
497 out,
498 } => Clamp::format(f, input, min_value, max_value, out),
499 Instruction::SyncThreads => D::compile_instruction_sync_threads(f),
500 Instruction::ThreadFence => f.write_str("__threadfence();\n"),
501 Instruction::Round(it) => Round::format(f, &it.input, &it.out),
502 Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
503 Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
504 Instruction::SliceLength { input, out } => {
505 let out = out.fmt_left();
506 writeln!(f, "{out} = {input}_length;")
507 }
508 Instruction::ConstLength { length, out } => {
509 let out = out.fmt_left();
510 writeln!(f, "{out} = {length};")
511 }
512 Instruction::Warp(it) => write!(f, "{it}"),
513 Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
514 Instruction::Wmma(it) => write!(f, "{it}"),
515 Instruction::Bitcast(UnaryInstruction { input, out }) => {
516 let qualifier = out.const_qualifier();
517 let input_item = input.item();
518 let out_item = out.item();
519
520 if out_item.elem.size() * out_item.vectorization
521 != input.item().elem.size() * input.item().vectorization
522 {
523 panic!("Unsupported type for bitcasting {out_item:?} from {input_item:?}");
524 } else {
525 let out = out.fmt_left();
526 let addr_space = D::address_space_for_variable(input);
527 writeln!(
528 f,
529 "{out} = reinterpret_cast<{addr_space}{out_item}{qualifier}&>({input});"
530 )
531 }
532 }
533 Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => {
534 D::compile_atomic_add(f, lhs, rhs, out)
535 }
536 Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => {
537 D::compile_atomic_and(f, lhs, rhs, out)
538 }
539 Instruction::AtomicCAS {
540 input,
541 cmp,
542 val,
543 out,
544 } => D::compile_atomic_cas(f, input, cmp, val, out),
545 Instruction::AtomicLoad(UnaryInstruction { input, out }) => {
546 D::compile_atomic_load(f, input, out)
547 }
548 Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => {
549 D::compile_atomic_max(f, lhs, rhs, out)
550 }
551 Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => {
552 D::compile_atomic_min(f, lhs, rhs, out)
553 }
554 Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => {
555 D::compile_atomic_or(f, lhs, rhs, out)
556 }
557 Instruction::AtomicStore(UnaryInstruction { input, out }) => {
558 D::compile_atomic_store(f, input, out)
559 }
560 Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => {
561 D::compile_atomic_sub(f, lhs, rhs, out)
562 }
563 Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => {
564 D::compile_atomic_swap(f, lhs, rhs, out)
565 }
566 Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => {
567 D::compile_atomic_xor(f, lhs, rhs, out)
568 }
569 Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
570 Instruction::Neg(UnaryInstruction { input, out }) => {
571 let out = out.fmt_left();
572 writeln!(f, "{out} = -{input};")
573 }
574 Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out),
575 Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out),
576 Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
577 Instruction::VecInit { inputs, out } => {
578 let item = out.item();
579 let inputs = inputs
580 .iter()
581 .map(|input| format!("{input}"))
582 .collect::<Vec<_>>();
583 let out = out.fmt_left();
584 writeln!(f, "{out} = {item}{{{}}};", inputs.join(","))
585 }
586 Instruction::Printf {
587 format_string,
588 args,
589 } => D::compile_instruction_printf(f, format_string, args),
590 Instruction::Comment { content } => {
591 if content.contains('\n') {
592 writeln!(f, "/* {content} */")
593 } else {
594 writeln!(f, "// {content}")
595 }
596 }
597 Instruction::Pipeline(pipeline_ops) => write!(f, "{pipeline_ops}"),
598 Instruction::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
599 Instruction::Line { file, line } => writeln!(f, "#line {line} \"{file}\""),
600 Instruction::ProxySharedFence => {
601 writeln!(
602 f,
603 "cuda::device::experimental::fence_proxy_async_shared_cta();"
604 )
605 }
606 Instruction::BulkCommitGroup => writeln!(
607 f,
608 "cuda::device::experimental::cp_async_bulk_commit_group();"
609 ),
610 Instruction::BulkWaitGroup { max_pending } => writeln!(
611 f,
612 "cuda::device::experimental::cp_async_bulk_wait_group<{max_pending}>();"
613 ),
614 Instruction::BulkWaitGroupRead { max_pending } => writeln!(
615 f,
616 "cuda::device::experimental::cp_async_bulk_wait_group_read<{max_pending}>();"
617 ),
618 Instruction::MemCopyAsyncTensorSharedToGlobal {
619 smem_buffer,
620 tensor_map,
621 indices,
622 } => {
623 let rank = indices.len();
624 let smem_ptr = smem_buffer.fmt_ptr();
625 let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
626 let _ = write!(s, "{it}, ");
627 s
628 });
629 writeln!(
630 f,
631 "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_shared_to_global(&{tensor_map}, {indices} {smem_ptr});"
632 )
633 }
634 }
635 }
636}
637
638struct Fma<D: Dialect> {
639 _dialect: PhantomData<D>,
640}
641
642impl<D: Dialect> Fma<D> {
643 fn format(
644 f: &mut core::fmt::Formatter<'_>,
645 a: &Variable<D>,
646 b: &Variable<D>,
647 c: &Variable<D>,
648 out: &Variable<D>,
649 ) -> core::fmt::Result {
650 let out_item = out.item();
651 let num = out_item.vectorization;
652
653 let out = out.fmt_left();
654 if num == 1 {
655 writeln!(f, "{out} = fma({a}, {b}, {c});")
656 } else {
657 writeln!(f, "{out} = {out_item}{{")?;
658
659 for i in 0..num {
660 let ai = a.index(i);
661 let bi = b.index(i);
662 let ci = c.index(i);
663
664 writeln!(f, "fma({ai}, {bi}, {ci}),")?;
665 }
666 f.write_str("};\n")
667 }
668 }
669}
670
671struct Clamp<D: Dialect> {
672 _dialect: PhantomData<D>,
673}
674
675impl<D: Dialect> Clamp<D> {
676 fn format(
677 f: &mut core::fmt::Formatter<'_>,
678 input: &Variable<D>,
679 min_value: &Variable<D>,
680 max_value: &Variable<D>,
681 out: &Variable<D>,
682 ) -> core::fmt::Result {
683 let input = input.optimized();
684 let min_value = min_value.optimized();
685 let max_value = max_value.optimized();
686 let out = out.optimized();
687 let out_item = out.item();
688 let num = out_item.vectorization;
689
690 let (min, max) = match out.elem() {
691 Elem::F16 | Elem::BF16 => ("__hmin", "__hmax"),
692 Elem::F162 | Elem::BF162 => ("__hmin2", "__hmax2"),
693 _ => ("min", "max"),
694 };
695
696 let out_fmt = out.fmt_left();
697 if num == 1 {
698 writeln!(f, "{out_fmt} = ")?;
699 D::compile_instruction_max_function_name(f, out.item())?;
700 writeln!(f, "({min_value}, ")?;
701 D::compile_instruction_min_function_name(f, out.item())?;
702 writeln!(f, "({max_value}, {input}));")
703 } else {
704 writeln!(f, "{out_fmt} = {out_item}{{")?;
705 for i in 0..num {
706 let inputi = input.index(i);
707 let mini = min_value.index(i);
708 let maxi = max_value.index(i);
709
710 writeln!(f, "{max}({mini}, {min}({maxi}, {inputi})),")?;
711 }
712
713 f.write_str("};\n")
714 }
715 }
716}
717
718struct Remainder<D: Dialect> {
719 _dialect: PhantomData<D>,
720}
721
722impl<D: Dialect> Remainder<D> {
723 fn format(
724 f: &mut core::fmt::Formatter<'_>,
725 lhs: &Variable<D>,
726 rhs: &Variable<D>,
727 out: &Variable<D>,
728 ) -> core::fmt::Result {
729 let floor = |elem| {
730 let prefix = match elem {
731 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
732 Elem::F162 | Elem::BF162 => D::compile_instruction_half2_function_name_prefix(),
733 _ => "",
734 };
735 format!("{prefix}floor")
736 };
737
738 if out.item().vectorization == 1 {
739 let floor = floor(out.elem());
740
741 let out = out.fmt_left();
742 return writeln!(f, "{out} = {lhs} - {rhs} * {floor}({lhs} / {rhs});");
743 }
744
745 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
746 let [lhs, rhs, out_optimized] = optimized.args;
747
748 let item_out_original = out.item();
749 let item_out_optimized = out_optimized.item();
750
751 let index = match optimized.optimization_factor {
752 Some(factor) => item_out_original.vectorization / factor,
753 None => item_out_optimized.vectorization,
754 };
755
756 let floor = floor(*item_out_optimized.elem());
757
758 let mut write_op =
759 |lhs: &Variable<D>, rhs: &Variable<D>, out: &Variable<D>, item_out: Item<D>| {
760 let out = out.fmt_left();
761 writeln!(f, "{out} = {item_out}{{")?;
762 for i in 0..index {
763 let lhsi = lhs.index(i);
764 let rhsi = rhs.index(i);
765
766 writeln!(f, "{lhsi} - {rhsi} * {floor}({lhsi} / {rhsi})")?;
767 f.write_str(", ")?;
768 }
769
770 f.write_str("};\n")
771 };
772
773 if item_out_original == item_out_optimized {
774 write_op(&lhs, &rhs, out, item_out_optimized)
775 } else {
776 let out_tmp = Variable::tmp(item_out_optimized);
777
778 write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;
779
780 let addr_space = D::address_space_for_variable(&out_tmp);
781 let qualifier = out.const_qualifier();
782 let out = out.fmt_left();
783
784 writeln!(
785 f,
786 "{out} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
787 )?;
788
789 Ok(())
790 }
791 }
792}
793
794struct Magnitude<D: Dialect> {
795 _dialect: PhantomData<D>,
796}
797
798impl<D: Dialect> Magnitude<D> {
799 fn format(
800 f: &mut core::fmt::Formatter<'_>,
801 input: &Variable<D>,
802 out: &Variable<D>,
803 ) -> core::fmt::Result {
804 let num = input.item().vectorization;
805 let elem = input.elem();
806
807 let mag = format!("{out}_mag");
808
809 writeln!(f, "{} {mag} = 0.0;", out.item())?;
810
811 for i in 0..num {
812 let input_i = input.index(i);
813 writeln!(f, "{mag} += {input_i} * {input_i};")?;
814 }
815
816 let out = out.fmt_left();
817 write!(f, "{out} = ")?;
818 Sqrt::format_unary(f, &mag, elem)?;
819 f.write_str(";\n")
820 }
821}
822
823struct Normalize<D: Dialect> {
824 _dialect: PhantomData<D>,
825}
826
827impl<D: Dialect> Normalize<D> {
828 fn format(
829 f: &mut core::fmt::Formatter<'_>,
830 input: &Variable<D>,
831 out: &Variable<D>,
832 ) -> core::fmt::Result {
833 let num = input.item().vectorization;
834 let elem = input.elem();
835 let norm = format!("{out}_norm");
836
837 let out_item = out.item();
838 let out = out.fmt_left();
839 writeln!(f, "{elem} {norm} = 0.0;")?;
840
841 for i in 0..num {
842 let input_i = input.index(i);
843 writeln!(f, "{norm} += {input_i} * {input_i};")?;
844 }
845
846 write!(f, "{norm} = ")?;
847 Sqrt::format_unary(f, &norm, elem)?;
848 f.write_str(";\n")?;
849
850 if num == 1 {
851 writeln!(f, "{out} = {input} / {norm};")
852 } else {
853 write!(f, "{out} = {out_item}{{")?;
854 for i in 0..num {
855 let input_i = input.index(i);
856
857 writeln!(f, "{input_i} / {norm},")?;
858 }
859
860 f.write_str("};\n")
861 }
862 }
863}
864
865struct Dot<D: Dialect> {
866 _dialect: PhantomData<D>,
867}
868
869impl<D: Dialect> Dot<D> {
870 fn format(
871 f: &mut core::fmt::Formatter<'_>,
872 lhs: &Variable<D>,
873 rhs: &Variable<D>,
874 out: &Variable<D>,
875 ) -> core::fmt::Result {
876 let num = lhs.item().vectorization;
877
878 let muls = (0..num)
879 .map(|i| {
880 let lhs_i = lhs.index(i);
881 let rhs_i = rhs.index(i);
882 format!("{lhs_i} * {rhs_i}")
883 })
884 .collect::<Vec<_>>();
885
886 let out = out.fmt_left();
887 writeln!(f, "{out} = {};", muls.join(" + "))
888 }
889}
890
891struct EnsureBoolArg<'a, V: Display, D: Dialect> {
892 var: &'a V,
893 elem: &'a Elem<D>,
894}
895
896impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
897 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
898 if self.elem != &Elem::Bool {
899 write!(f, "bool({})", self.var)
900 } else {
901 write!(f, "{}", self.var)
902 }
903 }
904}