1use super::{Component, Dialect, Elem, FmtLeft, Variable};
2use std::fmt::Display;
3
4pub trait Unary<D: Dialect> {
5 fn format(
6 f: &mut std::fmt::Formatter<'_>,
7 input: &Variable<D>,
8 out: &Variable<D>,
9 ) -> std::fmt::Result {
10 let out_item = out.item();
11
12 if out_item.vectorization == 1 {
13 write!(f, "{} = ", out.fmt_left())?;
14 Self::format_scalar(f, *input, out_item.elem)?;
15 f.write_str(";\n")
16 } else {
17 Self::unroll_vec(f, input, out, out_item.elem, out_item.vectorization)
18 }
19 }
20
21 fn format_scalar<Input: Component<D>>(
22 f: &mut std::fmt::Formatter<'_>,
23 input: Input,
24 out_elem: Elem<D>,
25 ) -> std::fmt::Result;
26
27 fn unroll_vec(
28 f: &mut std::fmt::Formatter<'_>,
29 input: &Variable<D>,
30 out: &Variable<D>,
31 out_elem: Elem<D>,
32 index: usize,
33 ) -> std::fmt::Result {
34 let mut write_op = |index, out_elem, input: &Variable<D>, out: &Variable<D>| {
35 let out_item = out.item();
36 let out = out.fmt_left();
37 writeln!(f, "{out} = {out_item}{{")?;
38
39 for i in 0..index {
40 let inputi = input.index(i);
41
42 Self::format_scalar(f, inputi, out_elem)?;
43 f.write_str(",")?;
44 }
45
46 f.write_str("};\n")
47 };
48
49 if Self::can_optimize() {
50 let optimized = Variable::optimized_args([*input, *out]);
51 let [input, out_optimized] = optimized.args;
52
53 let item_out_original = out.item();
54 let item_out_optimized = out_optimized.item();
55
56 let (index, out_elem) = match optimized.optimization_factor {
57 Some(factor) => (index / factor, out_optimized.elem()),
58 None => (index, out_elem),
59 };
60
61 if item_out_original != item_out_optimized {
62 let out_tmp = Variable::tmp(item_out_optimized);
63
64 write_op(index, out_elem, &input, &out_tmp)?;
65 let qualifier = out.const_qualifier();
66 let addr_space = D::address_space_for_variable(out);
67 let out_fmt = out.fmt_left();
68 writeln!(
69 f,
70 "{out_fmt} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
71 )
72 } else {
73 write_op(index, out_elem, &input, &out_optimized)
74 }
75 } else {
76 write_op(index, out_elem, input, out)
77 }
78 }
79
80 fn can_optimize() -> bool {
81 true
82 }
83}
84
85pub trait FunctionFmt<D: Dialect> {
86 fn base_function_name() -> &'static str;
87 fn function_name(elem: Elem<D>) -> String {
88 if Self::half_support() {
89 let prefix = match elem {
90 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
91 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
92 _ => "",
93 };
94 format!("{prefix}{}", Self::base_function_name())
95 } else {
96 Self::base_function_name().into()
97 }
98 }
99 fn format_unary<Input: Display>(
100 f: &mut std::fmt::Formatter<'_>,
101 input: Input,
102 elem: Elem<D>,
103 ) -> std::fmt::Result {
104 if Self::half_support() {
105 write!(f, "{}({input})", Self::function_name(elem))
106 } else {
107 match elem {
108 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
109 write!(f, "{}({}(float({input})))", elem, Self::function_name(elem))
110 }
111 _ => write!(f, "{}({input})", Self::function_name(elem)),
112 }
113 }
114 }
115
116 fn half_support() -> bool;
117}
118
119macro_rules! function {
120 ($name:ident, $func:expr) => {
121 function!($name, $func, true);
122 };
123 ($name:ident, $func:expr, $half_support:expr) => {
124 pub struct $name;
125
126 impl<D: Dialect> FunctionFmt<D> for $name {
127 fn base_function_name() -> &'static str {
128 $func
129 }
130 fn half_support() -> bool {
131 $half_support
132 }
133 }
134
135 impl<D: Dialect> Unary<D> for $name {
136 fn format_scalar<Input: Display>(
137 f: &mut std::fmt::Formatter<'_>,
138 input: Input,
139 elem: Elem<D>,
140 ) -> std::fmt::Result {
141 Self::format_unary(f, input, elem)
142 }
143
144 fn can_optimize() -> bool {
145 $half_support
146 }
147 }
148 };
149}
150
151function!(Log, "log");
152function!(FastLog, "__logf", false);
153function!(Sin, "sin");
154function!(Cos, "cos");
155function!(Tan, "tan", false);
156function!(Sinh, "sinh", false);
157function!(Cosh, "cosh", false);
158function!(ArcCos, "acos", false);
159function!(ArcSin, "asin", false);
160function!(ArcTan, "atan", false);
161function!(ArcSinh, "asinh", false);
162function!(ArcCosh, "acosh", false);
163function!(ArcTanh, "atanh", false);
164function!(FastSin, "__sinf", false);
165function!(FastCos, "__cosf", false);
166function!(Sqrt, "sqrt");
167function!(InverseSqrt, "rsqrt");
168function!(FastSqrt, "__fsqrt_rn", false);
169function!(FastInverseSqrt, "__frsqrt_rn", false);
170function!(Exp, "exp");
171function!(FastExp, "__expf", false);
172function!(Ceil, "ceil");
173function!(Trunc, "trunc");
174function!(Floor, "floor");
175function!(Round, "rint");
176function!(FastRecip, "__frcp_rn", false);
177function!(FastTanh, "__tanhf", false);
178
179function!(Erf, "erf", false);
180function!(Abs, "abs", false);
181
182pub struct Log1p;
183
184impl<D: Dialect> Unary<D> for Log1p {
185 fn format_scalar<Input: Component<D>>(
186 f: &mut std::fmt::Formatter<'_>,
187 input: Input,
188 _out_elem: Elem<D>,
189 ) -> std::fmt::Result {
190 D::compile_instruction_log1p_scalar(f, input)
191 }
192
193 fn can_optimize() -> bool {
194 false
195 }
196}
197
198pub struct Tanh;
199
200impl<D: Dialect> Unary<D> for Tanh {
201 fn format_scalar<Input: Component<D>>(
202 f: &mut std::fmt::Formatter<'_>,
203 input: Input,
204 _out_elem: Elem<D>,
205 ) -> std::fmt::Result {
206 D::compile_instruction_tanh_scalar(f, input)
207 }
208
209 fn can_optimize() -> bool {
210 false
211 }
212}
213
214pub struct Degrees;
215
216impl<D: Dialect> Unary<D> for Degrees {
217 fn format_scalar<Input: Component<D>>(
218 f: &mut std::fmt::Formatter<'_>,
219 input: Input,
220 elem: Elem<D>,
221 ) -> std::fmt::Result {
222 write!(f, "{input}*{elem}(57.29577951308232f)")
223 }
224
225 fn can_optimize() -> bool {
226 false
227 }
228}
229
230pub struct Radians;
231
232impl<D: Dialect> Unary<D> for Radians {
233 fn format_scalar<Input: Component<D>>(
234 f: &mut std::fmt::Formatter<'_>,
235 input: Input,
236 elem: Elem<D>,
237 ) -> std::fmt::Result {
238 write!(f, "{input}*{elem}(0.017453292519943295f)")
239 }
240
241 fn can_optimize() -> bool {
242 false
243 }
244}
245
246pub fn zero_extend<D: Dialect>(input: impl Component<D>) -> String {
247 match input.elem() {
248 Elem::I8 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U8),
249 Elem::I16 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U16),
250 Elem::U8 => format!("{}({input})", Elem::<D>::U32),
251 Elem::U16 => format!("{}({input})", Elem::<D>::U32),
252 _ => unreachable!("zero extend only supports integer < 32 bits"),
253 }
254}
255
256pub struct CountBits;
257
258impl<D: Dialect> Unary<D> for CountBits {
259 fn format_scalar<Input: Component<D>>(
260 f: &mut std::fmt::Formatter<'_>,
261 input: Input,
262 elem: Elem<D>,
263 ) -> std::fmt::Result {
264 D::compile_instruction_popcount_scalar(f, input, elem)
265 }
266}
267
268pub struct ReverseBits;
269
270impl<D: Dialect> Unary<D> for ReverseBits {
271 fn format_scalar<Input: Component<D>>(
272 f: &mut std::fmt::Formatter<'_>,
273 input: Input,
274 elem: Elem<D>,
275 ) -> std::fmt::Result {
276 D::compile_instruction_reverse_bits_scalar(f, input, elem)
277 }
278}
279
280pub struct LeadingZeros;
281
282impl<D: Dialect> Unary<D> for LeadingZeros {
283 fn format_scalar<Input: Component<D>>(
284 f: &mut std::fmt::Formatter<'_>,
285 input: Input,
286 elem: Elem<D>,
287 ) -> std::fmt::Result {
288 D::compile_instruction_leading_zeros_scalar(f, input, elem)
289 }
290}
291
292pub struct FindFirstSet;
293
294impl<D: Dialect> Unary<D> for FindFirstSet {
295 fn format_scalar<Input: Component<D>>(
296 f: &mut std::fmt::Formatter<'_>,
297 input: Input,
298 out_elem: Elem<D>,
299 ) -> std::fmt::Result {
300 D::compile_instruction_find_first_set(f, input, out_elem)
301 }
302}
303
304pub struct BitwiseNot;
305
306impl<D: Dialect> Unary<D> for BitwiseNot {
307 fn format_scalar<Input>(
308 f: &mut std::fmt::Formatter<'_>,
309 input: Input,
310 _out_elem: Elem<D>,
311 ) -> std::fmt::Result
312 where
313 Input: Component<D>,
314 {
315 write!(f, "~{input}")
316 }
317}
318
319pub struct Not;
320
321impl<D: Dialect> Unary<D> for Not {
322 fn format_scalar<Input>(
323 f: &mut std::fmt::Formatter<'_>,
324 input: Input,
325 _out_elem: Elem<D>,
326 ) -> std::fmt::Result
327 where
328 Input: Component<D>,
329 {
330 write!(f, "!{input}")
331 }
332}
333
334pub struct Assign;
335
336impl<D: Dialect> Unary<D> for Assign {
337 fn format(
338 f: &mut std::fmt::Formatter<'_>,
339 input: &Variable<D>,
340 out: &Variable<D>,
341 ) -> std::fmt::Result {
342 let item = out.item();
343
344 if item.vectorization == 1 || input.item() == item {
345 write!(f, "{} = ", out.fmt_left())?;
346 Self::format_scalar(f, *input, item.elem)?;
347 f.write_str(";\n")
348 } else {
349 Self::unroll_vec(f, input, out, item.elem, item.vectorization)
350 }
351 }
352
353 fn format_scalar<Input>(
354 f: &mut std::fmt::Formatter<'_>,
355 input: Input,
356 elem: Elem<D>,
357 ) -> std::fmt::Result
358 where
359 Input: Component<D>,
360 {
361 if elem != input.elem() {
363 match elem {
364 Elem::TF32 => write!(f, "nvcuda::wmma::__float_to_tf32({input})"),
365 elem => write!(f, "{elem}({input})"),
366 }
367 } else {
368 write!(f, "{input}")
369 }
370 }
371}
372
373fn elem_function_name<D: Dialect>(base_name: &'static str, elem: Elem<D>) -> String {
374 let prefix = match elem {
376 Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
377 Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
378 _ => "",
379 };
380 if prefix.is_empty() {
381 base_name.to_string()
382 } else if prefix == "h" || prefix == "h2" {
383 format!("__{prefix}{base_name}")
384 } else {
385 panic!("Unknown prefix '{prefix}'");
386 }
387}
388
389pub struct IsNan;
391
392impl<D: Dialect> Unary<D> for IsNan {
393 fn format_scalar<Input: Component<D>>(
394 f: &mut std::fmt::Formatter<'_>,
395 input: Input,
396 _elem: Elem<D>,
397 ) -> std::fmt::Result {
398 let elem = input.elem();
400 write!(f, "{}({input})", elem_function_name("isnan", elem))
401 }
402
403 fn can_optimize() -> bool {
404 true
405 }
406}
407
408pub struct IsInf;
409
410impl<D: Dialect> Unary<D> for IsInf {
411 fn format_scalar<Input: Component<D>>(
412 f: &mut std::fmt::Formatter<'_>,
413 input: Input,
414 _elem: Elem<D>,
415 ) -> std::fmt::Result {
416 let elem = input.elem();
418 write!(f, "{}({input})", elem_function_name("isinf", elem))
419 }
420
421 fn can_optimize() -> bool {
422 true
423 }
424}