burn_tensor/tensor/api/
cast.rs1use burn_backend::tensor::{Bool, Float, Int, TensorKind};
2use burn_backend::{Backend, DType, FloatDType, IntDType, TensorMetadata, TensorPrimitive};
3
4pub trait Cast<B: Backend, K: TensorKind<B>> {
9 type OutputKind: TensorKind<B>;
11
12 fn cast(primitive: K::Primitive, dtype: Self)
14 -> <Self::OutputKind as TensorKind<B>>::Primitive;
15}
16
17impl<B: Backend> Cast<B, Float> for FloatDType {
20 type OutputKind = Float;
21
22 fn cast(primitive: TensorPrimitive<B>, dtype: Self) -> TensorPrimitive<B> {
23 if let TensorPrimitive::Float(ref tensor) = primitive {
24 let current: FloatDType = tensor.dtype().into();
25 if current == dtype {
26 return primitive;
27 }
28 }
29 TensorPrimitive::Float(B::float_cast(primitive.tensor(), dtype))
30 }
31}
32
33impl<B: Backend> Cast<B, Float> for IntDType {
34 type OutputKind = Int;
35
36 fn cast(primitive: TensorPrimitive<B>, dtype: Self) -> B::IntTensorPrimitive {
37 B::float_into_int(primitive.tensor(), dtype)
38 }
39}
40
41impl<B: Backend> Cast<B, Float> for DType {
48 type OutputKind = Float;
49
50 fn cast(primitive: TensorPrimitive<B>, dtype: Self) -> TensorPrimitive<B> {
51 let float_dtype: FloatDType = dtype.into();
52 <FloatDType as Cast<B, Float>>::cast(primitive, float_dtype)
53 }
54}
55
56impl<B: Backend> Cast<B, Int> for IntDType {
59 type OutputKind = Int;
60
61 fn cast(primitive: B::IntTensorPrimitive, dtype: Self) -> B::IntTensorPrimitive {
62 let current: IntDType = primitive.dtype().into();
63 if current == dtype {
64 return primitive;
65 }
66 B::int_cast(primitive, dtype)
67 }
68}
69
70impl<B: Backend> Cast<B, Int> for FloatDType {
71 type OutputKind = Float;
72
73 fn cast(primitive: B::IntTensorPrimitive, dtype: Self) -> TensorPrimitive<B> {
74 TensorPrimitive::Float(B::int_into_float(primitive, dtype))
75 }
76}
77
78impl<B: Backend> Cast<B, Int> for DType {
85 type OutputKind = Int;
86
87 fn cast(primitive: B::IntTensorPrimitive, dtype: Self) -> B::IntTensorPrimitive {
88 let int_dtype: IntDType = dtype.into();
89 <IntDType as Cast<B, Int>>::cast(primitive, int_dtype)
90 }
91}
92
93impl<B: Backend> Cast<B, Bool> for IntDType {
96 type OutputKind = Int;
97
98 fn cast(primitive: B::BoolTensorPrimitive, dtype: Self) -> B::IntTensorPrimitive {
99 B::bool_into_int(primitive, dtype)
100 }
101}
102
103impl<B: Backend> Cast<B, Bool> for FloatDType {
104 type OutputKind = Float;
105
106 fn cast(primitive: B::BoolTensorPrimitive, dtype: Self) -> TensorPrimitive<B> {
107 TensorPrimitive::Float(B::bool_into_float(primitive, dtype))
108 }
109}