cubecl_core/frontend/container/line/
ops.rs

1use core::ops::Not;
2use cubecl_ir::{Bitwise, ElemType, Instruction, Type, UIntKind, UnaryOperator};
3use cubecl_macros::{cube, intrinsic};
4use num_traits::{NumCast, ToPrimitive};
5
6use crate::{
7    self as cubecl,
8    prelude::{
9        ArcTan2, InverseSqrt, IsInf, IsNan, Powf, Powi, SaturatingAdd, SaturatingSub, Trunc,
10    },
11};
12use crate::{
13    frontend::{
14        Abs, ArcCos, ArcCosh, ArcSin, ArcSinh, ArcTan, ArcTanh, Ceil, Cos, Cosh, CubeNot,
15        CubePrimitive, Erf, Exp, ExpandElementTyped, Floor, Log, Log1p, Recip, Remainder, Round,
16        Sin, Sinh, Sqrt, Tan, Tanh,
17    },
18    prelude::{CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
19    unexpanded,
20};
21
22use super::Line;
23type LineExpand<E> = ExpandElementTyped<Line<E>>;
24
25impl<P> core::ops::Add<Self> for Line<P>
26where
27    P: CubePrimitive,
28    P: core::ops::Add<P, Output = P>,
29{
30    type Output = Self;
31
32    fn add(self, rhs: Self) -> Self::Output {
33        Self::new(self.val + rhs.val)
34    }
35}
36
37impl<P> core::ops::Sub<Self> for Line<P>
38where
39    P: CubePrimitive,
40    P: core::ops::Sub<P, Output = P>,
41{
42    type Output = Self;
43
44    fn sub(self, rhs: Self) -> Self::Output {
45        Self::new(self.val - rhs.val)
46    }
47}
48
49impl<P> core::ops::Mul<Self> for Line<P>
50where
51    P: CubePrimitive,
52    P: core::ops::Mul<P, Output = P>,
53{
54    type Output = Self;
55
56    fn mul(self, rhs: Self) -> Self::Output {
57        Self::new(self.val * rhs.val)
58    }
59}
60
61impl<P> core::ops::Div<Self> for Line<P>
62where
63    P: CubePrimitive,
64    P: core::ops::Div<P, Output = P>,
65{
66    type Output = Self;
67
68    fn div(self, rhs: Self) -> Self::Output {
69        Self::new(self.val / rhs.val)
70    }
71}
72
73impl<P> core::ops::AddAssign<Self> for Line<P>
74where
75    P: CubePrimitive,
76    P: core::ops::AddAssign,
77{
78    fn add_assign(&mut self, rhs: Self) {
79        self.val += rhs.val;
80    }
81}
82
83impl<P> core::ops::SubAssign<Self> for Line<P>
84where
85    P: CubePrimitive,
86    P: core::ops::SubAssign,
87{
88    fn sub_assign(&mut self, rhs: Self) {
89        self.val -= rhs.val;
90    }
91}
92
93impl<P> core::ops::DivAssign<Self> for Line<P>
94where
95    P: CubePrimitive,
96    P: core::ops::DivAssign,
97{
98    fn div_assign(&mut self, rhs: Self) {
99        self.val /= rhs.val;
100    }
101}
102
103impl<P> core::ops::MulAssign<Self> for Line<P>
104where
105    P: CubePrimitive,
106    P: core::ops::MulAssign,
107{
108    fn mul_assign(&mut self, rhs: Self) {
109        self.val *= rhs.val;
110    }
111}
112
113impl<P> core::cmp::PartialEq for Line<P>
114where
115    P: CubePrimitive,
116    P: core::cmp::PartialEq,
117{
118    fn eq(&self, other: &Self) -> bool {
119        self.val.eq(&other.val)
120    }
121}
122
123impl<P> core::cmp::PartialOrd for Line<P>
124where
125    P: CubePrimitive,
126    P: core::cmp::PartialOrd,
127{
128    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
129        self.val.partial_cmp(&other.val)
130    }
131}
132
133impl<P> core::ops::BitAnd<Self> for Line<P>
134where
135    P: CubePrimitive,
136    P: core::ops::BitAnd<P, Output = P>,
137{
138    type Output = Self;
139
140    fn bitand(self, rhs: Self) -> Self::Output {
141        Self::new(self.val & rhs.val)
142    }
143}
144
145impl<P> core::ops::BitOr<Self> for Line<P>
146where
147    P: CubePrimitive,
148    P: core::ops::BitOr<P, Output = P>,
149{
150    type Output = Self;
151
152    fn bitor(self, rhs: Self) -> Self::Output {
153        Self::new(self.val | rhs.val)
154    }
155}
156
157impl<P> core::ops::BitXor<Self> for Line<P>
158where
159    P: CubePrimitive,
160    P: core::ops::BitXor<P, Output = P>,
161{
162    type Output = Self;
163
164    fn bitxor(self, rhs: Self) -> Self::Output {
165        Self::new(self.val ^ rhs.val)
166    }
167}
168
169impl<P> core::ops::Shl<Self> for Line<P>
170where
171    P: CubePrimitive,
172    P: core::ops::Shl<P, Output = P>,
173{
174    type Output = Self;
175
176    fn shl(self, rhs: Self) -> Self::Output {
177        Self::new(self.val << rhs.val)
178    }
179}
180
181impl<P> core::ops::Shr<Self> for Line<P>
182where
183    P: CubePrimitive,
184    P: core::ops::Shr<P, Output = P>,
185{
186    type Output = Self;
187
188    fn shr(self, rhs: Self) -> Self::Output {
189        Self::new(self.val >> rhs.val)
190    }
191}
192
193impl<P> core::ops::BitAndAssign<Self> for Line<P>
194where
195    P: CubePrimitive,
196    P: core::ops::BitAndAssign,
197{
198    fn bitand_assign(&mut self, rhs: Self) {
199        self.val &= rhs.val;
200    }
201}
202
203impl<P> core::ops::BitOrAssign<Self> for Line<P>
204where
205    P: CubePrimitive,
206    P: core::ops::BitOrAssign,
207{
208    fn bitor_assign(&mut self, rhs: Self) {
209        self.val |= rhs.val;
210    }
211}
212
213impl<P> core::ops::BitXorAssign<Self> for Line<P>
214where
215    P: CubePrimitive,
216    P: core::ops::BitXorAssign,
217{
218    fn bitxor_assign(&mut self, rhs: Self) {
219        self.val ^= rhs.val;
220    }
221}
222
223impl<P> core::ops::ShlAssign<Self> for Line<P>
224where
225    P: CubePrimitive,
226    P: core::ops::ShlAssign,
227{
228    fn shl_assign(&mut self, rhs: Self) {
229        self.val <<= rhs.val;
230    }
231}
232
233impl<P> core::ops::ShrAssign<Self> for Line<P>
234where
235    P: CubePrimitive,
236    P: core::ops::ShrAssign,
237{
238    fn shr_assign(&mut self, rhs: Self) {
239        self.val >>= rhs.val;
240    }
241}
242
243impl<P: CubePrimitive + Abs> Abs for Line<P> {}
244impl<P: CubePrimitive + Log> Log for Line<P> {}
245impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
246impl<P: CubePrimitive + Erf> Erf for Line<P> {}
247impl<P: CubePrimitive + Exp> Exp for Line<P> {}
248impl<P: CubePrimitive + Powf> Powf for Line<P> {}
249impl<P: CubePrimitive + Powi<I>, I: CubePrimitive> Powi<Line<I>> for Line<P> {}
250impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
251impl<P: CubePrimitive + InverseSqrt> InverseSqrt for Line<P> {}
252impl<P: CubePrimitive + Cos> Cos for Line<P> {}
253impl<P: CubePrimitive + Sin> Sin for Line<P> {}
254impl<P: CubePrimitive + Tan> Tan for Line<P> {}
255impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
256impl<P: CubePrimitive + Sinh> Sinh for Line<P> {}
257impl<P: CubePrimitive + Cosh> Cosh for Line<P> {}
258impl<P: CubePrimitive + ArcSin> ArcSin for Line<P> {}
259impl<P: CubePrimitive + ArcCos> ArcCos for Line<P> {}
260impl<P: CubePrimitive + ArcTan> ArcTan for Line<P> {}
261impl<P: CubePrimitive + ArcSinh> ArcSinh for Line<P> {}
262impl<P: CubePrimitive + ArcCosh> ArcCosh for Line<P> {}
263impl<P: CubePrimitive + ArcTanh> ArcTanh for Line<P> {}
264impl<P: CubePrimitive + ArcTan2> ArcTan2 for Line<P> {}
265impl<P: CubePrimitive + Recip> Recip for Line<P> {}
266impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
267impl<P: CubePrimitive + Round> Round for Line<P> {}
268impl<P: CubePrimitive + Floor> Floor for Line<P> {}
269impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
270impl<P: CubePrimitive + Trunc> Trunc for Line<P> {}
271impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
272impl<P: CubePrimitive + CubeNot> CubeNot for Line<P> {}
273impl<P: CubePrimitive + SaturatingAdd> SaturatingAdd for Line<P> {}
274impl<P: CubePrimitive + SaturatingSub> SaturatingSub for Line<P> {}
275impl<P: CubePrimitive + IsNan> IsNan for Line<P> {}
276impl<P: CubePrimitive + IsInf> IsInf for Line<P> {}
277
278impl<P: CubePrimitive + Ord> Ord for Line<P> {
279    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
280        self.val.cmp(&other.val)
281    }
282}
283
284#[cube]
285impl<P: CountOnes + CubePrimitive> Line<P> {
286    pub fn count_ones(self) -> Line<u32> {
287        intrinsic!(|scope| {
288            let out_item =
289                Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
290            let out = scope.create_local(out_item);
291            scope.register(Instruction::new(
292                Bitwise::CountOnes(UnaryOperator {
293                    input: *self.expand,
294                }),
295                *out,
296            ));
297            out.into()
298        })
299    }
300}
301
302#[cube]
303impl<P: LeadingZeros + CubePrimitive> Line<P> {
304    pub fn leading_zeros(self) -> Line<u32> {
305        intrinsic!(|scope| {
306            let out_item =
307                Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
308            let out = scope.create_local(out_item);
309            scope.register(Instruction::new(
310                Bitwise::LeadingZeros(UnaryOperator {
311                    input: *self.expand,
312                }),
313                *out,
314            ));
315            out.into()
316        })
317    }
318}
319
320#[cube]
321impl<P: FindFirstSet + CubePrimitive> Line<P> {
322    pub fn find_first_set(self) -> Line<u32> {
323        intrinsic!(|scope| {
324            let out_item =
325                Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
326            let out = scope.create_local(out_item);
327            scope.register(Instruction::new(
328                Bitwise::FindFirstSet(UnaryOperator {
329                    input: *self.expand,
330                }),
331                *out,
332            ));
333            out.into()
334        })
335    }
336}
337
338impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
339    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
340        let val: P = NumCast::from(n)?;
341        Some(Self { val })
342    }
343}
344impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
345    fn to_i64(&self) -> Option<i64> {
346        self.val.to_i64()
347    }
348
349    fn to_u64(&self) -> Option<u64> {
350        self.val.to_u64()
351    }
352}
353
354impl<P: Not<Output = P> + CubePrimitive> Not for Line<P> {
355    type Output = Self;
356
357    fn not(self) -> Self::Output {
358        Line::new(self.val.not())
359    }
360}
361
362#[allow(clippy::from_over_into)]
363impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
364    fn into(self) -> ExpandElementTyped<Self> {
365        let elem: ExpandElementTyped<P> = self.val.into();
366        elem.expand.into()
367    }
368}
369
370macro_rules! operation_literal {
371    ($lit:ty) => {
372        impl<P> core::ops::Add<$lit> for Line<P>
373        where
374            P: CubePrimitive,
375        {
376            type Output = Self;
377
378            fn add(self, _rhs: $lit) -> Self::Output {
379                unexpanded!();
380            }
381        }
382
383        impl<P> core::ops::Sub<$lit> for Line<P>
384        where
385            P: CubePrimitive,
386        {
387            type Output = Self;
388
389            fn sub(self, _rhs: $lit) -> Self::Output {
390                unexpanded!();
391            }
392        }
393
394        impl<P> core::ops::Mul<$lit> for Line<P>
395        where
396            P: CubePrimitive,
397        {
398            type Output = Self;
399
400            fn mul(self, _rhs: $lit) -> Self::Output {
401                unexpanded!();
402            }
403        }
404
405        impl<P> core::ops::Div<$lit> for Line<P>
406        where
407            P: CubePrimitive,
408        {
409            type Output = Self;
410
411            fn div(self, _rhs: $lit) -> Self::Output {
412                unexpanded!();
413            }
414        }
415    };
416}
417
418operation_literal!(f32);
419operation_literal!(f64);
420operation_literal!(usize);
421operation_literal!(i32);
422operation_literal!(i64);