1use cubecl::prelude::*;
2use cubecl_core::{
3 self as cubecl, intrinsic,
4 ir::{ElemType, ExpandElement, FloatKind, IntKind, UIntKind},
5};
6use half::{bf16, f16};
7use serde::{Deserialize, Serialize};
8
9#[derive(CubeType, Clone)]
10pub enum InputScalar {
14 F64(f64),
15 F32(f32),
16 F16(f16),
17 BF16(bf16),
18 I64(i64),
19 I32(i32),
20 I16(i16),
21 I8(i8),
22 U64(u64),
23 U32(u32),
24 U16(u16),
25 U8(u8),
26}
27
28impl InputScalar {
29 pub fn new<E: Numeric>(val: E, dtype: impl Into<ElemType>) -> Self {
35 let dtype: ElemType = dtype.into();
36 match dtype {
37 ElemType::Float(float_kind) => match float_kind {
38 FloatKind::F16 => Self::F16(half::f16::from_f32(val.to_f32().unwrap())),
39 FloatKind::BF16 => Self::BF16(half::bf16::from_f32(val.to_f32().unwrap())),
40 FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => {
41 Self::F32(val.to_f32().unwrap())
42 }
43 FloatKind::F64 => Self::F64(val.to_f64().unwrap()),
44 _ => panic!("Unsupported float element type"),
45 },
46 ElemType::Int(int_kind) => match int_kind {
47 IntKind::I8 => Self::I8(val.to_i8().unwrap()),
48 IntKind::I16 => Self::I16(val.to_i16().unwrap()),
49 IntKind::I32 => Self::I32(val.to_i32().unwrap()),
50 IntKind::I64 => Self::I64(val.to_i64().unwrap()),
51 },
52 ElemType::UInt(uint_kind) => match uint_kind {
53 UIntKind::U8 => Self::U8(val.to_u8().unwrap()),
54 UIntKind::U16 => Self::U16(val.to_u16().unwrap()),
55 UIntKind::U32 => Self::U32(val.to_u32().unwrap()),
56 UIntKind::U64 => Self::U64(val.to_u64().unwrap()),
57 },
58 ElemType::Bool => panic!("Bool isn't a scalar"),
59 }
60 }
61}
62
63#[cube]
64impl InputScalar {
65 pub fn get<C: CubePrimitive>(&self) -> C {
69 intrinsic!(|scope| {
70 let dtype = C::as_type(scope).elem_type();
71
72 match self {
73 InputScalarExpand::U64(val) => {
74 if dtype == ElemType::UInt(cubecl::ir::UIntKind::U64) {
75 let expand: ExpandElement = val.clone().into();
76 ExpandElementTyped::from(expand.clone())
77 } else {
78 C::__expand_cast_from(scope, val.clone())
79 }
80 }
81 InputScalarExpand::U32(val) => {
82 if dtype == ElemType::UInt(cubecl::ir::UIntKind::U32) {
83 let expand: ExpandElement = val.clone().into();
84 ExpandElementTyped::from(expand.clone())
85 } else {
86 C::__expand_cast_from(scope, val.clone())
87 }
88 }
89 InputScalarExpand::U16(val) => {
90 if dtype == ElemType::UInt(cubecl::ir::UIntKind::U16) {
91 let expand: ExpandElement = val.clone().into();
92 ExpandElementTyped::from(expand.clone())
93 } else {
94 C::__expand_cast_from(scope, val.clone())
95 }
96 }
97 InputScalarExpand::F64(val) => {
98 if dtype == ElemType::Float(cubecl::ir::FloatKind::F64) {
99 let expand: ExpandElement = val.clone().into();
100 ExpandElementTyped::from(expand.clone())
101 } else {
102 C::__expand_cast_from(scope, val.clone())
103 }
104 }
105 InputScalarExpand::F32(val) => {
106 if dtype == ElemType::Float(cubecl::ir::FloatKind::F32) {
107 let expand: ExpandElement = val.clone().into();
108 ExpandElementTyped::from(expand.clone())
109 } else {
110 C::__expand_cast_from(scope, val.clone())
111 }
112 }
113 InputScalarExpand::F16(val) => {
114 if dtype == ElemType::Float(cubecl::ir::FloatKind::F16) {
115 let expand: ExpandElement = val.clone().into();
116 ExpandElementTyped::from(expand.clone())
117 } else {
118 C::__expand_cast_from(scope, val.clone())
119 }
120 }
121 InputScalarExpand::BF16(val) => {
122 if dtype == ElemType::Float(cubecl::ir::FloatKind::BF16) {
123 let expand: ExpandElement = val.clone().into();
124 ExpandElementTyped::from(expand.clone())
125 } else {
126 C::__expand_cast_from(scope, val.clone())
127 }
128 }
129 InputScalarExpand::U8(val) => {
130 if dtype == ElemType::UInt(cubecl::ir::UIntKind::U8) {
131 let expand: ExpandElement = val.clone().into();
132 ExpandElementTyped::from(expand.clone())
133 } else {
134 C::__expand_cast_from(scope, val.clone())
135 }
136 }
137
138 InputScalarExpand::I64(val) => {
139 if dtype == ElemType::Int(cubecl::ir::IntKind::I64) {
140 let expand: ExpandElement = val.clone().into();
141 ExpandElementTyped::from(expand.clone())
142 } else {
143 C::__expand_cast_from(scope, val.clone())
144 }
145 }
146 InputScalarExpand::I32(val) => {
147 if dtype == ElemType::Int(cubecl::ir::IntKind::I32) {
148 let expand: ExpandElement = val.clone().into();
149 ExpandElementTyped::from(expand.clone())
150 } else {
151 C::__expand_cast_from(scope, val.clone())
152 }
153 }
154 InputScalarExpand::I16(val) => {
155 if dtype == ElemType::Int(cubecl::ir::IntKind::I16) {
156 let expand: ExpandElement = val.clone().into();
157 ExpandElementTyped::from(expand.clone())
158 } else {
159 C::__expand_cast_from(scope, val.clone())
160 }
161 }
162 InputScalarExpand::I8(val) => {
163 if dtype == ElemType::Int(cubecl::ir::IntKind::I8) {
164 let expand: ExpandElement = val.clone().into();
165 ExpandElementTyped::from(expand.clone())
166 } else {
167 C::__expand_cast_from(scope, val.clone())
168 }
169 }
170 }
171 })
172 }
173}
174
175impl LaunchArg for InputScalar {
176 type RuntimeArg<'a, R: Runtime> = InputScalar;
177 type CompilationArg = InputScalarCompilationArg;
178
179 fn compilation_arg<R: Runtime>(arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
180 match arg {
181 InputScalar::F64(_) => {
182 InputScalarCompilationArg::new(ElemType::Float(FloatKind::F64).into())
183 }
184 InputScalar::F32(_) => {
185 InputScalarCompilationArg::new(ElemType::Float(FloatKind::F32).into())
186 }
187 InputScalar::F16(_) => {
188 InputScalarCompilationArg::new(ElemType::Float(FloatKind::F16).into())
189 }
190 InputScalar::BF16(_) => {
191 InputScalarCompilationArg::new(ElemType::Float(FloatKind::BF16).into())
192 }
193 InputScalar::I64(_) => {
194 InputScalarCompilationArg::new(ElemType::Int(IntKind::I64).into())
195 }
196 InputScalar::I32(_) => {
197 InputScalarCompilationArg::new(ElemType::Int(IntKind::I32).into())
198 }
199 InputScalar::I16(_) => {
200 InputScalarCompilationArg::new(ElemType::Int(IntKind::I16).into())
201 }
202 InputScalar::I8(_) => InputScalarCompilationArg::new(ElemType::Int(IntKind::I8).into()),
203 InputScalar::U64(_) => {
204 InputScalarCompilationArg::new(ElemType::UInt(UIntKind::U64).into())
205 }
206 InputScalar::U32(_) => {
207 InputScalarCompilationArg::new(ElemType::UInt(UIntKind::U32).into())
208 }
209 InputScalar::U16(_) => {
210 InputScalarCompilationArg::new(ElemType::UInt(UIntKind::U16).into())
211 }
212 InputScalar::U8(_) => {
213 InputScalarCompilationArg::new(ElemType::UInt(UIntKind::U8).into())
214 }
215 }
216 }
217
218 fn expand(
219 arg: &Self::CompilationArg,
220 builder: &mut KernelBuilder,
221 ) -> <Self as CubeType>::ExpandType {
222 let expand = builder.scalar(arg.ty);
223 match arg.ty.elem_type() {
224 ElemType::Float(float_kind) => match float_kind {
225 FloatKind::F16 => InputScalarExpand::F16(expand.into()),
226 FloatKind::BF16 => InputScalarExpand::BF16(expand.into()),
227 FloatKind::Flex32 => InputScalarExpand::F32(expand.into()),
228 FloatKind::F32 => InputScalarExpand::F32(expand.into()),
229 FloatKind::TF32 => InputScalarExpand::F32(expand.into()),
230 FloatKind::F64 => InputScalarExpand::F32(expand.into()),
231 FloatKind::E2M1
232 | FloatKind::E2M3
233 | FloatKind::E3M2
234 | FloatKind::E4M3
235 | FloatKind::E5M2
236 | FloatKind::UE8M0 => unimplemented!("FP8 can't be passed as scalar"),
237 },
238 ElemType::Int(int_kind) => match int_kind {
239 IntKind::I8 => InputScalarExpand::I8(expand.into()),
240 IntKind::I16 => InputScalarExpand::I16(expand.into()),
241 IntKind::I32 => InputScalarExpand::I32(expand.into()),
242 IntKind::I64 => InputScalarExpand::I64(expand.into()),
243 },
244 ElemType::UInt(uint_kind) => match uint_kind {
245 UIntKind::U8 => InputScalarExpand::U8(expand.into()),
246 UIntKind::U16 => InputScalarExpand::U16(expand.into()),
247 UIntKind::U32 => InputScalarExpand::U32(expand.into()),
248 UIntKind::U64 => InputScalarExpand::U64(expand.into()),
249 },
250 ElemType::Bool => panic!("Bool should be converted first."),
251 }
252 }
253}
254
255#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
256pub struct InputScalarCompilationArg {
257 ty: StorageType,
258}
259
260impl InputScalarCompilationArg {
261 pub fn new(ty: StorageType) -> Self {
262 Self { ty }
263 }
264}
265
266impl CompilationArg for InputScalarCompilationArg {}
267
268impl<R: Runtime> ArgSettings<R> for InputScalar {
269 fn register(&self, launcher: &mut KernelLauncher<R>) {
270 match self {
271 InputScalar::F64(val) => launcher.register_f64(*val),
272 InputScalar::F32(val) => launcher.register_f32(*val),
273 InputScalar::F16(val) => launcher.register_f16(*val),
274 InputScalar::BF16(val) => launcher.register_bf16(*val),
275 InputScalar::I64(val) => launcher.register_i64(*val),
276 InputScalar::I32(val) => launcher.register_i32(*val),
277 InputScalar::I16(val) => launcher.register_i16(*val),
278 InputScalar::I8(val) => launcher.register_i8(*val),
279 InputScalar::U64(val) => launcher.register_u64(*val),
280 InputScalar::U32(val) => launcher.register_u32(*val),
281 InputScalar::U16(val) => launcher.register_u16(*val),
282 InputScalar::U8(val) => launcher.register_u8(*val),
283 }
284 }
285}