cubecl_core/frontend/container/line/
ops.rs1use cubecl_ir::{Bitwise, ElemType, Instruction, Type, UIntKind, UnaryOperator};
2use cubecl_macros::{cube, intrinsic};
3use num_traits::{NumCast, ToPrimitive};
4
5use crate::{
6 self as cubecl,
7 prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub, Trunc},
8};
9use crate::{
10 frontend::{
11 Abs, Ceil, Clamp, Cos, CubePrimitive, Erf, Exp, ExpandElementTyped, Floor, Log, Log1p, Max,
12 Min, Powf, Recip, Remainder, Round, Sin, Sqrt, Tanh,
13 },
14 prelude::{BitwiseNot, CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
15 unexpanded,
16};
17
18use super::Line;
19type LineExpand<E> = ExpandElementTyped<Line<E>>;
20
21impl<P> core::ops::Add<Self> for Line<P>
22where
23 P: CubePrimitive,
24 P: core::ops::Add<P, Output = P>,
25{
26 type Output = Self;
27
28 fn add(self, rhs: Self) -> Self::Output {
29 Self::new(self.val + rhs.val)
30 }
31}
32
33impl<P> core::ops::Sub<Self> for Line<P>
34where
35 P: CubePrimitive,
36 P: core::ops::Sub<P, Output = P>,
37{
38 type Output = Self;
39
40 fn sub(self, rhs: Self) -> Self::Output {
41 Self::new(self.val - rhs.val)
42 }
43}
44
45impl<P> core::ops::Mul<Self> for Line<P>
46where
47 P: CubePrimitive,
48 P: core::ops::Mul<P, Output = P>,
49{
50 type Output = Self;
51
52 fn mul(self, rhs: Self) -> Self::Output {
53 Self::new(self.val * rhs.val)
54 }
55}
56
57impl<P> core::ops::Div<Self> for Line<P>
58where
59 P: CubePrimitive,
60 P: core::ops::Div<P, Output = P>,
61{
62 type Output = Self;
63
64 fn div(self, rhs: Self) -> Self::Output {
65 Self::new(self.val / rhs.val)
66 }
67}
68
69impl<P> core::ops::AddAssign<Self> for Line<P>
70where
71 P: CubePrimitive,
72 P: core::ops::AddAssign,
73{
74 fn add_assign(&mut self, rhs: Self) {
75 self.val += rhs.val;
76 }
77}
78
79impl<P> core::ops::SubAssign<Self> for Line<P>
80where
81 P: CubePrimitive,
82 P: core::ops::SubAssign,
83{
84 fn sub_assign(&mut self, rhs: Self) {
85 self.val -= rhs.val;
86 }
87}
88
89impl<P> core::ops::DivAssign<Self> for Line<P>
90where
91 P: CubePrimitive,
92 P: core::ops::DivAssign,
93{
94 fn div_assign(&mut self, rhs: Self) {
95 self.val /= rhs.val;
96 }
97}
98
99impl<P> core::ops::MulAssign<Self> for Line<P>
100where
101 P: CubePrimitive,
102 P: core::ops::MulAssign,
103{
104 fn mul_assign(&mut self, rhs: Self) {
105 self.val *= rhs.val;
106 }
107}
108
109impl<P> core::cmp::PartialEq for Line<P>
110where
111 P: CubePrimitive,
112 P: core::cmp::PartialEq,
113{
114 fn eq(&self, other: &Self) -> bool {
115 self.val.eq(&other.val)
116 }
117}
118
119impl<P> core::cmp::PartialOrd for Line<P>
120where
121 P: CubePrimitive,
122 P: core::cmp::PartialOrd,
123{
124 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
125 self.val.partial_cmp(&other.val)
126 }
127}
128
129impl<P> core::ops::BitAnd<Self> for Line<P>
130where
131 P: CubePrimitive,
132 P: core::ops::BitAnd<P, Output = P>,
133{
134 type Output = Self;
135
136 fn bitand(self, rhs: Self) -> Self::Output {
137 Self::new(self.val & rhs.val)
138 }
139}
140
141impl<P> core::ops::BitOr<Self> for Line<P>
142where
143 P: CubePrimitive,
144 P: core::ops::BitOr<P, Output = P>,
145{
146 type Output = Self;
147
148 fn bitor(self, rhs: Self) -> Self::Output {
149 Self::new(self.val | rhs.val)
150 }
151}
152
153impl<P> core::ops::BitXor<Self> for Line<P>
154where
155 P: CubePrimitive,
156 P: core::ops::BitXor<P, Output = P>,
157{
158 type Output = Self;
159
160 fn bitxor(self, rhs: Self) -> Self::Output {
161 Self::new(self.val ^ rhs.val)
162 }
163}
164
165impl<P> core::ops::Shl<Self> for Line<P>
166where
167 P: CubePrimitive,
168 P: core::ops::Shl<P, Output = P>,
169{
170 type Output = Self;
171
172 fn shl(self, rhs: Self) -> Self::Output {
173 Self::new(self.val << rhs.val)
174 }
175}
176
177impl<P> core::ops::Shr<Self> for Line<P>
178where
179 P: CubePrimitive,
180 P: core::ops::Shr<P, Output = P>,
181{
182 type Output = Self;
183
184 fn shr(self, rhs: Self) -> Self::Output {
185 Self::new(self.val >> rhs.val)
186 }
187}
188
189impl<P> core::ops::BitAndAssign<Self> for Line<P>
190where
191 P: CubePrimitive,
192 P: core::ops::BitAndAssign,
193{
194 fn bitand_assign(&mut self, rhs: Self) {
195 self.val &= rhs.val;
196 }
197}
198
199impl<P> core::ops::BitOrAssign<Self> for Line<P>
200where
201 P: CubePrimitive,
202 P: core::ops::BitOrAssign,
203{
204 fn bitor_assign(&mut self, rhs: Self) {
205 self.val |= rhs.val;
206 }
207}
208
209impl<P> core::ops::BitXorAssign<Self> for Line<P>
210where
211 P: CubePrimitive,
212 P: core::ops::BitXorAssign,
213{
214 fn bitxor_assign(&mut self, rhs: Self) {
215 self.val ^= rhs.val;
216 }
217}
218
219impl<P> core::ops::ShlAssign<Self> for Line<P>
220where
221 P: CubePrimitive,
222 P: core::ops::ShlAssign,
223{
224 fn shl_assign(&mut self, rhs: Self) {
225 self.val <<= rhs.val;
226 }
227}
228
229impl<P> core::ops::ShrAssign<Self> for Line<P>
230where
231 P: CubePrimitive,
232 P: core::ops::ShrAssign,
233{
234 fn shr_assign(&mut self, rhs: Self) {
235 self.val >>= rhs.val;
236 }
237}
238
239impl<P: CubePrimitive + Abs> Abs for Line<P> {}
240impl<P: CubePrimitive + Max> Max for Line<P> {}
241impl<P: CubePrimitive + Min> Min for Line<P> {}
242impl<P: CubePrimitive + Clamp> Clamp for Line<P> {}
243impl<P: CubePrimitive + Log> Log for Line<P> {}
244impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
245impl<P: CubePrimitive + Erf> Erf for Line<P> {}
246impl<P: CubePrimitive + Exp> Exp for Line<P> {}
247impl<P: CubePrimitive + Powf> Powf for Line<P> {}
248impl<P: CubePrimitive + Powi<I>, I: CubePrimitive> Powi<Line<I>> for Line<P> {}
249impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
250impl<P: CubePrimitive + Cos> Cos for Line<P> {}
251impl<P: CubePrimitive + Sin> Sin for Line<P> {}
252impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
253impl<P: CubePrimitive + Recip> Recip for Line<P> {}
254impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
255impl<P: CubePrimitive + Round> Round for Line<P> {}
256impl<P: CubePrimitive + Floor> Floor for Line<P> {}
257impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
258impl<P: CubePrimitive + Trunc> Trunc for Line<P> {}
259impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
260impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
261impl<P: CubePrimitive + SaturatingAdd> SaturatingAdd for Line<P> {}
262impl<P: CubePrimitive + SaturatingSub> SaturatingSub for Line<P> {}
263impl<P: CubePrimitive + IsNan> IsNan for Line<P> {}
264impl<P: CubePrimitive + IsInf> IsInf for Line<P> {}
265
266#[cube]
267impl<P: CountOnes> Line<P> {
268 pub fn count_ones(self) -> Line<u32> {
269 intrinsic!(|scope| {
270 let out_item =
271 Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
272 let out = scope.create_local(out_item);
273 scope.register(Instruction::new(
274 Bitwise::CountOnes(UnaryOperator {
275 input: *self.expand,
276 }),
277 *out,
278 ));
279 out.into()
280 })
281 }
282}
283
284#[cube]
285impl<P: LeadingZeros> Line<P> {
286 pub fn leading_zeros(self) -> Line<u32> {
287 intrinsic!(|scope| {
288 let out_item =
289 Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
290 let out = scope.create_local(out_item);
291 scope.register(Instruction::new(
292 Bitwise::LeadingZeros(UnaryOperator {
293 input: *self.expand,
294 }),
295 *out,
296 ));
297 out.into()
298 })
299 }
300}
301
302#[cube]
303impl<P: FindFirstSet> Line<P> {
304 pub fn find_first_set(self) -> Line<u32> {
305 intrinsic!(|scope| {
306 let out_item =
307 Type::scalar(ElemType::UInt(UIntKind::U32)).line(self.expand.ty.line_size());
308 let out = scope.create_local(out_item);
309 scope.register(Instruction::new(
310 Bitwise::FindFirstSet(UnaryOperator {
311 input: *self.expand,
312 }),
313 *out,
314 ));
315 out.into()
316 })
317 }
318}
319
320impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
321 fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
322 let val: P = NumCast::from(n)?;
323 Some(Self { val })
324 }
325}
326impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
327 fn to_i64(&self) -> Option<i64> {
328 self.val.to_i64()
329 }
330
331 fn to_u64(&self) -> Option<u64> {
332 self.val.to_u64()
333 }
334}
335
336#[allow(clippy::from_over_into)]
337impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
338 fn into(self) -> ExpandElementTyped<Self> {
339 let elem: ExpandElementTyped<P> = self.val.into();
340 elem.expand.into()
341 }
342}
343
344macro_rules! operation_literal {
345 ($lit:ty) => {
346 impl<P> core::ops::Add<$lit> for Line<P>
347 where
348 P: CubePrimitive,
349 {
350 type Output = Self;
351
352 fn add(self, _rhs: $lit) -> Self::Output {
353 unexpanded!();
354 }
355 }
356
357 impl<P> core::ops::Sub<$lit> for Line<P>
358 where
359 P: CubePrimitive,
360 {
361 type Output = Self;
362
363 fn sub(self, _rhs: $lit) -> Self::Output {
364 unexpanded!();
365 }
366 }
367
368 impl<P> core::ops::Mul<$lit> for Line<P>
369 where
370 P: CubePrimitive,
371 {
372 type Output = Self;
373
374 fn mul(self, _rhs: $lit) -> Self::Output {
375 unexpanded!();
376 }
377 }
378
379 impl<P> core::ops::Div<$lit> for Line<P>
380 where
381 P: CubePrimitive,
382 {
383 type Output = Self;
384
385 fn div(self, _rhs: $lit) -> Self::Output {
386 unexpanded!();
387 }
388 }
389 };
390}
391
392operation_literal!(f32);
393operation_literal!(f64);
394operation_literal!(usize);
395operation_literal!(i32);
396operation_literal!(i64);