entrenar/autograd/precision/
precision_types.rs1use std::fmt;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
7pub enum Precision {
8 #[default]
10 Fp32,
11 Fp16,
13 Bf16,
15}
16
17impl Precision {
18 pub fn size_bytes(&self) -> usize {
20 match self {
21 Precision::Fp32 => 4,
22 Precision::Fp16 | Precision::Bf16 => 2,
23 }
24 }
25
26 pub fn name(&self) -> &'static str {
28 match self {
29 Precision::Fp32 => "fp32",
30 Precision::Fp16 => "fp16",
31 Precision::Bf16 => "bf16",
32 }
33 }
34
35 pub fn is_reduced(&self) -> bool {
37 matches!(self, Precision::Fp16 | Precision::Bf16)
38 }
39
40 pub fn memory_multiplier(&self) -> f32 {
42 match self {
43 Precision::Fp32 => 1.0,
44 Precision::Fp16 | Precision::Bf16 => 0.5,
45 }
46 }
47}
48
49impl fmt::Display for Precision {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(f, "{}", self.name())
52 }
53}