cubecl_core/frontend/element/
cast.rs1use crate::unexpanded;
2use crate::{
3 expand_assert,
4 ir::{Instruction, Operator, Scope, UnaryOperator},
5};
6use crate::{
7 expand_error,
8 frontend::{CubePrimitive, CubeType, cast},
9};
10
11use super::NativeExpand;
12
13pub trait Cast: CubePrimitive {
15 fn cast_from<From: CubePrimitive>(value: From) -> Self;
16
17 fn __expand_cast_from<From: CubePrimitive>(
18 scope: &mut Scope,
19 value: NativeExpand<From>,
20 ) -> <Self as CubeType>::ExpandType {
21 if Self::as_type(scope) == From::as_type(scope) {
22 return value.expand.into();
23 }
24 let vec_in = value.expand.vector_size();
25 let elems_in = vec_in * value.expand.ty.packing_factor();
26 let elems_out = Self::__expand_vector_size(scope) * Self::__expand_packing_factor(scope);
27 if vec_in > 1 && elems_in != elems_out {
28 expand_error!("Cast element count must match if input is not scalar");
29 }
30 let new_var = scope.create_local(<Self as CubePrimitive>::as_type(scope));
31 cast::expand::<From, Self>(scope, value, new_var.clone().into());
32 new_var.into()
33 }
34}
35
36impl<P: CubePrimitive> Cast for P {
37 fn cast_from<From: CubePrimitive>(_value: From) -> Self {
38 unexpanded!()
39 }
40}
41
42pub trait Reinterpret: CubePrimitive {
44 #[allow(unused_variables)]
46 fn reinterpret<From: CubePrimitive>(value: From) -> Self {
47 unexpanded!()
48 }
49
50 fn reinterpret_vectorization<From: CubePrimitive>() -> usize {
52 unexpanded!()
53 }
54
55 fn __expand_reinterpret<From: CubePrimitive>(
56 scope: &mut Scope,
57 value: NativeExpand<From>,
58 ) -> <Self as CubeType>::ExpandType {
59 let size_in = value.expand.ty.size();
60 let size_out = Self::__expand_type_size(scope);
61 expand_assert!(size_in == size_out, "Reinterpret type sizes must match");
62 let new_var = scope.create_local(<Self as CubePrimitive>::as_type(scope));
63 scope.register(Instruction::new(
64 Operator::Reinterpret(UnaryOperator {
65 input: *value.expand,
66 }),
67 *new_var.clone(),
68 ));
69 new_var.into()
70 }
71
72 fn __expand_reinterpret_vectorization<From: CubePrimitive>(scope: &mut Scope) -> usize {
73 let type_size = From::__expand_type_size(scope);
74 type_size / Self::Scalar::__expand_type_size(scope)
75 }
76}
77
78impl<P: CubePrimitive> Reinterpret for P {}