burn_core/record/
settings.rs1use burn_tensor::Element;
2use serde::{Serialize, de::DeserializeOwned};
3
4pub trait PrecisionSettings:
6 Send + Sync + core::fmt::Debug + core::default::Default + Clone
7{
8 type FloatElem: Element + Serialize + DeserializeOwned;
10
11 type IntElem: Element + Serialize + DeserializeOwned;
13}
14
15#[derive(Debug, Default, Clone)]
17pub struct FullPrecisionSettings;
18
19#[derive(Debug, Default, Clone)]
21pub struct HalfPrecisionSettings;
22
23#[derive(Debug, Default, Clone)]
25pub struct DoublePrecisionSettings;
26
27impl PrecisionSettings for FullPrecisionSettings {
28 type FloatElem = f32;
29 type IntElem = i32;
30}
31
32impl PrecisionSettings for DoublePrecisionSettings {
33 type FloatElem = f64;
34 type IntElem = i64;
35}
36
37impl PrecisionSettings for HalfPrecisionSettings {
38 type FloatElem = half::f16;
39 type IntElem = i16;
40}