use thiserror::Error;
#[derive(Debug, Error)]
pub enum SapientError {
#[error("Shape mismatch: expected {expected:?}, got {got:?}")]
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
#[error("Rank mismatch: expected {expected}, got {got}")]
RankMismatch { expected: usize, got: usize },
#[error("Type mismatch: expected {expected}, got {got}")]
TypeMismatch { expected: String, got: String },
#[error("Incompatible shapes for broadcasting: {lhs:?} and {rhs:?}")]
BroadcastError { lhs: Vec<usize>, rhs: Vec<usize> },
#[error("Graph contains a cycle — execution is impossible")]
CyclicGraph,
#[error("Node {0:?} not found in graph")]
NodeNotFound(String),
#[error("Graph validation failed: {0}")]
InvalidGraph(String),
#[error("Shape inference failed for op '{op}': {reason}")]
ShapeInferenceFailed { op: String, reason: String },
#[error("Backend '{backend}' does not support op '{op}'")]
UnsupportedOp { backend: String, op: String },
#[error("Backend error from '{backend}': {message}")]
BackendError { backend: String, message: String },
#[error("No suitable backend found for execution")]
NoBackendAvailable,
#[error("Allocation failed: requested {bytes} bytes (alignment {align})")]
AllocationFailed { bytes: usize, align: usize },
#[error("Buffer size mismatch: expected {expected} bytes, got {got}")]
BufferSizeMismatch { expected: usize, got: usize },
#[error("Memory pool exhausted — consider increasing pool capacity")]
PoolExhausted,
#[error("ONNX parse error: {0}")]
OnnxParseError(String),
#[error("GGUF parse error: {0}")]
GgufParseError(String),
#[error("Safetensors parse error: {0}")]
SafetensorsParseError(String),
#[error("Unsupported model format: {0}")]
UnsupportedFormat(String),
#[error("Model not found at path '{0}'")]
ModelNotFound(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Request timed out (deadline exceeded)")]
DeadlineExceeded,
#[error("Batch scheduler is shut down")]
SchedulerShutdown,
#[error("Runtime is not initialized — call Session::new() first")]
UninitializedRuntime,
#[error("Telemetry export failed: {0}")]
TelemetryError(String),
#[error("Internal error: {0}")]
Internal(String),
}
pub type Result<T> = std::result::Result<T, SapientError>;
impl SapientError {
pub fn backend(backend: impl Into<String>, message: impl Into<String>) -> Self {
Self::BackendError {
backend: backend.into(),
message: message.into(),
}
}
pub fn unsupported_op(backend: impl Into<String>, op: impl Into<String>) -> Self {
Self::UnsupportedOp {
backend: backend.into(),
op: op.into(),
}
}
pub fn internal(msg: impl Into<String>) -> Self {
Self::Internal(msg.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_display() {
let e = SapientError::ShapeMismatch {
expected: vec![2, 3],
got: vec![2, 4],
};
let s = e.to_string();
assert!(s.contains("Shape mismatch"));
assert!(s.contains("[2, 3]"));
}
#[test]
fn from_io_error() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
let e: SapientError = io_err.into();
assert!(matches!(e, SapientError::Io(_)));
}
}