cubecl_core/frontend/
scalar.rs

1use cubecl::prelude::*;
2use cubecl_common::{e4m3, e5m2, ue8m0};
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    self as cubecl, CubeScalar, intrinsic,
7    ir::{ElemType, ExpandElement, FloatKind, IntKind, Type, UIntKind},
8};
9
10#[derive(Clone, Copy)]
11/// A way to define an input scalar without a generic attached to it.
12///
13/// It uses comptime enum with zero-cost runtime abstraction for kernel generation.
14pub struct InputScalar {
15    data: [u8; 8],
16    dtype: StorageType,
17}
18
19#[derive(Clone)]
20pub struct InputScalarExpand {
21    pub expand: ExpandElement,
22}
23
24impl CubeType for InputScalar {
25    type ExpandType = InputScalarExpand;
26}
27
28impl IntoMut for InputScalarExpand {
29    fn into_mut(self, _scope: &mut Scope) -> Self {
30        self
31    }
32}
33
34impl CubeDebug for InputScalarExpand {}
35
36impl InputScalar {
37    /// Creates an [InputScalar] from the given element and dtype.
38    ///
39    /// # Panics
40    ///
41    /// If the given numeric element can't be transformed into the passed [ElemType].
42    pub fn new<E: num_traits::ToPrimitive>(val: E, dtype: impl Into<StorageType>) -> Self {
43        let dtype: StorageType = dtype.into();
44        let mut out = InputScalar {
45            data: Default::default(),
46            dtype,
47        };
48        fn write<E: CubeScalar>(val: impl num_traits::ToPrimitive, out: &mut [u8]) {
49            let val = [E::from(val).unwrap()];
50            let bytes = E::as_bytes(&val);
51            out[..bytes.len()].copy_from_slice(bytes);
52        }
53        match dtype {
54            StorageType::Scalar(elem) => match elem {
55                ElemType::Float(float_kind) => match float_kind {
56                    FloatKind::F16 => write::<half::f16>(val, &mut out.data),
57                    FloatKind::BF16 => write::<half::bf16>(val, &mut out.data),
58                    FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => {
59                        write::<f32>(val, &mut out.data)
60                    }
61                    FloatKind::F64 => write::<f64>(val, &mut out.data),
62                    FloatKind::E2M1 | FloatKind::E2M3 | FloatKind::E3M2 => {
63                        unimplemented!("fp6 CPU conversion not yet implemented")
64                    }
65                    FloatKind::E4M3 => write::<e4m3>(val, &mut out.data),
66                    FloatKind::E5M2 => write::<e5m2>(val, &mut out.data),
67                    FloatKind::UE8M0 => write::<ue8m0>(val, &mut out.data),
68                },
69                ElemType::Int(int_kind) => match int_kind {
70                    IntKind::I8 => write::<i8>(val, &mut out.data),
71                    IntKind::I16 => write::<i16>(val, &mut out.data),
72                    IntKind::I32 => write::<i32>(val, &mut out.data),
73                    IntKind::I64 => write::<i64>(val, &mut out.data),
74                },
75                ElemType::UInt(uint_kind) => match uint_kind {
76                    UIntKind::U8 => write::<u8>(val, &mut out.data),
77                    UIntKind::U16 => write::<u16>(val, &mut out.data),
78                    UIntKind::U32 => write::<u32>(val, &mut out.data),
79                    UIntKind::U64 => write::<u64>(val, &mut out.data),
80                },
81                ElemType::Bool => panic!("Bool isn't a scalar"),
82            },
83            other => unimplemented!("{other} not supported for scalars"),
84        };
85        out
86    }
87}
88
89#[cube]
90impl InputScalar {
91    /// Reads the scalar with the given element type.
92    ///
93    /// Performs casting if necessary.
94    pub fn get<C: CubePrimitive>(&self) -> C {
95        intrinsic!(|scope| {
96            let dtype = C::as_type(scope);
97            if self.expand.storage_type() == dtype {
98                return self.expand.into();
99            }
100            let new_var = scope.create_local(Type::new(dtype));
101            cast::expand::<C, C>(scope, self.expand.into(), new_var.clone().into());
102            new_var.into()
103        })
104    }
105}
106
107impl InputScalar {
108    pub fn as_bytes(&self) -> Vec<u8> {
109        self.data[..self.dtype.size()].to_vec()
110    }
111}
112
113impl LaunchArg for InputScalar {
114    type RuntimeArg<'a, R: Runtime> = InputScalar;
115    type CompilationArg = InputScalarCompilationArg;
116
117    fn compilation_arg<R: Runtime>(arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
118        InputScalarCompilationArg::new(arg.dtype)
119    }
120
121    fn expand(
122        arg: &Self::CompilationArg,
123        builder: &mut KernelBuilder,
124    ) -> <Self as CubeType>::ExpandType {
125        let expand = builder.scalar(arg.ty);
126        InputScalarExpand { expand }
127    }
128}
129
130#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
131pub struct InputScalarCompilationArg {
132    ty: StorageType,
133}
134
135impl InputScalarCompilationArg {
136    pub fn new(ty: StorageType) -> Self {
137        Self { ty }
138    }
139}
140
141impl CompilationArg for InputScalarCompilationArg {}
142
143impl<R: Runtime> ArgSettings<R> for InputScalar {
144    fn register(&self, launcher: &mut KernelLauncher<R>) {
145        let dtype = self.dtype;
146        launcher.register_scalar_raw(&self.data[..dtype.size()], dtype);
147    }
148}