1use cubecl::prelude::*;
2use cubecl_common::{e4m3, e5m2, ue8m0};
3use cubecl_core::{
4 self as cubecl, CubeScalar, intrinsic,
5 ir::{ElemType, ExpandElement, FloatKind, IntKind, Type, UIntKind},
6};
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone, Copy)]
10pub struct InputScalar {
14 data: [u8; 8],
15 dtype: StorageType,
16}
17
18#[derive(Clone)]
19pub struct InputScalarExpand {
20 pub expand: ExpandElement,
21}
22
23impl CubeType for InputScalar {
24 type ExpandType = InputScalarExpand;
25}
26
27impl IntoMut for InputScalarExpand {
28 fn into_mut(self, _scope: &mut Scope) -> Self {
29 self
30 }
31}
32
33impl CubeDebug for InputScalarExpand {}
34
35impl InputScalar {
36 pub fn new<E: num_traits::ToPrimitive>(val: E, dtype: impl Into<StorageType>) -> Self {
42 let dtype: StorageType = dtype.into();
43 let mut out = InputScalar {
44 data: Default::default(),
45 dtype,
46 };
47 fn write<E: CubeScalar>(val: impl num_traits::ToPrimitive, out: &mut [u8]) {
48 let val = [E::from(val).unwrap()];
49 let bytes = E::as_bytes(&val);
50 out[..bytes.len()].copy_from_slice(bytes);
51 }
52 match dtype {
53 StorageType::Scalar(elem) => match elem {
54 ElemType::Float(float_kind) => match float_kind {
55 FloatKind::F16 => write::<half::f16>(val, &mut out.data),
56 FloatKind::BF16 => write::<half::bf16>(val, &mut out.data),
57 FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => {
58 write::<f32>(val, &mut out.data)
59 }
60 FloatKind::F64 => write::<f64>(val, &mut out.data),
61 FloatKind::E2M1 | FloatKind::E2M3 | FloatKind::E3M2 => {
62 unimplemented!("fp6 CPU conversion not yet implemented")
63 }
64 FloatKind::E4M3 => write::<e4m3>(val, &mut out.data),
65 FloatKind::E5M2 => write::<e5m2>(val, &mut out.data),
66 FloatKind::UE8M0 => write::<ue8m0>(val, &mut out.data),
67 },
68 ElemType::Int(int_kind) => match int_kind {
69 IntKind::I8 => write::<i8>(val, &mut out.data),
70 IntKind::I16 => write::<i16>(val, &mut out.data),
71 IntKind::I32 => write::<i32>(val, &mut out.data),
72 IntKind::I64 => write::<i64>(val, &mut out.data),
73 },
74 ElemType::UInt(uint_kind) => match uint_kind {
75 UIntKind::U8 => write::<u8>(val, &mut out.data),
76 UIntKind::U16 => write::<u16>(val, &mut out.data),
77 UIntKind::U32 => write::<u32>(val, &mut out.data),
78 UIntKind::U64 => write::<u64>(val, &mut out.data),
79 },
80 ElemType::Bool => panic!("Bool isn't a scalar"),
81 },
82 other => unimplemented!("{other} not supported for scalars"),
83 };
84 out
85 }
86}
87
88#[cube]
89impl InputScalar {
90 pub fn get<C: CubePrimitive>(&self) -> C {
94 intrinsic!(|scope| {
95 let dtype = C::as_type(scope);
96 if self.expand.storage_type() == dtype {
97 return self.expand.into();
98 }
99 let new_var = scope.create_local(Type::new(dtype));
100 cast::expand::<C, C>(scope, self.expand.into(), new_var.clone().into());
101 new_var.into()
102 })
103 }
104}
105
106impl LaunchArg for InputScalar {
107 type RuntimeArg<'a, R: Runtime> = InputScalar;
108 type CompilationArg = InputScalarCompilationArg;
109
110 fn compilation_arg<R: Runtime>(arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
111 InputScalarCompilationArg::new(arg.dtype)
112 }
113
114 fn expand(
115 arg: &Self::CompilationArg,
116 builder: &mut KernelBuilder,
117 ) -> <Self as CubeType>::ExpandType {
118 let expand = builder.scalar(arg.ty);
119 InputScalarExpand { expand }
120 }
121}
122
123#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
124pub struct InputScalarCompilationArg {
125 ty: StorageType,
126}
127
128impl InputScalarCompilationArg {
129 pub fn new(ty: StorageType) -> Self {
130 Self { ty }
131 }
132}
133
134impl CompilationArg for InputScalarCompilationArg {}
135
136impl<R: Runtime> ArgSettings<R> for InputScalar {
137 fn register(&self, launcher: &mut KernelLauncher<R>) {
138 let dtype = self.dtype;
139 launcher.register_scalar_raw(&self.data[..dtype.size()], dtype);
140 }
141}