use crate::schema::ModelCapability;
use crate::tasks::generate::{parse_tool_calls, render_chat_prompt, GenerateRequest, ToolCall};
use crate::InferenceError;
pub trait TextDecoder: Send {
fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError>;
fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError>;
fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Vec<f32>, InferenceError>;
fn eos_ids(&self) -> Vec<u32>;
fn context_length(&self) -> usize;
fn clear_kv_cache(&mut self);
fn begin_prompt(&mut self, prompt_tokens: &[u32]) -> usize {
let _ = prompt_tokens;
self.clear_kv_cache();
0
}
}
pub struct LocalGeneration {
pub text: String,
pub ttft_ms: Option<u64>,
pub stop_reason: Option<String>,
pub prompt_tokens: usize,
pub completion_tokens: usize,
}
pub enum DriveError {
Recoverable(InferenceError),
BackendCorrupted(InferenceError),
}
impl DriveError {
pub fn into_inner(self) -> InferenceError {
match self {
DriveError::Recoverable(e) | DriveError::BackendCorrupted(e) => e,
}
}
}
pub struct BackendDescriptor {
pub backend_name: &'static str,
pub model_types: &'static [&'static str],
}
pub trait LocalInferenceBackend: TextDecoder {
fn backend_name(&self) -> &'static str;
fn supports_capability(&self, cap: ModelCapability) -> bool;
fn render_prompt(&self, req: &GenerateRequest) -> Result<String, InferenceError> {
Ok(render_chat_prompt(req))
}
fn parse_tool_calls(&self, text: &str) -> (String, Vec<ToolCall>) {
parse_tool_calls(text)
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
mod mlx_impl {
use super::*;
use crate::backend::MlxBackend;
impl TextDecoder for MlxBackend {
fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
MlxBackend::encode(self, text)
}
fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError> {
MlxBackend::decode(self, tokens)
}
fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Vec<f32>, InferenceError> {
MlxBackend::forward(self, tokens, pos)
}
fn eos_ids(&self) -> Vec<u32> {
let mut ids = Vec::new();
if let Some(e) = self.eos_token_id() {
ids.push(e);
}
if let Some(e) = self.token_id("<|im_end|>") {
if !ids.contains(&e) {
ids.push(e);
}
}
ids
}
fn context_length(&self) -> usize {
MlxBackend::context_length(self)
}
fn clear_kv_cache(&mut self) {
MlxBackend::clear_kv_cache(self)
}
}
impl LocalInferenceBackend for MlxBackend {
fn backend_name(&self) -> &'static str {
"native-mlx-qwen3"
}
fn supports_capability(&self, cap: ModelCapability) -> bool {
MlxBackend::supports_capability(self, cap)
}
}
use crate::backend::mlx_gemma4::{parse_gemma4_tool_calls, Gemma4Backend};
impl TextDecoder for Gemma4Backend {
fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
Gemma4Backend::encode(self, text)
}
fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError> {
Gemma4Backend::decode(self, tokens)
}
fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Vec<f32>, InferenceError> {
Gemma4Backend::forward(self, tokens, pos)
}
fn eos_ids(&self) -> Vec<u32> {
Gemma4Backend::eos_token_ids(self)
}
fn context_length(&self) -> usize {
Gemma4Backend::context_length(self)
}
fn clear_kv_cache(&mut self) {
Gemma4Backend::clear_kv_cache(self)
}
fn begin_prompt(&mut self, prompt_tokens: &[u32]) -> usize {
Gemma4Backend::begin_prompt(self, prompt_tokens)
}
}
impl LocalInferenceBackend for Gemma4Backend {
fn backend_name(&self) -> &'static str {
"native-mlx-gemma4"
}
fn supports_capability(&self, cap: ModelCapability) -> bool {
use ModelCapability as C;
match cap {
C::Generate
| C::ToolUse
| C::MultiToolCall
| C::Reasoning
| C::Summarize
| C::Code
| C::Classify => true,
C::Rerank
| C::Embed
| C::Grounding
| C::Vision
| C::VideoUnderstanding
| C::AudioUnderstanding
| C::SpeechToText
| C::TextToSpeech
| C::ImageGeneration
| C::VideoGeneration => false,
}
}
fn render_prompt(&self, req: &GenerateRequest) -> Result<String, InferenceError> {
match self.chat_template() {
Some(t) => t.render_request(req),
None => Ok(render_chat_prompt(req)),
}
}
fn parse_tool_calls(&self, text: &str) -> (String, Vec<ToolCall>) {
parse_gemma4_tool_calls(text)
}
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn read_model_type(model_dir: &std::path::Path) -> Result<String, InferenceError> {
let cfg_path = model_dir.join("config.json");
let raw = std::fs::read_to_string(&cfg_path).map_err(|e| {
InferenceError::InferenceFailed(format!("read {}: {e}", cfg_path.display()))
})?;
let cfg: serde_json::Value = serde_json::from_str(&raw).map_err(|e| {
InferenceError::InferenceFailed(format!("parse {}: {e}", cfg_path.display()))
})?;
Ok(cfg
.get("model_type")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_ascii_lowercase())
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
pub fn local_backend_for(
model_dir: &std::path::Path,
) -> Result<Box<dyn LocalInferenceBackend>, InferenceError> {
let model_type = read_model_type(model_dir)?;
match model_type.as_str() {
"gemma4_unified" | "gemma4_unified_text" => Ok(Box::new(
crate::backend::mlx_gemma4::Gemma4Backend::load(model_dir)?,
)),
other => Err(InferenceError::InferenceFailed(format!(
"no in-process MLX backend for model_type '{other}' ({}); Qwen3 uses the \
dedicated native path and other architectures route to vLLM-MLX",
model_dir.display()
))),
}
}