use ndarray::ArrayD;
use std::collections::HashMap;
use std::path::Path;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum BackendError {
#[error("Model not loaded")]
ModelNotLoaded,
#[error("Failed to load model: {0}")]
LoadFailed(String),
#[error("Inference failed: {0}")]
InferenceFailed(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("IO error: {0}")]
IOError(#[from] std::io::Error),
#[error("Runtime error: {0}")]
RuntimeError(String),
}
pub type BackendResult<T> = Result<T, BackendError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RuntimeType {
Onnx,
CoreML,
#[cfg(feature = "candle")]
Candle,
}
impl RuntimeType {
pub fn as_str(&self) -> &'static str {
match self {
RuntimeType::Onnx => "onnx",
RuntimeType::CoreML => "coreml",
#[cfg(feature = "candle")]
RuntimeType::Candle => "candle",
}
}
}
pub trait InferenceBackend: Send + Sync {
fn runtime_type(&self) -> RuntimeType;
fn load_model(&mut self, model_path: &Path, config_path: Option<&Path>) -> BackendResult<()>;
fn run_inference(
&self,
inputs: HashMap<String, ArrayD<f32>>,
) -> BackendResult<HashMap<String, ArrayD<f32>>>;
fn is_loaded(&self) -> bool;
fn input_names(&self) -> BackendResult<Vec<String>>;
fn output_names(&self) -> BackendResult<Vec<String>>;
}