cubecl_core/frontend/element/int/
typemap.rs1use bytemuck::{Pod, Zeroable};
2use core::ops::*;
3use cubecl_ir::{
4 ConstantScalarValue, ElemType, ExpandElement, IntKind, Scope, StorageType, Variable,
5};
6use derive_more::derive::{
7 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Display, Div,
8 DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub,
9 SubAssign,
10};
11use num_traits::{NumCast, ToPrimitive};
12use serde::Serialize;
13
14use crate::{Runtime, compute::KernelLauncher, prelude::*};
15
16use super::{Int, into_mut_expand_element};
17
18#[repr(transparent)]
19#[derive(
20 Clone,
21 Copy,
22 Default,
23 Serialize,
24 Zeroable,
25 Pod,
26 PartialEq,
27 PartialOrd,
28 Neg,
29 Add,
30 Sub,
31 Mul,
32 Div,
33 Rem,
34 AddAssign,
35 SubAssign,
36 MulAssign,
37 DivAssign,
38 RemAssign,
39 Debug,
40 Display,
41 Shl,
42 ShlAssign,
43 Shr,
44 ShrAssign,
45 BitXor,
46 BitXorAssign,
47 BitAnd,
48 BitAndAssign,
49 BitOr,
50 BitOrAssign,
51 Not,
52)]
53pub struct IntExpand<const POS: u8>(i64);
54
55impl<const POS: u8> Mul for IntExpand<POS> {
56 type Output = Self;
57
58 fn mul(self, rhs: Self) -> Self::Output {
59 IntExpand(self.0 * rhs.0)
60 }
61}
62
63impl<const POS: u8> Div for IntExpand<POS> {
64 type Output = Self;
65
66 fn div(self, rhs: Self) -> Self::Output {
67 IntExpand(self.0 / rhs.0)
68 }
69}
70
71impl<const POS: u8> Rem for IntExpand<POS> {
72 type Output = Self;
73
74 fn rem(self, rhs: Self) -> Self::Output {
75 IntExpand(self.0 % rhs.0)
76 }
77}
78
79impl<const POS: u8> MulAssign for IntExpand<POS> {
80 fn mul_assign(&mut self, rhs: Self) {
81 self.0 *= rhs.0;
82 }
83}
84
85impl<const POS: u8> DivAssign for IntExpand<POS> {
86 fn div_assign(&mut self, rhs: Self) {
87 self.0 /= rhs.0;
88 }
89}
90
91impl<const POS: u8> RemAssign for IntExpand<POS> {
92 fn rem_assign(&mut self, rhs: Self) {
93 self.0 %= rhs.0;
94 }
95}
96
97impl<const POS: u8> Shr for IntExpand<POS> {
98 type Output = Self;
99
100 fn shr(self, rhs: Self) -> Self::Output {
101 IntExpand(self.0 >> rhs.0)
102 }
103}
104
105impl<const POS: u8> Shl for IntExpand<POS> {
106 type Output = Self;
107
108 fn shl(self, rhs: Self) -> Self::Output {
109 IntExpand(self.0 << rhs.0)
110 }
111}
112
113impl<const POS: u8> ToPrimitive for IntExpand<POS> {
114 fn to_i64(&self) -> Option<i64> {
115 Some(self.0)
116 }
117
118 fn to_u64(&self) -> Option<u64> {
119 Some(self.0 as u64)
120 }
121
122 fn to_f32(&self) -> Option<f32> {
123 Some(self.0 as f32)
124 }
125
126 fn to_f64(&self) -> Option<f64> {
127 Some(self.0 as f64)
128 }
129}
130
131impl<const POS: u8> NumCast for IntExpand<POS> {
132 fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
133 Some(IntExpand(n.to_i64()?))
134 }
135}
136
137impl<const POS: u8> CubeType for IntExpand<POS> {
138 type ExpandType = ExpandElementTyped<IntExpand<POS>>;
139}
140
141impl<const POS: u8> CubePrimitive for IntExpand<POS> {
142 fn as_type(scope: &Scope) -> StorageType {
144 scope.resolve_type::<Self>().expect("Type to be registered")
145 }
146
147 fn from_const_value(_value: ConstantScalarValue) -> Self {
148 unimplemented!("Can't turn `IntExpand` into a constant value")
149 }
150}
151
152impl<const POS: u8> From<IntExpand<POS>> for Variable {
153 fn from(val: IntExpand<POS>) -> Self {
154 Variable::new(
156 crate::ir::VariableKind::ConstantScalar(crate::ir::ConstantScalarValue::Int(
157 val.0,
158 cubecl_ir::IntKind::I32,
159 )),
160 crate::ir::Type::scalar(ElemType::Int(IntKind::I64)),
161 )
162 }
163}
164
165impl<const POS: u8> From<IntExpand<POS>> for ExpandElementTyped<IntExpand<POS>> {
166 fn from(value: IntExpand<POS>) -> Self {
167 let var: Variable = value.into();
168 ExpandElementTyped::new(ExpandElement::Plain(var))
169 }
170}
171
172impl<const POS: u8> IntoRuntime for IntExpand<POS> {
173 fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
174 let elem: ExpandElementTyped<Self> = ExpandElementTyped::from_lit(scope, self.0);
175 into_runtime_expand_element(scope, elem).into()
176 }
177}
178
179impl<const POS: u8> Numeric for IntExpand<POS> {
180 fn min_value() -> Self {
181 panic!("Can't use min value in comptime with dynamic element type");
182 }
183 fn max_value() -> Self {
184 panic!("Can't use max value in comptime with dynamic element type");
185 }
186}
187
188impl<const POS: u8> ExpandElementIntoMut for IntExpand<POS> {
189 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
190 into_mut_expand_element(scope, elem)
191 }
192}
193
194impl<const POS: u8> Remainder for IntExpand<POS> {}
195impl<const POS: u8> Abs for IntExpand<POS> {}
196impl<const POS: u8> Max for IntExpand<POS> {}
197impl<const POS: u8> Min for IntExpand<POS> {}
198impl<const POS: u8> Clamp for IntExpand<POS> {}
199impl<const POS: u8> MulHi for IntExpand<POS> {}
200
201impl<const POS: u8> BitwiseNot for IntExpand<POS> {}
202impl<const POS: u8> ReverseBits for IntExpand<POS> {}
203impl<const POS: u8> CountOnes for IntExpand<POS> {}
204impl<const POS: u8> FindFirstSet for IntExpand<POS> {}
205impl<const POS: u8> LeadingZeros for IntExpand<POS> {}
206impl<const POS: u8> SaturatingAdd for IntExpand<POS> {}
207impl<const POS: u8> SaturatingSub for IntExpand<POS> {}
208
209impl<const POS: u8> Int for IntExpand<POS> {
210 const BITS: u32 = 32;
211
212 fn new(val: i64) -> Self {
213 IntExpand(val)
214 }
215}
216
217impl<const POS: u8> ScalarArgSettings for IntExpand<POS> {
218 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
219 settings.register_i32(self.0 as i32);
220 }
221}