cubecl_cpp/shared/
unary.rs

1use super::{Component, Dialect, Elem, FmtLeft, Variable};
2use std::fmt::Display;
3
4pub trait Unary<D: Dialect> {
5    fn format(
6        f: &mut std::fmt::Formatter<'_>,
7        input: &Variable<D>,
8        out: &Variable<D>,
9    ) -> std::fmt::Result {
10        let out_item = out.item();
11
12        if out_item.vectorization == 1 {
13            write!(f, "{} = ", out.fmt_left())?;
14            Self::format_scalar(f, *input, out_item.elem)?;
15            f.write_str(";\n")
16        } else {
17            Self::unroll_vec(f, input, out, out_item.elem, out_item.vectorization)
18        }
19    }
20
21    fn format_scalar<Input: Component<D>>(
22        f: &mut std::fmt::Formatter<'_>,
23        input: Input,
24        out_elem: Elem<D>,
25    ) -> std::fmt::Result;
26
27    fn unroll_vec(
28        f: &mut std::fmt::Formatter<'_>,
29        input: &Variable<D>,
30        out: &Variable<D>,
31        out_elem: Elem<D>,
32        index: usize,
33    ) -> std::fmt::Result {
34        let mut write_op = |index, out_elem, input: &Variable<D>, out: &Variable<D>| {
35            let out_item = out.item();
36            let out = out.fmt_left();
37            writeln!(f, "{out} = {out_item}{{")?;
38
39            for i in 0..index {
40                let inputi = input.index(i);
41
42                Self::format_scalar(f, inputi, out_elem)?;
43                f.write_str(",")?;
44            }
45
46            f.write_str("};\n")
47        };
48
49        if Self::can_optimize() {
50            let optimized = Variable::optimized_args([*input, *out]);
51            let [input, out_optimized] = optimized.args;
52
53            let item_out_original = out.item();
54            let item_out_optimized = out_optimized.item();
55
56            let (index, out_elem) = match optimized.optimization_factor {
57                Some(factor) => (index / factor, out_optimized.elem()),
58                None => (index, out_elem),
59            };
60
61            if item_out_original != item_out_optimized {
62                let out_tmp = Variable::tmp(item_out_optimized);
63
64                write_op(index, out_elem, &input, &out_tmp)?;
65                let qualifier = out.const_qualifier();
66                let addr_space = D::address_space_for_variable(out);
67                let out_fmt = out.fmt_left();
68                writeln!(
69                    f,
70                    "{out_fmt} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
71                )
72            } else {
73                write_op(index, out_elem, &input, &out_optimized)
74            }
75        } else {
76            write_op(index, out_elem, input, out)
77        }
78    }
79
80    fn can_optimize() -> bool {
81        true
82    }
83}
84
85pub trait FunctionFmt<D: Dialect> {
86    fn base_function_name() -> &'static str;
87    fn function_name(elem: Elem<D>) -> String {
88        if Self::half_support() {
89            let prefix = match elem {
90                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
91                Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
92                _ => "",
93            };
94            format!("{prefix}{}", Self::base_function_name())
95        } else {
96            Self::base_function_name().into()
97        }
98    }
99    fn format_unary<Input: Display>(
100        f: &mut std::fmt::Formatter<'_>,
101        input: Input,
102        elem: Elem<D>,
103    ) -> std::fmt::Result {
104        if Self::half_support() {
105            write!(f, "{}({input})", Self::function_name(elem))
106        } else {
107            match elem {
108                Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
109                    write!(f, "{}({}(float({input})))", elem, Self::function_name(elem))
110                }
111                _ => write!(f, "{}({input})", Self::function_name(elem)),
112            }
113        }
114    }
115
116    fn half_support() -> bool;
117}
118
119macro_rules! function {
120    ($name:ident, $func:expr) => {
121        function!($name, $func, true);
122    };
123    ($name:ident, $func:expr, $half_support:expr) => {
124        pub struct $name;
125
126        impl<D: Dialect> FunctionFmt<D> for $name {
127            fn base_function_name() -> &'static str {
128                $func
129            }
130            fn half_support() -> bool {
131                $half_support
132            }
133        }
134
135        impl<D: Dialect> Unary<D> for $name {
136            fn format_scalar<Input: Display>(
137                f: &mut std::fmt::Formatter<'_>,
138                input: Input,
139                elem: Elem<D>,
140            ) -> std::fmt::Result {
141                Self::format_unary(f, input, elem)
142            }
143
144            fn can_optimize() -> bool {
145                $half_support
146            }
147        }
148    };
149}
150
151function!(Log, "log");
152function!(Cos, "cos");
153function!(Sin, "sin");
154function!(Sqrt, "sqrt");
155function!(Exp, "exp");
156function!(Ceil, "ceil");
157function!(Trunc, "trunc");
158function!(Floor, "floor");
159function!(Round, "rint");
160
161function!(Erf, "erf", false);
162function!(Abs, "abs", false);
163
164pub struct Log1p;
165
166impl<D: Dialect> Unary<D> for Log1p {
167    fn format_scalar<Input: Component<D>>(
168        f: &mut std::fmt::Formatter<'_>,
169        input: Input,
170        _out_elem: Elem<D>,
171    ) -> std::fmt::Result {
172        D::compile_instruction_log1p_scalar(f, input)
173    }
174
175    fn can_optimize() -> bool {
176        false
177    }
178}
179
180pub struct Tanh;
181
182impl<D: Dialect> Unary<D> for Tanh {
183    fn format_scalar<Input: Component<D>>(
184        f: &mut std::fmt::Formatter<'_>,
185        input: Input,
186        _out_elem: Elem<D>,
187    ) -> std::fmt::Result {
188        D::compile_instruction_tanh_scalar(f, input)
189    }
190
191    fn can_optimize() -> bool {
192        false
193    }
194}
195
196pub fn zero_extend<D: Dialect>(input: impl Component<D>) -> String {
197    match input.elem() {
198        Elem::I8 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U8),
199        Elem::I16 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U16),
200        Elem::U8 => format!("{}({input})", Elem::<D>::U32),
201        Elem::U16 => format!("{}({input})", Elem::<D>::U32),
202        _ => unreachable!("zero extend only supports integer < 32 bits"),
203    }
204}
205
206pub struct CountBits;
207
208impl<D: Dialect> Unary<D> for CountBits {
209    fn format_scalar<Input: Component<D>>(
210        f: &mut std::fmt::Formatter<'_>,
211        input: Input,
212        elem: Elem<D>,
213    ) -> std::fmt::Result {
214        D::compile_instruction_popcount_scalar(f, input, elem)
215    }
216}
217
218pub struct ReverseBits;
219
220impl<D: Dialect> Unary<D> for ReverseBits {
221    fn format_scalar<Input: Component<D>>(
222        f: &mut std::fmt::Formatter<'_>,
223        input: Input,
224        elem: Elem<D>,
225    ) -> std::fmt::Result {
226        D::compile_instruction_reverse_bits_scalar(f, input, elem)
227    }
228}
229
230pub struct LeadingZeros;
231
232impl<D: Dialect> Unary<D> for LeadingZeros {
233    fn format_scalar<Input: Component<D>>(
234        f: &mut std::fmt::Formatter<'_>,
235        input: Input,
236        elem: Elem<D>,
237    ) -> std::fmt::Result {
238        D::compile_instruction_leading_zeros_scalar(f, input, elem)
239    }
240}
241
242pub struct FindFirstSet;
243
244impl<D: Dialect> Unary<D> for FindFirstSet {
245    fn format_scalar<Input: Component<D>>(
246        f: &mut std::fmt::Formatter<'_>,
247        input: Input,
248        out_elem: Elem<D>,
249    ) -> std::fmt::Result {
250        D::compile_instruction_find_first_set(f, input, out_elem)
251    }
252}
253
254pub struct BitwiseNot;
255
256impl<D: Dialect> Unary<D> for BitwiseNot {
257    fn format_scalar<Input>(
258        f: &mut std::fmt::Formatter<'_>,
259        input: Input,
260        _out_elem: Elem<D>,
261    ) -> std::fmt::Result
262    where
263        Input: Component<D>,
264    {
265        write!(f, "~{input}")
266    }
267}
268
269pub struct Not;
270
271impl<D: Dialect> Unary<D> for Not {
272    fn format_scalar<Input>(
273        f: &mut std::fmt::Formatter<'_>,
274        input: Input,
275        _out_elem: Elem<D>,
276    ) -> std::fmt::Result
277    where
278        Input: Component<D>,
279    {
280        write!(f, "!{input}")
281    }
282}
283
284pub struct Assign;
285
286impl<D: Dialect> Unary<D> for Assign {
287    fn format(
288        f: &mut std::fmt::Formatter<'_>,
289        input: &Variable<D>,
290        out: &Variable<D>,
291    ) -> std::fmt::Result {
292        let item = out.item();
293
294        if item.vectorization == 1 || input.item() == item {
295            write!(f, "{} = ", out.fmt_left())?;
296            Self::format_scalar(f, *input, item.elem)?;
297            f.write_str(";\n")
298        } else {
299            Self::unroll_vec(f, input, out, item.elem, item.vectorization)
300        }
301    }
302
303    fn format_scalar<Input>(
304        f: &mut std::fmt::Formatter<'_>,
305        input: Input,
306        elem: Elem<D>,
307    ) -> std::fmt::Result
308    where
309        Input: Component<D>,
310    {
311        // Cast only when necessary.
312        if elem != input.elem() {
313            match elem {
314                Elem::TF32 => write!(f, "nvcuda::wmma::__float_to_tf32({input})"),
315                elem => write!(f, "{elem}({input})"),
316            }
317        } else {
318            write!(f, "{input}")
319        }
320    }
321}
322
323fn elem_function_name<D: Dialect>(base_name: &'static str, elem: Elem<D>) -> String {
324    // Math functions prefix (no leading underscores)
325    let prefix = match elem {
326        Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
327        Elem::F16x2 | Elem::BF16x2 => D::compile_instruction_half2_function_name_prefix(),
328        _ => "",
329    };
330    if prefix.is_empty() {
331        base_name.to_string()
332    } else if prefix == "h" || prefix == "h2" {
333        format!("__{prefix}{base_name}")
334    } else {
335        panic!("Unknown prefix '{prefix}'");
336    }
337}
338
339// `isnan` / `isinf` are defined for cuda/hip/metal with same prefixes for half/bf16 on cuda/hip
340pub struct IsNan;
341
342impl<D: Dialect> Unary<D> for IsNan {
343    fn format_scalar<Input: Component<D>>(
344        f: &mut std::fmt::Formatter<'_>,
345        input: Input,
346        _elem: Elem<D>,
347    ) -> std::fmt::Result {
348        // Format unary function name based on *input* elem dtype
349        let elem = input.elem();
350        write!(f, "{}({input})", elem_function_name("isnan", elem))
351    }
352
353    fn can_optimize() -> bool {
354        true
355    }
356}
357
358pub struct IsInf;
359
360impl<D: Dialect> Unary<D> for IsInf {
361    fn format_scalar<Input: Component<D>>(
362        f: &mut std::fmt::Formatter<'_>,
363        input: Input,
364        _elem: Elem<D>,
365    ) -> std::fmt::Result {
366        // Format unary function name based on *input* elem dtype
367        let elem = input.elem();
368        write!(f, "{}({input})", elem_function_name("isinf", elem))
369    }
370
371    fn can_optimize() -> bool {
372        true
373    }
374}