Skip to main content

entrenar/autograd/precision/
precision_types.rs

1//! Precision type definitions for mixed-precision training.
2
3use std::fmt;
4
5/// Data type precision levels
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
7pub enum Precision {
8    /// 32-bit floating point (default)
9    #[default]
10    Fp32,
11    /// 16-bit floating point (IEEE half precision)
12    Fp16,
13    /// 16-bit brain floating point (truncated mantissa)
14    Bf16,
15}
16
17impl Precision {
18    /// Size in bytes
19    pub fn size_bytes(&self) -> usize {
20        match self {
21            Precision::Fp32 => 4,
22            Precision::Fp16 | Precision::Bf16 => 2,
23        }
24    }
25
26    /// Human-readable name
27    pub fn name(&self) -> &'static str {
28        match self {
29            Precision::Fp32 => "fp32",
30            Precision::Fp16 => "fp16",
31            Precision::Bf16 => "bf16",
32        }
33    }
34
35    /// Whether this is a reduced precision type
36    pub fn is_reduced(&self) -> bool {
37        matches!(self, Precision::Fp16 | Precision::Bf16)
38    }
39
40    /// Memory multiplier compared to fp32
41    pub fn memory_multiplier(&self) -> f32 {
42        match self {
43            Precision::Fp32 => 1.0,
44            Precision::Fp16 | Precision::Bf16 => 0.5,
45        }
46    }
47}
48
49impl fmt::Display for Precision {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        write!(f, "{}", self.name())
52    }
53}