use anyhow::Result;
use async_trait::async_trait;
use super::{GenOptions, InferenceEngine, LoadedModel, ModelSpec};
#[cfg(feature = "huggingface")]
use super::{UniversalEngine, UniversalModel, UniversalModelSpec};
pub struct InferenceEngineAdapter {
#[cfg(feature = "huggingface")]
huggingface_engine: super::huggingface::HuggingFaceEngine,
#[cfg(feature = "llama")]
llama_engine: super::llama::LlamaEngine,
#[cfg(feature = "mlx")]
mlx_engine: super::mlx::MLXEngine,
safetensors_engine: super::safetensors_native::SafeTensorsEngine,
}
impl Default for InferenceEngineAdapter {
fn default() -> Self {
Self::new()
}
}
impl InferenceEngineAdapter {
pub fn new() -> Self {
Self {
#[cfg(feature = "huggingface")]
huggingface_engine: super::huggingface::HuggingFaceEngine::new(),
#[cfg(feature = "llama")]
llama_engine: super::llama::LlamaEngine::new(),
#[cfg(feature = "mlx")]
mlx_engine: super::mlx::MLXEngine::new(),
safetensors_engine: super::safetensors_native::SafeTensorsEngine::new(),
}
}
fn select_backend(&self, spec: &ModelSpec) -> BackendChoice {
let path_str = spec.base_path.to_string_lossy();
#[cfg(feature = "mlx")]
{
if let Some(ext) = spec.base_path.extension().and_then(|s| s.to_str()) {
if ext == "npz" || ext == "mlx" {
return BackendChoice::MLX;
}
}
if cfg!(target_os = "macos") && std::env::consts::ARCH == "aarch64" {
let model_name = spec.name.to_lowercase();
if model_name.contains("llama") || model_name.contains("mistral")
|| model_name.contains("phi") || model_name.contains("qwen") {
return BackendChoice::MLX;
}
}
}
if let Some(ext) = spec.base_path.extension().and_then(|s| s.to_str()) {
if ext == "safetensors" {
return BackendChoice::SafeTensors;
}
}
if let Some(ext) = spec.base_path.extension().and_then(|s| s.to_str()) {
if ext == "gguf" {
#[cfg(feature = "llama")]
{
return BackendChoice::Llama;
}
#[cfg(not(feature = "llama"))]
{
panic!("GGUF file detected but llama feature not enabled. Please install with --features llama");
}
}
}
if path_str.contains("ollama") && path_str.contains("blobs") && path_str.contains("sha256-")
{
#[cfg(feature = "llama")]
{
return BackendChoice::Llama;
}
#[cfg(not(feature = "llama"))]
{
#[cfg(feature = "huggingface")]
{
return BackendChoice::HuggingFace;
}
#[cfg(not(feature = "huggingface"))]
{
panic!("Ollama blob detected but no backend enabled");
}
}
}
if path_str.contains(".gguf")
|| spec.name.contains("llama")
|| spec.name.contains("phi")
|| spec.name.contains("qwen")
|| spec.name.contains("gemma")
|| spec.name.contains("mistral")
{
#[cfg(feature = "llama")]
{
return BackendChoice::Llama;
}
#[cfg(not(feature = "llama"))]
{
#[cfg(feature = "huggingface")]
{
return BackendChoice::HuggingFace;
}
#[cfg(not(feature = "huggingface"))]
{
panic!("GGUF model detected but no backend enabled");
}
}
}
#[cfg(feature = "huggingface")]
{
BackendChoice::HuggingFace
}
#[cfg(not(feature = "huggingface"))]
{
#[cfg(feature = "llama")]
{
BackendChoice::Llama
}
#[cfg(not(feature = "llama"))]
{
panic!("No backend features enabled. Please compile with --features llama or --features huggingface");
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
enum BackendChoice {
#[cfg(feature = "llama")]
Llama,
#[cfg(feature = "huggingface")]
HuggingFace,
#[cfg(feature = "mlx")]
MLX,
SafeTensors,
}
#[async_trait]
impl InferenceEngine for InferenceEngineAdapter {
async fn load(&self, spec: &ModelSpec) -> Result<Box<dyn LoadedModel>> {
let backend = self.select_backend(spec);
match backend {
BackendChoice::SafeTensors => {
self.safetensors_engine.load(spec).await
}
#[cfg(feature = "mlx")]
BackendChoice::MLX => {
self.mlx_engine.load(spec).await
}
#[cfg(feature = "llama")]
BackendChoice::Llama => self.llama_engine.load(spec).await,
#[cfg(feature = "huggingface")]
BackendChoice::HuggingFace => {
let universal_spec = UniversalModelSpec {
name: spec.name.clone(),
backend: super::ModelBackend::HuggingFace {
base_model_id: spec.base_path.to_string_lossy().to_string(),
peft_path: spec.lora_path.as_ref().map(|p| p.to_path_buf()),
use_local: true,
},
template: spec.template.clone(),
ctx_len: spec.ctx_len,
device: "cpu".to_string(),
n_threads: spec.n_threads,
};
let universal_model = self.huggingface_engine.load(&universal_spec).await?;
Ok(Box::new(UniversalModelWrapper {
model: universal_model,
}))
}
}
}
}
#[cfg(feature = "huggingface")]
struct UniversalModelWrapper {
model: Box<dyn UniversalModel>,
}
#[cfg(feature = "huggingface")]
#[async_trait]
impl LoadedModel for UniversalModelWrapper {
async fn generate(
&self,
prompt: &str,
opts: GenOptions,
on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
self.model.generate(prompt, opts, on_token).await
}
}