cubecl_core/frontend/element/
cast.rs

1use std::num::NonZero;
2
3use cubecl_ir::{ExpandElement, Operator};
4
5use crate::frontend::{CubePrimitive, CubeType, cast};
6use crate::ir::{Instruction, Item, Scope, UnaryOperator, Variable};
7use crate::unexpanded;
8
9use super::ExpandElementTyped;
10
11/// Enable elegant casting from any to any CubeElem
12pub trait Cast: CubePrimitive {
13    fn cast_from<From: CubePrimitive>(value: From) -> Self;
14
15    fn __expand_cast_from<From: CubePrimitive>(
16        scope: &mut Scope,
17        value: ExpandElementTyped<From>,
18    ) -> <Self as CubeType>::ExpandType {
19        if core::any::TypeId::of::<Self>() == core::any::TypeId::of::<From>() {
20            return value.expand.into();
21        }
22        let new_var = scope.create_local(Item::vectorized(
23            <Self as CubePrimitive>::as_elem(scope),
24            value.expand.item.vectorization,
25        ));
26        cast::expand(scope, value, new_var.clone().into());
27        new_var.into()
28    }
29}
30
31impl<P: CubePrimitive> Cast for P {
32    fn cast_from<From: CubePrimitive>(_value: From) -> Self {
33        unexpanded!()
34    }
35}
36
37/// Enables reinterpetring the bits from any value to any other type of the same size.
38pub trait Reinterpret: CubePrimitive {
39    /// Reinterpret the bits of another primitive as this primitive without conversion.
40    #[allow(unused_variables)]
41    fn reinterpret<From: CubePrimitive>(value: From) -> Self {
42        unexpanded!()
43    }
44
45    fn __expand_reinterpret<From: CubePrimitive>(
46        scope: &mut Scope,
47        value: ExpandElementTyped<From>,
48    ) -> <Self as CubeType>::ExpandType {
49        let value: ExpandElement = value.into();
50        let var: Variable = *value;
51        let vectorization = var.elem().size()
52            * var
53                .item
54                .vectorization
55                .unwrap_or(NonZero::new(1).unwrap())
56                .get() as usize
57            / Self::as_elem(scope).size();
58        let new_var = scope.create_local(Item::vectorized(
59            <Self as CubePrimitive>::as_elem(scope),
60            NonZero::new(vectorization as u8),
61        ));
62        scope.register(Instruction::new(
63            Operator::Reinterpret(UnaryOperator { input: *value }),
64            *new_var.clone(),
65        ));
66        new_var.into()
67    }
68}
69
70impl<P: CubePrimitive> Reinterpret for P {}