Skip to main content

cubecl_core/frontend/
scalar.rs

1use alloc::vec::Vec;
2use cubecl::prelude::*;
3use cubecl_common::{e4m3, e5m2, ue8m0};
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    self as cubecl, CubeScalar, intrinsic,
8    ir::{ElemType, ExpandElement, FloatKind, IntKind, Type, UIntKind},
9};
10
11#[derive(Clone, Copy)]
12/// A way to define an input scalar without a generic attached to it.
13///
14/// It uses comptime enum with zero-cost runtime abstraction for kernel generation.
15pub struct InputScalar {
16    data: [u8; 8],
17    dtype: StorageType,
18}
19
20#[derive(Clone)]
21pub struct InputScalarExpand {
22    pub expand: ExpandElement,
23}
24
25impl CubeType for InputScalar {
26    type ExpandType = InputScalarExpand;
27}
28
29impl IntoMut for InputScalarExpand {
30    fn into_mut(self, _scope: &mut Scope) -> Self {
31        self
32    }
33}
34
35impl CubeDebug for InputScalarExpand {}
36
37impl InputScalar {
38    /// Creates an [`InputScalar`] from the given element and dtype.
39    ///
40    /// # Panics
41    ///
42    /// If the given numeric element can't be transformed into the passed [`ElemType`].
43    pub fn new<E: num_traits::ToPrimitive>(val: E, dtype: impl Into<StorageType>) -> Self {
44        let dtype: StorageType = dtype.into();
45        let mut out = InputScalar {
46            data: Default::default(),
47            dtype,
48        };
49        fn write<E: CubeScalar>(val: impl num_traits::ToPrimitive, out: &mut [u8]) {
50            let val = [E::from(val).unwrap()];
51            let bytes = E::as_bytes(&val);
52            out[..bytes.len()].copy_from_slice(bytes);
53        }
54        match dtype {
55            StorageType::Scalar(elem) => match elem {
56                ElemType::Float(float_kind) => match float_kind {
57                    FloatKind::F16 => write::<half::f16>(val, &mut out.data),
58                    FloatKind::BF16 => write::<half::bf16>(val, &mut out.data),
59                    FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => {
60                        write::<f32>(val, &mut out.data)
61                    }
62                    FloatKind::F64 => write::<f64>(val, &mut out.data),
63                    FloatKind::E2M1 | FloatKind::E2M3 | FloatKind::E3M2 => {
64                        unimplemented!("fp6 CPU conversion not yet implemented")
65                    }
66                    FloatKind::E4M3 => write::<e4m3>(val, &mut out.data),
67                    FloatKind::E5M2 => write::<e5m2>(val, &mut out.data),
68                    FloatKind::UE8M0 => write::<ue8m0>(val, &mut out.data),
69                },
70                ElemType::Int(int_kind) => match int_kind {
71                    IntKind::I8 => write::<i8>(val, &mut out.data),
72                    IntKind::I16 => write::<i16>(val, &mut out.data),
73                    IntKind::I32 => write::<i32>(val, &mut out.data),
74                    IntKind::I64 => write::<i64>(val, &mut out.data),
75                },
76                ElemType::UInt(uint_kind) => match uint_kind {
77                    UIntKind::U8 => write::<u8>(val, &mut out.data),
78                    UIntKind::U16 => write::<u16>(val, &mut out.data),
79                    UIntKind::U32 => write::<u32>(val, &mut out.data),
80                    UIntKind::U64 => write::<u64>(val, &mut out.data),
81                },
82                ElemType::Bool => panic!("Bool isn't a scalar"),
83            },
84            other => unimplemented!("{other} not supported for scalars"),
85        };
86        out
87    }
88}
89
90#[cube]
91impl InputScalar {
92    /// Reads the scalar with the given element type.
93    ///
94    /// Performs casting if necessary.
95    pub fn get<C: CubePrimitive>(&self) -> C {
96        intrinsic!(|scope| {
97            let dtype = C::as_type(scope);
98            if self.expand.storage_type() == dtype {
99                return self.expand.into();
100            }
101            let new_var = scope.create_local(Type::new(dtype));
102            cast::expand::<C, C>(scope, self.expand.into(), new_var.clone().into());
103            new_var.into()
104        })
105    }
106}
107
108impl InputScalar {
109    pub fn as_bytes(&self) -> Vec<u8> {
110        self.data[..self.dtype.size()].to_vec()
111    }
112}
113
114impl LaunchArg for InputScalar {
115    type RuntimeArg<'a, R: Runtime> = InputScalar;
116    type CompilationArg = InputScalarCompilationArg;
117
118    fn compilation_arg<R: Runtime>(arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
119        InputScalarCompilationArg::new(arg.dtype)
120    }
121
122    fn expand(
123        arg: &Self::CompilationArg,
124        builder: &mut KernelBuilder,
125    ) -> <Self as CubeType>::ExpandType {
126        let expand = builder.scalar(arg.ty);
127        InputScalarExpand { expand }
128    }
129}
130
131#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
132pub struct InputScalarCompilationArg {
133    ty: StorageType,
134}
135
136impl InputScalarCompilationArg {
137    pub fn new(ty: StorageType) -> Self {
138        Self { ty }
139    }
140}
141
142impl CompilationArg for InputScalarCompilationArg {}
143
144impl<R: Runtime> ArgSettings<R> for InputScalar {
145    fn register(&self, launcher: &mut KernelLauncher<R>) {
146        let dtype = self.dtype;
147        launcher.register_scalar_raw(&self.data[..dtype.size()], dtype);
148    }
149}