Skip to main content

car_inference/
lib.rs

1//! # car-inference
2//!
3//! Local model inference for the Common Agent Runtime.
4//!
5//! Provides on-device inference using Candle with automatic hardware detection:
6//! - **macOS**: Metal (Apple Silicon GPU)
7//! - **Linux**: CUDA (NVIDIA GPU) or CPU fallback
8//!
9//! Ships with Qwen3 models downloaded on first use from HuggingFace.
10//! Supports remote API models (OpenAI, Anthropic, Google) via the same schema.
11//!
12//! ## Architecture
13//!
14//! Models are first-class typed resources described by `ModelSchema` (analogous
15//! to `ToolSchema`). The `UnifiedRegistry` holds local and remote models.
16//! The `AdaptiveRouter` selects the best model using a three-phase strategy:
17//! filter → score → explore. The `OutcomeTracker` learns from results to
18//! improve routing over time.
19//!
20//! ## Dual purpose
21//!
22//! 1. **Internal** — powers skill learning/repair, semantic memory, policy evaluation
23//! 2. **Service** — exposes `infer`, `embed`, `classify` as built-in CAR tools
24
25pub mod adaptive_router;
26pub mod backend;
27pub mod backend_cache;
28pub mod hardware;
29pub mod intent;
30pub mod key_pool;
31pub mod models;
32pub mod outcome;
33pub mod protocol;
34pub mod registry;
35pub mod remote;
36pub mod router;
37pub mod routing_ext;
38pub mod schema;
39pub mod service;
40pub mod stream;
41pub mod tasks;
42pub mod vllm_mlx;
43
44use std::path::{Path, PathBuf};
45use std::sync::Arc;
46use std::time::Instant;
47use std::time::{SystemTime, UNIX_EPOCH};
48
49use serde::Serialize;
50use reqwest::multipart::{Form, Part};
51use thiserror::Error;
52#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
53use tokio::process::Command;
54#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
55use tokio::sync::Mutex;
56use tokio::sync::RwLock;
57use tracing::{debug, instrument};
58
59// --- New types ---
60pub use adaptive_router::{
61    AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy,
62};
63pub use intent::{IntentHint, TaskHint};
64pub use key_pool::{KeyPool, KeyStats};
65pub use outcome::{
66    CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
67};
68pub use registry::{ModelFilter, ModelInfo, ModelRuntimeRequirement, ModelUpgrade, UnifiedRegistry};
69pub use remote::RemoteBackend;
70pub use routing_ext::{
71    CircuitBreaker, CircuitBreakerRegistry, CircuitState, ImplicitSignal, ImplicitSignalType,
72    RoutingMode, SpendControl, SpendLimitExceeded, SpendLimits, SpendStatus,
73};
74pub use schema::{
75    ApiProtocol, BenchmarkScore, CostModel, ModelCapability, ModelSchema, ModelSource,
76    PerformanceEnvelope, ProprietaryAuth,
77};
78
79// --- Legacy re-exports (kept for backward compatibility) ---
80pub use adaptive_router::TaskComplexity;
81#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
82pub use backend::CandleBackend;
83#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
84pub use backend::EmbeddingBackend;
85#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
86pub use backend::MlxBackend;
87pub use hardware::HardwareInfo;
88pub use models::{ModelRegistry, ModelRole};
89pub use router::{ModelRouter, RoutingDecision};
90pub use stream::{StreamAccumulator, StreamEvent};
91pub use tasks::{
92    parse_boxes, BoundingBox, ClassifyRequest, ClassifyResult, ContentBlock, EmbedRequest,
93    GenerateImageRequest, GenerateImageResult, GenerateParams, GenerateRequest,
94    GenerateVideoRequest, GenerateVideoResult, GroundRequest, GroundResult, Message, RerankRequest,
95    RerankResult, RerankedDocument, RoutingWorkload, SynthesizeRequest, SynthesizeResult, ToolCall,
96    TranscribeRequest, TranscribeResult, VideoMode,
97};
98// TokenUsage is defined in this module and already public
99
100#[derive(Error, Debug)]
101pub enum InferenceError {
102    #[error("model not found: {0}")]
103    ModelNotFound(String),
104
105    #[error("model download failed: {0}")]
106    DownloadFailed(String),
107
108    #[error("inference failed: {0}")]
109    InferenceFailed(String),
110
111    /// A request mode is accepted on the public surface but the
112    /// selected backend hasn't wired it yet. Distinct from
113    /// `InferenceFailed` so callers can distinguish "backend can't"
114    /// from "backend tried and something went wrong".
115    #[error("mode {mode} not implemented on backend {backend}: {reason}")]
116    UnsupportedMode {
117        mode: &'static str,
118        backend: &'static str,
119        reason: &'static str,
120    },
121
122    #[error("tokenization error: {0}")]
123    TokenizationError(String),
124
125    #[error("device error: {0}")]
126    DeviceError(String),
127
128    #[error("io error: {0}")]
129    Io(#[from] std::io::Error),
130}
131
132/// Which device to run inference on.
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum Device {
135    Cpu,
136    Metal,
137    Cuda(usize), // device ordinal
138}
139
140impl Device {
141    /// Auto-detect the best available device for this platform.
142    pub fn auto() -> Self {
143        #[cfg(feature = "metal")]
144        {
145            return Device::Metal;
146        }
147        #[cfg(feature = "cuda")]
148        {
149            return Device::Cuda(0);
150        }
151        #[cfg(not(any(feature = "metal", feature = "cuda")))]
152        {
153            Device::Cpu
154        }
155    }
156}
157
158/// Configuration for the inference engine.
159#[derive(Debug, Clone)]
160pub struct InferenceConfig {
161    /// Where to store downloaded models. Defaults to ~/.car/models/
162    pub models_dir: std::path::PathBuf,
163    /// Device override. None = auto-detect.
164    pub device: Option<Device>,
165    /// Default model for generation tasks.
166    pub generation_model: String,
167    /// Optional preferred model override for generation tasks.
168    pub preferred_generation_model: Option<String>,
169    /// Default model for embedding tasks.
170    pub embedding_model: String,
171    /// Optional preferred model override for embedding tasks.
172    pub preferred_embedding_model: Option<String>,
173    /// Default model for classification tasks.
174    pub classification_model: String,
175    /// Optional preferred model override for classification tasks.
176    pub preferred_classification_model: Option<String>,
177}
178
179impl Default for InferenceConfig {
180    fn default() -> Self {
181        let models_dir = dirs_next()
182            .unwrap_or_else(|| std::path::PathBuf::from("."))
183            .join(".car")
184            .join("models");
185
186        let hw = HardwareInfo::detect();
187
188        Self {
189            models_dir,
190            device: None,
191            generation_model: hw.recommended_model,
192            preferred_generation_model: None,
193            embedding_model: "Qwen3-Embedding-0.6B".to_string(),
194            preferred_embedding_model: None,
195            classification_model: "Qwen3-0.6B".to_string(),
196            preferred_classification_model: None,
197        }
198    }
199}
200
201fn dirs_next() -> Option<std::path::PathBuf> {
202    std::env::var("HOME").ok().map(std::path::PathBuf::from)
203}
204
205/// Token usage statistics from a model response.
206#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
207pub struct TokenUsage {
208    /// Number of tokens in the prompt/input.
209    pub prompt_tokens: u64,
210    /// Number of tokens in the completion/output.
211    pub completion_tokens: u64,
212    /// Total tokens (prompt + completion).
213    pub total_tokens: u64,
214    /// Model's maximum context window size.
215    pub context_window: u64,
216}
217
218/// Result of an inference call, including trace ID for outcome tracking.
219#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
220pub struct InferenceResult {
221    /// The generated text (empty if tool_calls are present).
222    pub text: String,
223    /// Tool calls returned by the model (when tools were provided in the request).
224    pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
225    /// Structured bounding boxes when the model emitted Qwen2.5-VL
226    /// grounding spans (`<|box_*|>`, `<|object_ref_*|>`) in its text.
227    /// Parsed from the same `text` field — the raw span markers remain
228    /// visible in `text` for callers that need to see them verbatim.
229    /// Empty vec when the model didn't ground anything (typical for
230    /// non-VL models or prompts that only ask for description).
231    #[serde(default, skip_serializing_if = "Vec::is_empty")]
232    pub bounding_boxes: Vec<crate::tasks::grounding::BoundingBox>,
233    /// Trace ID for reporting outcomes back to the tracker.
234    pub trace_id: String,
235    /// Which model was used.
236    pub model_used: String,
237    /// Wall-clock latency in ms.
238    pub latency_ms: u64,
239    /// Time to first token in milliseconds. Populated by the local
240    /// generate paths (Candle/MLX) which observe the prefill→first-decode
241    /// transition directly. `None` for paths that can't measure it
242    /// honestly without streaming — currently the non-streaming remote
243    /// paths. Callers needing TTFT on remote models should use
244    /// [`InferenceEngine::generate_tracked_stream`] and time the first
245    /// `text` event arrival themselves.
246    ///
247    /// Always serialized (as `null` when `None`) so downstream
248    /// validation harnesses can distinguish "wasn't measured" from
249    /// "field doesn't exist on this client's protocol version".
250    #[serde(default)]
251    pub time_to_first_token_ms: Option<u64>,
252    /// Token usage statistics from the API response (None for local models).
253    pub usage: Option<TokenUsage>,
254}
255
256impl InferenceResult {
257    /// Returns true if the model chose to call tools instead of generating text.
258    pub fn has_tool_calls(&self) -> bool {
259        !self.tool_calls.is_empty()
260    }
261}
262
263#[derive(Debug, Clone, Serialize)]
264pub struct SpeechRuntimeHealth {
265    pub root: PathBuf,
266    pub installed: bool,
267    pub python: PathBuf,
268    pub stt_command: PathBuf,
269    pub tts_command: PathBuf,
270    pub configured_python: Option<String>,
271    pub detected_python: Option<String>,
272}
273
274#[derive(Debug, Clone, Serialize)]
275pub struct SpeechModelHealth {
276    pub id: String,
277    pub name: String,
278    pub provider: String,
279    pub capability: ModelCapability,
280    pub is_local: bool,
281    pub available: bool,
282    pub cached: bool,
283    pub selected_by_default: bool,
284    pub source: String,
285}
286
287#[derive(Debug, Clone, Serialize)]
288pub struct SpeechHealthReport {
289    pub runtime: SpeechRuntimeHealth,
290    pub local_models: Vec<SpeechModelHealth>,
291    pub remote_models: Vec<SpeechModelHealth>,
292    pub elevenlabs_configured: bool,
293    pub prefer_local: bool,
294    pub allow_remote_fallback: bool,
295    pub preferred_local_stt: Option<String>,
296    pub preferred_local_tts: Option<String>,
297    pub preferred_remote_stt: Option<String>,
298    pub preferred_remote_tts: Option<String>,
299    pub local_stt_default: Option<String>,
300    pub local_tts_default: Option<String>,
301    pub remote_stt_default: Option<String>,
302    pub remote_tts_default: Option<String>,
303}
304
305#[derive(Debug, Clone, Serialize)]
306pub struct ModelDefaultHealth {
307    pub capability: ModelCapability,
308    pub configured_model: String,
309    pub available: bool,
310    pub is_local: bool,
311    pub provider: Option<String>,
312}
313
314#[derive(Debug, Clone, Serialize)]
315pub struct ModelProviderHealth {
316    pub provider: String,
317    pub configured: bool,
318    pub local_models: usize,
319    pub remote_models: usize,
320    pub available_models: usize,
321    pub capabilities: Vec<ModelCapability>,
322}
323
324#[derive(Debug, Clone, Serialize)]
325pub struct ModelCapabilityHealth {
326    pub capability: ModelCapability,
327    pub total_models: usize,
328    pub available_models: usize,
329    pub local_available_models: usize,
330    pub remote_available_models: usize,
331}
332
333#[derive(Debug, Clone, Serialize)]
334pub struct RoutingScenarioHealth {
335    pub name: String,
336    pub workload: RoutingWorkload,
337    pub task_family: String,
338    pub has_tools: bool,
339    pub has_vision: bool,
340    pub prefer_local: bool,
341    pub quality_first_cold_start: bool,
342    pub bootstrap_min_task_observations: u64,
343    pub bootstrap_quality_floor: f64,
344    pub model_id: String,
345    pub model_name: String,
346    pub reason: String,
347    pub strategy: RoutingStrategy,
348}
349
350#[derive(Debug, Clone, Serialize)]
351pub struct ModelBenchmarkPriorHealth {
352    pub model_id: String,
353    pub model_name: Option<String>,
354    pub overall_score: f64,
355    pub overall_latency_ms: Option<f64>,
356    pub task_scores: std::collections::HashMap<String, f64>,
357    pub task_latency_ms: std::collections::HashMap<String, f64>,
358    pub source_path: PathBuf,
359}
360
361#[derive(Debug, Clone, Serialize)]
362pub struct ModelHealthReport {
363    pub total_models: usize,
364    pub available_models: usize,
365    pub local_models: usize,
366    pub remote_models: usize,
367    pub defaults: Vec<ModelDefaultHealth>,
368    pub providers: Vec<ModelProviderHealth>,
369    pub capabilities: Vec<ModelCapabilityHealth>,
370    pub routing_prefer_local: bool,
371    pub routing_quality_first_cold_start: bool,
372    pub routing_min_observations: u64,
373    pub routing_bootstrap_min_task_observations: u64,
374    pub routing_bootstrap_quality_floor: f64,
375    pub routing_quality_weight: f64,
376    pub routing_latency_weight: f64,
377    pub routing_cost_weight: f64,
378    pub routing_scenarios: Vec<RoutingScenarioHealth>,
379    pub benchmark_priors: Vec<ModelBenchmarkPriorHealth>,
380    pub speech: SpeechHealthReport,
381}
382
383#[derive(Debug, Clone, Serialize)]
384pub struct SpeechInstallReport {
385    pub name: String,
386    pub hf_repo: String,
387    pub snapshot_path: PathBuf,
388    pub files_downloaded: usize,
389}
390
391#[derive(Debug, Clone, Serialize)]
392pub struct SpeechSmokePathReport {
393    pub path: String,
394    pub tts_model: String,
395    pub stt_model: String,
396    pub audio_path: PathBuf,
397    pub transcript: String,
398}
399
400#[derive(Debug, Clone, Serialize, Default)]
401pub struct SpeechSmokeReport {
402    pub local: Option<SpeechSmokePathReport>,
403    pub remote: Option<SpeechSmokePathReport>,
404    pub skipped: Vec<String>,
405}
406
407#[derive(Debug, Clone, Serialize, Default)]
408pub struct SpeechPolicy {
409    pub prefer_local: bool,
410    pub allow_remote_fallback: bool,
411    pub preferred_local_stt: Option<String>,
412    pub preferred_local_tts: Option<String>,
413    pub preferred_remote_stt: Option<String>,
414    pub preferred_remote_tts: Option<String>,
415}
416
417/// The main inference engine. Thread-safe, lazily loads models.
418///
419/// Now includes the unified registry, adaptive router, and outcome tracker
420/// for schema-driven model selection with learned performance profiles.
421pub struct InferenceEngine {
422    pub config: InferenceConfig,
423    /// Unified model registry (local + remote).
424    pub unified_registry: UnifiedRegistry,
425    /// Adaptive router with three-phase selection.
426    pub adaptive_router: AdaptiveRouter,
427    /// Outcome tracker for learning from results.
428    pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
429    /// HTTP client for remote API models.
430    remote_backend: RemoteBackend,
431    /// Native MLX text-gen / embedding backends keyed by model id.
432    /// Same cache shape as `flux_cache` / `ltx_cache` / `kokoro_cache`:
433    /// per-entry `Arc<Mutex<MlxBackend>>` so concurrent calls for the
434    /// same model serialize, while calls for different models proceed
435    /// in parallel. Also bounded by `CAR_INFERENCE_MODEL_CACHE_MB`.
436    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
437    mlx_backends: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
438    /// LRU-evicting, mutex-serialized cache of loaded Flux image backends.
439    /// Avoids reloading the 4–5 GB model on every generate_image call,
440    /// and serializes concurrent calls onto the same backend (MLX ops
441    /// are not `Sync`).
442    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
443    flux_cache: Arc<backend_cache::BackendCache<backend::mlx_flux::FluxBackend>>,
444    /// Same for LTX video (~9 GB: transformer + Gemma 3 12B + VAE + vocoder).
445    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
446    ltx_cache: Arc<backend_cache::BackendCache<backend::mlx_ltx::LtxBackend>>,
447    /// Same for Kokoro TTS (~160 MB). Small but reloading per-utterance
448    /// added ~1 s of latency to every `synth` call.
449    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
450    kokoro_cache: Arc<backend_cache::BackendCache<backend::mlx_kokoro::KokoroBackend>>,
451    // Legacy fields kept for backward compatibility
452    pub registry: models::ModelRegistry,
453    pub router: ModelRouter,
454    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
455    backend: Arc<RwLock<Option<CandleBackend>>>,
456    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
457    embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
458    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
459    speech_runtime: Arc<Mutex<Option<SpeechRuntime>>>,
460    speech_policy: SpeechPolicy,
461}
462
463impl InferenceEngine {
464    fn preferred_model_for_capability(&self, capability: ModelCapability) -> Option<&str> {
465        match capability {
466            ModelCapability::Generate => self.config.preferred_generation_model.as_deref(),
467            ModelCapability::Embed => self.config.preferred_embedding_model.as_deref(),
468            ModelCapability::Classify => self.config.preferred_classification_model.as_deref(),
469            _ => None,
470        }
471    }
472
473    fn request_needs_vision(req: &GenerateRequest) -> bool {
474        req.images.as_ref().is_some_and(|images| !images.is_empty())
475            || req.messages.as_ref().is_some_and(|messages| {
476                messages
477                    .iter()
478                    .any(|msg| matches!(msg, Message::UserMultimodal { .. }))
479            })
480    }
481
482    /// True when any content block in the request carries video data.
483    /// Backends without a video-tokenization path use this to reject
484    /// the request up front with [`InferenceError::UnsupportedMode`]
485    /// rather than silently dropping the content.
486    fn request_has_video(req: &GenerateRequest) -> bool {
487        let images_have_video = req
488            .images
489            .as_ref()
490            .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
491        let messages_have_video = req.messages.as_ref().is_some_and(|messages| {
492            messages.iter().any(|msg| match msg {
493                Message::UserMultimodal { content } => {
494                    content.iter().any(ContentBlock::is_video)
495                }
496                _ => false,
497            })
498        });
499        images_have_video || messages_have_video
500    }
501
502    /// True when any content block in the request carries audio data.
503    /// Same role as [`request_has_video`] but for the audio path
504    /// (Gemma 4 small variants, Gemini).
505    fn request_has_audio(req: &GenerateRequest) -> bool {
506        let images_have_audio = req
507            .images
508            .as_ref()
509            .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
510        let messages_have_audio = req.messages.as_ref().is_some_and(|messages| {
511            messages.iter().any(|msg| match msg {
512                Message::UserMultimodal { content } => {
513                    content.iter().any(ContentBlock::is_audio)
514                }
515                _ => false,
516            })
517        });
518        images_have_audio || messages_have_audio
519    }
520
521    pub fn new(config: InferenceConfig) -> Self {
522        let registry = models::ModelRegistry::new(config.models_dir.clone());
523        let hw = HardwareInfo::detect();
524        let router = ModelRouter::new(hw.clone());
525        let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
526        let adaptive_router = AdaptiveRouter::with_default_config(hw);
527        let mut tracker = OutcomeTracker::new();
528        // Load persisted profiles from previous sessions (#13)
529        let profiles_path = config.models_dir.join("outcome_profiles.json");
530        if let Ok(n) = tracker.load_from_file(&profiles_path) {
531            if n > 0 {
532                tracing::info!(loaded = n, "loaded persisted model profiles");
533            }
534        }
535        let mut benchmark_models_loaded = 0usize;
536        for path in benchmark_priors_paths(&config.models_dir) {
537            match routing_ext::load_benchmark_priors(&path) {
538                Ok(priors) if !priors.is_empty() => {
539                    benchmark_models_loaded += priors.len();
540                    routing_ext::apply_benchmark_priors(&mut tracker, &priors);
541                    tracing::info!(
542                        path = %path.display(),
543                        loaded = priors.len(),
544                        "loaded benchmark quality priors"
545                    );
546                }
547                Ok(_) => {}
548                Err(error) => {
549                    tracing::warn!(path = %path.display(), %error, "failed to load benchmark priors");
550                }
551            }
552        }
553        if benchmark_models_loaded > 0 {
554            tracing::info!(
555                loaded = benchmark_models_loaded,
556                "applied benchmark priors to cold-start routing"
557            );
558        }
559        let outcome_tracker = Arc::new(RwLock::new(tracker));
560
561        let remote_backend = RemoteBackend::new();
562
563        Self {
564            config,
565            unified_registry,
566            adaptive_router,
567            outcome_tracker,
568            remote_backend,
569            #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
570            mlx_backends: Arc::new(backend_cache::BackendCache::from_env()),
571            #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
572            flux_cache: Arc::new(backend_cache::BackendCache::from_env()),
573            #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
574            ltx_cache: Arc::new(backend_cache::BackendCache::from_env()),
575            #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
576            kokoro_cache: Arc::new(backend_cache::BackendCache::from_env()),
577            registry,
578            router,
579            #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
580            backend: Arc::new(RwLock::new(None)),
581            #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
582            embedding_backend: Arc::new(RwLock::new(None)),
583            #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
584            speech_runtime: Arc::new(Mutex::new(None)),
585            speech_policy: SpeechPolicy {
586                prefer_local: cfg!(all(target_os = "macos", target_arch = "aarch64")),
587                allow_remote_fallback: true,
588                preferred_local_stt: None,
589                preferred_local_tts: None,
590                preferred_remote_stt: None,
591                preferred_remote_tts: None,
592            },
593        }
594    }
595
596    /// Initialize key pool: register keys from all remote models and load persisted stats.
597    /// Call this after construction (requires async).
598    pub async fn init_key_pool(&self) {
599        // Register keys from all remote models in the catalog
600        for schema in self.unified_registry.list() {
601            if schema.is_remote() {
602                self.remote_backend.register_model_keys(schema).await;
603            }
604        }
605
606        // Load persisted key stats
607        let stats_path = self.config.models_dir.join("key_pool_stats.json");
608        if let Ok(n) = self.remote_backend.key_pool.load_stats(&stats_path).await {
609            if n > 0 {
610                tracing::info!(loaded = n, "loaded persisted key pool stats");
611            }
612        }
613
614        let total = self.remote_backend.key_pool.total_keys().await;
615        if total > 0 {
616            tracing::info!(keys = total, "key pool initialized");
617        }
618    }
619
620    /// Get or initialize the generative Candle backend, loading the specified model.
621    /// Not used on Apple Silicon where all local inference goes through MLX.
622    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
623    async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
624        let read = self.backend.read().await;
625        if read.is_some() {
626            return Ok(());
627        }
628        drop(read);
629
630        let mut write = self.backend.write().await;
631        if write.is_some() {
632            return Ok(());
633        }
634
635        let model_path = self.registry.ensure_model(model_name).await?;
636        let device = self.config.device.unwrap_or_else(Device::auto);
637        let backend = CandleBackend::load(&model_path, device)?;
638        *write = Some(backend);
639        Ok(())
640    }
641
642    /// Get or initialize the embedding backend.
643    /// On Apple Silicon, uses the MLX backend instead of Candle.
644    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
645    async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
646        let read = self.embedding_backend.read().await;
647        if read.is_some() {
648            return Ok(());
649        }
650        drop(read);
651
652        let mut write = self.embedding_backend.write().await;
653        if write.is_some() {
654            return Ok(());
655        }
656
657        let embedding_model = self
658            .preferred_model_for_capability(ModelCapability::Embed)
659            .unwrap_or(&self.config.embedding_model);
660        let model_path = self
661            .registry
662            .ensure_model(embedding_model)
663            .await?;
664        let device = self.config.device.unwrap_or_else(Device::auto);
665        let backend = EmbeddingBackend::load(&model_path, device)?;
666        *write = Some(backend);
667        Ok(())
668    }
669
670    /// On Apple Silicon, ensure the MLX embedding model is loaded.
671    /// Returns the schema ID of the embedding model for keying into mlx_backends.
672    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
673    async fn ensure_mlx_embedding_backend(&self) -> Result<String, InferenceError> {
674        let embedding_model_name = self
675            .preferred_model_for_capability(ModelCapability::Embed)
676            .unwrap_or(&self.config.embedding_model)
677            .to_string();
678        let schema = self
679            .unified_registry
680            .get(&embedding_model_name)
681            .or_else(|| self.unified_registry.find_by_name(&embedding_model_name))
682            .ok_or_else(|| InferenceError::ModelNotFound(embedding_model_name.clone()))?
683            .clone();
684        self.ensure_mlx_backend(&schema).await?;
685        Ok(schema.id)
686    }
687
688    /// Get or initialize the native MLX backend for a specific model.
689    /// Returns a shared mutex handle — the caller locks it for the
690    /// duration of an inference call so concurrent requests serialize.
691    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
692    async fn ensure_mlx_backend(
693        &self,
694        schema: &ModelSchema,
695    ) -> Result<backend_cache::CachedBackend<backend::MlxBackend>, InferenceError> {
696        if !Self::supports_native_mlx(schema) {
697            return Err(InferenceError::InferenceFailed(format!(
698                "native MLX backend does not support {} ({}) yet; use vLLM-MLX or add a family-specific MLX backend",
699                schema.name, schema.family
700            )));
701        }
702
703        let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
704        let size = backend_cache::estimate_model_size(&model_dir);
705        let cache = Arc::clone(&self.mlx_backends);
706        let key = schema.id.clone();
707        // Loader runs inside `get_or_load` only on a cache miss. Wrap it
708        // in `catch_unwind` because MLX/accelerate occasionally panics
709        // at the FFI boundary and we don't want the whole engine to die.
710        cache.get_or_load(&key, size, move || {
711            std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
712                backend::MlxBackend::load(&model_dir)
713            }))
714            .map_err(|e| {
715                InferenceError::InferenceFailed(format!(
716                    "MLX backend loading panicked (possible Metal/accelerate exception): {:?}",
717                    e
718                ))
719            })?
720        })
721    }
722
723    /// Pre-load a set of models into the MLX cache so the first real
724    /// inference call doesn't pay the 1–14 s model-load latency. Safe
725    /// to call multiple times; already-loaded models are no-ops.
726    ///
727    /// Handles both text-gen MLX backends (`mlx_backends`) and the
728    /// image/video/tts caches. Pass the full `schema.id` values.
729    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
730    pub async fn warm_up<S: AsRef<str>>(&self, schema_ids: &[S]) -> Vec<Result<(), InferenceError>> {
731        let mut results = Vec::with_capacity(schema_ids.len());
732        for id in schema_ids {
733            let id = id.as_ref();
734            let outcome: Result<(), InferenceError> = async {
735                let schema = self
736                    .unified_registry
737                    .get(id)
738                    .cloned()
739                    .ok_or_else(|| InferenceError::InferenceFailed(
740                        format!("warm_up: unknown schema id {id}"),
741                    ))?;
742                match schema.capabilities.first().copied() {
743                    Some(ModelCapability::ImageGeneration) => {
744                        let model_dir =
745                            self.unified_registry.ensure_local(&schema.id).await?;
746                        let size = backend_cache::estimate_model_size(&model_dir);
747                        let _ = self.flux_cache.get_or_load(&schema.id, size, || {
748                            backend::mlx_flux::FluxBackend::load(&model_dir)
749                        })?;
750                    }
751                    Some(ModelCapability::VideoGeneration) => {
752                        let model_dir =
753                            self.unified_registry.ensure_local(&schema.id).await?;
754                        let size = backend_cache::estimate_model_size(&model_dir);
755                        let _ = self.ltx_cache.get_or_load(&schema.id, size, || {
756                            backend::mlx_ltx::LtxBackend::load(&model_dir)
757                        })?;
758                    }
759                    Some(ModelCapability::TextToSpeech) => {
760                        let model_dir =
761                            self.unified_registry.ensure_local(&schema.id).await?;
762                        let size = backend_cache::estimate_model_size(&model_dir);
763                        let _ = self.kokoro_cache.get_or_load(&schema.id, size, || {
764                            backend::mlx_kokoro::KokoroBackend::load(&model_dir)
765                        })?;
766                    }
767                    _ => {
768                        let _ = self.ensure_mlx_backend(&schema).await?;
769                    }
770                }
771                Ok(())
772            }
773            .await;
774            results.push(outcome);
775        }
776        results
777    }
778
779    /// No-op on non-macOS — MLX doesn't run here.
780    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
781    pub async fn warm_up<S: AsRef<str>>(&self, _schema_ids: &[S]) -> Vec<Result<(), InferenceError>> {
782        Vec::new()
783    }
784
785    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
786    fn supports_native_mlx(schema: &ModelSchema) -> bool {
787        matches!(schema.family.as_str(), "qwen3" | "qwen2.5-vl" | "qwen2-vl")
788    }
789
790    /// Route a prompt using the adaptive router (new). Returns full decision context.
791    pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
792        if let Some(model) = self.preferred_model_for_capability(ModelCapability::Generate) {
793            let ctx_len = self
794                .unified_registry
795                .get(model)
796                .or_else(|| self.unified_registry.find_by_name(model))
797                .map(|s| s.context_length)
798                .unwrap_or(0);
799            return AdaptiveRoutingDecision {
800                model_id: model.to_string(),
801                model_name: model.to_string(),
802                task: InferenceTask::Generate,
803                complexity: TaskComplexity::assess(prompt),
804                reason: "preferred generation model override".into(),
805                strategy: RoutingStrategy::Explicit,
806                predicted_quality: 0.5,
807                fallbacks: vec![],
808                context_length: ctx_len,
809                needs_compaction: false,
810            };
811        }
812        let tracker = self.outcome_tracker.read().await;
813        self.adaptive_router
814            .route(prompt, &self.unified_registry, &tracker)
815    }
816
817    /// Route a prompt to the best model without executing (legacy compat).
818    pub fn route(&self, prompt: &str) -> RoutingDecision {
819        self.router.route_generate(prompt, &self.registry)
820    }
821
822    /// Estimate token count for a request against a specific model's context window.
823    /// Returns (estimated_input_tokens, context_window_tokens, fits).
824    pub fn estimated_tokens(
825        &self,
826        req: &GenerateRequest,
827        model_id: Option<&str>,
828    ) -> (usize, usize, bool) {
829        let prompt_tokens = remote::estimate_tokens(&req.prompt);
830        let context_tokens = req
831            .context
832            .as_ref()
833            .map(|c| remote::estimate_tokens(c))
834            .unwrap_or(0);
835        let tools_tokens = req
836            .tools
837            .as_ref()
838            .map(|t| remote::estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
839            .unwrap_or(0);
840        let total_input = prompt_tokens + context_tokens + tools_tokens;
841
842        let context_window = model_id
843            .and_then(|id| {
844                self.unified_registry
845                    .get(id)
846                    .or_else(|| self.unified_registry.find_by_name(id))
847            })
848            .map(|s| s.context_length)
849            .unwrap_or(0);
850
851        let fits = context_window == 0 || (total_input + req.params.max_tokens) <= context_window;
852        (total_input, context_window, fits)
853    }
854
855    /// Generate text from a prompt with outcome tracking.
856    /// Returns `InferenceResult` with trace_id for reporting outcomes.
857    #[instrument(
858        name = "inference.generate",
859        skip_all,
860        fields(
861            model = tracing::field::Empty,
862            max_tokens = req.params.max_tokens,
863            prompt_tokens = tracing::field::Empty,
864            completion_tokens = tracing::field::Empty,
865            latency_ms = tracing::field::Empty,
866        )
867    )]
868    pub async fn generate_tracked(
869        &self,
870        req: GenerateRequest,
871    ) -> Result<InferenceResult, InferenceError> {
872        let start = Instant::now();
873
874        // Route using adaptive router (context-aware)
875        let (estimated_input, _, _) = self.estimated_tokens(&req, None);
876        let tracker_read = self.outcome_tracker.read().await;
877        let has_tools = req.tools.is_some();
878        let has_vision = Self::request_needs_vision(&req);
879        let preferred_model = self
880            .preferred_model_for_capability(ModelCapability::Generate)
881            .map(str::to_string);
882        let decision = match req.model.clone().or(preferred_model) {
883            Some(m) => {
884                let ctx_len = self
885                    .unified_registry
886                    .get(&m)
887                    .or_else(|| self.unified_registry.find_by_name(&m))
888                    .map(|s| s.context_length)
889                    .unwrap_or(0);
890                AdaptiveRoutingDecision {
891                    model_id: m.clone(),
892                    model_name: m.clone(),
893                    task: InferenceTask::Generate,
894                    complexity: TaskComplexity::assess(&req.prompt),
895                    reason: "explicit model".into(),
896                    strategy: RoutingStrategy::Explicit,
897                    predicted_quality: 0.5,
898                    fallbacks: vec![],
899                    context_length: ctx_len,
900                    needs_compaction: ctx_len > 0 && estimated_input > ctx_len,
901                }
902            }
903            None => match &req.intent {
904                Some(hint) => self.adaptive_router.route_context_aware_with_intent(
905                    &req.prompt,
906                    estimated_input,
907                    &self.unified_registry,
908                    &tracker_read,
909                    has_tools,
910                    has_vision,
911                    req.params.workload,
912                    hint,
913                ),
914                None => self.adaptive_router.route_context_aware(
915                    &req.prompt,
916                    estimated_input,
917                    &self.unified_registry,
918                    &tracker_read,
919                    has_tools,
920                    has_vision,
921                    req.params.workload,
922                ),
923            },
924        };
925        drop(tracker_read);
926
927        if decision.needs_compaction {
928            tracing::info!(
929                model = %decision.model_name,
930                prompt_tokens = estimated_input,
931                context_window = decision.context_length,
932                "prompt exceeds model context window — compaction or truncation needed"
933            );
934        }
935
936        // Record start
937        let trace_id = {
938            let mut tracker = self.outcome_tracker.write().await;
939            tracker.record_start(&decision.model_id, decision.task, &decision.reason)
940        };
941
942        debug!(
943            model = %decision.model_name,
944            strategy = ?decision.strategy,
945            reason = %decision.reason,
946            trace = %trace_id,
947            "adaptive-routed generate request"
948        );
949
950        // Auto-enable extended thinking for complex tasks when the model supports it
951        // and the caller hasn't explicitly set budget_tokens.
952        let mut req = req;
953        if req.params.budget_tokens == 0 && matches!(decision.complexity, TaskComplexity::Complex) {
954            let supports_thinking = self
955                .unified_registry
956                .get(&decision.model_id)
957                .map(|s| {
958                    s.supported_params
959                        .contains(&schema::GenerateParam::ExtendedThinking)
960                })
961                .unwrap_or(false);
962            if supports_thinking {
963                req.params.budget_tokens = 8000;
964                tracing::info!(model = %decision.model_name, budget = 8000, "auto-enabled extended thinking for complex task");
965            }
966        }
967
968        // Execute — dispatch to local or remote backend, with fallback on failure
969        let mut models_to_try = vec![decision.model_id.clone()];
970        models_to_try.extend(decision.fallbacks.iter().cloned());
971
972        let mut last_error = None;
973
974        for candidate_id in &models_to_try {
975            // `mut` is needed on the aarch64-macos cfg branch below;
976            // other targets don't rebind.
977            #[allow(unused_mut)]
978            let mut schema = self
979                .unified_registry
980                .get(candidate_id)
981                .or_else(|| self.unified_registry.find_by_name(candidate_id))
982                .cloned();
983
984            // On Apple Silicon, redirect GGUF/Candle models to their MLX equivalents.
985            #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
986            if let Some(ref s) = schema {
987                if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
988                    tracing::info!(
989                        from = %s.id, to = %mlx_equiv.id,
990                        "redirecting GGUF model to MLX equivalent on Apple Silicon"
991                    );
992                    schema = Some(mlx_equiv.clone());
993                }
994            }
995
996            let candidate_name = schema
997                .as_ref()
998                .map(|s| s.name.clone())
999                .unwrap_or_else(|| candidate_id.clone());
1000
1001            let is_remote = schema
1002                .as_ref()
1003                .map(|s| s.is_remote() || s.is_vllm_mlx())
1004                .unwrap_or(false);
1005
1006            let has_tools = req.tools.is_some();
1007
1008            // Reinforce done tool instructions in context (fixes #10: empty done results)
1009            let context = if has_tools
1010                && req.tools.as_ref().map_or(false, |t| {
1011                    t.iter().any(|tool| {
1012                        tool.get("function")
1013                            .and_then(|f| f.get("name"))
1014                            .and_then(|n| n.as_str())
1015                            == Some("done")
1016                    })
1017                }) {
1018                let base = req.context.as_deref().unwrap_or("");
1019                Some(format!(
1020                    "{base}\n\nIMPORTANT: When calling the `done` tool, the `result` field MUST contain a DETAILED summary of everything you found and did. This is the ONLY output the user sees. Do NOT just say 'completed' — include specific findings, data, and conclusions."
1021                ))
1022            } else {
1023                req.context.clone()
1024            };
1025
1026            let result = if is_remote {
1027                let schema_val = schema.unwrap();
1028                let _ctx_len = schema_val.context_length;
1029                // Strip unsupported params based on model schema (#15).
1030                // Use -1.0 as sentinel: remote backends omit temperature entirely.
1031                let temperature = if !schema_val.supported_params.is_empty()
1032                    && !schema_val
1033                        .supported_params
1034                        .contains(&crate::schema::GenerateParam::Temperature)
1035                {
1036                    -1.0
1037                } else {
1038                    req.params.temperature
1039                };
1040
1041                // Always use the multi path so token usage is preserved on
1042                // both tool and non-tool requests. The bare `generate()` helper
1043                // in remote_backend wraps this same call but drops the usage
1044                // tuple, which breaks observability for plain text inference
1045                // (sc-3 in outcome 043).
1046                self.remote_backend
1047                    .generate_with_tools_multi(
1048                        &schema_val,
1049                        &req.prompt,
1050                        context.as_deref(),
1051                        temperature,
1052                        req.params.max_tokens,
1053                        req.tools.as_deref(),
1054                        req.images.as_deref(),
1055                        req.messages.as_deref(),
1056                        req.params.tool_choice.as_deref(),
1057                        req.params.parallel_tool_calls,
1058                        req.params.budget_tokens,
1059                        req.cache_control,
1060                        req.response_format.as_ref(),
1061                    )
1062                    .await
1063                    // Non-streaming remote APIs don't expose a
1064                    // first-token timestamp. Set TTFT=None and let
1065                    // streaming-aware callers measure it themselves
1066                    // via generate_tracked_stream.
1067                    .map(|(t, c, u)| (t, c, u, None::<u64>))
1068            } else {
1069                #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1070                {
1071                    // On Apple Silicon, all local models must go through MLX.
1072                    // GGUF models were redirected to MLX equivalents above;
1073                    // if we still have a non-MLX model here, it has no MLX equivalent.
1074                    let schema_ref = schema
1075                        .as_ref()
1076                        .ok_or_else(|| InferenceError::ModelNotFound(candidate_id.clone()))?;
1077
1078                    // Apple FoundationModels — on-device system model.
1079                    // Tools, vision/audio/video aren't wired in v1; reject
1080                    // upstream so the router falls through to a richer
1081                    // model rather than silently dropping capabilities.
1082                    if schema_ref.is_foundation_models() {
1083                        if has_tools {
1084                            Err(InferenceError::UnsupportedMode {
1085                                mode: "tool-use",
1086                                backend: "foundation-models",
1087                                reason:
1088                                    "tool calling not yet wired through the FoundationModels \
1089                                     bridge — route to a remote model with ToolUse capability",
1090                            })
1091                        } else if Self::request_has_video(&req)
1092                            || Self::request_has_audio(&req)
1093                            || req
1094                                .images
1095                                .as_ref()
1096                                .is_some_and(|imgs| !imgs.is_empty())
1097                        {
1098                            Err(InferenceError::UnsupportedMode {
1099                                mode: "multimodal-content",
1100                                backend: "foundation-models",
1101                                reason:
1102                                    "the FoundationModels bridge currently exposes text-only \
1103                                     generation — route image/audio/video to a remote VL model",
1104                            })
1105                        } else {
1106                            let prompt = req.prompt.clone();
1107                            let instructions = context.clone();
1108                            let max_tokens = req.params.max_tokens as u32;
1109                            let temperature = req.params.temperature;
1110                            tokio::task::spawn_blocking(move || {
1111                                crate::backend::foundation_models::generate(
1112                                    &prompt,
1113                                    instructions.as_deref(),
1114                                    max_tokens,
1115                                    temperature as f32,
1116                                )
1117                            })
1118                            .await
1119                            .map_err(|e| {
1120                                InferenceError::InferenceFailed(format!(
1121                                    "FoundationModels task panicked: {e}"
1122                                ))
1123                            })
1124                            .and_then(|r| r)
1125                            .map(|text| (text, vec![], None, None))
1126                        }
1127                    } else if !schema_ref.is_mlx() {
1128                        Err(InferenceError::InferenceFailed(format!(
1129                            "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1130                            schema_ref.id
1131                        )))
1132                    } else if schema_ref.tags.iter().any(|t| t == "mlx-vlm-cli") {
1133                        // Schemas explicitly tagged for the
1134                        // mlx-vlm CLI shell-out path skip the
1135                        // wasteful native-MLX text-tower load
1136                        // entirely — `mlx_vlm.generate` loads its
1137                        // own weights from the HF cache and
1138                        // performs vision tokenization that the
1139                        // native backend does not. Falls through
1140                        // to the same error message as the
1141                        // post-load fallback when mlx-vlm is not
1142                        // installed, so the user-facing failure
1143                        // is consistent.
1144                        let has_images = req
1145                            .images
1146                            .as_ref()
1147                            .is_some_and(|imgs| !imgs.is_empty());
1148                        if !has_images {
1149                            return Err(InferenceError::UnsupportedMode {
1150                                mode: "text-only-on-mlx-vlm-id",
1151                                backend: "mlx-vlm-cli",
1152                                reason:
1153                                    "the `mlx-vlm/...` model IDs route exclusively \
1154                                     through the mlx-vlm CLI for image inference. \
1155                                     For text-only generation, route to a Qwen3 \
1156                                     text model (`mlx/qwen3-4b:4bit` etc.) — the \
1157                                     CLI shell-out has higher latency than the \
1158                                     in-process MLX text tower.",
1159                            });
1160                        }
1161                        if !crate::backend::mlx_vlm_cli::is_available() {
1162                            return Err(InferenceError::UnsupportedMode {
1163                                mode: "image-content-block",
1164                                backend: "mlx-vlm-cli",
1165                                reason:
1166                                    "mlx-vlm CLI not found on PATH. Install with \
1167                                     `uv tool install mlx-vlm` (or `pip install \
1168                                     mlx-vlm`) and CAR will route through \
1169                                     `mlx_vlm.generate` automatically. \
1170                                     Alternatives: a local vLLM-MLX VLM server, \
1171                                     or a remote VL model. (#115)",
1172                            });
1173                        }
1174                        let repo = match &schema_ref.source {
1175                            crate::schema::ModelSource::Mlx { hf_repo, .. } => {
1176                                hf_repo.clone()
1177                            }
1178                            _ => {
1179                                return Err(InferenceError::InferenceFailed(format!(
1180                                    "model '{}' is tagged mlx-vlm-cli but its \
1181                                     source isn't ModelSource::Mlx — registry bug",
1182                                    schema_ref.id
1183                                )));
1184                            }
1185                        };
1186                        let imgs = req.images.clone().unwrap_or_default();
1187                        let temp = req.params.temperature;
1188                        let max_t = req.params.max_tokens;
1189                        let prompt = req.prompt.clone();
1190                        let text = tokio::task::spawn_blocking(move || {
1191                            crate::backend::mlx_vlm_cli::generate(
1192                                &repo, &prompt, &imgs, temp, max_t,
1193                            )
1194                        })
1195                        .await
1196                        .map_err(|e| {
1197                            InferenceError::InferenceFailed(format!(
1198                                "mlx_vlm CLI task panicked: {e}"
1199                            ))
1200                        })??;
1201                        let bounding_boxes = parse_boxes(&text);
1202                        let latency_ms = start.elapsed().as_millis() as u64;
1203                        {
1204                            let mut tracker = self.outcome_tracker.write().await;
1205                            tracker.record_complete(&trace_id, latency_ms, 0, 0);
1206                        }
1207                        return Ok(InferenceResult {
1208                            text,
1209                            tool_calls: vec![],
1210                            bounding_boxes,
1211                            trace_id: trace_id.clone(),
1212                            model_used: schema_ref.id.clone(),
1213                            latency_ms,
1214                            time_to_first_token_ms: None,
1215                            usage: None,
1216                        });
1217                    } else {
1218                        // Load the backend first so we can ask it what
1219                        // it's actually able to execute. VL checkpoints
1220                        // currently load as text-only towers (see the
1221                        // `language_model.` prefix strip in backend/mlx.rs);
1222                        // the backend's `supports_capability(Vision)`
1223                        // returns false until GH #58 wires the vision
1224                        // tower. A registry-level capability claim is
1225                        // an aspiration for routing; the backend answer
1226                        // is the execution contract.
1227                        let handle = self.ensure_mlx_backend(schema_ref).await?;
1228                        // Native MLX path doesn't have a video
1229                        // tokenization pipeline yet. Reject video
1230                        // content blocks up front with a precise
1231                        // UnsupportedMode so callers don't silently
1232                        // get a text-only reply.
1233                        if Self::request_has_video(&req) {
1234                            return Err(InferenceError::UnsupportedMode {
1235                                mode: "video-content-block",
1236                                backend: "native-mlx-qwen25vl",
1237                                reason:
1238                                    "Qwen2.5-VL video understanding is on the request surface \
1239                                     but the video-tokenization path (frame sampling + merger) \
1240                                     is not yet wired; route to a remote VL provider for now",
1241                            });
1242                        }
1243                        if Self::request_has_audio(&req) {
1244                            return Err(InferenceError::UnsupportedMode {
1245                                mode: "audio-content-block",
1246                                backend: "native-mlx-qwen25vl",
1247                                reason:
1248                                    "audio understanding is on the request surface (Gemma 4 \
1249                                     E2B/E4B and Gemini accept it) but the native MLX path \
1250                                     for this model does not — route to Gemini or Gemma-4",
1251                            });
1252                        }
1253                        let has_images = req
1254                            .images
1255                            .as_ref()
1256                            .is_some_and(|imgs| !imgs.is_empty());
1257                        if has_images {
1258                            let can_do_vision = {
1259                                let guard = handle.lock().map_err(|_| {
1260                                    InferenceError::InferenceFailed(
1261                                        "MLX backend mutex poisoned".into(),
1262                                    )
1263                                })?;
1264                                guard.supports_capability(
1265                                    crate::schema::ModelCapability::Vision,
1266                                )
1267                            };
1268                            if !can_do_vision {
1269                                // The `mlx-vlm-cli`-tagged route handled
1270                                // above is the primary fix for #115; if
1271                                // we landed here it means the schema is
1272                                // a non-tagged `ModelSource::Mlx` (e.g.
1273                                // a user-registered custom model that
1274                                // doesn't advertise the CLI route). The
1275                                // error message points them at the
1276                                // tagged catalog IDs rather than
1277                                // claiming nothing local works.
1278                                return Err(InferenceError::UnsupportedMode {
1279                                    mode: "image-content-block",
1280                                    backend: "native-mlx-text",
1281                                    reason:
1282                                        "this MLX backend is a plain Qwen3 text tower. \
1283                                         For local image inference, route to \
1284                                         `mlx-vlm/qwen3-vl-2b:bf16` or another `mlx-vlm/...` \
1285                                         catalog ID so CAR shells out to `mlx_vlm.generate`. \
1286                                         Alternatives: a local vLLM-MLX VLM server, or a \
1287                                         remote VL model. (#115)",
1288                                });
1289                            }
1290                        }
1291                        self.generate_mlx(req.clone(), &schema_ref.id)
1292                            .await
1293                            .map(|(text, ttft)| (text, vec![], None, ttft))
1294                    }
1295                }
1296
1297                #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1298                {
1299                    match self.ensure_backend(&candidate_name).await {
1300                        Ok(()) => {
1301                            let mut write = self.backend.write().await;
1302                            let backend = write.as_mut().unwrap();
1303                            tasks::generate::generate(backend, req.clone())
1304                                .await
1305                                .map(|(text, ttft)| (text, vec![], None, ttft))
1306                        }
1307                        Err(e) => Err(e),
1308                    }
1309                }
1310            };
1311
1312            match result {
1313                Ok((text, tool_calls, usage, time_to_first_token_ms)) => {
1314                    let latency_ms = start.elapsed().as_millis() as u64;
1315                    let estimated_tokens = usage
1316                        .as_ref()
1317                        .map(|u| u.completion_tokens as usize)
1318                        .unwrap_or_else(|| text.split_whitespace().count());
1319                    {
1320                        let mut tracker = self.outcome_tracker.write().await;
1321                        tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
1322                    }
1323                    // Circuit breaker: record success (#25)
1324                    if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1325                        cb.record_success(candidate_id);
1326                    }
1327                    // Auto-persist profiles after each successful call
1328                    self.auto_save_outcomes().await;
1329
1330                    // Record deferred span fields now that we have the result
1331                    let span = tracing::Span::current();
1332                    span.record("model", candidate_name.as_str());
1333                    span.record("latency_ms", latency_ms);
1334                    if let Some(ttft) = time_to_first_token_ms {
1335                        span.record("ttft_ms", ttft);
1336                    }
1337                    if let Some(ref u) = usage {
1338                        span.record("prompt_tokens", u.prompt_tokens);
1339                        span.record("completion_tokens", u.completion_tokens);
1340                    }
1341
1342                    // Parse Qwen2.5-VL grounding spans out of the
1343                    // text output. Empty vec on anything else.
1344                    let bounding_boxes = tasks::grounding::parse_boxes(&text);
1345                    return Ok(InferenceResult {
1346                        text,
1347                        tool_calls,
1348                        bounding_boxes,
1349                        trace_id,
1350                        model_used: candidate_name,
1351                        latency_ms,
1352                        time_to_first_token_ms,
1353                        usage,
1354                    });
1355                }
1356                Err(e) => {
1357                    tracing::warn!(
1358                        model = %candidate_name,
1359                        error = %e,
1360                        remaining = models_to_try.len().saturating_sub(
1361                            models_to_try.iter().position(|m| m == candidate_id).unwrap_or(0) + 1
1362                        ),
1363                        "model failed, trying next fallback immediately"
1364                    );
1365                    // Record failure for this model so the router learns
1366                    {
1367                        let mut tracker = self.outcome_tracker.write().await;
1368                        let fail_trace =
1369                            tracker.record_start(candidate_id, decision.task, "fallback");
1370                        tracker.record_failure(&fail_trace, &e.to_string());
1371                    }
1372                    // Circuit breaker: record failure (#25).
1373                    // 4xx errors (client errors) use longer cooldown since they indicate
1374                    // permanent incompatibility (wrong endpoint, unsupported param).
1375                    {
1376                        let err_str = e.to_string();
1377                        let is_client_error =
1378                            err_str.contains("API returned 4") && !err_str.contains("429");
1379                        if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1380                            cb.record_failure(candidate_id);
1381                            // For persistent 4xx errors, lower the threshold by
1382                            // recording an extra failure to trip faster
1383                            if is_client_error {
1384                                cb.record_failure(candidate_id);
1385                            }
1386                        }
1387                    }
1388                    // Reset backend so next model can load
1389                    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1390                    {
1391                        let mut write = self.backend.write().await;
1392                        *write = None;
1393                    }
1394                    last_error = Some(e);
1395                }
1396            }
1397        }
1398
1399        // All models failed
1400        let e = last_error.unwrap_or(InferenceError::InferenceFailed(
1401            "no models available".into(),
1402        ));
1403        {
1404            let mut tracker = self.outcome_tracker.write().await;
1405            tracker.record_failure(&trace_id, &e.to_string());
1406        }
1407        self.auto_save_outcomes().await;
1408        Err(e)
1409    }
1410
1411    /// Stream a response with real-time token output.
1412    ///
1413    /// Returns a channel receiver yielding `StreamEvent` variants:
1414    /// - `TextDelta(String)` — partial text token
1415    /// - `ToolCallStart { name, index, id }` — tool call begins
1416    /// - `ToolCallDelta { index, arguments_delta }` — partial tool arguments
1417    ///
1418    /// Use `StreamAccumulator` to collect events into a final result.
1419    ///
1420    /// ## Local streaming behavior
1421    ///
1422    /// Local models (both MLX and Candle backends) emit true incremental `TextDelta`
1423    /// events as each token is generated. This enables:
1424    /// - Token-by-token UI updates
1425    /// - Overlapping speech synthesis with generation (for voice apps)
1426    /// - Early cancellation when a stop sequence is detected
1427    ///
1428    /// The channel capacity is 64 events, providing buffering for burst tokens
1429    /// without blocking the generation loop.
1430    ///
1431    /// ## Example: voice app integration
1432    ///
1433    /// ```rust,ignore
1434    /// let mut rx = engine.generate_tracked_stream(req).await?;
1435    /// let mut text_buf = String::new();
1436    /// while let Some(event) = rx.recv().await {
1437    ///     match event {
1438    ///         StreamEvent::TextDelta(delta) => {
1439    ///             text_buf.push_str(&delta);
1440    ///             // Feed text_buf to TTS when a sentence boundary is reached
1441    ///         }
1442    ///         StreamEvent::Done { text, .. } => break,
1443    ///         _ => {}
1444    ///     }
1445    /// }
1446    /// ```
1447    pub async fn generate_tracked_stream(
1448        &self,
1449        req: GenerateRequest,
1450    ) -> Result<tokio::sync::mpsc::Receiver<stream::StreamEvent>, InferenceError> {
1451        let has_tools = req.tools.is_some();
1452        let has_vision = Self::request_needs_vision(&req);
1453        let preferred_model = self
1454            .preferred_model_for_capability(ModelCapability::Generate)
1455            .map(str::to_string);
1456        let decision = match req.model.clone().or(preferred_model) {
1457            Some(m) => {
1458                let ctx_len = self
1459                    .unified_registry
1460                    .get(&m)
1461                    .or_else(|| self.unified_registry.find_by_name(&m))
1462                    .map(|s| s.context_length)
1463                    .unwrap_or(0);
1464                AdaptiveRoutingDecision {
1465                    model_id: m.clone(),
1466                    model_name: m,
1467                    task: InferenceTask::Generate,
1468                    complexity: TaskComplexity::assess(&req.prompt),
1469                    reason: "explicit model".into(),
1470                    strategy: RoutingStrategy::Explicit,
1471                    predicted_quality: 0.5,
1472                    fallbacks: vec![],
1473                    context_length: ctx_len,
1474                    needs_compaction: false,
1475                }
1476            }
1477            None => {
1478                let tracker_read = self.outcome_tracker.read().await;
1479                if has_vision {
1480                    self.adaptive_router.route_with_vision(
1481                        &req.prompt,
1482                        &self.unified_registry,
1483                        &tracker_read,
1484                        has_tools,
1485                    )
1486                } else if has_tools {
1487                    self.adaptive_router.route_with_tools(
1488                        &req.prompt,
1489                        &self.unified_registry,
1490                        &tracker_read,
1491                    )
1492                } else {
1493                    self.adaptive_router
1494                        .route(&req.prompt, &self.unified_registry, &tracker_read)
1495                }
1496            }
1497        };
1498
1499        // `mut` is needed on the aarch64-macos cfg branch below;
1500        // other targets don't rebind.
1501        #[allow(unused_mut)]
1502        let mut schema = self
1503            .unified_registry
1504            .get(&decision.model_id)
1505            .or_else(|| self.unified_registry.find_by_name(&decision.model_id))
1506            .cloned();
1507
1508        // On Apple Silicon, redirect GGUF/Candle models to their MLX equivalents.
1509        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1510        if let Some(ref s) = schema {
1511            if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1512                tracing::info!(
1513                    from = %s.id, to = %mlx_equiv.id,
1514                    "redirecting GGUF model to MLX equivalent on Apple Silicon (stream)"
1515                );
1516                schema = Some(mlx_equiv.clone());
1517            }
1518        }
1519
1520        let is_remote = schema
1521            .as_ref()
1522            .map(|s| s.is_remote() || s.is_vllm_mlx())
1523            .unwrap_or(false);
1524
1525        if is_remote {
1526            let schema = schema.unwrap();
1527            // Register keys for this model
1528            self.remote_backend.register_model_keys(&schema).await;
1529
1530            self.remote_backend
1531                .generate_stream(
1532                    &schema,
1533                    &req.prompt,
1534                    req.context.as_deref(),
1535                    req.params.temperature,
1536                    req.params.max_tokens,
1537                    req.tools.as_deref(),
1538                    req.images.as_deref(),
1539                    req.params.tool_choice.as_deref(),
1540                    req.params.parallel_tool_calls,
1541                    req.response_format.as_ref(),
1542                )
1543                .await
1544        } else {
1545            let schema = schema.ok_or_else(|| {
1546                InferenceError::ModelNotFound(decision.model_id.clone())
1547            })?;
1548            let (tx, rx) = tokio::sync::mpsc::channel(64);
1549
1550            #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1551            {
1552                // FoundationModels streaming — text-only, no tools.
1553                if schema.is_foundation_models() {
1554                    let prompt = req.prompt.clone();
1555                    let instructions = req.context.clone();
1556                    let max_tokens = req.params.max_tokens as u32;
1557                    let temperature = req.params.temperature;
1558                    let tx_clone = tx.clone();
1559                    tokio::task::spawn_blocking(move || {
1560                        // Share the accumulator between the streaming
1561                        // callback and the post-stream Done event so
1562                        // the FoundationModels path matches Candle/MLX
1563                        // shape — `Done.text` is the full assembled
1564                        // generation, not an empty sentinel that
1565                        // forces consumers to reassemble.
1566                        let accum = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
1567                        let accum_cb = accum.clone();
1568                        let cb = crate::backend::foundation_models::StreamCallback::new(
1569                            move |delta: &str| {
1570                                if let Ok(mut g) = accum_cb.lock() {
1571                                    g.push_str(delta);
1572                                }
1573                                tx_clone
1574                                    .blocking_send(stream::StreamEvent::TextDelta(
1575                                        delta.to_string(),
1576                                    ))
1577                                    .is_ok()
1578                            },
1579                        );
1580                        let result = crate::backend::foundation_models::stream(
1581                            &prompt,
1582                            instructions.as_deref(),
1583                            max_tokens,
1584                            temperature as f32,
1585                            cb,
1586                        );
1587                        let final_text = accum
1588                            .lock()
1589                            .map(|g| g.clone())
1590                            .unwrap_or_default();
1591                        let _ = tx.blocking_send(stream::StreamEvent::Done {
1592                            text: final_text,
1593                            tool_calls: vec![],
1594                        });
1595                        result
1596                    });
1597                    return Ok(rx);
1598                }
1599                // On Apple Silicon, all local models must go through MLX.
1600                if !schema.is_mlx() {
1601                    return Err(InferenceError::InferenceFailed(format!(
1602                        "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1603                        schema.id
1604                    )));
1605                }
1606                let backend = self.ensure_mlx_backend(&schema).await?;
1607                let model_id = schema.id.clone();
1608                let cache = Arc::clone(&self.mlx_backends);
1609                // MLX ops are blocking (GPU-bound) and `MutexGuard<MlxBackend>`
1610                // isn't `Send`, so run the whole generation on a blocking
1611                // worker. `tx.blocking_send` bridges tokens back to the
1612                // async stream consumer without holding the guard across
1613                // an `.await`.
1614                tokio::task::spawn_blocking(move || {
1615                    let _ = Self::stream_local_mlx(backend, cache, model_id, req, tx);
1616                });
1617                return Ok(rx);
1618            }
1619
1620            #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1621            {
1622                self.ensure_backend(&schema.name).await?;
1623                let backend = self.backend.clone();
1624                tokio::spawn(async move {
1625                    let _ = Self::stream_local_candle(backend, req, tx).await;
1626                });
1627                Ok(rx)
1628            }
1629        }
1630    }
1631
1632    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1633    async fn stream_local_candle(
1634        backend_lock: Arc<RwLock<Option<CandleBackend>>>,
1635        req: GenerateRequest,
1636        tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1637    ) -> Result<(), InferenceError> {
1638        let mut write = backend_lock.write().await;
1639        let backend = write
1640            .as_mut()
1641            .ok_or_else(|| InferenceError::InferenceFailed("backend not initialized".into()))?;
1642        backend.clear_kv_cache();
1643
1644        let formatted = tasks::generate::apply_chat_template(
1645            &req.prompt,
1646            req.context.as_deref(),
1647            req.params.thinking,
1648        );
1649        let tokens = backend.encode(&formatted)?;
1650        let eos = backend.eos_token_id();
1651        let eos_alt = backend.token_id("<|im_end|>");
1652        let params = &req.params;
1653
1654        if tokens.is_empty() {
1655            let _ = tx
1656                .send(stream::StreamEvent::Done {
1657                    text: String::new(),
1658                    tool_calls: vec![],
1659                })
1660                .await;
1661            return Ok(());
1662        }
1663
1664        let max_ctx = backend.context_length().unwrap_or(32768);
1665        let headroom = params.max_tokens.min(max_ctx / 4);
1666        let max_prompt = max_ctx.saturating_sub(headroom);
1667        let tokens = if tokens.len() > max_prompt {
1668            tokens[tokens.len() - max_prompt..].to_vec()
1669        } else {
1670            tokens
1671        };
1672
1673        let mut generated = Vec::new();
1674        let logits = backend.forward(&tokens, 0)?;
1675        let mut next_token = tasks::generate::sample_token(&logits, params)?;
1676
1677        for _ in 0..params.max_tokens {
1678            if eos.map_or(false, |id| next_token == id)
1679                || eos_alt.map_or(false, |id| next_token == id)
1680            {
1681                break;
1682            }
1683
1684            generated.push(next_token);
1685            let delta = backend.decode(&[next_token])?;
1686            if !delta.is_empty()
1687                && tx.send(stream::StreamEvent::TextDelta(delta)).await.is_err()
1688            {
1689                return Ok(());
1690            }
1691
1692            if !params.stop.is_empty() {
1693                let text_so_far = backend.decode(&generated)?;
1694                if params.stop.iter().any(|s| text_so_far.contains(s)) {
1695                    break;
1696                }
1697            }
1698
1699            let pos = tokens.len() + generated.len() - 1;
1700            let logits = backend.forward(&[next_token], pos)?;
1701            next_token = tasks::generate::sample_token(&logits, params)?;
1702        }
1703
1704        let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1705        let _ = tx
1706            .send(stream::StreamEvent::Done {
1707                text,
1708                tool_calls: vec![],
1709            })
1710            .await;
1711        Ok(())
1712    }
1713
1714    /// Blocking streaming generator — runs on a `spawn_blocking` worker
1715    /// so the sync MLX mutex guard isn't held across an async `.await`.
1716    /// Uses `tx.blocking_send` to push tokens back to the caller.
1717    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1718    fn stream_local_mlx(
1719        handle: backend_cache::CachedBackend<backend::MlxBackend>,
1720        cache: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
1721        model_id: String,
1722        req: GenerateRequest,
1723        tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1724    ) -> Result<(), InferenceError> {
1725        let mut guard = handle.lock().map_err(|_| {
1726            InferenceError::InferenceFailed(format!(
1727                "MLX backend mutex poisoned for {model_id}"
1728            ))
1729        })?;
1730        let backend: &mut backend::MlxBackend = &mut *guard;
1731        backend.clear_kv_cache();
1732
1733        let formatted = tasks::generate::apply_chat_template(
1734            &req.prompt,
1735            req.context.as_deref(),
1736            req.params.thinking,
1737        );
1738        let tokens = backend.encode(&formatted)?;
1739        let eos = backend.eos_token_id();
1740        let eos_alt = backend.token_id("<|im_end|>");
1741        let params = &req.params;
1742
1743        if tokens.is_empty() {
1744            let _ = tx.blocking_send(stream::StreamEvent::Done {
1745                text: String::new(),
1746                tool_calls: vec![],
1747            });
1748            return Ok(());
1749        }
1750
1751        let max_ctx = backend.context_length();
1752        let headroom = params.max_tokens.min(max_ctx / 4);
1753        let max_prompt = max_ctx.saturating_sub(headroom);
1754        let tokens = if tokens.len() > max_prompt {
1755            tokens[tokens.len() - max_prompt..].to_vec()
1756        } else {
1757            tokens
1758        };
1759
1760        let mut generated = Vec::new();
1761        // Wrap MLX forward calls to catch panics. On panic, drop this
1762        // backend from the cache — its KV cache may be in an
1763        // indeterminate state and subsequent callers would inherit it.
1764        // Outstanding handles continue to work until their Arc drops.
1765        let logits = match Self::catch_mlx("stream prefill", || backend.forward(&tokens, 0)) {
1766            Ok(v) => v,
1767            Err(e) => { cache.invalidate(&model_id); return Err(e); }
1768        };
1769        let mut next_token = Self::sample_from_logits(&logits, params)?;
1770
1771        for _ in 0..params.max_tokens {
1772            if eos.map_or(false, |id| next_token == id)
1773                || eos_alt.map_or(false, |id| next_token == id)
1774            {
1775                break;
1776            }
1777
1778            generated.push(next_token);
1779            let delta = backend.decode(&[next_token])?;
1780            if !delta.is_empty()
1781                && tx.blocking_send(stream::StreamEvent::TextDelta(delta)).is_err()
1782            {
1783                return Ok(());
1784            }
1785
1786            if !params.stop.is_empty() {
1787                let text_so_far = backend.decode(&generated)?;
1788                if params.stop.iter().any(|s| text_so_far.contains(s)) {
1789                    break;
1790                }
1791            }
1792
1793            let pos = tokens.len() + generated.len() - 1;
1794            let logits = match Self::catch_mlx("stream forward", || backend.forward(&[next_token], pos)) {
1795                Ok(v) => v,
1796                Err(e) => { cache.invalidate(&model_id); return Err(e); }
1797            };
1798            next_token = Self::sample_from_logits(&logits, params)?;
1799        }
1800
1801        let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1802        let _ = tx.blocking_send(stream::StreamEvent::Done {
1803            text,
1804            tool_calls: vec![],
1805        });
1806        Ok(())
1807    }
1808
1809    /// Route a prompt using the adaptive router without executing inference.
1810    pub async fn route_context_snapshot(
1811        &self,
1812        prompt: &str,
1813        workload: RoutingWorkload,
1814        has_tools: bool,
1815        has_vision: bool,
1816    ) -> AdaptiveRoutingDecision {
1817        let tracker = self.outcome_tracker.read().await;
1818        self.adaptive_router.route_context_aware(
1819            prompt,
1820            0,
1821            &self.unified_registry,
1822            &tracker,
1823            has_tools,
1824            has_vision,
1825            workload,
1826        )
1827    }
1828
1829    /// Generate text from a prompt (legacy API, no outcome tracking).
1830    /// When `req.model` is None, uses intelligent routing based on prompt complexity.
1831    pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
1832        Ok(self.generate_tracked(req).await?.text)
1833    }
1834
1835    /// Encode `text` via the named model's tokenizer. Returns raw token IDs
1836    /// without any chat-template wrapping or BOS-prepending — pair with
1837    /// [`Self::detokenize`] for the round-trip property
1838    /// `detokenize(model, tokenize(model, s)) == s` for any UTF-8 `s`.
1839    ///
1840    /// Only local models have a tokenizer the runtime can call directly
1841    /// (Candle/GGUF on Linux/Windows, MLX on Apple Silicon). For remote
1842    /// models the call returns
1843    /// [`InferenceError::UnsupportedMode`] — provider tokenizer endpoints
1844    /// vary too widely to be portable here, and bundling tiktoken-style
1845    /// tables would lock the registry to a fixed set of providers.
1846    pub async fn tokenize(
1847        &self,
1848        model: &str,
1849        text: &str,
1850    ) -> Result<Vec<u32>, InferenceError> {
1851        self.assert_local_for_tokenize(model)?;
1852
1853        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1854        {
1855            let schema = self
1856                .unified_registry
1857                .get(model)
1858                .or_else(|| self.unified_registry.find_by_name(model))
1859                .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1860                .clone();
1861            let handle = self.ensure_mlx_backend(&schema).await?;
1862            let guard = handle.lock().map_err(|_| {
1863                InferenceError::InferenceFailed(format!(
1864                    "MLX backend mutex poisoned for {}", schema.id
1865                ))
1866            })?;
1867            return guard.tokenize_raw(text);
1868        }
1869
1870        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1871        {
1872            self.ensure_backend(model).await?;
1873            let read = self.backend.read().await;
1874            let backend = read.as_ref().ok_or_else(|| {
1875                InferenceError::InferenceFailed(
1876                    "candle backend missing after ensure_backend".to_string(),
1877                )
1878            })?;
1879            backend.tokenize_raw(text)
1880        }
1881    }
1882
1883    /// Inverse of [`Self::tokenize`]: decode token IDs back to text.
1884    pub async fn detokenize(
1885        &self,
1886        model: &str,
1887        tokens: &[u32],
1888    ) -> Result<String, InferenceError> {
1889        self.assert_local_for_tokenize(model)?;
1890
1891        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1892        {
1893            let schema = self
1894                .unified_registry
1895                .get(model)
1896                .or_else(|| self.unified_registry.find_by_name(model))
1897                .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1898                .clone();
1899            let handle = self.ensure_mlx_backend(&schema).await?;
1900            let guard = handle.lock().map_err(|_| {
1901                InferenceError::InferenceFailed(format!(
1902                    "MLX backend mutex poisoned for {}", schema.id
1903                ))
1904            })?;
1905            return guard.detokenize_raw(tokens);
1906        }
1907
1908        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1909        {
1910            self.ensure_backend(model).await?;
1911            let read = self.backend.read().await;
1912            let backend = read.as_ref().ok_or_else(|| {
1913                InferenceError::InferenceFailed(
1914                    "candle backend missing after ensure_backend".to_string(),
1915                )
1916            })?;
1917            backend.detokenize_raw(tokens)
1918        }
1919    }
1920
1921    /// Common pre-flight for [`Self::tokenize`] / [`Self::detokenize`]: bail
1922    /// early on remote models with the same `UnsupportedMode` taxonomy used
1923    /// elsewhere on the engine surface.
1924    fn assert_local_for_tokenize(&self, model: &str) -> Result<(), InferenceError> {
1925        if let Some(schema) = self
1926            .unified_registry
1927            .get(model)
1928            .or_else(|| self.unified_registry.find_by_name(model))
1929        {
1930            if !schema.is_local() {
1931                return Err(InferenceError::UnsupportedMode {
1932                    mode: "tokenize/detokenize",
1933                    backend: "remote",
1934                    reason:
1935                        "remote provider tokenizer is not exposed by the runtime; \
1936                         use a local model (Qwen3 GGUF / MLX) for tokenizer-correctness checks",
1937                });
1938            }
1939        }
1940        // Unknown model name: let the load step surface ModelNotFound below.
1941        Ok(())
1942    }
1943
1944    /// Wrap an MLX FFI call with catch_unwind to catch Rust panics at the boundary.
1945    /// NOTE: This catches Rust panics only, NOT C++ exceptions from Metal/MLX.
1946    /// True C++ exceptions will still abort the process — that requires an upstream
1947    /// fix in mlx-rs to catch C++ exceptions before they cross the FFI boundary.
1948    /// On panic, callers MUST remove the backend from the map — it is poisoned.
1949    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1950    fn catch_mlx<F, T>(context: &str, f: F) -> Result<T, InferenceError>
1951    where
1952        F: FnOnce() -> Result<T, InferenceError>,
1953    {
1954        std::panic::catch_unwind(std::panic::AssertUnwindSafe(f))
1955            .map_err(|e| InferenceError::InferenceFailed(
1956                format!("MLX panicked during {context}: {e:?}")
1957            ))?
1958    }
1959
1960    /// Generate text using the MLX backend.
1961    /// Mirrors the Candle generate loop but uses MlxBackend::forward which returns Vec<f32>.
1962    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1963    async fn generate_mlx(
1964        &self,
1965        req: GenerateRequest,
1966        model_id: &str,
1967    ) -> Result<(String, Option<u64>), InferenceError> {
1968        let start = std::time::Instant::now();
1969
1970        let schema = self
1971            .unified_registry
1972            .get(model_id)
1973            .cloned()
1974            .ok_or_else(|| InferenceError::InferenceFailed(format!(
1975                "generate_mlx: unknown schema id {model_id}"
1976            )))?;
1977        let handle = self.ensure_mlx_backend(&schema).await?;
1978        let mut guard = handle.lock().map_err(|_| {
1979            InferenceError::InferenceFailed(format!(
1980                "MLX backend mutex poisoned for {model_id}"
1981            ))
1982        })?;
1983        let backend: &mut backend::MlxBackend = &mut *guard;
1984        backend.clear_kv_cache();
1985
1986        let formatted = tasks::generate::apply_chat_template(
1987            &req.prompt,
1988            req.context.as_deref(),
1989            req.params.thinking,
1990        );
1991        let tokens = backend.encode(&formatted)?;
1992        let eos = backend.eos_token_id();
1993        let eos_alt = backend.token_id("<|im_end|>");
1994        let params = &req.params;
1995
1996        if tokens.is_empty() {
1997            return Ok((String::new(), None));
1998        }
1999
2000        // Truncate to context length
2001        let max_ctx = backend.context_length();
2002        let headroom = params.max_tokens.min(max_ctx / 4);
2003        let max_prompt = max_ctx.saturating_sub(headroom);
2004        let tokens = if tokens.len() > max_prompt {
2005            tokens[tokens.len() - max_prompt..].to_vec()
2006        } else {
2007            tokens
2008        };
2009
2010        let mut generated = Vec::new();
2011
2012        // Wrap MLX forward calls to catch panics at the FFI boundary.
2013        // On panic, drop the backend from the cache — it's in an
2014        // indeterminate state and we don't want subsequent callers to
2015        // inherit it. Our outstanding `guard` keeps working until we
2016        // return.
2017        let logits = match Self::catch_mlx("prefill", || backend.forward(&tokens, 0)) {
2018            Ok(v) => v,
2019            Err(e) => {
2020                drop(guard);
2021                self.mlx_backends.invalidate(model_id);
2022                return Err(e);
2023            }
2024        };
2025        let mut next_token = Self::sample_from_logits(&logits, params)?;
2026        let ttft_ms = Some(start.elapsed().as_millis() as u64);
2027
2028        for _ in 0..params.max_tokens {
2029            if eos.map_or(false, |id| next_token == id)
2030                || eos_alt.map_or(false, |id| next_token == id)
2031            {
2032                break;
2033            }
2034
2035            generated.push(next_token);
2036
2037            if !params.stop.is_empty() {
2038                let text_so_far = backend.decode(&generated)?;
2039                if params.stop.iter().any(|s| text_so_far.contains(s)) {
2040                    break;
2041                }
2042            }
2043
2044            let pos = tokens.len() + generated.len() - 1;
2045            let logits = match Self::catch_mlx("forward", || backend.forward(&[next_token], pos)) {
2046                Ok(v) => v,
2047                Err(e) => {
2048                    drop(guard);
2049                    self.mlx_backends.invalidate(model_id);
2050                    return Err(e);
2051                }
2052            };
2053            next_token = Self::sample_from_logits(&logits, params)?;
2054        }
2055
2056        let text = backend.decode(&generated)?;
2057        Ok((tasks::generate::strip_thinking(&text, params.thinking), ttft_ms))
2058    }
2059
2060    /// Sample a token from a logits Vec<f32> (for MLX backend).
2061    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2062    fn sample_from_logits(logits: &[f32], params: &GenerateParams) -> Result<u32, InferenceError> {
2063        if params.temperature <= 0.0 {
2064            // Greedy: argmax
2065            let (idx, _) = logits
2066                .iter()
2067                .enumerate()
2068                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2069                .ok_or_else(|| InferenceError::InferenceFailed("empty logits".into()))?;
2070            return Ok(idx as u32);
2071        }
2072
2073        // Temperature-scaled softmax sampling
2074        let temp = params.temperature as f32;
2075        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
2076        let mut probs: Vec<f32> = logits
2077            .iter()
2078            .map(|&l| ((l - max_logit) / temp).exp())
2079            .collect();
2080        let sum: f32 = probs.iter().sum();
2081        for p in &mut probs {
2082            *p /= sum;
2083        }
2084
2085        // Top-p (nucleus) sampling
2086        if params.top_p < 1.0 {
2087            let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
2088            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2089            let mut cumsum = 0.0;
2090            let mut cutoff_idx = indexed.len();
2091            for (i, &(_, p)) in indexed.iter().enumerate() {
2092                cumsum += p;
2093                if cumsum > params.top_p as f32 {
2094                    cutoff_idx = i + 1;
2095                    break;
2096                }
2097            }
2098            // Zero out tokens below the cutoff
2099            let allowed: std::collections::HashSet<usize> =
2100                indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
2101            for (i, p) in probs.iter_mut().enumerate() {
2102                if !allowed.contains(&i) {
2103                    *p = 0.0;
2104                }
2105            }
2106            let sum: f32 = probs.iter().sum();
2107            if sum > 0.0 {
2108                for p in &mut probs {
2109                    *p /= sum;
2110                }
2111            }
2112        }
2113
2114        // Sample from distribution
2115        use rand::Rng;
2116        let mut rng = rand::rng();
2117        let r: f32 = rng.random();
2118        let mut cumsum = 0.0;
2119        for (i, &p) in probs.iter().enumerate() {
2120            cumsum += p;
2121            if cumsum >= r {
2122                return Ok(i as u32);
2123            }
2124        }
2125        Ok((probs.len() - 1) as u32)
2126    }
2127
2128    /// Generate embeddings for text using the dedicated embedding model.
2129    /// On Apple Silicon, uses the native MLX backend; on other platforms, uses Candle.
2130    pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
2131        let instruction = req
2132            .instruction
2133            .as_deref()
2134            .unwrap_or("Retrieve relevant memory facts");
2135
2136        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2137        {
2138            let model_id = self.ensure_mlx_embedding_backend().await?;
2139            let schema = self
2140                .unified_registry
2141                .get(&model_id)
2142                .cloned()
2143                .ok_or_else(|| InferenceError::InferenceFailed(format!(
2144                    "embed: unknown schema id {model_id}"
2145                )))?;
2146            let handle = self.ensure_mlx_backend(&schema).await?;
2147            let mut guard = handle.lock().map_err(|_| {
2148                InferenceError::InferenceFailed(format!(
2149                    "MLX embedding backend mutex poisoned for {model_id}"
2150                ))
2151            })?;
2152            let backend: &mut backend::MlxBackend = &mut *guard;
2153
2154            let mut results = Vec::with_capacity(req.texts.len());
2155            for text in &req.texts {
2156                let embedding = if req.is_query {
2157                    backend.embed_query(text, instruction)?
2158                } else {
2159                    backend.embed_one(text)?
2160                };
2161                results.push(embedding);
2162            }
2163            return Ok(results);
2164        }
2165
2166        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
2167        {
2168            self.ensure_embedding_backend().await?;
2169            let mut write = self.embedding_backend.write().await;
2170            let backend = write.as_mut().unwrap();
2171
2172            let mut results = Vec::with_capacity(req.texts.len());
2173            for text in &req.texts {
2174                let embedding = if req.is_query {
2175                    backend.embed_query(text, instruction)?
2176                } else {
2177                    backend.embed_one(text)?
2178                };
2179                results.push(embedding);
2180            }
2181            Ok(results)
2182        }
2183    }
2184
2185    /// Rerank candidate documents against a query using a cross-encoder
2186    /// reranker model (Qwen3-Reranker family). Returns documents sorted
2187    /// by descending relevance.
2188    ///
2189    /// ## Scoring
2190    ///
2191    /// Qwen3-Reranker is a Qwen3 base LM fine-tuned so that the first
2192    /// assistant token is `"yes"` or `"no"` given the templated
2193    /// `<Instruct>/<Query>/<Document>` user turn. We run a short
2194    /// greedy decode (≤ 3 tokens, so a leading space, BOS artifact, or
2195    /// the occasional newline don't break us) and score
2196    /// `yes → 1.0`, `no → 0.0`, anything else → `0.5` with a warning.
2197    ///
2198    /// This is a **binary** score — the soft probability
2199    /// `softmax(logit_yes, logit_no)` would give finer ordering but
2200    /// requires per-token logit access on [`backend::MlxBackend`],
2201    /// which isn't exposed publicly yet. Tracked as a follow-up;
2202    /// binary scores still produce a correct partial ordering, just
2203    /// with coarser tiebreaks within the {yes} or {no} groups.
2204    ///
2205    /// ## Prompt template
2206    ///
2207    /// We emit the upstream Qwen3-Reranker chat template verbatim:
2208    /// a dedicated system prompt fixing the yes/no answer space,
2209    /// then the user turn with `<Instruct>/<Query>/<Document>`, then
2210    /// the assistant prefix with a closed empty `<think>` block to
2211    /// suppress thinking (reranker is not a reasoner — it's a
2212    /// classifier). Deviating from this template produces sharply
2213    /// degraded yes/no distributions.
2214    pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResult, InferenceError> {
2215        if req.documents.is_empty() {
2216            return Ok(RerankResult { ranked: Vec::new(), model_used: None });
2217        }
2218
2219        let model_name = match req.model.clone() {
2220            Some(m) => m,
2221            None => self
2222                .preferred_model_for_capability(ModelCapability::Rerank)
2223                .map(str::to_string)
2224                .ok_or_else(|| {
2225                    InferenceError::InferenceFailed(
2226                        "no reranker model available — pull a Qwen3-Reranker model first".into(),
2227                    )
2228                })?,
2229        };
2230
2231        let schema = self
2232            .unified_registry
2233            .find_by_name(&model_name)
2234            .or_else(|| self.unified_registry.get(&model_name))
2235            .cloned()
2236            .ok_or_else(|| {
2237                InferenceError::InferenceFailed(format!(
2238                    "rerank: unknown reranker model {model_name}"
2239                ))
2240            })?;
2241        if !schema.has_capability(ModelCapability::Rerank) {
2242            return Err(InferenceError::InferenceFailed(format!(
2243                "model {} does not declare the Rerank capability",
2244                schema.name
2245            )));
2246        }
2247
2248        let instruction = req.instruction.as_deref().unwrap_or(
2249            "Given a web search query, retrieve relevant passages that answer the query",
2250        );
2251
2252        let mut scored: Vec<RerankedDocument> = Vec::with_capacity(req.documents.len());
2253        for (idx, doc) in req.documents.iter().enumerate() {
2254            let prompt = rerank_prompt(instruction, &req.query, doc);
2255            let gen_req = GenerateRequest {
2256                prompt,
2257                model: Some(schema.id.clone()),
2258                params: tasks::generate::GenerateParams {
2259                    temperature: 0.0,
2260                    // Three tokens is enough to scan past a leading
2261                    // space, BOS, or newline that some tokenizers
2262                    // insert before the real yes/no token.
2263                    max_tokens: 3,
2264                    thinking: tasks::generate::ThinkingMode::Off,
2265                    ..Default::default()
2266                },
2267                context: None,
2268                tools: None,
2269                images: None,
2270                messages: None,
2271                cache_control: false,
2272                response_format: None,
2273                intent: None,
2274            };
2275            let out = self.generate(gen_req).await?;
2276            let score = score_from_rerank_output(&out, &schema.name);
2277            scored.push(RerankedDocument {
2278                index: idx,
2279                score,
2280                document: doc.clone(),
2281            });
2282        }
2283
2284        // Sort descending by score; preserve original index as a
2285        // deterministic tiebreaker. top_n must truncate after sorting.
2286        scored.sort_by(|a, b| {
2287            b.score
2288                .partial_cmp(&a.score)
2289                .unwrap_or(std::cmp::Ordering::Equal)
2290                .then_with(|| a.index.cmp(&b.index))
2291        });
2292        if let Some(n) = req.top_n {
2293            scored.truncate(n);
2294        }
2295
2296        Ok(RerankResult {
2297            ranked: scored,
2298            model_used: Some(schema.name),
2299        })
2300    }
2301
2302    /// Dedicated endpoint for structured visual grounding.
2303    ///
2304    /// Runs a VL generate call under the hood and parses Qwen2.5-VL's
2305    /// inline `<|object_ref_*|>...<|box_*|>(x1,y1),(x2,y2)` spans into
2306    /// typed [`BoundingBox`]es. Distinct from the generic
2307    /// [`InferenceEngine::generate`] + `InferenceResult.bounding_boxes`
2308    /// path so callers can express "I want boxes" as a first-class
2309    /// intent — which also lets the router prefer models that declare
2310    /// the `Grounding` capability.
2311    pub async fn ground(&self, req: GroundRequest) -> Result<GroundResult, InferenceError> {
2312        let model_name = match req.model.clone() {
2313            Some(m) => m,
2314            None => self
2315                .preferred_model_for_capability(ModelCapability::Grounding)
2316                .map(str::to_string)
2317                .ok_or_else(|| {
2318                    InferenceError::InferenceFailed(
2319                        "no grounding-capable model available — pull a Qwen2.5-VL model first"
2320                            .into(),
2321                    )
2322                })?,
2323        };
2324
2325        let gen_req = GenerateRequest {
2326            prompt: req.prompt.clone(),
2327            model: Some(model_name),
2328            params: GenerateParams::default(),
2329            context: None,
2330            tools: None,
2331            images: Some(vec![req.image.clone()]),
2332            messages: None,
2333            cache_control: false,
2334            response_format: None,
2335            intent: None,
2336        };
2337        let result = self.generate_tracked(gen_req).await?;
2338        Ok(GroundResult {
2339            boxes: result.bounding_boxes,
2340            raw_text: result.text,
2341            model_used: Some(result.model_used),
2342        })
2343    }
2344
2345    /// Classify text against candidate labels.
2346    /// When `req.model` is None, routes to the smallest available model.
2347    pub async fn classify(
2348        &self,
2349        req: ClassifyRequest,
2350    ) -> Result<Vec<ClassifyResult>, InferenceError> {
2351        let model = match req
2352            .model
2353            .clone()
2354            .or_else(|| self.preferred_model_for_capability(ModelCapability::Classify).map(str::to_string))
2355        {
2356            Some(m) => m,
2357            None => {
2358                let m = self.router.route_small(&self.registry);
2359                debug!(model = %m, "auto-routed classify request");
2360                m
2361            }
2362        };
2363
2364        // On Apple Silicon, route through the main generate path (which uses MLX)
2365        // instead of the Candle backend directly.
2366        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2367        {
2368            return self.classify_via_generate(req, &model).await;
2369        }
2370
2371        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
2372        {
2373            self.ensure_backend(&model).await?;
2374            let mut write = self.backend.write().await;
2375            let backend = write.as_mut().unwrap();
2376            tasks::classify::classify(backend, req).await
2377        }
2378    }
2379
2380    /// Classify by routing through the main generate path (uses MLX on Apple Silicon).
2381    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2382    async fn classify_via_generate(
2383        &self,
2384        req: ClassifyRequest,
2385        model: &str,
2386    ) -> Result<Vec<ClassifyResult>, InferenceError> {
2387        let labels_str = req
2388            .labels
2389            .iter()
2390            .enumerate()
2391            .map(|(i, l)| format!("{}. {}", i + 1, l))
2392            .collect::<Vec<_>>()
2393            .join("\n");
2394
2395        let prompt = format!(
2396            "Classify the following text into one of these categories:\n\
2397             {labels_str}\n\n\
2398             Text: {}\n\n\
2399             Respond with ONLY the category name, nothing else.",
2400            req.text
2401        );
2402
2403        let gen_req = GenerateRequest {
2404            prompt,
2405            model: Some(model.to_string()),
2406            params: tasks::generate::GenerateParams {
2407                temperature: 0.0,
2408                max_tokens: 32,
2409                // Classification is latency-sensitive and single-label;
2410                // force the fast no-thinking path even on Qwen3.
2411                thinking: tasks::generate::ThinkingMode::Off,
2412                ..Default::default()
2413            },
2414            context: None,
2415            tools: None,
2416            images: None,
2417            messages: None,
2418            cache_control: false,
2419            response_format: None,
2420            intent: None,
2421        };
2422
2423        let response = self.generate(gen_req).await?;
2424        let response_lower = response.trim().to_lowercase();
2425
2426        let mut results: Vec<ClassifyResult> = req
2427            .labels
2428            .iter()
2429            .map(|label| {
2430                let label_lower = label.to_lowercase();
2431                let score = if response_lower == label_lower {
2432                    1.0
2433                } else if response_lower.contains(&label_lower) {
2434                    0.8
2435                } else {
2436                    let label_words: Vec<&str> = label_lower.split_whitespace().collect();
2437                    let matches = label_words
2438                        .iter()
2439                        .filter(|w| response_lower.contains(**w))
2440                        .count();
2441                    if label_words.is_empty() {
2442                        0.0
2443                    } else {
2444                        0.5 * (matches as f64 / label_words.len() as f64)
2445                    }
2446                };
2447                ClassifyResult {
2448                    label: label.clone(),
2449                    score,
2450                }
2451            })
2452            .collect();
2453
2454        results.sort_by(|a, b| {
2455            b.score
2456                .partial_cmp(&a.score)
2457                .unwrap_or(std::cmp::Ordering::Equal)
2458        });
2459
2460        let total: f64 = results.iter().map(|r| r.score).sum();
2461        if total > 0.0 {
2462            for r in &mut results {
2463                r.score /= total;
2464            }
2465        }
2466
2467        Ok(results)
2468    }
2469
2470    /// Transcribe an audio file using the best available STT model.
2471    pub async fn transcribe(
2472        &self,
2473        req: TranscribeRequest,
2474    ) -> Result<TranscribeResult, InferenceError> {
2475        let candidates =
2476            self.speech_candidates(ModelCapability::SpeechToText, req.model.as_deref())?;
2477        let mut last_error = None;
2478
2479        for schema in candidates {
2480            let result = match &schema.source {
2481                ModelSource::Mlx { .. } => self.transcribe_local_mlx(&schema, &req).await,
2482                ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2483                    self.transcribe_elevenlabs(&schema, &req).await
2484                }
2485                _ => Err(InferenceError::InferenceFailed(format!(
2486                    "speech-to-text not implemented for model source: {}",
2487                    schema.id
2488                ))),
2489            };
2490
2491            match result {
2492                Ok(result) => return Ok(result),
2493                Err(err) => last_error = Some(err),
2494            }
2495        }
2496
2497        Err(last_error.unwrap_or_else(|| {
2498            InferenceError::InferenceFailed("no speech-to-text models available".into())
2499        }))
2500    }
2501
2502    /// Synthesize speech using the best available TTS model.
2503    pub async fn synthesize(
2504        &self,
2505        req: SynthesizeRequest,
2506    ) -> Result<SynthesizeResult, InferenceError> {
2507        let candidates =
2508            self.speech_candidates(ModelCapability::TextToSpeech, req.model.as_deref())?;
2509        let mut last_error = None;
2510
2511        for schema in candidates {
2512            let result = match &schema.source {
2513                ModelSource::Mlx { .. } => self.synthesize_local_mlx(&schema, &req).await,
2514                ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2515                    self.synthesize_elevenlabs(&schema, &req).await
2516                }
2517                _ => Err(InferenceError::InferenceFailed(format!(
2518                    "text-to-speech not implemented for model source: {}",
2519                    schema.id
2520                ))),
2521            };
2522
2523            match result {
2524                Ok(result) => return Ok(result),
2525                Err(err) => last_error = Some(err),
2526            }
2527        }
2528
2529        Err(last_error.unwrap_or_else(|| {
2530            InferenceError::InferenceFailed("no text-to-speech models available".into())
2531        }))
2532    }
2533
2534    /// Generate an image using the best available local MLX image model.
2535    pub async fn generate_image(
2536        &self,
2537        req: GenerateImageRequest,
2538    ) -> Result<GenerateImageResult, InferenceError> {
2539        // Dispatch: the native Rust MLX Flux backend now reaches prompt-faithful
2540        // parity at 512×512×4-step against mflux (verified via the parity harness
2541        // in tools/parity/ref_flux.py + diff_flux_blocks.py). Default to native;
2542        // `CAR_IMAGE_BACKEND=external` still routes to the Python mflux fallback
2543        // for A/B comparison, `CAR_IMAGE_BACKEND=auto` reads env preference then
2544        // auto-detects.
2545        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2546        {
2547            use crate::backend::external_flux;
2548            let backend = std::env::var("CAR_IMAGE_BACKEND")
2549                .unwrap_or_else(|_| "native".to_string());
2550            let use_external = match backend.as_str() {
2551                "external" => true,
2552                "native" => false,
2553                // `auto` honors external only when the subprocess CLI is on PATH.
2554                _ => external_flux::is_available() && backend == "auto-external",
2555            };
2556            if use_external {
2557                tracing::info!(
2558                    "routing image generation to external mflux \
2559                     (set CAR_IMAGE_BACKEND=native to use the Rust port)"
2560                );
2561                let mut req = req;
2562                req.model = self.resolve_external_hf_repo(
2563                    req.model.as_deref(),
2564                    ModelCapability::ImageGeneration,
2565                );
2566                return external_flux::generate_image(&req);
2567            }
2568            tracing::info!("using native Rust MLX Flux backend");
2569        }
2570
2571        let candidates =
2572            self.media_generation_candidates(ModelCapability::ImageGeneration, req.model.as_deref())?;
2573        let mut last_error = None;
2574
2575        for schema in candidates {
2576            let result = match &schema.source {
2577                #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2578                ModelSource::Mlx { .. } => self.generate_image_native_mlx(&schema, &req).await,
2579                _ => Err(InferenceError::InferenceFailed(format!(
2580                    "image generation not implemented for model source: {}",
2581                    schema.id
2582                ))),
2583            };
2584
2585            match result {
2586                Ok(result) => return Ok(result),
2587                Err(err) => last_error = Some(err),
2588            }
2589        }
2590
2591        Err(last_error.unwrap_or_else(|| {
2592            InferenceError::InferenceFailed("no image generation models available".into())
2593        }))
2594    }
2595
2596    /// Generate one or more variants in a single call.
2597    ///
2598    /// Returns `req.variant_count` results (defaulting to 1). The
2599    /// current MLX Flux backend doesn't support native batching, so
2600    /// this loops over `generate_image` with the seed advanced per
2601    /// variant for visual diversity. A future hosted backend
2602    /// (gpt-image-2, Replicate) can short-circuit this with one
2603    /// network call producing N coherent images.
2604    ///
2605    /// Per-variant errors abort the batch — there's no partial-
2606    /// success semantics today. Callers needing more lenient
2607    /// behaviour should call `generate_image` directly in their own
2608    /// loop.
2609    ///
2610    /// Closes #110.
2611    pub async fn generate_image_batch(
2612        &self,
2613        req: GenerateImageRequest,
2614    ) -> Result<Vec<GenerateImageResult>, InferenceError> {
2615        let count = req.variant_count.unwrap_or(1).max(1);
2616        if count == 1 {
2617            return self.generate_image(req).await.map(|r| vec![r]);
2618        }
2619        let base_seed = req.seed.unwrap_or(0);
2620        let mut results = Vec::with_capacity(count as usize);
2621        for i in 0..count {
2622            // Vary the seed per variant so backends that key prompt
2623            // → output deterministically actually produce different
2624            // images. Callers wanting reproducible single-seed
2625            // variants override `seed` per call themselves.
2626            let mut variant_req = req.clone();
2627            variant_req.seed = Some(base_seed.wrapping_add(i as u64));
2628            // Suppress variant_count on the inner call to avoid
2629            // recursion — generate_image ignores the field today,
2630            // but this also documents intent.
2631            variant_req.variant_count = Some(1);
2632            results.push(self.generate_image(variant_req).await?);
2633        }
2634        Ok(results)
2635    }
2636
2637    /// Native MLX Flux image generation (no Python shelling).
2638    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2639    async fn generate_image_native_mlx(
2640        &self,
2641        schema: &ModelSchema,
2642        req: &GenerateImageRequest,
2643    ) -> Result<GenerateImageResult, InferenceError> {
2644        let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2645        let size = backend_cache::estimate_model_size(&model_dir);
2646        let cache = Arc::clone(&self.flux_cache);
2647        let key = schema.id.clone();
2648        let handle = cache.get_or_load(&key, size, || {
2649            backend::mlx_flux::FluxBackend::load(&model_dir)
2650        })?;
2651        // Run the synchronous, GPU-bound inference on a blocking worker
2652        // so we don't stall the tokio runtime. The per-model mutex
2653        // serializes concurrent requests for the same schema.
2654        let req = req.clone();
2655        tokio::task::spawn_blocking(move || -> Result<GenerateImageResult, InferenceError> {
2656            let mut guard = handle.lock().map_err(|_| {
2657                InferenceError::InferenceFailed("flux backend mutex poisoned".into())
2658            })?;
2659            guard.generate(&req)
2660        })
2661        .await
2662        .map_err(|e| InferenceError::InferenceFailed(format!("flux task join: {e}")))?
2663    }
2664
2665    /// Native MLX LTX-2.3 video generation (no Python shelling).
2666    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2667    async fn generate_video_native_mlx(
2668        &self,
2669        schema: &ModelSchema,
2670        req: &GenerateVideoRequest,
2671    ) -> Result<GenerateVideoResult, InferenceError> {
2672        let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2673        let size = backend_cache::estimate_model_size(&model_dir);
2674        let cache = Arc::clone(&self.ltx_cache);
2675        let key = schema.id.clone();
2676        let handle = cache.get_or_load(&key, size, || {
2677            backend::mlx_ltx::LtxBackend::load(&model_dir)
2678        })?;
2679        let req = req.clone();
2680        tokio::task::spawn_blocking(move || -> Result<GenerateVideoResult, InferenceError> {
2681            let mut guard = handle.lock().map_err(|_| {
2682                InferenceError::InferenceFailed("ltx backend mutex poisoned".into())
2683            })?;
2684            guard.generate(&req)
2685        })
2686        .await
2687        .map_err(|e| InferenceError::InferenceFailed(format!("ltx task join: {e}")))?
2688    }
2689
2690    /// Generate a video using the best available local MLX video model.
2691    pub async fn generate_video(
2692        &self,
2693        req: GenerateVideoRequest,
2694    ) -> Result<GenerateVideoResult, InferenceError> {
2695        // Validate the request shape up front so callers get a clean
2696        // error rather than a backend failure deep in the stack.
2697        if let Err(msg) = req.validate() {
2698            return Err(InferenceError::InferenceFailed(format!(
2699                "invalid GenerateVideoRequest: {}",
2700                msg
2701            )));
2702        }
2703        // Backend selection (temporary — until the native Rust MLX port
2704        // reaches parity with the upstream Python port):
2705        //   CAR_VIDEO_BACKEND=external  → call the `ltx-2-mlx` Python CLI.
2706        //   CAR_VIDEO_BACKEND=native    → force the native Rust MLX path.
2707        //   <unset> / "auto" (default)  → external when the CLI is on PATH,
2708        //                                  else fall back to native.
2709        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2710        {
2711            use crate::backend::external_ltx;
2712            // Default flipped to `native` now that the Rust port reaches
2713            // quality parity with upstream ltx-2-mlx (matching prompt
2714            // produces a matching subject; see the cascade of fixes in
2715            // #40 / #45). `external` still honored for A/B comparison.
2716            let backend = std::env::var("CAR_VIDEO_BACKEND")
2717                .unwrap_or_else(|_| "native".to_string());
2718            let use_external = match backend.as_str() {
2719                "external" => true,
2720                "native" => false,
2721                // `auto-external` is an opt-in to the old behavior: use
2722                // the Python CLI if it happens to be on PATH, else fall
2723                // back to native.
2724                "auto-external" => external_ltx::is_available(),
2725                _ => false,
2726            };
2727            if use_external {
2728                tracing::info!(
2729                    "routing video generation to external ltx-2-mlx \
2730                     (set CAR_VIDEO_BACKEND=native to opt into the in-progress Rust port)"
2731                );
2732                let mut req = req;
2733                req.model = self.resolve_external_hf_repo(
2734                    req.model.as_deref(),
2735                    ModelCapability::VideoGeneration,
2736                );
2737                return external_ltx::generate_video(&req);
2738            }
2739            tracing::info!(
2740                "using native Rust MLX video backend — Gemma left-pad bug was \
2741                 just fixed; verify output quality before making this the default"
2742            );
2743        }
2744
2745        let candidates =
2746            self.media_generation_candidates(ModelCapability::VideoGeneration, req.model.as_deref())?;
2747        let mut last_error = None;
2748
2749        for schema in candidates {
2750            let result = match &schema.source {
2751                #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2752                ModelSource::Mlx { .. } => self.generate_video_native_mlx(&schema, &req).await,
2753                _ => Err(InferenceError::InferenceFailed(format!(
2754                    "video generation not implemented for model source: {}",
2755                    schema.id
2756                ))),
2757            };
2758
2759            match result {
2760                Ok(result) => return Ok(result),
2761                Err(err) => last_error = Some(err),
2762            }
2763        }
2764
2765        Err(last_error.unwrap_or_else(|| {
2766            InferenceError::InferenceFailed("no video generation models available".into())
2767        }))
2768    }
2769
2770    /// List all known models and their status (new registry).
2771    pub fn list_models_unified(&self) -> Vec<ModelInfo> {
2772        self.unified_registry
2773            .list()
2774            .iter()
2775            .map(|m| ModelInfo::from(*m))
2776            .collect()
2777    }
2778
2779    /// List all known models and their download status (legacy).
2780    /// List all model schemas from the unified registry (full metadata).
2781    pub fn list_schemas(&self) -> Vec<ModelSchema> {
2782        self.unified_registry.list().into_iter().cloned().collect()
2783    }
2784
2785    pub fn list_models(&self) -> Vec<models::ModelInfo> {
2786        self.registry.list_models()
2787    }
2788
2789    /// Download a model if not already present.
2790    pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
2791        let schema = self
2792            .unified_registry
2793            .find_by_name(name)
2794            .or_else(|| self.unified_registry.get(name))
2795            .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2796        self.unified_registry.ensure_local(&schema.id).await
2797    }
2798
2799    /// Remove a downloaded model.
2800    pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
2801        let schema = self
2802            .unified_registry
2803            .get(name)
2804            .or_else(|| {
2805                self.unified_registry
2806                    .list()
2807                    .into_iter()
2808                    .find(|schema| schema.name.eq_ignore_ascii_case(name))
2809            })
2810            .or_else(|| self.unified_registry.find_by_name(name))
2811            .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2812        let model_dir = self.unified_registry.models_dir().join(&schema.name);
2813        if model_dir.exists() {
2814            std::fs::remove_dir_all(&model_dir)?;
2815        }
2816        match &schema.source {
2817            ModelSource::Mlx { hf_repo, .. } => {
2818                remove_huggingface_repo_cache(hf_repo)?;
2819            }
2820            ModelSource::Local {
2821                hf_repo,
2822                tokenizer_repo,
2823                ..
2824            } => {
2825                remove_huggingface_repo_cache(hf_repo)?;
2826                remove_huggingface_repo_cache(tokenizer_repo)?;
2827            }
2828            _ => {}
2829        }
2830        Ok(())
2831    }
2832
2833    /// Register a model in the unified registry.
2834    pub fn register_model(&mut self, schema: ModelSchema) {
2835        self.unified_registry.register(schema);
2836    }
2837
2838    /// Discover generic MLX models from a running vLLM-MLX server and register them.
2839    /// Returns the number of discovered models added or refreshed in the registry.
2840    pub async fn discover_vllm_mlx_models(&mut self) -> usize {
2841        let config = vllm_mlx::VllmMlxConfig::default();
2842        if !config.auto_discover {
2843            return 0;
2844        }
2845        vllm_mlx::discover_and_register(&config, &mut self.unified_registry).await
2846    }
2847
2848    /// Get outcome tracker for external use (e.g., memgine integration).
2849    pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
2850        self.outcome_tracker.clone()
2851    }
2852
2853    /// Auto-save outcomes and key pool stats silently (called after every inference call).
2854    async fn auto_save_outcomes(&self) {
2855        if let Err(e) = self.save_outcomes().await {
2856            tracing::debug!("auto-save outcomes failed: {}", e);
2857        }
2858        if let Err(e) = self.save_key_pool_stats().await {
2859            tracing::debug!("auto-save key pool stats failed: {}", e);
2860        }
2861    }
2862
2863    /// Persist outcome profiles to disk for cross-session learning (#13).
2864    pub async fn save_outcomes(&self) -> Result<(), std::io::Error> {
2865        let tracker = self.outcome_tracker.read().await;
2866        let path = self.config.models_dir.join("outcome_profiles.json");
2867        tracker.save_to_file(&path)
2868    }
2869
2870    /// Persist key pool stats to disk.
2871    pub async fn save_key_pool_stats(&self) -> Result<(), std::io::Error> {
2872        let path = self.config.models_dir.join("key_pool_stats.json");
2873        self.remote_backend.key_pool.save_stats(&path).await
2874    }
2875
2876    /// Get key pool stats for all endpoints.
2877    pub async fn key_pool_stats(
2878        &self,
2879    ) -> std::collections::HashMap<String, Vec<key_pool::KeyStats>> {
2880        self.remote_backend.key_pool.all_stats().await
2881    }
2882
2883    /// Export model performance profiles for persistence.
2884    pub async fn export_profiles(&self) -> Vec<ModelProfile> {
2885        let tracker = self.outcome_tracker.read().await;
2886        tracker.export_profiles()
2887    }
2888
2889    /// Import model performance profiles (from persistence).
2890    pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
2891        let mut tracker = self.outcome_tracker.write().await;
2892        tracker.import_profiles(profiles);
2893    }
2894
2895    /// Ensure the managed local speech runtime exists and return its root directory.
2896    /// On Apple Silicon, speech uses native MLX backends and no Python runtime is needed.
2897    pub async fn prepare_speech_runtime(&self) -> Result<PathBuf, InferenceError> {
2898        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2899        {
2900            // Native MLX speech — no Python runtime needed.
2901            Ok(self.config.models_dir.clone())
2902        }
2903        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
2904        {
2905            Ok(self.ensure_speech_runtime().await?.root)
2906        }
2907    }
2908
2909    /// Override speech routing preferences for the current engine instance.
2910    pub fn set_speech_policy(&mut self, policy: SpeechPolicy) {
2911        self.speech_policy = policy;
2912    }
2913
2914    pub fn set_routing_config(&mut self, config: RoutingConfig) {
2915        self.adaptive_router.set_config(config);
2916    }
2917
2918    /// Download the curated local speech model set into the shared Hugging Face cache.
2919    pub async fn install_curated_speech(
2920        &mut self,
2921    ) -> Result<Vec<SpeechInstallReport>, InferenceError> {
2922        let _runtime_root = self.prepare_speech_runtime().await?;
2923        let schemas = self.list_schemas();
2924        let mut repos = Vec::new();
2925        for schema in &schemas {
2926            if !schema.is_mlx() || !schema.tags.iter().any(|tag| tag == "speech") {
2927                continue;
2928            }
2929            if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
2930                if !repos.iter().any(|existing: &String| existing == hf_repo) {
2931                    repos.push(hf_repo.clone());
2932                }
2933            }
2934        }
2935
2936        let mut installed = Vec::new();
2937        for repo in repos {
2938            let (snapshot_path, files_downloaded) = download_hf_repo_snapshot(&repo).await?;
2939            let name = schemas
2940                .iter()
2941                .find(|schema| {
2942                    matches!(&schema.source, ModelSource::Mlx { hf_repo, .. } if hf_repo == &repo)
2943                })
2944                .map(|schema| schema.name.clone())
2945                .unwrap_or_else(|| repo.clone());
2946            installed.push(SpeechInstallReport {
2947                name,
2948                hf_repo: repo,
2949                snapshot_path,
2950                files_downloaded,
2951            });
2952        }
2953
2954        self.unified_registry.refresh_availability();
2955        Ok(installed)
2956    }
2957
2958
2959    /// Report speech runtime, model cache, and remote-provider health.
2960    pub fn speech_health(&self) -> SpeechHealthReport {
2961        let local_stt_default = self
2962            .speech_health_default_name(ModelCapability::SpeechToText, true, false);
2963        let local_tts_default = self
2964            .speech_health_default_name(ModelCapability::TextToSpeech, true, false);
2965        let remote_stt_default = self
2966            .speech_health_default_name(ModelCapability::SpeechToText, false, true);
2967        let remote_tts_default = self
2968            .speech_health_default_name(ModelCapability::TextToSpeech, false, true);
2969
2970        let mut local_models = Vec::new();
2971        let mut remote_models = Vec::new();
2972        for schema in self.list_schemas() {
2973            let capability = if schema.has_capability(ModelCapability::SpeechToText) {
2974                Some(ModelCapability::SpeechToText)
2975            } else if schema.has_capability(ModelCapability::TextToSpeech) {
2976                Some(ModelCapability::TextToSpeech)
2977            } else {
2978                None
2979            };
2980            let Some(capability) = capability else {
2981                continue;
2982            };
2983
2984            let selected_by_default = local_stt_default
2985                .as_ref()
2986                .is_some_and(|name| name == &schema.name)
2987                || local_tts_default
2988                    .as_ref()
2989                    .is_some_and(|name| name == &schema.name)
2990                || remote_stt_default
2991                    .as_ref()
2992                    .is_some_and(|name| name == &schema.name)
2993                || remote_tts_default
2994                    .as_ref()
2995                    .is_some_and(|name| name == &schema.name);
2996
2997            let health = SpeechModelHealth {
2998                id: schema.id.clone(),
2999                name: schema.name.clone(),
3000                provider: schema.provider.clone(),
3001                capability,
3002                is_local: schema.is_local(),
3003                available: schema.available,
3004                cached: speech_model_cached(&schema),
3005                selected_by_default,
3006                source: speech_model_source_label(&schema),
3007            };
3008            if schema.is_local() {
3009                local_models.push(health);
3010            } else {
3011                remote_models.push(health);
3012            }
3013        }
3014
3015        // On Apple Silicon, native MLX backends are used instead of a Python runtime.
3016        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
3017        let runtime = SpeechRuntimeHealth {
3018            root: self.config.models_dir.clone(),
3019            installed: true,
3020            python: PathBuf::new(),
3021            stt_command: PathBuf::new(),
3022            tts_command: PathBuf::new(),
3023            configured_python: None,
3024            detected_python: None,
3025        };
3026
3027        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3028        let runtime = {
3029            let rt = SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3030            SpeechRuntimeHealth {
3031                root: rt.root.clone(),
3032                installed: rt.is_ready(),
3033                python: rt.python.clone(),
3034                stt_command: rt.stt_program.clone(),
3035                tts_command: rt.tts_program.clone(),
3036                configured_python: std::env::var("CAR_SPEECH_PYTHON")
3037                    .ok()
3038                    .filter(|value| !value.trim().is_empty()),
3039                detected_python: detect_speech_python(),
3040            }
3041        };
3042
3043        SpeechHealthReport {
3044            runtime,
3045            local_models,
3046            remote_models,
3047            elevenlabs_configured: car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY").is_some(),
3048            prefer_local: self.speech_policy.prefer_local,
3049            allow_remote_fallback: self.speech_policy.allow_remote_fallback,
3050            preferred_local_stt: self.speech_policy.preferred_local_stt.clone(),
3051            preferred_local_tts: self.speech_policy.preferred_local_tts.clone(),
3052            preferred_remote_stt: self.speech_policy.preferred_remote_stt.clone(),
3053            preferred_remote_tts: self.speech_policy.preferred_remote_tts.clone(),
3054            local_stt_default,
3055            local_tts_default,
3056            remote_stt_default,
3057            remote_tts_default,
3058        }
3059    }
3060
3061    /// Report the current model catalog, configured defaults, capability coverage,
3062    /// and speech runtime/provider health in one place.
3063    pub async fn model_health(&self) -> ModelHealthReport {
3064        let schemas = self.list_schemas();
3065        let total_models = schemas.len();
3066        let available_models = schemas.iter().filter(|schema| schema.available).count();
3067        let local_models = schemas.iter().filter(|schema| schema.is_local()).count();
3068        let remote_models = total_models.saturating_sub(local_models);
3069
3070        let defaults = vec![
3071            self.model_default_health(
3072                ModelCapability::Generate,
3073                self.preferred_model_for_capability(ModelCapability::Generate)
3074                    .unwrap_or(&self.config.generation_model),
3075            ),
3076            self.model_default_health(
3077                ModelCapability::Embed,
3078                self.preferred_model_for_capability(ModelCapability::Embed)
3079                    .unwrap_or(&self.config.embedding_model),
3080            ),
3081            self.model_default_health(
3082                ModelCapability::Classify,
3083                self.preferred_model_for_capability(ModelCapability::Classify)
3084                    .unwrap_or(&self.config.classification_model),
3085            ),
3086        ];
3087
3088        let mut providers = std::collections::BTreeMap::new();
3089        for schema in &schemas {
3090            let entry = providers
3091                .entry(schema.provider.clone())
3092                .or_insert_with(|| ProviderAccumulator {
3093                    configured: false,
3094                    local_models: 0,
3095                    remote_models: 0,
3096                    available_models: 0,
3097                    capabilities: std::collections::HashSet::new(),
3098                });
3099
3100            entry.configured |= model_source_configured(schema);
3101            if schema.is_local() {
3102                entry.local_models += 1;
3103            } else {
3104                entry.remote_models += 1;
3105            }
3106            if schema.available {
3107                entry.available_models += 1;
3108            }
3109            for capability in &schema.capabilities {
3110                entry.capabilities.insert(*capability);
3111            }
3112        }
3113
3114        let providers = providers
3115            .into_iter()
3116            .map(|(provider, acc)| ModelProviderHealth {
3117                provider,
3118                configured: acc.configured,
3119                local_models: acc.local_models,
3120                remote_models: acc.remote_models,
3121                available_models: acc.available_models,
3122                capabilities: sort_capabilities(acc.capabilities.into_iter().collect()),
3123            })
3124            .collect();
3125
3126        let capabilities = all_model_capabilities()
3127            .into_iter()
3128            .map(|capability| {
3129                let relevant: Vec<&ModelSchema> = schemas
3130                    .iter()
3131                    .filter(|schema| schema.has_capability(capability))
3132                    .collect();
3133                let available: Vec<&ModelSchema> = relevant
3134                    .iter()
3135                    .copied()
3136                    .filter(|schema| schema.available)
3137                    .collect();
3138                ModelCapabilityHealth {
3139                    capability,
3140                    total_models: relevant.len(),
3141                    available_models: available.len(),
3142                    local_available_models: available
3143                        .iter()
3144                        .filter(|schema| schema.is_local())
3145                        .count(),
3146                    remote_available_models: available
3147                        .iter()
3148                        .filter(|schema| !schema.is_local())
3149                        .count(),
3150                }
3151            })
3152            .collect();
3153
3154        let routing = self.routing_scenarios().await;
3155        let routing_config = self.adaptive_router.config().clone();
3156        let benchmark_priors = load_benchmark_prior_health(&self.config.models_dir, &schemas);
3157
3158        ModelHealthReport {
3159            total_models,
3160            available_models,
3161            local_models,
3162            remote_models,
3163            defaults,
3164            providers,
3165            capabilities,
3166            routing_prefer_local: routing_config.prefer_local,
3167            routing_quality_first_cold_start: routing_config.quality_first_cold_start,
3168            routing_min_observations: routing_config.min_observations,
3169            routing_bootstrap_min_task_observations: routing_config.bootstrap_min_task_observations,
3170            routing_bootstrap_quality_floor: routing_config.bootstrap_quality_floor,
3171            routing_quality_weight: routing_config.quality_weight,
3172            routing_latency_weight: routing_config.latency_weight,
3173            routing_cost_weight: routing_config.cost_weight,
3174            routing_scenarios: routing,
3175            benchmark_priors,
3176            speech: self.speech_health(),
3177        }
3178    }
3179
3180    async fn routing_scenarios(&self) -> Vec<RoutingScenarioHealth> {
3181        let tracker = self.outcome_tracker.read().await;
3182        let config = self.adaptive_router.config().clone();
3183        let scenarios = [
3184            (
3185                "interactive_text",
3186                "Summarize the benefits of local-first AI routing in two sentences.",
3187                "text",
3188                RoutingWorkload::Interactive,
3189                false,
3190                false,
3191            ),
3192            (
3193                "background_code",
3194                "Write a Python function named fibonacci(n) that returns the nth Fibonacci number.",
3195                "code",
3196                RoutingWorkload::Background,
3197                false,
3198                false,
3199            ),
3200            (
3201                "interactive_tool_use",
3202                "Use the provided weather tool to get the weather for Boston.",
3203                "tool_use",
3204                RoutingWorkload::Interactive,
3205                true,
3206                false,
3207            ),
3208            (
3209                "interactive_vision",
3210                "What is in this image? Answer in one word.",
3211                "vision",
3212                RoutingWorkload::Interactive,
3213                false,
3214                true,
3215            ),
3216        ];
3217
3218        scenarios
3219            .into_iter()
3220            .map(|(name, prompt, task_family, workload, has_tools, has_vision)| {
3221                let decision = self.adaptive_router.route_context_aware(
3222                    prompt,
3223                    0,
3224                    &self.unified_registry,
3225                    &tracker,
3226                    has_tools,
3227                    has_vision,
3228                    workload,
3229                );
3230                let quality_first_cold_start = if has_tools || has_vision {
3231                    config.quality_first_cold_start
3232                } else if task_family == "code" && matches!(workload, RoutingWorkload::Background) {
3233                    false
3234                } else {
3235                    config.quality_first_cold_start
3236                };
3237                RoutingScenarioHealth {
3238                    name: name.to_string(),
3239                    task_family: task_family.to_string(),
3240                    workload,
3241                    has_tools,
3242                    has_vision,
3243                    prefer_local: if task_family == "speech" {
3244                        self.speech_policy.prefer_local
3245                    } else {
3246                        config.prefer_local
3247                    },
3248                    quality_first_cold_start,
3249                    bootstrap_min_task_observations: config.bootstrap_min_task_observations,
3250                    bootstrap_quality_floor: config.bootstrap_quality_floor,
3251                    model_id: decision.model_id,
3252                    model_name: decision.model_name,
3253                    reason: decision.reason,
3254                    strategy: decision.strategy,
3255                }
3256            })
3257            .collect()
3258    }
3259
3260    /// Run a real speech smoke test through the configured local and/or remote paths.
3261    pub async fn smoke_test_speech(
3262        &self,
3263        local: bool,
3264        remote: bool,
3265    ) -> Result<SpeechSmokeReport, InferenceError> {
3266        let mut report = SpeechSmokeReport::default();
3267
3268        if local {
3269            let tts = self
3270                .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
3271                .ok_or_else(|| {
3272                    InferenceError::InferenceFailed(
3273                        "no local text-to-speech model available".into(),
3274                    )
3275                })?;
3276            let stt = self
3277                .preferred_speech_schema(ModelCapability::SpeechToText, true, false)
3278                .ok_or_else(|| {
3279                    InferenceError::InferenceFailed(
3280                        "no local speech-to-text model available".into(),
3281                    )
3282                })?;
3283            report.local = Some(
3284                self.run_speech_smoke_path("local", &tts, &stt, "Testing CAR local speech path.")
3285                    .await?,
3286            );
3287        } else {
3288            report.skipped.push("local".to_string());
3289        }
3290
3291        if remote {
3292            let tts = self
3293                .preferred_speech_schema(ModelCapability::TextToSpeech, false, true)
3294                .ok_or_else(|| {
3295                    InferenceError::InferenceFailed(
3296                        "no remote text-to-speech model available".into(),
3297                    )
3298                })?;
3299            let stt = self
3300                .preferred_speech_schema(ModelCapability::SpeechToText, false, true)
3301                .ok_or_else(|| {
3302                    InferenceError::InferenceFailed(
3303                        "no remote speech-to-text model available".into(),
3304                    )
3305                })?;
3306            report.remote = Some(
3307                self.run_speech_smoke_path("remote", &tts, &stt, "Testing CAR remote speech path.")
3308                    .await?,
3309            );
3310        } else {
3311            report.skipped.push("remote".to_string());
3312        }
3313
3314        Ok(report)
3315    }
3316
3317    fn speech_candidates(
3318        &self,
3319        capability: ModelCapability,
3320        explicit: Option<&str>,
3321    ) -> Result<Vec<ModelSchema>, InferenceError> {
3322        if let Some(model) = explicit {
3323            let schema = self
3324                .unified_registry
3325                .get(model)
3326                .or_else(|| self.unified_registry.find_by_name(model))
3327                .cloned()
3328                .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3329            if !schema.has_capability(capability) {
3330                return Err(InferenceError::InferenceFailed(format!(
3331                    "model {} does not support {:?}",
3332                    schema.name, capability
3333                )));
3334            }
3335            return Ok(vec![schema]);
3336        }
3337
3338        let mut candidates: Vec<ModelSchema> = self
3339            .unified_registry
3340            .query(&ModelFilter {
3341                capabilities: vec![capability],
3342                ..Default::default()
3343            })
3344            .into_iter()
3345            .cloned()
3346            .collect();
3347
3348        if candidates.is_empty() {
3349            return Err(InferenceError::InferenceFailed(format!(
3350                "no models registered for capability {:?}",
3351                capability
3352            )));
3353        }
3354
3355        candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3356        if !self.speech_policy.allow_remote_fallback && candidates.iter().any(|model| model.is_local()) {
3357            candidates.retain(|model| model.is_local());
3358        }
3359
3360        Ok(candidates)
3361    }
3362
3363    /// Resolve a car-canonical model id (e.g. `mlx/flux-1-lite-8b:q4`) to the
3364    /// HuggingFace repo (`mlx-community/Flux-1.lite-8B-MLX-Q4`) that the
3365    /// external Python CLIs expect. Falls back to the input if no schema
3366    /// matches or the schema is not MLX-sourced.
3367    fn resolve_external_hf_repo(
3368        &self,
3369        explicit: Option<&str>,
3370        capability: ModelCapability,
3371    ) -> Option<String> {
3372        let id = explicit?;
3373        let schema = self
3374            .unified_registry
3375            .get(id)
3376            .or_else(|| self.unified_registry.find_by_name(id))?;
3377        if !schema.has_capability(capability) {
3378            return Some(id.to_string());
3379        }
3380        if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3381            return Some(hf_repo.clone());
3382        }
3383        Some(id.to_string())
3384    }
3385
3386    fn media_generation_candidates(
3387        &self,
3388        capability: ModelCapability,
3389        explicit: Option<&str>,
3390    ) -> Result<Vec<ModelSchema>, InferenceError> {
3391        if let Some(model) = explicit {
3392            let schema = self
3393                .unified_registry
3394                .get(model)
3395                .or_else(|| self.unified_registry.find_by_name(model))
3396                .cloned()
3397                .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3398            if !schema.has_capability(capability) {
3399                return Err(InferenceError::InferenceFailed(format!(
3400                    "model {} does not support {:?}",
3401                    schema.name, capability
3402                )));
3403            }
3404            return Ok(vec![schema]);
3405        }
3406
3407        let mut candidates: Vec<ModelSchema> = self
3408            .unified_registry
3409            .query(&ModelFilter {
3410                capabilities: vec![capability],
3411                local_only: true,
3412                ..Default::default()
3413            })
3414            .into_iter()
3415            .cloned()
3416            .collect();
3417        candidates.sort_by_key(|schema| (!schema.available, schema.size_mb()));
3418        if candidates.is_empty() {
3419            return Err(InferenceError::InferenceFailed(format!(
3420                "no models registered for capability {:?}",
3421                capability
3422            )));
3423        }
3424        Ok(candidates)
3425    }
3426
3427    fn preferred_speech_schema(
3428        &self,
3429        capability: ModelCapability,
3430        local_only: bool,
3431        remote_only: bool,
3432    ) -> Option<ModelSchema> {
3433        let available_only = remote_only;
3434        let mut candidates: Vec<ModelSchema> = self
3435            .unified_registry
3436            .query(&ModelFilter {
3437                capabilities: vec![capability],
3438                available_only,
3439                ..Default::default()
3440            })
3441            .into_iter()
3442            .filter(|schema| (!local_only || schema.is_local()) && (!remote_only || schema.is_remote()))
3443            .cloned()
3444            .collect();
3445        candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3446        candidates.into_iter().next()
3447    }
3448
3449    fn speech_health_default_name(
3450        &self,
3451        capability: ModelCapability,
3452        local_only: bool,
3453        remote_only: bool,
3454    ) -> Option<String> {
3455        let preferred = match capability {
3456            ModelCapability::SpeechToText if local_only => {
3457                self.speech_policy.preferred_local_stt.as_ref()
3458            }
3459            ModelCapability::SpeechToText if remote_only => {
3460                self.speech_policy.preferred_remote_stt.as_ref()
3461            }
3462            ModelCapability::TextToSpeech if local_only => {
3463                self.speech_policy.preferred_local_tts.as_ref()
3464            }
3465            ModelCapability::TextToSpeech if remote_only => {
3466                self.speech_policy.preferred_remote_tts.as_ref()
3467            }
3468            _ => None,
3469        };
3470
3471        preferred
3472            .filter(|name| {
3473                self.unified_registry.list().iter().any(|schema| {
3474                    schema.name == **name
3475                        && schema.has_capability(capability)
3476                        && (!local_only || schema.is_local())
3477                        && (!remote_only || schema.is_remote())
3478                })
3479            })
3480            .cloned()
3481            .or_else(|| {
3482                self.preferred_speech_schema(capability, local_only, remote_only)
3483                    .map(|schema| schema.name)
3484            })
3485    }
3486
3487    fn model_default_health(
3488        &self,
3489        capability: ModelCapability,
3490        configured_model: &str,
3491    ) -> ModelDefaultHealth {
3492        let schema = self
3493            .unified_registry
3494            .find_by_name(configured_model)
3495            .or_else(|| self.unified_registry.get(configured_model));
3496
3497        ModelDefaultHealth {
3498            capability,
3499            configured_model: configured_model.to_string(),
3500            available: schema.is_some_and(|model| model.available),
3501            is_local: schema.is_some_and(ModelSchema::is_local),
3502            provider: schema.map(|model| model.provider.clone()),
3503        }
3504    }
3505
3506    fn speech_sort_key(
3507        &self,
3508        capability: ModelCapability,
3509        model: &ModelSchema,
3510    ) -> (u8, u8, u8, u8, u64, u64) {
3511        let policy_preference = match capability {
3512            ModelCapability::SpeechToText if model.is_local() => self
3513                .speech_policy
3514                .preferred_local_stt
3515                .as_ref(),
3516            ModelCapability::SpeechToText => self.speech_policy.preferred_remote_stt.as_ref(),
3517            ModelCapability::TextToSpeech if model.is_local() => self
3518                .speech_policy
3519                .preferred_local_tts
3520                .as_ref(),
3521            ModelCapability::TextToSpeech => self.speech_policy.preferred_remote_tts.as_ref(),
3522            _ => None,
3523        };
3524        let local_rank = if self.speech_policy.prefer_local {
3525            if model.is_local() {
3526                0
3527            } else {
3528                1
3529            }
3530        } else if model.is_remote() {
3531            0
3532        } else {
3533            1
3534        };
3535        let availability_rank = if model.available {
3536            0
3537        } else if model.is_local() {
3538            1
3539        } else {
3540            2
3541        };
3542        let policy_rank: u8 = if policy_preference.is_some_and(|preferred| preferred == &model.name) {
3543            0
3544        } else {
3545            1
3546        };
3547        let speech_rank = match capability {
3548            ModelCapability::TextToSpeech => {
3549                if model.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
3550                    0
3551                } else if model.name == "Kokoro-82M-bf16" {
3552                    1
3553                } else if model.name == "Kokoro-82M-6bit" {
3554                    2
3555                } else {
3556                    3
3557                }
3558            }
3559            ModelCapability::SpeechToText => {
3560                if model.name == "Parakeet-TDT-0.6B-v3-MLX" {
3561                    0
3562                } else {
3563                    1
3564                }
3565            }
3566            _ => 0,
3567        };
3568        let latency_rank = model.performance.latency_p50_ms.unwrap_or(u64::MAX);
3569        let size_rank = model.cost.size_mb.unwrap_or(u64::MAX);
3570        (
3571            local_rank,
3572            availability_rank,
3573            policy_rank,
3574            speech_rank,
3575            latency_rank,
3576            size_rank,
3577        )
3578    }
3579
3580    async fn run_speech_smoke_path(
3581        &self,
3582        path: &str,
3583        tts: &ModelSchema,
3584        stt: &ModelSchema,
3585        text: &str,
3586    ) -> Result<SpeechSmokePathReport, InferenceError> {
3587        let work_dir = temp_work_dir(&format!("speech-smoke-{path}"))?;
3588        let audio_path = work_dir.join(format!("{path}.wav"));
3589        let synth = self
3590            .synthesize(SynthesizeRequest {
3591                text: text.to_string(),
3592                model: Some(tts.name.clone()),
3593                voice: default_speech_voice(tts),
3594                language: Some("en".to_string()),
3595                output_path: Some(audio_path.display().to_string()),
3596                ..SynthesizeRequest::default()
3597            })
3598            .await?;
3599        let transcript = self
3600            .transcribe(TranscribeRequest {
3601                audio_path: synth.audio_path.clone(),
3602                model: Some(stt.name.clone()),
3603                language: Some("en".to_string()),
3604                prompt: None,
3605                timestamps: false,
3606            })
3607            .await?;
3608
3609        Ok(SpeechSmokePathReport {
3610            path: path.to_string(),
3611            tts_model: synth.model_used.unwrap_or_else(|| tts.name.clone()),
3612            stt_model: transcript.model_used.unwrap_or_else(|| stt.name.clone()),
3613            audio_path: PathBuf::from(synth.audio_path),
3614            transcript: transcript.text,
3615        })
3616    }
3617
3618    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3619    async fn ensure_speech_runtime(&self) -> Result<SpeechRuntime, InferenceError> {
3620        let mut guard = self.speech_runtime.lock().await;
3621        if let Some(runtime) = guard.as_ref() {
3622            if runtime.is_ready() {
3623                return Ok(runtime.clone());
3624            }
3625        }
3626
3627        let runtime = SpeechRuntime::new(speech_runtime_root_from_models_dir(
3628            &self.config.models_dir,
3629        ));
3630        if !runtime.is_ready() {
3631            bootstrap_speech_runtime(&runtime).await?;
3632        }
3633        if !runtime.is_ready() {
3634            return Err(InferenceError::InferenceFailed(format!(
3635                "managed speech runtime is not ready at {}",
3636                runtime.root.display()
3637            )));
3638        }
3639
3640        *guard = Some(runtime.clone());
3641        Ok(runtime)
3642    }
3643
3644    async fn transcribe_local_mlx(
3645        &self,
3646        schema: &ModelSchema,
3647        req: &TranscribeRequest,
3648    ) -> Result<TranscribeResult, InferenceError> {
3649        // Native MLX transcription via Parakeet backend (no Python shelling).
3650        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
3651        {
3652            let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3653            let parakeet = backend::mlx_parakeet::ParakeetBackend::load(&model_dir)?;
3654            // Only pay the word-grouping cost when the caller asked.
3655            let (text, words) = if req.timestamps {
3656                parakeet
3657                    .transcribe_detailed(Path::new(&req.audio_path))
3658                    .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?
3659            } else {
3660                let t = parakeet
3661                    .transcribe(Path::new(&req.audio_path))
3662                    .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?;
3663                (t, Vec::new())
3664            };
3665            return Ok(TranscribeResult {
3666                text,
3667                model_used: Some(schema.name.clone()),
3668                language: req.language.clone(),
3669                words,
3670            });
3671        }
3672
3673        // Non-Apple-Silicon fallback: use Python speech runtime.
3674        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3675        {
3676            let runtime = self.ensure_speech_runtime().await?;
3677            let hf_repo = match &schema.source {
3678                ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3679                _ => unreachable!(),
3680            };
3681            let output_dir = temp_work_dir("stt")?;
3682            let output_prefix = output_dir.join("transcript");
3683            let mut args = vec![
3684                "--model".to_string(),
3685                hf_repo,
3686                "--audio".to_string(),
3687                req.audio_path.clone(),
3688                "--output-path".to_string(),
3689                output_prefix.display().to_string(),
3690                "--format".to_string(),
3691                "json".to_string(),
3692            ];
3693            if let Some(language) = &req.language {
3694                args.push("--language".to_string());
3695                args.push(normalize_lang_code(language));
3696            }
3697            if let Some(prompt) = &req.prompt {
3698                args.push("--context".to_string());
3699                args.push(prompt.clone());
3700            }
3701            if req.timestamps {
3702                args.push("--verbose".to_string());
3703            }
3704
3705            let output = run_mlx_audio_command(&runtime, "stt.generate", &args).await?;
3706            let text = read_transcription_result(&output_prefix)?
3707                .or_else(|| extract_text_from_payload(&output.stdout))
3708                .ok_or_else(|| {
3709                    InferenceError::InferenceFailed(format!(
3710                        "mlx-audio transcription returned no text: {}",
3711                        output.stderr
3712                    ))
3713                })?;
3714
3715            Ok(TranscribeResult {
3716                text,
3717                model_used: Some(schema.name.clone()),
3718                language: req.language.clone(),
3719                words: Vec::new(),
3720            })
3721        }
3722    }
3723
3724    async fn synthesize_local_mlx(
3725        &self,
3726        schema: &ModelSchema,
3727        req: &SynthesizeRequest,
3728    ) -> Result<SynthesizeResult, InferenceError> {
3729        // Single entry-point check for Qwen3-TTS advanced controls.
3730        // Hoisted here so that a Kokoro → Kokoro-bf16 fallback chain
3731        // doesn't double-warn, and so strict callers get one clean
3732        // error instead of being lied to by partial success.
3733        let requested = req.requested_advanced_controls();
3734        let repo_supports_advanced = match &schema.source {
3735            ModelSource::Mlx { hf_repo, .. } => {
3736                hf_repo.to_ascii_lowercase().contains("qwen3-tts")
3737            }
3738            _ => false,
3739        };
3740        if !requested.is_empty() && !repo_supports_advanced {
3741            if req.strict_capabilities {
3742                return Err(InferenceError::InferenceFailed(format!(
3743                    "model {name} does not support Qwen3-TTS advanced controls {requested:?}; \
3744                     route to a Qwen3-TTS model or set strict_capabilities = false to degrade",
3745                    name = schema.name,
3746                )));
3747            }
3748            tracing::warn!(
3749                model = %schema.name,
3750                fields = ?requested,
3751                "Qwen3-TTS advanced controls set on non-Qwen3-TTS backend — ignored \
3752                 (set strict_capabilities=true to error instead)"
3753            );
3754        }
3755
3756        // Native MLX synthesis via Kokoro backend (no Python shelling).
3757        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
3758        {
3759            // The native Apple-Silicon path is Kokoro-only today; a
3760            // Qwen3-TTS schema would have routed here but the backend
3761            // has no cloning support yet. Strict callers are already
3762            // stopped above; degrade-ok callers get a second, narrower
3763            // note about the native-vs-Python capability gap.
3764            if repo_supports_advanced && !requested.is_empty() {
3765                if req.strict_capabilities {
3766                    return Err(InferenceError::InferenceFailed(format!(
3767                        "native MLX TTS backend does not yet implement Qwen3-TTS advanced \
3768                         controls {requested:?}; run on non-Apple-Silicon to use the Python \
3769                         mlx-audio fallback, or set strict_capabilities = false"
3770                    )));
3771                }
3772                tracing::warn!(
3773                    model = %schema.name,
3774                    fields = ?requested,
3775                    "Qwen3-TTS advanced controls are not yet implemented in the native MLX TTS \
3776                     backend; synthesizing without cloning/voice-design"
3777                );
3778            }
3779            let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3780            let size = backend_cache::estimate_model_size(&model_dir);
3781            let cache = Arc::clone(&self.kokoro_cache);
3782            let key = schema.id.clone();
3783            let handle = cache.get_or_load(&key, size, || {
3784                backend::mlx_kokoro::KokoroBackend::load(&model_dir)
3785            })?;
3786
3787            let output_path = req.output_path.clone().unwrap_or_else(|| {
3788                let dir = std::env::temp_dir().join("car_tts");
3789                let _ = std::fs::create_dir_all(&dir);
3790                dir.join("output.wav").display().to_string()
3791            });
3792            let voice = req.voice.as_deref().unwrap_or("af_heart").to_string();
3793            let text = req.text.clone();
3794            let op = tokio::task::spawn_blocking(move || -> Result<PathBuf, InferenceError> {
3795                let mut guard = handle.lock().map_err(|_| {
3796                    InferenceError::InferenceFailed("kokoro backend mutex poisoned".into())
3797                })?;
3798                guard
3799                    .synthesize(&text, Some(&voice), Path::new(&output_path))
3800                    .map_err(|e| InferenceError::InferenceFailed(format!("native TTS: {e}")))
3801            })
3802            .await
3803            .map_err(|e| InferenceError::InferenceFailed(format!("kokoro task join: {e}")))??;
3804
3805            let final_path =
3806                materialize_audio_output(&op, req.output_path.as_deref(), &req.format)?;
3807            return Ok(SynthesizeResult {
3808                audio_path: final_path.display().to_string(),
3809                media_type: media_type_for_format(&req.format),
3810                model_used: Some(schema.name.clone()),
3811                voice_used: req.voice.clone(),
3812            });
3813        }
3814
3815        // Non-Apple-Silicon fallback: use Python speech runtime.
3816        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3817        {
3818            let runtime = self.ensure_speech_runtime().await?;
3819            let primary_hf_repo = match &schema.source {
3820                ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3821                _ => unreachable!(),
3822            };
3823            let (produced, model_used) = match self
3824                .synthesize_local_mlx_repo(&runtime, &primary_hf_repo, schema.name.as_str(), req)
3825                .await
3826            {
3827                Ok(result) => result,
3828                Err(primary_err)
3829                    if primary_hf_repo == "mlx-community/Kokoro-82M-6bit"
3830                        && kokoro_runtime_fallback_enabled() =>
3831                {
3832                    let fallback_repo = "mlx-community/Kokoro-82M-bf16";
3833                    let fallback_name = "Kokoro-82M-bf16";
3834                    match self
3835                        .synthesize_local_mlx_repo(&runtime, fallback_repo, fallback_name, req)
3836                        .await
3837                    {
3838                        Ok(result) => result,
3839                        Err(fallback_err) => {
3840                            return Err(InferenceError::InferenceFailed(format!(
3841                                "{primary_err}; fallback {fallback_name} also failed: {fallback_err}"
3842                            )));
3843                        }
3844                    }
3845                }
3846                Err(err) => return Err(err),
3847            };
3848            let final_path =
3849                materialize_audio_output(&produced, req.output_path.as_deref(), &req.format)?;
3850
3851            Ok(SynthesizeResult {
3852                audio_path: final_path.display().to_string(),
3853                media_type: media_type_for_format(&req.format),
3854                model_used: Some(model_used),
3855                voice_used: req.voice.clone(),
3856            })
3857        }
3858    }
3859
3860    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3861    async fn synthesize_local_mlx_repo(
3862        &self,
3863        runtime: &SpeechRuntime,
3864        hf_repo: &str,
3865        model_name: &str,
3866        req: &SynthesizeRequest,
3867    ) -> Result<(PathBuf, String), InferenceError> {
3868        let output_dir = temp_work_dir("tts")?;
3869        let mut args = vec![
3870            "--model".to_string(),
3871            hf_repo.to_string(),
3872            "--text".to_string(),
3873            req.text.clone(),
3874            "--output_path".to_string(),
3875            output_dir.display().to_string(),
3876        ];
3877        if let Some(voice) = &req.voice {
3878            args.push("--voice".to_string());
3879            args.push(voice.clone());
3880        }
3881        if let Some(speed) = req.speed {
3882            args.push("--speed".to_string());
3883            args.push(speed.to_string());
3884        }
3885        let repo_lower = hf_repo.to_ascii_lowercase();
3886        if repo_lower.contains("kokoro") {
3887            args.push("--lang_code".to_string());
3888            args.push(kokoro_lang_code(req.language.as_deref()).to_string());
3889        } else if let Some(language) = &req.language {
3890            args.push("--lang_code".to_string());
3891            args.push(normalize_lang_code(language));
3892        }
3893
3894        // Qwen3-TTS advanced controls — reference-audio cloning and
3895        // voice-design natural-language instruction. The
3896        // supported/unsupported decision was already made at the
3897        // `synthesize_local_mlx` entry point; here we only need to
3898        // forward the flags to the mlx-audio CLI for Qwen3-TTS repos.
3899        if repo_lower.contains("qwen3-tts") {
3900            if let Some(ref_audio) = &req.reference_audio_path {
3901                args.push("--ref_audio".to_string());
3902                args.push(ref_audio.clone());
3903            }
3904            if let Some(ref_text) = &req.reference_text {
3905                args.push("--ref_text".to_string());
3906                args.push(ref_text.clone());
3907            }
3908            if let Some(instruct) = &req.voice_instruction {
3909                args.push("--instruct".to_string());
3910                args.push(instruct.clone());
3911            }
3912        }
3913
3914        let output = if repo_lower.contains("kokoro") {
3915            let device = std::env::var("CAR_SPEECH_KOKORO_DEVICE")
3916                .or_else(|_| std::env::var("CAR_SPEECH_MLX_DEVICE"))
3917                .unwrap_or_else(|_| "cpu".to_string());
3918            let extra_env = vec![
3919                // Force MLX device (defaults to CPU to avoid Metal/NSRangeException crashes)
3920                ("MLX_DEVICE".to_string(), device),
3921                // Prevent MPS/Metal kernel crashes by enabling CPU fallback
3922                ("PYTORCH_ENABLE_MPS_FALLBACK".to_string(), "1".to_string()),
3923            ];
3924            run_mlx_audio_command_with_env(runtime, "tts.generate", &args, &extra_env).await?
3925        } else {
3926            run_mlx_audio_command(runtime, "tts.generate", &args).await?
3927        };
3928        let produced = find_audio_file(&output_dir)?.ok_or_else(|| {
3929            let hint = if repo_lower.contains("kokoro") {
3930                ". Kokoro models may crash on GPU — try CAR_SPEECH_KOKORO_DEVICE=cpu or use the default Qwen3-TTS model"
3931            } else {
3932                ""
3933            };
3934            InferenceError::InferenceFailed(format!(
3935                "mlx-audio synthesis produced no audio file: {}{}",
3936                output.stderr, hint
3937            ))
3938        })?;
3939        Ok((produced, model_name.to_string()))
3940    }
3941
3942
3943    async fn transcribe_elevenlabs(
3944        &self,
3945        schema: &ModelSchema,
3946        req: &TranscribeRequest,
3947    ) -> Result<TranscribeResult, InferenceError> {
3948        let (endpoint, api_key) = elevenlabs_auth(schema)?;
3949        let file_name = Path::new(&req.audio_path)
3950            .file_name()
3951            .and_then(|f| f.to_str())
3952            .unwrap_or("audio.wav")
3953            .to_string();
3954        let audio_bytes = tokio::fs::read(&req.audio_path).await?;
3955        let file_part = Part::bytes(audio_bytes).file_name(file_name);
3956        let mut form = Form::new()
3957            .text("model_id", schema.name.clone())
3958            .part("file", file_part);
3959        if let Some(language) = &req.language {
3960            form = form.text("language_code", language.clone());
3961        }
3962
3963        let resp = self
3964            .remote_backend
3965            .client
3966            .post(format!(
3967                "{}/v1/speech-to-text",
3968                endpoint.trim_end_matches('/')
3969            ))
3970            .header("xi-api-key", api_key)
3971            .multipart(form)
3972            .send()
3973            .await
3974            .map_err(|e| {
3975                InferenceError::InferenceFailed(format!("ElevenLabs STT request failed: {e}"))
3976            })?;
3977        let status = resp.status();
3978        let body = resp.text().await.map_err(|e| {
3979            InferenceError::InferenceFailed(format!("read ElevenLabs STT body: {e}"))
3980        })?;
3981        if !status.is_success() {
3982            return Err(InferenceError::InferenceFailed(format!(
3983                "ElevenLabs STT returned {status}: {body}"
3984            )));
3985        }
3986        let payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
3987            InferenceError::InferenceFailed(format!("parse ElevenLabs STT response: {e}"))
3988        })?;
3989        let text = payload
3990            .get("text")
3991            .and_then(|v| v.as_str())
3992            .map(str::to_string)
3993            .ok_or_else(|| {
3994                InferenceError::InferenceFailed("ElevenLabs STT response missing text".into())
3995            })?;
3996
3997        Ok(TranscribeResult {
3998            text,
3999            model_used: Some(schema.name.clone()),
4000            language: payload
4001                .get("language_code")
4002                .and_then(|v| v.as_str())
4003                .map(str::to_string),
4004            words: Vec::new(),
4005        })
4006    }
4007
4008    async fn synthesize_elevenlabs(
4009        &self,
4010        schema: &ModelSchema,
4011        req: &SynthesizeRequest,
4012    ) -> Result<SynthesizeResult, InferenceError> {
4013        // ElevenLabs doesn't expose a Qwen3-TTS-style cloning or
4014        // voice-design surface on its `/v1/text-to-speech` endpoint;
4015        // honor the strict_capabilities contract here too.
4016        let requested = req.requested_advanced_controls();
4017        if !requested.is_empty() {
4018            if req.strict_capabilities {
4019                return Err(InferenceError::InferenceFailed(format!(
4020                    "ElevenLabs backend does not support Qwen3-TTS advanced controls \
4021                     {requested:?}; route to a Qwen3-TTS model or set strict_capabilities = false"
4022                )));
4023            }
4024            tracing::warn!(
4025                model = %schema.name,
4026                fields = ?requested,
4027                "Qwen3-TTS advanced controls ignored by ElevenLabs backend"
4028            );
4029        }
4030        let (endpoint, api_key) = elevenlabs_auth(schema)?;
4031        let voice_id = req
4032            .voice
4033            .clone()
4034            .unwrap_or_else(|| "JBFqnCBsd6RMkjVDRZzb".to_string());
4035        let output_format = elevenlabs_output_format(&req.format);
4036        let url = format!(
4037            "{}/v1/text-to-speech/{}?output_format={}",
4038            endpoint.trim_end_matches('/'),
4039            voice_id,
4040            output_format
4041        );
4042
4043        let mut body = serde_json::json!({
4044            "text": req.text,
4045            "model_id": schema.name,
4046        });
4047        if let Some(language) = &req.language {
4048            body["language_code"] = serde_json::Value::String(language.clone());
4049        }
4050
4051        let resp = self
4052            .remote_backend
4053            .client
4054            .post(url)
4055            .header("xi-api-key", api_key)
4056            .header("Content-Type", "application/json")
4057            .json(&body)
4058            .send()
4059            .await
4060            .map_err(|e| {
4061                InferenceError::InferenceFailed(format!("ElevenLabs TTS request failed: {e}"))
4062            })?;
4063        let status = resp.status();
4064        let audio = resp.bytes().await.map_err(|e| {
4065            InferenceError::InferenceFailed(format!("read ElevenLabs TTS body: {e}"))
4066        })?;
4067        if !status.is_success() {
4068            let err_body = String::from_utf8_lossy(&audio);
4069            return Err(InferenceError::InferenceFailed(format!(
4070                "ElevenLabs TTS returned {status}: {err_body}"
4071            )));
4072        }
4073
4074        let final_path = requested_or_temp_output(req.output_path.as_deref(), &req.format)?;
4075        ensure_parent_dir(&final_path)?;
4076        tokio::fs::write(&final_path, &audio).await?;
4077
4078        Ok(SynthesizeResult {
4079            audio_path: final_path.display().to_string(),
4080            media_type: media_type_for_format(&req.format),
4081            model_used: Some(schema.name.clone()),
4082            voice_used: Some(voice_id),
4083        })
4084    }
4085}
4086
4087#[derive(Default)]
4088struct ProviderAccumulator {
4089    configured: bool,
4090    local_models: usize,
4091    remote_models: usize,
4092    available_models: usize,
4093    capabilities: std::collections::HashSet<ModelCapability>,
4094}
4095
4096// ─── Python Speech Runtime (non-Apple-Silicon only) ──────────────────────────
4097// On Apple Silicon, speech uses native MLX backends (mlx_parakeet, mlx_kokoro).
4098
4099#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4100struct CommandOutput {
4101    stdout: String,
4102    stderr: String,
4103}
4104
4105#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4106#[derive(Debug, Clone)]
4107struct SpeechRuntime {
4108    root: PathBuf,
4109    python: PathBuf,
4110    stt_program: PathBuf,
4111    tts_program: PathBuf,
4112}
4113
4114#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4115impl SpeechRuntime {
4116    fn new(root: PathBuf) -> Self {
4117        let bin_dir = root.join("bin");
4118        Self {
4119            root,
4120            python: bin_dir.join("python"),
4121            stt_program: bin_dir.join("mlx_audio.stt.generate"),
4122            tts_program: bin_dir.join("mlx_audio.tts.generate"),
4123        }
4124    }
4125
4126    fn is_ready(&self) -> bool {
4127        self.python.exists() && self.stt_program.exists() && self.tts_program.exists()
4128    }
4129
4130    fn command_for(&self, subcommand: &str) -> Result<&Path, InferenceError> {
4131        match subcommand {
4132            "stt.generate" => Ok(&self.stt_program),
4133            "tts.generate" => Ok(&self.tts_program),
4134            _ => Err(InferenceError::InferenceFailed(format!(
4135                "unknown speech subcommand: {subcommand}"
4136            ))),
4137        }
4138    }
4139}
4140
4141
4142#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4143async fn run_mlx_audio_command(
4144    runtime: &SpeechRuntime,
4145    subcommand: &str,
4146    args: &[String],
4147) -> Result<CommandOutput, InferenceError> {
4148    run_mlx_audio_command_with_env(runtime, subcommand, args, &[]).await
4149}
4150
4151#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4152async fn run_mlx_audio_command_with_env(
4153    runtime: &SpeechRuntime,
4154    subcommand: &str,
4155    args: &[String],
4156    envs: &[(String, String)],
4157) -> Result<CommandOutput, InferenceError> {
4158    let program = runtime.command_for(subcommand)?;
4159    let mut command = Command::new(program);
4160    command.args(args);
4161    for (key, value) in envs {
4162        command.env(key, value);
4163    }
4164    let output = command.output().await.map_err(|err| {
4165        InferenceError::InferenceFailed(format!("{}: {err}", program.display()))
4166    })?;
4167
4168    if output.status.success() {
4169        Ok(CommandOutput {
4170            stdout: String::from_utf8_lossy(&output.stdout).to_string(),
4171            stderr: String::from_utf8_lossy(&output.stderr).to_string(),
4172        })
4173    } else {
4174        Err(InferenceError::InferenceFailed(format!(
4175            "{} exited with {}: {}",
4176            program.display(),
4177            output.status,
4178            String::from_utf8_lossy(&output.stderr)
4179        )))
4180    }
4181}
4182
4183
4184#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4185async fn bootstrap_speech_runtime(runtime: &SpeechRuntime) -> Result<(), InferenceError> {
4186    std::fs::create_dir_all(&runtime.root)?;
4187    let python = select_speech_python()?;
4188
4189    run_command(
4190        "uv",
4191        &[
4192            "venv".to_string(),
4193            "--python".to_string(),
4194            python,
4195            runtime.root.display().to_string(),
4196        ],
4197    )
4198    .await?;
4199
4200    run_command(
4201        "uv",
4202        &[
4203            "pip".to_string(),
4204            "install".to_string(),
4205            "--python".to_string(),
4206            runtime.python.display().to_string(),
4207            speech_runtime_mlx_audio_spec(),
4208            "misaki[en]".to_string(),
4209            speech_runtime_spacy_model_spec(),
4210        ],
4211    )
4212    .await?;
4213
4214    Ok(())
4215}
4216
4217
4218#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4219async fn run_command(program: &str, args: &[String]) -> Result<(), InferenceError> {
4220    let output = Command::new(program)
4221        .args(args)
4222        .output()
4223        .await
4224        .map_err(|err| InferenceError::InferenceFailed(format!("{program}: {err}")))?;
4225
4226    if output.status.success() {
4227        Ok(())
4228    } else {
4229        Err(InferenceError::InferenceFailed(format!(
4230            "{} exited with {}: {}",
4231            program,
4232            output.status,
4233            String::from_utf8_lossy(&output.stderr)
4234        )))
4235    }
4236}
4237
4238#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4239fn select_speech_python() -> Result<String, InferenceError> {
4240    if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4241        if !path.trim().is_empty() {
4242            return Ok(path);
4243        }
4244    }
4245
4246    for candidate in ["python3.13", "python3.12", "python3.11"] {
4247        if command_in_path(candidate) {
4248            return Ok(candidate.to_string());
4249        }
4250    }
4251
4252    Err(InferenceError::InferenceFailed(
4253        "no supported Python found for managed speech runtime (tried python3.13, python3.12, python3.11)".into(),
4254    ))
4255}
4256
4257#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4258fn detect_speech_python() -> Option<String> {
4259    if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4260        if !path.trim().is_empty() {
4261            return Some(path);
4262        }
4263    }
4264
4265    ["python3.13", "python3.12", "python3.11"]
4266        .into_iter()
4267        .find(|candidate| command_in_path(candidate))
4268        .map(str::to_string)
4269}
4270
4271#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4272fn speech_runtime_root_from_models_dir(_models_dir: &Path) -> PathBuf {
4273    if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
4274        if !path.trim().is_empty() {
4275            return PathBuf::from(path);
4276        }
4277    }
4278
4279    std::env::var("HOME")
4280        .map(PathBuf::from)
4281        .unwrap_or_else(|_| PathBuf::from("."))
4282        .join(".car")
4283        .join("speech-runtime")
4284}
4285
4286#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4287fn command_in_path(name: &str) -> bool {
4288    std::env::var_os("PATH")
4289        .map(|paths| {
4290            std::env::split_paths(&paths).any(|dir| {
4291                let path = dir.join(name);
4292                path.exists() && path.is_file()
4293            })
4294        })
4295        .unwrap_or(false)
4296}
4297
4298fn speech_model_cached(schema: &ModelSchema) -> bool {
4299    match &schema.source {
4300        ModelSource::Mlx { hf_repo, .. } => {
4301            huggingface_repo_has_snapshot(hf_repo)
4302        }
4303        ModelSource::Proprietary { auth, .. } => match auth {
4304            ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4305            ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4306            ProprietaryAuth::OAuth2Pkce { .. } => false,
4307        },
4308        _ => false,
4309    }
4310}
4311
4312fn remove_huggingface_repo_cache(repo_id: &str) -> Result<(), InferenceError> {
4313    let repo_dir = std::env::var("HF_HOME")
4314        .map(PathBuf::from)
4315        .unwrap_or_else(|_| {
4316            std::env::var("HOME")
4317                .map(PathBuf::from)
4318                .unwrap_or_else(|_| PathBuf::from("."))
4319                .join(".cache")
4320                .join("huggingface")
4321        })
4322        .join("hub")
4323        .join(format!("models--{}", repo_id.replace('/', "--")));
4324
4325    if repo_dir.exists() {
4326        std::fs::remove_dir_all(repo_dir)?;
4327    }
4328    Ok(())
4329}
4330
4331fn model_source_configured(schema: &ModelSchema) -> bool {
4332    match &schema.source {
4333        ModelSource::RemoteApi {
4334            api_key_env,
4335            api_key_envs,
4336            ..
4337        } => {
4338            std::env::var(api_key_env).is_ok()
4339                || api_key_envs
4340                    .iter()
4341                    .any(|env_var| std::env::var(env_var).is_ok())
4342        }
4343        ModelSource::Proprietary { auth, .. } => match auth {
4344            ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4345            ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4346            ProprietaryAuth::OAuth2Pkce { .. } => false,
4347        },
4348        ModelSource::VllmMlx { .. } => std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available,
4349        ModelSource::Ollama { .. } => schema.available,
4350        ModelSource::Mlx { .. } | ModelSource::Local { .. } => {
4351            true
4352        }
4353        ModelSource::AppleFoundationModels { .. } => schema.available,
4354    }
4355}
4356
4357fn all_model_capabilities() -> [ModelCapability; 13] {
4358    [
4359        ModelCapability::Generate,
4360        ModelCapability::Embed,
4361        ModelCapability::Classify,
4362        ModelCapability::Code,
4363        ModelCapability::Reasoning,
4364        ModelCapability::Summarize,
4365        ModelCapability::ToolUse,
4366        ModelCapability::MultiToolCall,
4367        ModelCapability::Vision,
4368        ModelCapability::SpeechToText,
4369        ModelCapability::TextToSpeech,
4370        ModelCapability::ImageGeneration,
4371        ModelCapability::VideoGeneration,
4372    ]
4373}
4374
4375fn sort_capabilities(mut capabilities: Vec<ModelCapability>) -> Vec<ModelCapability> {
4376    capabilities.sort_by_key(|capability| {
4377        all_model_capabilities()
4378            .iter()
4379            .position(|candidate| candidate == capability)
4380            .unwrap_or(usize::MAX)
4381    });
4382    capabilities
4383}
4384
4385fn speech_model_source_label(schema: &ModelSchema) -> String {
4386    match &schema.source {
4387        ModelSource::Mlx { hf_repo, .. } => format!("mlx:{hf_repo}"),
4388        ModelSource::Proprietary {
4389            provider, endpoint, ..
4390        } => format!("proprietary:{provider}:{endpoint}"),
4391        ModelSource::RemoteApi { endpoint, .. } => format!("remote:{endpoint}"),
4392        ModelSource::Local { hf_repo, .. } => format!("local:{hf_repo}"),
4393        ModelSource::VllmMlx {
4394            endpoint,
4395            model_name,
4396        } => format!("vllm-mlx:{endpoint}:{model_name}"),
4397        ModelSource::Ollama { model_tag, host } => format!("ollama:{host}:{model_tag}"),
4398        ModelSource::AppleFoundationModels { use_case } => {
4399            format!("apple-foundation:{}", use_case.as_deref().unwrap_or("default"))
4400        }
4401    }
4402}
4403
4404/// Build the Qwen3-Reranker chat-template prompt for a single
4405/// `(query, document)` candidate.
4406///
4407/// The format matches upstream `reranker_quick_start.py`: a system
4408/// message pinning the answer space to yes/no, a user turn with
4409/// `<Instruct>/<Query>/<Document>`, and an assistant prefix with a
4410/// closed empty `<think>` block to force non-thinking classification.
4411fn rerank_prompt(instruction: &str, query: &str, document: &str) -> String {
4412    const SYSTEM: &str = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".";
4413    format!(
4414        "<|im_start|>system\n{SYSTEM}<|im_end|>\n\
4415         <|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n\
4416         <|im_start|>assistant\n<think>\n\n</think>\n\n"
4417    )
4418}
4419
4420/// Interpret the first useful token from a Qwen3-Reranker greedy
4421/// decode as a relevance score. Scans up to the first few tokens so
4422/// a leading space, BOS artifact, or stray newline doesn't poison the
4423/// result. Returns 1.0 for "yes", 0.0 for "no", 0.5 on unexpected
4424/// output (with a warn so the mismatch is visible).
4425fn score_from_rerank_output(text: &str, model_name: &str) -> f32 {
4426    // Replace every non-alphanumeric byte with a space, lowercase,
4427    // and scan the first few whitespace-separated tokens for
4428    // "yes"/"no". This strips chat-template tags (`<|im_end|>`),
4429    // punctuation, and underscores cleanly without special-casing.
4430    let normalized: String = text
4431        .to_ascii_lowercase()
4432        .chars()
4433        .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
4434        .collect();
4435    for tok in normalized.split_ascii_whitespace().take(5) {
4436        match tok {
4437            "yes" => return 1.0,
4438            "no" => return 0.0,
4439            _ => continue,
4440        }
4441    }
4442    tracing::warn!(
4443        model = %model_name,
4444        output = %text,
4445        "rerank: first tokens contain neither `yes` nor `no`; returning neutral 0.5"
4446    );
4447    0.5
4448}
4449
4450fn default_speech_voice(schema: &ModelSchema) -> Option<String> {
4451    if schema.provider == "elevenlabs" {
4452        Some("JBFqnCBsd6RMkjVDRZzb".to_string())
4453    } else if schema.name == "Kokoro-82M-6bit" || schema.name == "Kokoro-82M-bf16" {
4454        Some("af_heart".to_string())
4455    } else if schema.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
4456        Some("Chelsie".to_string())
4457    } else {
4458        None
4459    }
4460}
4461
4462fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
4463    find_latest_huggingface_snapshot(repo_id).is_some()
4464}
4465
4466fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
4467    let cache_root = std::env::var("HF_HOME")
4468        .map(PathBuf::from)
4469        .unwrap_or_else(|_| {
4470            std::env::var("HOME")
4471                .map(PathBuf::from)
4472                .unwrap_or_else(|_| PathBuf::from("."))
4473                .join(".cache")
4474                .join("huggingface")
4475        })
4476        .join("hub");
4477    cache_root.join(format!("models--{}", repo_id.replace('/', "--")))
4478}
4479
4480fn find_latest_huggingface_snapshot(repo_id: &str) -> Option<PathBuf> {
4481    let snapshots = huggingface_repo_dir(repo_id).join("snapshots");
4482    std::fs::read_dir(snapshots)
4483        .ok()?
4484        .filter_map(Result::ok)
4485        .map(|entry| entry.path())
4486        .find(|path| path.is_dir() && snapshot_looks_ready(path))
4487}
4488
4489fn snapshot_looks_ready(path: &Path) -> bool {
4490    if path.join("config.json").exists() || path.join("model_index.json").exists() {
4491        return true;
4492    }
4493    snapshot_contains_ext(path, "safetensors")
4494}
4495
4496fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
4497    let Ok(entries) = std::fs::read_dir(root) else {
4498        return false;
4499    };
4500    entries.filter_map(Result::ok).any(|entry| {
4501        let path = entry.path();
4502        if path.is_dir() {
4503            snapshot_contains_ext(&path, ext)
4504        } else {
4505            path.extension()
4506                .and_then(|value| value.to_str())
4507                .map(|value| value.eq_ignore_ascii_case(ext))
4508                .unwrap_or(false)
4509        }
4510    })
4511}
4512
4513fn count_files_recursive(root: &Path) -> usize {
4514    let Ok(entries) = std::fs::read_dir(root) else {
4515        return 0;
4516    };
4517    entries
4518        .filter_map(Result::ok)
4519        .map(|entry| entry.path())
4520        .map(|path| {
4521            if path.is_dir() {
4522                count_files_recursive(&path)
4523            } else if path.is_file() {
4524                1
4525            } else {
4526                0
4527            }
4528        })
4529        .sum()
4530}
4531
4532async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
4533    let api = hf_hub::api::tokio::ApiBuilder::from_env()
4534        .with_progress(false)
4535        .build()
4536        .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
4537    let repo = api.model(repo_id.to_string());
4538    let info = repo
4539        .info()
4540        .await
4541        .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
4542
4543    let snapshot_path = huggingface_repo_dir(repo_id)
4544        .join("snapshots")
4545        .join(&info.sha);
4546    let mut downloaded = 0usize;
4547    for sibling in &info.siblings {
4548        let local_path = snapshot_path.join(&sibling.rfilename);
4549        if local_path.exists() {
4550            downloaded += 1;
4551            continue;
4552        }
4553        repo.download(&sibling.rfilename).await.map_err(|e| {
4554            InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
4555        })?;
4556        downloaded += 1;
4557    }
4558
4559    Ok((snapshot_path, downloaded))
4560}
4561
4562fn temp_work_dir(prefix: &str) -> Result<PathBuf, InferenceError> {
4563    let unique = SystemTime::now()
4564        .duration_since(UNIX_EPOCH)
4565        .map_err(|e| InferenceError::InferenceFailed(format!("clock error: {e}")))?
4566        .as_nanos();
4567    let dir = std::env::temp_dir().join(format!("car-inference-{prefix}-{unique}"));
4568    std::fs::create_dir_all(&dir)?;
4569    Ok(dir)
4570}
4571
4572fn ensure_parent_dir(path: &Path) -> Result<(), InferenceError> {
4573    if let Some(parent) = path.parent() {
4574        std::fs::create_dir_all(parent)?;
4575    }
4576    Ok(())
4577}
4578
4579fn requested_or_temp_output(
4580    output_path: Option<&str>,
4581    format: &str,
4582) -> Result<PathBuf, InferenceError> {
4583    if let Some(path) = output_path {
4584        return Ok(PathBuf::from(path));
4585    }
4586    let dir = temp_work_dir("audio-out")?;
4587    Ok(dir.join(format!("speech.{format}")))
4588}
4589
4590fn requested_or_temp_media_output(
4591    output_path: Option<&str>,
4592    format: &str,
4593    stem: &str,
4594) -> Result<PathBuf, InferenceError> {
4595    if let Some(path) = output_path {
4596        return Ok(PathBuf::from(path));
4597    }
4598    let dir = temp_work_dir(&format!("{stem}-out"))?;
4599    Ok(dir.join(format!("{stem}.{format}")))
4600}
4601
4602fn materialize_audio_output(
4603    produced: &Path,
4604    requested: Option<&str>,
4605    format: &str,
4606) -> Result<PathBuf, InferenceError> {
4607    if let Some(path) = requested {
4608        let dest = PathBuf::from(path);
4609        ensure_parent_dir(&dest)?;
4610        std::fs::copy(produced, &dest)?;
4611        Ok(dest)
4612    } else {
4613        let dest = requested_or_temp_output(None, format)?;
4614        ensure_parent_dir(&dest)?;
4615        std::fs::copy(produced, &dest)?;
4616        Ok(dest)
4617    }
4618}
4619
4620fn materialize_binary_output(
4621    produced: &Path,
4622    requested: Option<&str>,
4623    format: &str,
4624    stem: &str,
4625) -> Result<PathBuf, InferenceError> {
4626    let dest = requested_or_temp_media_output(requested, format, stem)?;
4627    ensure_parent_dir(&dest)?;
4628    std::fs::copy(produced, &dest)?;
4629    Ok(dest)
4630}
4631
4632fn find_generated_file(
4633    root: &Path,
4634    extensions: &[&str],
4635) -> Result<Option<PathBuf>, InferenceError> {
4636    let entries = std::fs::read_dir(root)?;
4637    let mut candidates: Vec<PathBuf> = entries
4638        .filter_map(Result::ok)
4639        .map(|entry| entry.path())
4640        .filter(|path| {
4641            path.is_file()
4642                && path
4643                    .extension()
4644                    .and_then(|ext| ext.to_str())
4645                    .map(|ext| extensions.iter().any(|candidate| candidate.eq_ignore_ascii_case(ext)))
4646                    .unwrap_or(false)
4647        })
4648        .collect();
4649    candidates.sort();
4650    Ok(candidates.pop())
4651}
4652
4653fn media_type_for_image_format(format: &str) -> String {
4654    match format.to_ascii_lowercase().as_str() {
4655        "jpg" | "jpeg" => "image/jpeg".to_string(),
4656        "webp" => "image/webp".to_string(),
4657        _ => "image/png".to_string(),
4658    }
4659}
4660
4661fn media_type_for_video_format(format: &str) -> String {
4662    match format.to_ascii_lowercase().as_str() {
4663        "mov" => "video/quicktime".to_string(),
4664        "gif" => "image/gif".to_string(),
4665        _ => "video/mp4".to_string(),
4666    }
4667}
4668
4669fn read_transcription_result(output_prefix: &Path) -> Result<Option<String>, InferenceError> {
4670    let candidates = [
4671        output_prefix.with_extension("json"),
4672        output_prefix.to_path_buf(),
4673    ];
4674
4675    for path in candidates {
4676        if path.exists() {
4677            let contents = std::fs::read_to_string(path)?;
4678            if let Some(text) = extract_text_from_payload(&contents) {
4679                return Ok(Some(text));
4680            }
4681        }
4682    }
4683
4684    Ok(None)
4685}
4686
4687fn extract_text_from_payload(payload: &str) -> Option<String> {
4688    let value: serde_json::Value = serde_json::from_str(payload).ok()?;
4689    if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
4690        return Some(text.to_string());
4691    }
4692    if let Some(transcripts) = value.get("transcripts").and_then(|v| v.as_array()) {
4693        let joined = transcripts
4694            .iter()
4695            .filter_map(|item| item.get("text").and_then(|v| v.as_str()))
4696            .collect::<Vec<_>>()
4697            .join("\n");
4698        if !joined.is_empty() {
4699            return Some(joined);
4700        }
4701    }
4702    if let Some(items) = value.as_array() {
4703        let joined = items
4704            .iter()
4705            .filter_map(|item| {
4706                item.get("text")
4707                    .or_else(|| item.get("Content"))
4708                    .and_then(|v| v.as_str())
4709            })
4710            .collect::<Vec<_>>()
4711            .join(" ");
4712        if !joined.is_empty() {
4713            return Some(joined);
4714        }
4715    }
4716    None
4717}
4718
4719fn find_audio_file(output_dir: &Path) -> Result<Option<PathBuf>, InferenceError> {
4720    let mut audio_files = Vec::new();
4721    collect_audio_files(output_dir, &mut audio_files)?;
4722    audio_files.sort();
4723    Ok(audio_files.into_iter().next())
4724}
4725
4726fn collect_audio_files(dir: &Path, audio_files: &mut Vec<PathBuf>) -> Result<(), InferenceError> {
4727    for entry in std::fs::read_dir(dir)? {
4728        let path = entry?.path();
4729        if path.is_dir() {
4730            collect_audio_files(&path, audio_files)?;
4731        } else if matches!(
4732            path.extension().and_then(|ext| ext.to_str()),
4733            Some("wav" | "mp3" | "flac" | "pcm" | "m4a")
4734        ) {
4735            audio_files.push(path);
4736        }
4737    }
4738    Ok(())
4739}
4740
4741fn media_type_for_format(format: &str) -> String {
4742    match format.to_ascii_lowercase().as_str() {
4743        "mp3" => "audio/mpeg".to_string(),
4744        "flac" => "audio/flac".to_string(),
4745        "pcm" => "audio/L16".to_string(),
4746        "m4a" => "audio/mp4".to_string(),
4747        _ => "audio/wav".to_string(),
4748    }
4749}
4750
4751#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4752fn kokoro_lang_code(language: Option<&str>) -> &'static str {
4753    match language.unwrap_or("en").to_ascii_lowercase().as_str() {
4754        "en-gb" | "british" | "british english" => "b",
4755        "ja" | "japanese" => "j",
4756        "zh" | "zh-cn" | "mandarin" | "chinese" => "z",
4757        "es" | "spanish" => "e",
4758        "fr" | "french" => "f",
4759        _ => "a",
4760    }
4761}
4762
4763fn normalize_lang_code(language: &str) -> String {
4764    match language.to_ascii_lowercase().as_str() {
4765        "english" | "en-us" | "en_us" => "en".to_string(),
4766        "spanish" => "es".to_string(),
4767        "french" => "fr".to_string(),
4768        "japanese" => "ja".to_string(),
4769        "chinese" | "mandarin" => "zh".to_string(),
4770        other => match other {
4771            "en" | "es" | "fr" | "ja" | "zh" => other.to_string(),
4772            _ => "en".to_string(),
4773        },
4774    }
4775}
4776
4777fn elevenlabs_auth(schema: &ModelSchema) -> Result<(String, String), InferenceError> {
4778    match &schema.source {
4779        ModelSource::Proprietary {
4780            endpoint,
4781            auth: schema::ProprietaryAuth::ApiKeyEnv { env_var },
4782            ..
4783        } => {
4784            let key = car_secrets::resolve_env_or_keychain(env_var).ok_or_else(|| {
4785                InferenceError::InferenceFailed(format!(
4786                    "missing API key {env_var}; set the environment variable or \
4787                     store it with `car secrets put {env_var}`"
4788                ))
4789            })?;
4790            Ok((endpoint.clone(), key))
4791        }
4792        _ => Err(InferenceError::InferenceFailed(format!(
4793            "model {} is not an ElevenLabs proprietary model",
4794            schema.id
4795        ))),
4796    }
4797}
4798
4799fn elevenlabs_output_format(format: &str) -> &'static str {
4800    match format.to_ascii_lowercase().as_str() {
4801        "mp3" => "mp3_44100_128",
4802        "pcm" => "pcm_16000",
4803        _ => "wav_44100",
4804    }
4805}
4806
4807fn benchmark_priors_paths(models_dir: &Path) -> Vec<PathBuf> {
4808    let mut paths = Vec::new();
4809
4810    let direct = models_dir.join("benchmark_priors.json");
4811    if !paths.contains(&direct) {
4812        paths.push(direct);
4813    }
4814
4815    if let Some(parent) = models_dir.parent() {
4816        let parent_path = parent.join("benchmark_priors.json");
4817        if !paths.contains(&parent_path) {
4818            paths.push(parent_path);
4819        }
4820    }
4821
4822    if let Some(path) = std::env::var_os("CAR_BENCHMARK_PRIORS_PATH") {
4823        let path = PathBuf::from(path);
4824        if !paths.contains(&path) {
4825            paths.push(path);
4826        }
4827    }
4828
4829    paths
4830}
4831
4832fn load_benchmark_prior_health(
4833    models_dir: &Path,
4834    schemas: &[ModelSchema],
4835) -> Vec<ModelBenchmarkPriorHealth> {
4836    let mut priors = std::collections::BTreeMap::new();
4837    for path in benchmark_priors_paths(models_dir) {
4838        let Ok(loaded) = routing_ext::load_benchmark_priors(&path) else {
4839            continue;
4840        };
4841        for (model_id, prior) in loaded {
4842            let model_name = schemas
4843                .iter()
4844                .find(|schema| schema.id == model_id)
4845                .map(|schema| schema.name.clone());
4846            priors.insert(
4847                model_id.clone(),
4848                ModelBenchmarkPriorHealth {
4849                    model_id,
4850                    model_name,
4851                    overall_score: prior.overall_score,
4852                    overall_latency_ms: prior.overall_latency_ms,
4853                    task_scores: prior.task_scores,
4854                    task_latency_ms: prior.task_latency_ms,
4855                    source_path: path.clone(),
4856                },
4857            );
4858        }
4859    }
4860
4861    priors.into_values().collect()
4862}
4863
4864#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4865fn kokoro_runtime_fallback_enabled() -> bool {
4866    std::env::var("CAR_SPEECH_KOKORO_FALLBACK")
4867        .ok()
4868        .map(|value| !matches!(value.trim().to_ascii_lowercase().as_str(), "0" | "false" | "off"))
4869        .unwrap_or(true)
4870}
4871
4872#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4873fn speech_runtime_mlx_audio_spec() -> String {
4874    std::env::var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC")
4875        .ok()
4876        .filter(|value| !value.trim().is_empty())
4877        .unwrap_or_else(|| "mlx-audio==0.4.2".to_string())
4878}
4879
4880#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4881fn speech_runtime_spacy_model_spec() -> String {
4882    std::env::var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC")
4883        .ok()
4884        .filter(|value| !value.trim().is_empty())
4885        .unwrap_or_else(|| {
4886            "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl".to_string()
4887        })
4888}
4889
4890#[cfg(test)]
4891mod tests {
4892    use super::*;
4893    use tempfile::TempDir;
4894
4895    /// Tests that mutate process-wide env vars must hold this lock to avoid
4896    /// races with parallel tests (env vars are global mutable state).
4897    static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
4898
4899    fn test_config(models_dir: PathBuf) -> InferenceConfig {
4900        InferenceConfig {
4901            models_dir,
4902            device: None,
4903            generation_model: "Qwen3-0.6B".into(),
4904            preferred_generation_model: None,
4905            embedding_model: "Qwen3-Embedding-0.6B".into(),
4906            preferred_embedding_model: None,
4907            classification_model: "Qwen3-0.6B".into(),
4908            preferred_classification_model: None,
4909        }
4910    }
4911
4912    #[tokio::test]
4913    async fn tokenize_rejects_known_remote_model_with_unsupported_mode() {
4914        // The unified registry's built-in catalog includes remote models like
4915        // OpenAI / Anthropic ones. Regardless of which exact id ships, we just
4916        // need any non-local schema to confirm the pre-flight catches it
4917        // before we try (and fail) to load a non-existent local backend.
4918        let tmp = TempDir::new().unwrap();
4919        let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
4920        let remote_id = engine
4921            .list_schemas()
4922            .into_iter()
4923            .find(|s| !s.is_local())
4924            .map(|s| s.id)
4925            .expect(
4926                "built-in catalog should include at least one remote model schema",
4927            );
4928
4929        let err = engine
4930            .tokenize(&remote_id, "hello")
4931            .await
4932            .expect_err("remote tokenize must error");
4933        match err {
4934            InferenceError::UnsupportedMode { mode, backend, .. } => {
4935                assert_eq!(mode, "tokenize/detokenize");
4936                assert_eq!(backend, "remote");
4937            }
4938            other => panic!("expected UnsupportedMode, got {other:?}"),
4939        }
4940
4941        let err = engine
4942            .detokenize(&remote_id, &[1, 2, 3])
4943            .await
4944            .expect_err("remote detokenize must error");
4945        assert!(
4946            matches!(err, InferenceError::UnsupportedMode { .. }),
4947            "expected UnsupportedMode, got {err:?}"
4948        );
4949    }
4950
4951    #[test]
4952    fn engine_loads_benchmark_priors_on_startup() {
4953        let _env = ENV_MUTEX.lock().unwrap();
4954        let tmp = TempDir::new().unwrap();
4955        let priors_path = tmp.path().join("benchmark_priors.json");
4956        std::fs::write(
4957            &priors_path,
4958            serde_json::json!({
4959                "model_id": "qwen/qwen3-8b:q4_k_m",
4960                "overall_score": 0.88
4961            })
4962            .to_string(),
4963        )
4964        .unwrap();
4965
4966        unsafe {
4967            std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
4968        }
4969
4970        let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
4971        let tracker = engine.outcome_tracker.blocking_read();
4972        let profile = tracker
4973            .profile("qwen/qwen3-8b:q4_k_m")
4974            .expect("benchmark prior should create a profile");
4975        assert!((profile.ema_quality - 0.88).abs() < 0.01);
4976
4977        unsafe {
4978            std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
4979        }
4980    }
4981
4982    #[test]
4983    fn benchmark_priors_do_not_override_observed_profiles() {
4984        let _env = ENV_MUTEX.lock().unwrap();
4985        let tmp = TempDir::new().unwrap();
4986        let models_dir = tmp.path().join("models");
4987        std::fs::create_dir_all(&models_dir).unwrap();
4988
4989        let observed = vec![ModelProfile {
4990            model_id: "qwen/qwen3-8b:q4_k_m".into(),
4991            total_calls: 12,
4992            success_count: 3,
4993            fail_count: 9,
4994            total_latency_ms: 1200,
4995            total_input_tokens: 0,
4996            total_output_tokens: 0,
4997            task_stats: std::collections::HashMap::new(),
4998            ema_quality: 0.21,
4999            quality_per_1k_tokens: 0.0,
5000            updated_at: 1,
5001        }];
5002        std::fs::write(
5003            models_dir.join("outcome_profiles.json"),
5004            serde_json::to_string(&observed).unwrap(),
5005        )
5006        .unwrap();
5007
5008        let priors_path = tmp.path().join("benchmark_priors.json");
5009        std::fs::write(
5010            &priors_path,
5011            serde_json::json!({
5012                "model_id": "qwen/qwen3-8b:q4_k_m",
5013                "overall_score": 0.95
5014            })
5015            .to_string(),
5016        )
5017        .unwrap();
5018
5019        unsafe {
5020            std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5021        }
5022
5023        let engine = InferenceEngine::new(test_config(models_dir));
5024        let tracker = engine.outcome_tracker.blocking_read();
5025        let profile = tracker
5026            .profile("qwen/qwen3-8b:q4_k_m")
5027            .expect("observed profile should remain present");
5028        assert!((profile.ema_quality - 0.21).abs() < 0.01);
5029        assert_eq!(profile.total_calls, 12);
5030
5031        unsafe {
5032            std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5033        }
5034    }
5035
5036    #[test]
5037    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
5038    fn speech_runtime_package_spec_defaults_and_overrides() {
5039        let _env = ENV_MUTEX.lock().unwrap();
5040        unsafe {
5041            std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5042        }
5043        assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.2");
5044
5045        unsafe {
5046            std::env::set_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC", "mlx-audio==0.4.1");
5047        }
5048        assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.1");
5049
5050        unsafe {
5051            std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5052        }
5053    }
5054
5055    #[test]
5056    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
5057    fn speech_runtime_spacy_model_spec_defaults_and_overrides() {
5058        let _env = ENV_MUTEX.lock().unwrap();
5059        unsafe {
5060            std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5061        }
5062        assert!(
5063            speech_runtime_spacy_model_spec().starts_with("en-core-web-sm @ https://github.com/")
5064        );
5065
5066        unsafe {
5067            std::env::set_var(
5068                "CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC",
5069                "en-core-web-sm==3.8.0",
5070            );
5071        }
5072        assert_eq!(
5073            speech_runtime_spacy_model_spec(),
5074            "en-core-web-sm==3.8.0"
5075        );
5076
5077        unsafe {
5078            std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5079        }
5080    }
5081
5082    #[test]
5083    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
5084    fn kokoro_runtime_fallback_defaults_on() {
5085        unsafe {
5086            std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5087        }
5088        assert!(kokoro_runtime_fallback_enabled());
5089
5090        unsafe {
5091            std::env::set_var("CAR_SPEECH_KOKORO_FALLBACK", "false");
5092        }
5093        assert!(!kokoro_runtime_fallback_enabled());
5094
5095        unsafe {
5096            std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5097        }
5098    }
5099
5100    #[test]
5101    fn preferred_local_tts_wins_over_builtin_rank() {
5102        let tmp = TempDir::new().unwrap();
5103        let mut engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5104        engine.set_speech_policy(SpeechPolicy {
5105            prefer_local: true,
5106            allow_remote_fallback: false,
5107            preferred_local_stt: None,
5108            preferred_local_tts: Some("Kokoro-82M-6bit".into()),
5109            preferred_remote_stt: None,
5110            preferred_remote_tts: None,
5111        });
5112
5113        let schema = engine
5114            .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
5115            .expect("preferred local TTS should resolve");
5116        assert_eq!(schema.name, "Kokoro-82M-6bit");
5117    }
5118
5119    #[test]
5120    fn preferred_discovered_vllm_mlx_model_wins_generate_routing() {
5121        let tmp = TempDir::new().unwrap();
5122        let mut config = test_config(tmp.path().join("models"));
5123        config.preferred_generation_model =
5124            Some("vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit".into());
5125        let mut engine = InferenceEngine::new(config);
5126        let schema = crate::vllm_mlx::to_model_schema(
5127            &crate::vllm_mlx::DiscoveredModel {
5128                id: "mlx-community/gemma-3n-E2B-it-lm-4bit".into(),
5129                owned_by: Some("mlx-community".into()),
5130            },
5131            "http://127.0.0.1:8001",
5132        );
5133        engine.register_model(schema);
5134
5135        let rt = tokio::runtime::Runtime::new().unwrap();
5136        let decision = rt.block_on(engine.route_adaptive("say hello in one sentence"));
5137        assert_eq!(decision.model_id, "vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit");
5138        assert_eq!(decision.strategy, RoutingStrategy::Explicit);
5139        assert_eq!(decision.reason, "preferred generation model override");
5140    }
5141
5142    /// Issue #43 — InferenceResult must serialize with all fields preserved
5143    /// (text, tool_calls, trace_id, model_used, latency_ms, usage) using
5144    /// snake_case field names. The car-server WebSocket handler relies on
5145    /// `serde_json::to_value(&InferenceResult)` producing this exact shape.
5146    #[test]
5147    fn inference_result_serializes_with_full_shape() {
5148        use crate::tasks::generate::ToolCall;
5149        use std::collections::HashMap;
5150
5151        let mut args = HashMap::new();
5152        args.insert("path".to_string(), serde_json::json!("README.md"));
5153
5154        let result = InferenceResult {
5155            text: String::new(),
5156            bounding_boxes: Vec::new(),
5157            tool_calls: vec![ToolCall {
5158                id: None,
5159                name: "read_file".into(),
5160                arguments: args,
5161            }],
5162            trace_id: "trace-abc".into(),
5163            model_used: "test-model".into(),
5164            latency_ms: 1234,
5165            time_to_first_token_ms: Some(180),
5166            usage: Some(TokenUsage {
5167                prompt_tokens: 100,
5168                completion_tokens: 50,
5169                total_tokens: 150,
5170                context_window: 8192,
5171            }),
5172        };
5173
5174        let json = serde_json::to_value(&result).expect("serialize");
5175
5176        // Required snake_case fields with type-strict assertions
5177        assert_eq!(json["text"].as_str(), Some(""));
5178        assert_eq!(json["trace_id"].as_str(), Some("trace-abc"));
5179        assert_eq!(json["model_used"].as_str(), Some("test-model"));
5180        assert_eq!(json["latency_ms"].as_u64(), Some(1234));
5181
5182        // tool_calls is a non-empty array with name + arguments
5183        let tool_calls = json["tool_calls"].as_array().expect("tool_calls array");
5184        assert_eq!(tool_calls.len(), 1);
5185        assert_eq!(tool_calls[0]["name"].as_str(), Some("read_file"));
5186        assert_eq!(tool_calls[0]["arguments"]["path"].as_str(), Some("README.md"));
5187
5188        // usage is an object with all four documented fields
5189        let usage = &json["usage"];
5190        assert_eq!(usage["prompt_tokens"].as_u64(), Some(100));
5191        assert_eq!(usage["completion_tokens"].as_u64(), Some(50));
5192        assert_eq!(usage["total_tokens"].as_u64(), Some(150));
5193        assert_eq!(usage["context_window"].as_u64(), Some(8192));
5194
5195        // TTFT propagates through serialization when populated.
5196        assert_eq!(json["time_to_first_token_ms"].as_u64(), Some(180));
5197    }
5198
5199    /// Issue #43 — Lock the top-level WebSocket `infer` response contract.
5200    /// If a future change adds a field to `InferenceResult`, this test forces
5201    /// the developer to deliberately update the protocol surface and the
5202    /// expected key set here, rather than silently leaking new fields onto
5203    /// the wire.
5204    #[test]
5205    fn inference_result_top_level_keys_are_locked() {
5206        use std::collections::BTreeSet;
5207
5208        let result = InferenceResult {
5209            text: "anything".into(),
5210            bounding_boxes: Vec::new(),
5211            tool_calls: vec![],
5212            trace_id: "t".into(),
5213            model_used: "m".into(),
5214            latency_ms: 0,
5215            time_to_first_token_ms: None,
5216            usage: None,
5217        };
5218
5219        let json = serde_json::to_value(&result).expect("serialize");
5220        let keys: BTreeSet<&str> = json
5221            .as_object()
5222            .expect("top-level object")
5223            .keys()
5224            .map(String::as_str)
5225            .collect();
5226
5227        let expected: BTreeSet<&str> = [
5228            "text",
5229            "tool_calls",
5230            "trace_id",
5231            "model_used",
5232            "latency_ms",
5233            "time_to_first_token_ms",
5234            "usage",
5235        ]
5236        .into_iter()
5237        .collect();
5238
5239        assert_eq!(
5240            keys, expected,
5241            "infer response top-level keys drifted -- update both the test \
5242             and the WebSocket protocol documentation if this is intentional"
5243        );
5244
5245        // All keys are snake_case (constraint c-2 in outcome 043).
5246        for key in &keys {
5247            assert!(
5248                !key.chars().any(|c| c.is_uppercase()) && !key.contains('-'),
5249                "key '{}' is not snake_case",
5250                key
5251            );
5252        }
5253    }
5254
5255    /// Plain text result (no tools) must still serialize cleanly with text
5256    /// populated and tool_calls present as an empty array. Backward compat
5257    /// for clients that only care about `.text`.
5258    #[test]
5259    fn inference_result_serializes_plain_text_response() {
5260        let result = InferenceResult {
5261            text: "hello world".into(),
5262            bounding_boxes: Vec::new(),
5263            tool_calls: vec![],
5264            trace_id: "trace-xyz".into(),
5265            model_used: "test-model".into(),
5266            latency_ms: 42,
5267            time_to_first_token_ms: None,
5268            usage: None,
5269        };
5270
5271        let json = serde_json::to_value(&result).expect("serialize");
5272        assert_eq!(json["text"], "hello world");
5273        assert!(json["tool_calls"].is_array());
5274        assert_eq!(json["tool_calls"].as_array().unwrap().len(), 0);
5275        assert_eq!(json["model_used"], "test-model");
5276        assert!(json["usage"].is_null());
5277        // Honest "wasn't measured" rather than missing key — the field
5278        // is always present at the protocol surface.
5279        assert!(json["time_to_first_token_ms"].is_null());
5280    }
5281
5282    #[test]
5283    fn rerank_prompt_matches_upstream_template_shape() {
5284        let p = rerank_prompt("retrieve relevant passages", "who runs the treasury?", "doc x");
5285        assert!(p.contains("<|im_start|>system"));
5286        assert!(p.contains("Note that the answer can only be \"yes\" or \"no\"."));
5287        assert!(p.contains("<|im_start|>user\n<Instruct>: retrieve relevant passages"));
5288        assert!(p.contains("<Query>: who runs the treasury?"));
5289        assert!(p.contains("<Document>: doc x<|im_end|>"));
5290        assert!(p.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
5291    }
5292
5293    #[test]
5294    fn rerank_score_yes_and_no_exactly() {
5295        assert_eq!(score_from_rerank_output("yes", "m"), 1.0);
5296        assert_eq!(score_from_rerank_output("no", "m"), 0.0);
5297    }
5298
5299    #[test]
5300    fn rerank_score_handles_case_leading_space_and_chat_sentinels() {
5301        // Real decodes often include leading whitespace, punctuation,
5302        // or chat-template sentinels around the answer token.
5303        assert_eq!(score_from_rerank_output(" Yes", "m"), 1.0);
5304        assert_eq!(score_from_rerank_output("\nno.", "m"), 0.0);
5305        assert_eq!(score_from_rerank_output("<|im_end|>yes", "m"), 1.0);
5306    }
5307
5308    #[test]
5309    fn rerank_score_scans_up_to_three_tokens() {
5310        // Tokenizer artifacts can produce a BOS-like leading token
5311        // before the real answer. Don't miss it.
5312        assert_eq!(score_from_rerank_output("_bos_ yes", "m"), 1.0);
5313    }
5314
5315    #[test]
5316    fn rerank_score_unexpected_is_neutral() {
5317        // Plain-base models that aren't reranker-fine-tuned will emit
5318        // arbitrary completion tokens. Don't partition; go neutral.
5319        assert_eq!(score_from_rerank_output("maybe", "m"), 0.5);
5320        assert_eq!(score_from_rerank_output("", "m"), 0.5);
5321    }
5322}