use thiserror::Error;
#[derive(Debug, Clone, Error, PartialEq)]
pub enum SpeculativeDecodingError {
#[error(
"vocab size mismatch between draft ({draft}) and target ({target}) \
speculative-decoding models"
)]
VocabMismatch { draft: usize, target: usize },
#[error("distribution row width mismatch: expected {expected}, got {got}")]
DistributionWidthMismatch { expected: usize, got: usize },
#[error(
"draft proposal shape mismatch: tokens={tokens}, token_logprobs={logprobs}, \
distributions={distributions}"
)]
DraftShapeMismatch {
tokens: usize,
logprobs: usize,
distributions: usize,
},
#[error("target verification shape mismatch: expected {expected} rows (k+1), got {got}")]
TargetShapeMismatch { expected: usize, got: usize },
#[error("invalid configuration: {0}")]
InvalidConfig(String),
#[error("speculative decoding was invoked with an empty prefix")]
EmptyPrefix,
#[error("token id {token} is out of range for vocabulary size {vocab_size}")]
TokenOutOfRange { token: usize, vocab_size: usize },
#[error("no mass left in adjusted distribution and target fallback is also zero")]
DegenerateDistribution,
#[error("model error: {0}")]
ModelError(String),
}
pub type SpeculativeDecodingResult<T> = Result<T, SpeculativeDecodingError>;
impl From<SpeculativeDecodingError> for crate::error::TrustformerError {
fn from(err: SpeculativeDecodingError) -> Self {
crate::error::TrustformerError::CompilationError(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_contains_context() {
let err = SpeculativeDecodingError::VocabMismatch {
draft: 10,
target: 20,
};
let msg = err.to_string();
assert!(msg.contains("10"));
assert!(msg.contains("20"));
assert!(msg.contains("vocab"));
}
#[test]
fn bridges_into_trustformer_error() {
let err = SpeculativeDecodingError::InvalidConfig("k must be > 0".into());
let bridged: crate::error::TrustformerError = err.into();
assert!(bridged.to_string().contains("k must be > 0"));
}
#[test]
fn empty_prefix_is_distinct() {
let err = SpeculativeDecodingError::EmptyPrefix;
assert!(err.to_string().contains("empty prefix"));
}
}