cubecl_core/frontend/element/
cast.rs1use std::num::NonZero;
2
3use cubecl_ir::{ExpandElement, Operator};
4
5use crate::frontend::{CubePrimitive, CubeType, cast};
6use crate::ir::{Instruction, Item, Scope, UnaryOperator, Variable};
7use crate::unexpanded;
8
9use super::ExpandElementTyped;
10
11pub trait Cast: CubePrimitive {
13 fn cast_from<From: CubePrimitive>(value: From) -> Self;
14
15 fn __expand_cast_from<From: CubePrimitive>(
16 scope: &mut Scope,
17 value: ExpandElementTyped<From>,
18 ) -> <Self as CubeType>::ExpandType {
19 if core::any::TypeId::of::<Self>() == core::any::TypeId::of::<From>() {
20 return value.expand.into();
21 }
22 let new_var = scope.create_local(Item::vectorized(
23 <Self as CubePrimitive>::as_elem(scope),
24 value.expand.item.vectorization,
25 ));
26 cast::expand(scope, value, new_var.clone().into());
27 new_var.into()
28 }
29}
30
31impl<P: CubePrimitive> Cast for P {
32 fn cast_from<From: CubePrimitive>(_value: From) -> Self {
33 unexpanded!()
34 }
35}
36
37pub trait Reinterpret: CubePrimitive {
39 #[allow(unused_variables)]
41 fn reinterpret<From: CubePrimitive>(value: From) -> Self {
42 unexpanded!()
43 }
44
45 fn __expand_reinterpret<From: CubePrimitive>(
46 scope: &mut Scope,
47 value: ExpandElementTyped<From>,
48 ) -> <Self as CubeType>::ExpandType {
49 let value: ExpandElement = value.into();
50 let var: Variable = *value;
51 let vectorization = var.elem().size()
52 * var
53 .item
54 .vectorization
55 .unwrap_or(NonZero::new(1).unwrap())
56 .get() as usize
57 / Self::as_elem(scope).size();
58 let new_var = scope.create_local(Item::vectorized(
59 <Self as CubePrimitive>::as_elem(scope),
60 NonZero::new(vectorization as u8),
61 ));
62 scope.register(Instruction::new(
63 Operator::Reinterpret(UnaryOperator { input: *value }),
64 *new_var.clone(),
65 ));
66 new_var.into()
67 }
68}
69
70impl<P: CubePrimitive> Reinterpret for P {}