Skip to main content

oxicuda_lm/
error.rs

1//! Error types for the `oxicuda-lm` crate.
2
3use thiserror::Error;
4
5/// All errors that can arise from LLM inference operations.
6#[derive(Debug, Error, Clone, PartialEq, Eq)]
7pub enum LmError {
8    /// Tensor or slice dimension does not match expectation.
9    #[error("dimension mismatch: expected {expected}, got {got}")]
10    DimensionMismatch { expected: usize, got: usize },
11
12    /// Model or layer configuration is invalid.
13    #[error("invalid config: {msg}")]
14    InvalidConfig { msg: String },
15
16    /// An input that must be non-empty was empty.
17    #[error("empty input: {context}")]
18    EmptyInput { context: &'static str },
19
20    /// A token id is outside the vocabulary range.
21    #[error("out-of-vocabulary token id {token}")]
22    OutOfVocab { token: u32 },
23
24    /// Tokenizer was used before being properly initialised.
25    #[error("tokenizer is not initialised")]
26    TokenizerUninitialized,
27
28    /// Token byte sequence is not valid UTF-8.
29    #[error("UTF-8 decode error for token {token}")]
30    Utf8Decode { token: u32 },
31
32    /// A named weight entry is absent from `ModelWeights`.
33    #[error("weight '{name}' not found")]
34    WeightNotFound { name: String },
35
36    /// A named weight has an unexpected shape.
37    #[error("weight '{name}' shape mismatch: expected {expected:?}, got {got:?}")]
38    WeightShapeMismatch {
39        name: String,
40        expected: Vec<usize>,
41        got: Vec<usize>,
42    },
43
44    /// Transformer layer index is out of range.
45    #[error("layer index {idx} is out of range [0, {n_layers})")]
46    LayerIndexOutOfRange { idx: usize, n_layers: usize },
47
48    /// Hidden dimension is not divisible by the number of attention heads.
49    #[error(
50        "head dimension mismatch: hidden_dim={hidden_dim} must be divisible by n_heads={n_heads}"
51    )]
52    HeadDimMismatch { hidden_dim: usize, n_heads: usize },
53
54    /// KV cache contains a different number of past tokens than expected.
55    #[error("KV cache length mismatch: expected past_len={past_len}, got {got}")]
56    KvCacheLengthMismatch { past_len: usize, got: usize },
57
58    /// Input sequence length exceeds the model's `max_position_embeddings`.
59    #[error("sequence too long: total={total_len} > max_position_embeddings={max_pos}")]
60    SequenceTooLong { total_len: usize, max_pos: usize },
61
62    /// A BPE merge references tokens that are not in the vocabulary.
63    #[error("invalid BPE merge pair: tokens {a} and {b} not both present in vocabulary")]
64    InvalidMergePair { a: u32, b: u32 },
65
66    /// Vocabulary size does not match model configuration.
67    #[error("vocab size mismatch: expected {expected}, got {got}")]
68    VocabSizeMismatch { expected: usize, got: usize },
69
70    /// n_kv_heads does not divide n_heads evenly (GQA requirement).
71    #[error(
72        "GQA constraint violated: n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}"
73    )]
74    GqaHeadMismatch { n_heads: usize, n_kv_heads: usize },
75
76    /// A weight tensor element count does not match its declared shape.
77    #[error("weight data length {data_len} does not match shape {shape:?} product {expected}")]
78    WeightDataLengthMismatch {
79        data_len: usize,
80        shape: Vec<usize>,
81        expected: usize,
82    },
83
84    /// Catch-all for internal invariant violations.
85    #[error("internal error: {msg}")]
86    Internal { msg: String },
87}
88
89/// Convenience alias.
90pub type LmResult<T> = Result<T, LmError>;
91
92// ─── Tests ────────────────────────────────────────────────────────────────────
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn error_display_dimension_mismatch() {
100        let e = LmError::DimensionMismatch {
101            expected: 4,
102            got: 3,
103        };
104        assert!(e.to_string().contains("expected 4"));
105        assert!(e.to_string().contains("got 3"));
106    }
107
108    #[test]
109    fn error_display_invalid_config() {
110        let e = LmError::InvalidConfig {
111            msg: "n_heads must be positive".into(),
112        };
113        assert!(e.to_string().contains("n_heads must be positive"));
114    }
115
116    #[test]
117    fn error_display_weight_not_found() {
118        let e = LmError::WeightNotFound {
119            name: "layer.0.attn.w_q".into(),
120        };
121        assert!(e.to_string().contains("layer.0.attn.w_q"));
122    }
123
124    #[test]
125    fn error_display_weight_shape_mismatch() {
126        let e = LmError::WeightShapeMismatch {
127            name: "embed".into(),
128            expected: vec![128, 64],
129            got: vec![64, 128],
130        };
131        let s = e.to_string();
132        assert!(s.contains("embed"));
133        assert!(s.contains("[128, 64]"));
134        assert!(s.contains("[64, 128]"));
135    }
136
137    #[test]
138    fn error_display_out_of_vocab() {
139        let e = LmError::OutOfVocab { token: 99999 };
140        assert!(e.to_string().contains("99999"));
141    }
142
143    #[test]
144    fn error_display_sequence_too_long() {
145        let e = LmError::SequenceTooLong {
146            total_len: 2048,
147            max_pos: 1024,
148        };
149        let s = e.to_string();
150        assert!(s.contains("2048"));
151        assert!(s.contains("1024"));
152    }
153
154    #[test]
155    fn error_display_gqa_mismatch() {
156        let e = LmError::GqaHeadMismatch {
157            n_heads: 32,
158            n_kv_heads: 5,
159        };
160        let s = e.to_string();
161        assert!(s.contains("32"));
162        assert!(s.contains("5"));
163    }
164
165    #[test]
166    fn error_display_invalid_merge_pair() {
167        let e = LmError::InvalidMergePair { a: 10, b: 20 };
168        let s = e.to_string();
169        assert!(s.contains("10"));
170        assert!(s.contains("20"));
171    }
172
173    #[test]
174    fn error_is_std_error() {
175        let e: Box<dyn std::error::Error> = Box::new(LmError::Internal { msg: "test".into() });
176        assert!(e.to_string().contains("test"));
177    }
178
179    #[test]
180    fn error_clone_and_eq() {
181        let a = LmError::EmptyInput {
182            context: "token_ids",
183        };
184        let b = a.clone();
185        assert_eq!(a, b);
186    }
187}