1use half::{bf16, f16};
2
3use crate::{
4 flex32,
5 frontend::CubeContext,
6 ir::Operator,
7 prelude::{CubePrimitive, ExpandElement, ExpandElementTyped},
8 tf32, unexpanded,
9};
10
11use super::base::{unary_expand, unary_expand_fixed_output};
12
13pub mod not {
14 use super::*;
15
16 pub fn expand(
17 context: &mut CubeContext,
18 x: ExpandElementTyped<bool>,
19 ) -> ExpandElementTyped<bool> {
20 unary_expand(context, x.into(), Operator::Not).into()
21 }
22}
23
24pub mod neg {
25 use crate::prelude::Numeric;
26
27 use super::*;
28
29 pub fn expand<N: Numeric>(
30 context: &mut CubeContext,
31 x: ExpandElementTyped<N>,
32 ) -> ExpandElementTyped<N> {
33 unary_expand(context, x.into(), Operator::Neg).into()
34 }
35}
36
37macro_rules! impl_unary_func {
38 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
39 pub trait $trait_name: CubePrimitive + Sized {
40 #[allow(unused_variables)]
41 fn $method_name(x: Self) -> Self {
42 unexpanded!()
43 }
44
45 fn $method_name_expand(context: &mut CubeContext, x: Self::ExpandType) -> ExpandElementTyped<Self> {
46 unary_expand(context, x.into(), $operator).into()
47 }
48 }
49
50 $(impl $trait_name for $type {})*
51 }
52}
53
54macro_rules! impl_unary_func_fixed_out_vectorization {
55 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
56 pub trait $trait_name: CubePrimitive + Sized {
57 #[allow(unused_variables)]
58 fn $method_name(x: Self) -> Self {
59 unexpanded!()
60 }
61
62 fn $method_name_expand(context: &mut CubeContext, x: Self::ExpandType) -> ExpandElementTyped<Self> {
63 let expand_element: ExpandElement = x.into();
64 let mut item = expand_element.item;
65 item.vectorization = $out_vectorization;
66 unary_expand_fixed_output(context, expand_element, item, $operator).into()
67 }
68 }
69
70 $(impl $trait_name for $type {})*
71 }
72}
73
74macro_rules! impl_unary_func_fixed_out_ty {
75 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
76 pub trait $trait_name: CubePrimitive + Sized {
77 #[allow(unused_variables)]
78 fn $method_name(x: Self) -> $out_ty {
79 unexpanded!()
80 }
81
82 fn $method_name_expand(context: &mut CubeContext, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
83 let expand_element: ExpandElement = x.into();
84 let mut item = expand_element.item;
85 item.elem = <$out_ty as CubePrimitive>::as_elem(context);
86 unary_expand_fixed_output(context, expand_element, item, $operator).into()
87 }
88 }
89
90 $(impl $trait_name for $type {})*
91 }
92}
93
94impl_unary_func!(
95 Abs,
96 abs,
97 __expand_abs,
98 Operator::Abs,
99 f16,
100 bf16,
101 flex32,
102 tf32,
103 f32,
104 f64,
105 i8,
106 i16,
107 i32,
108 i64,
109 u8,
110 u16,
111 u32,
112 u64
113);
114impl_unary_func!(
115 Exp,
116 exp,
117 __expand_exp,
118 Operator::Exp,
119 f16,
120 bf16,
121 flex32,
122 tf32,
123 f32,
124 f64
125);
126impl_unary_func!(
127 Log,
128 log,
129 __expand_log,
130 Operator::Log,
131 f16,
132 bf16,
133 flex32,
134 tf32,
135 f32,
136 f64
137);
138impl_unary_func!(
139 Log1p,
140 log1p,
141 __expand_log1p,
142 Operator::Log1p,
143 f16,
144 bf16,
145 flex32,
146 tf32,
147 f32,
148 f64
149);
150impl_unary_func!(
151 Cos,
152 cos,
153 __expand_cos,
154 Operator::Cos,
155 f16,
156 bf16,
157 flex32,
158 tf32,
159 f32,
160 f64
161);
162impl_unary_func!(
163 Sin,
164 sin,
165 __expand_sin,
166 Operator::Sin,
167 f16,
168 bf16,
169 flex32,
170 tf32,
171 f32,
172 f64
173);
174impl_unary_func!(
175 Tanh,
176 tanh,
177 __expand_tanh,
178 Operator::Tanh,
179 f16,
180 bf16,
181 flex32,
182 tf32,
183 f32,
184 f64
185);
186impl_unary_func!(
187 Sqrt,
188 sqrt,
189 __expand_sqrt,
190 Operator::Sqrt,
191 f16,
192 bf16,
193 flex32,
194 tf32,
195 f32,
196 f64
197);
198impl_unary_func!(
199 Round,
200 round,
201 __expand_round,
202 Operator::Round,
203 f16,
204 bf16,
205 flex32,
206 tf32,
207 f32,
208 f64
209);
210impl_unary_func!(
211 Floor,
212 floor,
213 __expand_floor,
214 Operator::Floor,
215 f16,
216 bf16,
217 flex32,
218 tf32,
219 f32,
220 f64
221);
222impl_unary_func!(
223 Ceil,
224 ceil,
225 __expand_ceil,
226 Operator::Ceil,
227 f16,
228 bf16,
229 flex32,
230 tf32,
231 f32,
232 f64
233);
234impl_unary_func!(
235 Erf,
236 erf,
237 __expand_erf,
238 Operator::Erf,
239 f16,
240 bf16,
241 flex32,
242 tf32,
243 f32,
244 f64
245);
246impl_unary_func!(
247 Recip,
248 recip,
249 __expand_recip,
250 Operator::Recip,
251 f16,
252 bf16,
253 flex32,
254 tf32,
255 f32,
256 f64
257);
258impl_unary_func_fixed_out_vectorization!(
259 Magnitude,
260 magnitude,
261 __expand_magnitude,
262 Operator::Magnitude,
263 None,
264 f16,
265 bf16,
266 flex32,
267 tf32,
268 f32,
269 f64
270);
271impl_unary_func!(
272 Normalize,
273 normalize,
274 __expand_normalize,
275 Operator::Normalize,
276 f16,
277 bf16,
278 flex32,
279 tf32,
280 f32,
281 f64
282);
283impl_unary_func_fixed_out_ty!(
284 CountOnes,
285 count_ones,
286 __expand_count_ones,
287 u32,
288 Operator::CountOnes,
289 u8,
290 i8,
291 u16,
292 i16,
293 u32,
294 i32,
295 u64,
296 i64
297);
298impl_unary_func!(
299 ReverseBits,
300 reverse_bits,
301 __expand_reverse_bits,
302 Operator::ReverseBits,
303 u8,
304 i8,
305 u16,
306 i16,
307 u32,
308 i32,
309 u64,
310 i64
311);
312impl_unary_func!(
313 Not,
314 not,
315 __expand_not,
316 Operator::Not,
317 u8,
318 i8,
319 u16,
320 i16,
321 u32,
322 i32,
323 u64,
324 i64
325);