1use thiserror::Error;
4
5#[derive(Debug, Error, Clone, PartialEq, Eq)]
7pub enum LmError {
8 #[error("dimension mismatch: expected {expected}, got {got}")]
10 DimensionMismatch { expected: usize, got: usize },
11
12 #[error("invalid config: {msg}")]
14 InvalidConfig { msg: String },
15
16 #[error("empty input: {context}")]
18 EmptyInput { context: &'static str },
19
20 #[error("out-of-vocabulary token id {token}")]
22 OutOfVocab { token: u32 },
23
24 #[error("tokenizer is not initialised")]
26 TokenizerUninitialized,
27
28 #[error("UTF-8 decode error for token {token}")]
30 Utf8Decode { token: u32 },
31
32 #[error("weight '{name}' not found")]
34 WeightNotFound { name: String },
35
36 #[error("weight '{name}' shape mismatch: expected {expected:?}, got {got:?}")]
38 WeightShapeMismatch {
39 name: String,
40 expected: Vec<usize>,
41 got: Vec<usize>,
42 },
43
44 #[error("layer index {idx} is out of range [0, {n_layers})")]
46 LayerIndexOutOfRange { idx: usize, n_layers: usize },
47
48 #[error(
50 "head dimension mismatch: hidden_dim={hidden_dim} must be divisible by n_heads={n_heads}"
51 )]
52 HeadDimMismatch { hidden_dim: usize, n_heads: usize },
53
54 #[error("KV cache length mismatch: expected past_len={past_len}, got {got}")]
56 KvCacheLengthMismatch { past_len: usize, got: usize },
57
58 #[error("sequence too long: total={total_len} > max_position_embeddings={max_pos}")]
60 SequenceTooLong { total_len: usize, max_pos: usize },
61
62 #[error("invalid BPE merge pair: tokens {a} and {b} not both present in vocabulary")]
64 InvalidMergePair { a: u32, b: u32 },
65
66 #[error("vocab size mismatch: expected {expected}, got {got}")]
68 VocabSizeMismatch { expected: usize, got: usize },
69
70 #[error(
72 "GQA constraint violated: n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}"
73 )]
74 GqaHeadMismatch { n_heads: usize, n_kv_heads: usize },
75
76 #[error("weight data length {data_len} does not match shape {shape:?} product {expected}")]
78 WeightDataLengthMismatch {
79 data_len: usize,
80 shape: Vec<usize>,
81 expected: usize,
82 },
83
84 #[error("internal error: {msg}")]
86 Internal { msg: String },
87}
88
89pub type LmResult<T> = Result<T, LmError>;
91
92#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn error_display_dimension_mismatch() {
100 let e = LmError::DimensionMismatch {
101 expected: 4,
102 got: 3,
103 };
104 assert!(e.to_string().contains("expected 4"));
105 assert!(e.to_string().contains("got 3"));
106 }
107
108 #[test]
109 fn error_display_invalid_config() {
110 let e = LmError::InvalidConfig {
111 msg: "n_heads must be positive".into(),
112 };
113 assert!(e.to_string().contains("n_heads must be positive"));
114 }
115
116 #[test]
117 fn error_display_weight_not_found() {
118 let e = LmError::WeightNotFound {
119 name: "layer.0.attn.w_q".into(),
120 };
121 assert!(e.to_string().contains("layer.0.attn.w_q"));
122 }
123
124 #[test]
125 fn error_display_weight_shape_mismatch() {
126 let e = LmError::WeightShapeMismatch {
127 name: "embed".into(),
128 expected: vec![128, 64],
129 got: vec![64, 128],
130 };
131 let s = e.to_string();
132 assert!(s.contains("embed"));
133 assert!(s.contains("[128, 64]"));
134 assert!(s.contains("[64, 128]"));
135 }
136
137 #[test]
138 fn error_display_out_of_vocab() {
139 let e = LmError::OutOfVocab { token: 99999 };
140 assert!(e.to_string().contains("99999"));
141 }
142
143 #[test]
144 fn error_display_sequence_too_long() {
145 let e = LmError::SequenceTooLong {
146 total_len: 2048,
147 max_pos: 1024,
148 };
149 let s = e.to_string();
150 assert!(s.contains("2048"));
151 assert!(s.contains("1024"));
152 }
153
154 #[test]
155 fn error_display_gqa_mismatch() {
156 let e = LmError::GqaHeadMismatch {
157 n_heads: 32,
158 n_kv_heads: 5,
159 };
160 let s = e.to_string();
161 assert!(s.contains("32"));
162 assert!(s.contains("5"));
163 }
164
165 #[test]
166 fn error_display_invalid_merge_pair() {
167 let e = LmError::InvalidMergePair { a: 10, b: 20 };
168 let s = e.to_string();
169 assert!(s.contains("10"));
170 assert!(s.contains("20"));
171 }
172
173 #[test]
174 fn error_is_std_error() {
175 let e: Box<dyn std::error::Error> = Box::new(LmError::Internal { msg: "test".into() });
176 assert!(e.to_string().contains("test"));
177 }
178
179 #[test]
180 fn error_clone_and_eq() {
181 let a = LmError::EmptyInput {
182 context: "token_ids",
183 };
184 let b = a.clone();
185 assert_eq!(a, b);
186 }
187}