burn_core/record/
settings.rs

1use burn_tensor::Element;
2use serde::{Serialize, de::DeserializeOwned};
3
4/// Settings allowing to control the precision when (de)serializing items.
5pub trait PrecisionSettings:
6    Send + Sync + core::fmt::Debug + core::default::Default + Clone
7{
8    /// Float element type.
9    type FloatElem: Element + Serialize + DeserializeOwned;
10
11    /// Integer element type.
12    type IntElem: Element + Serialize + DeserializeOwned;
13}
14
15/// Default precision settings.
16#[derive(Debug, Default, Clone)]
17pub struct FullPrecisionSettings;
18
19/// Precision settings optimized for compactness.
20#[derive(Debug, Default, Clone)]
21pub struct HalfPrecisionSettings;
22
23/// Precision settings optimized for precision.
24#[derive(Debug, Default, Clone)]
25pub struct DoublePrecisionSettings;
26
27impl PrecisionSettings for FullPrecisionSettings {
28    type FloatElem = f32;
29    type IntElem = i32;
30}
31
32impl PrecisionSettings for DoublePrecisionSettings {
33    type FloatElem = f64;
34    type IntElem = i64;
35}
36
37impl PrecisionSettings for HalfPrecisionSettings {
38    type FloatElem = half::f16;
39    type IntElem = i16;
40}