use thiserror::Error;
#[derive(Debug, Clone, Error, PartialEq)]
pub enum SparseAttentionError {
#[error("invalid window_size: must be > 0, got {0}")]
InvalidWindowSize(usize),
#[error("invalid sequence length: must be > 0, got {0}")]
InvalidSequenceLength(usize),
#[error("global token index {index} is out of bounds for sequence length {seq_len}")]
InvalidGlobalIndices { index: usize, seq_len: usize },
#[error("dimension mismatch: {context} — expected {expected}, got {got}")]
DimensionMismatch {
context: String,
expected: usize,
got: usize,
},
#[error("numerical instability: softmax denominator is zero at position {position}")]
NumericalInstability { position: usize },
}
pub type SparseAttentionResult<T> = Result<T, SparseAttentionError>;
impl From<SparseAttentionError> for crate::error::TrustformerError {
fn from(err: SparseAttentionError) -> Self {
crate::error::TrustformerError::CompilationError(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_contains_context() {
let err = SparseAttentionError::InvalidWindowSize(0);
let msg = err.to_string();
assert!(msg.contains("window_size"));
assert!(msg.contains("0"));
}
#[test]
fn bridges_into_trustformer_error() {
let err = SparseAttentionError::InvalidWindowSize(0);
let bridged: crate::error::TrustformerError = err.into();
assert!(bridged.to_string().contains("window_size"));
}
#[test]
fn global_index_error_message() {
let err = SparseAttentionError::InvalidGlobalIndices {
index: 42,
seq_len: 16,
};
assert!(err.to_string().contains("42"));
assert!(err.to_string().contains("16"));
}
#[test]
fn dimension_mismatch_message() {
let err = SparseAttentionError::DimensionMismatch {
context: "query rows".into(),
expected: 32,
got: 16,
};
let msg = err.to_string();
assert!(msg.contains("query rows"));
assert!(msg.contains("32"));
assert!(msg.contains("16"));
}
}