use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum LmError {
#[error("dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch { expected: usize, got: usize },
#[error("invalid config: {msg}")]
InvalidConfig { msg: String },
#[error("empty input: {context}")]
EmptyInput { context: &'static str },
#[error("out-of-vocabulary token id {token}")]
OutOfVocab { token: u32 },
#[error("tokenizer is not initialised")]
TokenizerUninitialized,
#[error("UTF-8 decode error for token {token}")]
Utf8Decode { token: u32 },
#[error("weight '{name}' not found")]
WeightNotFound { name: String },
#[error("weight '{name}' shape mismatch: expected {expected:?}, got {got:?}")]
WeightShapeMismatch {
name: String,
expected: Vec<usize>,
got: Vec<usize>,
},
#[error("layer index {idx} is out of range [0, {n_layers})")]
LayerIndexOutOfRange { idx: usize, n_layers: usize },
#[error(
"head dimension mismatch: hidden_dim={hidden_dim} must be divisible by n_heads={n_heads}"
)]
HeadDimMismatch { hidden_dim: usize, n_heads: usize },
#[error("KV cache length mismatch: expected past_len={past_len}, got {got}")]
KvCacheLengthMismatch { past_len: usize, got: usize },
#[error("sequence too long: total={total_len} > max_position_embeddings={max_pos}")]
SequenceTooLong { total_len: usize, max_pos: usize },
#[error("invalid BPE merge pair: tokens {a} and {b} not both present in vocabulary")]
InvalidMergePair { a: u32, b: u32 },
#[error("vocab size mismatch: expected {expected}, got {got}")]
VocabSizeMismatch { expected: usize, got: usize },
#[error(
"GQA constraint violated: n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}"
)]
GqaHeadMismatch { n_heads: usize, n_kv_heads: usize },
#[error("weight data length {data_len} does not match shape {shape:?} product {expected}")]
WeightDataLengthMismatch {
data_len: usize,
shape: Vec<usize>,
expected: usize,
},
#[error("internal error: {msg}")]
Internal { msg: String },
}
pub type LmResult<T> = Result<T, LmError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_display_dimension_mismatch() {
let e = LmError::DimensionMismatch {
expected: 4,
got: 3,
};
assert!(e.to_string().contains("expected 4"));
assert!(e.to_string().contains("got 3"));
}
#[test]
fn error_display_invalid_config() {
let e = LmError::InvalidConfig {
msg: "n_heads must be positive".into(),
};
assert!(e.to_string().contains("n_heads must be positive"));
}
#[test]
fn error_display_weight_not_found() {
let e = LmError::WeightNotFound {
name: "layer.0.attn.w_q".into(),
};
assert!(e.to_string().contains("layer.0.attn.w_q"));
}
#[test]
fn error_display_weight_shape_mismatch() {
let e = LmError::WeightShapeMismatch {
name: "embed".into(),
expected: vec![128, 64],
got: vec![64, 128],
};
let s = e.to_string();
assert!(s.contains("embed"));
assert!(s.contains("[128, 64]"));
assert!(s.contains("[64, 128]"));
}
#[test]
fn error_display_out_of_vocab() {
let e = LmError::OutOfVocab { token: 99999 };
assert!(e.to_string().contains("99999"));
}
#[test]
fn error_display_sequence_too_long() {
let e = LmError::SequenceTooLong {
total_len: 2048,
max_pos: 1024,
};
let s = e.to_string();
assert!(s.contains("2048"));
assert!(s.contains("1024"));
}
#[test]
fn error_display_gqa_mismatch() {
let e = LmError::GqaHeadMismatch {
n_heads: 32,
n_kv_heads: 5,
};
let s = e.to_string();
assert!(s.contains("32"));
assert!(s.contains("5"));
}
#[test]
fn error_display_invalid_merge_pair() {
let e = LmError::InvalidMergePair { a: 10, b: 20 };
let s = e.to_string();
assert!(s.contains("10"));
assert!(s.contains("20"));
}
#[test]
fn error_is_std_error() {
let e: Box<dyn std::error::Error> = Box::new(LmError::Internal { msg: "test".into() });
assert!(e.to_string().contains("test"));
}
#[test]
fn error_clone_and_eq() {
let a = LmError::EmptyInput {
context: "token_ids",
};
let b = a.clone();
assert_eq!(a, b);
}
}