cubecl_core/frontend/element/
cast.rs

1use crate::ir::{Instruction, Item, UnaryOperator, Variable};
2use crate::{frontend::ExpandElement, unexpanded};
3use crate::{
4    frontend::{cast, CubeContext, CubePrimitive, CubeType},
5    ir::Operator,
6};
7
8use super::ExpandElementTyped;
9
10/// Enable elegant casting from any to any CubeElem
11pub trait Cast: CubePrimitive {
12    fn cast_from<From: CubePrimitive>(value: From) -> Self;
13
14    fn __expand_cast_from<From: CubePrimitive>(
15        context: &mut CubeContext,
16        value: ExpandElementTyped<From>,
17    ) -> <Self as CubeType>::ExpandType {
18        if core::any::TypeId::of::<Self>() == core::any::TypeId::of::<From>() {
19            return value.expand.into();
20        }
21        let new_var = context.create_local(Item::vectorized(
22            <Self as CubePrimitive>::as_elem(context),
23            value.expand.item.vectorization,
24        ));
25        cast::expand(context, value, new_var.clone().into());
26        new_var.into()
27    }
28}
29
30impl<P: CubePrimitive> Cast for P {
31    fn cast_from<From: CubePrimitive>(_value: From) -> Self {
32        unexpanded!()
33    }
34}
35
36/// Enables reinterpet-casting/bitcasting from any floating point value to any integer value and vice
37/// versa
38pub trait BitCast: CubePrimitive {
39    /// Reinterpret the bits of another primitive as this primitive without conversion.
40    #[allow(unused_variables)]
41    fn bitcast_from<From: CubePrimitive>(value: From) -> Self {
42        unexpanded!()
43    }
44
45    fn __expand_bitcast_from<From: CubePrimitive>(
46        context: &mut CubeContext,
47        value: ExpandElementTyped<From>,
48    ) -> <Self as CubeType>::ExpandType {
49        let value: ExpandElement = value.into();
50        let var: Variable = *value;
51        let new_var = context.create_local(Item::vectorized(
52            <Self as CubePrimitive>::as_elem(context),
53            var.item.vectorization,
54        ));
55        context.register(Instruction::new(
56            Operator::Bitcast(UnaryOperator { input: *value }),
57            *new_var.clone(),
58        ));
59        new_var.into()
60    }
61}
62
63impl<P: CubePrimitive> BitCast for P {}