cubecl_core/frontend/container/line/
ops.rs

1use cubecl_ir::{Bitwise, ElemType, Instruction, Type, UIntKind, UnaryOperator};
2use cubecl_macros::{cube, intrinsic};
3use num_traits::{NumCast, ToPrimitive};
4
5use crate::{
6    self as cubecl,
7    prelude::{InverseSqrt, IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub, Trunc},
8};
9use crate::{
10    frontend::{
11        Abs, Ceil, Clamp, Cos, CubePrimitive, Erf, Exp, ExpandElementTyped, Floor, Log, Log1p, Max,
12        Min, Powf, Recip, Remainder, Round, Sin, Sqrt, Tanh,
13    },
14    prelude::{BitwiseNot, CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
15    unexpanded,
16};
17
18use super::Line;
19type LineExpand<E> = ExpandElementTyped<Line<E>>;
20
21impl<P> core::ops::Add<Self> for Line<P>
22where
23    P: CubePrimitive,
24    P: core::ops::Add<P, Output = P>,
25{
26    type Output = Self;
27
28    fn add(self, rhs: Self) -> Self::Output {
29        Self::new(self.val + rhs.val)
30    }
31}
32
33impl<P> core::ops::Sub<Self> for Line<P>
34where
35    P: CubePrimitive,
36    P: core::ops::Sub<P, Output = P>,
37{
38    type Output = Self;
39
40    fn sub(self, rhs: Self) -> Self::Output {
41        Self::new(self.val - rhs.val)
42    }
43}
44
45impl<P> core::ops::Mul<Self> for Line<P>
46where
47    P: CubePrimitive,
48    P: core::ops::Mul<P, Output = P>,
49{
50    type Output = Self;
51
52    fn mul(self, rhs: Self) -> Self::Output {
53        Self::new(self.val * rhs.val)
54    }
55}
56
57impl<P> core::ops::Div<Self> for Line<P>
58where
59    P: CubePrimitive,
60    P: core::ops::Div<P, Output = P>,
61{
62    type Output = Self;
63
64    fn div(self, rhs: Self) -> Self::Output {
65        Self::new(self.val / rhs.val)
66    }
67}
68
69impl<P> core::ops::AddAssign<Self> for Line<P>
70where
71    P: CubePrimitive,
72    P: core::ops::AddAssign,
73{
74    fn add_assign(&mut self, rhs: Self) {
75        self.val += rhs.val;
76    }
77}
78
79impl<P> core::ops::SubAssign<Self> for Line<P>
80where
81    P: CubePrimitive,
82    P: core::ops::SubAssign,
83{
84    fn sub_assign(&mut self, rhs: Self) {
85        self.val -= rhs.val;
86    }
87}
88
89impl<P> core::ops::DivAssign<Self> for Line<P>
90where
91    P: CubePrimitive,
92    P: core::ops::DivAssign,
93{
94    fn div_assign(&mut self, rhs: Self) {
95        self.val /= rhs.val;
96    }
97}
98
99impl<P> core::ops::MulAssign<Self> for Line<P>
100where
101    P: CubePrimitive,
102    P: core::ops::MulAssign,
103{
104    fn mul_assign(&mut self, rhs: Self) {
105        self.val *= rhs.val;
106    }
107}
108
109impl<P> core::cmp::PartialEq for Line<P>
110where
111    P: CubePrimitive,
112    P: core::cmp::PartialEq,
113{
114    fn eq(&self, other: &Self) -> bool {
115        self.val.eq(&other.val)
116    }
117}
118
119impl<P> core::cmp::PartialOrd for Line<P>
120where
121    P: CubePrimitive,
122    P: core::cmp::PartialOrd,
123{
124    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
125        self.val.partial_cmp(&other.val)
126    }
127}
128
129impl<P> core::ops::BitAnd<Self> for Line<P>
130where
131    P: CubePrimitive,
132    P: core::ops::BitAnd<P, Output = P>,
133{
134    type Output = Self;
135
136    fn bitand(self, rhs: Self) -> Self::Output {
137        Self::new(self.val & rhs.val)
138    }
139}
140
141impl<P> core::ops::BitOr<Self> for Line<P>
142where
143    P: CubePrimitive,
144    P: core::ops::BitOr<P, Output = P>,
145{
146    type Output = Self;
147
148    fn bitor(self, rhs: Self) -> Self::Output {
149        Self::new(self.val | rhs.val)
150    }
151}
152
153impl<P> core::ops::BitXor<Self> for Line<P>
154where
155    P: CubePrimitive,
156    P: core::ops::BitXor<P, Output = P>,
157{
158    type Output = Self;
159
160    fn bitxor(self, rhs: Self) -> Self::Output {
161        Self::new(self.val ^ rhs.val)
162    }
163}
164
165impl<P> core::ops::Shl<Self> for Line<P>
166where
167    P: CubePrimitive,
168    P: core::ops::Shl<P, Output = P>,
169{
170    type Output = Self;
171
172    fn shl(self, rhs: Self) -> Self::Output {
173        Self::new(self.val << rhs.val)
174    }
175}
176
177impl<P> core::ops::Shr<Self> for Line<P>
178where
179    P: CubePrimitive,
180    P: core::ops::Shr<P, Output = P>,
181{
182    type Output = Self;
183
184    fn shr(self, rhs: Self) -> Self::Output {
185        Self::new(self.val >> rhs.val)
186    }
187}
188
189impl<P> core::ops::BitAndAssign<Self> for Line<P>
190where
191    P: CubePrimitive,
192    P: core::ops::BitAndAssign,
193{
194    fn bitand_assign(&mut self, rhs: Self) {
195        self.val &= rhs.val;
196    }
197}
198
199impl<P> core::ops::BitOrAssign<Self> for Line<P>
200where
201    P: CubePrimitive,
202    P: core::ops::BitOrAssign,
203{
204    fn bitor_assign(&mut self, rhs: Self) {
205        self.val |= rhs.val;
206    }
207}
208
209impl<P> core::ops::BitXorAssign<Self> for Line<P>
210where
211    P: CubePrimitive,
212    P: core::ops::BitXorAssign,
213{
214    fn bitxor_assign(&mut self, rhs: Self) {
215        self.val ^= rhs.val;
216    }
217}
218
219impl<P> core::ops::ShlAssign<Self> for Line<P>
220where
221    P: CubePrimitive,
222    P: core::ops::ShlAssign,
223{
224    fn shl_assign(&mut self, rhs: Self) {
225        self.val <<= rhs.val;
226    }
227}
228
229impl<P> core::ops::ShrAssign<Self> for Line<P>
230where
231    P: CubePrimitive,
232    P: core::ops::ShrAssign,
233{
234    fn shr_assign(&mut self, rhs: Self) {
235        self.val >>= rhs.val;
236    }
237}
238
239impl<P: CubePrimitive + Abs> Abs for Line<P> {}
240impl<P: CubePrimitive + Max> Max for Line<P> {}
241impl<P: CubePrimitive + Min> Min for Line<P> {}
242impl<P: CubePrimitive + Clamp> Clamp for Line<P> {}
243impl<P: CubePrimitive + Log> Log for Line<P> {}
244impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
245impl<P: CubePrimitive + Erf> Erf for Line<P> {}
246impl<P: CubePrimitive + Exp> Exp for Line<P> {}
247impl<P: CubePrimitive + Powf> Powf for Line<P> {}
248impl<P: CubePrimitive + Powi<I>, I: CubePrimitive> Powi<Line<I>> for Line<P> {}
249impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
250impl<P: CubePrimitive + InverseSqrt> InverseSqrt for Line<P> {}
251impl<P: CubePrimitive + Cos> Cos for Line<P> {}
252impl<P: CubePrimitive + Sin> Sin for Line<P> {}
253impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
254impl<P: CubePrimitive + Recip> Recip for Line<P> {}
255impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
256impl<P: CubePrimitive + Round> Round for Line<P> {}
257impl<P: CubePrimitive + Floor> Floor for Line<P> {}
258impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
259impl<P: CubePrimitive + Trunc> Trunc for Line<P> {}
260impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
261impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
262impl<P: CubePrimitive + SaturatingAdd> SaturatingAdd for Line<P> {}
263impl<P: CubePrimitive + SaturatingSub> SaturatingSub for Line<P> {}
264impl<P: CubePrimitive + IsNan> IsNan for Line<P> {}
265impl<P: CubePrimitive + IsInf> IsInf for Line<P> {}
266
267#[cube]
268impl<P: CountOnes> Line<P> {
269    pub fn count_ones(self) -> Line<u32> {
270        intrinsic!(|scope| {
271            let out_item =
272                Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
273            let out = scope.create_local(out_item);
274            scope.register(Instruction::new(
275                Bitwise::CountOnes(UnaryOperator {
276                    input: *self.expand,
277                }),
278                *out,
279            ));
280            out.into()
281        })
282    }
283}
284
285#[cube]
286impl<P: LeadingZeros> Line<P> {
287    pub fn leading_zeros(self) -> Line<u32> {
288        intrinsic!(|scope| {
289            let out_item =
290                Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
291            let out = scope.create_local(out_item);
292            scope.register(Instruction::new(
293                Bitwise::LeadingZeros(UnaryOperator {
294                    input: *self.expand,
295                }),
296                *out,
297            ));
298            out.into()
299        })
300    }
301}
302
303#[cube]
304impl<P: FindFirstSet> Line<P> {
305    pub fn find_first_set(self) -> Line<u32> {
306        intrinsic!(|scope| {
307            let out_item =
308                Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
309            let out = scope.create_local(out_item);
310            scope.register(Instruction::new(
311                Bitwise::FindFirstSet(UnaryOperator {
312                    input: *self.expand,
313                }),
314                *out,
315            ));
316            out.into()
317        })
318    }
319}
320
321impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
322    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
323        let val: P = NumCast::from(n)?;
324        Some(Self { val })
325    }
326}
327impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
328    fn to_i64(&self) -> Option<i64> {
329        self.val.to_i64()
330    }
331
332    fn to_u64(&self) -> Option<u64> {
333        self.val.to_u64()
334    }
335}
336
337#[allow(clippy::from_over_into)]
338impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
339    fn into(self) -> ExpandElementTyped<Self> {
340        let elem: ExpandElementTyped<P> = self.val.into();
341        elem.expand.into()
342    }
343}
344
345macro_rules! operation_literal {
346    ($lit:ty) => {
347        impl<P> core::ops::Add<$lit> for Line<P>
348        where
349            P: CubePrimitive,
350        {
351            type Output = Self;
352
353            fn add(self, _rhs: $lit) -> Self::Output {
354                unexpanded!();
355            }
356        }
357
358        impl<P> core::ops::Sub<$lit> for Line<P>
359        where
360            P: CubePrimitive,
361        {
362            type Output = Self;
363
364            fn sub(self, _rhs: $lit) -> Self::Output {
365                unexpanded!();
366            }
367        }
368
369        impl<P> core::ops::Mul<$lit> for Line<P>
370        where
371            P: CubePrimitive,
372        {
373            type Output = Self;
374
375            fn mul(self, _rhs: $lit) -> Self::Output {
376                unexpanded!();
377            }
378        }
379
380        impl<P> core::ops::Div<$lit> for Line<P>
381        where
382            P: CubePrimitive,
383        {
384            type Output = Self;
385
386            fn div(self, _rhs: $lit) -> Self::Output {
387                unexpanded!();
388            }
389        }
390    };
391}
392
393operation_literal!(f32);
394operation_literal!(f64);
395operation_literal!(usize);
396operation_literal!(i32);
397operation_literal!(i64);