cubecl_core/frontend/container/line/
ops.rs1use cubecl_ir::{Bitwise, Elem, Instruction, UIntKind, UnaryOperator};
2use cubecl_macros::{cube, intrinsic};
3use num_traits::{NumCast, ToPrimitive};
4
5use crate as cubecl;
6use crate::{
7 frontend::{
8 Abs, Ceil, Clamp, Cos, CubePrimitive, Erf, Exp, ExpandElementTyped, Floor, Log, Log1p, Max,
9 Min, Powf, Recip, Remainder, Round, Sin, Sqrt, Tanh,
10 },
11 prelude::{BitwiseNot, CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
12 unexpanded,
13};
14
15use super::Line;
16type LineExpand<E> = ExpandElementTyped<Line<E>>;
17
18impl<P> core::ops::Add<Self> for Line<P>
19where
20 P: CubePrimitive,
21 P: core::ops::Add<P, Output = P>,
22{
23 type Output = Self;
24
25 fn add(self, rhs: Self) -> Self::Output {
26 Self::new(self.val + rhs.val)
27 }
28}
29
30impl<P> core::ops::Sub<Self> for Line<P>
31where
32 P: CubePrimitive,
33 P: core::ops::Sub<P, Output = P>,
34{
35 type Output = Self;
36
37 fn sub(self, rhs: Self) -> Self::Output {
38 Self::new(self.val - rhs.val)
39 }
40}
41
42impl<P> core::ops::Mul<Self> for Line<P>
43where
44 P: CubePrimitive,
45 P: core::ops::Mul<P, Output = P>,
46{
47 type Output = Self;
48
49 fn mul(self, rhs: Self) -> Self::Output {
50 Self::new(self.val * rhs.val)
51 }
52}
53
54impl<P> core::ops::Div<Self> for Line<P>
55where
56 P: CubePrimitive,
57 P: core::ops::Div<P, Output = P>,
58{
59 type Output = Self;
60
61 fn div(self, rhs: Self) -> Self::Output {
62 Self::new(self.val / rhs.val)
63 }
64}
65
66impl<P> core::ops::AddAssign<Self> for Line<P>
67where
68 P: CubePrimitive,
69 P: core::ops::AddAssign,
70{
71 fn add_assign(&mut self, rhs: Self) {
72 self.val += rhs.val;
73 }
74}
75
76impl<P> core::ops::SubAssign<Self> for Line<P>
77where
78 P: CubePrimitive,
79 P: core::ops::SubAssign,
80{
81 fn sub_assign(&mut self, rhs: Self) {
82 self.val -= rhs.val;
83 }
84}
85
86impl<P> core::ops::DivAssign<Self> for Line<P>
87where
88 P: CubePrimitive,
89 P: core::ops::DivAssign,
90{
91 fn div_assign(&mut self, rhs: Self) {
92 self.val /= rhs.val;
93 }
94}
95
96impl<P> core::ops::MulAssign<Self> for Line<P>
97where
98 P: CubePrimitive,
99 P: core::ops::MulAssign,
100{
101 fn mul_assign(&mut self, rhs: Self) {
102 self.val *= rhs.val;
103 }
104}
105
106impl<P> core::cmp::PartialEq for Line<P>
107where
108 P: CubePrimitive,
109 P: core::cmp::PartialEq,
110{
111 fn eq(&self, other: &Self) -> bool {
112 self.val.eq(&other.val)
113 }
114}
115
116impl<P> core::cmp::PartialOrd for Line<P>
117where
118 P: CubePrimitive,
119 P: core::cmp::PartialOrd,
120{
121 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122 self.val.partial_cmp(&other.val)
123 }
124}
125
126impl<P> core::ops::BitAnd<Self> for Line<P>
127where
128 P: CubePrimitive,
129 P: core::ops::BitAnd<P, Output = P>,
130{
131 type Output = Self;
132
133 fn bitand(self, rhs: Self) -> Self::Output {
134 Self::new(self.val & rhs.val)
135 }
136}
137
138impl<P> core::ops::BitOr<Self> for Line<P>
139where
140 P: CubePrimitive,
141 P: core::ops::BitOr<P, Output = P>,
142{
143 type Output = Self;
144
145 fn bitor(self, rhs: Self) -> Self::Output {
146 Self::new(self.val | rhs.val)
147 }
148}
149
150impl<P> core::ops::BitXor<Self> for Line<P>
151where
152 P: CubePrimitive,
153 P: core::ops::BitXor<P, Output = P>,
154{
155 type Output = Self;
156
157 fn bitxor(self, rhs: Self) -> Self::Output {
158 Self::new(self.val ^ rhs.val)
159 }
160}
161
162impl<P> core::ops::Shl<Self> for Line<P>
163where
164 P: CubePrimitive,
165 P: core::ops::Shl<P, Output = P>,
166{
167 type Output = Self;
168
169 fn shl(self, rhs: Self) -> Self::Output {
170 Self::new(self.val << rhs.val)
171 }
172}
173
174impl<P> core::ops::Shr<Self> for Line<P>
175where
176 P: CubePrimitive,
177 P: core::ops::Shr<P, Output = P>,
178{
179 type Output = Self;
180
181 fn shr(self, rhs: Self) -> Self::Output {
182 Self::new(self.val >> rhs.val)
183 }
184}
185
186impl<P> core::ops::BitAndAssign<Self> for Line<P>
187where
188 P: CubePrimitive,
189 P: core::ops::BitAndAssign,
190{
191 fn bitand_assign(&mut self, rhs: Self) {
192 self.val &= rhs.val;
193 }
194}
195
196impl<P> core::ops::BitOrAssign<Self> for Line<P>
197where
198 P: CubePrimitive,
199 P: core::ops::BitOrAssign,
200{
201 fn bitor_assign(&mut self, rhs: Self) {
202 self.val |= rhs.val;
203 }
204}
205
206impl<P> core::ops::BitXorAssign<Self> for Line<P>
207where
208 P: CubePrimitive,
209 P: core::ops::BitXorAssign,
210{
211 fn bitxor_assign(&mut self, rhs: Self) {
212 self.val ^= rhs.val;
213 }
214}
215
216impl<P> core::ops::ShlAssign<Self> for Line<P>
217where
218 P: CubePrimitive,
219 P: core::ops::ShlAssign,
220{
221 fn shl_assign(&mut self, rhs: Self) {
222 self.val <<= rhs.val;
223 }
224}
225
226impl<P> core::ops::ShrAssign<Self> for Line<P>
227where
228 P: CubePrimitive,
229 P: core::ops::ShrAssign,
230{
231 fn shr_assign(&mut self, rhs: Self) {
232 self.val >>= rhs.val;
233 }
234}
235
236impl<P: CubePrimitive + Abs> Abs for Line<P> {}
237impl<P: CubePrimitive + Max> Max for Line<P> {}
238impl<P: CubePrimitive + Min> Min for Line<P> {}
239impl<P: CubePrimitive + Clamp> Clamp for Line<P> {}
240impl<P: CubePrimitive + Log> Log for Line<P> {}
241impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
242impl<P: CubePrimitive + Erf> Erf for Line<P> {}
243impl<P: CubePrimitive + Exp> Exp for Line<P> {}
244impl<P: CubePrimitive + Powf> Powf for Line<P> {}
245impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
246impl<P: CubePrimitive + Cos> Cos for Line<P> {}
247impl<P: CubePrimitive + Sin> Sin for Line<P> {}
248impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
249impl<P: CubePrimitive + Recip> Recip for Line<P> {}
250impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
251impl<P: CubePrimitive + Round> Round for Line<P> {}
252impl<P: CubePrimitive + Floor> Floor for Line<P> {}
253impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
254impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
255impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
256
257#[cube]
258impl<P: CountOnes> Line<P> {
259 pub fn count_ones(self) -> Line<u32> {
260 intrinsic!(|scope| {
261 let mut out_item = self.expand.item;
262 out_item.elem = Elem::UInt(UIntKind::U32);
263 let out = scope.create_local(out_item);
264 scope.register(Instruction::new(
265 Bitwise::CountOnes(UnaryOperator {
266 input: *self.expand,
267 }),
268 *out,
269 ));
270 out.into()
271 })
272 }
273}
274
275#[cube]
276impl<P: LeadingZeros> Line<P> {
277 pub fn leading_zeros(self) -> Line<u32> {
278 intrinsic!(|scope| {
279 let mut out_item = self.expand.item;
280 out_item.elem = Elem::UInt(UIntKind::U32);
281 let out = scope.create_local(out_item);
282 scope.register(Instruction::new(
283 Bitwise::LeadingZeros(UnaryOperator {
284 input: *self.expand,
285 }),
286 *out,
287 ));
288 out.into()
289 })
290 }
291}
292
293#[cube]
294impl<P: FindFirstSet> Line<P> {
295 pub fn find_first_set(self) -> Line<u32> {
296 intrinsic!(|scope| {
297 let mut out_item = self.expand.item;
298 out_item.elem = Elem::UInt(UIntKind::U32);
299 let out = scope.create_local(out_item);
300 scope.register(Instruction::new(
301 Bitwise::FindFirstSet(UnaryOperator {
302 input: *self.expand,
303 }),
304 *out,
305 ));
306 out.into()
307 })
308 }
309}
310
311impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
312 fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
313 let val: P = NumCast::from(n)?;
314 Some(Self { val })
315 }
316}
317impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
318 fn to_i64(&self) -> Option<i64> {
319 self.val.to_i64()
320 }
321
322 fn to_u64(&self) -> Option<u64> {
323 self.val.to_u64()
324 }
325}
326
327#[allow(clippy::from_over_into)]
328impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
329 fn into(self) -> ExpandElementTyped<Self> {
330 let elem: ExpandElementTyped<P> = self.val.into();
331 elem.expand.into()
332 }
333}
334
335macro_rules! operation_literal {
336 ($lit:ty) => {
337 impl<P> core::ops::Add<$lit> for Line<P>
338 where
339 P: CubePrimitive,
340 {
341 type Output = Self;
342
343 fn add(self, _rhs: $lit) -> Self::Output {
344 unexpanded!();
345 }
346 }
347
348 impl<P> core::ops::Sub<$lit> for Line<P>
349 where
350 P: CubePrimitive,
351 {
352 type Output = Self;
353
354 fn sub(self, _rhs: $lit) -> Self::Output {
355 unexpanded!();
356 }
357 }
358
359 impl<P> core::ops::Mul<$lit> for Line<P>
360 where
361 P: CubePrimitive,
362 {
363 type Output = Self;
364
365 fn mul(self, _rhs: $lit) -> Self::Output {
366 unexpanded!();
367 }
368 }
369
370 impl<P> core::ops::Div<$lit> for Line<P>
371 where
372 P: CubePrimitive,
373 {
374 type Output = Self;
375
376 fn div(self, _rhs: $lit) -> Self::Output {
377 unexpanded!();
378 }
379 }
380 };
381}
382
383operation_literal!(f32);
384operation_literal!(f64);
385operation_literal!(usize);
386operation_literal!(i32);
387operation_literal!(i64);