cubecl_cpp/shared/
unary.rs

1use super::{Component, Dialect, Elem, FmtLeft, Variable};
2use std::fmt::Display;
3
4pub trait Unary<D: Dialect> {
5    fn format(
6        f: &mut std::fmt::Formatter<'_>,
7        input: &Variable<D>,
8        out: &Variable<D>,
9    ) -> std::fmt::Result {
10        let out_item = out.item();
11
12        if out_item.vectorization == 1 {
13            write!(f, "{} = ", out.fmt_left())?;
14            Self::format_scalar(f, *input, out_item.elem)?;
15            f.write_str(";\n")
16        } else {
17            Self::unroll_vec(f, input, out, out_item.elem, out_item.vectorization)
18        }
19    }
20
21    fn format_scalar<Input: Component<D>>(
22        f: &mut std::fmt::Formatter<'_>,
23        input: Input,
24        out_elem: Elem<D>,
25    ) -> std::fmt::Result;
26
27    fn unroll_vec(
28        f: &mut std::fmt::Formatter<'_>,
29        input: &Variable<D>,
30        out: &Variable<D>,
31        out_elem: Elem<D>,
32        index: usize,
33    ) -> std::fmt::Result {
34        let mut write_op = |index, out_elem, input: &Variable<D>, out: &Variable<D>| {
35            let out_item = out.item();
36            let out = out.fmt_left();
37            writeln!(f, "{out} = {out_item}{{")?;
38
39            for i in 0..index {
40                let inputi = input.index(i);
41
42                Self::format_scalar(f, inputi, out_elem)?;
43                f.write_str(",")?;
44            }
45
46            f.write_str("};\n")
47        };
48
49        if Self::can_optimize() {
50            let optimized = Variable::optimized_args([*input, *out]);
51            let [input, out_optimized] = optimized.args;
52
53            let item_out_original = out.item();
54            let item_out_optimized = out_optimized.item();
55
56            let (index, out_elem) = match optimized.optimization_factor {
57                Some(factor) => (index / factor, out_optimized.elem()),
58                None => (index, out_elem),
59            };
60
61            if item_out_original != item_out_optimized {
62                let out_tmp = Variable::tmp(item_out_optimized);
63
64                write_op(index, out_elem, &input, &out_tmp)?;
65                let qualifier = out.const_qualifier();
66                let addr_space = D::address_space_for_variable(out);
67                let out_fmt = out.fmt_left();
68                writeln!(
69                    f,
70                    "{out_fmt} = reinterpret_cast<{addr_space}{item_out_original}{qualifier}&>({out_tmp});\n"
71                )
72            } else {
73                write_op(index, out_elem, &input, &out_optimized)
74            }
75        } else {
76            write_op(index, out_elem, input, out)
77        }
78    }
79
80    fn can_optimize() -> bool {
81        true
82    }
83}
84
85pub trait FunctionFmt<D: Dialect> {
86    fn base_function_name() -> &'static str;
87    fn function_name(elem: Elem<D>) -> String {
88        if Self::half_support() {
89            let prefix = match elem {
90                Elem::F16 | Elem::BF16 => D::compile_instruction_half_function_name_prefix(),
91                Elem::F162 | Elem::BF162 => D::compile_instruction_half2_function_name_prefix(),
92                _ => "",
93            };
94            format!("{prefix}{}", Self::base_function_name())
95        } else {
96            Self::base_function_name().into()
97        }
98    }
99    fn format_unary<Input: Display>(
100        f: &mut std::fmt::Formatter<'_>,
101        input: Input,
102        elem: Elem<D>,
103    ) -> std::fmt::Result {
104        if Self::half_support() {
105            write!(f, "{}({input})", Self::function_name(elem))
106        } else {
107            match elem {
108                Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
109                    write!(f, "{}({}(float({input})))", elem, Self::function_name(elem))
110                }
111                _ => write!(f, "{}({input})", Self::function_name(elem)),
112            }
113        }
114    }
115
116    fn half_support() -> bool;
117}
118
119macro_rules! function {
120    ($name:ident, $func:expr) => {
121        function!($name, $func, true);
122    };
123    ($name:ident, $func:expr, $half_support:expr) => {
124        pub struct $name;
125
126        impl<D: Dialect> FunctionFmt<D> for $name {
127            fn base_function_name() -> &'static str {
128                $func
129            }
130            fn half_support() -> bool {
131                $half_support
132            }
133        }
134
135        impl<D: Dialect> Unary<D> for $name {
136            fn format_scalar<Input: Display>(
137                f: &mut std::fmt::Formatter<'_>,
138                input: Input,
139                elem: Elem<D>,
140            ) -> std::fmt::Result {
141                Self::format_unary(f, input, elem)
142            }
143
144            fn can_optimize() -> bool {
145                $half_support
146            }
147        }
148    };
149}
150
151function!(Log, "log");
152function!(Cos, "cos");
153function!(Sin, "sin");
154function!(Sqrt, "sqrt");
155function!(Exp, "exp");
156function!(Ceil, "ceil");
157function!(Floor, "floor");
158function!(Round, "rint");
159
160function!(Erf, "erf", false);
161function!(Abs, "abs", false);
162
163pub struct Log1p;
164
165impl<D: Dialect> Unary<D> for Log1p {
166    fn format_scalar<Input: Component<D>>(
167        f: &mut std::fmt::Formatter<'_>,
168        input: Input,
169        _out_elem: Elem<D>,
170    ) -> std::fmt::Result {
171        D::compile_instruction_log1p_scalar(f, input)
172    }
173
174    fn can_optimize() -> bool {
175        false
176    }
177}
178
179pub struct Tanh;
180
181impl<D: Dialect> Unary<D> for Tanh {
182    fn format_scalar<Input: Component<D>>(
183        f: &mut std::fmt::Formatter<'_>,
184        input: Input,
185        _out_elem: Elem<D>,
186    ) -> std::fmt::Result {
187        D::compile_instruction_tanh_scalar(f, input)
188    }
189
190    fn can_optimize() -> bool {
191        false
192    }
193}
194
195pub fn zero_extend<D: Dialect>(input: impl Component<D>) -> String {
196    match input.elem() {
197        Elem::I8 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U8),
198        Elem::I16 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U16),
199        Elem::U8 => format!("{}({input})", Elem::<D>::U32),
200        Elem::U16 => format!("{}({input})", Elem::<D>::U32),
201        _ => unreachable!("zero extend only supports integer < 32 bits"),
202    }
203}
204
205pub struct CountBits;
206
207impl<D: Dialect> Unary<D> for CountBits {
208    fn format_scalar<Input: Component<D>>(
209        f: &mut std::fmt::Formatter<'_>,
210        input: Input,
211        elem: Elem<D>,
212    ) -> std::fmt::Result {
213        D::compile_instruction_popcount_scalar(f, input, elem)
214    }
215}
216
217pub struct ReverseBits;
218
219impl<D: Dialect> Unary<D> for ReverseBits {
220    fn format_scalar<Input: Component<D>>(
221        f: &mut std::fmt::Formatter<'_>,
222        input: Input,
223        elem: Elem<D>,
224    ) -> std::fmt::Result {
225        D::compile_instruction_reverse_bits_scalar(f, input, elem)
226    }
227}
228
229pub struct LeadingZeros;
230
231impl<D: Dialect> Unary<D> for LeadingZeros {
232    fn format_scalar<Input: Component<D>>(
233        f: &mut std::fmt::Formatter<'_>,
234        input: Input,
235        elem: Elem<D>,
236    ) -> std::fmt::Result {
237        D::compile_instruction_leading_zeros_scalar(f, input, elem)
238    }
239}
240
241pub struct FindFirstSet;
242
243impl<D: Dialect> Unary<D> for FindFirstSet {
244    fn format_scalar<Input: Component<D>>(
245        f: &mut std::fmt::Formatter<'_>,
246        input: Input,
247        out_elem: Elem<D>,
248    ) -> std::fmt::Result {
249        D::compile_instruction_find_first_set(f, input, out_elem)
250    }
251}
252
253pub struct BitwiseNot;
254
255impl<D: Dialect> Unary<D> for BitwiseNot {
256    fn format_scalar<Input>(
257        f: &mut std::fmt::Formatter<'_>,
258        input: Input,
259        _out_elem: Elem<D>,
260    ) -> std::fmt::Result
261    where
262        Input: Component<D>,
263    {
264        write!(f, "~{input}")
265    }
266}
267
268pub struct Not;
269
270impl<D: Dialect> Unary<D> for Not {
271    fn format_scalar<Input>(
272        f: &mut std::fmt::Formatter<'_>,
273        input: Input,
274        _out_elem: Elem<D>,
275    ) -> std::fmt::Result
276    where
277        Input: Component<D>,
278    {
279        write!(f, "!{input}")
280    }
281}
282
283pub struct Assign;
284
285impl<D: Dialect> Unary<D> for Assign {
286    fn format(
287        f: &mut std::fmt::Formatter<'_>,
288        input: &Variable<D>,
289        out: &Variable<D>,
290    ) -> std::fmt::Result {
291        let item = out.item();
292
293        if item.vectorization == 1 || input.item() == item {
294            write!(f, "{} = ", out.fmt_left())?;
295            Self::format_scalar(f, *input, item.elem)?;
296            f.write_str(";\n")
297        } else {
298            Self::unroll_vec(f, input, out, item.elem, item.vectorization)
299        }
300    }
301
302    fn format_scalar<Input>(
303        f: &mut std::fmt::Formatter<'_>,
304        input: Input,
305        elem: Elem<D>,
306    ) -> std::fmt::Result
307    where
308        Input: Component<D>,
309    {
310        // Cast only when necessary.
311        if elem != input.elem() {
312            match elem {
313                Elem::TF32 => write!(f, "nvcuda::wmma::__float_to_tf32({input})"),
314                elem => write!(f, "{elem}({input})"),
315            }
316        } else {
317            write!(f, "{input}")
318        }
319    }
320}