cubecl_core/frontend/container/line/
ops.rs

1use cubecl_ir::{Bitwise, Elem, Instruction, Scope, UIntKind, UnaryOperator};
2use num_traits::{NumCast, ToPrimitive};
3
4use crate::{
5    frontend::{
6        Abs, Ceil, Clamp, Cos, CubeIndex, CubeIndexMut, CubePrimitive, Erf, Exp,
7        ExpandElementTyped, Floor, Log, Log1p, Max, Min, Powf, Recip, Remainder, Round, Sin, Sqrt,
8        Tanh,
9    },
10    prelude::{BitwiseNot, CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
11    unexpanded,
12};
13
14use super::Line;
15
16impl<P> core::ops::Add<Self> for Line<P>
17where
18    P: CubePrimitive,
19    P: core::ops::Add<P, Output = P>,
20{
21    type Output = Self;
22
23    fn add(self, rhs: Self) -> Self::Output {
24        Self::new(self.val + rhs.val)
25    }
26}
27
28impl<P> core::ops::Sub<Self> for Line<P>
29where
30    P: CubePrimitive,
31    P: core::ops::Sub<P, Output = P>,
32{
33    type Output = Self;
34
35    fn sub(self, rhs: Self) -> Self::Output {
36        Self::new(self.val - rhs.val)
37    }
38}
39
40impl<P> core::ops::Mul<Self> for Line<P>
41where
42    P: CubePrimitive,
43    P: core::ops::Mul<P, Output = P>,
44{
45    type Output = Self;
46
47    fn mul(self, rhs: Self) -> Self::Output {
48        Self::new(self.val * rhs.val)
49    }
50}
51
52impl<P> core::ops::Div<Self> for Line<P>
53where
54    P: CubePrimitive,
55    P: core::ops::Div<P, Output = P>,
56{
57    type Output = Self;
58
59    fn div(self, rhs: Self) -> Self::Output {
60        Self::new(self.val / rhs.val)
61    }
62}
63
64impl<P> core::ops::AddAssign<Self> for Line<P>
65where
66    P: CubePrimitive,
67    P: core::ops::AddAssign,
68{
69    fn add_assign(&mut self, rhs: Self) {
70        self.val += rhs.val;
71    }
72}
73
74impl<P> core::ops::SubAssign<Self> for Line<P>
75where
76    P: CubePrimitive,
77    P: core::ops::SubAssign,
78{
79    fn sub_assign(&mut self, rhs: Self) {
80        self.val -= rhs.val;
81    }
82}
83
84impl<P> core::ops::DivAssign<Self> for Line<P>
85where
86    P: CubePrimitive,
87    P: core::ops::DivAssign,
88{
89    fn div_assign(&mut self, rhs: Self) {
90        self.val /= rhs.val;
91    }
92}
93
94impl<P> core::ops::MulAssign<Self> for Line<P>
95where
96    P: CubePrimitive,
97    P: core::ops::MulAssign,
98{
99    fn mul_assign(&mut self, rhs: Self) {
100        self.val *= rhs.val;
101    }
102}
103
104impl<P> core::cmp::PartialEq for Line<P>
105where
106    P: CubePrimitive,
107    P: core::cmp::PartialEq,
108{
109    fn eq(&self, other: &Self) -> bool {
110        self.val.eq(&other.val)
111    }
112}
113
114impl<P> core::cmp::PartialOrd for Line<P>
115where
116    P: CubePrimitive,
117    P: core::cmp::PartialOrd,
118{
119    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
120        self.val.partial_cmp(&other.val)
121    }
122}
123
124impl<P> core::ops::BitAnd<Self> for Line<P>
125where
126    P: CubePrimitive,
127    P: core::ops::BitAnd<P, Output = P>,
128{
129    type Output = Self;
130
131    fn bitand(self, rhs: Self) -> Self::Output {
132        Self::new(self.val & rhs.val)
133    }
134}
135
136impl<P> core::ops::BitOr<Self> for Line<P>
137where
138    P: CubePrimitive,
139    P: core::ops::BitOr<P, Output = P>,
140{
141    type Output = Self;
142
143    fn bitor(self, rhs: Self) -> Self::Output {
144        Self::new(self.val | rhs.val)
145    }
146}
147
148impl<P> core::ops::BitXor<Self> for Line<P>
149where
150    P: CubePrimitive,
151    P: core::ops::BitXor<P, Output = P>,
152{
153    type Output = Self;
154
155    fn bitxor(self, rhs: Self) -> Self::Output {
156        Self::new(self.val ^ rhs.val)
157    }
158}
159
160impl<P> core::ops::Shl<Self> for Line<P>
161where
162    P: CubePrimitive,
163    P: core::ops::Shl<P, Output = P>,
164{
165    type Output = Self;
166
167    fn shl(self, rhs: Self) -> Self::Output {
168        Self::new(self.val << rhs.val)
169    }
170}
171
172impl<P> core::ops::Shr<Self> for Line<P>
173where
174    P: CubePrimitive,
175    P: core::ops::Shr<P, Output = P>,
176{
177    type Output = Self;
178
179    fn shr(self, rhs: Self) -> Self::Output {
180        Self::new(self.val >> rhs.val)
181    }
182}
183
184impl<P> core::ops::BitAndAssign<Self> for Line<P>
185where
186    P: CubePrimitive,
187    P: core::ops::BitAndAssign,
188{
189    fn bitand_assign(&mut self, rhs: Self) {
190        self.val &= rhs.val;
191    }
192}
193
194impl<P> core::ops::BitOrAssign<Self> for Line<P>
195where
196    P: CubePrimitive,
197    P: core::ops::BitOrAssign,
198{
199    fn bitor_assign(&mut self, rhs: Self) {
200        self.val |= rhs.val;
201    }
202}
203
204impl<P> core::ops::BitXorAssign<Self> for Line<P>
205where
206    P: CubePrimitive,
207    P: core::ops::BitXorAssign,
208{
209    fn bitxor_assign(&mut self, rhs: Self) {
210        self.val ^= rhs.val;
211    }
212}
213
214impl<P> core::ops::ShlAssign<Self> for Line<P>
215where
216    P: CubePrimitive,
217    P: core::ops::ShlAssign,
218{
219    fn shl_assign(&mut self, rhs: Self) {
220        self.val <<= rhs.val;
221    }
222}
223
224impl<P> core::ops::ShrAssign<Self> for Line<P>
225where
226    P: CubePrimitive,
227    P: core::ops::ShrAssign,
228{
229    fn shr_assign(&mut self, rhs: Self) {
230        self.val >>= rhs.val;
231    }
232}
233
234impl<P: CubePrimitive + Abs> Abs for Line<P> {}
235impl<P: CubePrimitive + Max> Max for Line<P> {}
236impl<P: CubePrimitive + Min> Min for Line<P> {}
237impl<P: CubePrimitive + Clamp> Clamp for Line<P> {}
238impl<P: CubePrimitive + Log> Log for Line<P> {}
239impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
240impl<P: CubePrimitive + Erf> Erf for Line<P> {}
241impl<P: CubePrimitive + Exp> Exp for Line<P> {}
242impl<P: CubePrimitive + Powf> Powf for Line<P> {}
243impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
244impl<P: CubePrimitive + Cos> Cos for Line<P> {}
245impl<P: CubePrimitive + Sin> Sin for Line<P> {}
246impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
247impl<P: CubePrimitive + Recip> Recip for Line<P> {}
248impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
249impl<P: CubePrimitive + Round> Round for Line<P> {}
250impl<P: CubePrimitive + Floor> Floor for Line<P> {}
251impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
252impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
253impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
254
255impl<P: CountOnes> Line<P> {
256    pub fn count_ones(self) -> Line<u32> {
257        unexpanded!()
258    }
259
260    pub fn __expand_count_ones(
261        scope: &mut Scope,
262        value: ExpandElementTyped<Self>,
263    ) -> ExpandElementTyped<Line<u32>> {
264        value.__expand_count_ones_method(scope)
265    }
266}
267
268impl<P: CountOnes> ExpandElementTyped<Line<P>> {
269    pub fn __expand_count_ones_method(self, scope: &mut Scope) -> ExpandElementTyped<Line<u32>> {
270        let mut out_item = self.expand.item;
271        out_item.elem = Elem::UInt(UIntKind::U32);
272        let out = scope.create_local(out_item);
273        scope.register(Instruction::new(
274            Bitwise::CountOnes(UnaryOperator {
275                input: *self.expand,
276            }),
277            *out,
278        ));
279        out.into()
280    }
281}
282
283impl<P: LeadingZeros> Line<P> {
284    pub fn leading_zeros(self) -> Line<u32> {
285        unexpanded!()
286    }
287
288    pub fn __expand_leading_zeros(
289        scope: &mut Scope,
290        value: ExpandElementTyped<Self>,
291    ) -> ExpandElementTyped<Line<u32>> {
292        value.__expand_leading_zeros_method(scope)
293    }
294}
295
296impl<P: LeadingZeros> ExpandElementTyped<Line<P>> {
297    pub fn __expand_leading_zeros_method(self, scope: &mut Scope) -> ExpandElementTyped<Line<u32>> {
298        let mut out_item = self.expand.item;
299        out_item.elem = Elem::UInt(UIntKind::U32);
300        let out = scope.create_local(out_item);
301        scope.register(Instruction::new(
302            Bitwise::LeadingZeros(UnaryOperator {
303                input: *self.expand,
304            }),
305            *out,
306        ));
307        out.into()
308    }
309}
310
311impl<P: FindFirstSet> Line<P> {
312    pub fn find_first_set(self) -> Line<u32> {
313        unexpanded!()
314    }
315
316    pub fn __expand_find_first_set(
317        scope: &mut Scope,
318        value: ExpandElementTyped<Self>,
319    ) -> ExpandElementTyped<Line<u32>> {
320        value.__expand_find_first_set_method(scope)
321    }
322}
323
324impl<P: FindFirstSet> ExpandElementTyped<Line<P>> {
325    pub fn __expand_find_first_set_method(
326        self,
327        scope: &mut Scope,
328    ) -> ExpandElementTyped<Line<u32>> {
329        let mut out_item = self.expand.item;
330        out_item.elem = Elem::UInt(UIntKind::U32);
331        let out = scope.create_local(out_item);
332        scope.register(Instruction::new(
333            Bitwise::FindFirstSet(UnaryOperator {
334                input: *self.expand,
335            }),
336            *out,
337        ));
338        out.into()
339    }
340}
341
342impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
343    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
344        let val: P = NumCast::from(n)?;
345        Some(Self { val })
346    }
347}
348impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
349    fn to_i64(&self) -> Option<i64> {
350        self.val.to_i64()
351    }
352
353    fn to_u64(&self) -> Option<u64> {
354        self.val.to_u64()
355    }
356}
357
358impl<P> CubeIndex<u32> for Line<P>
359where
360    P: CubePrimitive,
361{
362    type Output = P;
363
364    fn cube_idx(&self, _i: u32) -> &Self::Output {
365        unexpanded!()
366    }
367}
368
369impl<P> CubeIndexMut<u32> for Line<P>
370where
371    P: CubePrimitive,
372{
373    fn cube_idx_mut(&mut self, _i: u32) -> &mut Self::Output {
374        unexpanded!()
375    }
376}
377
378impl<P> CubeIndex<ExpandElementTyped<u32>> for Line<P>
379where
380    P: CubePrimitive,
381{
382    type Output = P;
383
384    fn cube_idx(&self, _i: ExpandElementTyped<u32>) -> &Self::Output {
385        unexpanded!()
386    }
387}
388
389impl<P> CubeIndexMut<ExpandElementTyped<u32>> for Line<P>
390where
391    P: CubePrimitive,
392{
393    fn cube_idx_mut(&mut self, _i: ExpandElementTyped<u32>) -> &mut Self::Output {
394        unexpanded!()
395    }
396}
397
398#[allow(clippy::from_over_into)]
399impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
400    fn into(self) -> ExpandElementTyped<Self> {
401        let elem: ExpandElementTyped<P> = self.val.into();
402        elem.expand.into()
403    }
404}
405
406macro_rules! operation_literal {
407    ($lit:ty) => {
408        impl<P> core::ops::Add<$lit> for Line<P>
409        where
410            P: CubePrimitive,
411        {
412            type Output = Self;
413
414            fn add(self, _rhs: $lit) -> Self::Output {
415                unexpanded!();
416            }
417        }
418
419        impl<P> core::ops::Sub<$lit> for Line<P>
420        where
421            P: CubePrimitive,
422        {
423            type Output = Self;
424
425            fn sub(self, _rhs: $lit) -> Self::Output {
426                unexpanded!();
427            }
428        }
429
430        impl<P> core::ops::Mul<$lit> for Line<P>
431        where
432            P: CubePrimitive,
433        {
434            type Output = Self;
435
436            fn mul(self, _rhs: $lit) -> Self::Output {
437                unexpanded!();
438            }
439        }
440
441        impl<P> core::ops::Div<$lit> for Line<P>
442        where
443            P: CubePrimitive,
444        {
445            type Output = Self;
446
447            fn div(self, _rhs: $lit) -> Self::Output {
448                unexpanded!();
449            }
450        }
451    };
452}
453
454operation_literal!(f32);
455operation_literal!(f64);
456operation_literal!(usize);
457operation_literal!(i32);
458operation_literal!(i64);