cubecl_core/frontend/element/
cast.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
use crate::ir::{Item, UnaryOperator, Variable};
use crate::{frontend::ExpandElement, unexpanded};
use crate::{
    frontend::{assign, CubeContext, CubePrimitive, CubeType},
    ir::Operator,
};

use super::ExpandElementTyped;

/// Enable elegant casting from any to any CubeElem
pub trait Cast: CubePrimitive {
    fn cast_from<From: CubePrimitive>(value: From) -> Self;

    fn __expand_cast_from<From: CubePrimitive>(
        context: &mut CubeContext,
        value: ExpandElementTyped<From>,
    ) -> <Self as CubeType>::ExpandType {
        if core::any::TypeId::of::<Self>() == core::any::TypeId::of::<From>() {
            return value.expand.into();
        }

        let new_var = context.create_local_binding(Item::vectorized(
            <Self as CubePrimitive>::as_elem(),
            value.expand.item().vectorization,
        ));
        assign::expand(context, value, new_var.clone().into());
        new_var.into()
    }
}

impl<P: CubePrimitive> Cast for P {
    fn cast_from<From: CubePrimitive>(_value: From) -> Self {
        unexpanded!()
    }
}

/// Enables reinterpet-casting/bitcasting from any floating point value to any integer value and vice
/// versa
pub trait BitCast: CubePrimitive {
    /// Reinterpret the bits of another primitive as this primitive without conversion.
    #[allow(unused_variables)]
    fn bitcast_from<From: CubePrimitive>(value: From) -> Self {
        unexpanded!()
    }

    fn __expand_bitcast_from<From: CubePrimitive>(
        context: &mut CubeContext,
        value: ExpandElementTyped<From>,
    ) -> <Self as CubeType>::ExpandType {
        let value: ExpandElement = value.into();
        let var: Variable = *value;
        let new_var = context.create_local_binding(Item::vectorized(
            <Self as CubePrimitive>::as_elem(),
            var.item().vectorization,
        ));
        context.register(Operator::Bitcast(UnaryOperator {
            input: *value,
            out: *new_var.clone(),
        }));
        new_var.into()
    }
}

impl<P: CubePrimitive> BitCast for P {}