cubecl_core/frontend/element/
cast.rs1use cubecl_ir::{ExpandElement, Operator};
2
3use crate::frontend::{CubePrimitive, CubeType, cast};
4use crate::ir::{Instruction, Scope, Type, UnaryOperator, Variable};
5use crate::unexpanded;
6
7use super::ExpandElementTyped;
8
9pub trait Cast: CubePrimitive {
11    fn cast_from<From: CubePrimitive>(value: From) -> Self;
12
13    fn __expand_cast_from<From: CubePrimitive>(
14        scope: &mut Scope,
15        value: ExpandElementTyped<From>,
16    ) -> <Self as CubeType>::ExpandType {
17        if core::any::TypeId::of::<Self>() == core::any::TypeId::of::<From>() {
18            return value.expand.into();
19        }
20        let line_size_in = value.expand.ty.line_size();
21        let line_size_out = line_size_in * value.expand.ty.storage_type().packing_factor()
22            / Self::as_type(scope).packing_factor();
23        let new_var = scope
24            .create_local(Type::new(<Self as CubePrimitive>::as_type(scope)).line(line_size_out));
25        cast::expand(scope, value, new_var.clone().into());
26        new_var.into()
27    }
28}
29
30impl<P: CubePrimitive> Cast for P {
31    fn cast_from<From: CubePrimitive>(_value: From) -> Self {
32        unexpanded!()
33    }
34}
35
36pub trait Reinterpret: CubePrimitive {
38    #[allow(unused_variables)]
40    fn reinterpret<From: CubePrimitive>(value: From) -> Self {
41        unexpanded!()
42    }
43
44    fn __expand_reinterpret<From: CubePrimitive>(
45        scope: &mut Scope,
46        value: ExpandElementTyped<From>,
47    ) -> <Self as CubeType>::ExpandType {
48        let value: ExpandElement = value.into();
49        let var: Variable = *value;
50        let line_size = var.ty.size() / Self::as_type(scope).size();
51        let new_var = scope.create_local(
52            Type::new(<Self as CubePrimitive>::as_type(scope)).line(line_size as u32),
53        );
54        scope.register(Instruction::new(
55            Operator::Reinterpret(UnaryOperator { input: *value }),
56            *new_var.clone(),
57        ));
58        new_var.into()
59    }
60}
61
62impl<P: CubePrimitive> Reinterpret for P {}