cubecl_core/frontend/container/line/
ops.rs

1use cubecl_ir::{Bitwise, ElemType, Instruction, Type, UIntKind, UnaryOperator};
2use cubecl_macros::{cube, intrinsic};
3use num_traits::{NumCast, ToPrimitive};
4
5use crate::{
6    self as cubecl,
7    prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub, Trunc},
8};
9use crate::{
10    frontend::{
11        Abs, Ceil, Clamp, Cos, CubePrimitive, Erf, Exp, ExpandElementTyped, Floor, Log, Log1p, Max,
12        Min, Powf, Recip, Remainder, Round, Sin, Sqrt, Tanh,
13    },
14    prelude::{BitwiseNot, CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
15    unexpanded,
16};
17
18use super::Line;
19type LineExpand<E> = ExpandElementTyped<Line<E>>;
20
21impl<P> core::ops::Add<Self> for Line<P>
22where
23    P: CubePrimitive,
24    P: core::ops::Add<P, Output = P>,
25{
26    type Output = Self;
27
28    fn add(self, rhs: Self) -> Self::Output {
29        Self::new(self.val + rhs.val)
30    }
31}
32
33impl<P> core::ops::Sub<Self> for Line<P>
34where
35    P: CubePrimitive,
36    P: core::ops::Sub<P, Output = P>,
37{
38    type Output = Self;
39
40    fn sub(self, rhs: Self) -> Self::Output {
41        Self::new(self.val - rhs.val)
42    }
43}
44
45impl<P> core::ops::Mul<Self> for Line<P>
46where
47    P: CubePrimitive,
48    P: core::ops::Mul<P, Output = P>,
49{
50    type Output = Self;
51
52    fn mul(self, rhs: Self) -> Self::Output {
53        Self::new(self.val * rhs.val)
54    }
55}
56
57impl<P> core::ops::Div<Self> for Line<P>
58where
59    P: CubePrimitive,
60    P: core::ops::Div<P, Output = P>,
61{
62    type Output = Self;
63
64    fn div(self, rhs: Self) -> Self::Output {
65        Self::new(self.val / rhs.val)
66    }
67}
68
69impl<P> core::ops::AddAssign<Self> for Line<P>
70where
71    P: CubePrimitive,
72    P: core::ops::AddAssign,
73{
74    fn add_assign(&mut self, rhs: Self) {
75        self.val += rhs.val;
76    }
77}
78
79impl<P> core::ops::SubAssign<Self> for Line<P>
80where
81    P: CubePrimitive,
82    P: core::ops::SubAssign,
83{
84    fn sub_assign(&mut self, rhs: Self) {
85        self.val -= rhs.val;
86    }
87}
88
89impl<P> core::ops::DivAssign<Self> for Line<P>
90where
91    P: CubePrimitive,
92    P: core::ops::DivAssign,
93{
94    fn div_assign(&mut self, rhs: Self) {
95        self.val /= rhs.val;
96    }
97}
98
99impl<P> core::ops::MulAssign<Self> for Line<P>
100where
101    P: CubePrimitive,
102    P: core::ops::MulAssign,
103{
104    fn mul_assign(&mut self, rhs: Self) {
105        self.val *= rhs.val;
106    }
107}
108
109impl<P> core::cmp::PartialEq for Line<P>
110where
111    P: CubePrimitive,
112    P: core::cmp::PartialEq,
113{
114    fn eq(&self, other: &Self) -> bool {
115        self.val.eq(&other.val)
116    }
117}
118
119impl<P> core::cmp::PartialOrd for Line<P>
120where
121    P: CubePrimitive,
122    P: core::cmp::PartialOrd,
123{
124    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
125        self.val.partial_cmp(&other.val)
126    }
127}
128
129impl<P> core::ops::BitAnd<Self> for Line<P>
130where
131    P: CubePrimitive,
132    P: core::ops::BitAnd<P, Output = P>,
133{
134    type Output = Self;
135
136    fn bitand(self, rhs: Self) -> Self::Output {
137        Self::new(self.val & rhs.val)
138    }
139}
140
141impl<P> core::ops::BitOr<Self> for Line<P>
142where
143    P: CubePrimitive,
144    P: core::ops::BitOr<P, Output = P>,
145{
146    type Output = Self;
147
148    fn bitor(self, rhs: Self) -> Self::Output {
149        Self::new(self.val | rhs.val)
150    }
151}
152
153impl<P> core::ops::BitXor<Self> for Line<P>
154where
155    P: CubePrimitive,
156    P: core::ops::BitXor<P, Output = P>,
157{
158    type Output = Self;
159
160    fn bitxor(self, rhs: Self) -> Self::Output {
161        Self::new(self.val ^ rhs.val)
162    }
163}
164
165impl<P> core::ops::Shl<Self> for Line<P>
166where
167    P: CubePrimitive,
168    P: core::ops::Shl<P, Output = P>,
169{
170    type Output = Self;
171
172    fn shl(self, rhs: Self) -> Self::Output {
173        Self::new(self.val << rhs.val)
174    }
175}
176
177impl<P> core::ops::Shr<Self> for Line<P>
178where
179    P: CubePrimitive,
180    P: core::ops::Shr<P, Output = P>,
181{
182    type Output = Self;
183
184    fn shr(self, rhs: Self) -> Self::Output {
185        Self::new(self.val >> rhs.val)
186    }
187}
188
189impl<P> core::ops::BitAndAssign<Self> for Line<P>
190where
191    P: CubePrimitive,
192    P: core::ops::BitAndAssign,
193{
194    fn bitand_assign(&mut self, rhs: Self) {
195        self.val &= rhs.val;
196    }
197}
198
199impl<P> core::ops::BitOrAssign<Self> for Line<P>
200where
201    P: CubePrimitive,
202    P: core::ops::BitOrAssign,
203{
204    fn bitor_assign(&mut self, rhs: Self) {
205        self.val |= rhs.val;
206    }
207}
208
209impl<P> core::ops::BitXorAssign<Self> for Line<P>
210where
211    P: CubePrimitive,
212    P: core::ops::BitXorAssign,
213{
214    fn bitxor_assign(&mut self, rhs: Self) {
215        self.val ^= rhs.val;
216    }
217}
218
219impl<P> core::ops::ShlAssign<Self> for Line<P>
220where
221    P: CubePrimitive,
222    P: core::ops::ShlAssign,
223{
224    fn shl_assign(&mut self, rhs: Self) {
225        self.val <<= rhs.val;
226    }
227}
228
229impl<P> core::ops::ShrAssign<Self> for Line<P>
230where
231    P: CubePrimitive,
232    P: core::ops::ShrAssign,
233{
234    fn shr_assign(&mut self, rhs: Self) {
235        self.val >>= rhs.val;
236    }
237}
238
239impl<P: CubePrimitive + Abs> Abs for Line<P> {}
240impl<P: CubePrimitive + Max> Max for Line<P> {}
241impl<P: CubePrimitive + Min> Min for Line<P> {}
242impl<P: CubePrimitive + Clamp> Clamp for Line<P> {}
243impl<P: CubePrimitive + Log> Log for Line<P> {}
244impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
245impl<P: CubePrimitive + Erf> Erf for Line<P> {}
246impl<P: CubePrimitive + Exp> Exp for Line<P> {}
247impl<P: CubePrimitive + Powf> Powf for Line<P> {}
248impl<P: CubePrimitive + Powi<I>, I: CubePrimitive> Powi<Line<I>> for Line<P> {}
249impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
250impl<P: CubePrimitive + Cos> Cos for Line<P> {}
251impl<P: CubePrimitive + Sin> Sin for Line<P> {}
252impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
253impl<P: CubePrimitive + Recip> Recip for Line<P> {}
254impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
255impl<P: CubePrimitive + Round> Round for Line<P> {}
256impl<P: CubePrimitive + Floor> Floor for Line<P> {}
257impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
258impl<P: CubePrimitive + Trunc> Trunc for Line<P> {}
259impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
260impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
261impl<P: CubePrimitive + SaturatingAdd> SaturatingAdd for Line<P> {}
262impl<P: CubePrimitive + SaturatingSub> SaturatingSub for Line<P> {}
263impl<P: CubePrimitive + IsNan> IsNan for Line<P> {}
264impl<P: CubePrimitive + IsInf> IsInf for Line<P> {}
265
266#[cube]
267impl<P: CountOnes> Line<P> {
268    pub fn count_ones(self) -> Line<u32> {
269        intrinsic!(|scope| {
270            let out_item =
271                Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
272            let out = scope.create_local(out_item);
273            scope.register(Instruction::new(
274                Bitwise::CountOnes(UnaryOperator {
275                    input: *self.expand,
276                }),
277                *out,
278            ));
279            out.into()
280        })
281    }
282}
283
284#[cube]
285impl<P: LeadingZeros> Line<P> {
286    pub fn leading_zeros(self) -> Line<u32> {
287        intrinsic!(|scope| {
288            let out_item =
289                Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
290            let out = scope.create_local(out_item);
291            scope.register(Instruction::new(
292                Bitwise::LeadingZeros(UnaryOperator {
293                    input: *self.expand,
294                }),
295                *out,
296            ));
297            out.into()
298        })
299    }
300}
301
302#[cube]
303impl<P: FindFirstSet> Line<P> {
304    pub fn find_first_set(self) -> Line<u32> {
305        intrinsic!(|scope| {
306            let out_item =
307                Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
308            let out = scope.create_local(out_item);
309            scope.register(Instruction::new(
310                Bitwise::FindFirstSet(UnaryOperator {
311                    input: *self.expand,
312                }),
313                *out,
314            ));
315            out.into()
316        })
317    }
318}
319
320impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
321    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
322        let val: P = NumCast::from(n)?;
323        Some(Self { val })
324    }
325}
326impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
327    fn to_i64(&self) -> Option<i64> {
328        self.val.to_i64()
329    }
330
331    fn to_u64(&self) -> Option<u64> {
332        self.val.to_u64()
333    }
334}
335
336#[allow(clippy::from_over_into)]
337impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
338    fn into(self) -> ExpandElementTyped<Self> {
339        let elem: ExpandElementTyped<P> = self.val.into();
340        elem.expand.into()
341    }
342}
343
344macro_rules! operation_literal {
345    ($lit:ty) => {
346        impl<P> core::ops::Add<$lit> for Line<P>
347        where
348            P: CubePrimitive,
349        {
350            type Output = Self;
351
352            fn add(self, _rhs: $lit) -> Self::Output {
353                unexpanded!();
354            }
355        }
356
357        impl<P> core::ops::Sub<$lit> for Line<P>
358        where
359            P: CubePrimitive,
360        {
361            type Output = Self;
362
363            fn sub(self, _rhs: $lit) -> Self::Output {
364                unexpanded!();
365            }
366        }
367
368        impl<P> core::ops::Mul<$lit> for Line<P>
369        where
370            P: CubePrimitive,
371        {
372            type Output = Self;
373
374            fn mul(self, _rhs: $lit) -> Self::Output {
375                unexpanded!();
376            }
377        }
378
379        impl<P> core::ops::Div<$lit> for Line<P>
380        where
381            P: CubePrimitive,
382        {
383            type Output = Self;
384
385            fn div(self, _rhs: $lit) -> Self::Output {
386                unexpanded!();
387            }
388        }
389    };
390}
391
392operation_literal!(f32);
393operation_literal!(f64);
394operation_literal!(usize);
395operation_literal!(i32);
396operation_literal!(i64);