cubecl_core/frontend/container/line/
ops.rs

1use cubecl_ir::{Bitwise, Elem, Instruction, UIntKind, UnaryOperator};
2use cubecl_macros::{cube, intrinsic};
3use num_traits::{NumCast, ToPrimitive};
4
5use crate as cubecl;
6use crate::{
7    frontend::{
8        Abs, Ceil, Clamp, Cos, CubePrimitive, Erf, Exp, ExpandElementTyped, Floor, Log, Log1p, Max,
9        Min, Powf, Recip, Remainder, Round, Sin, Sqrt, Tanh,
10    },
11    prelude::{BitwiseNot, CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
12    unexpanded,
13};
14
15use super::Line;
16type LineExpand<E> = ExpandElementTyped<Line<E>>;
17
18impl<P> core::ops::Add<Self> for Line<P>
19where
20    P: CubePrimitive,
21    P: core::ops::Add<P, Output = P>,
22{
23    type Output = Self;
24
25    fn add(self, rhs: Self) -> Self::Output {
26        Self::new(self.val + rhs.val)
27    }
28}
29
30impl<P> core::ops::Sub<Self> for Line<P>
31where
32    P: CubePrimitive,
33    P: core::ops::Sub<P, Output = P>,
34{
35    type Output = Self;
36
37    fn sub(self, rhs: Self) -> Self::Output {
38        Self::new(self.val - rhs.val)
39    }
40}
41
42impl<P> core::ops::Mul<Self> for Line<P>
43where
44    P: CubePrimitive,
45    P: core::ops::Mul<P, Output = P>,
46{
47    type Output = Self;
48
49    fn mul(self, rhs: Self) -> Self::Output {
50        Self::new(self.val * rhs.val)
51    }
52}
53
54impl<P> core::ops::Div<Self> for Line<P>
55where
56    P: CubePrimitive,
57    P: core::ops::Div<P, Output = P>,
58{
59    type Output = Self;
60
61    fn div(self, rhs: Self) -> Self::Output {
62        Self::new(self.val / rhs.val)
63    }
64}
65
66impl<P> core::ops::AddAssign<Self> for Line<P>
67where
68    P: CubePrimitive,
69    P: core::ops::AddAssign,
70{
71    fn add_assign(&mut self, rhs: Self) {
72        self.val += rhs.val;
73    }
74}
75
76impl<P> core::ops::SubAssign<Self> for Line<P>
77where
78    P: CubePrimitive,
79    P: core::ops::SubAssign,
80{
81    fn sub_assign(&mut self, rhs: Self) {
82        self.val -= rhs.val;
83    }
84}
85
86impl<P> core::ops::DivAssign<Self> for Line<P>
87where
88    P: CubePrimitive,
89    P: core::ops::DivAssign,
90{
91    fn div_assign(&mut self, rhs: Self) {
92        self.val /= rhs.val;
93    }
94}
95
96impl<P> core::ops::MulAssign<Self> for Line<P>
97where
98    P: CubePrimitive,
99    P: core::ops::MulAssign,
100{
101    fn mul_assign(&mut self, rhs: Self) {
102        self.val *= rhs.val;
103    }
104}
105
106impl<P> core::cmp::PartialEq for Line<P>
107where
108    P: CubePrimitive,
109    P: core::cmp::PartialEq,
110{
111    fn eq(&self, other: &Self) -> bool {
112        self.val.eq(&other.val)
113    }
114}
115
116impl<P> core::cmp::PartialOrd for Line<P>
117where
118    P: CubePrimitive,
119    P: core::cmp::PartialOrd,
120{
121    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122        self.val.partial_cmp(&other.val)
123    }
124}
125
126impl<P> core::ops::BitAnd<Self> for Line<P>
127where
128    P: CubePrimitive,
129    P: core::ops::BitAnd<P, Output = P>,
130{
131    type Output = Self;
132
133    fn bitand(self, rhs: Self) -> Self::Output {
134        Self::new(self.val & rhs.val)
135    }
136}
137
138impl<P> core::ops::BitOr<Self> for Line<P>
139where
140    P: CubePrimitive,
141    P: core::ops::BitOr<P, Output = P>,
142{
143    type Output = Self;
144
145    fn bitor(self, rhs: Self) -> Self::Output {
146        Self::new(self.val | rhs.val)
147    }
148}
149
150impl<P> core::ops::BitXor<Self> for Line<P>
151where
152    P: CubePrimitive,
153    P: core::ops::BitXor<P, Output = P>,
154{
155    type Output = Self;
156
157    fn bitxor(self, rhs: Self) -> Self::Output {
158        Self::new(self.val ^ rhs.val)
159    }
160}
161
162impl<P> core::ops::Shl<Self> for Line<P>
163where
164    P: CubePrimitive,
165    P: core::ops::Shl<P, Output = P>,
166{
167    type Output = Self;
168
169    fn shl(self, rhs: Self) -> Self::Output {
170        Self::new(self.val << rhs.val)
171    }
172}
173
174impl<P> core::ops::Shr<Self> for Line<P>
175where
176    P: CubePrimitive,
177    P: core::ops::Shr<P, Output = P>,
178{
179    type Output = Self;
180
181    fn shr(self, rhs: Self) -> Self::Output {
182        Self::new(self.val >> rhs.val)
183    }
184}
185
186impl<P> core::ops::BitAndAssign<Self> for Line<P>
187where
188    P: CubePrimitive,
189    P: core::ops::BitAndAssign,
190{
191    fn bitand_assign(&mut self, rhs: Self) {
192        self.val &= rhs.val;
193    }
194}
195
196impl<P> core::ops::BitOrAssign<Self> for Line<P>
197where
198    P: CubePrimitive,
199    P: core::ops::BitOrAssign,
200{
201    fn bitor_assign(&mut self, rhs: Self) {
202        self.val |= rhs.val;
203    }
204}
205
206impl<P> core::ops::BitXorAssign<Self> for Line<P>
207where
208    P: CubePrimitive,
209    P: core::ops::BitXorAssign,
210{
211    fn bitxor_assign(&mut self, rhs: Self) {
212        self.val ^= rhs.val;
213    }
214}
215
216impl<P> core::ops::ShlAssign<Self> for Line<P>
217where
218    P: CubePrimitive,
219    P: core::ops::ShlAssign,
220{
221    fn shl_assign(&mut self, rhs: Self) {
222        self.val <<= rhs.val;
223    }
224}
225
226impl<P> core::ops::ShrAssign<Self> for Line<P>
227where
228    P: CubePrimitive,
229    P: core::ops::ShrAssign,
230{
231    fn shr_assign(&mut self, rhs: Self) {
232        self.val >>= rhs.val;
233    }
234}
235
236impl<P: CubePrimitive + Abs> Abs for Line<P> {}
237impl<P: CubePrimitive + Max> Max for Line<P> {}
238impl<P: CubePrimitive + Min> Min for Line<P> {}
239impl<P: CubePrimitive + Clamp> Clamp for Line<P> {}
240impl<P: CubePrimitive + Log> Log for Line<P> {}
241impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
242impl<P: CubePrimitive + Erf> Erf for Line<P> {}
243impl<P: CubePrimitive + Exp> Exp for Line<P> {}
244impl<P: CubePrimitive + Powf> Powf for Line<P> {}
245impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
246impl<P: CubePrimitive + Cos> Cos for Line<P> {}
247impl<P: CubePrimitive + Sin> Sin for Line<P> {}
248impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
249impl<P: CubePrimitive + Recip> Recip for Line<P> {}
250impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
251impl<P: CubePrimitive + Round> Round for Line<P> {}
252impl<P: CubePrimitive + Floor> Floor for Line<P> {}
253impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
254impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
255impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
256
257#[cube]
258impl<P: CountOnes> Line<P> {
259    pub fn count_ones(self) -> Line<u32> {
260        intrinsic!(|scope| {
261            let mut out_item = self.expand.item;
262            out_item.elem = Elem::UInt(UIntKind::U32);
263            let out = scope.create_local(out_item);
264            scope.register(Instruction::new(
265                Bitwise::CountOnes(UnaryOperator {
266                    input: *self.expand,
267                }),
268                *out,
269            ));
270            out.into()
271        })
272    }
273}
274
275#[cube]
276impl<P: LeadingZeros> Line<P> {
277    pub fn leading_zeros(self) -> Line<u32> {
278        intrinsic!(|scope| {
279            let mut out_item = self.expand.item;
280            out_item.elem = Elem::UInt(UIntKind::U32);
281            let out = scope.create_local(out_item);
282            scope.register(Instruction::new(
283                Bitwise::LeadingZeros(UnaryOperator {
284                    input: *self.expand,
285                }),
286                *out,
287            ));
288            out.into()
289        })
290    }
291}
292
293#[cube]
294impl<P: FindFirstSet> Line<P> {
295    pub fn find_first_set(self) -> Line<u32> {
296        intrinsic!(|scope| {
297            let mut out_item = self.expand.item;
298            out_item.elem = Elem::UInt(UIntKind::U32);
299            let out = scope.create_local(out_item);
300            scope.register(Instruction::new(
301                Bitwise::FindFirstSet(UnaryOperator {
302                    input: *self.expand,
303                }),
304                *out,
305            ));
306            out.into()
307        })
308    }
309}
310
311impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
312    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
313        let val: P = NumCast::from(n)?;
314        Some(Self { val })
315    }
316}
317impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
318    fn to_i64(&self) -> Option<i64> {
319        self.val.to_i64()
320    }
321
322    fn to_u64(&self) -> Option<u64> {
323        self.val.to_u64()
324    }
325}
326
327#[allow(clippy::from_over_into)]
328impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
329    fn into(self) -> ExpandElementTyped<Self> {
330        let elem: ExpandElementTyped<P> = self.val.into();
331        elem.expand.into()
332    }
333}
334
335macro_rules! operation_literal {
336    ($lit:ty) => {
337        impl<P> core::ops::Add<$lit> for Line<P>
338        where
339            P: CubePrimitive,
340        {
341            type Output = Self;
342
343            fn add(self, _rhs: $lit) -> Self::Output {
344                unexpanded!();
345            }
346        }
347
348        impl<P> core::ops::Sub<$lit> for Line<P>
349        where
350            P: CubePrimitive,
351        {
352            type Output = Self;
353
354            fn sub(self, _rhs: $lit) -> Self::Output {
355                unexpanded!();
356            }
357        }
358
359        impl<P> core::ops::Mul<$lit> for Line<P>
360        where
361            P: CubePrimitive,
362        {
363            type Output = Self;
364
365            fn mul(self, _rhs: $lit) -> Self::Output {
366                unexpanded!();
367            }
368        }
369
370        impl<P> core::ops::Div<$lit> for Line<P>
371        where
372            P: CubePrimitive,
373        {
374            type Output = Self;
375
376            fn div(self, _rhs: $lit) -> Self::Output {
377                unexpanded!();
378            }
379        }
380    };
381}
382
383operation_literal!(f32);
384operation_literal!(f64);
385operation_literal!(usize);
386operation_literal!(i32);
387operation_literal!(i64);