use cubecl::{
define_size, flex32,
prelude::{Numeric, Size},
};
define_size!(pub S);
pub trait ReducePrecision: 'static {
type EI: Numeric;
type SI: Size;
type EA: Numeric;
}
impl<EI: Numeric, SI: Size, EA: Numeric> ReducePrecision for (EI, SI, EA) {
type EI = EI;
type SI = SI;
type EA = EA;
}
impl ReducePrecision for f64 {
type EI = f64;
type EA = f64;
type SI = S;
}
impl ReducePrecision for f32 {
type EI = f32;
type EA = f32;
type SI = S;
}
impl ReducePrecision for flex32 {
type EI = f32;
type EA = f32;
type SI = S;
}
impl ReducePrecision for half::f16 {
type EI = half::f16;
type EA = f32;
type SI = S;
}
impl ReducePrecision for half::bf16 {
type EI = half::bf16;
type EA = f32;
type SI = S;
}
impl ReducePrecision for i64 {
type EI = i64;
type EA = i64;
type SI = S;
}
impl ReducePrecision for i32 {
type EI = i32;
type EA = i32;
type SI = S;
}
impl ReducePrecision for i16 {
type EI = i16;
type EA = i32;
type SI = S;
}
impl ReducePrecision for i8 {
type EI = i8;
type EA = i32;
type SI = S;
}
impl ReducePrecision for u64 {
type EI = u64;
type EA = u64;
type SI = S;
}
impl ReducePrecision for u32 {
type EI = u32;
type EA = u32;
type SI = S;
}
impl ReducePrecision for u16 {
type EI = u16;
type EA = u32;
type SI = S;
}
impl ReducePrecision for u8 {
type EI = u8;
type EA = u32;
type SI = S;
}