cubecl_core/frontend/element/int/
typemap.rs

1use bytemuck::{Pod, Zeroable};
2use core::ops::*;
3use cubecl_ir::{
4    ConstantScalarValue, ElemType, ExpandElement, IntKind, Scope, StorageType, Variable,
5};
6use derive_more::derive::{
7    Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Display, Div,
8    DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub,
9    SubAssign,
10};
11use num_traits::{NumCast, ToPrimitive};
12use serde::Serialize;
13
14use crate::{Runtime, compute::KernelLauncher, prelude::*};
15
16use super::{Int, into_mut_expand_element};
17
18#[repr(transparent)]
19#[derive(
20    Clone,
21    Copy,
22    Default,
23    Serialize,
24    Zeroable,
25    Pod,
26    PartialEq,
27    PartialOrd,
28    Neg,
29    Add,
30    Sub,
31    Mul,
32    Div,
33    Rem,
34    AddAssign,
35    SubAssign,
36    MulAssign,
37    DivAssign,
38    RemAssign,
39    Debug,
40    Display,
41    Shl,
42    ShlAssign,
43    Shr,
44    ShrAssign,
45    BitXor,
46    BitXorAssign,
47    BitAnd,
48    BitAndAssign,
49    BitOr,
50    BitOrAssign,
51    Not,
52)]
53pub struct IntExpand<const POS: u8>(i64);
54
55impl<const POS: u8> Mul for IntExpand<POS> {
56    type Output = Self;
57
58    fn mul(self, rhs: Self) -> Self::Output {
59        IntExpand(self.0 * rhs.0)
60    }
61}
62
63impl<const POS: u8> Div for IntExpand<POS> {
64    type Output = Self;
65
66    fn div(self, rhs: Self) -> Self::Output {
67        IntExpand(self.0 / rhs.0)
68    }
69}
70
71impl<const POS: u8> Rem for IntExpand<POS> {
72    type Output = Self;
73
74    fn rem(self, rhs: Self) -> Self::Output {
75        IntExpand(self.0 % rhs.0)
76    }
77}
78
79impl<const POS: u8> MulAssign for IntExpand<POS> {
80    fn mul_assign(&mut self, rhs: Self) {
81        self.0 *= rhs.0;
82    }
83}
84
85impl<const POS: u8> DivAssign for IntExpand<POS> {
86    fn div_assign(&mut self, rhs: Self) {
87        self.0 /= rhs.0;
88    }
89}
90
91impl<const POS: u8> RemAssign for IntExpand<POS> {
92    fn rem_assign(&mut self, rhs: Self) {
93        self.0 %= rhs.0;
94    }
95}
96
97impl<const POS: u8> Shr for IntExpand<POS> {
98    type Output = Self;
99
100    fn shr(self, rhs: Self) -> Self::Output {
101        IntExpand(self.0 >> rhs.0)
102    }
103}
104
105impl<const POS: u8> Shl for IntExpand<POS> {
106    type Output = Self;
107
108    fn shl(self, rhs: Self) -> Self::Output {
109        IntExpand(self.0 << rhs.0)
110    }
111}
112
113impl<const POS: u8> ToPrimitive for IntExpand<POS> {
114    fn to_i64(&self) -> Option<i64> {
115        Some(self.0)
116    }
117
118    fn to_u64(&self) -> Option<u64> {
119        Some(self.0 as u64)
120    }
121
122    fn to_f32(&self) -> Option<f32> {
123        Some(self.0 as f32)
124    }
125
126    fn to_f64(&self) -> Option<f64> {
127        Some(self.0 as f64)
128    }
129}
130
131impl<const POS: u8> NumCast for IntExpand<POS> {
132    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
133        Some(IntExpand(n.to_i64()?))
134    }
135}
136
137impl<const POS: u8> CubeType for IntExpand<POS> {
138    type ExpandType = ExpandElementTyped<IntExpand<POS>>;
139}
140
141impl<const POS: u8> CubePrimitive for IntExpand<POS> {
142    /// Return the element type to use on GPU
143    fn as_type(scope: &Scope) -> StorageType {
144        scope.resolve_type::<Self>().expect("Type to be registered")
145    }
146
147    fn from_const_value(_value: ConstantScalarValue) -> Self {
148        unimplemented!("Can't turn `IntExpand` into a constant value")
149    }
150}
151
152impl<const POS: u8> From<IntExpand<POS>> for Variable {
153    fn from(val: IntExpand<POS>) -> Self {
154        // TODO: Fix how we create literal.
155        Variable::new(
156            crate::ir::VariableKind::ConstantScalar(crate::ir::ConstantScalarValue::Int(
157                val.0,
158                cubecl_ir::IntKind::I32,
159            )),
160            crate::ir::Type::scalar(ElemType::Int(IntKind::I64)),
161        )
162    }
163}
164
165impl<const POS: u8> From<IntExpand<POS>> for ExpandElementTyped<IntExpand<POS>> {
166    fn from(value: IntExpand<POS>) -> Self {
167        let var: Variable = value.into();
168        ExpandElementTyped::new(ExpandElement::Plain(var))
169    }
170}
171
172impl<const POS: u8> IntoRuntime for IntExpand<POS> {
173    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
174        let elem: ExpandElementTyped<Self> = ExpandElementTyped::from_lit(scope, self.0);
175        into_runtime_expand_element(scope, elem).into()
176    }
177}
178
179impl<const POS: u8> Numeric for IntExpand<POS> {
180    fn min_value() -> Self {
181        panic!("Can't use min value in comptime with dynamic element type");
182    }
183    fn max_value() -> Self {
184        panic!("Can't use max value in comptime with dynamic element type");
185    }
186}
187
188impl<const POS: u8> ExpandElementIntoMut for IntExpand<POS> {
189    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
190        into_mut_expand_element(scope, elem)
191    }
192}
193
194impl<const POS: u8> Remainder for IntExpand<POS> {}
195impl<const POS: u8> Abs for IntExpand<POS> {}
196impl<const POS: u8> Max for IntExpand<POS> {}
197impl<const POS: u8> Min for IntExpand<POS> {}
198impl<const POS: u8> Clamp for IntExpand<POS> {}
199impl<const POS: u8> MulHi for IntExpand<POS> {}
200
201impl<const POS: u8> BitwiseNot for IntExpand<POS> {}
202impl<const POS: u8> ReverseBits for IntExpand<POS> {}
203impl<const POS: u8> CountOnes for IntExpand<POS> {}
204impl<const POS: u8> FindFirstSet for IntExpand<POS> {}
205impl<const POS: u8> LeadingZeros for IntExpand<POS> {}
206impl<const POS: u8> SaturatingAdd for IntExpand<POS> {}
207impl<const POS: u8> SaturatingSub for IntExpand<POS> {}
208
209impl<const POS: u8> Int for IntExpand<POS> {
210    const BITS: u32 = 32;
211
212    fn new(val: i64) -> Self {
213        IntExpand(val)
214    }
215}
216
217impl<const POS: u8> ScalarArgSettings for IntExpand<POS> {
218    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
219        settings.register_i32(self.0 as i32);
220    }
221}