cubecl_std/
scalar.rs

1use cubecl::prelude::*;
2use cubecl_common::{e4m3, e5m2, ue8m0};
3use cubecl_core::{
4    self as cubecl, CubeScalar, intrinsic,
5    ir::{ElemType, ExpandElement, FloatKind, IntKind, Type, UIntKind},
6};
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone, Copy)]
10/// A way to define an input scalar without a generic attached to it.
11///
12/// It uses comptime enum with zero-cost runtime abstraction for kernel generation.
13pub struct InputScalar {
14    data: [u8; 8],
15    dtype: StorageType,
16}
17
18#[derive(Clone)]
19pub struct InputScalarExpand {
20    pub expand: ExpandElement,
21}
22
23impl CubeType for InputScalar {
24    type ExpandType = InputScalarExpand;
25}
26
27impl IntoMut for InputScalarExpand {
28    fn into_mut(self, _scope: &mut Scope) -> Self {
29        self
30    }
31}
32
33impl CubeDebug for InputScalarExpand {}
34
35impl InputScalar {
36    /// Creates an [InputScalar] from the given element and dtype.
37    ///
38    /// # Panics
39    ///
40    /// If the given numeric element can't be transformed into the passed [ElemType].
41    pub fn new<E: num_traits::ToPrimitive>(val: E, dtype: impl Into<StorageType>) -> Self {
42        let dtype: StorageType = dtype.into();
43        let mut out = InputScalar {
44            data: Default::default(),
45            dtype,
46        };
47        fn write<E: CubeScalar>(val: impl num_traits::ToPrimitive, out: &mut [u8]) {
48            let val = [E::from(val).unwrap()];
49            let bytes = E::as_bytes(&val);
50            out[..bytes.len()].copy_from_slice(bytes);
51        }
52        match dtype {
53            StorageType::Scalar(elem) => match elem {
54                ElemType::Float(float_kind) => match float_kind {
55                    FloatKind::F16 => write::<half::f16>(val, &mut out.data),
56                    FloatKind::BF16 => write::<half::bf16>(val, &mut out.data),
57                    FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => {
58                        write::<f32>(val, &mut out.data)
59                    }
60                    FloatKind::F64 => write::<f64>(val, &mut out.data),
61                    FloatKind::E2M1 | FloatKind::E2M3 | FloatKind::E3M2 => {
62                        unimplemented!("fp6 CPU conversion not yet implemented")
63                    }
64                    FloatKind::E4M3 => write::<e4m3>(val, &mut out.data),
65                    FloatKind::E5M2 => write::<e5m2>(val, &mut out.data),
66                    FloatKind::UE8M0 => write::<ue8m0>(val, &mut out.data),
67                },
68                ElemType::Int(int_kind) => match int_kind {
69                    IntKind::I8 => write::<i8>(val, &mut out.data),
70                    IntKind::I16 => write::<i16>(val, &mut out.data),
71                    IntKind::I32 => write::<i32>(val, &mut out.data),
72                    IntKind::I64 => write::<i64>(val, &mut out.data),
73                },
74                ElemType::UInt(uint_kind) => match uint_kind {
75                    UIntKind::U8 => write::<u8>(val, &mut out.data),
76                    UIntKind::U16 => write::<u16>(val, &mut out.data),
77                    UIntKind::U32 => write::<u32>(val, &mut out.data),
78                    UIntKind::U64 => write::<u64>(val, &mut out.data),
79                },
80                ElemType::Bool => panic!("Bool isn't a scalar"),
81            },
82            other => unimplemented!("{other} not supported for scalars"),
83        };
84        out
85    }
86}
87
88#[cube]
89impl InputScalar {
90    /// Reads the scalar with the given element type.
91    ///
92    /// Performs casting if necessary.
93    pub fn get<C: CubePrimitive>(&self) -> C {
94        intrinsic!(|scope| {
95            let dtype = C::as_type(scope);
96            if self.expand.storage_type() == dtype {
97                return self.expand.into();
98            }
99            let new_var = scope.create_local(Type::new(dtype));
100            cast::expand::<C, C>(scope, self.expand.into(), new_var.clone().into());
101            new_var.into()
102        })
103    }
104}
105
106impl LaunchArg for InputScalar {
107    type RuntimeArg<'a, R: Runtime> = InputScalar;
108    type CompilationArg = InputScalarCompilationArg;
109
110    fn compilation_arg<R: Runtime>(arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
111        InputScalarCompilationArg::new(arg.dtype)
112    }
113
114    fn expand(
115        arg: &Self::CompilationArg,
116        builder: &mut KernelBuilder,
117    ) -> <Self as CubeType>::ExpandType {
118        let expand = builder.scalar(arg.ty);
119        InputScalarExpand { expand }
120    }
121}
122
123#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
124pub struct InputScalarCompilationArg {
125    ty: StorageType,
126}
127
128impl InputScalarCompilationArg {
129    pub fn new(ty: StorageType) -> Self {
130        Self { ty }
131    }
132}
133
134impl CompilationArg for InputScalarCompilationArg {}
135
136impl<R: Runtime> ArgSettings<R> for InputScalar {
137    fn register(&self, launcher: &mut KernelLauncher<R>) {
138        let dtype = self.dtype;
139        launcher.register_scalar_raw(&self.data[..dtype.size()], dtype);
140    }
141}