cubecl_core/frontend/container/line/
ops.rs1use cubecl_ir::{Bitwise, Elem, Instruction, Scope, UIntKind, UnaryOperator};
2use num_traits::{NumCast, ToPrimitive};
3
4use crate::{
5 frontend::{
6 Abs, Ceil, Clamp, Cos, CubeIndex, CubeIndexMut, CubePrimitive, Erf, Exp,
7 ExpandElementTyped, Floor, Log, Log1p, Max, Min, Powf, Recip, Remainder, Round, Sin, Sqrt,
8 Tanh,
9 },
10 prelude::{BitwiseNot, CountOnes, FindFirstSet, LeadingZeros, ReverseBits},
11 unexpanded,
12};
13
14use super::Line;
15
16impl<P> core::ops::Add<Self> for Line<P>
17where
18 P: CubePrimitive,
19 P: core::ops::Add<P, Output = P>,
20{
21 type Output = Self;
22
23 fn add(self, rhs: Self) -> Self::Output {
24 Self::new(self.val + rhs.val)
25 }
26}
27
28impl<P> core::ops::Sub<Self> for Line<P>
29where
30 P: CubePrimitive,
31 P: core::ops::Sub<P, Output = P>,
32{
33 type Output = Self;
34
35 fn sub(self, rhs: Self) -> Self::Output {
36 Self::new(self.val - rhs.val)
37 }
38}
39
40impl<P> core::ops::Mul<Self> for Line<P>
41where
42 P: CubePrimitive,
43 P: core::ops::Mul<P, Output = P>,
44{
45 type Output = Self;
46
47 fn mul(self, rhs: Self) -> Self::Output {
48 Self::new(self.val * rhs.val)
49 }
50}
51
52impl<P> core::ops::Div<Self> for Line<P>
53where
54 P: CubePrimitive,
55 P: core::ops::Div<P, Output = P>,
56{
57 type Output = Self;
58
59 fn div(self, rhs: Self) -> Self::Output {
60 Self::new(self.val / rhs.val)
61 }
62}
63
64impl<P> core::ops::AddAssign<Self> for Line<P>
65where
66 P: CubePrimitive,
67 P: core::ops::AddAssign,
68{
69 fn add_assign(&mut self, rhs: Self) {
70 self.val += rhs.val;
71 }
72}
73
74impl<P> core::ops::SubAssign<Self> for Line<P>
75where
76 P: CubePrimitive,
77 P: core::ops::SubAssign,
78{
79 fn sub_assign(&mut self, rhs: Self) {
80 self.val -= rhs.val;
81 }
82}
83
84impl<P> core::ops::DivAssign<Self> for Line<P>
85where
86 P: CubePrimitive,
87 P: core::ops::DivAssign,
88{
89 fn div_assign(&mut self, rhs: Self) {
90 self.val /= rhs.val;
91 }
92}
93
94impl<P> core::ops::MulAssign<Self> for Line<P>
95where
96 P: CubePrimitive,
97 P: core::ops::MulAssign,
98{
99 fn mul_assign(&mut self, rhs: Self) {
100 self.val *= rhs.val;
101 }
102}
103
104impl<P> core::cmp::PartialEq for Line<P>
105where
106 P: CubePrimitive,
107 P: core::cmp::PartialEq,
108{
109 fn eq(&self, other: &Self) -> bool {
110 self.val.eq(&other.val)
111 }
112}
113
114impl<P> core::cmp::PartialOrd for Line<P>
115where
116 P: CubePrimitive,
117 P: core::cmp::PartialOrd,
118{
119 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
120 self.val.partial_cmp(&other.val)
121 }
122}
123
124impl<P> core::ops::BitAnd<Self> for Line<P>
125where
126 P: CubePrimitive,
127 P: core::ops::BitAnd<P, Output = P>,
128{
129 type Output = Self;
130
131 fn bitand(self, rhs: Self) -> Self::Output {
132 Self::new(self.val & rhs.val)
133 }
134}
135
136impl<P> core::ops::BitOr<Self> for Line<P>
137where
138 P: CubePrimitive,
139 P: core::ops::BitOr<P, Output = P>,
140{
141 type Output = Self;
142
143 fn bitor(self, rhs: Self) -> Self::Output {
144 Self::new(self.val | rhs.val)
145 }
146}
147
148impl<P> core::ops::BitXor<Self> for Line<P>
149where
150 P: CubePrimitive,
151 P: core::ops::BitXor<P, Output = P>,
152{
153 type Output = Self;
154
155 fn bitxor(self, rhs: Self) -> Self::Output {
156 Self::new(self.val ^ rhs.val)
157 }
158}
159
160impl<P> core::ops::Shl<Self> for Line<P>
161where
162 P: CubePrimitive,
163 P: core::ops::Shl<P, Output = P>,
164{
165 type Output = Self;
166
167 fn shl(self, rhs: Self) -> Self::Output {
168 Self::new(self.val << rhs.val)
169 }
170}
171
172impl<P> core::ops::Shr<Self> for Line<P>
173where
174 P: CubePrimitive,
175 P: core::ops::Shr<P, Output = P>,
176{
177 type Output = Self;
178
179 fn shr(self, rhs: Self) -> Self::Output {
180 Self::new(self.val >> rhs.val)
181 }
182}
183
184impl<P> core::ops::BitAndAssign<Self> for Line<P>
185where
186 P: CubePrimitive,
187 P: core::ops::BitAndAssign,
188{
189 fn bitand_assign(&mut self, rhs: Self) {
190 self.val &= rhs.val;
191 }
192}
193
194impl<P> core::ops::BitOrAssign<Self> for Line<P>
195where
196 P: CubePrimitive,
197 P: core::ops::BitOrAssign,
198{
199 fn bitor_assign(&mut self, rhs: Self) {
200 self.val |= rhs.val;
201 }
202}
203
204impl<P> core::ops::BitXorAssign<Self> for Line<P>
205where
206 P: CubePrimitive,
207 P: core::ops::BitXorAssign,
208{
209 fn bitxor_assign(&mut self, rhs: Self) {
210 self.val ^= rhs.val;
211 }
212}
213
214impl<P> core::ops::ShlAssign<Self> for Line<P>
215where
216 P: CubePrimitive,
217 P: core::ops::ShlAssign,
218{
219 fn shl_assign(&mut self, rhs: Self) {
220 self.val <<= rhs.val;
221 }
222}
223
224impl<P> core::ops::ShrAssign<Self> for Line<P>
225where
226 P: CubePrimitive,
227 P: core::ops::ShrAssign,
228{
229 fn shr_assign(&mut self, rhs: Self) {
230 self.val >>= rhs.val;
231 }
232}
233
234impl<P: CubePrimitive + Abs> Abs for Line<P> {}
235impl<P: CubePrimitive + Max> Max for Line<P> {}
236impl<P: CubePrimitive + Min> Min for Line<P> {}
237impl<P: CubePrimitive + Clamp> Clamp for Line<P> {}
238impl<P: CubePrimitive + Log> Log for Line<P> {}
239impl<P: CubePrimitive + Log1p> Log1p for Line<P> {}
240impl<P: CubePrimitive + Erf> Erf for Line<P> {}
241impl<P: CubePrimitive + Exp> Exp for Line<P> {}
242impl<P: CubePrimitive + Powf> Powf for Line<P> {}
243impl<P: CubePrimitive + Sqrt> Sqrt for Line<P> {}
244impl<P: CubePrimitive + Cos> Cos for Line<P> {}
245impl<P: CubePrimitive + Sin> Sin for Line<P> {}
246impl<P: CubePrimitive + Tanh> Tanh for Line<P> {}
247impl<P: CubePrimitive + Recip> Recip for Line<P> {}
248impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
249impl<P: CubePrimitive + Round> Round for Line<P> {}
250impl<P: CubePrimitive + Floor> Floor for Line<P> {}
251impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
252impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
253impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
254
255impl<P: CountOnes> Line<P> {
256 pub fn count_ones(self) -> Line<u32> {
257 unexpanded!()
258 }
259
260 pub fn __expand_count_ones(
261 scope: &mut Scope,
262 value: ExpandElementTyped<Self>,
263 ) -> ExpandElementTyped<Line<u32>> {
264 value.__expand_count_ones_method(scope)
265 }
266}
267
268impl<P: CountOnes> ExpandElementTyped<Line<P>> {
269 pub fn __expand_count_ones_method(self, scope: &mut Scope) -> ExpandElementTyped<Line<u32>> {
270 let mut out_item = self.expand.item;
271 out_item.elem = Elem::UInt(UIntKind::U32);
272 let out = scope.create_local(out_item);
273 scope.register(Instruction::new(
274 Bitwise::CountOnes(UnaryOperator {
275 input: *self.expand,
276 }),
277 *out,
278 ));
279 out.into()
280 }
281}
282
283impl<P: LeadingZeros> Line<P> {
284 pub fn leading_zeros(self) -> Line<u32> {
285 unexpanded!()
286 }
287
288 pub fn __expand_leading_zeros(
289 scope: &mut Scope,
290 value: ExpandElementTyped<Self>,
291 ) -> ExpandElementTyped<Line<u32>> {
292 value.__expand_leading_zeros_method(scope)
293 }
294}
295
296impl<P: LeadingZeros> ExpandElementTyped<Line<P>> {
297 pub fn __expand_leading_zeros_method(self, scope: &mut Scope) -> ExpandElementTyped<Line<u32>> {
298 let mut out_item = self.expand.item;
299 out_item.elem = Elem::UInt(UIntKind::U32);
300 let out = scope.create_local(out_item);
301 scope.register(Instruction::new(
302 Bitwise::LeadingZeros(UnaryOperator {
303 input: *self.expand,
304 }),
305 *out,
306 ));
307 out.into()
308 }
309}
310
311impl<P: FindFirstSet> Line<P> {
312 pub fn find_first_set(self) -> Line<u32> {
313 unexpanded!()
314 }
315
316 pub fn __expand_find_first_set(
317 scope: &mut Scope,
318 value: ExpandElementTyped<Self>,
319 ) -> ExpandElementTyped<Line<u32>> {
320 value.__expand_find_first_set_method(scope)
321 }
322}
323
324impl<P: FindFirstSet> ExpandElementTyped<Line<P>> {
325 pub fn __expand_find_first_set_method(
326 self,
327 scope: &mut Scope,
328 ) -> ExpandElementTyped<Line<u32>> {
329 let mut out_item = self.expand.item;
330 out_item.elem = Elem::UInt(UIntKind::U32);
331 let out = scope.create_local(out_item);
332 scope.register(Instruction::new(
333 Bitwise::FindFirstSet(UnaryOperator {
334 input: *self.expand,
335 }),
336 *out,
337 ));
338 out.into()
339 }
340}
341
342impl<P: CubePrimitive + NumCast> NumCast for Line<P> {
343 fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
344 let val: P = NumCast::from(n)?;
345 Some(Self { val })
346 }
347}
348impl<P: CubePrimitive + NumCast> ToPrimitive for Line<P> {
349 fn to_i64(&self) -> Option<i64> {
350 self.val.to_i64()
351 }
352
353 fn to_u64(&self) -> Option<u64> {
354 self.val.to_u64()
355 }
356}
357
358impl<P> CubeIndex<u32> for Line<P>
359where
360 P: CubePrimitive,
361{
362 type Output = P;
363
364 fn cube_idx(&self, _i: u32) -> &Self::Output {
365 unexpanded!()
366 }
367}
368
369impl<P> CubeIndexMut<u32> for Line<P>
370where
371 P: CubePrimitive,
372{
373 fn cube_idx_mut(&mut self, _i: u32) -> &mut Self::Output {
374 unexpanded!()
375 }
376}
377
378impl<P> CubeIndex<ExpandElementTyped<u32>> for Line<P>
379where
380 P: CubePrimitive,
381{
382 type Output = P;
383
384 fn cube_idx(&self, _i: ExpandElementTyped<u32>) -> &Self::Output {
385 unexpanded!()
386 }
387}
388
389impl<P> CubeIndexMut<ExpandElementTyped<u32>> for Line<P>
390where
391 P: CubePrimitive,
392{
393 fn cube_idx_mut(&mut self, _i: ExpandElementTyped<u32>) -> &mut Self::Output {
394 unexpanded!()
395 }
396}
397
398#[allow(clippy::from_over_into)]
399impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Into<ExpandElementTyped<Self>> for Line<P> {
400 fn into(self) -> ExpandElementTyped<Self> {
401 let elem: ExpandElementTyped<P> = self.val.into();
402 elem.expand.into()
403 }
404}
405
406macro_rules! operation_literal {
407 ($lit:ty) => {
408 impl<P> core::ops::Add<$lit> for Line<P>
409 where
410 P: CubePrimitive,
411 {
412 type Output = Self;
413
414 fn add(self, _rhs: $lit) -> Self::Output {
415 unexpanded!();
416 }
417 }
418
419 impl<P> core::ops::Sub<$lit> for Line<P>
420 where
421 P: CubePrimitive,
422 {
423 type Output = Self;
424
425 fn sub(self, _rhs: $lit) -> Self::Output {
426 unexpanded!();
427 }
428 }
429
430 impl<P> core::ops::Mul<$lit> for Line<P>
431 where
432 P: CubePrimitive,
433 {
434 type Output = Self;
435
436 fn mul(self, _rhs: $lit) -> Self::Output {
437 unexpanded!();
438 }
439 }
440
441 impl<P> core::ops::Div<$lit> for Line<P>
442 where
443 P: CubePrimitive,
444 {
445 type Output = Self;
446
447 fn div(self, _rhs: $lit) -> Self::Output {
448 unexpanded!();
449 }
450 }
451 };
452}
453
454operation_literal!(f32);
455operation_literal!(f64);
456operation_literal!(usize);
457operation_literal!(i32);
458operation_literal!(i64);