Skip to main content

burn_tensor/tensor/api/
cast.rs

1use burn_backend::tensor::{Bool, Float, Int, TensorKind};
2use burn_backend::{Backend, DType, FloatDType, IntDType, TensorMetadata, TensorPrimitive};
3
4/// Trait for types that represent a valid cast target from a tensor of kind `K`.
5///
6/// The generic parameter `K` is the *input* tensor kind ([`Float`], [`Int`], or [`Bool`]).
7/// Implementors declare the output kind and provide the actual cast logic.
8pub trait Cast<B: Backend, K: TensorKind<B>> {
9    /// The output tensor kind after casting.
10    type OutputKind: TensorKind<B>;
11
12    /// Cast a tensor primitive to the target dtype.
13    fn cast(primitive: K::Primitive, dtype: Self)
14    -> <Self::OutputKind as TensorKind<B>>::Primitive;
15}
16
17// --- Float input impls ---
18
19impl<B: Backend> Cast<B, Float> for FloatDType {
20    type OutputKind = Float;
21
22    fn cast(primitive: TensorPrimitive<B>, dtype: Self) -> TensorPrimitive<B> {
23        if let TensorPrimitive::Float(ref tensor) = primitive {
24            let current: FloatDType = tensor.dtype().into();
25            if current == dtype {
26                return primitive;
27            }
28        }
29        TensorPrimitive::Float(B::float_cast(primitive.tensor(), dtype))
30    }
31}
32
33impl<B: Backend> Cast<B, Float> for IntDType {
34    type OutputKind = Int;
35
36    fn cast(primitive: TensorPrimitive<B>, dtype: Self) -> B::IntTensorPrimitive {
37        B::float_into_int(primitive.tensor(), dtype)
38    }
39}
40
41/// Backward-compatible impl: only float `DType` variants are accepted.
42///
43/// # Panics
44///
45/// Panics if `dtype` is not a float variant (e.g., `DType::I32`).
46/// Use [`IntDType`] directly for cross-kind casting to int.
47impl<B: Backend> Cast<B, Float> for DType {
48    type OutputKind = Float;
49
50    fn cast(primitive: TensorPrimitive<B>, dtype: Self) -> TensorPrimitive<B> {
51        let float_dtype: FloatDType = dtype.into();
52        <FloatDType as Cast<B, Float>>::cast(primitive, float_dtype)
53    }
54}
55
56// --- Int input impls ---
57
58impl<B: Backend> Cast<B, Int> for IntDType {
59    type OutputKind = Int;
60
61    fn cast(primitive: B::IntTensorPrimitive, dtype: Self) -> B::IntTensorPrimitive {
62        let current: IntDType = primitive.dtype().into();
63        if current == dtype {
64            return primitive;
65        }
66        B::int_cast(primitive, dtype)
67    }
68}
69
70impl<B: Backend> Cast<B, Int> for FloatDType {
71    type OutputKind = Float;
72
73    fn cast(primitive: B::IntTensorPrimitive, dtype: Self) -> TensorPrimitive<B> {
74        TensorPrimitive::Float(B::int_into_float(primitive, dtype))
75    }
76}
77
78/// Backward-compatible impl: only int `DType` variants are accepted.
79///
80/// # Panics
81///
82/// Panics if `dtype` is not an int variant (e.g., `DType::F32`).
83/// Use [`FloatDType`] directly for cross-kind casting to float.
84impl<B: Backend> Cast<B, Int> for DType {
85    type OutputKind = Int;
86
87    fn cast(primitive: B::IntTensorPrimitive, dtype: Self) -> B::IntTensorPrimitive {
88        let int_dtype: IntDType = dtype.into();
89        <IntDType as Cast<B, Int>>::cast(primitive, int_dtype)
90    }
91}
92
93// --- Bool input impls ---
94
95impl<B: Backend> Cast<B, Bool> for IntDType {
96    type OutputKind = Int;
97
98    fn cast(primitive: B::BoolTensorPrimitive, dtype: Self) -> B::IntTensorPrimitive {
99        B::bool_into_int(primitive, dtype)
100    }
101}
102
103impl<B: Backend> Cast<B, Bool> for FloatDType {
104    type OutputKind = Float;
105
106    fn cast(primitive: B::BoolTensorPrimitive, dtype: Self) -> TensorPrimitive<B> {
107        TensorPrimitive::Float(B::bool_into_float(primitive, dtype))
108    }
109}