cubecl_std/
scalar.rs

1use cubecl::prelude::*;
2use cubecl_core::{
3    self as cubecl, intrinsic,
4    ir::{ElemType, ExpandElement, FloatKind, IntKind, UIntKind},
5};
6use half::{bf16, f16};
7use serde::{Deserialize, Serialize};
8
9#[derive(CubeType, Clone)]
10/// A way to define an input scalar without a generic attached to it.
11///
12/// It uses comptime enum with zero-cost runtime abstraction for kernel generation.
13pub enum InputScalar {
14    F64(f64),
15    F32(f32),
16    F16(f16),
17    BF16(bf16),
18    I64(i64),
19    I32(i32),
20    I16(i16),
21    I8(i8),
22    U64(u64),
23    U32(u32),
24    U16(u16),
25    U8(u8),
26}
27
28impl InputScalar {
29    /// Creates an [InputScalar] from the given element and dtype.
30    ///
31    /// # Panics
32    ///
33    /// If the given numeric element can't be transformed into the passed [ElemType].
34    pub fn new<E: Numeric>(val: E, dtype: impl Into<ElemType>) -> Self {
35        let dtype: ElemType = dtype.into();
36        match dtype {
37            ElemType::Float(float_kind) => match float_kind {
38                FloatKind::F16 => Self::F16(half::f16::from_f32(val.to_f32().unwrap())),
39                FloatKind::BF16 => Self::BF16(half::bf16::from_f32(val.to_f32().unwrap())),
40                FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => {
41                    Self::F32(val.to_f32().unwrap())
42                }
43                FloatKind::F64 => Self::F64(val.to_f64().unwrap()),
44                _ => panic!("Unsupported float element type"),
45            },
46            ElemType::Int(int_kind) => match int_kind {
47                IntKind::I8 => Self::I8(val.to_i8().unwrap()),
48                IntKind::I16 => Self::I16(val.to_i16().unwrap()),
49                IntKind::I32 => Self::I32(val.to_i32().unwrap()),
50                IntKind::I64 => Self::I64(val.to_i64().unwrap()),
51            },
52            ElemType::UInt(uint_kind) => match uint_kind {
53                UIntKind::U8 => Self::U8(val.to_u8().unwrap()),
54                UIntKind::U16 => Self::U16(val.to_u16().unwrap()),
55                UIntKind::U32 => Self::U32(val.to_u32().unwrap()),
56                UIntKind::U64 => Self::U64(val.to_u64().unwrap()),
57            },
58            ElemType::Bool => panic!("Bool isn't a scalar"),
59        }
60    }
61}
62
63#[cube]
64impl InputScalar {
65    /// Reads the scalar with the given element type.
66    ///
67    /// Performs casting if necessary.
68    pub fn get<C: CubePrimitive>(&self) -> C {
69        intrinsic!(|scope| {
70            let dtype = C::as_type(scope).elem_type();
71
72            match self {
73                InputScalarExpand::U64(val) => {
74                    if dtype == ElemType::UInt(cubecl::ir::UIntKind::U64) {
75                        let expand: ExpandElement = val.clone().into();
76                        ExpandElementTyped::from(expand.clone())
77                    } else {
78                        C::__expand_cast_from(scope, val.clone())
79                    }
80                }
81                InputScalarExpand::U32(val) => {
82                    if dtype == ElemType::UInt(cubecl::ir::UIntKind::U32) {
83                        let expand: ExpandElement = val.clone().into();
84                        ExpandElementTyped::from(expand.clone())
85                    } else {
86                        C::__expand_cast_from(scope, val.clone())
87                    }
88                }
89                InputScalarExpand::U16(val) => {
90                    if dtype == ElemType::UInt(cubecl::ir::UIntKind::U16) {
91                        let expand: ExpandElement = val.clone().into();
92                        ExpandElementTyped::from(expand.clone())
93                    } else {
94                        C::__expand_cast_from(scope, val.clone())
95                    }
96                }
97                InputScalarExpand::F64(val) => {
98                    if dtype == ElemType::Float(cubecl::ir::FloatKind::F64) {
99                        let expand: ExpandElement = val.clone().into();
100                        ExpandElementTyped::from(expand.clone())
101                    } else {
102                        C::__expand_cast_from(scope, val.clone())
103                    }
104                }
105                InputScalarExpand::F32(val) => {
106                    if dtype == ElemType::Float(cubecl::ir::FloatKind::F32) {
107                        let expand: ExpandElement = val.clone().into();
108                        ExpandElementTyped::from(expand.clone())
109                    } else {
110                        C::__expand_cast_from(scope, val.clone())
111                    }
112                }
113                InputScalarExpand::F16(val) => {
114                    if dtype == ElemType::Float(cubecl::ir::FloatKind::F16) {
115                        let expand: ExpandElement = val.clone().into();
116                        ExpandElementTyped::from(expand.clone())
117                    } else {
118                        C::__expand_cast_from(scope, val.clone())
119                    }
120                }
121                InputScalarExpand::BF16(val) => {
122                    if dtype == ElemType::Float(cubecl::ir::FloatKind::BF16) {
123                        let expand: ExpandElement = val.clone().into();
124                        ExpandElementTyped::from(expand.clone())
125                    } else {
126                        C::__expand_cast_from(scope, val.clone())
127                    }
128                }
129                InputScalarExpand::U8(val) => {
130                    if dtype == ElemType::UInt(cubecl::ir::UIntKind::U8) {
131                        let expand: ExpandElement = val.clone().into();
132                        ExpandElementTyped::from(expand.clone())
133                    } else {
134                        C::__expand_cast_from(scope, val.clone())
135                    }
136                }
137
138                InputScalarExpand::I64(val) => {
139                    if dtype == ElemType::Int(cubecl::ir::IntKind::I64) {
140                        let expand: ExpandElement = val.clone().into();
141                        ExpandElementTyped::from(expand.clone())
142                    } else {
143                        C::__expand_cast_from(scope, val.clone())
144                    }
145                }
146                InputScalarExpand::I32(val) => {
147                    if dtype == ElemType::Int(cubecl::ir::IntKind::I32) {
148                        let expand: ExpandElement = val.clone().into();
149                        ExpandElementTyped::from(expand.clone())
150                    } else {
151                        C::__expand_cast_from(scope, val.clone())
152                    }
153                }
154                InputScalarExpand::I16(val) => {
155                    if dtype == ElemType::Int(cubecl::ir::IntKind::I16) {
156                        let expand: ExpandElement = val.clone().into();
157                        ExpandElementTyped::from(expand.clone())
158                    } else {
159                        C::__expand_cast_from(scope, val.clone())
160                    }
161                }
162                InputScalarExpand::I8(val) => {
163                    if dtype == ElemType::Int(cubecl::ir::IntKind::I8) {
164                        let expand: ExpandElement = val.clone().into();
165                        ExpandElementTyped::from(expand.clone())
166                    } else {
167                        C::__expand_cast_from(scope, val.clone())
168                    }
169                }
170            }
171        })
172    }
173}
174
175impl LaunchArg for InputScalar {
176    type RuntimeArg<'a, R: Runtime> = InputScalar;
177    type CompilationArg = InputScalarCompilationArg;
178
179    fn compilation_arg<R: Runtime>(arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
180        match arg {
181            InputScalar::F64(_) => {
182                InputScalarCompilationArg::new(ElemType::Float(FloatKind::F64).into())
183            }
184            InputScalar::F32(_) => {
185                InputScalarCompilationArg::new(ElemType::Float(FloatKind::F32).into())
186            }
187            InputScalar::F16(_) => {
188                InputScalarCompilationArg::new(ElemType::Float(FloatKind::F16).into())
189            }
190            InputScalar::BF16(_) => {
191                InputScalarCompilationArg::new(ElemType::Float(FloatKind::BF16).into())
192            }
193            InputScalar::I64(_) => {
194                InputScalarCompilationArg::new(ElemType::Int(IntKind::I64).into())
195            }
196            InputScalar::I32(_) => {
197                InputScalarCompilationArg::new(ElemType::Int(IntKind::I32).into())
198            }
199            InputScalar::I16(_) => {
200                InputScalarCompilationArg::new(ElemType::Int(IntKind::I16).into())
201            }
202            InputScalar::I8(_) => InputScalarCompilationArg::new(ElemType::Int(IntKind::I8).into()),
203            InputScalar::U64(_) => {
204                InputScalarCompilationArg::new(ElemType::UInt(UIntKind::U64).into())
205            }
206            InputScalar::U32(_) => {
207                InputScalarCompilationArg::new(ElemType::UInt(UIntKind::U32).into())
208            }
209            InputScalar::U16(_) => {
210                InputScalarCompilationArg::new(ElemType::UInt(UIntKind::U16).into())
211            }
212            InputScalar::U8(_) => {
213                InputScalarCompilationArg::new(ElemType::UInt(UIntKind::U8).into())
214            }
215        }
216    }
217
218    fn expand(
219        arg: &Self::CompilationArg,
220        builder: &mut KernelBuilder,
221    ) -> <Self as CubeType>::ExpandType {
222        let expand = builder.scalar(arg.ty);
223        match arg.ty.elem_type() {
224            ElemType::Float(float_kind) => match float_kind {
225                FloatKind::F16 => InputScalarExpand::F16(expand.into()),
226                FloatKind::BF16 => InputScalarExpand::BF16(expand.into()),
227                FloatKind::Flex32 => InputScalarExpand::F32(expand.into()),
228                FloatKind::F32 => InputScalarExpand::F32(expand.into()),
229                FloatKind::TF32 => InputScalarExpand::F32(expand.into()),
230                FloatKind::F64 => InputScalarExpand::F32(expand.into()),
231                FloatKind::E2M1
232                | FloatKind::E2M3
233                | FloatKind::E3M2
234                | FloatKind::E4M3
235                | FloatKind::E5M2
236                | FloatKind::UE8M0 => unimplemented!("FP8 can't be passed as scalar"),
237            },
238            ElemType::Int(int_kind) => match int_kind {
239                IntKind::I8 => InputScalarExpand::I8(expand.into()),
240                IntKind::I16 => InputScalarExpand::I16(expand.into()),
241                IntKind::I32 => InputScalarExpand::I32(expand.into()),
242                IntKind::I64 => InputScalarExpand::I64(expand.into()),
243            },
244            ElemType::UInt(uint_kind) => match uint_kind {
245                UIntKind::U8 => InputScalarExpand::U8(expand.into()),
246                UIntKind::U16 => InputScalarExpand::U16(expand.into()),
247                UIntKind::U32 => InputScalarExpand::U32(expand.into()),
248                UIntKind::U64 => InputScalarExpand::U64(expand.into()),
249            },
250            ElemType::Bool => panic!("Bool should be converted first."),
251        }
252    }
253}
254
255#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
256pub struct InputScalarCompilationArg {
257    ty: StorageType,
258}
259
260impl InputScalarCompilationArg {
261    pub fn new(ty: StorageType) -> Self {
262        Self { ty }
263    }
264}
265
266impl CompilationArg for InputScalarCompilationArg {}
267
268impl<R: Runtime> ArgSettings<R> for InputScalar {
269    fn register(&self, launcher: &mut KernelLauncher<R>) {
270        match self {
271            InputScalar::F64(val) => launcher.register_f64(*val),
272            InputScalar::F32(val) => launcher.register_f32(*val),
273            InputScalar::F16(val) => launcher.register_f16(*val),
274            InputScalar::BF16(val) => launcher.register_bf16(*val),
275            InputScalar::I64(val) => launcher.register_i64(*val),
276            InputScalar::I32(val) => launcher.register_i32(*val),
277            InputScalar::I16(val) => launcher.register_i16(*val),
278            InputScalar::I8(val) => launcher.register_i8(*val),
279            InputScalar::U64(val) => launcher.register_u64(*val),
280            InputScalar::U32(val) => launcher.register_u32(*val),
281            InputScalar::U16(val) => launcher.register_u16(*val),
282            InputScalar::U8(val) => launcher.register_u8(*val),
283        }
284    }
285}