use burn_tensor::Element;
use serde::{de::DeserializeOwned, Serialize};
pub trait PrecisionSettings:
Send + Sync + core::fmt::Debug + core::default::Default + Clone
{
type FloatElem: Element + Serialize + DeserializeOwned;
type IntElem: Element + Serialize + DeserializeOwned;
}
#[derive(Debug, Default, Clone)]
pub struct FullPrecisionSettings;
#[derive(Debug, Default, Clone)]
pub struct HalfPrecisionSettings;
#[derive(Debug, Default, Clone)]
pub struct DoublePrecisionSettings;
impl PrecisionSettings for FullPrecisionSettings {
type FloatElem = f32;
type IntElem = i32;
}
impl PrecisionSettings for DoublePrecisionSettings {
type FloatElem = f64;
type IntElem = i64;
}
impl PrecisionSettings for HalfPrecisionSettings {
type FloatElem = half::f16;
type IntElem = i16;
}