1use super::{Component, Dialect, Elem, FmtLeft, Variable};
2use std::fmt::Display;
3
4pub trait Unary<D: Dialect> {
5 fn format(
6 f: &mut std::fmt::Formatter<'_>,
7 input: &Variable<D>,
8 out: &Variable<D>,
9 ) -> std::fmt::Result {
10 let item = out.item();
11
12 if item.vectorization == 1 {
13 write!(f, "{} = ", out.fmt_left())?;
14 Self::format_scalar(f, *input, item.elem)?;
15 f.write_str(";\n")
16 } else {
17 Self::unroll_vec(f, input, out, item.elem, item.vectorization)
18 }
19 }
20
21 fn format_scalar<Input: Component<D>>(
22 f: &mut std::fmt::Formatter<'_>,
23 input: Input,
24 elem: Elem<D>,
25 ) -> std::fmt::Result;
26
27 fn unroll_vec(
28 f: &mut std::fmt::Formatter<'_>,
29 input: &Variable<D>,
30 out: &Variable<D>,
31 elem: Elem<D>,
32 index: usize,
33 ) -> std::fmt::Result {
34 let mut write_op = |index, elem, input: &Variable<D>, out: &Variable<D>| {
35 let out_item = out.item();
36 let out = out.fmt_left();
37 writeln!(f, "{out} = {out_item}{{")?;
38
39 for i in 0..index {
40 let inputi = input.index(i);
41
42 Self::format_scalar(f, inputi, elem)?;
43 f.write_str(",")?;
44 }
45
46 f.write_str("};\n")
47 };
48
49 if Self::can_optimize() {
50 let optimized = Variable::optimized_args([*input, *out]);
51 let [input, out_optimized] = optimized.args;
52
53 let item_out_original = out.item();
54 let item_out_optimized = out_optimized.item();
55
56 let (index, elem) = match optimized.optimization_factor {
57 Some(factor) => (index / factor, out_optimized.elem()),
58 None => (index, elem),
59 };
60
61 if item_out_original != item_out_optimized {
62 let out_tmp = Variable::tmp(item_out_optimized);
63
64 write_op(index, elem, &input, &out_tmp)?;
65 let qualifier = out.const_qualifier();
66 let out_fmt = out.fmt_left();
67 writeln!(
68 f,
69 "{out_fmt} = reinterpret_cast<{item_out_original}{qualifier}&>({out_tmp});\n"
70 )
71 } else {
72 write_op(index, elem, &input, &out_optimized)
73 }
74 } else {
75 write_op(index, elem, input, out)
76 }
77 }
78
79 fn can_optimize() -> bool {
80 true
81 }
82}
83
84pub trait FunctionFmt<D: Dialect> {
85 fn base_function_name() -> &'static str;
86 fn function_name(elem: Elem<D>) -> String {
87 if Self::half_support() {
88 match elem {
89 Elem::F16 | Elem::BF16 => return format!("h{}", Self::base_function_name()),
90 Elem::F162 | Elem::BF162 => return format!("h2{}", Self::base_function_name()),
91 _ => (),
92 };
93 }
94
95 Self::base_function_name().into()
96 }
97 fn format_unary<Input: Display>(
98 f: &mut std::fmt::Formatter<'_>,
99 input: Input,
100 elem: Elem<D>,
101 ) -> std::fmt::Result {
102 if Self::half_support() {
103 return write!(f, "{}({input})", Self::function_name(elem));
104 }
105
106 match elem {
107 Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
108 write!(f, "{}({}(float({input})))", elem, Self::function_name(elem))
109 }
110 _ => write!(f, "{}({input})", Self::function_name(elem)),
111 }
112 }
113
114 fn half_support() -> bool;
115}
116
117macro_rules! function {
118 ($name:ident, $func:expr) => {
119 function!($name, $func, true);
120 };
121 ($name:ident, $func:expr, $half_support:expr) => {
122 pub struct $name;
123
124 impl<D: Dialect> FunctionFmt<D> for $name {
125 fn base_function_name() -> &'static str {
126 $func
127 }
128 fn half_support() -> bool {
129 $half_support
130 }
131 }
132
133 impl<D: Dialect> Unary<D> for $name {
134 fn format_scalar<Input: Display>(
135 f: &mut std::fmt::Formatter<'_>,
136 input: Input,
137 elem: Elem<D>,
138 ) -> std::fmt::Result {
139 Self::format_unary(f, input, elem)
140 }
141
142 fn can_optimize() -> bool {
143 $half_support
144 }
145 }
146 };
147}
148
149function!(Log, "log");
150function!(Log1p, "log1p", false);
151function!(Cos, "cos");
152function!(Sin, "sin");
153function!(Sqrt, "sqrt");
154function!(Exp, "exp");
155function!(Ceil, "ceil");
156function!(Floor, "floor");
157function!(Round, "rint");
158
159function!(Tanh, "tanh", false);
160function!(Erf, "erf", false);
161function!(Abs, "abs", false);
162
163fn zero_extend<D: Dialect>(input: impl Component<D>) -> String {
164 match input.elem() {
165 Elem::I8 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U8),
166 Elem::I16 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U16),
167 Elem::U8 => format!("{}({input})", Elem::<D>::U32),
168 Elem::U16 => format!("{}({input})", Elem::<D>::U32),
169 _ => unreachable!("zero extend only supports integer < 32 bits"),
170 }
171}
172
173pub struct CountBits;
174
175impl<D: Dialect> Unary<D> for CountBits {
176 fn format_scalar<Input: Component<D>>(
177 f: &mut std::fmt::Formatter<'_>,
178 input: Input,
179 _elem: Elem<D>,
180 ) -> std::fmt::Result {
181 match input.elem() {
182 Elem::I32 | Elem::U32 => write!(f, "__popc({input})"),
183 Elem::I64 | Elem::U64 => write!(f, "__popcll({input})"),
184 _ => write!(f, "__popc({})", zero_extend(input)),
185 }
186 }
187}
188
189pub struct ReverseBits;
190
191impl<D: Dialect> Unary<D> for ReverseBits {
192 fn format_scalar<Input: Component<D>>(
193 f: &mut std::fmt::Formatter<'_>,
194 input: Input,
195 elem: Elem<D>,
196 ) -> std::fmt::Result {
197 match elem {
198 Elem::I32 | Elem::U32 => write!(f, "__brev({input})"),
199 Elem::I64 | Elem::U64 => write!(f, "__brevll({input})"),
200 _ => write!(
201 f,
202 "{elem}(__brev({}) >> {})",
203 zero_extend(input),
204 (size_of::<u32>() - elem.size()) * 8
205 ),
206 }
207 }
208}
209
210pub struct Not;
211
212impl<D: Dialect> Unary<D> for Not {
213 fn format_scalar<Input>(
214 f: &mut std::fmt::Formatter<'_>,
215 input: Input,
216 _elem: Elem<D>,
217 ) -> std::fmt::Result
218 where
219 Input: Component<D>,
220 {
221 write!(f, "!{input}")
222 }
223}
224
225pub struct Assign;
226
227impl<D: Dialect> Unary<D> for Assign {
228 fn format(
229 f: &mut std::fmt::Formatter<'_>,
230 input: &Variable<D>,
231 out: &Variable<D>,
232 ) -> std::fmt::Result {
233 let item = out.item();
234
235 if item.vectorization == 1 || input.item() == item {
236 write!(f, "{} = ", out.fmt_left())?;
237 Self::format_scalar(f, *input, item.elem)?;
238 f.write_str(";\n")
239 } else {
240 Self::unroll_vec(f, input, out, item.elem, item.vectorization)
241 }
242 }
243
244 fn format_scalar<Input>(
245 f: &mut std::fmt::Formatter<'_>,
246 input: Input,
247 elem: Elem<D>,
248 ) -> std::fmt::Result
249 where
250 Input: Component<D>,
251 {
252 if elem != input.elem() {
254 match elem {
255 Elem::TF32 => write!(f, "nvcuda::wmma::__float_to_tf32({input})"),
256 elem => write!(f, "{elem}({input})"),
257 }
258 } else {
259 write!(f, "{input}")
260 }
261 }
262}