use crate::ir::Envelope;
use std::collections::HashMap;
use thiserror::Error;
pub mod inference_backend;
pub(crate) mod metadata_driven;
pub(crate) mod tensor_utils;
pub mod traits;
pub mod types;
pub mod onnx;
pub mod cloud;
#[cfg(any(target_os = "macos", target_os = "ios", test))]
pub mod coreml;
#[cfg(feature = "candle")]
pub mod candle;
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
pub mod llm;
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
pub(crate) mod llm_telemetry;
#[cfg(feature = "llm-llamacpp")]
pub(crate) mod streaming_postprocess;
#[cfg(feature = "llm-mistral")]
pub mod mistral;
#[cfg(feature = "llm-llamacpp")]
pub mod llama_cpp;
pub use cloud::{CloudRuntimeAdapter, CloudStreaming};
pub use metadata_driven::MetadataDrivenAdapter;
pub use onnx::OnnxBackend;
pub use onnx::OnnxRuntimeAdapter;
pub use onnx::{ExecutionProviderKind, ONNXSession, SessionOptions};
#[cfg(any(target_os = "android", test))]
pub use onnx::ONNXMobileRuntimeAdapter;
#[cfg(any(target_os = "macos", target_os = "ios", test))]
pub use coreml::CoreMLRuntimeAdapter;
#[cfg(feature = "candle")]
pub use candle::{CandleBackend, CandleRuntimeAdapter};
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
pub use llm::{GenerationOutput, LlmBackend, LlmResult, LlmRuntimeAdapter};
#[cfg(feature = "llm-mistral")]
pub use mistral::MistralBackend;
#[cfg(feature = "llm-llamacpp")]
pub use llama_cpp::LlamaCppBackend;
#[cfg(feature = "llm-llamacpp")]
pub use llama_cpp::{llama_log_get_verbosity, llama_log_set_verbosity};
pub use inference_backend::{BackendError, BackendResult, InferenceBackend, RuntimeType};
pub use traits::ModelRuntime;
pub use types::{
ChatMessage, GenerationConfig, LlmConfig, PartialToken, StreamingCallback, StreamingError,
};
#[derive(Error, Debug)]
pub enum AdapterError {
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Model not loaded: {0}")]
ModelNotLoaded(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Inference failed: {0}")]
InferenceFailed(String),
#[error("IO error: {0}")]
IOError(#[from] std::io::Error),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Runtime error: {0}")]
RuntimeError(String),
#[error("Aborted for cloud fallback: {reason}")]
AbortedForCloudFallback { reason: crate::abort::AbortReason },
}
impl AdapterError {
pub fn from_streaming_callback_error(error: StreamingError) -> Self {
if let Some(reason) = crate::abort::cloud_fallback_reason_from_error(error.as_ref()) {
return Self::AbortedForCloudFallback { reason };
}
Self::RuntimeError(format!("Streaming callback error: {}", error))
}
pub fn cloud_fallback_abort_reason(&self) -> Option<crate::abort::AbortReason> {
match self {
Self::AbortedForCloudFallback { reason } => Some(*reason),
_ => None,
}
}
}
pub type AdapterResult<T> = Result<T, AdapterError>;
#[derive(Debug, Clone)]
pub struct ModelMetadata {
pub model_id: String,
pub version: String,
pub runtime_type: String,
pub model_path: String,
pub input_schema: HashMap<String, Vec<u64>>,
pub output_schema: HashMap<String, Vec<u64>>,
}
pub trait RuntimeAdapter: Send + Sync {
fn name(&self) -> &str;
fn supported_formats(&self) -> Vec<&'static str>;
fn load_model(&mut self, path: &str) -> AdapterResult<()>;
fn execute(&self, input: &Envelope) -> AdapterResult<Envelope>;
fn warmup(&mut self) -> AdapterResult<()> {
Ok(())
}
}
pub trait RuntimeAdapterExt {
fn is_loaded(&self, model_id: &str) -> bool;
fn get_metadata(&self, model_id: &str) -> AdapterResult<&ModelMetadata>;
fn infer(&self, model_id: &str, input: &Envelope) -> AdapterResult<Envelope>;
fn unload_model(&mut self, model_id: &str) -> AdapterResult<()>;
fn list_loaded_models(&self) -> Vec<String>;
}