Skip to main content

cubecl_core/frontend/element/
cast.rs

1use crate::unexpanded;
2use crate::{
3    expand_assert,
4    ir::{Instruction, Operator, Scope, UnaryOperator},
5};
6use crate::{
7    expand_error,
8    frontend::{CubePrimitive, CubeType, cast},
9};
10
11use super::NativeExpand;
12
13/// Enable elegant casting from any to any `CubeElem`
14pub trait Cast: CubePrimitive {
15    fn cast_from<From: CubePrimitive>(value: From) -> Self;
16
17    fn __expand_cast_from<From: CubePrimitive>(
18        scope: &mut Scope,
19        value: NativeExpand<From>,
20    ) -> <Self as CubeType>::ExpandType {
21        if Self::as_type(scope) == From::as_type(scope) {
22            return value.expand.into();
23        }
24        let vec_in = value.expand.vector_size();
25        let elems_in = vec_in * value.expand.ty.packing_factor();
26        let elems_out = Self::__expand_vector_size(scope) * Self::__expand_packing_factor(scope);
27        if vec_in > 1 && elems_in != elems_out {
28            expand_error!("Cast element count must match if input is not scalar");
29        }
30        let new_var = scope.create_local(<Self as CubePrimitive>::as_type(scope));
31        cast::expand::<From, Self>(scope, value, new_var.clone().into());
32        new_var.into()
33    }
34}
35
36impl<P: CubePrimitive> Cast for P {
37    fn cast_from<From: CubePrimitive>(_value: From) -> Self {
38        unexpanded!()
39    }
40}
41
42/// Enables reinterpetring the bits from any value to any other type of the same size.
43pub trait Reinterpret: CubePrimitive {
44    /// Reinterpret the bits of another primitive as this primitive without conversion.
45    #[allow(unused_variables)]
46    fn reinterpret<From: CubePrimitive>(value: From) -> Self {
47        unexpanded!()
48    }
49
50    /// Calculates the expected vectorization for the reinterpret target
51    fn reinterpret_vectorization<From: CubePrimitive>() -> usize {
52        unexpanded!()
53    }
54
55    fn __expand_reinterpret<From: CubePrimitive>(
56        scope: &mut Scope,
57        value: NativeExpand<From>,
58    ) -> <Self as CubeType>::ExpandType {
59        let size_in = value.expand.ty.size();
60        let size_out = Self::__expand_type_size(scope);
61        expand_assert!(size_in == size_out, "Reinterpret type sizes must match");
62        let new_var = scope.create_local(<Self as CubePrimitive>::as_type(scope));
63        scope.register(Instruction::new(
64            Operator::Reinterpret(UnaryOperator {
65                input: *value.expand,
66            }),
67            *new_var.clone(),
68        ));
69        new_var.into()
70    }
71
72    fn __expand_reinterpret_vectorization<From: CubePrimitive>(scope: &mut Scope) -> usize {
73        let type_size = From::__expand_type_size(scope);
74        type_size / Self::Scalar::__expand_type_size(scope)
75    }
76}
77
78impl<P: CubePrimitive> Reinterpret for P {}