cubecl_cpp/shared/
unary.rs

1use super::{Component, Dialect, Elem, FmtLeft, Variable};
2use std::fmt::Display;
3
4pub trait Unary<D: Dialect> {
5    fn format(
6        f: &mut std::fmt::Formatter<'_>,
7        input: &Variable<D>,
8        out: &Variable<D>,
9    ) -> std::fmt::Result {
10        let item = out.item();
11
12        if item.vectorization == 1 {
13            write!(f, "{} = ", out.fmt_left())?;
14            Self::format_scalar(f, *input, item.elem)?;
15            f.write_str(";\n")
16        } else {
17            Self::unroll_vec(f, input, out, item.elem, item.vectorization)
18        }
19    }
20
21    fn format_scalar<Input: Component<D>>(
22        f: &mut std::fmt::Formatter<'_>,
23        input: Input,
24        elem: Elem<D>,
25    ) -> std::fmt::Result;
26
27    fn unroll_vec(
28        f: &mut std::fmt::Formatter<'_>,
29        input: &Variable<D>,
30        out: &Variable<D>,
31        elem: Elem<D>,
32        index: usize,
33    ) -> std::fmt::Result {
34        let mut write_op = |index, elem, input: &Variable<D>, out: &Variable<D>| {
35            let out_item = out.item();
36            let out = out.fmt_left();
37            writeln!(f, "{out} = {out_item}{{")?;
38
39            for i in 0..index {
40                let inputi = input.index(i);
41
42                Self::format_scalar(f, inputi, elem)?;
43                f.write_str(",")?;
44            }
45
46            f.write_str("};\n")
47        };
48
49        if Self::can_optimize() {
50            let optimized = Variable::optimized_args([*input, *out]);
51            let [input, out_optimized] = optimized.args;
52
53            let item_out_original = out.item();
54            let item_out_optimized = out_optimized.item();
55
56            let (index, elem) = match optimized.optimization_factor {
57                Some(factor) => (index / factor, out_optimized.elem()),
58                None => (index, elem),
59            };
60
61            if item_out_original != item_out_optimized {
62                let out_tmp = Variable::tmp(item_out_optimized);
63
64                write_op(index, elem, &input, &out_tmp)?;
65                let qualifier = out.const_qualifier();
66                let out_fmt = out.fmt_left();
67                writeln!(
68                    f,
69                    "{out_fmt} = reinterpret_cast<{item_out_original}{qualifier}&>({out_tmp});\n"
70                )
71            } else {
72                write_op(index, elem, &input, &out_optimized)
73            }
74        } else {
75            write_op(index, elem, input, out)
76        }
77    }
78
79    fn can_optimize() -> bool {
80        true
81    }
82}
83
84pub trait FunctionFmt<D: Dialect> {
85    fn base_function_name() -> &'static str;
86    fn function_name(elem: Elem<D>) -> String {
87        if Self::half_support() {
88            match elem {
89                Elem::F16 | Elem::BF16 => return format!("h{}", Self::base_function_name()),
90                Elem::F162 | Elem::BF162 => return format!("h2{}", Self::base_function_name()),
91                _ => (),
92            };
93        }
94
95        Self::base_function_name().into()
96    }
97    fn format_unary<Input: Display>(
98        f: &mut std::fmt::Formatter<'_>,
99        input: Input,
100        elem: Elem<D>,
101    ) -> std::fmt::Result {
102        if Self::half_support() {
103            return write!(f, "{}({input})", Self::function_name(elem));
104        }
105
106        match elem {
107            Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
108                write!(f, "{}({}(float({input})))", elem, Self::function_name(elem))
109            }
110            _ => write!(f, "{}({input})", Self::function_name(elem)),
111        }
112    }
113
114    fn half_support() -> bool;
115}
116
117macro_rules! function {
118    ($name:ident, $func:expr) => {
119        function!($name, $func, true);
120    };
121    ($name:ident, $func:expr, $half_support:expr) => {
122        pub struct $name;
123
124        impl<D: Dialect> FunctionFmt<D> for $name {
125            fn base_function_name() -> &'static str {
126                $func
127            }
128            fn half_support() -> bool {
129                $half_support
130            }
131        }
132
133        impl<D: Dialect> Unary<D> for $name {
134            fn format_scalar<Input: Display>(
135                f: &mut std::fmt::Formatter<'_>,
136                input: Input,
137                elem: Elem<D>,
138            ) -> std::fmt::Result {
139                Self::format_unary(f, input, elem)
140            }
141
142            fn can_optimize() -> bool {
143                $half_support
144            }
145        }
146    };
147}
148
149function!(Log, "log");
150function!(Log1p, "log1p", false);
151function!(Cos, "cos");
152function!(Sin, "sin");
153function!(Sqrt, "sqrt");
154function!(Exp, "exp");
155function!(Ceil, "ceil");
156function!(Floor, "floor");
157function!(Round, "rint");
158
159function!(Tanh, "tanh", false);
160function!(Erf, "erf", false);
161function!(Abs, "abs", false);
162
163fn zero_extend<D: Dialect>(input: impl Component<D>) -> String {
164    match input.elem() {
165        Elem::I8 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U8),
166        Elem::I16 => format!("{}({}({input}))", Elem::<D>::U32, Elem::<D>::U16),
167        Elem::U8 => format!("{}({input})", Elem::<D>::U32),
168        Elem::U16 => format!("{}({input})", Elem::<D>::U32),
169        _ => unreachable!("zero extend only supports integer < 32 bits"),
170    }
171}
172
173pub struct CountBits;
174
175impl<D: Dialect> Unary<D> for CountBits {
176    fn format_scalar<Input: Component<D>>(
177        f: &mut std::fmt::Formatter<'_>,
178        input: Input,
179        _elem: Elem<D>,
180    ) -> std::fmt::Result {
181        match input.elem() {
182            Elem::I32 | Elem::U32 => write!(f, "__popc({input})"),
183            Elem::I64 | Elem::U64 => write!(f, "__popcll({input})"),
184            _ => write!(f, "__popc({})", zero_extend(input)),
185        }
186    }
187}
188
189pub struct ReverseBits;
190
191impl<D: Dialect> Unary<D> for ReverseBits {
192    fn format_scalar<Input: Component<D>>(
193        f: &mut std::fmt::Formatter<'_>,
194        input: Input,
195        elem: Elem<D>,
196    ) -> std::fmt::Result {
197        match elem {
198            Elem::I32 | Elem::U32 => write!(f, "__brev({input})"),
199            Elem::I64 | Elem::U64 => write!(f, "__brevll({input})"),
200            _ => write!(
201                f,
202                "{elem}(__brev({}) >> {})",
203                zero_extend(input),
204                (size_of::<u32>() - elem.size()) * 8
205            ),
206        }
207    }
208}
209
210pub struct Not;
211
212impl<D: Dialect> Unary<D> for Not {
213    fn format_scalar<Input>(
214        f: &mut std::fmt::Formatter<'_>,
215        input: Input,
216        _elem: Elem<D>,
217    ) -> std::fmt::Result
218    where
219        Input: Component<D>,
220    {
221        write!(f, "!{input}")
222    }
223}
224
225pub struct Assign;
226
227impl<D: Dialect> Unary<D> for Assign {
228    fn format(
229        f: &mut std::fmt::Formatter<'_>,
230        input: &Variable<D>,
231        out: &Variable<D>,
232    ) -> std::fmt::Result {
233        let item = out.item();
234
235        if item.vectorization == 1 || input.item() == item {
236            write!(f, "{} = ", out.fmt_left())?;
237            Self::format_scalar(f, *input, item.elem)?;
238            f.write_str(";\n")
239        } else {
240            Self::unroll_vec(f, input, out, item.elem, item.vectorization)
241        }
242    }
243
244    fn format_scalar<Input>(
245        f: &mut std::fmt::Formatter<'_>,
246        input: Input,
247        elem: Elem<D>,
248    ) -> std::fmt::Result
249    where
250        Input: Component<D>,
251    {
252        // Cast only when necessary.
253        if elem != input.elem() {
254            match elem {
255                Elem::TF32 => write!(f, "nvcuda::wmma::__float_to_tf32({input})"),
256                elem => write!(f, "{elem}({input})"),
257            }
258        } else {
259            write!(f, "{input}")
260        }
261    }
262}