Skip to main content

deep_delta_learning/
error.rs

1use std::fmt::{Display, Formatter};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum ConfigValidationError {
5    NonPositiveUsize {
6        field: &'static str,
7        value: usize,
8    },
9    NonPositiveF64 {
10        field: &'static str,
11        value: f64,
12    },
13    NonPositiveF32 {
14        field: &'static str,
15        value: f32,
16    },
17    DModelNotDivisibleByHeads {
18        d_model: usize,
19        num_heads: usize,
20    },
21    HeadDimMismatch {
22        head_dim: usize,
23        num_heads: usize,
24        d_model: usize,
25    },
26    BetaInitOutOfRange {
27        value: f64,
28    },
29}
30
31impl Display for ConfigValidationError {
32    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
33        match self {
34            Self::NonPositiveUsize { field, value } => {
35                write!(f, "{field} must be positive, got {value}")
36            }
37            Self::NonPositiveF64 { field, value } => {
38                write!(f, "{field} must be positive, got {value}")
39            }
40            Self::NonPositiveF32 { field, value } => {
41                write!(f, "{field} must be positive, got {value}")
42            }
43            Self::DModelNotDivisibleByHeads { d_model, num_heads } => write!(
44                f,
45                "d_model ({d_model}) must be divisible by num_heads ({num_heads})"
46            ),
47            Self::HeadDimMismatch {
48                head_dim,
49                num_heads,
50                d_model,
51            } => write!(
52                f,
53                "head_dim ({head_dim}) * num_heads ({num_heads}) must equal d_model ({d_model})"
54            ),
55            Self::BetaInitOutOfRange { value } => {
56                write!(f, "beta_init must be in (0, 2), got {value}")
57            }
58        }
59    }
60}
61
62impl std::error::Error for ConfigValidationError {}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum DataValidationError {
66    InvalidBatchSize(usize),
67    InvalidSeqLen(usize),
68    InvalidStride(usize),
69}
70
71impl Display for DataValidationError {
72    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73        match self {
74            Self::InvalidBatchSize(value) => {
75                write!(f, "batch_size must be positive, got {value}")
76            }
77            Self::InvalidSeqLen(value) => write!(f, "seq_len must be positive, got {value}"),
78            Self::InvalidStride(value) => write!(f, "stride must be positive, got {value}"),
79        }
80    }
81}
82
83impl std::error::Error for DataValidationError {}
84
85#[derive(Debug, Clone, PartialEq)]
86pub enum BaselineConfigError {
87    InvalidConfig(ConfigValidationError),
88    UnsupportedValueDimension { d_value: usize },
89}
90
91impl Display for BaselineConfigError {
92    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
93        match self {
94            Self::InvalidConfig(error) => write!(f, "invalid baseline configuration: {error}"),
95            Self::UnsupportedValueDimension { d_value } => write!(
96                f,
97                "baseline transformer only supports vector-state configs (d_value == 1), got {d_value}"
98            ),
99        }
100    }
101}
102
103impl std::error::Error for BaselineConfigError {
104    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
105        match self {
106            Self::InvalidConfig(error) => Some(error),
107            Self::UnsupportedValueDimension { .. } => None,
108        }
109    }
110}
111
112impl From<ConfigValidationError> for BaselineConfigError {
113    fn from(value: ConfigValidationError) -> Self {
114        Self::InvalidConfig(value)
115    }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub enum SpectralError {
120    InvalidMaxHistory(usize),
121}
122
123impl Display for SpectralError {
124    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
125        match self {
126            Self::InvalidMaxHistory(value) => {
127                write!(f, "max_history must be positive, got {value}")
128            }
129        }
130    }
131}
132
133impl std::error::Error for SpectralError {}