cubecl_core/frontend/container/line/
ops.rs1use core::ops::Not;
2use cubecl_ir::{Bitwise, ElemType, Instruction, Type, UIntKind, UnaryOperator};
3use cubecl_macros::{cube, intrinsic};
4use num_traits::{NumCast, ToPrimitive};
5
6use crate::{
7 self as cubecl,
8 prelude::{
9 ArcTan2, InverseSqrt, IsInf, IsNan, Powf, Powi, SaturatingAdd, SaturatingSub, Trunc,
10 },
11};
12use crate::{
13 frontend::{
14 Abs, ArcCos, ArcCosh, ArcSin, ArcSinh, ArcTan, ArcTanh, Ceil, Cos, Cosh, CubeNot,
15 CubePrimitive, Erf, Exp, ExpandElementTyped, Floor, Log, Log1p, Recip, Remainder, Round,
16 Sin, Sinh, Sqrt, Tan, Tanh,
17 },
18 prelude::{CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
19 unexpanded,
20};
21
22use super::Line;
23type LineExpand<E> = ExpandElementTyped<Line<E>>;
24
25impl<P> core::ops::Add<Self> for Line<P>
26where
27 P: CubePrimitive,
28 P: core::ops::Add<P, Output = P>,
29{
30 type Output = Self;
31
32 fn add(self, rhs: Self) -> Self::Output {
33 Self::new(self.val + rhs.val)
34 }
35}
36
37impl<P> core::ops::Sub<Self> for Line<P>
38where
39 P: CubePrimitive,
40 P: core::ops::Sub<P, Output = P>,
41{
42 type Output = Self;
43
44 fn sub(self, rhs: Self) -> Self::Output {
45 Self::new(self.val - rhs.val)
46 }
47}
48
49impl<P> core::ops::Mul<Self> for Line<P>
50where
51 P: CubePrimitive,
52 P: core::ops::Mul<P, Output = P>,
53{
54 type Output = Self;
55
56 fn mul(self, rhs: Self) -> Self::Output {
57 Self::new(self.val * rhs.val)
58 }
59}
60
61impl<P> core::ops::Div<Self> for Line<P>
62where
63 P: CubePrimitive,
64 P: core::ops::Div<P, Output = P>,
65{
66 type Output = Self;
67
68 fn div(self, rhs: Self) -> Self::Output {
69 Self::new(self.val / rhs.val)
70 }
71}
72
73impl<P> core::ops::AddAssign<Self> for Line<P>
74where
75 P: CubePrimitive,
76 P: core::ops::AddAssign,
77{
78 fn add_assign(&mut self, rhs: Self) {
79 self.val += rhs.val;
80 }
81}
82
83impl<P> core::ops::SubAssign<Self> for Line<P>
84where
85 P: CubePrimitive,
86 P: core::ops::SubAssign,
87{
88 fn sub_assign(&mut self, rhs: Self) {
89 self.val -= rhs.val;
90 }
91}
92
93impl<P> core::ops::DivAssign<Self> for Line<P>
94where
95 P: CubePrimitive,
96 P: core::ops::DivAssign,
97{
98 fn div_assign(&mut self, rhs: Self) {
99 self.val /= rhs.val;
100 }
101}
102
103impl<P> core::ops::MulAssign<Self> for Line<P>
104where
105 P: CubePrimitive,
106 P: core::ops::MulAssign,
107{
108 fn mul_assign(&mut self, rhs: Self) {
109 self.val *= rhs.val;
110 }
111}
112
113impl<P> core::cmp::PartialEq for Line<P>
114where
115 P: CubePrimitive,
116 P: core::cmp::PartialEq,
117{
118 fn eq(&self, other: &Self) -> bool {
119 self.val.eq(&other.val)
120 }
121}
122
123impl<P> core::cmp::PartialOrd for Line<P>
124where
125 P: CubePrimitive,
126 P: core::cmp::PartialOrd,
127{
128 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
129 self.val.partial_cmp(&other.val)
130 }
131}
132
133impl<P> core::ops::BitAnd<Self> for Line<P>
134where
135 P: CubePrimitive,
136 P: core::ops::BitAnd<P, Output = P>,
137{
138 type Output = Self;
139
140 fn bitand(self, rhs: Self) -> Self::Output {
141 Self::new(self.val & rhs.val)
142 }
143}
144
145impl<P> core::ops::BitOr<Self> for Line<P>
146where
147 P: CubePrimitive,
148 P: core::ops::BitOr<P, Output = P>,
149{
150 type Output = Self;
151
152 fn bitor(self, rhs: Self) -> Self::Output {
153 Self::new(self.val | rhs.val)
154 }
155}
156
157impl<P> core::ops::BitXor<Self> for Line<P>
158where
159 P: CubePrimitive,
160 P: core::ops::BitXor<P, Output = P>,
161{
162 type Output = Self;
163
164 fn bitxor(self, rhs: Self) -> Self::Output {
165 Self::new(self.val ^ rhs.val)
166 }
167}
168
169impl<P> core::ops::Shl<Self> for Line<P>
170where
171 P: CubePrimitive,
172 P: core::ops::Shl<P, Output = P>,
173{
174 type Output = Self;
175
176 fn shl(self, rhs: Self) -> Self::Output {
177 Self::new(self.val << rhs.val)
178 }
179}
180
181impl<P> core::ops::Shr<Self> for Line<P>
182where
183 P: CubePrimitive,
184 P: core::ops::Shr<P, Output = P>,
185{
186 type Output = Self;
187
188 fn shr(self, rhs: Self) -> Self::Output {
189 Self::new(self.val >> rhs.val)
190 }
191}
192
193impl<P> core::ops::BitAndAssign<Self> for Line<P>
194where
195 P: CubePrimitive,
196 P: core::ops::BitAndAssign,
197{
198 fn bitand_assign(&mut self, rhs: Self) {
199 self.val &= rhs.val;
200 }
201}
202
203impl<P> core::ops::BitOrAssign<Self> for Line<P>
204where
205 P: CubePrimitive,
206 P: core::ops::BitOrAssign,
207{
208 fn bitor_assign(&mut self, rhs: Self) {
209 self.val |= rhs.val;
210 }
211}
212
213impl<P> core::ops::BitXorAssign<Self> for Line<P>
214where
215 P: CubePrimitive,
216 P: core::ops::BitXorAssign,
217{
218 fn bitxor_assign(&mut self, rhs: Self) {
219 self.val ^= rhs.val;
220 }
221}
222
223impl<P> core::ops::ShlAssign<Self> for Line<P>
224where
225 P: CubePrimitive,
226 P: core::ops::ShlAssign,
227{
228 fn shl_assign(&mut self, rhs: Self) {
229 self.val <<= rhs.val;
230 }
231}
232
233impl<P> core::ops::ShrAssign<Self> for Line<P>
234where
235 P: CubePrimitive,
236 P: core::ops::ShrAssign,
237{
238 fn shr_assign(&mut self, rhs: Self) {
239 self.val >>= rhs.val;
240 }
241}
242
243impl<P: CubePrimitive + Abs> Abs for Line<P> {}
244impl<P: CubePrimitive + Log> Log for Line<P> {}
245impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
246impl<P: CubePrimitive + Erf> Erf for Line<P> {}
247impl<P: CubePrimitive + Exp> Exp for Line<P> {}
248impl<P: CubePrimitive + Powf> Powf for Line<P> {}
249impl<P: CubePrimitive + Powi<I>, I: CubePrimitive> Powi<Line<I>> for Line<P> {}
250impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
251impl<P: CubePrimitive + InverseSqrt> InverseSqrt for Line<P> {}
252impl<P: CubePrimitive + Cos> Cos for Line<P> {}
253impl<P: CubePrimitive + Sin> Sin for Line<P> {}
254impl<P: CubePrimitive + Tan> Tan for Line<P> {}
255impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
256impl<P: CubePrimitive + Sinh> Sinh for Line<P> {}
257impl<P: CubePrimitive + Cosh> Cosh for Line<P> {}
258impl<P: CubePrimitive + ArcSin> ArcSin for Line<P> {}
259impl<P: CubePrimitive + ArcCos> ArcCos for Line<P> {}
260impl<P: CubePrimitive + ArcTan> ArcTan for Line<P> {}
261impl<P: CubePrimitive + ArcSinh> ArcSinh for Line<P> {}
262impl<P: CubePrimitive + ArcCosh> ArcCosh for Line<P> {}
263impl<P: CubePrimitive + ArcTanh> ArcTanh for Line<P> {}
264impl<P: CubePrimitive + ArcTan2> ArcTan2 for Line<P> {}
265impl<P: CubePrimitive + Recip> Recip for Line<P> {}
266impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
267impl<P: CubePrimitive + Round> Round for Line<P> {}
268impl<P: CubePrimitive + Floor> Floor for Line<P> {}
269impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
270impl<P: CubePrimitive + Trunc> Trunc for Line<P> {}
271impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
272impl<P: CubePrimitive + CubeNot> CubeNot for Line<P> {}
273impl<P: CubePrimitive + SaturatingAdd> SaturatingAdd for Line<P> {}
274impl<P: CubePrimitive + SaturatingSub> SaturatingSub for Line<P> {}
275impl<P: CubePrimitive + IsNan> IsNan for Line<P> {}
276impl<P: CubePrimitive + IsInf> IsInf for Line<P> {}
277
278impl<P: CubePrimitive + Ord> Ord for Line<P> {
279 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
280 self.val.cmp(&other.val)
281 }
282}
283
284#[cube]
285impl<P: CountOnes + CubePrimitive> Line<P> {
286 pub fn count_ones(self) -> Line<u32> {
287 intrinsic!(|scope| {
288 let out_item =
289 Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
290 let out = scope.create_local(out_item);
291 scope.register(Instruction::new(
292 Bitwise::CountOnes(UnaryOperator {
293 input: *self.expand,
294 }),
295 *out,
296 ));
297 out.into()
298 })
299 }
300}
301
302#[cube]
303impl<P: LeadingZeros + CubePrimitive> Line<P> {
304 pub fn leading_zeros(self) -> Line<u32> {
305 intrinsic!(|scope| {
306 let out_item =
307 Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
308 let out = scope.create_local(out_item);
309 scope.register(Instruction::new(
310 Bitwise::LeadingZeros(UnaryOperator {
311 input: *self.expand,
312 }),
313 *out,
314 ));
315 out.into()
316 })
317 }
318}
319
320#[cube]
321impl<P: FindFirstSet + CubePrimitive> Line<P> {
322 pub fn find_first_set(self) -> Line<u32> {
323 intrinsic!(|scope| {
324 let out_item =
325 Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
326 let out = scope.create_local(out_item);
327 scope.register(Instruction::new(
328 Bitwise::FindFirstSet(UnaryOperator {
329 input: *self.expand,
330 }),
331 *out,
332 ));
333 out.into()
334 })
335 }
336}
337
338impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
339 fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
340 let val: P = NumCast::from(n)?;
341 Some(Self { val })
342 }
343}
344impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
345 fn to_i64(&self) -> Option<i64> {
346 self.val.to_i64()
347 }
348
349 fn to_u64(&self) -> Option<u64> {
350 self.val.to_u64()
351 }
352}
353
354impl<P: Not<Output = P> + CubePrimitive> Not for Line<P> {
355 type Output = Self;
356
357 fn not(self) -> Self::Output {
358 Line::new(self.val.not())
359 }
360}
361
362#[allow(clippy::from_over_into)]
363impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
364 fn into(self) -> ExpandElementTyped<Self> {
365 let elem: ExpandElementTyped<P> = self.val.into();
366 elem.expand.into()
367 }
368}
369
370macro_rules! operation_literal {
371 ($lit:ty) => {
372 impl<P> core::ops::Add<$lit> for Line<P>
373 where
374 P: CubePrimitive,
375 {
376 type Output = Self;
377
378 fn add(self, _rhs: $lit) -> Self::Output {
379 unexpanded!();
380 }
381 }
382
383 impl<P> core::ops::Sub<$lit> for Line<P>
384 where
385 P: CubePrimitive,
386 {
387 type Output = Self;
388
389 fn sub(self, _rhs: $lit) -> Self::Output {
390 unexpanded!();
391 }
392 }
393
394 impl<P> core::ops::Mul<$lit> for Line<P>
395 where
396 P: CubePrimitive,
397 {
398 type Output = Self;
399
400 fn mul(self, _rhs: $lit) -> Self::Output {
401 unexpanded!();
402 }
403 }
404
405 impl<P> core::ops::Div<$lit> for Line<P>
406 where
407 P: CubePrimitive,
408 {
409 type Output = Self;
410
411 fn div(self, _rhs: $lit) -> Self::Output {
412 unexpanded!();
413 }
414 }
415 };
416}
417
418operation_literal!(f32);
419operation_literal!(f64);
420operation_literal!(usize);
421operation_literal!(i32);
422operation_literal!(i64);