use burn_backend::tensor::{Bool, Float, Int, TensorKind};
use burn_backend::{Backend, DType, FloatDType, IntDType, TensorMetadata, TensorPrimitive};
pub trait Cast<B: Backend, K: TensorKind<B>> {
type OutputKind: TensorKind<B>;
fn cast(primitive: K::Primitive, dtype: Self)
-> <Self::OutputKind as TensorKind<B>>::Primitive;
}
impl<B: Backend> Cast<B, Float> for FloatDType {
type OutputKind = Float;
fn cast(primitive: TensorPrimitive<B>, dtype: Self) -> TensorPrimitive<B> {
if let TensorPrimitive::Float(ref tensor) = primitive {
let current: FloatDType = tensor.dtype().into();
if current == dtype {
return primitive;
}
}
TensorPrimitive::Float(B::float_cast(primitive.tensor(), dtype))
}
}
impl<B: Backend> Cast<B, Float> for IntDType {
type OutputKind = Int;
fn cast(primitive: TensorPrimitive<B>, dtype: Self) -> B::IntTensorPrimitive {
B::float_into_int(primitive.tensor(), dtype)
}
}
impl<B: Backend> Cast<B, Float> for DType {
type OutputKind = Float;
fn cast(primitive: TensorPrimitive<B>, dtype: Self) -> TensorPrimitive<B> {
let float_dtype: FloatDType = dtype.into();
<FloatDType as Cast<B, Float>>::cast(primitive, float_dtype)
}
}
impl<B: Backend> Cast<B, Int> for IntDType {
type OutputKind = Int;
fn cast(primitive: B::IntTensorPrimitive, dtype: Self) -> B::IntTensorPrimitive {
let current: IntDType = primitive.dtype().into();
if current == dtype {
return primitive;
}
B::int_cast(primitive, dtype)
}
}
impl<B: Backend> Cast<B, Int> for FloatDType {
type OutputKind = Float;
fn cast(primitive: B::IntTensorPrimitive, dtype: Self) -> TensorPrimitive<B> {
TensorPrimitive::Float(B::int_into_float(primitive, dtype))
}
}
impl<B: Backend> Cast<B, Int> for DType {
type OutputKind = Int;
fn cast(primitive: B::IntTensorPrimitive, dtype: Self) -> B::IntTensorPrimitive {
let int_dtype: IntDType = dtype.into();
<IntDType as Cast<B, Int>>::cast(primitive, int_dtype)
}
}
impl<B: Backend> Cast<B, Bool> for IntDType {
type OutputKind = Int;
fn cast(primitive: B::BoolTensorPrimitive, dtype: Self) -> B::IntTensorPrimitive {
B::bool_into_int(primitive, dtype)
}
}
impl<B: Backend> Cast<B, Bool> for FloatDType {
type OutputKind = Float;
fn cast(primitive: B::BoolTensorPrimitive, dtype: Self) -> TensorPrimitive<B> {
TensorPrimitive::Float(B::bool_into_float(primitive, dtype))
}
}