1use cubecl_ir::{Bitwise, Operator};
2use half::{bf16, f16};
3
4use crate::{
5 flex32,
6 ir::{Arithmetic, ExpandElement, Scope},
7 prelude::{CubePrimitive, ExpandElementTyped},
8 tf32, unexpanded,
9};
10
11use super::base::{unary_expand, unary_expand_fixed_output};
12
13pub mod not {
14 use super::*;
15
16 pub fn expand(scope: &mut Scope, x: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
17 unary_expand(scope, x.into(), Operator::Not).into()
18 }
19}
20
21pub mod neg {
22 use super::*;
23
24 pub fn expand<E: CubePrimitive>(
25 scope: &mut Scope,
26 x: ExpandElementTyped<E>,
27 ) -> ExpandElementTyped<E> {
28 unary_expand(scope, x.into(), Arithmetic::Neg).into()
29 }
30}
31
32macro_rules! impl_unary_func {
33 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
34 pub trait $trait_name: CubePrimitive + Sized {
35 #[allow(unused_variables)]
36 fn $method_name(x: Self) -> Self {
37 unexpanded!()
38 }
39
40 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
41 unary_expand(scope, x.into(), $operator).into()
42 }
43 }
44
45 $(impl $trait_name for $type {})*
46 }
47}
48
49macro_rules! impl_unary_func_fixed_out_vectorization {
50 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
51 pub trait $trait_name: CubePrimitive + Sized {
52 #[allow(unused_variables)]
53 fn $method_name(x: Self) -> Self {
54 unexpanded!()
55 }
56
57 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
58 let expand_element: ExpandElement = x.into();
59 let mut item = expand_element.item;
60 item.vectorization = $out_vectorization;
61 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
62 }
63 }
64
65 $(impl $trait_name for $type {})*
66 }
67}
68
69macro_rules! impl_unary_func_fixed_out_ty {
70 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
71 pub trait $trait_name: CubePrimitive + Sized {
72 #[allow(unused_variables)]
73 fn $method_name(x: Self) -> $out_ty {
74 unexpanded!()
75 }
76
77 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
78 let expand_element: ExpandElement = x.into();
79 let mut item = expand_element.item;
80 item.elem = <$out_ty as CubePrimitive>::as_elem(scope);
81 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
82 }
83 }
84
85 $(impl $trait_name for $type {})*
86 }
87}
88
89impl_unary_func!(
90 Abs,
91 abs,
92 __expand_abs,
93 Arithmetic::Abs,
94 f16,
95 bf16,
96 flex32,
97 tf32,
98 f32,
99 f64,
100 i8,
101 i16,
102 i32,
103 i64,
104 u8,
105 u16,
106 u32,
107 u64
108);
109impl_unary_func!(
110 Exp,
111 exp,
112 __expand_exp,
113 Arithmetic::Exp,
114 f16,
115 bf16,
116 flex32,
117 tf32,
118 f32,
119 f64
120);
121impl_unary_func!(
122 Log,
123 log,
124 __expand_log,
125 Arithmetic::Log,
126 f16,
127 bf16,
128 flex32,
129 tf32,
130 f32,
131 f64
132);
133impl_unary_func!(
134 Log1p,
135 log1p,
136 __expand_log1p,
137 Arithmetic::Log1p,
138 f16,
139 bf16,
140 flex32,
141 tf32,
142 f32,
143 f64
144);
145impl_unary_func!(
146 Cos,
147 cos,
148 __expand_cos,
149 Arithmetic::Cos,
150 f16,
151 bf16,
152 flex32,
153 tf32,
154 f32,
155 f64
156);
157impl_unary_func!(
158 Sin,
159 sin,
160 __expand_sin,
161 Arithmetic::Sin,
162 f16,
163 bf16,
164 flex32,
165 tf32,
166 f32,
167 f64
168);
169impl_unary_func!(
170 Tanh,
171 tanh,
172 __expand_tanh,
173 Arithmetic::Tanh,
174 f16,
175 bf16,
176 flex32,
177 tf32,
178 f32,
179 f64
180);
181impl_unary_func!(
182 Sqrt,
183 sqrt,
184 __expand_sqrt,
185 Arithmetic::Sqrt,
186 f16,
187 bf16,
188 flex32,
189 tf32,
190 f32,
191 f64
192);
193impl_unary_func!(
194 Round,
195 round,
196 __expand_round,
197 Arithmetic::Round,
198 f16,
199 bf16,
200 flex32,
201 tf32,
202 f32,
203 f64
204);
205impl_unary_func!(
206 Floor,
207 floor,
208 __expand_floor,
209 Arithmetic::Floor,
210 f16,
211 bf16,
212 flex32,
213 tf32,
214 f32,
215 f64
216);
217impl_unary_func!(
218 Ceil,
219 ceil,
220 __expand_ceil,
221 Arithmetic::Ceil,
222 f16,
223 bf16,
224 flex32,
225 tf32,
226 f32,
227 f64
228);
229impl_unary_func!(
230 Erf,
231 erf,
232 __expand_erf,
233 Arithmetic::Erf,
234 f16,
235 bf16,
236 flex32,
237 tf32,
238 f32,
239 f64
240);
241impl_unary_func!(
242 Recip,
243 recip,
244 __expand_recip,
245 Arithmetic::Recip,
246 f16,
247 bf16,
248 flex32,
249 tf32,
250 f32,
251 f64
252);
253impl_unary_func_fixed_out_vectorization!(
254 Magnitude,
255 magnitude,
256 __expand_magnitude,
257 Arithmetic::Magnitude,
258 None,
259 f16,
260 bf16,
261 flex32,
262 tf32,
263 f32,
264 f64
265);
266impl_unary_func!(
267 Normalize,
268 normalize,
269 __expand_normalize,
270 Arithmetic::Normalize,
271 f16,
272 bf16,
273 flex32,
274 tf32,
275 f32,
276 f64
277);
278impl_unary_func_fixed_out_ty!(
279 CountOnes,
280 count_ones,
281 __expand_count_ones,
282 u32,
283 Bitwise::CountOnes,
284 u8,
285 i8,
286 u16,
287 i16,
288 u32,
289 i32,
290 u64,
291 i64
292);
293impl_unary_func!(
294 ReverseBits,
295 reverse_bits,
296 __expand_reverse_bits,
297 Bitwise::ReverseBits,
298 u8,
299 i8,
300 u16,
301 i16,
302 u32,
303 i32,
304 u64,
305 i64
306);
307
308impl_unary_func!(
309 BitwiseNot,
310 bitwise_not,
311 __expand_bitwise_not,
312 Bitwise::BitwiseNot,
313 u8,
314 i8,
315 u16,
316 i16,
317 u32,
318 i32,
319 u64,
320 i64
321);
322impl_unary_func_fixed_out_ty!(
323 LeadingZeros,
324 leading_zeros,
325 __expand_leading_zeros,
326 u32,
327 Bitwise::LeadingZeros,
328 u8,
329 i8,
330 u16,
331 i16,
332 u32,
333 i32,
334 u64,
335 i64
336);
337impl_unary_func_fixed_out_ty!(
338 FindFirstSet,
339 find_first_set,
340 __expand_find_first_set,
341 u32,
342 Bitwise::FindFirstSet,
343 u8,
344 i8,
345 u16,
346 i16,
347 u32,
348 i32,
349 u64,
350 i64
351);