use anyhow::Result;
use async_trait::async_trait;
use super::{GenOptions, InferenceEngine, LoadedModel, ModelSpec};
#[cfg(feature = "huggingface")]
use super::{UniversalEngine, UniversalModel, UniversalModelSpec};
pub struct InferenceEngineAdapter {
#[cfg(feature = "huggingface")]
huggingface_engine: super::huggingface::HuggingFaceEngine,
#[cfg(feature = "llama")]
llama_engine: super::llama::LlamaEngine,
#[cfg(feature = "mlx")]
mlx_engine: super::mlx::MLXEngine,
safetensors_engine: super::safetensors_native::SafeTensorsEngine,
}
impl Default for InferenceEngineAdapter {
fn default() -> Self {
Self::new()
}
}
impl InferenceEngineAdapter {
pub fn new() -> Self {
Self {
#[cfg(feature = "huggingface")]
huggingface_engine: super::huggingface::HuggingFaceEngine::new(),
#[cfg(feature = "llama")]
llama_engine: super::llama::LlamaEngine::new(),
#[cfg(feature = "mlx")]
mlx_engine: super::mlx::MLXEngine::new(),
safetensors_engine: super::safetensors_native::SafeTensorsEngine::new(),
}
}
pub fn new_with_backend(gpu_backend: Option<&str>) -> Self {
Self {
#[cfg(feature = "huggingface")]
huggingface_engine: super::huggingface::HuggingFaceEngine::new(),
#[cfg(feature = "llama")]
llama_engine: super::llama::LlamaEngine::new_with_backend(gpu_backend),
#[cfg(feature = "mlx")]
mlx_engine: super::mlx::MLXEngine::new(),
safetensors_engine: super::safetensors_native::SafeTensorsEngine::new(),
}
}
#[cfg(feature = "llama")]
pub fn with_moe_config(mut self, cpu_moe_all: bool, n_cpu_moe: Option<usize>) -> Self {
self.llama_engine = self.llama_engine.with_moe_config(cpu_moe_all, n_cpu_moe);
self
}
fn select_backend(&self, spec: &ModelSpec) -> BackendChoice {
let path_str = spec.base_path.to_string_lossy();
#[cfg(feature = "huggingface")]
{
if path_str.contains('/') && !path_str.contains('\\') && !path_str.contains('.') {
return BackendChoice::HuggingFace;
}
}
#[cfg(feature = "mlx")]
{
if let Some(ext) = spec.base_path.extension().and_then(|s| s.to_str()) {
if ext == "npz" || ext == "mlx" {
return BackendChoice::MLX;
}
}
if cfg!(target_os = "macos") && std::env::consts::ARCH == "aarch64" {
let model_name = spec.name.to_lowercase();
if model_name.contains("llama")
|| model_name.contains("mistral")
|| model_name.contains("phi")
|| model_name.contains("qwen")
{
return BackendChoice::MLX;
}
}
}
if let Some(ext) = spec.base_path.extension().and_then(|s| s.to_str()) {
if ext == "safetensors" {
return BackendChoice::SafeTensors;
}
}
if let Some(ext) = spec.base_path.extension().and_then(|s| s.to_str()) {
if ext == "gguf" {
#[cfg(feature = "llama")]
{
return BackendChoice::Llama;
}
#[cfg(not(feature = "llama"))]
{
panic!("GGUF file detected but llama feature not enabled. Please install with --features llama");
}
}
}
if path_str.contains("ollama") && path_str.contains("blobs") && path_str.contains("sha256-")
{
#[cfg(feature = "llama")]
{
return BackendChoice::Llama;
}
#[cfg(not(feature = "llama"))]
{
#[cfg(feature = "huggingface")]
{
return BackendChoice::HuggingFace;
}
#[cfg(not(feature = "huggingface"))]
{
panic!("Ollama blob detected but no backend enabled");
}
}
}
if path_str.contains(".gguf")
|| spec.name.contains("llama")
|| spec.name.contains("phi")
|| spec.name.contains("qwen")
|| spec.name.contains("gemma")
|| spec.name.contains("mistral")
{
#[cfg(feature = "llama")]
{
return BackendChoice::Llama;
}
#[cfg(not(feature = "llama"))]
{
#[cfg(feature = "huggingface")]
{
return BackendChoice::HuggingFace;
}
#[cfg(not(feature = "huggingface"))]
{
panic!("GGUF model detected but no backend enabled");
}
}
}
#[cfg(feature = "huggingface")]
{
BackendChoice::HuggingFace
}
#[cfg(not(feature = "huggingface"))]
{
#[cfg(feature = "llama")]
{
BackendChoice::Llama
}
#[cfg(not(feature = "llama"))]
{
panic!("No backend features enabled. Please compile with --features llama or --features huggingface");
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
enum BackendChoice {
#[cfg(feature = "llama")]
Llama,
#[cfg(feature = "huggingface")]
HuggingFace,
#[cfg(feature = "mlx")]
MLX,
SafeTensors,
}
#[async_trait]
impl InferenceEngine for InferenceEngineAdapter {
async fn load(&self, spec: &ModelSpec) -> Result<Box<dyn LoadedModel>> {
let backend = self.select_backend(spec);
match backend {
BackendChoice::SafeTensors => {
self.safetensors_engine.load(spec).await
}
#[cfg(feature = "mlx")]
BackendChoice::MLX => {
self.mlx_engine.load(spec).await
}
#[cfg(feature = "llama")]
BackendChoice::Llama => self.llama_engine.load(spec).await,
#[cfg(feature = "huggingface")]
BackendChoice::HuggingFace => {
let universal_spec = UniversalModelSpec {
name: spec.name.clone(),
backend: super::ModelBackend::HuggingFace {
base_model_id: spec.base_path.to_string_lossy().to_string(),
peft_path: spec.lora_path.as_ref().map(|p| p.to_path_buf()),
use_local: true,
},
template: spec.template.clone(),
ctx_len: spec.ctx_len,
device: "cpu".to_string(),
n_threads: spec.n_threads,
};
let universal_model = self.huggingface_engine.load(&universal_spec).await?;
Ok(Box::new(UniversalModelWrapper {
model: universal_model,
}))
}
}
}
}
#[cfg(feature = "huggingface")]
struct UniversalModelWrapper {
model: Box<dyn UniversalModel>,
}
#[cfg(feature = "huggingface")]
#[async_trait]
impl LoadedModel for UniversalModelWrapper {
async fn generate(
&self,
prompt: &str,
opts: GenOptions,
on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
self.model.generate(prompt, opts, on_token).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn create_test_spec(name: &str, path: &str) -> ModelSpec {
ModelSpec {
name: name.to_string(),
base_path: PathBuf::from(path),
lora_path: None,
template: None,
ctx_len: 2048,
n_threads: None,
}
}
#[test]
fn test_huggingface_model_id_detection() {
let adapter = InferenceEngineAdapter::new();
let hf_spec = create_test_spec("qwen", "Qwen/Qwen3-Next-80B-A3B-Instruct");
let backend = adapter.select_backend(&hf_spec);
#[cfg(feature = "huggingface")]
assert_eq!(backend, BackendChoice::HuggingFace);
let hf_spec2 = create_test_spec("llama", "meta-llama/Llama-2-7b-chat-hf");
let backend2 = adapter.select_backend(&hf_spec2);
#[cfg(feature = "huggingface")]
assert_eq!(backend2, BackendChoice::HuggingFace);
}
#[test]
fn test_local_file_detection() {
let adapter = InferenceEngineAdapter::new();
#[cfg(feature = "llama")]
{
let gguf_spec = create_test_spec("local", "model.gguf");
let backend = adapter.select_backend(&gguf_spec);
assert_eq!(backend, BackendChoice::Llama);
}
let safetensors_spec = create_test_spec("local", "model.safetensors");
let backend2 = adapter.select_backend(&safetensors_spec);
assert_eq!(backend2, BackendChoice::SafeTensors);
#[cfg(feature = "llama")]
{
let windows_spec = create_test_spec("local", "C:\\path\\to\\model.gguf");
let backend3 = adapter.select_backend(&windows_spec);
assert_eq!(backend3, BackendChoice::Llama);
}
}
}