use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum QuantizationError {
DimensionMismatch { expected: usize, actual: usize },
NonFiniteValue { index: usize, value: f32 },
EmptyEmbedding,
CalibrationNotInitialized,
InvalidScale { scale: f32 },
ComputationOverflow,
}
impl fmt::Display for QuantizationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DimensionMismatch { expected, actual } => {
write!(f, "Dimension mismatch: expected {}, got {}", expected, actual)
}
Self::NonFiniteValue { index, value } => {
write!(f, "Non-finite value {} at index {}", value, index)
}
Self::EmptyEmbedding => write!(f, "Empty embedding"),
Self::CalibrationNotInitialized => write!(f, "Calibration not initialized"),
Self::InvalidScale { scale } => write!(f, "Invalid scale factor: {}", scale),
Self::ComputationOverflow => write!(f, "Computation overflow"),
}
}
}
impl std::error::Error for QuantizationError {}
pub fn validate_embedding(
embedding: &[f32],
expected_dims: usize,
) -> Result<(), QuantizationError> {
if embedding.len() != expected_dims {
return Err(QuantizationError::DimensionMismatch {
expected: expected_dims,
actual: embedding.len(),
});
}
if embedding.is_empty() {
return Err(QuantizationError::EmptyEmbedding);
}
for (i, &v) in embedding.iter().enumerate() {
if !v.is_finite() {
return Err(QuantizationError::NonFiniteValue { index: i, value: v });
}
}
Ok(())
}