pub mod adaptive_router;
pub mod backend;
pub mod backend_cache;
pub mod handle;
pub mod hardware;
pub mod intent;
pub mod key_pool;
pub mod models;
pub mod outcome;
pub mod protocol;
pub mod registry;
pub mod remote;
pub mod router;
pub mod routing_ext;
pub mod runner;
pub mod schema;
pub mod service;
pub mod stream;
pub mod tasks;
pub mod vllm_mlx;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;
use std::time::{SystemTime, UNIX_EPOCH};
use reqwest::multipart::{Form, Part};
use serde::Serialize;
use thiserror::Error;
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use tokio::process::Command;
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tracing::{debug, instrument};
pub use adaptive_router::{
AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy,
};
pub use handle::InferenceHandle;
pub use intent::{IntentHint, TaskHint};
pub use key_pool::{KeyPool, KeyStats};
pub use outcome::{
CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
};
pub use registry::{
ModelFilter, ModelInfo, ModelRuntimeRequirement, ModelUpgrade, UnifiedRegistry,
};
pub use remote::RemoteBackend;
pub use routing_ext::{
CircuitBreaker, CircuitBreakerRegistry, CircuitState, ImplicitSignal, ImplicitSignalType,
RoutingMode, SpendControl, SpendLimitExceeded, SpendLimits, SpendStatus,
};
pub use runner::{
current_inference_runner, set_inference_runner, EventEmitter, InferenceRunner, RunnerError,
RunnerResult,
};
pub use schema::{
ApiProtocol, BenchmarkScore, CostModel, ModelCapability, ModelSchema, ModelSource,
PerformanceEnvelope, ProprietaryAuth,
};
pub use adaptive_router::TaskComplexity;
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
pub use backend::CandleBackend;
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
pub use backend::EmbeddingBackend;
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
pub use backend::MlxBackend;
pub use hardware::HardwareInfo;
pub use models::{ModelRegistry, ModelRole};
pub use router::{ModelRouter, RoutingDecision};
pub use stream::{StreamAccumulator, StreamEvent};
pub use tasks::{
parse_boxes, BoundingBox, ClassifyRequest, ClassifyResult, ContentBlock, EmbedRequest,
GenerateImageRequest, GenerateImageResult, GenerateParams, GenerateRequest,
GenerateVideoRequest, GenerateVideoResult, GroundRequest, GroundResult, Message, RerankRequest,
RerankResult, RerankedDocument, RoutingWorkload, SynthesizeRequest, SynthesizeResult,
ThinkingMode, ToolCall, TranscribeRequest, TranscribeResult, VideoMode,
};
#[derive(Error, Debug)]
pub enum InferenceError {
#[error("model not found: {0}")]
ModelNotFound(String),
#[error("model download failed: {0}")]
DownloadFailed(String),
#[error("inference failed: {0}")]
InferenceFailed(String),
#[error("mode {mode} not implemented on backend {backend}: {reason}")]
UnsupportedMode {
mode: &'static str,
backend: &'static str,
reason: &'static str,
},
#[error("tokenization error: {0}")]
TokenizationError(String),
#[error("device error: {0}")]
DeviceError(String),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Device {
Cpu,
Metal,
Cuda(usize), }
impl Device {
pub fn auto() -> Self {
#[cfg(feature = "metal")]
{
return Device::Metal;
}
#[cfg(feature = "cuda")]
{
return Device::Cuda(0);
}
#[cfg(not(any(feature = "metal", feature = "cuda")))]
{
Device::Cpu
}
}
}
#[derive(Debug, Clone)]
pub struct InferenceConfig {
pub models_dir: std::path::PathBuf,
pub device: Option<Device>,
pub generation_model: String,
pub preferred_generation_model: Option<String>,
pub embedding_model: String,
pub preferred_embedding_model: Option<String>,
pub classification_model: String,
pub preferred_classification_model: Option<String>,
}
impl Default for InferenceConfig {
fn default() -> Self {
let models_dir = dirs_next()
.unwrap_or_else(|| std::path::PathBuf::from("."))
.join(".car")
.join("models");
let hw = HardwareInfo::detect();
Self {
models_dir,
device: None,
generation_model: hw.recommended_model,
preferred_generation_model: None,
embedding_model: "Qwen3-Embedding-0.6B".to_string(),
preferred_embedding_model: None,
classification_model: "Qwen3-0.6B".to_string(),
preferred_classification_model: None,
}
}
}
fn dirs_next() -> Option<std::path::PathBuf> {
std::env::var("HOME").ok().map(std::path::PathBuf::from)
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
pub context_window: u64,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct InferenceResult {
pub text: String,
pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub bounding_boxes: Vec<crate::tasks::grounding::BoundingBox>,
pub trace_id: String,
pub model_used: String,
pub latency_ms: u64,
#[serde(default)]
pub time_to_first_token_ms: Option<u64>,
pub usage: Option<TokenUsage>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub provider_output_items: Vec<serde_json::Value>,
}
impl InferenceResult {
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct SpeechRuntimeHealth {
pub root: PathBuf,
pub installed: bool,
pub python: PathBuf,
pub stt_command: PathBuf,
pub tts_command: PathBuf,
pub configured_python: Option<String>,
pub detected_python: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SpeechModelHealth {
pub id: String,
pub name: String,
pub provider: String,
pub capability: ModelCapability,
pub is_local: bool,
pub available: bool,
pub cached: bool,
pub selected_by_default: bool,
pub source: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct SpeechHealthReport {
pub runtime: SpeechRuntimeHealth,
pub local_models: Vec<SpeechModelHealth>,
pub remote_models: Vec<SpeechModelHealth>,
pub elevenlabs_configured: bool,
pub prefer_local: bool,
pub allow_remote_fallback: bool,
pub preferred_local_stt: Option<String>,
pub preferred_local_tts: Option<String>,
pub preferred_remote_stt: Option<String>,
pub preferred_remote_tts: Option<String>,
pub local_stt_default: Option<String>,
pub local_tts_default: Option<String>,
pub remote_stt_default: Option<String>,
pub remote_tts_default: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelDefaultHealth {
pub capability: ModelCapability,
pub configured_model: String,
pub available: bool,
pub is_local: bool,
pub provider: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelProviderHealth {
pub provider: String,
pub configured: bool,
pub local_models: usize,
pub remote_models: usize,
pub available_models: usize,
pub capabilities: Vec<ModelCapability>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelCapabilityHealth {
pub capability: ModelCapability,
pub total_models: usize,
pub available_models: usize,
pub local_available_models: usize,
pub remote_available_models: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct RoutingScenarioHealth {
pub name: String,
pub workload: RoutingWorkload,
pub task_family: String,
pub has_tools: bool,
pub has_vision: bool,
pub prefer_local: bool,
pub quality_first_cold_start: bool,
pub bootstrap_min_task_observations: u64,
pub bootstrap_quality_floor: f64,
pub model_id: String,
pub model_name: String,
pub reason: String,
pub strategy: RoutingStrategy,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelBenchmarkPriorHealth {
pub model_id: String,
pub model_name: Option<String>,
pub overall_score: f64,
pub overall_latency_ms: Option<f64>,
pub task_scores: std::collections::HashMap<String, f64>,
pub task_latency_ms: std::collections::HashMap<String, f64>,
pub source_path: PathBuf,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelHealthReport {
pub total_models: usize,
pub available_models: usize,
pub local_models: usize,
pub remote_models: usize,
pub defaults: Vec<ModelDefaultHealth>,
pub providers: Vec<ModelProviderHealth>,
pub capabilities: Vec<ModelCapabilityHealth>,
pub routing_prefer_local: bool,
pub routing_quality_first_cold_start: bool,
pub routing_min_observations: u64,
pub routing_bootstrap_min_task_observations: u64,
pub routing_bootstrap_quality_floor: f64,
pub routing_quality_weight: f64,
pub routing_latency_weight: f64,
pub routing_cost_weight: f64,
pub routing_scenarios: Vec<RoutingScenarioHealth>,
pub benchmark_priors: Vec<ModelBenchmarkPriorHealth>,
pub speech: SpeechHealthReport,
}
#[derive(Debug, Clone, Serialize)]
pub struct SpeechInstallReport {
pub name: String,
pub hf_repo: String,
pub snapshot_path: PathBuf,
pub files_downloaded: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct SpeechSmokePathReport {
pub path: String,
pub tts_model: String,
pub stt_model: String,
pub audio_path: PathBuf,
pub transcript: String,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct SpeechSmokeReport {
pub local: Option<SpeechSmokePathReport>,
pub remote: Option<SpeechSmokePathReport>,
pub skipped: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct SpeechPolicy {
pub prefer_local: bool,
pub allow_remote_fallback: bool,
pub preferred_local_stt: Option<String>,
pub preferred_local_tts: Option<String>,
pub preferred_remote_stt: Option<String>,
pub preferred_remote_tts: Option<String>,
}
pub struct InferenceEngine {
pub config: InferenceConfig,
pub unified_registry: UnifiedRegistry,
pub adaptive_router: AdaptiveRouter,
pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
remote_backend: RemoteBackend,
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
mlx_backends: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
flux_cache: Arc<backend_cache::BackendCache<backend::mlx_flux::FluxBackend>>,
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
ltx_cache: Arc<backend_cache::BackendCache<backend::mlx_ltx::LtxBackend>>,
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
kokoro_cache: Arc<backend_cache::BackendCache<backend::mlx_kokoro::KokoroBackend>>,
pub registry: models::ModelRegistry,
pub router: ModelRouter,
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
backend: Arc<RwLock<Option<CandleBackend>>>,
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
speech_runtime: Arc<Mutex<Option<SpeechRuntime>>>,
speech_policy: SpeechPolicy,
}
impl InferenceEngine {
fn preferred_model_for_capability(&self, capability: ModelCapability) -> Option<&str> {
match capability {
ModelCapability::Generate => self.config.preferred_generation_model.as_deref(),
ModelCapability::Embed => self.config.preferred_embedding_model.as_deref(),
ModelCapability::Classify => self.config.preferred_classification_model.as_deref(),
_ => None,
}
}
fn request_needs_vision(req: &GenerateRequest) -> bool {
req.images.as_ref().is_some_and(|images| !images.is_empty())
|| req.messages.as_ref().is_some_and(|messages| {
messages
.iter()
.any(|msg| matches!(msg, Message::UserMultimodal { .. }))
})
}
fn request_has_video(req: &GenerateRequest) -> bool {
let images_have_video = req
.images
.as_ref()
.is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
let messages_have_video = req.messages.as_ref().is_some_and(|messages| {
messages.iter().any(|msg| match msg {
Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_video),
_ => false,
})
});
images_have_video || messages_have_video
}
fn request_has_audio(req: &GenerateRequest) -> bool {
let images_have_audio = req
.images
.as_ref()
.is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
let messages_have_audio = req.messages.as_ref().is_some_and(|messages| {
messages.iter().any(|msg| match msg {
Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_audio),
_ => false,
})
});
images_have_audio || messages_have_audio
}
pub fn new(config: InferenceConfig) -> Self {
let registry = models::ModelRegistry::new(config.models_dir.clone());
let hw = HardwareInfo::detect();
let router = ModelRouter::new(hw.clone());
let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
let adaptive_router = AdaptiveRouter::with_default_config(hw);
let mut tracker = OutcomeTracker::new();
let profiles_path = config.models_dir.join("outcome_profiles.json");
if let Ok(n) = tracker.load_from_file(&profiles_path) {
if n > 0 {
tracing::info!(loaded = n, "loaded persisted model profiles");
}
}
let mut benchmark_models_loaded = 0usize;
for path in benchmark_priors_paths(&config.models_dir) {
match routing_ext::load_benchmark_priors(&path) {
Ok(priors) if !priors.is_empty() => {
benchmark_models_loaded += priors.len();
routing_ext::apply_benchmark_priors(&mut tracker, &priors);
tracing::info!(
path = %path.display(),
loaded = priors.len(),
"loaded benchmark quality priors"
);
}
Ok(_) => {}
Err(error) => {
tracing::warn!(path = %path.display(), %error, "failed to load benchmark priors");
}
}
}
if benchmark_models_loaded > 0 {
tracing::info!(
loaded = benchmark_models_loaded,
"applied benchmark priors to cold-start routing"
);
}
let outcome_tracker = Arc::new(RwLock::new(tracker));
let remote_backend = RemoteBackend::new();
Self {
config,
unified_registry,
adaptive_router,
outcome_tracker,
remote_backend,
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
mlx_backends: Arc::new(backend_cache::BackendCache::from_env()),
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
flux_cache: Arc::new(backend_cache::BackendCache::from_env()),
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
ltx_cache: Arc::new(backend_cache::BackendCache::from_env()),
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
kokoro_cache: Arc::new(backend_cache::BackendCache::from_env()),
registry,
router,
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
backend: Arc::new(RwLock::new(None)),
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
embedding_backend: Arc::new(RwLock::new(None)),
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
speech_runtime: Arc::new(Mutex::new(None)),
speech_policy: SpeechPolicy {
prefer_local: cfg!(all(
target_os = "macos",
target_arch = "aarch64",
not(car_skip_mlx)
)),
allow_remote_fallback: true,
preferred_local_stt: None,
preferred_local_tts: None,
preferred_remote_stt: None,
preferred_remote_tts: None,
},
}
}
pub async fn init_key_pool(&self) {
for schema in self.unified_registry.list() {
if schema.is_remote() {
self.remote_backend.register_model_keys(schema).await;
}
}
let stats_path = self.config.models_dir.join("key_pool_stats.json");
if let Ok(n) = self.remote_backend.key_pool.load_stats(&stats_path).await {
if n > 0 {
tracing::info!(loaded = n, "loaded persisted key pool stats");
}
}
let total = self.remote_backend.key_pool.total_keys().await;
if total > 0 {
tracing::info!(keys = total, "key pool initialized");
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
let read = self.backend.read().await;
if read.is_some() {
return Ok(());
}
drop(read);
let mut write = self.backend.write().await;
if write.is_some() {
return Ok(());
}
let model_path = self.registry.ensure_model(model_name).await?;
let device = self.config.device.unwrap_or_else(Device::auto);
let backend = CandleBackend::load(&model_path, device)?;
*write = Some(backend);
Ok(())
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
let read = self.embedding_backend.read().await;
if read.is_some() {
return Ok(());
}
drop(read);
let mut write = self.embedding_backend.write().await;
if write.is_some() {
return Ok(());
}
let embedding_model = self
.preferred_model_for_capability(ModelCapability::Embed)
.unwrap_or(&self.config.embedding_model);
let model_path = self.registry.ensure_model(embedding_model).await?;
let device = self.config.device.unwrap_or_else(Device::auto);
let backend = EmbeddingBackend::load(&model_path, device)?;
*write = Some(backend);
Ok(())
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
async fn ensure_mlx_embedding_backend(&self) -> Result<String, InferenceError> {
let embedding_model_name = self
.preferred_model_for_capability(ModelCapability::Embed)
.unwrap_or(&self.config.embedding_model)
.to_string();
let schema = self
.unified_registry
.get(&embedding_model_name)
.or_else(|| self.unified_registry.find_by_name(&embedding_model_name))
.ok_or_else(|| InferenceError::ModelNotFound(embedding_model_name.clone()))?
.clone();
self.ensure_mlx_backend(&schema).await?;
Ok(schema.id)
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
async fn ensure_mlx_backend(
&self,
schema: &ModelSchema,
) -> Result<backend_cache::CachedBackend<backend::MlxBackend>, InferenceError> {
if !Self::supports_native_mlx(schema) {
return Err(InferenceError::InferenceFailed(format!(
"native MLX backend does not support {} ({}) yet; use vLLM-MLX or add a family-specific MLX backend",
schema.name, schema.family
)));
}
let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
let size = backend_cache::estimate_model_size(&model_dir);
let cache = Arc::clone(&self.mlx_backends);
let key = schema.id.clone();
cache.get_or_load(&key, size, move || {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
backend::MlxBackend::load(&model_dir)
}))
.map_err(|e| {
InferenceError::InferenceFailed(format!(
"MLX backend loading panicked (possible Metal/accelerate exception): {:?}",
e
))
})?
})
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
pub async fn warm_up<S: AsRef<str>>(
&self,
schema_ids: &[S],
) -> Vec<Result<(), InferenceError>> {
let mut results = Vec::with_capacity(schema_ids.len());
for id in schema_ids {
let id = id.as_ref();
let outcome: Result<(), InferenceError> = async {
let schema = self.unified_registry.get(id).cloned().ok_or_else(|| {
InferenceError::InferenceFailed(format!("warm_up: unknown schema id {id}"))
})?;
match schema.capabilities.first().copied() {
Some(ModelCapability::ImageGeneration) => {
let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
let size = backend_cache::estimate_model_size(&model_dir);
let _ = self.flux_cache.get_or_load(&schema.id, size, || {
backend::mlx_flux::FluxBackend::load(&model_dir)
})?;
}
Some(ModelCapability::VideoGeneration) => {
let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
let size = backend_cache::estimate_model_size(&model_dir);
let _ = self.ltx_cache.get_or_load(&schema.id, size, || {
backend::mlx_ltx::LtxBackend::load(&model_dir)
})?;
}
Some(ModelCapability::TextToSpeech) => {
let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
let size = backend_cache::estimate_model_size(&model_dir);
let _ = self.kokoro_cache.get_or_load(&schema.id, size, || {
backend::mlx_kokoro::KokoroBackend::load(&model_dir)
})?;
}
_ => {
let _ = self.ensure_mlx_backend(&schema).await?;
}
}
Ok(())
}
.await;
results.push(outcome);
}
results
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
pub async fn warm_up<S: AsRef<str>>(
&self,
_schema_ids: &[S],
) -> Vec<Result<(), InferenceError>> {
Vec::new()
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn supports_native_mlx(schema: &ModelSchema) -> bool {
matches!(schema.family.as_str(), "qwen3" | "qwen2.5-vl" | "qwen2-vl")
}
pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
if let Some(model) = self.preferred_model_for_capability(ModelCapability::Generate) {
let ctx_len = self
.unified_registry
.get(model)
.or_else(|| self.unified_registry.find_by_name(model))
.map(|s| s.context_length)
.unwrap_or(0);
return AdaptiveRoutingDecision {
model_id: model.to_string(),
model_name: model.to_string(),
task: InferenceTask::Generate,
complexity: TaskComplexity::assess(prompt),
reason: "preferred generation model override".into(),
strategy: RoutingStrategy::Explicit,
predicted_quality: 0.5,
fallbacks: vec![],
context_length: ctx_len,
needs_compaction: false,
};
}
let tracker = self.outcome_tracker.read().await;
self.adaptive_router
.route(prompt, &self.unified_registry, &tracker)
}
pub fn route(&self, prompt: &str) -> RoutingDecision {
self.router.route_generate(prompt, &self.registry)
}
pub fn estimated_tokens(
&self,
req: &GenerateRequest,
model_id: Option<&str>,
) -> (usize, usize, bool) {
let prompt_tokens = remote::estimate_tokens(&req.prompt);
let context_tokens = req
.context
.as_ref()
.map(|c| remote::estimate_tokens(c))
.unwrap_or(0);
let tools_tokens = req
.tools
.as_ref()
.map(|t| remote::estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
.unwrap_or(0);
let total_input = prompt_tokens + context_tokens + tools_tokens;
let context_window = model_id
.and_then(|id| {
self.unified_registry
.get(id)
.or_else(|| self.unified_registry.find_by_name(id))
})
.map(|s| s.context_length)
.unwrap_or(0);
let fits = context_window == 0 || (total_input + req.params.max_tokens) <= context_window;
(total_input, context_window, fits)
}
#[instrument(
name = "inference.generate",
skip_all,
fields(
model = tracing::field::Empty,
max_tokens = req.params.max_tokens,
prompt_tokens = tracing::field::Empty,
completion_tokens = tracing::field::Empty,
latency_ms = tracing::field::Empty,
)
)]
pub async fn generate_tracked(
&self,
req: GenerateRequest,
) -> Result<InferenceResult, InferenceError> {
let start = Instant::now();
let (estimated_input, _, _) = self.estimated_tokens(&req, None);
let tracker_read = self.outcome_tracker.read().await;
let has_tools = req.tools.is_some();
let has_vision = Self::request_needs_vision(&req);
let preferred_model = self
.preferred_model_for_capability(ModelCapability::Generate)
.map(str::to_string);
let decision = match req.model.clone().or(preferred_model) {
Some(m) => {
let ctx_len = self
.unified_registry
.get(&m)
.or_else(|| self.unified_registry.find_by_name(&m))
.map(|s| s.context_length)
.unwrap_or(0);
AdaptiveRoutingDecision {
model_id: m.clone(),
model_name: m.clone(),
task: InferenceTask::Generate,
complexity: TaskComplexity::assess(&req.prompt),
reason: "explicit model".into(),
strategy: RoutingStrategy::Explicit,
predicted_quality: 0.5,
fallbacks: vec![],
context_length: ctx_len,
needs_compaction: ctx_len > 0 && estimated_input > ctx_len,
}
}
None => match &req.intent {
Some(hint) => self.adaptive_router.route_context_aware_with_intent(
&req.prompt,
estimated_input,
&self.unified_registry,
&tracker_read,
has_tools,
has_vision,
req.params.workload,
hint,
),
None => self.adaptive_router.route_context_aware(
&req.prompt,
estimated_input,
&self.unified_registry,
&tracker_read,
has_tools,
has_vision,
req.params.workload,
),
},
};
drop(tracker_read);
if decision.needs_compaction {
tracing::info!(
model = %decision.model_name,
prompt_tokens = estimated_input,
context_window = decision.context_length,
"prompt exceeds model context window — compaction or truncation needed"
);
}
let trace_id = {
let mut tracker = self.outcome_tracker.write().await;
tracker.record_start(&decision.model_id, decision.task, &decision.reason)
};
debug!(
model = %decision.model_name,
strategy = ?decision.strategy,
reason = %decision.reason,
trace = %trace_id,
"adaptive-routed generate request"
);
let mut req = req;
if req.params.budget_tokens == 0 && matches!(decision.complexity, TaskComplexity::Complex) {
let supports_thinking = self
.unified_registry
.get(&decision.model_id)
.map(|s| {
s.supported_params
.contains(&schema::GenerateParam::ExtendedThinking)
})
.unwrap_or(false);
if supports_thinking {
req.params.budget_tokens = 8000;
tracing::info!(model = %decision.model_name, budget = 8000, "auto-enabled extended thinking for complex task");
}
}
let mut models_to_try = vec![decision.model_id.clone()];
models_to_try.extend(decision.fallbacks.iter().cloned());
let mut last_error = None;
for candidate_id in &models_to_try {
#[allow(unused_mut)]
let mut schema = self
.unified_registry
.get(candidate_id)
.or_else(|| self.unified_registry.find_by_name(candidate_id))
.cloned();
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
if let Some(ref s) = schema {
if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
tracing::info!(
from = %s.id, to = %mlx_equiv.id,
"redirecting GGUF model to MLX equivalent on Apple Silicon"
);
schema = Some(mlx_equiv.clone());
}
}
let candidate_name = schema
.as_ref()
.map(|s| s.name.clone())
.unwrap_or_else(|| candidate_id.clone());
let is_remote = schema
.as_ref()
.map(|s| s.is_remote() || s.is_vllm_mlx())
.unwrap_or(false);
let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
if is_delegated {
let runner = match runner::current_inference_runner() {
Some(r) => r,
None => {
last_error = Some(InferenceError::InferenceFailed(
"model declares ModelSource::Delegated but no inference runner is registered"
.into(),
));
continue;
}
};
let (tx, mut rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
let emitter = runner::EventEmitter::new(tx);
let runner_req = req.clone();
let runner_handle =
tokio::spawn(async move { runner.run(runner_req, emitter).await });
let mut accumulator = stream::StreamAccumulator::default();
while let Some(evt) = rx.recv().await {
accumulator.push(&evt);
}
let (acc_text, acc_tool_calls) = accumulator.finish();
match runner_handle.await {
Ok(Ok(_runner_result)) => {
let elapsed = start.elapsed().as_millis() as u64;
let mut tracker = self.outcome_tracker.write().await;
tracker.record_complete(&trace_id, elapsed, 0, 0);
return Ok(InferenceResult {
text: acc_text,
tool_calls: acc_tool_calls,
bounding_boxes: vec![],
trace_id,
model_used: candidate_name,
latency_ms: elapsed,
time_to_first_token_ms: None,
usage: None,
provider_output_items: vec![],
});
}
Ok(Err(e)) => {
last_error = Some(InferenceError::InferenceFailed(e.to_string()));
continue;
}
Err(join_err) => {
last_error = Some(InferenceError::InferenceFailed(format!(
"runner task panicked: {join_err}"
)));
continue;
}
}
}
let has_tools = req.tools.is_some();
let context = if has_tools
&& req.tools.as_ref().map_or(false, |t| {
t.iter().any(|tool| {
tool.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
== Some("done")
})
}) {
let base = req.context.as_deref().unwrap_or("");
Some(format!(
"{base}\n\nIMPORTANT: When calling the `done` tool, the `result` field MUST contain a DETAILED summary of everything you found and did. This is the ONLY output the user sees. Do NOT just say 'completed' — include specific findings, data, and conclusions."
))
} else {
req.context.clone()
};
let result = if is_remote {
let schema_val = schema.unwrap();
let _ctx_len = schema_val.context_length;
let temperature = if !schema_val.supported_params.is_empty()
&& !schema_val
.supported_params
.contains(&crate::schema::GenerateParam::Temperature)
{
-1.0
} else {
req.params.temperature
};
self.remote_backend
.generate_with_tools_multi(
&schema_val,
&req.prompt,
context.as_deref(),
temperature,
req.params.max_tokens,
req.tools.as_deref(),
req.images.as_deref(),
req.messages.as_deref(),
req.params.tool_choice.as_deref(),
req.params.parallel_tool_calls,
req.params.budget_tokens,
req.cache_control,
req.response_format.as_ref(),
)
.await
.map(|(t, c, u)| (t, c, u, None::<u64>))
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
let schema_ref = schema
.as_ref()
.ok_or_else(|| InferenceError::ModelNotFound(candidate_id.clone()))?;
if schema_ref.is_foundation_models() {
if has_tools {
Err(InferenceError::UnsupportedMode {
mode: "tool-use",
backend: "foundation-models",
reason: "tool calling not yet wired through the FoundationModels \
bridge — route to a remote model with ToolUse capability",
})
} else if Self::request_has_video(&req)
|| Self::request_has_audio(&req)
|| req.images.as_ref().is_some_and(|imgs| !imgs.is_empty())
{
Err(InferenceError::UnsupportedMode {
mode: "multimodal-content",
backend: "foundation-models",
reason: "the FoundationModels bridge currently exposes text-only \
generation — route image/audio/video to a remote VL model",
})
} else {
let prompt = req.prompt.clone();
let instructions = context.clone();
let max_tokens = req.params.max_tokens as u32;
let temperature = req.params.temperature;
tokio::task::spawn_blocking(move || {
crate::backend::foundation_models::generate(
&prompt,
instructions.as_deref(),
max_tokens,
temperature as f32,
)
})
.await
.map_err(|e| {
InferenceError::InferenceFailed(format!(
"FoundationModels task panicked: {e}"
))
})
.and_then(|r| r)
.map(|text| (text, vec![], None, None))
}
} else if !schema_ref.is_mlx() {
Err(InferenceError::InferenceFailed(format!(
"model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
schema_ref.id
)))
} else if schema_ref.tags.iter().any(|t| t == "mlx-vlm-cli") {
let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
if !has_images {
return Err(InferenceError::UnsupportedMode {
mode: "text-only-on-mlx-vlm-id",
backend: "mlx-vlm-cli",
reason: "the `mlx-vlm/...` model IDs route exclusively \
through the mlx-vlm CLI for image inference. \
For text-only generation, route to a Qwen3 \
text model (`mlx/qwen3-4b:4bit` etc.) — the \
CLI shell-out has higher latency than the \
in-process MLX text tower.",
});
}
let vlm_status = crate::backend::mlx_vlm_cli::runtime_status();
if !vlm_status.is_available() {
return Err(InferenceError::InferenceFailed(vlm_status.user_message()));
}
let repo = match &schema_ref.source {
crate::schema::ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
_ => {
return Err(InferenceError::InferenceFailed(format!(
"model '{}' is tagged mlx-vlm-cli but its \
source isn't ModelSource::Mlx — registry bug",
schema_ref.id
)));
}
};
let imgs = req.images.clone().unwrap_or_default();
let temp = req.params.temperature;
let max_t = req.params.max_tokens;
let prompt = req.prompt.clone();
let text = tokio::task::spawn_blocking(move || {
crate::backend::mlx_vlm_cli::generate(
&repo, &prompt, &imgs, temp, max_t,
)
})
.await
.map_err(|e| {
InferenceError::InferenceFailed(format!(
"mlx_vlm CLI task panicked: {e}"
))
})??;
let bounding_boxes = parse_boxes(&text);
let latency_ms = start.elapsed().as_millis() as u64;
{
let mut tracker = self.outcome_tracker.write().await;
tracker.record_complete(&trace_id, latency_ms, 0, 0);
}
return Ok(InferenceResult {
text,
tool_calls: vec![],
bounding_boxes,
trace_id: trace_id.clone(),
model_used: schema_ref.id.clone(),
latency_ms,
time_to_first_token_ms: None,
usage: None,
provider_output_items: Vec::new(),
});
} else {
let handle = self.ensure_mlx_backend(schema_ref).await?;
if Self::request_has_video(&req) {
return Err(InferenceError::UnsupportedMode {
mode: "video-content-block",
backend: "native-mlx-qwen25vl",
reason: "Qwen2.5-VL video understanding is on the request surface \
but the video-tokenization path (frame sampling + merger) \
is not yet wired; route to a remote VL provider for now",
});
}
if Self::request_has_audio(&req) {
return Err(InferenceError::UnsupportedMode {
mode: "audio-content-block",
backend: "native-mlx-qwen25vl",
reason: "audio understanding is on the request surface (Gemma 4 \
E2B/E4B and Gemini accept it) but the native MLX path \
for this model does not — route to Gemini or Gemma-4",
});
}
let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
if has_images {
let can_do_vision = {
let guard = handle.lock().map_err(|_| {
InferenceError::InferenceFailed(
"MLX backend mutex poisoned".into(),
)
})?;
guard.supports_capability(crate::schema::ModelCapability::Vision)
};
if !can_do_vision {
return Err(InferenceError::UnsupportedMode {
mode: "image-content-block",
backend: "native-mlx-text",
reason: "this MLX backend is a plain Qwen3 text tower. \
For local image inference, route to \
`mlx-vlm/qwen3-vl-2b:bf16` or another `mlx-vlm/...` \
catalog ID so CAR shells out to `mlx_vlm.generate`. \
Alternatives: a local vLLM-MLX VLM server, or a \
remote VL model. (#115)",
});
}
}
self.generate_mlx(req.clone(), &schema_ref.id)
.await
.map(|(text, ttft)| (text, vec![], None, ttft))
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
{
match self.ensure_backend(&candidate_name).await {
Ok(()) => {
let mut write = self.backend.write().await;
let backend = write.as_mut().unwrap();
tasks::generate::generate(backend, req.clone())
.await
.map(|(text, ttft)| (text, vec![], None, ttft))
}
Err(e) => Err(e),
}
}
};
match result {
Ok((text, tool_calls, usage, time_to_first_token_ms)) => {
let latency_ms = start.elapsed().as_millis() as u64;
let estimated_tokens = usage
.as_ref()
.map(|u| u.completion_tokens as usize)
.unwrap_or_else(|| text.split_whitespace().count());
{
let mut tracker = self.outcome_tracker.write().await;
tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
}
if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
cb.record_success(candidate_id);
}
self.auto_save_outcomes().await;
let span = tracing::Span::current();
span.record("model", candidate_name.as_str());
span.record("latency_ms", latency_ms);
if let Some(ttft) = time_to_first_token_ms {
span.record("ttft_ms", ttft);
}
if let Some(ref u) = usage {
span.record("prompt_tokens", u.prompt_tokens);
span.record("completion_tokens", u.completion_tokens);
}
let bounding_boxes = tasks::grounding::parse_boxes(&text);
return Ok(InferenceResult {
text,
tool_calls,
bounding_boxes,
trace_id,
model_used: candidate_name,
latency_ms,
time_to_first_token_ms,
usage,
provider_output_items: Vec::new(),
});
}
Err(e) => {
tracing::warn!(
model = %candidate_name,
error = %e,
remaining = models_to_try.len().saturating_sub(
models_to_try.iter().position(|m| m == candidate_id).unwrap_or(0) + 1
),
"model failed, trying next fallback immediately"
);
{
let mut tracker = self.outcome_tracker.write().await;
let fail_trace =
tracker.record_start(candidate_id, decision.task, "fallback");
tracker.record_failure(&fail_trace, &e.to_string());
}
{
let err_str = e.to_string();
let is_client_error =
err_str.contains("API returned 4") && !err_str.contains("429");
if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
cb.record_failure(candidate_id);
if is_client_error {
cb.record_failure(candidate_id);
}
}
}
#[cfg(not(all(
target_os = "macos",
target_arch = "aarch64",
not(car_skip_mlx)
)))]
{
let mut write = self.backend.write().await;
*write = None;
}
last_error = Some(e);
}
}
}
let e = last_error.unwrap_or(InferenceError::InferenceFailed(
"no models available".into(),
));
{
let mut tracker = self.outcome_tracker.write().await;
tracker.record_failure(&trace_id, &e.to_string());
}
self.auto_save_outcomes().await;
Err(e)
}
pub async fn generate_tracked_stream(
&self,
req: GenerateRequest,
) -> Result<tokio::sync::mpsc::Receiver<stream::StreamEvent>, InferenceError> {
let has_tools = req.tools.is_some();
let has_vision = Self::request_needs_vision(&req);
let preferred_model = self
.preferred_model_for_capability(ModelCapability::Generate)
.map(str::to_string);
let decision = match req.model.clone().or(preferred_model) {
Some(m) => {
let ctx_len = self
.unified_registry
.get(&m)
.or_else(|| self.unified_registry.find_by_name(&m))
.map(|s| s.context_length)
.unwrap_or(0);
AdaptiveRoutingDecision {
model_id: m.clone(),
model_name: m,
task: InferenceTask::Generate,
complexity: TaskComplexity::assess(&req.prompt),
reason: "explicit model".into(),
strategy: RoutingStrategy::Explicit,
predicted_quality: 0.5,
fallbacks: vec![],
context_length: ctx_len,
needs_compaction: false,
}
}
None => {
let tracker_read = self.outcome_tracker.read().await;
if has_vision {
self.adaptive_router.route_with_vision(
&req.prompt,
&self.unified_registry,
&tracker_read,
has_tools,
)
} else if has_tools {
self.adaptive_router.route_with_tools(
&req.prompt,
&self.unified_registry,
&tracker_read,
)
} else {
self.adaptive_router
.route(&req.prompt, &self.unified_registry, &tracker_read)
}
}
};
#[allow(unused_mut)]
let mut schema = self
.unified_registry
.get(&decision.model_id)
.or_else(|| self.unified_registry.find_by_name(&decision.model_id))
.cloned();
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
if let Some(ref s) = schema {
if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
tracing::info!(
from = %s.id, to = %mlx_equiv.id,
"redirecting GGUF model to MLX equivalent on Apple Silicon (stream)"
);
schema = Some(mlx_equiv.clone());
}
}
let is_remote = schema
.as_ref()
.map(|s| s.is_remote() || s.is_vllm_mlx())
.unwrap_or(false);
let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
if is_delegated {
let runner = runner::current_inference_runner().ok_or_else(|| {
InferenceError::InferenceFailed(
"model declares ModelSource::Delegated but no inference runner is registered \
(call set_inference_runner / registerInferenceRunner / register_inference_runner)"
.into(),
)
})?;
let (tx, rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
let emitter = runner::EventEmitter::new(tx);
let request = req.clone();
tokio::spawn(async move {
if let Err(e) = runner.run(request, emitter).await {
tracing::warn!(error = %e, "delegated inference runner failed");
}
});
return Ok(rx);
}
if is_remote {
let schema = schema.unwrap();
self.remote_backend.register_model_keys(&schema).await;
self.remote_backend
.generate_stream(
&schema,
&req.prompt,
req.context.as_deref(),
req.params.temperature,
req.params.max_tokens,
req.tools.as_deref(),
req.images.as_deref(),
req.params.tool_choice.as_deref(),
req.params.parallel_tool_calls,
req.response_format.as_ref(),
)
.await
} else {
let schema =
schema.ok_or_else(|| InferenceError::ModelNotFound(decision.model_id.clone()))?;
let (tx, rx) = tokio::sync::mpsc::channel(64);
#[cfg(any(
all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
all(target_os = "ios", target_arch = "aarch64")
))]
{
if schema.is_foundation_models() {
let prompt = req.prompt.clone();
let instructions = req.context.clone();
let max_tokens = req.params.max_tokens as u32;
let temperature = req.params.temperature;
let tx_clone = tx.clone();
tokio::task::spawn_blocking(move || {
let accum = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
let accum_cb = accum.clone();
let cb = crate::backend::foundation_models::StreamCallback::new(
move |delta: &str| {
if let Ok(mut g) = accum_cb.lock() {
g.push_str(delta);
}
tx_clone
.blocking_send(stream::StreamEvent::TextDelta(
delta.to_string(),
))
.is_ok()
},
);
let result = crate::backend::foundation_models::stream(
&prompt,
instructions.as_deref(),
max_tokens,
temperature as f32,
cb,
);
let final_text = accum.lock().map(|g| g.clone()).unwrap_or_default();
let _ = tx.blocking_send(stream::StreamEvent::Done {
text: final_text,
tool_calls: vec![],
});
result
});
return Ok(rx);
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
if !schema.is_mlx() {
return Err(InferenceError::InferenceFailed(format!(
"model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
schema.id
)));
}
let backend = self.ensure_mlx_backend(&schema).await?;
let model_id = schema.id.clone();
let cache = Arc::clone(&self.mlx_backends);
tokio::task::spawn_blocking(move || {
let _ = Self::stream_local_mlx(backend, cache, model_id, req, tx);
});
return Ok(rx);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
{
self.ensure_backend(&schema.name).await?;
let backend = self.backend.clone();
tokio::spawn(async move {
let _ = Self::stream_local_candle(backend, req, tx).await;
});
Ok(rx)
}
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
async fn stream_local_candle(
backend_lock: Arc<RwLock<Option<CandleBackend>>>,
req: GenerateRequest,
tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
) -> Result<(), InferenceError> {
let mut write = backend_lock.write().await;
let backend = write
.as_mut()
.ok_or_else(|| InferenceError::InferenceFailed("backend not initialized".into()))?;
backend.clear_kv_cache();
let formatted = tasks::generate::apply_chat_template(
&req.prompt,
req.context.as_deref(),
req.params.thinking,
);
let tokens = backend.encode(&formatted)?;
let eos = backend.eos_token_id();
let eos_alt = backend.token_id("<|im_end|>");
let params = &req.params;
if tokens.is_empty() {
let _ = tx
.send(stream::StreamEvent::Done {
text: String::new(),
tool_calls: vec![],
})
.await;
return Ok(());
}
let max_ctx = backend.context_length().unwrap_or(32768);
let headroom = params.max_tokens.min(max_ctx / 4);
let max_prompt = max_ctx.saturating_sub(headroom);
let tokens = if tokens.len() > max_prompt {
tokens[tokens.len() - max_prompt..].to_vec()
} else {
tokens
};
let mut generated = Vec::new();
let logits = backend.forward(&tokens, 0)?;
let mut next_token = tasks::generate::sample_token(&logits, params)?;
for _ in 0..params.max_tokens {
if eos.map_or(false, |id| next_token == id)
|| eos_alt.map_or(false, |id| next_token == id)
{
break;
}
generated.push(next_token);
let delta = backend.decode(&[next_token])?;
if !delta.is_empty()
&& tx
.send(stream::StreamEvent::TextDelta(delta))
.await
.is_err()
{
return Ok(());
}
if !params.stop.is_empty() {
let text_so_far = backend.decode(&generated)?;
if params.stop.iter().any(|s| text_so_far.contains(s)) {
break;
}
}
let pos = tokens.len() + generated.len() - 1;
let logits = backend.forward(&[next_token], pos)?;
next_token = tasks::generate::sample_token(&logits, params)?;
}
let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
let _ = tx
.send(stream::StreamEvent::Done {
text,
tool_calls: vec![],
})
.await;
Ok(())
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn stream_local_mlx(
handle: backend_cache::CachedBackend<backend::MlxBackend>,
cache: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
model_id: String,
req: GenerateRequest,
tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
) -> Result<(), InferenceError> {
let mut guard = handle.lock().map_err(|_| {
InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
})?;
let backend: &mut backend::MlxBackend = &mut *guard;
backend.clear_kv_cache();
let formatted = tasks::generate::apply_chat_template(
&req.prompt,
req.context.as_deref(),
req.params.thinking,
);
let tokens = backend.encode(&formatted)?;
let eos = backend.eos_token_id();
let eos_alt = backend.token_id("<|im_end|>");
let params = &req.params;
if tokens.is_empty() {
let _ = tx.blocking_send(stream::StreamEvent::Done {
text: String::new(),
tool_calls: vec![],
});
return Ok(());
}
let max_ctx = backend.context_length();
let headroom = params.max_tokens.min(max_ctx / 4);
let max_prompt = max_ctx.saturating_sub(headroom);
let tokens = if tokens.len() > max_prompt {
tokens[tokens.len() - max_prompt..].to_vec()
} else {
tokens
};
let mut generated = Vec::new();
let logits = match Self::catch_mlx("stream prefill", || backend.forward(&tokens, 0)) {
Ok(v) => v,
Err(e) => {
cache.invalidate(&model_id);
return Err(e);
}
};
let mut next_token = Self::sample_from_logits(&logits, params)?;
for _ in 0..params.max_tokens {
if eos.map_or(false, |id| next_token == id)
|| eos_alt.map_or(false, |id| next_token == id)
{
break;
}
generated.push(next_token);
let delta = backend.decode(&[next_token])?;
if !delta.is_empty()
&& tx
.blocking_send(stream::StreamEvent::TextDelta(delta))
.is_err()
{
return Ok(());
}
if !params.stop.is_empty() {
let text_so_far = backend.decode(&generated)?;
if params.stop.iter().any(|s| text_so_far.contains(s)) {
break;
}
}
let pos = tokens.len() + generated.len() - 1;
let logits =
match Self::catch_mlx("stream forward", || backend.forward(&[next_token], pos)) {
Ok(v) => v,
Err(e) => {
cache.invalidate(&model_id);
return Err(e);
}
};
next_token = Self::sample_from_logits(&logits, params)?;
}
let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
let _ = tx.blocking_send(stream::StreamEvent::Done {
text,
tool_calls: vec![],
});
Ok(())
}
pub async fn route_context_snapshot(
&self,
prompt: &str,
workload: RoutingWorkload,
has_tools: bool,
has_vision: bool,
) -> AdaptiveRoutingDecision {
let tracker = self.outcome_tracker.read().await;
self.adaptive_router.route_context_aware(
prompt,
0,
&self.unified_registry,
&tracker,
has_tools,
has_vision,
workload,
)
}
pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
Ok(self.generate_tracked(req).await?.text)
}
pub async fn tokenize(&self, model: &str, text: &str) -> Result<Vec<u32>, InferenceError> {
self.assert_local_for_tokenize(model)?;
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
let schema = self
.unified_registry
.get(model)
.or_else(|| self.unified_registry.find_by_name(model))
.ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
.clone();
let handle = self.ensure_mlx_backend(&schema).await?;
let guard = handle.lock().map_err(|_| {
InferenceError::InferenceFailed(format!(
"MLX backend mutex poisoned for {}",
schema.id
))
})?;
return guard.tokenize_raw(text);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
{
self.ensure_backend(model).await?;
let read = self.backend.read().await;
let backend = read.as_ref().ok_or_else(|| {
InferenceError::InferenceFailed(
"candle backend missing after ensure_backend".to_string(),
)
})?;
backend.tokenize_raw(text)
}
}
pub async fn detokenize(&self, model: &str, tokens: &[u32]) -> Result<String, InferenceError> {
self.assert_local_for_tokenize(model)?;
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
let schema = self
.unified_registry
.get(model)
.or_else(|| self.unified_registry.find_by_name(model))
.ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
.clone();
let handle = self.ensure_mlx_backend(&schema).await?;
let guard = handle.lock().map_err(|_| {
InferenceError::InferenceFailed(format!(
"MLX backend mutex poisoned for {}",
schema.id
))
})?;
return guard.detokenize_raw(tokens);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
{
self.ensure_backend(model).await?;
let read = self.backend.read().await;
let backend = read.as_ref().ok_or_else(|| {
InferenceError::InferenceFailed(
"candle backend missing after ensure_backend".to_string(),
)
})?;
backend.detokenize_raw(tokens)
}
}
fn assert_local_for_tokenize(&self, model: &str) -> Result<(), InferenceError> {
if let Some(schema) = self
.unified_registry
.get(model)
.or_else(|| self.unified_registry.find_by_name(model))
{
if !schema.is_local() {
return Err(InferenceError::UnsupportedMode {
mode: "tokenize/detokenize",
backend: "remote",
reason: "remote provider tokenizer is not exposed by the runtime; \
use a local model (Qwen3 GGUF / MLX) for tokenizer-correctness checks",
});
}
}
Ok(())
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn catch_mlx<F, T>(context: &str, f: F) -> Result<T, InferenceError>
where
F: FnOnce() -> Result<T, InferenceError>,
{
std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)).map_err(|e| {
InferenceError::InferenceFailed(format!("MLX panicked during {context}: {e:?}"))
})?
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
async fn generate_mlx(
&self,
req: GenerateRequest,
model_id: &str,
) -> Result<(String, Option<u64>), InferenceError> {
let start = std::time::Instant::now();
let schema = self
.unified_registry
.get(model_id)
.cloned()
.ok_or_else(|| {
InferenceError::InferenceFailed(format!(
"generate_mlx: unknown schema id {model_id}"
))
})?;
let handle = self.ensure_mlx_backend(&schema).await?;
let mut guard = handle.lock().map_err(|_| {
InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
})?;
let backend: &mut backend::MlxBackend = &mut *guard;
backend.clear_kv_cache();
let formatted = tasks::generate::apply_chat_template(
&req.prompt,
req.context.as_deref(),
req.params.thinking,
);
let tokens = backend.encode(&formatted)?;
let eos = backend.eos_token_id();
let eos_alt = backend.token_id("<|im_end|>");
let params = &req.params;
if tokens.is_empty() {
return Ok((String::new(), None));
}
let max_ctx = backend.context_length();
let headroom = params.max_tokens.min(max_ctx / 4);
let max_prompt = max_ctx.saturating_sub(headroom);
let tokens = if tokens.len() > max_prompt {
tokens[tokens.len() - max_prompt..].to_vec()
} else {
tokens
};
let mut generated = Vec::new();
let logits = match Self::catch_mlx("prefill", || backend.forward(&tokens, 0)) {
Ok(v) => v,
Err(e) => {
drop(guard);
self.mlx_backends.invalidate(model_id);
return Err(e);
}
};
let mut next_token = Self::sample_from_logits(&logits, params)?;
let ttft_ms = Some(start.elapsed().as_millis() as u64);
for _ in 0..params.max_tokens {
if eos.map_or(false, |id| next_token == id)
|| eos_alt.map_or(false, |id| next_token == id)
{
break;
}
generated.push(next_token);
if !params.stop.is_empty() {
let text_so_far = backend.decode(&generated)?;
if params.stop.iter().any(|s| text_so_far.contains(s)) {
break;
}
}
let pos = tokens.len() + generated.len() - 1;
let logits = match Self::catch_mlx("forward", || backend.forward(&[next_token], pos)) {
Ok(v) => v,
Err(e) => {
drop(guard);
self.mlx_backends.invalidate(model_id);
return Err(e);
}
};
next_token = Self::sample_from_logits(&logits, params)?;
}
let text = backend.decode(&generated)?;
Ok((
tasks::generate::strip_thinking(&text, params.thinking),
ttft_ms,
))
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn sample_from_logits(logits: &[f32], params: &GenerateParams) -> Result<u32, InferenceError> {
if params.temperature <= 0.0 {
let (idx, _) = logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| InferenceError::InferenceFailed("empty logits".into()))?;
return Ok(idx as u32);
}
let temp = params.temperature as f32;
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = logits
.iter()
.map(|&l| ((l - max_logit) / temp).exp())
.collect();
let sum: f32 = probs.iter().sum();
for p in &mut probs {
*p /= sum;
}
if params.top_p < 1.0 {
let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut cutoff_idx = indexed.len();
for (i, &(_, p)) in indexed.iter().enumerate() {
cumsum += p;
if cumsum > params.top_p as f32 {
cutoff_idx = i + 1;
break;
}
}
let allowed: std::collections::HashSet<usize> =
indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
for (i, p) in probs.iter_mut().enumerate() {
if !allowed.contains(&i) {
*p = 0.0;
}
}
let sum: f32 = probs.iter().sum();
if sum > 0.0 {
for p in &mut probs {
*p /= sum;
}
}
}
use rand::Rng;
let mut rng = rand::rng();
let r: f32 = rng.random();
let mut cumsum = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if cumsum >= r {
return Ok(i as u32);
}
}
Ok((probs.len() - 1) as u32)
}
pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
let instruction = req
.instruction
.as_deref()
.unwrap_or("Retrieve relevant memory facts");
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
let model_id = self.ensure_mlx_embedding_backend().await?;
let schema = self
.unified_registry
.get(&model_id)
.cloned()
.ok_or_else(|| {
InferenceError::InferenceFailed(format!("embed: unknown schema id {model_id}"))
})?;
let handle = self.ensure_mlx_backend(&schema).await?;
let mut guard = handle.lock().map_err(|_| {
InferenceError::InferenceFailed(format!(
"MLX embedding backend mutex poisoned for {model_id}"
))
})?;
let backend: &mut backend::MlxBackend = &mut *guard;
let mut results = Vec::with_capacity(req.texts.len());
for text in &req.texts {
let embedding = if req.is_query {
backend.embed_query(text, instruction)?
} else {
backend.embed_one(text)?
};
results.push(embedding);
}
return Ok(results);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
{
self.ensure_embedding_backend().await?;
let mut write = self.embedding_backend.write().await;
let backend = write.as_mut().unwrap();
let mut results = Vec::with_capacity(req.texts.len());
for text in &req.texts {
let embedding = if req.is_query {
backend.embed_query(text, instruction)?
} else {
backend.embed_one(text)?
};
results.push(embedding);
}
Ok(results)
}
}
pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResult, InferenceError> {
if req.documents.is_empty() {
return Ok(RerankResult {
ranked: Vec::new(),
model_used: None,
});
}
let model_name = match req.model.clone() {
Some(m) => m,
None => self
.preferred_model_for_capability(ModelCapability::Rerank)
.map(str::to_string)
.ok_or_else(|| {
InferenceError::InferenceFailed(
"no reranker model available — pull a Qwen3-Reranker model first".into(),
)
})?,
};
let schema = self
.unified_registry
.find_by_name(&model_name)
.or_else(|| self.unified_registry.get(&model_name))
.cloned()
.ok_or_else(|| {
InferenceError::InferenceFailed(format!(
"rerank: unknown reranker model {model_name}"
))
})?;
if !schema.has_capability(ModelCapability::Rerank) {
return Err(InferenceError::InferenceFailed(format!(
"model {} does not declare the Rerank capability",
schema.name
)));
}
let instruction = req.instruction.as_deref().unwrap_or(
"Given a web search query, retrieve relevant passages that answer the query",
);
let mut scored: Vec<RerankedDocument> = Vec::with_capacity(req.documents.len());
for (idx, doc) in req.documents.iter().enumerate() {
let prompt = rerank_prompt(instruction, &req.query, doc);
let gen_req = GenerateRequest {
prompt,
model: Some(schema.id.clone()),
params: tasks::generate::GenerateParams {
temperature: 0.0,
max_tokens: 3,
thinking: tasks::generate::ThinkingMode::Off,
..Default::default()
},
context: None,
tools: None,
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
};
let out = self.generate(gen_req).await?;
let score = score_from_rerank_output(&out, &schema.name);
scored.push(RerankedDocument {
index: idx,
score,
document: doc.clone(),
});
}
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.index.cmp(&b.index))
});
if let Some(n) = req.top_n {
scored.truncate(n);
}
Ok(RerankResult {
ranked: scored,
model_used: Some(schema.name),
})
}
pub async fn ground(&self, req: GroundRequest) -> Result<GroundResult, InferenceError> {
let model_name = match req.model.clone() {
Some(m) => m,
None => self
.preferred_model_for_capability(ModelCapability::Grounding)
.map(str::to_string)
.ok_or_else(|| {
InferenceError::InferenceFailed(
"no grounding-capable model available — pull a Qwen2.5-VL model first"
.into(),
)
})?,
};
let gen_req = GenerateRequest {
prompt: req.prompt.clone(),
model: Some(model_name),
params: GenerateParams::default(),
context: None,
tools: None,
images: Some(vec![req.image.clone()]),
messages: None,
cache_control: false,
response_format: None,
intent: None,
};
let result = self.generate_tracked(gen_req).await?;
Ok(GroundResult {
boxes: result.bounding_boxes,
raw_text: result.text,
model_used: Some(result.model_used),
})
}
pub async fn classify(
&self,
req: ClassifyRequest,
) -> Result<Vec<ClassifyResult>, InferenceError> {
let model = match req.model.clone().or_else(|| {
self.preferred_model_for_capability(ModelCapability::Classify)
.map(str::to_string)
}) {
Some(m) => m,
None => {
let m = self.router.route_small(&self.registry);
debug!(model = %m, "auto-routed classify request");
m
}
};
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
return self.classify_via_generate(req, &model).await;
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
{
self.ensure_backend(&model).await?;
let mut write = self.backend.write().await;
let backend = write.as_mut().unwrap();
tasks::classify::classify(backend, req).await
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
async fn classify_via_generate(
&self,
req: ClassifyRequest,
model: &str,
) -> Result<Vec<ClassifyResult>, InferenceError> {
let labels_str = req
.labels
.iter()
.enumerate()
.map(|(i, l)| format!("{}. {}", i + 1, l))
.collect::<Vec<_>>()
.join("\n");
let prompt = format!(
"Classify the following text into one of these categories:\n\
{labels_str}\n\n\
Text: {}\n\n\
Respond with ONLY the category name, nothing else.",
req.text
);
let gen_req = GenerateRequest {
prompt,
model: Some(model.to_string()),
params: tasks::generate::GenerateParams {
temperature: 0.0,
max_tokens: 32,
thinking: tasks::generate::ThinkingMode::Off,
..Default::default()
},
context: None,
tools: None,
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
};
let response = self.generate(gen_req).await?;
let response_lower = response.trim().to_lowercase();
let mut results: Vec<ClassifyResult> = req
.labels
.iter()
.map(|label| {
let label_lower = label.to_lowercase();
let score = if response_lower == label_lower {
1.0
} else if response_lower.contains(&label_lower) {
0.8
} else {
let label_words: Vec<&str> = label_lower.split_whitespace().collect();
let matches = label_words
.iter()
.filter(|w| response_lower.contains(**w))
.count();
if label_words.is_empty() {
0.0
} else {
0.5 * (matches as f64 / label_words.len() as f64)
}
};
ClassifyResult {
label: label.clone(),
score,
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let total: f64 = results.iter().map(|r| r.score).sum();
if total > 0.0 {
for r in &mut results {
r.score /= total;
}
}
Ok(results)
}
pub async fn transcribe(
&self,
req: TranscribeRequest,
) -> Result<TranscribeResult, InferenceError> {
let candidates =
self.speech_candidates(ModelCapability::SpeechToText, req.model.as_deref())?;
let mut last_error = None;
for schema in candidates {
let result = match &schema.source {
ModelSource::Mlx { .. } => self.transcribe_local_mlx(&schema, &req).await,
ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
self.transcribe_elevenlabs(&schema, &req).await
}
_ => Err(InferenceError::InferenceFailed(format!(
"speech-to-text not implemented for model source: {}",
schema.id
))),
};
match result {
Ok(result) => return Ok(result),
Err(err) => last_error = Some(err),
}
}
Err(last_error.unwrap_or_else(|| {
InferenceError::InferenceFailed("no speech-to-text models available".into())
}))
}
pub async fn synthesize(
&self,
req: SynthesizeRequest,
) -> Result<SynthesizeResult, InferenceError> {
let candidates =
self.speech_candidates(ModelCapability::TextToSpeech, req.model.as_deref())?;
let mut last_error = None;
for schema in candidates {
let result = match &schema.source {
ModelSource::Mlx { .. } => self.synthesize_local_mlx(&schema, &req).await,
ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
self.synthesize_elevenlabs(&schema, &req).await
}
_ => Err(InferenceError::InferenceFailed(format!(
"text-to-speech not implemented for model source: {}",
schema.id
))),
};
match result {
Ok(result) => return Ok(result),
Err(err) => last_error = Some(err),
}
}
Err(last_error.unwrap_or_else(|| {
InferenceError::InferenceFailed("no text-to-speech models available".into())
}))
}
pub async fn generate_image(
&self,
req: GenerateImageRequest,
) -> Result<GenerateImageResult, InferenceError> {
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
use crate::backend::external_flux;
let backend =
std::env::var("CAR_IMAGE_BACKEND").unwrap_or_else(|_| "native".to_string());
let use_external = match backend.as_str() {
"external" => true,
"native" => false,
_ => external_flux::is_available() && backend == "auto-external",
};
if use_external {
tracing::info!(
"routing image generation to external mflux \
(set CAR_IMAGE_BACKEND=native to use the Rust port)"
);
let mut req = req;
req.model = self.resolve_external_hf_repo(
req.model.as_deref(),
ModelCapability::ImageGeneration,
);
return external_flux::generate_image(&req);
}
tracing::info!("using native Rust MLX Flux backend");
}
let candidates = self
.media_generation_candidates(ModelCapability::ImageGeneration, req.model.as_deref())?;
let mut last_error = None;
for schema in candidates {
let result = match &schema.source {
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
ModelSource::Mlx { .. } => self.generate_image_native_mlx(&schema, &req).await,
_ => Err(InferenceError::InferenceFailed(format!(
"image generation not implemented for model source: {}",
schema.id
))),
};
match result {
Ok(result) => return Ok(result),
Err(err) => last_error = Some(err),
}
}
Err(last_error.unwrap_or_else(|| {
InferenceError::InferenceFailed("no image generation models available".into())
}))
}
pub async fn generate_image_batch(
&self,
req: GenerateImageRequest,
) -> Result<Vec<GenerateImageResult>, InferenceError> {
let count = req.variant_count.unwrap_or(1).max(1);
if count == 1 {
return self.generate_image(req).await.map(|r| vec![r]);
}
let base_seed = req.seed.unwrap_or(0);
let mut results = Vec::with_capacity(count as usize);
for i in 0..count {
let mut variant_req = req.clone();
variant_req.seed = Some(base_seed.wrapping_add(i as u64));
variant_req.variant_count = Some(1);
results.push(self.generate_image(variant_req).await?);
}
Ok(results)
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
async fn generate_image_native_mlx(
&self,
schema: &ModelSchema,
req: &GenerateImageRequest,
) -> Result<GenerateImageResult, InferenceError> {
let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
let size = backend_cache::estimate_model_size(&model_dir);
let cache = Arc::clone(&self.flux_cache);
let key = schema.id.clone();
let handle = cache.get_or_load(&key, size, || {
backend::mlx_flux::FluxBackend::load(&model_dir)
})?;
let req = req.clone();
tokio::task::spawn_blocking(move || -> Result<GenerateImageResult, InferenceError> {
let mut guard = handle.lock().map_err(|_| {
InferenceError::InferenceFailed("flux backend mutex poisoned".into())
})?;
guard.generate(&req)
})
.await
.map_err(|e| InferenceError::InferenceFailed(format!("flux task join: {e}")))?
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
async fn generate_video_native_mlx(
&self,
schema: &ModelSchema,
req: &GenerateVideoRequest,
) -> Result<GenerateVideoResult, InferenceError> {
let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
let size = backend_cache::estimate_model_size(&model_dir);
let cache = Arc::clone(&self.ltx_cache);
let key = schema.id.clone();
let handle = cache.get_or_load(&key, size, || {
backend::mlx_ltx::LtxBackend::load(&model_dir)
})?;
let req = req.clone();
tokio::task::spawn_blocking(move || -> Result<GenerateVideoResult, InferenceError> {
let mut guard = handle.lock().map_err(|_| {
InferenceError::InferenceFailed("ltx backend mutex poisoned".into())
})?;
guard.generate(&req)
})
.await
.map_err(|e| InferenceError::InferenceFailed(format!("ltx task join: {e}")))?
}
pub async fn generate_video(
&self,
req: GenerateVideoRequest,
) -> Result<GenerateVideoResult, InferenceError> {
if let Err(msg) = req.validate() {
return Err(InferenceError::InferenceFailed(format!(
"invalid GenerateVideoRequest: {}",
msg
)));
}
let requires_audio_conditioning = req.requires_audio_passthrough_opt_in();
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
use crate::backend::external_ltx;
let backend =
std::env::var("CAR_VIDEO_BACKEND").unwrap_or_else(|_| "native".to_string());
let use_external = match backend.as_str() {
"external" => true,
"native" => false,
"auto-external" => external_ltx::is_available(),
_ => false,
};
if use_external {
tracing::info!(
"CAR_VIDEO_BACKEND requested external LTX routing for LTX-family models"
);
} else {
tracing::info!("using family-aware MLX video routing");
}
}
let candidates = self
.media_generation_candidates(ModelCapability::VideoGeneration, req.model.as_deref())?;
let mut last_error = None;
for schema in candidates {
let result = match &schema.source {
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
ModelSource::Mlx { hf_repo, .. } => {
if crate::backend::external_mlx_video::is_wan_family(&schema) {
match self.unified_registry.ensure_local(&schema.id).await {
Ok(model_dir) => {
crate::backend::external_mlx_video::generate_wan_video(
&schema, &model_dir, &req,
)
}
Err(err) => Err(err),
}
} else {
let backend = std::env::var("CAR_VIDEO_BACKEND")
.unwrap_or_else(|_| "native".to_string());
let use_external_ltx = match backend.as_str() {
"external" => true,
"native" => false,
"auto-external" => crate::backend::external_ltx::is_available(),
_ => false,
};
let use_external_ltx = use_external_ltx || requires_audio_conditioning;
if requires_audio_conditioning
&& !crate::backend::external_ltx::is_available()
{
return Err(InferenceError::InferenceFailed(
"audio-reference video conditioning requires the external `ltx-2-mlx a2v` CLI on PATH"
.to_string(),
));
}
if use_external_ltx {
let mut req = req.clone();
req.model = Some(hf_repo.clone());
crate::backend::external_ltx::generate_video(&req)
} else {
self.generate_video_native_mlx(&schema, &req).await
}
}
}
_ => Err(InferenceError::InferenceFailed(format!(
"video generation not implemented for model source: {}",
schema.id
))),
};
match result {
Ok(result) => return Ok(result),
Err(err) => last_error = Some(err),
}
}
Err(last_error.unwrap_or_else(|| {
InferenceError::InferenceFailed("no video generation models available".into())
}))
}
pub fn list_models_unified(&self) -> Vec<ModelInfo> {
self.unified_registry
.list()
.iter()
.map(|m| ModelInfo::from(*m))
.collect()
}
pub fn available_model_upgrades(&self) -> Vec<ModelUpgrade> {
self.unified_registry.available_upgrades()
}
pub fn list_schemas(&self) -> Vec<ModelSchema> {
self.unified_registry.list().into_iter().cloned().collect()
}
pub fn list_models(&self) -> Vec<models::ModelInfo> {
self.registry.list_models()
}
pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
let schema = self
.unified_registry
.find_by_name(name)
.or_else(|| self.unified_registry.get(name))
.ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
self.unified_registry.ensure_local(&schema.id).await
}
pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
let schema = self
.unified_registry
.get(name)
.or_else(|| {
self.unified_registry
.list()
.into_iter()
.find(|schema| schema.name.eq_ignore_ascii_case(name))
})
.or_else(|| self.unified_registry.find_by_name(name))
.ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
let model_dir = self.unified_registry.models_dir().join(&schema.name);
if model_dir.exists() {
std::fs::remove_dir_all(&model_dir)?;
}
match &schema.source {
ModelSource::Mlx { hf_repo, .. } => {
remove_huggingface_repo_cache(hf_repo)?;
}
ModelSource::Local {
hf_repo,
tokenizer_repo,
..
} => {
remove_huggingface_repo_cache(hf_repo)?;
remove_huggingface_repo_cache(tokenizer_repo)?;
}
_ => {}
}
Ok(())
}
pub fn register_model(&mut self, schema: ModelSchema) {
self.unified_registry.register(schema);
}
pub async fn discover_vllm_mlx_models(&mut self) -> usize {
let config = vllm_mlx::VllmMlxConfig::default();
if !config.auto_discover {
return 0;
}
vllm_mlx::discover_and_register(&config, &mut self.unified_registry).await
}
pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
self.outcome_tracker.clone()
}
async fn auto_save_outcomes(&self) {
if let Err(e) = self.save_outcomes().await {
tracing::debug!("auto-save outcomes failed: {}", e);
}
if let Err(e) = self.save_key_pool_stats().await {
tracing::debug!("auto-save key pool stats failed: {}", e);
}
}
pub async fn save_outcomes(&self) -> Result<(), std::io::Error> {
let tracker = self.outcome_tracker.read().await;
let path = self.config.models_dir.join("outcome_profiles.json");
tracker.save_to_file(&path)
}
pub async fn save_key_pool_stats(&self) -> Result<(), std::io::Error> {
let path = self.config.models_dir.join("key_pool_stats.json");
self.remote_backend.key_pool.save_stats(&path).await
}
pub async fn key_pool_stats(
&self,
) -> std::collections::HashMap<String, Vec<key_pool::KeyStats>> {
self.remote_backend.key_pool.all_stats().await
}
pub async fn export_profiles(&self) -> Vec<ModelProfile> {
let tracker = self.outcome_tracker.read().await;
tracker.export_profiles()
}
pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
let mut tracker = self.outcome_tracker.write().await;
tracker.import_profiles(profiles);
}
pub async fn prepare_speech_runtime(&self) -> Result<PathBuf, InferenceError> {
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
Ok(self.config.models_dir.clone())
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
{
Ok(self.ensure_speech_runtime().await?.root)
}
}
pub fn set_speech_policy(&mut self, policy: SpeechPolicy) {
self.speech_policy = policy;
}
pub fn set_routing_config(&mut self, config: RoutingConfig) {
self.adaptive_router.set_config(config);
}
pub async fn install_curated_speech(
&mut self,
) -> Result<Vec<SpeechInstallReport>, InferenceError> {
let _runtime_root = self.prepare_speech_runtime().await?;
let schemas = self.list_schemas();
let mut repos = Vec::new();
for schema in &schemas {
if !schema.is_mlx() || !schema.tags.iter().any(|tag| tag == "speech") {
continue;
}
if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
if !repos.iter().any(|existing: &String| existing == hf_repo) {
repos.push(hf_repo.clone());
}
}
}
let mut installed = Vec::new();
for repo in repos {
let (snapshot_path, files_downloaded) = download_hf_repo_snapshot(&repo).await?;
let name = schemas
.iter()
.find(|schema| {
matches!(&schema.source, ModelSource::Mlx { hf_repo, .. } if hf_repo == &repo)
})
.map(|schema| schema.name.clone())
.unwrap_or_else(|| repo.clone());
installed.push(SpeechInstallReport {
name,
hf_repo: repo,
snapshot_path,
files_downloaded,
});
}
self.unified_registry.refresh_availability();
Ok(installed)
}
pub fn speech_health(&self) -> SpeechHealthReport {
let local_stt_default =
self.speech_health_default_name(ModelCapability::SpeechToText, true, false);
let local_tts_default =
self.speech_health_default_name(ModelCapability::TextToSpeech, true, false);
let remote_stt_default =
self.speech_health_default_name(ModelCapability::SpeechToText, false, true);
let remote_tts_default =
self.speech_health_default_name(ModelCapability::TextToSpeech, false, true);
let mut local_models = Vec::new();
let mut remote_models = Vec::new();
for schema in self.list_schemas() {
let capability = if schema.has_capability(ModelCapability::SpeechToText) {
Some(ModelCapability::SpeechToText)
} else if schema.has_capability(ModelCapability::TextToSpeech) {
Some(ModelCapability::TextToSpeech)
} else {
None
};
let Some(capability) = capability else {
continue;
};
let selected_by_default = local_stt_default
.as_ref()
.is_some_and(|name| name == &schema.name)
|| local_tts_default
.as_ref()
.is_some_and(|name| name == &schema.name)
|| remote_stt_default
.as_ref()
.is_some_and(|name| name == &schema.name)
|| remote_tts_default
.as_ref()
.is_some_and(|name| name == &schema.name);
let health = SpeechModelHealth {
id: schema.id.clone(),
name: schema.name.clone(),
provider: schema.provider.clone(),
capability,
is_local: schema.is_local(),
available: schema.available,
cached: speech_model_cached(&schema),
selected_by_default,
source: speech_model_source_label(&schema),
};
if schema.is_local() {
local_models.push(health);
} else {
remote_models.push(health);
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
let runtime = SpeechRuntimeHealth {
root: self.config.models_dir.clone(),
installed: true,
python: PathBuf::new(),
stt_command: PathBuf::new(),
tts_command: PathBuf::new(),
configured_python: None,
detected_python: None,
};
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
let runtime = {
let rt =
SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
SpeechRuntimeHealth {
root: rt.root.clone(),
installed: rt.is_ready(),
python: rt.python.clone(),
stt_command: rt.stt_program.clone(),
tts_command: rt.tts_program.clone(),
configured_python: std::env::var("CAR_SPEECH_PYTHON")
.ok()
.filter(|value| !value.trim().is_empty()),
detected_python: detect_speech_python(),
}
};
SpeechHealthReport {
runtime,
local_models,
remote_models,
elevenlabs_configured: car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY")
.is_some(),
prefer_local: self.speech_policy.prefer_local,
allow_remote_fallback: self.speech_policy.allow_remote_fallback,
preferred_local_stt: self.speech_policy.preferred_local_stt.clone(),
preferred_local_tts: self.speech_policy.preferred_local_tts.clone(),
preferred_remote_stt: self.speech_policy.preferred_remote_stt.clone(),
preferred_remote_tts: self.speech_policy.preferred_remote_tts.clone(),
local_stt_default,
local_tts_default,
remote_stt_default,
remote_tts_default,
}
}
pub async fn model_health(&self) -> ModelHealthReport {
let schemas = self.list_schemas();
let total_models = schemas.len();
let available_models = schemas.iter().filter(|schema| schema.available).count();
let local_models = schemas.iter().filter(|schema| schema.is_local()).count();
let remote_models = total_models.saturating_sub(local_models);
let defaults = vec![
self.model_default_health(
ModelCapability::Generate,
self.preferred_model_for_capability(ModelCapability::Generate)
.unwrap_or(&self.config.generation_model),
),
self.model_default_health(
ModelCapability::Embed,
self.preferred_model_for_capability(ModelCapability::Embed)
.unwrap_or(&self.config.embedding_model),
),
self.model_default_health(
ModelCapability::Classify,
self.preferred_model_for_capability(ModelCapability::Classify)
.unwrap_or(&self.config.classification_model),
),
];
let mut providers = std::collections::BTreeMap::new();
for schema in &schemas {
let entry =
providers
.entry(schema.provider.clone())
.or_insert_with(|| ProviderAccumulator {
configured: false,
local_models: 0,
remote_models: 0,
available_models: 0,
capabilities: std::collections::HashSet::new(),
});
entry.configured |= model_source_configured(schema);
if schema.is_local() {
entry.local_models += 1;
} else {
entry.remote_models += 1;
}
if schema.available {
entry.available_models += 1;
}
for capability in &schema.capabilities {
entry.capabilities.insert(*capability);
}
}
let providers = providers
.into_iter()
.map(|(provider, acc)| ModelProviderHealth {
provider,
configured: acc.configured,
local_models: acc.local_models,
remote_models: acc.remote_models,
available_models: acc.available_models,
capabilities: sort_capabilities(acc.capabilities.into_iter().collect()),
})
.collect();
let capabilities = all_model_capabilities()
.into_iter()
.map(|capability| {
let relevant: Vec<&ModelSchema> = schemas
.iter()
.filter(|schema| schema.has_capability(capability))
.collect();
let available: Vec<&ModelSchema> = relevant
.iter()
.copied()
.filter(|schema| schema.available)
.collect();
ModelCapabilityHealth {
capability,
total_models: relevant.len(),
available_models: available.len(),
local_available_models: available
.iter()
.filter(|schema| schema.is_local())
.count(),
remote_available_models: available
.iter()
.filter(|schema| !schema.is_local())
.count(),
}
})
.collect();
let routing = self.routing_scenarios().await;
let routing_config = self.adaptive_router.config().clone();
let benchmark_priors = load_benchmark_prior_health(&self.config.models_dir, &schemas);
ModelHealthReport {
total_models,
available_models,
local_models,
remote_models,
defaults,
providers,
capabilities,
routing_prefer_local: routing_config.prefer_local,
routing_quality_first_cold_start: routing_config.quality_first_cold_start,
routing_min_observations: routing_config.min_observations,
routing_bootstrap_min_task_observations: routing_config.bootstrap_min_task_observations,
routing_bootstrap_quality_floor: routing_config.bootstrap_quality_floor,
routing_quality_weight: routing_config.quality_weight,
routing_latency_weight: routing_config.latency_weight,
routing_cost_weight: routing_config.cost_weight,
routing_scenarios: routing,
benchmark_priors,
speech: self.speech_health(),
}
}
async fn routing_scenarios(&self) -> Vec<RoutingScenarioHealth> {
let tracker = self.outcome_tracker.read().await;
let config = self.adaptive_router.config().clone();
let scenarios = [
(
"interactive_text",
"Summarize the benefits of local-first AI routing in two sentences.",
"text",
RoutingWorkload::Interactive,
false,
false,
),
(
"background_code",
"Write a Python function named fibonacci(n) that returns the nth Fibonacci number.",
"code",
RoutingWorkload::Background,
false,
false,
),
(
"interactive_tool_use",
"Use the provided weather tool to get the weather for Boston.",
"tool_use",
RoutingWorkload::Interactive,
true,
false,
),
(
"interactive_vision",
"What is in this image? Answer in one word.",
"vision",
RoutingWorkload::Interactive,
false,
true,
),
];
scenarios
.into_iter()
.map(
|(name, prompt, task_family, workload, has_tools, has_vision)| {
let decision = self.adaptive_router.route_context_aware(
prompt,
0,
&self.unified_registry,
&tracker,
has_tools,
has_vision,
workload,
);
let quality_first_cold_start = if has_tools || has_vision {
config.quality_first_cold_start
} else if task_family == "code"
&& matches!(workload, RoutingWorkload::Background)
{
false
} else {
config.quality_first_cold_start
};
RoutingScenarioHealth {
name: name.to_string(),
task_family: task_family.to_string(),
workload,
has_tools,
has_vision,
prefer_local: if task_family == "speech" {
self.speech_policy.prefer_local
} else {
config.prefer_local
},
quality_first_cold_start,
bootstrap_min_task_observations: config.bootstrap_min_task_observations,
bootstrap_quality_floor: config.bootstrap_quality_floor,
model_id: decision.model_id,
model_name: decision.model_name,
reason: decision.reason,
strategy: decision.strategy,
}
},
)
.collect()
}
pub async fn smoke_test_speech(
&self,
local: bool,
remote: bool,
) -> Result<SpeechSmokeReport, InferenceError> {
let mut report = SpeechSmokeReport::default();
if local {
let tts = self
.preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
.ok_or_else(|| {
InferenceError::InferenceFailed(
"no local text-to-speech model available".into(),
)
})?;
let stt = self
.preferred_speech_schema(ModelCapability::SpeechToText, true, false)
.ok_or_else(|| {
InferenceError::InferenceFailed(
"no local speech-to-text model available".into(),
)
})?;
report.local = Some(
self.run_speech_smoke_path("local", &tts, &stt, "Testing CAR local speech path.")
.await?,
);
} else {
report.skipped.push("local".to_string());
}
if remote {
let tts = self
.preferred_speech_schema(ModelCapability::TextToSpeech, false, true)
.ok_or_else(|| {
InferenceError::InferenceFailed(
"no remote text-to-speech model available".into(),
)
})?;
let stt = self
.preferred_speech_schema(ModelCapability::SpeechToText, false, true)
.ok_or_else(|| {
InferenceError::InferenceFailed(
"no remote speech-to-text model available".into(),
)
})?;
report.remote = Some(
self.run_speech_smoke_path("remote", &tts, &stt, "Testing CAR remote speech path.")
.await?,
);
} else {
report.skipped.push("remote".to_string());
}
Ok(report)
}
fn speech_candidates(
&self,
capability: ModelCapability,
explicit: Option<&str>,
) -> Result<Vec<ModelSchema>, InferenceError> {
if let Some(model) = explicit {
let schema = self
.unified_registry
.get(model)
.or_else(|| self.unified_registry.find_by_name(model))
.cloned()
.ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
if !schema.has_capability(capability) {
return Err(InferenceError::InferenceFailed(format!(
"model {} does not support {:?}",
schema.name, capability
)));
}
return Ok(vec![schema]);
}
let mut candidates: Vec<ModelSchema> = self
.unified_registry
.query(&ModelFilter {
capabilities: vec![capability],
..Default::default()
})
.into_iter()
.cloned()
.collect();
if candidates.is_empty() {
return Err(InferenceError::InferenceFailed(format!(
"no models registered for capability {:?}",
capability
)));
}
candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
if !self.speech_policy.allow_remote_fallback
&& candidates.iter().any(|model| model.is_local())
{
candidates.retain(|model| model.is_local());
}
Ok(candidates)
}
fn resolve_external_hf_repo(
&self,
explicit: Option<&str>,
capability: ModelCapability,
) -> Option<String> {
let id = explicit?;
let schema = self
.unified_registry
.get(id)
.or_else(|| self.unified_registry.find_by_name(id))?;
if !schema.has_capability(capability) {
return Some(id.to_string());
}
if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
return Some(hf_repo.clone());
}
Some(id.to_string())
}
fn media_generation_candidates(
&self,
capability: ModelCapability,
explicit: Option<&str>,
) -> Result<Vec<ModelSchema>, InferenceError> {
if let Some(model) = explicit {
let schema = self
.unified_registry
.get(model)
.or_else(|| self.unified_registry.find_by_name(model))
.cloned()
.ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
if !schema.has_capability(capability) {
return Err(InferenceError::InferenceFailed(format!(
"model {} does not support {:?}",
schema.name, capability
)));
}
return Ok(vec![schema]);
}
let mut candidates: Vec<ModelSchema> = self
.unified_registry
.query(&ModelFilter {
capabilities: vec![capability],
local_only: true,
..Default::default()
})
.into_iter()
.cloned()
.collect();
candidates.sort_by_key(|schema| (!schema.available, schema.size_mb()));
if candidates.is_empty() {
return Err(InferenceError::InferenceFailed(format!(
"no models registered for capability {:?}",
capability
)));
}
Ok(candidates)
}
fn preferred_speech_schema(
&self,
capability: ModelCapability,
local_only: bool,
remote_only: bool,
) -> Option<ModelSchema> {
let available_only = remote_only;
let mut candidates: Vec<ModelSchema> = self
.unified_registry
.query(&ModelFilter {
capabilities: vec![capability],
available_only,
..Default::default()
})
.into_iter()
.filter(|schema| {
(!local_only || schema.is_local()) && (!remote_only || schema.is_remote())
})
.cloned()
.collect();
candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
candidates.into_iter().next()
}
fn speech_health_default_name(
&self,
capability: ModelCapability,
local_only: bool,
remote_only: bool,
) -> Option<String> {
let preferred = match capability {
ModelCapability::SpeechToText if local_only => {
self.speech_policy.preferred_local_stt.as_ref()
}
ModelCapability::SpeechToText if remote_only => {
self.speech_policy.preferred_remote_stt.as_ref()
}
ModelCapability::TextToSpeech if local_only => {
self.speech_policy.preferred_local_tts.as_ref()
}
ModelCapability::TextToSpeech if remote_only => {
self.speech_policy.preferred_remote_tts.as_ref()
}
_ => None,
};
preferred
.filter(|name| {
self.unified_registry.list().iter().any(|schema| {
schema.name == **name
&& schema.has_capability(capability)
&& (!local_only || schema.is_local())
&& (!remote_only || schema.is_remote())
})
})
.cloned()
.or_else(|| {
self.preferred_speech_schema(capability, local_only, remote_only)
.map(|schema| schema.name)
})
}
fn model_default_health(
&self,
capability: ModelCapability,
configured_model: &str,
) -> ModelDefaultHealth {
let schema = self
.unified_registry
.find_by_name(configured_model)
.or_else(|| self.unified_registry.get(configured_model));
ModelDefaultHealth {
capability,
configured_model: configured_model.to_string(),
available: schema.is_some_and(|model| model.available),
is_local: schema.is_some_and(ModelSchema::is_local),
provider: schema.map(|model| model.provider.clone()),
}
}
fn speech_sort_key(
&self,
capability: ModelCapability,
model: &ModelSchema,
) -> (u8, u8, u8, u8, u64, u64) {
let policy_preference = match capability {
ModelCapability::SpeechToText if model.is_local() => {
self.speech_policy.preferred_local_stt.as_ref()
}
ModelCapability::SpeechToText => self.speech_policy.preferred_remote_stt.as_ref(),
ModelCapability::TextToSpeech if model.is_local() => {
self.speech_policy.preferred_local_tts.as_ref()
}
ModelCapability::TextToSpeech => self.speech_policy.preferred_remote_tts.as_ref(),
_ => None,
};
let local_rank = if self.speech_policy.prefer_local {
if model.is_local() {
0
} else {
1
}
} else if model.is_remote() {
0
} else {
1
};
let availability_rank = if model.available {
0
} else if model.is_local() {
1
} else {
2
};
let policy_rank: u8 = if policy_preference.is_some_and(|preferred| preferred == &model.name)
{
0
} else {
1
};
let speech_rank = match capability {
ModelCapability::TextToSpeech => {
if model.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
0
} else if model.name == "Kokoro-82M-bf16" {
1
} else if model.name == "Kokoro-82M-6bit" {
2
} else {
3
}
}
ModelCapability::SpeechToText => {
if model.name == "Parakeet-TDT-0.6B-v3-MLX" {
0
} else {
1
}
}
_ => 0,
};
let latency_rank = model.performance.latency_p50_ms.unwrap_or(u64::MAX);
let size_rank = model.cost.size_mb.unwrap_or(u64::MAX);
(
local_rank,
availability_rank,
policy_rank,
speech_rank,
latency_rank,
size_rank,
)
}
async fn run_speech_smoke_path(
&self,
path: &str,
tts: &ModelSchema,
stt: &ModelSchema,
text: &str,
) -> Result<SpeechSmokePathReport, InferenceError> {
let work_dir = temp_work_dir(&format!("speech-smoke-{path}"))?;
let audio_path = work_dir.join(format!("{path}.wav"));
let synth = self
.synthesize(SynthesizeRequest {
text: text.to_string(),
model: Some(tts.name.clone()),
voice: default_speech_voice(tts),
language: Some("en".to_string()),
output_path: Some(audio_path.display().to_string()),
..SynthesizeRequest::default()
})
.await?;
let transcript = self
.transcribe(TranscribeRequest {
audio_path: synth.audio_path.clone(),
model: Some(stt.name.clone()),
language: Some("en".to_string()),
prompt: None,
timestamps: false,
})
.await?;
Ok(SpeechSmokePathReport {
path: path.to_string(),
tts_model: synth.model_used.unwrap_or_else(|| tts.name.clone()),
stt_model: transcript.model_used.unwrap_or_else(|| stt.name.clone()),
audio_path: PathBuf::from(synth.audio_path),
transcript: transcript.text,
})
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
async fn ensure_speech_runtime(&self) -> Result<SpeechRuntime, InferenceError> {
let mut guard = self.speech_runtime.lock().await;
if let Some(runtime) = guard.as_ref() {
if runtime.is_ready() {
return Ok(runtime.clone());
}
}
let runtime =
SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
if !runtime.is_ready() {
bootstrap_speech_runtime(&runtime).await?;
}
if !runtime.is_ready() {
return Err(InferenceError::InferenceFailed(format!(
"managed speech runtime is not ready at {}",
runtime.root.display()
)));
}
*guard = Some(runtime.clone());
Ok(runtime)
}
async fn transcribe_local_mlx(
&self,
schema: &ModelSchema,
req: &TranscribeRequest,
) -> Result<TranscribeResult, InferenceError> {
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
let parakeet = backend::mlx_parakeet::ParakeetBackend::load(&model_dir)?;
let (text, words) = if req.timestamps {
parakeet
.transcribe_detailed(Path::new(&req.audio_path))
.map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?
} else {
let t = parakeet
.transcribe(Path::new(&req.audio_path))
.map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?;
(t, Vec::new())
};
return Ok(TranscribeResult {
text,
model_used: Some(schema.name.clone()),
language: req.language.clone(),
words,
});
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
{
let runtime = self.ensure_speech_runtime().await?;
let hf_repo = match &schema.source {
ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
_ => unreachable!(),
};
let output_dir = temp_work_dir("stt")?;
let output_prefix = output_dir.join("transcript");
let mut args = vec![
"--model".to_string(),
hf_repo,
"--audio".to_string(),
req.audio_path.clone(),
"--output-path".to_string(),
output_prefix.display().to_string(),
"--format".to_string(),
"json".to_string(),
];
if let Some(language) = &req.language {
args.push("--language".to_string());
args.push(normalize_lang_code(language));
}
if let Some(prompt) = &req.prompt {
args.push("--context".to_string());
args.push(prompt.clone());
}
if req.timestamps {
args.push("--verbose".to_string());
}
let output = run_mlx_audio_command(&runtime, "stt.generate", &args).await?;
let text = read_transcription_result(&output_prefix)?
.or_else(|| extract_text_from_payload(&output.stdout))
.ok_or_else(|| {
InferenceError::InferenceFailed(format!(
"mlx-audio transcription returned no text: {}",
output.stderr
))
})?;
Ok(TranscribeResult {
text,
model_used: Some(schema.name.clone()),
language: req.language.clone(),
words: Vec::new(),
})
}
}
async fn synthesize_local_mlx(
&self,
schema: &ModelSchema,
req: &SynthesizeRequest,
) -> Result<SynthesizeResult, InferenceError> {
let requested = req.requested_advanced_controls();
let repo_supports_advanced = match &schema.source {
ModelSource::Mlx { hf_repo, .. } => hf_repo.to_ascii_lowercase().contains("qwen3-tts"),
_ => false,
};
if !requested.is_empty() && !repo_supports_advanced {
if req.strict_capabilities {
return Err(InferenceError::InferenceFailed(format!(
"model {name} does not support Qwen3-TTS advanced controls {requested:?}; \
route to a Qwen3-TTS model or set strict_capabilities = false to degrade",
name = schema.name,
)));
}
tracing::warn!(
model = %schema.name,
fields = ?requested,
"Qwen3-TTS advanced controls set on non-Qwen3-TTS backend — ignored \
(set strict_capabilities=true to error instead)"
);
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
{
if repo_supports_advanced && !requested.is_empty() {
if req.strict_capabilities {
return Err(InferenceError::InferenceFailed(format!(
"native MLX TTS backend does not yet implement Qwen3-TTS advanced \
controls {requested:?}; run on non-Apple-Silicon to use the Python \
mlx-audio fallback, or set strict_capabilities = false"
)));
}
tracing::warn!(
model = %schema.name,
fields = ?requested,
"Qwen3-TTS advanced controls are not yet implemented in the native MLX TTS \
backend; synthesizing without cloning/voice-design"
);
}
let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
let size = backend_cache::estimate_model_size(&model_dir);
let cache = Arc::clone(&self.kokoro_cache);
let key = schema.id.clone();
let handle = cache.get_or_load(&key, size, || {
backend::mlx_kokoro::KokoroBackend::load(&model_dir)
})?;
let output_path = req.output_path.clone().unwrap_or_else(|| {
let dir = std::env::temp_dir().join("car_tts");
let _ = std::fs::create_dir_all(&dir);
dir.join("output.wav").display().to_string()
});
let voice = req.voice.as_deref().unwrap_or("af_heart").to_string();
let text = req.text.clone();
let op = tokio::task::spawn_blocking(move || -> Result<PathBuf, InferenceError> {
let mut guard = handle.lock().map_err(|_| {
InferenceError::InferenceFailed("kokoro backend mutex poisoned".into())
})?;
guard
.synthesize(&text, Some(&voice), Path::new(&output_path))
.map_err(|e| InferenceError::InferenceFailed(format!("native TTS: {e}")))
})
.await
.map_err(|e| InferenceError::InferenceFailed(format!("kokoro task join: {e}")))??;
let final_path =
materialize_audio_output(&op, req.output_path.as_deref(), &req.format)?;
return Ok(SynthesizeResult {
audio_path: final_path.display().to_string(),
media_type: media_type_for_format(&req.format),
model_used: Some(schema.name.clone()),
voice_used: req.voice.clone(),
});
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
{
let runtime = self.ensure_speech_runtime().await?;
let primary_hf_repo = match &schema.source {
ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
_ => unreachable!(),
};
let (produced, model_used) = match self
.synthesize_local_mlx_repo(&runtime, &primary_hf_repo, schema.name.as_str(), req)
.await
{
Ok(result) => result,
Err(primary_err)
if primary_hf_repo == "mlx-community/Kokoro-82M-6bit"
&& kokoro_runtime_fallback_enabled() =>
{
let fallback_repo = "mlx-community/Kokoro-82M-bf16";
let fallback_name = "Kokoro-82M-bf16";
match self
.synthesize_local_mlx_repo(&runtime, fallback_repo, fallback_name, req)
.await
{
Ok(result) => result,
Err(fallback_err) => {
return Err(InferenceError::InferenceFailed(format!(
"{primary_err}; fallback {fallback_name} also failed: {fallback_err}"
)));
}
}
}
Err(err) => return Err(err),
};
let final_path =
materialize_audio_output(&produced, req.output_path.as_deref(), &req.format)?;
Ok(SynthesizeResult {
audio_path: final_path.display().to_string(),
media_type: media_type_for_format(&req.format),
model_used: Some(model_used),
voice_used: req.voice.clone(),
})
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
async fn synthesize_local_mlx_repo(
&self,
runtime: &SpeechRuntime,
hf_repo: &str,
model_name: &str,
req: &SynthesizeRequest,
) -> Result<(PathBuf, String), InferenceError> {
let output_dir = temp_work_dir("tts")?;
let mut args = vec![
"--model".to_string(),
hf_repo.to_string(),
"--text".to_string(),
req.text.clone(),
"--output_path".to_string(),
output_dir.display().to_string(),
];
if let Some(voice) = &req.voice {
args.push("--voice".to_string());
args.push(voice.clone());
}
if let Some(speed) = req.speed {
args.push("--speed".to_string());
args.push(speed.to_string());
}
let repo_lower = hf_repo.to_ascii_lowercase();
if repo_lower.contains("kokoro") {
args.push("--lang_code".to_string());
args.push(kokoro_lang_code(req.language.as_deref()).to_string());
} else if let Some(language) = &req.language {
args.push("--lang_code".to_string());
args.push(normalize_lang_code(language));
}
if repo_lower.contains("qwen3-tts") {
if let Some(ref_audio) = &req.reference_audio_path {
args.push("--ref_audio".to_string());
args.push(ref_audio.clone());
}
if let Some(ref_text) = &req.reference_text {
args.push("--ref_text".to_string());
args.push(ref_text.clone());
}
if let Some(instruct) = &req.voice_instruction {
args.push("--instruct".to_string());
args.push(instruct.clone());
}
}
let output = if repo_lower.contains("kokoro") {
let device = std::env::var("CAR_SPEECH_KOKORO_DEVICE")
.or_else(|_| std::env::var("CAR_SPEECH_MLX_DEVICE"))
.unwrap_or_else(|_| "cpu".to_string());
let extra_env = vec![
("MLX_DEVICE".to_string(), device),
("PYTORCH_ENABLE_MPS_FALLBACK".to_string(), "1".to_string()),
];
run_mlx_audio_command_with_env(runtime, "tts.generate", &args, &extra_env).await?
} else {
run_mlx_audio_command(runtime, "tts.generate", &args).await?
};
let produced = find_audio_file(&output_dir)?.ok_or_else(|| {
let hint = if repo_lower.contains("kokoro") {
". Kokoro models may crash on GPU — try CAR_SPEECH_KOKORO_DEVICE=cpu or use the default Qwen3-TTS model"
} else {
""
};
InferenceError::InferenceFailed(format!(
"mlx-audio synthesis produced no audio file: {}{}",
output.stderr, hint
))
})?;
Ok((produced, model_name.to_string()))
}
async fn transcribe_elevenlabs(
&self,
schema: &ModelSchema,
req: &TranscribeRequest,
) -> Result<TranscribeResult, InferenceError> {
let (endpoint, api_key) = elevenlabs_auth(schema)?;
let file_name = Path::new(&req.audio_path)
.file_name()
.and_then(|f| f.to_str())
.unwrap_or("audio.wav")
.to_string();
let audio_bytes = tokio::fs::read(&req.audio_path).await?;
let file_part = Part::bytes(audio_bytes).file_name(file_name);
let mut form = Form::new()
.text("model_id", schema.name.clone())
.part("file", file_part);
if let Some(language) = &req.language {
form = form.text("language_code", language.clone());
}
let resp = self
.remote_backend
.client
.post(format!(
"{}/v1/speech-to-text",
endpoint.trim_end_matches('/')
))
.header("xi-api-key", api_key)
.multipart(form)
.send()
.await
.map_err(|e| {
InferenceError::InferenceFailed(format!("ElevenLabs STT request failed: {e}"))
})?;
let status = resp.status();
let body = resp.text().await.map_err(|e| {
InferenceError::InferenceFailed(format!("read ElevenLabs STT body: {e}"))
})?;
if !status.is_success() {
return Err(InferenceError::InferenceFailed(format!(
"ElevenLabs STT returned {status}: {body}"
)));
}
let payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
InferenceError::InferenceFailed(format!("parse ElevenLabs STT response: {e}"))
})?;
let text = payload
.get("text")
.and_then(|v| v.as_str())
.map(str::to_string)
.ok_or_else(|| {
InferenceError::InferenceFailed("ElevenLabs STT response missing text".into())
})?;
Ok(TranscribeResult {
text,
model_used: Some(schema.name.clone()),
language: payload
.get("language_code")
.and_then(|v| v.as_str())
.map(str::to_string),
words: Vec::new(),
})
}
async fn synthesize_elevenlabs(
&self,
schema: &ModelSchema,
req: &SynthesizeRequest,
) -> Result<SynthesizeResult, InferenceError> {
let requested = req.requested_advanced_controls();
if !requested.is_empty() {
if req.strict_capabilities {
return Err(InferenceError::InferenceFailed(format!(
"ElevenLabs backend does not support Qwen3-TTS advanced controls \
{requested:?}; route to a Qwen3-TTS model or set strict_capabilities = false"
)));
}
tracing::warn!(
model = %schema.name,
fields = ?requested,
"Qwen3-TTS advanced controls ignored by ElevenLabs backend"
);
}
let (endpoint, api_key) = elevenlabs_auth(schema)?;
let voice_id = req
.voice
.clone()
.unwrap_or_else(|| "JBFqnCBsd6RMkjVDRZzb".to_string());
let output_format = elevenlabs_output_format(&req.format);
let url = format!(
"{}/v1/text-to-speech/{}?output_format={}",
endpoint.trim_end_matches('/'),
voice_id,
output_format
);
let mut body = serde_json::json!({
"text": req.text,
"model_id": schema.name,
});
if let Some(language) = &req.language {
body["language_code"] = serde_json::Value::String(language.clone());
}
let resp = self
.remote_backend
.client
.post(url)
.header("xi-api-key", api_key)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| {
InferenceError::InferenceFailed(format!("ElevenLabs TTS request failed: {e}"))
})?;
let status = resp.status();
let audio = resp.bytes().await.map_err(|e| {
InferenceError::InferenceFailed(format!("read ElevenLabs TTS body: {e}"))
})?;
if !status.is_success() {
let err_body = String::from_utf8_lossy(&audio);
return Err(InferenceError::InferenceFailed(format!(
"ElevenLabs TTS returned {status}: {err_body}"
)));
}
let final_path = requested_or_temp_output(req.output_path.as_deref(), &req.format)?;
ensure_parent_dir(&final_path)?;
tokio::fs::write(&final_path, &audio).await?;
Ok(SynthesizeResult {
audio_path: final_path.display().to_string(),
media_type: media_type_for_format(&req.format),
model_used: Some(schema.name.clone()),
voice_used: Some(voice_id),
})
}
}
#[derive(Default)]
struct ProviderAccumulator {
configured: bool,
local_models: usize,
remote_models: usize,
available_models: usize,
capabilities: std::collections::HashSet<ModelCapability>,
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
struct CommandOutput {
stdout: String,
stderr: String,
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
#[derive(Debug, Clone)]
struct SpeechRuntime {
root: PathBuf,
python: PathBuf,
stt_program: PathBuf,
tts_program: PathBuf,
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
impl SpeechRuntime {
fn new(root: PathBuf) -> Self {
let bin_dir = root.join("bin");
Self {
root,
python: bin_dir.join("python"),
stt_program: bin_dir.join("mlx_audio.stt.generate"),
tts_program: bin_dir.join("mlx_audio.tts.generate"),
}
}
fn is_ready(&self) -> bool {
self.python.exists() && self.stt_program.exists() && self.tts_program.exists()
}
fn command_for(&self, subcommand: &str) -> Result<&Path, InferenceError> {
match subcommand {
"stt.generate" => Ok(&self.stt_program),
"tts.generate" => Ok(&self.tts_program),
_ => Err(InferenceError::InferenceFailed(format!(
"unknown speech subcommand: {subcommand}"
))),
}
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
async fn run_mlx_audio_command(
runtime: &SpeechRuntime,
subcommand: &str,
args: &[String],
) -> Result<CommandOutput, InferenceError> {
run_mlx_audio_command_with_env(runtime, subcommand, args, &[]).await
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
async fn run_mlx_audio_command_with_env(
runtime: &SpeechRuntime,
subcommand: &str,
args: &[String],
envs: &[(String, String)],
) -> Result<CommandOutput, InferenceError> {
let program = runtime.command_for(subcommand)?;
let mut command = Command::new(program);
command.args(args);
for (key, value) in envs {
command.env(key, value);
}
let output = command
.output()
.await
.map_err(|err| InferenceError::InferenceFailed(format!("{}: {err}", program.display())))?;
if output.status.success() {
Ok(CommandOutput {
stdout: String::from_utf8_lossy(&output.stdout).to_string(),
stderr: String::from_utf8_lossy(&output.stderr).to_string(),
})
} else {
Err(InferenceError::InferenceFailed(format!(
"{} exited with {}: {}",
program.display(),
output.status,
String::from_utf8_lossy(&output.stderr)
)))
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
async fn bootstrap_speech_runtime(runtime: &SpeechRuntime) -> Result<(), InferenceError> {
std::fs::create_dir_all(&runtime.root)?;
let python = select_speech_python()?;
run_command(
"uv",
&[
"venv".to_string(),
"--python".to_string(),
python,
runtime.root.display().to_string(),
],
)
.await?;
run_command(
"uv",
&[
"pip".to_string(),
"install".to_string(),
"--python".to_string(),
runtime.python.display().to_string(),
speech_runtime_mlx_audio_spec(),
"misaki[en]".to_string(),
speech_runtime_spacy_model_spec(),
],
)
.await?;
Ok(())
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
async fn run_command(program: &str, args: &[String]) -> Result<(), InferenceError> {
let output = Command::new(program)
.args(args)
.output()
.await
.map_err(|err| InferenceError::InferenceFailed(format!("{program}: {err}")))?;
if output.status.success() {
Ok(())
} else {
Err(InferenceError::InferenceFailed(format!(
"{} exited with {}: {}",
program,
output.status,
String::from_utf8_lossy(&output.stderr)
)))
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn select_speech_python() -> Result<String, InferenceError> {
if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
if !path.trim().is_empty() {
return Ok(path);
}
}
for candidate in ["python3.13", "python3.12", "python3.11"] {
if command_in_path(candidate) {
return Ok(candidate.to_string());
}
}
Err(InferenceError::InferenceFailed(
"no supported Python found for managed speech runtime (tried python3.13, python3.12, python3.11)".into(),
))
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn detect_speech_python() -> Option<String> {
if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
if !path.trim().is_empty() {
return Some(path);
}
}
["python3.13", "python3.12", "python3.11"]
.into_iter()
.find(|candidate| command_in_path(candidate))
.map(str::to_string)
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn speech_runtime_root_from_models_dir(_models_dir: &Path) -> PathBuf {
if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
if !path.trim().is_empty() {
return PathBuf::from(path);
}
}
std::env::var("HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("."))
.join(".car")
.join("speech-runtime")
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn command_in_path(name: &str) -> bool {
std::env::var_os("PATH")
.map(|paths| {
std::env::split_paths(&paths).any(|dir| {
let path = dir.join(name);
path.exists() && path.is_file()
})
})
.unwrap_or(false)
}
fn speech_model_cached(schema: &ModelSchema) -> bool {
match &schema.source {
ModelSource::Mlx { hf_repo, .. } => huggingface_repo_has_snapshot(hf_repo),
ModelSource::Proprietary { auth, .. } => match auth {
ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
ProprietaryAuth::OAuth2Pkce { .. } => false,
},
_ => false,
}
}
fn remove_huggingface_repo_cache(repo_id: &str) -> Result<(), InferenceError> {
let repo_dir = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
std::env::var("HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("."))
.join(".cache")
.join("huggingface")
})
.join("hub")
.join(format!("models--{}", repo_id.replace('/', "--")));
if repo_dir.exists() {
std::fs::remove_dir_all(repo_dir)?;
}
Ok(())
}
fn model_source_configured(schema: &ModelSchema) -> bool {
match &schema.source {
ModelSource::RemoteApi {
api_key_env,
api_key_envs,
..
} => {
std::env::var(api_key_env).is_ok()
|| api_key_envs
.iter()
.any(|env_var| std::env::var(env_var).is_ok())
}
ModelSource::Proprietary { auth, .. } => match auth {
ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
ProprietaryAuth::OAuth2Pkce { .. } => false,
},
ModelSource::VllmMlx { .. } => {
std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available
}
ModelSource::Ollama { .. } => schema.available,
ModelSource::Mlx { .. } | ModelSource::Local { .. } => true,
ModelSource::AppleFoundationModels { .. } => schema.available,
ModelSource::Delegated { .. } => true,
}
}
fn all_model_capabilities() -> [ModelCapability; 13] {
[
ModelCapability::Generate,
ModelCapability::Embed,
ModelCapability::Classify,
ModelCapability::Code,
ModelCapability::Reasoning,
ModelCapability::Summarize,
ModelCapability::ToolUse,
ModelCapability::MultiToolCall,
ModelCapability::Vision,
ModelCapability::SpeechToText,
ModelCapability::TextToSpeech,
ModelCapability::ImageGeneration,
ModelCapability::VideoGeneration,
]
}
fn sort_capabilities(mut capabilities: Vec<ModelCapability>) -> Vec<ModelCapability> {
capabilities.sort_by_key(|capability| {
all_model_capabilities()
.iter()
.position(|candidate| candidate == capability)
.unwrap_or(usize::MAX)
});
capabilities
}
fn speech_model_source_label(schema: &ModelSchema) -> String {
match &schema.source {
ModelSource::Mlx { hf_repo, .. } => format!("mlx:{hf_repo}"),
ModelSource::Proprietary {
provider, endpoint, ..
} => format!("proprietary:{provider}:{endpoint}"),
ModelSource::RemoteApi { endpoint, .. } => format!("remote:{endpoint}"),
ModelSource::Local { hf_repo, .. } => format!("local:{hf_repo}"),
ModelSource::VllmMlx {
endpoint,
model_name,
} => format!("vllm-mlx:{endpoint}:{model_name}"),
ModelSource::Ollama { model_tag, host } => format!("ollama:{host}:{model_tag}"),
ModelSource::AppleFoundationModels { use_case } => {
format!(
"apple-foundation:{}",
use_case.as_deref().unwrap_or("default")
)
}
ModelSource::Delegated { hint } => {
format!("delegated:{}", hint.as_deref().unwrap_or("(none)"))
}
}
}
fn rerank_prompt(instruction: &str, query: &str, document: &str) -> String {
const SYSTEM: &str = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".";
format!(
"<|im_start|>system\n{SYSTEM}<|im_end|>\n\
<|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n\
<|im_start|>assistant\n<think>\n\n</think>\n\n"
)
}
fn score_from_rerank_output(text: &str, model_name: &str) -> f32 {
let normalized: String = text
.to_ascii_lowercase()
.chars()
.map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
.collect();
for tok in normalized.split_ascii_whitespace().take(5) {
match tok {
"yes" => return 1.0,
"no" => return 0.0,
_ => continue,
}
}
tracing::warn!(
model = %model_name,
output = %text,
"rerank: first tokens contain neither `yes` nor `no`; returning neutral 0.5"
);
0.5
}
fn default_speech_voice(schema: &ModelSchema) -> Option<String> {
if schema.provider == "elevenlabs" {
Some("JBFqnCBsd6RMkjVDRZzb".to_string())
} else if schema.name == "Kokoro-82M-6bit" || schema.name == "Kokoro-82M-bf16" {
Some("af_heart".to_string())
} else if schema.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
Some("Chelsie".to_string())
} else {
None
}
}
fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
find_latest_huggingface_snapshot(repo_id).is_some()
}
fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
let cache_root = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
std::env::var("HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("."))
.join(".cache")
.join("huggingface")
})
.join("hub");
cache_root.join(format!("models--{}", repo_id.replace('/', "--")))
}
fn find_latest_huggingface_snapshot(repo_id: &str) -> Option<PathBuf> {
let snapshots = huggingface_repo_dir(repo_id).join("snapshots");
std::fs::read_dir(snapshots)
.ok()?
.filter_map(Result::ok)
.map(|entry| entry.path())
.find(|path| path.is_dir() && snapshot_looks_ready(path))
}
fn snapshot_looks_ready(path: &Path) -> bool {
if path.join("config.json").exists() || path.join("model_index.json").exists() {
return true;
}
snapshot_contains_ext(path, "safetensors")
}
fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
let Ok(entries) = std::fs::read_dir(root) else {
return false;
};
entries.filter_map(Result::ok).any(|entry| {
let path = entry.path();
if path.is_dir() {
snapshot_contains_ext(&path, ext)
} else {
path.extension()
.and_then(|value| value.to_str())
.map(|value| value.eq_ignore_ascii_case(ext))
.unwrap_or(false)
}
})
}
fn count_files_recursive(root: &Path) -> usize {
let Ok(entries) = std::fs::read_dir(root) else {
return 0;
};
entries
.filter_map(Result::ok)
.map(|entry| entry.path())
.map(|path| {
if path.is_dir() {
count_files_recursive(&path)
} else if path.is_file() {
1
} else {
0
}
})
.sum()
}
async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
let api = hf_hub::api::tokio::ApiBuilder::from_env()
.with_progress(false)
.build()
.map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
let repo = api.model(repo_id.to_string());
let info = repo
.info()
.await
.map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
let snapshot_path = huggingface_repo_dir(repo_id)
.join("snapshots")
.join(&info.sha);
let mut downloaded = 0usize;
for sibling in &info.siblings {
let local_path = snapshot_path.join(&sibling.rfilename);
if local_path.exists() {
downloaded += 1;
continue;
}
repo.download(&sibling.rfilename).await.map_err(|e| {
InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
})?;
downloaded += 1;
}
Ok((snapshot_path, downloaded))
}
fn temp_work_dir(prefix: &str) -> Result<PathBuf, InferenceError> {
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| InferenceError::InferenceFailed(format!("clock error: {e}")))?
.as_nanos();
let dir = std::env::temp_dir().join(format!("car-inference-{prefix}-{unique}"));
std::fs::create_dir_all(&dir)?;
Ok(dir)
}
fn ensure_parent_dir(path: &Path) -> Result<(), InferenceError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
Ok(())
}
fn requested_or_temp_output(
output_path: Option<&str>,
format: &str,
) -> Result<PathBuf, InferenceError> {
if let Some(path) = output_path {
return Ok(PathBuf::from(path));
}
let dir = temp_work_dir("audio-out")?;
Ok(dir.join(format!("speech.{format}")))
}
fn requested_or_temp_media_output(
output_path: Option<&str>,
format: &str,
stem: &str,
) -> Result<PathBuf, InferenceError> {
if let Some(path) = output_path {
return Ok(PathBuf::from(path));
}
let dir = temp_work_dir(&format!("{stem}-out"))?;
Ok(dir.join(format!("{stem}.{format}")))
}
fn materialize_audio_output(
produced: &Path,
requested: Option<&str>,
format: &str,
) -> Result<PathBuf, InferenceError> {
if let Some(path) = requested {
let dest = PathBuf::from(path);
ensure_parent_dir(&dest)?;
std::fs::copy(produced, &dest)?;
Ok(dest)
} else {
let dest = requested_or_temp_output(None, format)?;
ensure_parent_dir(&dest)?;
std::fs::copy(produced, &dest)?;
Ok(dest)
}
}
fn materialize_binary_output(
produced: &Path,
requested: Option<&str>,
format: &str,
stem: &str,
) -> Result<PathBuf, InferenceError> {
let dest = requested_or_temp_media_output(requested, format, stem)?;
ensure_parent_dir(&dest)?;
std::fs::copy(produced, &dest)?;
Ok(dest)
}
fn find_generated_file(
root: &Path,
extensions: &[&str],
) -> Result<Option<PathBuf>, InferenceError> {
let entries = std::fs::read_dir(root)?;
let mut candidates: Vec<PathBuf> = entries
.filter_map(Result::ok)
.map(|entry| entry.path())
.filter(|path| {
path.is_file()
&& path
.extension()
.and_then(|ext| ext.to_str())
.map(|ext| {
extensions
.iter()
.any(|candidate| candidate.eq_ignore_ascii_case(ext))
})
.unwrap_or(false)
})
.collect();
candidates.sort();
Ok(candidates.pop())
}
fn media_type_for_image_format(format: &str) -> String {
match format.to_ascii_lowercase().as_str() {
"jpg" | "jpeg" => "image/jpeg".to_string(),
"webp" => "image/webp".to_string(),
_ => "image/png".to_string(),
}
}
fn media_type_for_video_format(format: &str) -> String {
match format.to_ascii_lowercase().as_str() {
"mov" => "video/quicktime".to_string(),
"gif" => "image/gif".to_string(),
_ => "video/mp4".to_string(),
}
}
fn read_transcription_result(output_prefix: &Path) -> Result<Option<String>, InferenceError> {
let candidates = [
output_prefix.with_extension("json"),
output_prefix.to_path_buf(),
];
for path in candidates {
if path.exists() {
let contents = std::fs::read_to_string(path)?;
if let Some(text) = extract_text_from_payload(&contents) {
return Ok(Some(text));
}
}
}
Ok(None)
}
fn extract_text_from_payload(payload: &str) -> Option<String> {
let value: serde_json::Value = serde_json::from_str(payload).ok()?;
if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
return Some(text.to_string());
}
if let Some(transcripts) = value.get("transcripts").and_then(|v| v.as_array()) {
let joined = transcripts
.iter()
.filter_map(|item| item.get("text").and_then(|v| v.as_str()))
.collect::<Vec<_>>()
.join("\n");
if !joined.is_empty() {
return Some(joined);
}
}
if let Some(items) = value.as_array() {
let joined = items
.iter()
.filter_map(|item| {
item.get("text")
.or_else(|| item.get("Content"))
.and_then(|v| v.as_str())
})
.collect::<Vec<_>>()
.join(" ");
if !joined.is_empty() {
return Some(joined);
}
}
None
}
fn find_audio_file(output_dir: &Path) -> Result<Option<PathBuf>, InferenceError> {
let mut audio_files = Vec::new();
collect_audio_files(output_dir, &mut audio_files)?;
audio_files.sort();
Ok(audio_files.into_iter().next())
}
fn collect_audio_files(dir: &Path, audio_files: &mut Vec<PathBuf>) -> Result<(), InferenceError> {
for entry in std::fs::read_dir(dir)? {
let path = entry?.path();
if path.is_dir() {
collect_audio_files(&path, audio_files)?;
} else if matches!(
path.extension().and_then(|ext| ext.to_str()),
Some("wav" | "mp3" | "flac" | "pcm" | "m4a")
) {
audio_files.push(path);
}
}
Ok(())
}
fn media_type_for_format(format: &str) -> String {
match format.to_ascii_lowercase().as_str() {
"mp3" => "audio/mpeg".to_string(),
"flac" => "audio/flac".to_string(),
"pcm" => "audio/L16".to_string(),
"m4a" => "audio/mp4".to_string(),
_ => "audio/wav".to_string(),
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn kokoro_lang_code(language: Option<&str>) -> &'static str {
match language.unwrap_or("en").to_ascii_lowercase().as_str() {
"en-gb" | "british" | "british english" => "b",
"ja" | "japanese" => "j",
"zh" | "zh-cn" | "mandarin" | "chinese" => "z",
"es" | "spanish" => "e",
"fr" | "french" => "f",
_ => "a",
}
}
fn normalize_lang_code(language: &str) -> String {
match language.to_ascii_lowercase().as_str() {
"english" | "en-us" | "en_us" => "en".to_string(),
"spanish" => "es".to_string(),
"french" => "fr".to_string(),
"japanese" => "ja".to_string(),
"chinese" | "mandarin" => "zh".to_string(),
other => match other {
"en" | "es" | "fr" | "ja" | "zh" => other.to_string(),
_ => "en".to_string(),
},
}
}
fn elevenlabs_auth(schema: &ModelSchema) -> Result<(String, String), InferenceError> {
match &schema.source {
ModelSource::Proprietary {
endpoint,
auth: schema::ProprietaryAuth::ApiKeyEnv { env_var },
..
} => {
let key = car_secrets::resolve_env_or_keychain(env_var).ok_or_else(|| {
InferenceError::InferenceFailed(format!(
"missing API key {env_var}; set the environment variable or \
store it with `car secrets put {env_var}`"
))
})?;
Ok((endpoint.clone(), key))
}
_ => Err(InferenceError::InferenceFailed(format!(
"model {} is not an ElevenLabs proprietary model",
schema.id
))),
}
}
fn elevenlabs_output_format(format: &str) -> &'static str {
match format.to_ascii_lowercase().as_str() {
"mp3" => "mp3_44100_128",
"pcm" => "pcm_16000",
_ => "wav_44100",
}
}
fn benchmark_priors_paths(models_dir: &Path) -> Vec<PathBuf> {
let mut paths = Vec::new();
let direct = models_dir.join("benchmark_priors.json");
if !paths.contains(&direct) {
paths.push(direct);
}
if let Some(parent) = models_dir.parent() {
let parent_path = parent.join("benchmark_priors.json");
if !paths.contains(&parent_path) {
paths.push(parent_path);
}
}
if let Some(path) = std::env::var_os("CAR_BENCHMARK_PRIORS_PATH") {
let path = PathBuf::from(path);
if !paths.contains(&path) {
paths.push(path);
}
}
paths
}
fn load_benchmark_prior_health(
models_dir: &Path,
schemas: &[ModelSchema],
) -> Vec<ModelBenchmarkPriorHealth> {
let mut priors = std::collections::BTreeMap::new();
for path in benchmark_priors_paths(models_dir) {
let Ok(loaded) = routing_ext::load_benchmark_priors(&path) else {
continue;
};
for (model_id, prior) in loaded {
let model_name = schemas
.iter()
.find(|schema| schema.id == model_id)
.map(|schema| schema.name.clone());
priors.insert(
model_id.clone(),
ModelBenchmarkPriorHealth {
model_id,
model_name,
overall_score: prior.overall_score,
overall_latency_ms: prior.overall_latency_ms,
task_scores: prior.task_scores,
task_latency_ms: prior.task_latency_ms,
source_path: path.clone(),
},
);
}
}
priors.into_values().collect()
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn kokoro_runtime_fallback_enabled() -> bool {
std::env::var("CAR_SPEECH_KOKORO_FALLBACK")
.ok()
.map(|value| {
!matches!(
value.trim().to_ascii_lowercase().as_str(),
"0" | "false" | "off"
)
})
.unwrap_or(true)
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn speech_runtime_mlx_audio_spec() -> String {
std::env::var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC")
.ok()
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| "mlx-audio==0.4.2".to_string())
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn speech_runtime_spacy_model_spec() -> String {
std::env::var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC")
.ok()
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| {
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl".to_string()
})
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
fn test_config(models_dir: PathBuf) -> InferenceConfig {
InferenceConfig {
models_dir,
device: None,
generation_model: "Qwen3-0.6B".into(),
preferred_generation_model: None,
embedding_model: "Qwen3-Embedding-0.6B".into(),
preferred_embedding_model: None,
classification_model: "Qwen3-0.6B".into(),
preferred_classification_model: None,
}
}
#[tokio::test]
async fn tokenize_rejects_known_remote_model_with_unsupported_mode() {
let tmp = TempDir::new().unwrap();
let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
let remote_id = engine
.list_schemas()
.into_iter()
.find(|s| !s.is_local())
.map(|s| s.id)
.expect("built-in catalog should include at least one remote model schema");
let err = engine
.tokenize(&remote_id, "hello")
.await
.expect_err("remote tokenize must error");
match err {
InferenceError::UnsupportedMode { mode, backend, .. } => {
assert_eq!(mode, "tokenize/detokenize");
assert_eq!(backend, "remote");
}
other => panic!("expected UnsupportedMode, got {other:?}"),
}
let err = engine
.detokenize(&remote_id, &[1, 2, 3])
.await
.expect_err("remote detokenize must error");
assert!(
matches!(err, InferenceError::UnsupportedMode { .. }),
"expected UnsupportedMode, got {err:?}"
);
}
#[test]
fn engine_loads_benchmark_priors_on_startup() {
let _env = ENV_MUTEX.lock().unwrap();
let tmp = TempDir::new().unwrap();
let priors_path = tmp.path().join("benchmark_priors.json");
std::fs::write(
&priors_path,
serde_json::json!({
"model_id": "qwen/qwen3-8b:q4_k_m",
"overall_score": 0.88
})
.to_string(),
)
.unwrap();
unsafe {
std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
}
let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
let tracker = engine.outcome_tracker.blocking_read();
let profile = tracker
.profile("qwen/qwen3-8b:q4_k_m")
.expect("benchmark prior should create a profile");
assert!((profile.ema_quality - 0.88).abs() < 0.01);
unsafe {
std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
}
}
#[test]
fn benchmark_priors_do_not_override_observed_profiles() {
let _env = ENV_MUTEX.lock().unwrap();
let tmp = TempDir::new().unwrap();
let models_dir = tmp.path().join("models");
std::fs::create_dir_all(&models_dir).unwrap();
let observed = vec![ModelProfile {
model_id: "qwen/qwen3-8b:q4_k_m".into(),
total_calls: 12,
success_count: 3,
fail_count: 9,
total_latency_ms: 1200,
total_input_tokens: 0,
total_output_tokens: 0,
task_stats: std::collections::HashMap::new(),
ema_quality: 0.21,
quality_per_1k_tokens: 0.0,
updated_at: 1,
}];
std::fs::write(
models_dir.join("outcome_profiles.json"),
serde_json::to_string(&observed).unwrap(),
)
.unwrap();
let priors_path = tmp.path().join("benchmark_priors.json");
std::fs::write(
&priors_path,
serde_json::json!({
"model_id": "qwen/qwen3-8b:q4_k_m",
"overall_score": 0.95
})
.to_string(),
)
.unwrap();
unsafe {
std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
}
let engine = InferenceEngine::new(test_config(models_dir));
let tracker = engine.outcome_tracker.blocking_read();
let profile = tracker
.profile("qwen/qwen3-8b:q4_k_m")
.expect("observed profile should remain present");
assert!((profile.ema_quality - 0.21).abs() < 0.01);
assert_eq!(profile.total_calls, 12);
unsafe {
std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
}
}
#[test]
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn speech_runtime_package_spec_defaults_and_overrides() {
let _env = ENV_MUTEX.lock().unwrap();
unsafe {
std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
}
assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.2");
unsafe {
std::env::set_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC", "mlx-audio==0.4.1");
}
assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.1");
unsafe {
std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
}
}
#[test]
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn speech_runtime_spacy_model_spec_defaults_and_overrides() {
let _env = ENV_MUTEX.lock().unwrap();
unsafe {
std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
}
assert!(
speech_runtime_spacy_model_spec().starts_with("en-core-web-sm @ https://github.com/")
);
unsafe {
std::env::set_var(
"CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC",
"en-core-web-sm==3.8.0",
);
}
assert_eq!(speech_runtime_spacy_model_spec(), "en-core-web-sm==3.8.0");
unsafe {
std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
}
}
#[test]
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn kokoro_runtime_fallback_defaults_on() {
unsafe {
std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
}
assert!(kokoro_runtime_fallback_enabled());
unsafe {
std::env::set_var("CAR_SPEECH_KOKORO_FALLBACK", "false");
}
assert!(!kokoro_runtime_fallback_enabled());
unsafe {
std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
}
}
#[test]
fn preferred_local_tts_wins_over_builtin_rank() {
let tmp = TempDir::new().unwrap();
let mut engine = InferenceEngine::new(test_config(tmp.path().join("models")));
engine.set_speech_policy(SpeechPolicy {
prefer_local: true,
allow_remote_fallback: false,
preferred_local_stt: None,
preferred_local_tts: Some("Kokoro-82M-6bit".into()),
preferred_remote_stt: None,
preferred_remote_tts: None,
});
let schema = engine
.preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
.expect("preferred local TTS should resolve");
assert_eq!(schema.name, "Kokoro-82M-6bit");
}
#[test]
fn preferred_discovered_vllm_mlx_model_wins_generate_routing() {
let tmp = TempDir::new().unwrap();
let mut config = test_config(tmp.path().join("models"));
config.preferred_generation_model =
Some("vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit".into());
let mut engine = InferenceEngine::new(config);
let schema = crate::vllm_mlx::to_model_schema(
&crate::vllm_mlx::DiscoveredModel {
id: "mlx-community/gemma-3n-E2B-it-lm-4bit".into(),
owned_by: Some("mlx-community".into()),
},
"http://127.0.0.1:8001",
);
engine.register_model(schema);
let rt = tokio::runtime::Runtime::new().unwrap();
let decision = rt.block_on(engine.route_adaptive("say hello in one sentence"));
assert_eq!(
decision.model_id,
"vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit"
);
assert_eq!(decision.strategy, RoutingStrategy::Explicit);
assert_eq!(decision.reason, "preferred generation model override");
}
#[test]
fn inference_result_serializes_with_full_shape() {
use crate::tasks::generate::ToolCall;
use std::collections::HashMap;
let mut args = HashMap::new();
args.insert("path".to_string(), serde_json::json!("README.md"));
let result = InferenceResult {
text: String::new(),
bounding_boxes: Vec::new(),
tool_calls: vec![ToolCall {
id: None,
name: "read_file".into(),
arguments: args,
}],
trace_id: "trace-abc".into(),
model_used: "test-model".into(),
latency_ms: 1234,
time_to_first_token_ms: Some(180),
usage: Some(TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
context_window: 8192,
}),
provider_output_items: Vec::new(),
};
let json = serde_json::to_value(&result).expect("serialize");
assert_eq!(json["text"].as_str(), Some(""));
assert_eq!(json["trace_id"].as_str(), Some("trace-abc"));
assert_eq!(json["model_used"].as_str(), Some("test-model"));
assert_eq!(json["latency_ms"].as_u64(), Some(1234));
let tool_calls = json["tool_calls"].as_array().expect("tool_calls array");
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0]["name"].as_str(), Some("read_file"));
assert_eq!(
tool_calls[0]["arguments"]["path"].as_str(),
Some("README.md")
);
let usage = &json["usage"];
assert_eq!(usage["prompt_tokens"].as_u64(), Some(100));
assert_eq!(usage["completion_tokens"].as_u64(), Some(50));
assert_eq!(usage["total_tokens"].as_u64(), Some(150));
assert_eq!(usage["context_window"].as_u64(), Some(8192));
assert_eq!(json["time_to_first_token_ms"].as_u64(), Some(180));
}
#[test]
fn inference_result_top_level_keys_are_locked() {
use std::collections::BTreeSet;
let result = InferenceResult {
text: "anything".into(),
bounding_boxes: Vec::new(),
tool_calls: vec![],
trace_id: "t".into(),
model_used: "m".into(),
latency_ms: 0,
time_to_first_token_ms: None,
usage: None,
provider_output_items: Vec::new(),
};
let json = serde_json::to_value(&result).expect("serialize");
let keys: BTreeSet<&str> = json
.as_object()
.expect("top-level object")
.keys()
.map(String::as_str)
.collect();
let expected: BTreeSet<&str> = [
"text",
"tool_calls",
"trace_id",
"model_used",
"latency_ms",
"time_to_first_token_ms",
"usage",
]
.into_iter()
.collect();
assert_eq!(
keys, expected,
"infer response top-level keys drifted -- update both the test \
and the WebSocket protocol documentation if this is intentional"
);
for key in &keys {
assert!(
!key.chars().any(|c| c.is_uppercase()) && !key.contains('-'),
"key '{}' is not snake_case",
key
);
}
}
#[test]
fn inference_result_serializes_plain_text_response() {
let result = InferenceResult {
text: "hello world".into(),
bounding_boxes: Vec::new(),
tool_calls: vec![],
trace_id: "trace-xyz".into(),
model_used: "test-model".into(),
latency_ms: 42,
time_to_first_token_ms: None,
usage: None,
provider_output_items: Vec::new(),
};
let json = serde_json::to_value(&result).expect("serialize");
assert_eq!(json["text"], "hello world");
assert!(json["tool_calls"].is_array());
assert_eq!(json["tool_calls"].as_array().unwrap().len(), 0);
assert_eq!(json["model_used"], "test-model");
assert!(json["usage"].is_null());
assert!(json["time_to_first_token_ms"].is_null());
}
#[test]
fn generate_request_deserializes_intent_field_from_json_rpc_params() {
use crate::intent::{IntentHint, TaskHint};
use crate::schema::ModelCapability;
let params = serde_json::json!({
"prompt": "summarize this email",
"intent": {
"task": "chat",
"prefer_local": true,
"require": ["tool_use"],
},
});
let req: GenerateRequest =
serde_json::from_value(params).expect("GenerateRequest deserialize");
let intent = req.intent.as_ref().expect("intent field deserialized");
assert_eq!(intent.task, Some(TaskHint::Chat));
assert!(intent.prefer_local);
assert_eq!(intent.require, vec![ModelCapability::ToolUse]);
let back: serde_json::Value =
serde_json::to_value(&req).expect("re-serialize GenerateRequest");
assert_eq!(back["intent"]["task"], "chat");
assert_eq!(back["intent"]["prefer_local"], true);
assert_eq!(back["intent"]["require"][0], "tool_use");
let default_req: GenerateRequest = serde_json::from_value(serde_json::json!({
"prompt": "x",
"intent": {},
}))
.unwrap();
let default_intent = default_req.intent.expect("present but empty");
assert_eq!(default_intent.task, None);
assert!(!default_intent.prefer_local);
assert!(default_intent.require.is_empty());
let no_intent: GenerateRequest =
serde_json::from_value(serde_json::json!({"prompt": "x"})).unwrap();
assert!(no_intent.intent.is_none());
}
#[test]
fn rerank_prompt_matches_upstream_template_shape() {
let p = rerank_prompt(
"retrieve relevant passages",
"who runs the treasury?",
"doc x",
);
assert!(p.contains("<|im_start|>system"));
assert!(p.contains("Note that the answer can only be \"yes\" or \"no\"."));
assert!(p.contains("<|im_start|>user\n<Instruct>: retrieve relevant passages"));
assert!(p.contains("<Query>: who runs the treasury?"));
assert!(p.contains("<Document>: doc x<|im_end|>"));
assert!(p.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
}
#[test]
fn rerank_score_yes_and_no_exactly() {
assert_eq!(score_from_rerank_output("yes", "m"), 1.0);
assert_eq!(score_from_rerank_output("no", "m"), 0.0);
}
#[test]
fn rerank_score_handles_case_leading_space_and_chat_sentinels() {
assert_eq!(score_from_rerank_output(" Yes", "m"), 1.0);
assert_eq!(score_from_rerank_output("\nno.", "m"), 0.0);
assert_eq!(score_from_rerank_output("<|im_end|>yes", "m"), 1.0);
}
#[test]
fn rerank_score_scans_up_to_three_tokens() {
assert_eq!(score_from_rerank_output("_bos_ yes", "m"), 1.0);
}
#[test]
fn rerank_score_unexpected_is_neutral() {
assert_eq!(score_from_rerank_output("maybe", "m"), 0.5);
assert_eq!(score_from_rerank_output("", "m"), 0.5);
}
}