cubecl_core/frontend/element/
cast.rs1use cubecl_ir::{ExpandElement, Operator};
2
3use crate::frontend::{CubePrimitive, CubeType, cast};
4use crate::ir::{Instruction, Scope, Type, UnaryOperator, Variable};
5use crate::unexpanded;
6
7use super::ExpandElementTyped;
8
9pub trait Cast: CubePrimitive {
11 fn cast_from<From: CubePrimitive>(value: From) -> Self;
12
13 fn __expand_cast_from<From: CubePrimitive>(
14 scope: &mut Scope,
15 value: ExpandElementTyped<From>,
16 ) -> <Self as CubeType>::ExpandType {
17 if core::any::TypeId::of::<Self>() == core::any::TypeId::of::<From>() {
18 return value.expand.into();
19 }
20 let line_size_in = value.expand.ty.line_size();
21 let line_size_out = line_size_in * value.expand.ty.storage_type().packing_factor()
22 / Self::as_type(scope).packing_factor();
23 let new_var = scope
24 .create_local(Type::new(<Self as CubePrimitive>::as_type(scope)).line(line_size_out));
25 cast::expand(scope, value, new_var.clone().into());
26 new_var.into()
27 }
28}
29
30impl<P: CubePrimitive> Cast for P {
31 fn cast_from<From: CubePrimitive>(_value: From) -> Self {
32 unexpanded!()
33 }
34}
35
36pub trait Reinterpret: CubePrimitive {
38 #[allow(unused_variables)]
40 fn reinterpret<From: CubePrimitive>(value: From) -> Self {
41 unexpanded!()
42 }
43
44 fn __expand_reinterpret<From: CubePrimitive>(
45 scope: &mut Scope,
46 value: ExpandElementTyped<From>,
47 ) -> <Self as CubeType>::ExpandType {
48 let value: ExpandElement = value.into();
49 let var: Variable = *value;
50 let line_size = var.ty.size() / Self::as_type(scope).size();
51 let new_var = scope.create_local(
52 Type::new(<Self as CubePrimitive>::as_type(scope)).line(line_size as u32),
53 );
54 scope.register(Instruction::new(
55 Operator::Reinterpret(UnaryOperator { input: *value }),
56 *new_var.clone(),
57 ));
58 new_var.into()
59 }
60}
61
62impl<P: CubePrimitive> Reinterpret for P {}