cubecl_core/frontend/container/line/
ops.rs1use cubecl_ir::{Bitwise, ElemType, Instruction, Type, UIntKind, UnaryOperator};
2use cubecl_macros::{cube, intrinsic};
3use num_traits::{NumCast, ToPrimitive};
4
5use crate::{
6 self as cubecl,
7 prelude::{InverseSqrt, IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub, Trunc},
8};
9use crate::{
10 frontend::{
11 Abs, Ceil, Clamp, Cos, CubePrimitive, Erf, Exp, ExpandElementTyped, Floor, Log, Log1p, Max,
12 Min, Powf, Recip, Remainder, Round, Sin, Sqrt, Tanh,
13 },
14 prelude::{BitwiseNot, CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
15 unexpanded,
16};
17
18use super::Line;
19type LineExpand<E> = ExpandElementTyped<Line<E>>;
20
21impl<P> core::ops::Add<Self> for Line<P>
22where
23 P: CubePrimitive,
24 P: core::ops::Add<P, Output = P>,
25{
26 type Output = Self;
27
28 fn add(self, rhs: Self) -> Self::Output {
29 Self::new(self.val + rhs.val)
30 }
31}
32
33impl<P> core::ops::Sub<Self> for Line<P>
34where
35 P: CubePrimitive,
36 P: core::ops::Sub<P, Output = P>,
37{
38 type Output = Self;
39
40 fn sub(self, rhs: Self) -> Self::Output {
41 Self::new(self.val - rhs.val)
42 }
43}
44
45impl<P> core::ops::Mul<Self> for Line<P>
46where
47 P: CubePrimitive,
48 P: core::ops::Mul<P, Output = P>,
49{
50 type Output = Self;
51
52 fn mul(self, rhs: Self) -> Self::Output {
53 Self::new(self.val * rhs.val)
54 }
55}
56
57impl<P> core::ops::Div<Self> for Line<P>
58where
59 P: CubePrimitive,
60 P: core::ops::Div<P, Output = P>,
61{
62 type Output = Self;
63
64 fn div(self, rhs: Self) -> Self::Output {
65 Self::new(self.val / rhs.val)
66 }
67}
68
69impl<P> core::ops::AddAssign<Self> for Line<P>
70where
71 P: CubePrimitive,
72 P: core::ops::AddAssign,
73{
74 fn add_assign(&mut self, rhs: Self) {
75 self.val += rhs.val;
76 }
77}
78
79impl<P> core::ops::SubAssign<Self> for Line<P>
80where
81 P: CubePrimitive,
82 P: core::ops::SubAssign,
83{
84 fn sub_assign(&mut self, rhs: Self) {
85 self.val -= rhs.val;
86 }
87}
88
89impl<P> core::ops::DivAssign<Self> for Line<P>
90where
91 P: CubePrimitive,
92 P: core::ops::DivAssign,
93{
94 fn div_assign(&mut self, rhs: Self) {
95 self.val /= rhs.val;
96 }
97}
98
99impl<P> core::ops::MulAssign<Self> for Line<P>
100where
101 P: CubePrimitive,
102 P: core::ops::MulAssign,
103{
104 fn mul_assign(&mut self, rhs: Self) {
105 self.val *= rhs.val;
106 }
107}
108
109impl<P> core::cmp::PartialEq for Line<P>
110where
111 P: CubePrimitive,
112 P: core::cmp::PartialEq,
113{
114 fn eq(&self, other: &Self) -> bool {
115 self.val.eq(&other.val)
116 }
117}
118
119impl<P> core::cmp::PartialOrd for Line<P>
120where
121 P: CubePrimitive,
122 P: core::cmp::PartialOrd,
123{
124 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
125 self.val.partial_cmp(&other.val)
126 }
127}
128
129impl<P> core::ops::BitAnd<Self> for Line<P>
130where
131 P: CubePrimitive,
132 P: core::ops::BitAnd<P, Output = P>,
133{
134 type Output = Self;
135
136 fn bitand(self, rhs: Self) -> Self::Output {
137 Self::new(self.val & rhs.val)
138 }
139}
140
141impl<P> core::ops::BitOr<Self> for Line<P>
142where
143 P: CubePrimitive,
144 P: core::ops::BitOr<P, Output = P>,
145{
146 type Output = Self;
147
148 fn bitor(self, rhs: Self) -> Self::Output {
149 Self::new(self.val | rhs.val)
150 }
151}
152
153impl<P> core::ops::BitXor<Self> for Line<P>
154where
155 P: CubePrimitive,
156 P: core::ops::BitXor<P, Output = P>,
157{
158 type Output = Self;
159
160 fn bitxor(self, rhs: Self) -> Self::Output {
161 Self::new(self.val ^ rhs.val)
162 }
163}
164
165impl<P> core::ops::Shl<Self> for Line<P>
166where
167 P: CubePrimitive,
168 P: core::ops::Shl<P, Output = P>,
169{
170 type Output = Self;
171
172 fn shl(self, rhs: Self) -> Self::Output {
173 Self::new(self.val << rhs.val)
174 }
175}
176
177impl<P> core::ops::Shr<Self> for Line<P>
178where
179 P: CubePrimitive,
180 P: core::ops::Shr<P, Output = P>,
181{
182 type Output = Self;
183
184 fn shr(self, rhs: Self) -> Self::Output {
185 Self::new(self.val >> rhs.val)
186 }
187}
188
189impl<P> core::ops::BitAndAssign<Self> for Line<P>
190where
191 P: CubePrimitive,
192 P: core::ops::BitAndAssign,
193{
194 fn bitand_assign(&mut self, rhs: Self) {
195 self.val &= rhs.val;
196 }
197}
198
199impl<P> core::ops::BitOrAssign<Self> for Line<P>
200where
201 P: CubePrimitive,
202 P: core::ops::BitOrAssign,
203{
204 fn bitor_assign(&mut self, rhs: Self) {
205 self.val |= rhs.val;
206 }
207}
208
209impl<P> core::ops::BitXorAssign<Self> for Line<P>
210where
211 P: CubePrimitive,
212 P: core::ops::BitXorAssign,
213{
214 fn bitxor_assign(&mut self, rhs: Self) {
215 self.val ^= rhs.val;
216 }
217}
218
219impl<P> core::ops::ShlAssign<Self> for Line<P>
220where
221 P: CubePrimitive,
222 P: core::ops::ShlAssign,
223{
224 fn shl_assign(&mut self, rhs: Self) {
225 self.val <<= rhs.val;
226 }
227}
228
229impl<P> core::ops::ShrAssign<Self> for Line<P>
230where
231 P: CubePrimitive,
232 P: core::ops::ShrAssign,
233{
234 fn shr_assign(&mut self, rhs: Self) {
235 self.val >>= rhs.val;
236 }
237}
238
239impl<P: CubePrimitive + Abs> Abs for Line<P> {}
240impl<P: CubePrimitive + Max> Max for Line<P> {}
241impl<P: CubePrimitive + Min> Min for Line<P> {}
242impl<P: CubePrimitive + Clamp> Clamp for Line<P> {}
243impl<P: CubePrimitive + Log> Log for Line<P> {}
244impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
245impl<P: CubePrimitive + Erf> Erf for Line<P> {}
246impl<P: CubePrimitive + Exp> Exp for Line<P> {}
247impl<P: CubePrimitive + Powf> Powf for Line<P> {}
248impl<P: CubePrimitive + Powi<I>, I: CubePrimitive> Powi<Line<I>> for Line<P> {}
249impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
250impl<P: CubePrimitive + InverseSqrt> InverseSqrt for Line<P> {}
251impl<P: CubePrimitive + Cos> Cos for Line<P> {}
252impl<P: CubePrimitive + Sin> Sin for Line<P> {}
253impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
254impl<P: CubePrimitive + Recip> Recip for Line<P> {}
255impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
256impl<P: CubePrimitive + Round> Round for Line<P> {}
257impl<P: CubePrimitive + Floor> Floor for Line<P> {}
258impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
259impl<P: CubePrimitive + Trunc> Trunc for Line<P> {}
260impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
261impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
262impl<P: CubePrimitive + SaturatingAdd> SaturatingAdd for Line<P> {}
263impl<P: CubePrimitive + SaturatingSub> SaturatingSub for Line<P> {}
264impl<P: CubePrimitive + IsNan> IsNan for Line<P> {}
265impl<P: CubePrimitive + IsInf> IsInf for Line<P> {}
266
267#[cube]
268impl<P: CountOnes> Line<P> {
269 pub fn count_ones(self) -> Line<u32> {
270 intrinsic!(|scope| {
271 let out_item =
272 Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
273 let out = scope.create_local(out_item);
274 scope.register(Instruction::new(
275 Bitwise::CountOnes(UnaryOperator {
276 input: *self.expand,
277 }),
278 *out,
279 ));
280 out.into()
281 })
282 }
283}
284
285#[cube]
286impl<P: LeadingZeros> Line<P> {
287 pub fn leading_zeros(self) -> Line<u32> {
288 intrinsic!(|scope| {
289 let out_item =
290 Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
291 let out = scope.create_local(out_item);
292 scope.register(Instruction::new(
293 Bitwise::LeadingZeros(UnaryOperator {
294 input: *self.expand,
295 }),
296 *out,
297 ));
298 out.into()
299 })
300 }
301}
302
303#[cube]
304impl<P: FindFirstSet> Line<P> {
305 pub fn find_first_set(self) -> Line<u32> {
306 intrinsic!(|scope| {
307 let out_item =
308 Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
309 let out = scope.create_local(out_item);
310 scope.register(Instruction::new(
311 Bitwise::FindFirstSet(UnaryOperator {
312 input: *self.expand,
313 }),
314 *out,
315 ));
316 out.into()
317 })
318 }
319}
320
321impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
322 fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
323 let val: P = NumCast::from(n)?;
324 Some(Self { val })
325 }
326}
327impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
328 fn to_i64(&self) -> Option<i64> {
329 self.val.to_i64()
330 }
331
332 fn to_u64(&self) -> Option<u64> {
333 self.val.to_u64()
334 }
335}
336
337#[allow(clippy::from_over_into)]
338impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
339 fn into(self) -> ExpandElementTyped<Self> {
340 let elem: ExpandElementTyped<P> = self.val.into();
341 elem.expand.into()
342 }
343}
344
345macro_rules! operation_literal {
346 ($lit:ty) => {
347 impl<P> core::ops::Add<$lit> for Line<P>
348 where
349 P: CubePrimitive,
350 {
351 type Output = Self;
352
353 fn add(self, _rhs: $lit) -> Self::Output {
354 unexpanded!();
355 }
356 }
357
358 impl<P> core::ops::Sub<$lit> for Line<P>
359 where
360 P: CubePrimitive,
361 {
362 type Output = Self;
363
364 fn sub(self, _rhs: $lit) -> Self::Output {
365 unexpanded!();
366 }
367 }
368
369 impl<P> core::ops::Mul<$lit> for Line<P>
370 where
371 P: CubePrimitive,
372 {
373 type Output = Self;
374
375 fn mul(self, _rhs: $lit) -> Self::Output {
376 unexpanded!();
377 }
378 }
379
380 impl<P> core::ops::Div<$lit> for Line<P>
381 where
382 P: CubePrimitive,
383 {
384 type Output = Self;
385
386 fn div(self, _rhs: $lit) -> Self::Output {
387 unexpanded!();
388 }
389 }
390 };
391}
392
393operation_literal!(f32);
394operation_literal!(f64);
395operation_literal!(usize);
396operation_literal!(i32);
397operation_literal!(i64);