use thiserror::Error;
pub type Result<T> = std::result::Result<T, EmbeddingError>;
#[derive(Error, Debug)]
pub enum EmbeddingError {
#[error("ONNX Runtime error: {0}")]
OnnxRuntime(#[from] ort::Error),
#[error("Tokenizer error: {0}")]
Tokenizer(#[from] tokenizers::tokenizer::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("Model not found: {path}")]
ModelNotFound { path: String },
#[error("Tokenizer not found: {path}")]
TokenizerNotFound { path: String },
#[error("Invalid model format: {reason}")]
InvalidModel { reason: String },
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Empty input provided")]
EmptyInput,
#[error("Batch size {size} exceeds maximum {max}")]
BatchSizeExceeded { size: usize, max: usize },
#[error("Sequence length {length} exceeds maximum {max}")]
SequenceTooLong { length: usize, max: usize },
#[error("Failed to download model: {reason}")]
DownloadFailed { reason: String },
#[error("Cache error: {reason}")]
CacheError { reason: String },
#[error("Checksum mismatch: expected {expected}, got {actual}")]
ChecksumMismatch { expected: String, actual: String },
#[error("Invalid configuration: {reason}")]
InvalidConfig { reason: String },
#[error("Execution provider not available: {provider}")]
ExecutionProviderNotAvailable { provider: String },
#[error("RuVector error: {0}")]
RuVector(String),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Shape error: {0}")]
Shape(#[from] ndarray::ShapeError),
#[error("{0}")]
Other(String),
#[error("GPU initialization failed: {reason}")]
GpuInitFailed { reason: String },
#[error("GPU operation failed: {operation} - {reason}")]
GpuOperationFailed { operation: String, reason: String },
#[error("Shader compilation failed: {shader} - {reason}")]
ShaderCompilationFailed { shader: String, reason: String },
#[error("GPU buffer error: {reason}")]
GpuBufferError { reason: String },
#[error("GPU not available: {reason}")]
GpuNotAvailable { reason: String },
}
impl EmbeddingError {
pub fn model_not_found(path: impl Into<String>) -> Self {
Self::ModelNotFound { path: path.into() }
}
pub fn tokenizer_not_found(path: impl Into<String>) -> Self {
Self::TokenizerNotFound { path: path.into() }
}
pub fn invalid_model(reason: impl Into<String>) -> Self {
Self::InvalidModel {
reason: reason.into(),
}
}
pub fn dimension_mismatch(expected: usize, actual: usize) -> Self {
Self::DimensionMismatch { expected, actual }
}
pub fn download_failed(reason: impl Into<String>) -> Self {
Self::DownloadFailed {
reason: reason.into(),
}
}
pub fn cache_error(reason: impl Into<String>) -> Self {
Self::CacheError {
reason: reason.into(),
}
}
pub fn invalid_config(reason: impl Into<String>) -> Self {
Self::InvalidConfig {
reason: reason.into(),
}
}
pub fn execution_provider_not_available(provider: impl Into<String>) -> Self {
Self::ExecutionProviderNotAvailable {
provider: provider.into(),
}
}
pub fn ruvector(msg: impl Into<String>) -> Self {
Self::RuVector(msg.into())
}
pub fn other(msg: impl Into<String>) -> Self {
Self::Other(msg.into())
}
pub fn gpu_init_failed(reason: impl Into<String>) -> Self {
Self::GpuInitFailed { reason: reason.into() }
}
pub fn gpu_operation_failed(operation: impl Into<String>, reason: impl Into<String>) -> Self {
Self::GpuOperationFailed {
operation: operation.into(),
reason: reason.into(),
}
}
pub fn shader_compilation_failed(shader: impl Into<String>, reason: impl Into<String>) -> Self {
Self::ShaderCompilationFailed {
shader: shader.into(),
reason: reason.into(),
}
}
pub fn gpu_buffer_error(reason: impl Into<String>) -> Self {
Self::GpuBufferError { reason: reason.into() }
}
pub fn gpu_not_available(reason: impl Into<String>) -> Self {
Self::GpuNotAvailable { reason: reason.into() }
}
pub fn is_gpu_error(&self) -> bool {
matches!(
self,
Self::GpuInitFailed { .. }
| Self::GpuOperationFailed { .. }
| Self::ShaderCompilationFailed { .. }
| Self::GpuBufferError { .. }
| Self::GpuNotAvailable { .. }
)
}
pub fn is_recoverable(&self) -> bool {
matches!(
self,
Self::Http(_) | Self::DownloadFailed { .. } | Self::CacheError { .. }
)
}
pub fn is_config_error(&self) -> bool {
matches!(
self,
Self::InvalidConfig { .. }
| Self::InvalidModel { .. }
| Self::DimensionMismatch { .. }
)
}
}