deep_delta_learning/
error.rs1use std::fmt::{Display, Formatter};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum ConfigValidationError {
5 NonPositiveUsize {
6 field: &'static str,
7 value: usize,
8 },
9 NonPositiveF64 {
10 field: &'static str,
11 value: f64,
12 },
13 NonPositiveF32 {
14 field: &'static str,
15 value: f32,
16 },
17 DModelNotDivisibleByHeads {
18 d_model: usize,
19 num_heads: usize,
20 },
21 HeadDimMismatch {
22 head_dim: usize,
23 num_heads: usize,
24 d_model: usize,
25 },
26 BetaInitOutOfRange {
27 value: f64,
28 },
29}
30
31impl Display for ConfigValidationError {
32 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Self::NonPositiveUsize { field, value } => {
35 write!(f, "{field} must be positive, got {value}")
36 }
37 Self::NonPositiveF64 { field, value } => {
38 write!(f, "{field} must be positive, got {value}")
39 }
40 Self::NonPositiveF32 { field, value } => {
41 write!(f, "{field} must be positive, got {value}")
42 }
43 Self::DModelNotDivisibleByHeads { d_model, num_heads } => write!(
44 f,
45 "d_model ({d_model}) must be divisible by num_heads ({num_heads})"
46 ),
47 Self::HeadDimMismatch {
48 head_dim,
49 num_heads,
50 d_model,
51 } => write!(
52 f,
53 "head_dim ({head_dim}) * num_heads ({num_heads}) must equal d_model ({d_model})"
54 ),
55 Self::BetaInitOutOfRange { value } => {
56 write!(f, "beta_init must be in (0, 2), got {value}")
57 }
58 }
59 }
60}
61
62impl std::error::Error for ConfigValidationError {}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum DataValidationError {
66 InvalidBatchSize(usize),
67 InvalidSeqLen(usize),
68 InvalidStride(usize),
69}
70
71impl Display for DataValidationError {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::InvalidBatchSize(value) => {
75 write!(f, "batch_size must be positive, got {value}")
76 }
77 Self::InvalidSeqLen(value) => write!(f, "seq_len must be positive, got {value}"),
78 Self::InvalidStride(value) => write!(f, "stride must be positive, got {value}"),
79 }
80 }
81}
82
83impl std::error::Error for DataValidationError {}
84
85#[derive(Debug, Clone, PartialEq)]
86pub enum BaselineConfigError {
87 InvalidConfig(ConfigValidationError),
88 UnsupportedValueDimension { d_value: usize },
89}
90
91impl Display for BaselineConfigError {
92 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Self::InvalidConfig(error) => write!(f, "invalid baseline configuration: {error}"),
95 Self::UnsupportedValueDimension { d_value } => write!(
96 f,
97 "baseline transformer only supports vector-state configs (d_value == 1), got {d_value}"
98 ),
99 }
100 }
101}
102
103impl std::error::Error for BaselineConfigError {
104 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
105 match self {
106 Self::InvalidConfig(error) => Some(error),
107 Self::UnsupportedValueDimension { .. } => None,
108 }
109 }
110}
111
112impl From<ConfigValidationError> for BaselineConfigError {
113 fn from(value: ConfigValidationError) -> Self {
114 Self::InvalidConfig(value)
115 }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub enum SpectralError {
120 InvalidMaxHistory(usize),
121}
122
123impl Display for SpectralError {
124 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
125 match self {
126 Self::InvalidMaxHistory(value) => {
127 write!(f, "max_history must be positive, got {value}")
128 }
129 }
130 }
131}
132
133impl std::error::Error for SpectralError {}