Skip to main content

cubecl_cpp/shared/
unary.rs

1use super::{Component, Dialect, Elem, FmtLeft, Variable};
2use std::fmt::Display;
3
4pub trait Unary<D: Dialect> {
5    fn format(
6        f: &mut std::fmt::Formatter<'_>,
7        input: &Variable<D>,
8        out: &Variable<D>,
9    ) -> std::fmt::Result {
10        let out_item = out.item();
11
12        if out_item.vectorization == 1 {
13            write!(f, "{} = ", out.fmt_left())?;
14            Self::format_scalar(f, *input, out_item.elem)?;
15            f.write_str(";\n")
16        } else {
17            Self::unroll_vec(f, input, out, out_item.elem, out_item.vectorization)
18        }
19    }
20
21    fn format_scalar<Input: Component<D>>(
22        f: &mut std::fmt::Formatter<'_>,
23        input: Input,
24        out_elem: Elem<D>,
25    ) -> std::fmt::Result;
26
27    fn unroll_vec(
28        f: &mut std::fmt::Formatter<'_>,
29        input: &Variable<D>,
30        out: &Variable<D>,
31        out_elem: Elem<D>,
32        index: usize,
33    ) -> std::fmt::Result {
34        let mut write_op = |index, out_elem, input: &Variable<D>, out: &Variable<D>| {
35            let out_item = out.item();
36            let out = out.fmt_left();
37            writeln!(f, "{out} = {out_item}{{")?;
38
39            for i in 0..index {
40                let inputi = input.index(i);
41
42                Self::format_scalar(f, inputi, out_elem)?;
43                f.write_str(",")?;
44            }
45
46            f.write_str("};\n")
47        };
48
49        if Self::can_optimize() {
50            let optimized = Variable::optimized_args([*input, *out]);
51            let [input, out_optimized] = optimized.args;
52
53            let item_out_original = out.item();
54            let item_out_optimized = out_optimized.item();
55
56            let (index, out_elem) = match optimized.optimization_factor {
57                Some(factor) => (index / factor, out_optimized.elem()),
58                None => (index, out_elem),
59            };
60
61            if item_out_original != item_out_optimized {
62                let out_tmp = Variable::tmp(item_out_optimized);
63
64                write_op(index, out_elem, &input, &out_tmp)?;
65                let qualifier = out.const_qualifier();
66                let addr_space = D::address_space_for_variable(out);
67                let out_fmt = out.fmt_left();
68                writeln!(
69                    f,
70                    "{out_fmt} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
71                )
72            } else {
73                write_op(index, out_elem, &input, &out_optimized)
74            }
75        } else {
76            write_op(index, out_elem, input, out)
77        }
78    }
79
80    fn can_optimize() -> bool {
81        true
82    }
83}
84
85pub trait FunctionFmt<D: Dialect> {
86    fn base_function_name() -> &'static str;
87    fn function_name(elem: Elem<D>) -> String {
88        if Self::half_support() {
89            let prefix = match elem {
90                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
91                Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
92                _ => "",
93            };
94            format!("{prefix}{}", Self::base_function_name())
95        } else {
96            Self::base_function_name().into()
97        }
98    }
99    fn format_unary<Input: Display>(
100        f: &mut std::fmt::Formatter<'_>,
101        input: Input,
102        elem: Elem<D>,
103    ) -> std::fmt::Result {
104        if Self::half_support() {
105            write!(f, "{}({input})", Self::function_name(elem))
106        } else {
107            match elem {
108                Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
109                    write!(f, "{}({}(float({input})))", elem, Self::function_name(elem))
110                }
111                _ => write!(f, "{}({input})", Self::function_name(elem)),
112            }
113        }
114    }
115
116    fn half_support() -> bool;
117}
118
119macro_rules! function {
120    ($name:ident, $func:expr) => {
121        function!($name, $func, true);
122    };
123    ($name:ident, $func:expr, $half_support:expr) => {
124        pub struct $name;
125
126        impl<D: Dialect> FunctionFmt<D> for $name {
127            fn base_function_name() -> &'static str {
128                $func
129            }
130            fn half_support() -> bool {
131                $half_support
132            }
133        }
134
135        impl<D: Dialect> Unary<D> for $name {
136            fn format_scalar<Input: Display>(
137                f: &mut std::fmt::Formatter<'_>,
138                input: Input,
139                elem: Elem<D>,
140            ) -> std::fmt::Result {
141                Self::format_unary(f, input, elem)
142            }
143
144            fn can_optimize() -> bool {
145                $half_support
146            }
147        }
148    };
149}
150
151function!(Log, "log");
152function!(FastLog, "__logf", false);
153function!(Sin, "sin");
154function!(Cos, "cos");
155function!(Tan, "tan", false);
156function!(Sinh, "sinh", false);
157function!(Cosh, "cosh", false);
158function!(ArcCos, "acos", false);
159function!(ArcSin, "asin", false);
160function!(ArcTan, "atan", false);
161function!(ArcSinh, "asinh", false);
162function!(ArcCosh, "acosh", false);
163function!(ArcTanh, "atanh", false);
164function!(FastSin, "__sinf", false);
165function!(FastCos, "__cosf", false);
166function!(Sqrt, "sqrt");
167function!(InverseSqrt, "rsqrt");
168function!(FastSqrt, "__fsqrt_rn", false);
169function!(FastInverseSqrt, "__frsqrt_rn", false);
170function!(Exp, "exp");
171function!(FastExp, "__expf", false);
172function!(Ceil, "ceil");
173function!(Trunc, "trunc");
174function!(Floor, "floor");
175function!(Round, "rint");
176function!(FastRecip, "__frcp_rn", false);
177function!(FastTanh, "__tanhf", false);
178
179function!(Erf, "erf", false);
180function!(Abs, "abs", false);
181
182pub struct Log1p;
183
184impl<D: Dialect> Unary<D> for Log1p {
185    fn format_scalar<Input: Component<D>>(
186        f: &mut std::fmt::Formatter<'_>,
187        input: Input,
188        _out_elem: Elem<D>,
189    ) -> std::fmt::Result {
190        D::compile_instruction_log1p_scalar(f, input)
191    }
192
193    fn can_optimize() -> bool {
194        false
195    }
196}
197
198pub struct Tanh;
199
200impl<D: Dialect> Unary<D> for Tanh {
201    fn format_scalar<Input: Component<D>>(
202        f: &mut std::fmt::Formatter<'_>,
203        input: Input,
204        _out_elem: Elem<D>,
205    ) -> std::fmt::Result {
206        D::compile_instruction_tanh_scalar(f, input)
207    }
208
209    fn can_optimize() -> bool {
210        false
211    }
212}
213
214pub struct Degrees;
215
216impl<D: Dialect> Unary<D> for Degrees {
217    fn format_scalar<Input: Component<D>>(
218        f: &mut std::fmt::Formatter<'_>,
219        input: Input,
220        elem: Elem<D>,
221    ) -> std::fmt::Result {
222        write!(f, "{input}*{elem}(57.29577951308232f)")
223    }
224
225    fn can_optimize() -> bool {
226        false
227    }
228}
229
230pub struct Radians;
231
232impl<D: Dialect> Unary<D> for Radians {
233    fn format_scalar<Input: Component<D>>(
234        f: &mut std::fmt::Formatter<'_>,
235        input: Input,
236        elem: Elem<D>,
237    ) -> std::fmt::Result {
238        write!(f, "{input}*{elem}(0.017453292519943295f)")
239    }
240
241    fn can_optimize() -> bool {
242        false
243    }
244}
245
246pub fn zero_extend<D: Dialect>(input: impl Component<D>) -> String {
247    match input.elem() {
248        Elem::I8 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U8),
249        Elem::I16 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U16),
250        Elem::U8 => format!("{}({input})", Elem::<D>::U32),
251        Elem::U16 => format!("{}({input})", Elem::<D>::U32),
252        _ => unreachable!("zero extend only supports integer < 32 bits"),
253    }
254}
255
256pub struct CountBits;
257
258impl<D: Dialect> Unary<D> for CountBits {
259    fn format_scalar<Input: Component<D>>(
260        f: &mut std::fmt::Formatter<'_>,
261        input: Input,
262        elem: Elem<D>,
263    ) -> std::fmt::Result {
264        D::compile_instruction_popcount_scalar(f, input, elem)
265    }
266}
267
268pub struct ReverseBits;
269
270impl<D: Dialect> Unary<D> for ReverseBits {
271    fn format_scalar<Input: Component<D>>(
272        f: &mut std::fmt::Formatter<'_>,
273        input: Input,
274        elem: Elem<D>,
275    ) -> std::fmt::Result {
276        D::compile_instruction_reverse_bits_scalar(f, input, elem)
277    }
278}
279
280pub struct LeadingZeros;
281
282impl<D: Dialect> Unary<D> for LeadingZeros {
283    fn format_scalar<Input: Component<D>>(
284        f: &mut std::fmt::Formatter<'_>,
285        input: Input,
286        elem: Elem<D>,
287    ) -> std::fmt::Result {
288        D::compile_instruction_leading_zeros_scalar(f, input, elem)
289    }
290}
291
292pub struct TrailingZeros;
293
294impl<D: Dialect> Unary<D> for TrailingZeros {
295    fn format_scalar<Input: Component<D>>(
296        f: &mut std::fmt::Formatter<'_>,
297        input: Input,
298        elem: Elem<D>,
299    ) -> std::fmt::Result {
300        D::compile_instruction_trailing_zeros_scalar(f, input, elem)
301    }
302}
303
304pub struct FindFirstSet;
305
306impl<D: Dialect> Unary<D> for FindFirstSet {
307    fn format_scalar<Input: Component<D>>(
308        f: &mut std::fmt::Formatter<'_>,
309        input: Input,
310        out_elem: Elem<D>,
311    ) -> std::fmt::Result {
312        D::compile_instruction_find_first_set(f, input, out_elem)
313    }
314}
315
316pub struct BitwiseNot;
317
318impl<D: Dialect> Unary<D> for BitwiseNot {
319    fn format_scalar<Input>(
320        f: &mut std::fmt::Formatter<'_>,
321        input: Input,
322        _out_elem: Elem<D>,
323    ) -> std::fmt::Result
324    where
325        Input: Component<D>,
326    {
327        write!(f, "~{input}")
328    }
329}
330
331pub struct Not;
332
333impl<D: Dialect> Unary<D> for Not {
334    fn format_scalar<Input>(
335        f: &mut std::fmt::Formatter<'_>,
336        input: Input,
337        _out_elem: Elem<D>,
338    ) -> std::fmt::Result
339    where
340        Input: Component<D>,
341    {
342        write!(f, "!{input}")
343    }
344}
345
346pub struct Assign;
347
348impl<D: Dialect> Unary<D> for Assign {
349    fn format(
350        f: &mut std::fmt::Formatter<'_>,
351        input: &Variable<D>,
352        out: &Variable<D>,
353    ) -> std::fmt::Result {
354        let item = out.item();
355
356        if item.vectorization == 1 || input.item() == item {
357            write!(f, "{} = ", out.fmt_left())?;
358            Self::format_scalar(f, *input, item.elem)?;
359            f.write_str(";\n")
360        } else {
361            Self::unroll_vec(f, input, out, item.elem, item.vectorization)
362        }
363    }
364
365    fn format_scalar<Input>(
366        f: &mut std::fmt::Formatter<'_>,
367        input: Input,
368        elem: Elem<D>,
369    ) -> std::fmt::Result
370    where
371        Input: Component<D>,
372    {
373        // Cast only when necessary.
374        if elem != input.elem() {
375            match elem {
376                Elem::TF32 => write!(f, "nvcuda::wmma::__float_to_tf32({input})"),
377                elem => write!(f, "{elem}({input})"),
378            }
379        } else {
380            write!(f, "{input}")
381        }
382    }
383}
384
385fn elem_function_name<D: Dialect>(base_name: &'static str, elem: Elem<D>) -> String {
386    // Math functions prefix (no leading underscores)
387    let prefix = match elem {
388        Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
389        Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
390        _ => "",
391    };
392    if prefix.is_empty() {
393        base_name.to_string()
394    } else if prefix == "h" || prefix == "h2" {
395        format!("__{prefix}{base_name}")
396    } else {
397        panic!("Unknown prefix '{prefix}'");
398    }
399}
400
401// `isnan` / `isinf` are defined for cuda/hip/metal with same prefixes for half/bf16 on cuda/hip
402pub struct IsNan;
403
404impl<D: Dialect> Unary<D> for IsNan {
405    fn format_scalar<Input: Component<D>>(
406        f: &mut std::fmt::Formatter<'_>,
407        input: Input,
408        _elem: Elem<D>,
409    ) -> std::fmt::Result {
410        // Format unary function name based on *input* elem dtype
411        let elem = input.elem();
412        write!(f, "{}({input})", elem_function_name("isnan", elem))
413    }
414
415    fn can_optimize() -> bool {
416        true
417    }
418}
419
420pub struct IsInf;
421
422impl<D: Dialect> Unary<D> for IsInf {
423    fn format_scalar<Input: Component<D>>(
424        f: &mut std::fmt::Formatter<'_>,
425        input: Input,
426        _elem: Elem<D>,
427    ) -> std::fmt::Result {
428        // Format unary function name based on *input* elem dtype
429        let elem = input.elem();
430        write!(f, "{}({input})", elem_function_name("isinf", elem))
431    }
432
433    fn can_optimize() -> bool {
434        true
435    }
436}