use thiserror::Error;
#[derive(Error, Debug)]
pub enum BitTTTError {
#[error("Shape mismatch: {0}")]
ShapeMismatch(String),
#[error("Device error: {0}")]
DeviceError(String),
#[error("Kernel error: {0}")]
KernelError(String),
#[error("Storage error: {0}")]
StorageError(String),
#[error("Feature not enabled: {0}")]
FeatureNotEnabled(String),
#[error("Candle error: {0}")]
Candle(#[from] candle_core::Error),
}
pub type BitResult<T> = std::result::Result<T, BitTTTError>;
impl BitTTTError {
pub fn shape_mismatch(msg: impl Into<String>) -> Self {
Self::ShapeMismatch(msg.into())
}
pub fn device_error(msg: impl Into<String>) -> Self {
Self::DeviceError(msg.into())
}
pub fn kernel_error(msg: impl Into<String>) -> Self {
Self::KernelError(msg.into())
}
pub fn storage_error(msg: impl Into<String>) -> Self {
Self::StorageError(msg.into())
}
pub fn feature_not_enabled(feature: impl Into<String>) -> Self {
Self::FeatureNotEnabled(feature.into())
}
}
impl From<BitTTTError> for candle_core::Error {
fn from(err: BitTTTError) -> Self {
candle_core::Error::Msg(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_messages() {
let err = BitTTTError::shape_mismatch("Input [4, 8] vs Weight [8, 16]");
assert!(err.to_string().contains("Shape mismatch"));
let err = BitTTTError::device_error("Expected CUDA device");
assert!(err.to_string().contains("Device error"));
let err = BitTTTError::kernel_error("Kernel not found");
assert!(err.to_string().contains("Kernel error"));
}
#[test]
fn test_conversion_to_candle_error() {
let bit_err = BitTTTError::shape_mismatch("test");
let candle_err: candle_core::Error = bit_err.into();
assert!(candle_err.to_string().contains("Shape mismatch"));
}
}