use thiserror::Error;
#[derive(Error, Debug)]
pub enum XybridError {
#[error("Inference error: {0}")]
Inference(#[from] InferenceError),
#[error("Pipeline error: {0}")]
Pipeline(#[from] PipelineError),
#[error("Not found: {0}")]
NotFound(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(String),
}
#[derive(Error, Debug)]
pub enum InferenceError {
#[error("Model not loaded: {0}")]
ModelNotLoaded(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Backend error: {0}")]
Backend(String),
#[error("Preprocessing failed: {0}")]
Preprocessing(String),
#[error("Postprocessing failed: {0}")]
Postprocessing(String),
}
#[derive(Error, Debug)]
pub enum PipelineError {
#[error("Stage '{stage}' failed: {reason}")]
StageFailed {
stage: String,
reason: String,
},
#[error("Invalid target: {0}")]
InvalidTarget(String),
#[error("Provider error: {0}")]
Provider(String),
#[error("Policy denied: {0}")]
PolicyDenied(String),
#[error("Resolution error: {0}")]
Resolution(String),
}
pub type XybridResult<T> = Result<T, XybridError>;
impl From<crate::runtime_adapter::AdapterError> for XybridError {
fn from(e: crate::runtime_adapter::AdapterError) -> Self {
use crate::runtime_adapter::AdapterError;
match e {
AdapterError::ModelNotFound(s) => XybridError::NotFound(s),
AdapterError::ModelNotLoaded(s) => {
XybridError::Inference(InferenceError::ModelNotLoaded(s))
}
AdapterError::InvalidInput(s) => {
XybridError::Inference(InferenceError::InvalidInput(s))
}
AdapterError::InferenceFailed(s) => XybridError::Inference(InferenceError::Backend(s)),
AdapterError::IOError(e) => XybridError::Io(e),
AdapterError::SerializationError(s) => XybridError::Serialization(s),
AdapterError::RuntimeError(s) => XybridError::Inference(InferenceError::Backend(s)),
AdapterError::AbortedForCloudFallback { reason } => XybridError::Inference(
InferenceError::Backend(format!("aborted for cloud fallback: {reason}")),
),
}
}
}
impl From<serde_json::Error> for XybridError {
fn from(e: serde_json::Error) -> Self {
XybridError::Serialization(e.to_string())
}
}
impl From<serde_yaml::Error> for XybridError {
fn from(e: serde_yaml::Error) -> Self {
XybridError::Serialization(e.to_string())
}
}
impl XybridError {
pub fn not_found(msg: impl Into<String>) -> Self {
XybridError::NotFound(msg.into())
}
pub fn config(msg: impl Into<String>) -> Self {
XybridError::Config(msg.into())
}
pub fn serialization(msg: impl Into<String>) -> Self {
XybridError::Serialization(msg.into())
}
}
impl InferenceError {
pub fn model_not_loaded(msg: impl Into<String>) -> Self {
InferenceError::ModelNotLoaded(msg.into())
}
pub fn invalid_input(msg: impl Into<String>) -> Self {
InferenceError::InvalidInput(msg.into())
}
pub fn backend(msg: impl Into<String>) -> Self {
InferenceError::Backend(msg.into())
}
pub fn preprocessing(msg: impl Into<String>) -> Self {
InferenceError::Preprocessing(msg.into())
}
pub fn postprocessing(msg: impl Into<String>) -> Self {
InferenceError::Postprocessing(msg.into())
}
}
impl PipelineError {
pub fn stage_failed(stage: impl Into<String>, reason: impl Into<String>) -> Self {
PipelineError::StageFailed {
stage: stage.into(),
reason: reason.into(),
}
}
pub fn invalid_target(msg: impl Into<String>) -> Self {
PipelineError::InvalidTarget(msg.into())
}
pub fn provider(msg: impl Into<String>) -> Self {
PipelineError::Provider(msg.into())
}
pub fn policy_denied(msg: impl Into<String>) -> Self {
PipelineError::PolicyDenied(msg.into())
}
pub fn resolution(msg: impl Into<String>) -> Self {
PipelineError::Resolution(msg.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xybrid_error_display() {
let err = XybridError::NotFound("model.onnx".to_string());
assert_eq!(err.to_string(), "Not found: model.onnx");
}
#[test]
fn test_inference_error_display() {
let err = InferenceError::Backend("ONNX runtime failed".to_string());
assert_eq!(err.to_string(), "Backend error: ONNX runtime failed");
}
#[test]
fn test_pipeline_error_display() {
let err = PipelineError::StageFailed {
stage: "tts".to_string(),
reason: "voice not found".to_string(),
};
assert_eq!(err.to_string(), "Stage 'tts' failed: voice not found");
}
#[test]
fn test_adapter_error_conversion() {
use crate::runtime_adapter::AdapterError;
let adapter_err = AdapterError::ModelNotFound("whisper.onnx".to_string());
let xybrid_err: XybridError = adapter_err.into();
assert!(matches!(xybrid_err, XybridError::NotFound(_)));
let adapter_err = AdapterError::InferenceFailed("ORT error".to_string());
let xybrid_err: XybridError = adapter_err.into();
assert!(matches!(
xybrid_err,
XybridError::Inference(InferenceError::Backend(_))
));
}
#[test]
fn test_convenience_constructors() {
let err = XybridError::not_found("test.onnx");
assert!(matches!(err, XybridError::NotFound(_)));
let err = InferenceError::backend("runtime crash");
assert!(matches!(err, InferenceError::Backend(_)));
let err = PipelineError::stage_failed("asr", "timeout");
assert!(matches!(err, PipelineError::StageFailed { .. }));
}
#[test]
fn test_json_error_conversion() {
let json_str = "invalid json {";
let result: Result<serde_json::Value, _> = serde_json::from_str(json_str);
let xybrid_err: XybridError = result.unwrap_err().into();
assert!(matches!(xybrid_err, XybridError::Serialization(_)));
}
}