use crate::scalar::{bf16, f8e4m3, i4};
use super::scalar::Scalar;
pub trait FetchCast<D: Scalar>: Into<D> + Cast<D> {}
impl<D> FetchCast<D> for D where D: Scalar {}
impl FetchCast<i32> for i8 {}
impl FetchCast<f32> for bf16 {}
impl FetchCast<f32> for f8e4m3 {}
impl FetchCast<i32> for i4 {}
pub trait Cast<D: Scalar> {
fn cast(self) -> D;
}
impl<D: Scalar> Cast<D> for D {
fn cast(self) -> D {
self
}
}
impl Cast<i32> for i8 {
fn cast(self) -> i32 {
self as i32
}
}
impl Cast<i8> for i32 {
fn cast(self) -> i8 {
self as i8
}
}
impl Cast<f32> for bf16 {
fn cast(self) -> f32 {
self.to_f32()
}
}
impl Cast<bf16> for f32 {
fn cast(self) -> bf16 {
bf16::from_f32(self)
}
}
impl Cast<f32> for f8e4m3 {
fn cast(self) -> f32 {
self.to_f32()
}
}
impl Cast<f8e4m3> for f32 {
fn cast(self) -> f8e4m3 {
f8e4m3::from_f32(self)
}
}
impl Cast<i32> for i4 {
fn cast(self) -> i32 {
self.to_i32()
}
}
impl Cast<i4> for i32 {
fn cast(self) -> i4 {
i4::from_i32(self)
}
}
pub trait ContractionCast: Scalar {
type Output: Scalar;
}
impl ContractionCast for i8 {
type Output = i32;
}
impl ContractionCast for bf16 {
type Output = f32;
}
impl ContractionCast for f8e4m3 {
type Output = f32;
}
impl ContractionCast for i4 {
type Output = i32;
}