1pub mod adaptive_router;
26pub mod backend;
27pub mod backend_cache;
28pub mod handle;
29pub mod hardware;
30pub mod intent;
31pub mod key_pool;
32pub mod models;
33pub mod outcome;
34pub mod protocol;
35pub mod registry;
36pub mod remote;
37pub mod router;
38pub mod runner;
39pub mod routing_ext;
40pub mod schema;
41pub mod service;
42pub mod stream;
43pub mod tasks;
44pub mod vllm_mlx;
45
46use std::path::{Path, PathBuf};
47use std::sync::Arc;
48use std::time::Instant;
49use std::time::{SystemTime, UNIX_EPOCH};
50
51use serde::Serialize;
52use reqwest::multipart::{Form, Part};
53use thiserror::Error;
54#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
55use tokio::process::Command;
56#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
57use tokio::sync::Mutex;
58use tokio::sync::RwLock;
59use tracing::{debug, instrument};
60
61pub use adaptive_router::{
63 AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy,
64};
65pub use handle::InferenceHandle;
66pub use intent::{IntentHint, TaskHint};
67pub use key_pool::{KeyPool, KeyStats};
68pub use outcome::{
69 CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
70};
71pub use registry::{ModelFilter, ModelInfo, ModelRuntimeRequirement, ModelUpgrade, UnifiedRegistry};
72pub use remote::RemoteBackend;
73pub use runner::{
74 current_inference_runner, set_inference_runner, EventEmitter, InferenceRunner, RunnerError,
75 RunnerResult,
76};
77pub use routing_ext::{
78 CircuitBreaker, CircuitBreakerRegistry, CircuitState, ImplicitSignal, ImplicitSignalType,
79 RoutingMode, SpendControl, SpendLimitExceeded, SpendLimits, SpendStatus,
80};
81pub use schema::{
82 ApiProtocol, BenchmarkScore, CostModel, ModelCapability, ModelSchema, ModelSource,
83 PerformanceEnvelope, ProprietaryAuth,
84};
85
86pub use adaptive_router::TaskComplexity;
88#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
89pub use backend::CandleBackend;
90#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
91pub use backend::EmbeddingBackend;
92#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
93pub use backend::MlxBackend;
94pub use hardware::HardwareInfo;
95pub use models::{ModelRegistry, ModelRole};
96pub use router::{ModelRouter, RoutingDecision};
97pub use stream::{StreamAccumulator, StreamEvent};
98pub use tasks::{
99 parse_boxes, BoundingBox, ClassifyRequest, ClassifyResult, ContentBlock, EmbedRequest,
100 GenerateImageRequest, GenerateImageResult, GenerateParams, GenerateRequest,
101 GenerateVideoRequest, GenerateVideoResult, GroundRequest, GroundResult, Message, RerankRequest,
102 RerankResult, RerankedDocument, RoutingWorkload, SynthesizeRequest, SynthesizeResult,
103 ThinkingMode, ToolCall, TranscribeRequest, TranscribeResult, VideoMode,
104};
105#[derive(Error, Debug)]
108pub enum InferenceError {
109 #[error("model not found: {0}")]
110 ModelNotFound(String),
111
112 #[error("model download failed: {0}")]
113 DownloadFailed(String),
114
115 #[error("inference failed: {0}")]
116 InferenceFailed(String),
117
118 #[error("mode {mode} not implemented on backend {backend}: {reason}")]
123 UnsupportedMode {
124 mode: &'static str,
125 backend: &'static str,
126 reason: &'static str,
127 },
128
129 #[error("tokenization error: {0}")]
130 TokenizationError(String),
131
132 #[error("device error: {0}")]
133 DeviceError(String),
134
135 #[error("io error: {0}")]
136 Io(#[from] std::io::Error),
137}
138
139#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141pub enum Device {
142 Cpu,
143 Metal,
144 Cuda(usize), }
146
147impl Device {
148 pub fn auto() -> Self {
150 #[cfg(feature = "metal")]
151 {
152 return Device::Metal;
153 }
154 #[cfg(feature = "cuda")]
155 {
156 return Device::Cuda(0);
157 }
158 #[cfg(not(any(feature = "metal", feature = "cuda")))]
159 {
160 Device::Cpu
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct InferenceConfig {
168 pub models_dir: std::path::PathBuf,
170 pub device: Option<Device>,
172 pub generation_model: String,
174 pub preferred_generation_model: Option<String>,
176 pub embedding_model: String,
178 pub preferred_embedding_model: Option<String>,
180 pub classification_model: String,
182 pub preferred_classification_model: Option<String>,
184}
185
186impl Default for InferenceConfig {
187 fn default() -> Self {
188 let models_dir = dirs_next()
189 .unwrap_or_else(|| std::path::PathBuf::from("."))
190 .join(".car")
191 .join("models");
192
193 let hw = HardwareInfo::detect();
194
195 Self {
196 models_dir,
197 device: None,
198 generation_model: hw.recommended_model,
199 preferred_generation_model: None,
200 embedding_model: "Qwen3-Embedding-0.6B".to_string(),
201 preferred_embedding_model: None,
202 classification_model: "Qwen3-0.6B".to_string(),
203 preferred_classification_model: None,
204 }
205 }
206}
207
208fn dirs_next() -> Option<std::path::PathBuf> {
209 std::env::var("HOME").ok().map(std::path::PathBuf::from)
210}
211
212#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
214pub struct TokenUsage {
215 pub prompt_tokens: u64,
217 pub completion_tokens: u64,
219 pub total_tokens: u64,
221 pub context_window: u64,
223}
224
225#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
227pub struct InferenceResult {
228 pub text: String,
230 pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
232 #[serde(default, skip_serializing_if = "Vec::is_empty")]
239 pub bounding_boxes: Vec<crate::tasks::grounding::BoundingBox>,
240 pub trace_id: String,
242 pub model_used: String,
244 pub latency_ms: u64,
246 #[serde(default)]
258 pub time_to_first_token_ms: Option<u64>,
259 pub usage: Option<TokenUsage>,
261 #[serde(default, skip_serializing_if = "Vec::is_empty")]
274 pub provider_output_items: Vec<serde_json::Value>,
275}
276
277impl InferenceResult {
278 pub fn has_tool_calls(&self) -> bool {
280 !self.tool_calls.is_empty()
281 }
282}
283
284#[derive(Debug, Clone, Serialize)]
285pub struct SpeechRuntimeHealth {
286 pub root: PathBuf,
287 pub installed: bool,
288 pub python: PathBuf,
289 pub stt_command: PathBuf,
290 pub tts_command: PathBuf,
291 pub configured_python: Option<String>,
292 pub detected_python: Option<String>,
293}
294
295#[derive(Debug, Clone, Serialize)]
296pub struct SpeechModelHealth {
297 pub id: String,
298 pub name: String,
299 pub provider: String,
300 pub capability: ModelCapability,
301 pub is_local: bool,
302 pub available: bool,
303 pub cached: bool,
304 pub selected_by_default: bool,
305 pub source: String,
306}
307
308#[derive(Debug, Clone, Serialize)]
309pub struct SpeechHealthReport {
310 pub runtime: SpeechRuntimeHealth,
311 pub local_models: Vec<SpeechModelHealth>,
312 pub remote_models: Vec<SpeechModelHealth>,
313 pub elevenlabs_configured: bool,
314 pub prefer_local: bool,
315 pub allow_remote_fallback: bool,
316 pub preferred_local_stt: Option<String>,
317 pub preferred_local_tts: Option<String>,
318 pub preferred_remote_stt: Option<String>,
319 pub preferred_remote_tts: Option<String>,
320 pub local_stt_default: Option<String>,
321 pub local_tts_default: Option<String>,
322 pub remote_stt_default: Option<String>,
323 pub remote_tts_default: Option<String>,
324}
325
326#[derive(Debug, Clone, Serialize)]
327pub struct ModelDefaultHealth {
328 pub capability: ModelCapability,
329 pub configured_model: String,
330 pub available: bool,
331 pub is_local: bool,
332 pub provider: Option<String>,
333}
334
335#[derive(Debug, Clone, Serialize)]
336pub struct ModelProviderHealth {
337 pub provider: String,
338 pub configured: bool,
339 pub local_models: usize,
340 pub remote_models: usize,
341 pub available_models: usize,
342 pub capabilities: Vec<ModelCapability>,
343}
344
345#[derive(Debug, Clone, Serialize)]
346pub struct ModelCapabilityHealth {
347 pub capability: ModelCapability,
348 pub total_models: usize,
349 pub available_models: usize,
350 pub local_available_models: usize,
351 pub remote_available_models: usize,
352}
353
354#[derive(Debug, Clone, Serialize)]
355pub struct RoutingScenarioHealth {
356 pub name: String,
357 pub workload: RoutingWorkload,
358 pub task_family: String,
359 pub has_tools: bool,
360 pub has_vision: bool,
361 pub prefer_local: bool,
362 pub quality_first_cold_start: bool,
363 pub bootstrap_min_task_observations: u64,
364 pub bootstrap_quality_floor: f64,
365 pub model_id: String,
366 pub model_name: String,
367 pub reason: String,
368 pub strategy: RoutingStrategy,
369}
370
371#[derive(Debug, Clone, Serialize)]
372pub struct ModelBenchmarkPriorHealth {
373 pub model_id: String,
374 pub model_name: Option<String>,
375 pub overall_score: f64,
376 pub overall_latency_ms: Option<f64>,
377 pub task_scores: std::collections::HashMap<String, f64>,
378 pub task_latency_ms: std::collections::HashMap<String, f64>,
379 pub source_path: PathBuf,
380}
381
382#[derive(Debug, Clone, Serialize)]
383pub struct ModelHealthReport {
384 pub total_models: usize,
385 pub available_models: usize,
386 pub local_models: usize,
387 pub remote_models: usize,
388 pub defaults: Vec<ModelDefaultHealth>,
389 pub providers: Vec<ModelProviderHealth>,
390 pub capabilities: Vec<ModelCapabilityHealth>,
391 pub routing_prefer_local: bool,
392 pub routing_quality_first_cold_start: bool,
393 pub routing_min_observations: u64,
394 pub routing_bootstrap_min_task_observations: u64,
395 pub routing_bootstrap_quality_floor: f64,
396 pub routing_quality_weight: f64,
397 pub routing_latency_weight: f64,
398 pub routing_cost_weight: f64,
399 pub routing_scenarios: Vec<RoutingScenarioHealth>,
400 pub benchmark_priors: Vec<ModelBenchmarkPriorHealth>,
401 pub speech: SpeechHealthReport,
402}
403
404#[derive(Debug, Clone, Serialize)]
405pub struct SpeechInstallReport {
406 pub name: String,
407 pub hf_repo: String,
408 pub snapshot_path: PathBuf,
409 pub files_downloaded: usize,
410}
411
412#[derive(Debug, Clone, Serialize)]
413pub struct SpeechSmokePathReport {
414 pub path: String,
415 pub tts_model: String,
416 pub stt_model: String,
417 pub audio_path: PathBuf,
418 pub transcript: String,
419}
420
421#[derive(Debug, Clone, Serialize, Default)]
422pub struct SpeechSmokeReport {
423 pub local: Option<SpeechSmokePathReport>,
424 pub remote: Option<SpeechSmokePathReport>,
425 pub skipped: Vec<String>,
426}
427
428#[derive(Debug, Clone, Serialize, Default)]
429pub struct SpeechPolicy {
430 pub prefer_local: bool,
431 pub allow_remote_fallback: bool,
432 pub preferred_local_stt: Option<String>,
433 pub preferred_local_tts: Option<String>,
434 pub preferred_remote_stt: Option<String>,
435 pub preferred_remote_tts: Option<String>,
436}
437
438pub struct InferenceEngine {
443 pub config: InferenceConfig,
444 pub unified_registry: UnifiedRegistry,
446 pub adaptive_router: AdaptiveRouter,
448 pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
450 remote_backend: RemoteBackend,
452 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
458 mlx_backends: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
459 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
464 flux_cache: Arc<backend_cache::BackendCache<backend::mlx_flux::FluxBackend>>,
465 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
467 ltx_cache: Arc<backend_cache::BackendCache<backend::mlx_ltx::LtxBackend>>,
468 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
471 kokoro_cache: Arc<backend_cache::BackendCache<backend::mlx_kokoro::KokoroBackend>>,
472 pub registry: models::ModelRegistry,
474 pub router: ModelRouter,
475 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
476 backend: Arc<RwLock<Option<CandleBackend>>>,
477 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
478 embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
479 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
480 speech_runtime: Arc<Mutex<Option<SpeechRuntime>>>,
481 speech_policy: SpeechPolicy,
482}
483
484impl InferenceEngine {
485 fn preferred_model_for_capability(&self, capability: ModelCapability) -> Option<&str> {
486 match capability {
487 ModelCapability::Generate => self.config.preferred_generation_model.as_deref(),
488 ModelCapability::Embed => self.config.preferred_embedding_model.as_deref(),
489 ModelCapability::Classify => self.config.preferred_classification_model.as_deref(),
490 _ => None,
491 }
492 }
493
494 fn request_needs_vision(req: &GenerateRequest) -> bool {
495 req.images.as_ref().is_some_and(|images| !images.is_empty())
496 || req.messages.as_ref().is_some_and(|messages| {
497 messages
498 .iter()
499 .any(|msg| matches!(msg, Message::UserMultimodal { .. }))
500 })
501 }
502
503 fn request_has_video(req: &GenerateRequest) -> bool {
508 let images_have_video = req
509 .images
510 .as_ref()
511 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
512 let messages_have_video = req.messages.as_ref().is_some_and(|messages| {
513 messages.iter().any(|msg| match msg {
514 Message::UserMultimodal { content } => {
515 content.iter().any(ContentBlock::is_video)
516 }
517 _ => false,
518 })
519 });
520 images_have_video || messages_have_video
521 }
522
523 fn request_has_audio(req: &GenerateRequest) -> bool {
527 let images_have_audio = req
528 .images
529 .as_ref()
530 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
531 let messages_have_audio = req.messages.as_ref().is_some_and(|messages| {
532 messages.iter().any(|msg| match msg {
533 Message::UserMultimodal { content } => {
534 content.iter().any(ContentBlock::is_audio)
535 }
536 _ => false,
537 })
538 });
539 images_have_audio || messages_have_audio
540 }
541
542 pub fn new(config: InferenceConfig) -> Self {
543 let registry = models::ModelRegistry::new(config.models_dir.clone());
544 let hw = HardwareInfo::detect();
545 let router = ModelRouter::new(hw.clone());
546 let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
547 let adaptive_router = AdaptiveRouter::with_default_config(hw);
548 let mut tracker = OutcomeTracker::new();
549 let profiles_path = config.models_dir.join("outcome_profiles.json");
551 if let Ok(n) = tracker.load_from_file(&profiles_path) {
552 if n > 0 {
553 tracing::info!(loaded = n, "loaded persisted model profiles");
554 }
555 }
556 let mut benchmark_models_loaded = 0usize;
557 for path in benchmark_priors_paths(&config.models_dir) {
558 match routing_ext::load_benchmark_priors(&path) {
559 Ok(priors) if !priors.is_empty() => {
560 benchmark_models_loaded += priors.len();
561 routing_ext::apply_benchmark_priors(&mut tracker, &priors);
562 tracing::info!(
563 path = %path.display(),
564 loaded = priors.len(),
565 "loaded benchmark quality priors"
566 );
567 }
568 Ok(_) => {}
569 Err(error) => {
570 tracing::warn!(path = %path.display(), %error, "failed to load benchmark priors");
571 }
572 }
573 }
574 if benchmark_models_loaded > 0 {
575 tracing::info!(
576 loaded = benchmark_models_loaded,
577 "applied benchmark priors to cold-start routing"
578 );
579 }
580 let outcome_tracker = Arc::new(RwLock::new(tracker));
581
582 let remote_backend = RemoteBackend::new();
583
584 Self {
585 config,
586 unified_registry,
587 adaptive_router,
588 outcome_tracker,
589 remote_backend,
590 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
591 mlx_backends: Arc::new(backend_cache::BackendCache::from_env()),
592 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
593 flux_cache: Arc::new(backend_cache::BackendCache::from_env()),
594 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
595 ltx_cache: Arc::new(backend_cache::BackendCache::from_env()),
596 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
597 kokoro_cache: Arc::new(backend_cache::BackendCache::from_env()),
598 registry,
599 router,
600 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
601 backend: Arc::new(RwLock::new(None)),
602 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
603 embedding_backend: Arc::new(RwLock::new(None)),
604 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
605 speech_runtime: Arc::new(Mutex::new(None)),
606 speech_policy: SpeechPolicy {
607 prefer_local: cfg!(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))),
608 allow_remote_fallback: true,
609 preferred_local_stt: None,
610 preferred_local_tts: None,
611 preferred_remote_stt: None,
612 preferred_remote_tts: None,
613 },
614 }
615 }
616
617 pub async fn init_key_pool(&self) {
620 for schema in self.unified_registry.list() {
622 if schema.is_remote() {
623 self.remote_backend.register_model_keys(schema).await;
624 }
625 }
626
627 let stats_path = self.config.models_dir.join("key_pool_stats.json");
629 if let Ok(n) = self.remote_backend.key_pool.load_stats(&stats_path).await {
630 if n > 0 {
631 tracing::info!(loaded = n, "loaded persisted key pool stats");
632 }
633 }
634
635 let total = self.remote_backend.key_pool.total_keys().await;
636 if total > 0 {
637 tracing::info!(keys = total, "key pool initialized");
638 }
639 }
640
641 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
644 async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
645 let read = self.backend.read().await;
646 if read.is_some() {
647 return Ok(());
648 }
649 drop(read);
650
651 let mut write = self.backend.write().await;
652 if write.is_some() {
653 return Ok(());
654 }
655
656 let model_path = self.registry.ensure_model(model_name).await?;
657 let device = self.config.device.unwrap_or_else(Device::auto);
658 let backend = CandleBackend::load(&model_path, device)?;
659 *write = Some(backend);
660 Ok(())
661 }
662
663 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
666 async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
667 let read = self.embedding_backend.read().await;
668 if read.is_some() {
669 return Ok(());
670 }
671 drop(read);
672
673 let mut write = self.embedding_backend.write().await;
674 if write.is_some() {
675 return Ok(());
676 }
677
678 let embedding_model = self
679 .preferred_model_for_capability(ModelCapability::Embed)
680 .unwrap_or(&self.config.embedding_model);
681 let model_path = self
682 .registry
683 .ensure_model(embedding_model)
684 .await?;
685 let device = self.config.device.unwrap_or_else(Device::auto);
686 let backend = EmbeddingBackend::load(&model_path, device)?;
687 *write = Some(backend);
688 Ok(())
689 }
690
691 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
694 async fn ensure_mlx_embedding_backend(&self) -> Result<String, InferenceError> {
695 let embedding_model_name = self
696 .preferred_model_for_capability(ModelCapability::Embed)
697 .unwrap_or(&self.config.embedding_model)
698 .to_string();
699 let schema = self
700 .unified_registry
701 .get(&embedding_model_name)
702 .or_else(|| self.unified_registry.find_by_name(&embedding_model_name))
703 .ok_or_else(|| InferenceError::ModelNotFound(embedding_model_name.clone()))?
704 .clone();
705 self.ensure_mlx_backend(&schema).await?;
706 Ok(schema.id)
707 }
708
709 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
713 async fn ensure_mlx_backend(
714 &self,
715 schema: &ModelSchema,
716 ) -> Result<backend_cache::CachedBackend<backend::MlxBackend>, InferenceError> {
717 if !Self::supports_native_mlx(schema) {
718 return Err(InferenceError::InferenceFailed(format!(
719 "native MLX backend does not support {} ({}) yet; use vLLM-MLX or add a family-specific MLX backend",
720 schema.name, schema.family
721 )));
722 }
723
724 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
725 let size = backend_cache::estimate_model_size(&model_dir);
726 let cache = Arc::clone(&self.mlx_backends);
727 let key = schema.id.clone();
728 cache.get_or_load(&key, size, move || {
732 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
733 backend::MlxBackend::load(&model_dir)
734 }))
735 .map_err(|e| {
736 InferenceError::InferenceFailed(format!(
737 "MLX backend loading panicked (possible Metal/accelerate exception): {:?}",
738 e
739 ))
740 })?
741 })
742 }
743
744 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
751 pub async fn warm_up<S: AsRef<str>>(&self, schema_ids: &[S]) -> Vec<Result<(), InferenceError>> {
752 let mut results = Vec::with_capacity(schema_ids.len());
753 for id in schema_ids {
754 let id = id.as_ref();
755 let outcome: Result<(), InferenceError> = async {
756 let schema = self
757 .unified_registry
758 .get(id)
759 .cloned()
760 .ok_or_else(|| InferenceError::InferenceFailed(
761 format!("warm_up: unknown schema id {id}"),
762 ))?;
763 match schema.capabilities.first().copied() {
764 Some(ModelCapability::ImageGeneration) => {
765 let model_dir =
766 self.unified_registry.ensure_local(&schema.id).await?;
767 let size = backend_cache::estimate_model_size(&model_dir);
768 let _ = self.flux_cache.get_or_load(&schema.id, size, || {
769 backend::mlx_flux::FluxBackend::load(&model_dir)
770 })?;
771 }
772 Some(ModelCapability::VideoGeneration) => {
773 let model_dir =
774 self.unified_registry.ensure_local(&schema.id).await?;
775 let size = backend_cache::estimate_model_size(&model_dir);
776 let _ = self.ltx_cache.get_or_load(&schema.id, size, || {
777 backend::mlx_ltx::LtxBackend::load(&model_dir)
778 })?;
779 }
780 Some(ModelCapability::TextToSpeech) => {
781 let model_dir =
782 self.unified_registry.ensure_local(&schema.id).await?;
783 let size = backend_cache::estimate_model_size(&model_dir);
784 let _ = self.kokoro_cache.get_or_load(&schema.id, size, || {
785 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
786 })?;
787 }
788 _ => {
789 let _ = self.ensure_mlx_backend(&schema).await?;
790 }
791 }
792 Ok(())
793 }
794 .await;
795 results.push(outcome);
796 }
797 results
798 }
799
800 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
802 pub async fn warm_up<S: AsRef<str>>(&self, _schema_ids: &[S]) -> Vec<Result<(), InferenceError>> {
803 Vec::new()
804 }
805
806 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
807 fn supports_native_mlx(schema: &ModelSchema) -> bool {
808 matches!(schema.family.as_str(), "qwen3" | "qwen2.5-vl" | "qwen2-vl")
809 }
810
811 pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
813 if let Some(model) = self.preferred_model_for_capability(ModelCapability::Generate) {
814 let ctx_len = self
815 .unified_registry
816 .get(model)
817 .or_else(|| self.unified_registry.find_by_name(model))
818 .map(|s| s.context_length)
819 .unwrap_or(0);
820 return AdaptiveRoutingDecision {
821 model_id: model.to_string(),
822 model_name: model.to_string(),
823 task: InferenceTask::Generate,
824 complexity: TaskComplexity::assess(prompt),
825 reason: "preferred generation model override".into(),
826 strategy: RoutingStrategy::Explicit,
827 predicted_quality: 0.5,
828 fallbacks: vec![],
829 context_length: ctx_len,
830 needs_compaction: false,
831 };
832 }
833 let tracker = self.outcome_tracker.read().await;
834 self.adaptive_router
835 .route(prompt, &self.unified_registry, &tracker)
836 }
837
838 pub fn route(&self, prompt: &str) -> RoutingDecision {
840 self.router.route_generate(prompt, &self.registry)
841 }
842
843 pub fn estimated_tokens(
846 &self,
847 req: &GenerateRequest,
848 model_id: Option<&str>,
849 ) -> (usize, usize, bool) {
850 let prompt_tokens = remote::estimate_tokens(&req.prompt);
851 let context_tokens = req
852 .context
853 .as_ref()
854 .map(|c| remote::estimate_tokens(c))
855 .unwrap_or(0);
856 let tools_tokens = req
857 .tools
858 .as_ref()
859 .map(|t| remote::estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
860 .unwrap_or(0);
861 let total_input = prompt_tokens + context_tokens + tools_tokens;
862
863 let context_window = model_id
864 .and_then(|id| {
865 self.unified_registry
866 .get(id)
867 .or_else(|| self.unified_registry.find_by_name(id))
868 })
869 .map(|s| s.context_length)
870 .unwrap_or(0);
871
872 let fits = context_window == 0 || (total_input + req.params.max_tokens) <= context_window;
873 (total_input, context_window, fits)
874 }
875
876 #[instrument(
879 name = "inference.generate",
880 skip_all,
881 fields(
882 model = tracing::field::Empty,
883 max_tokens = req.params.max_tokens,
884 prompt_tokens = tracing::field::Empty,
885 completion_tokens = tracing::field::Empty,
886 latency_ms = tracing::field::Empty,
887 )
888 )]
889 pub async fn generate_tracked(
890 &self,
891 req: GenerateRequest,
892 ) -> Result<InferenceResult, InferenceError> {
893 let start = Instant::now();
894
895 let (estimated_input, _, _) = self.estimated_tokens(&req, None);
897 let tracker_read = self.outcome_tracker.read().await;
898 let has_tools = req.tools.is_some();
899 let has_vision = Self::request_needs_vision(&req);
900 let preferred_model = self
901 .preferred_model_for_capability(ModelCapability::Generate)
902 .map(str::to_string);
903 let decision = match req.model.clone().or(preferred_model) {
904 Some(m) => {
905 let ctx_len = self
906 .unified_registry
907 .get(&m)
908 .or_else(|| self.unified_registry.find_by_name(&m))
909 .map(|s| s.context_length)
910 .unwrap_or(0);
911 AdaptiveRoutingDecision {
912 model_id: m.clone(),
913 model_name: m.clone(),
914 task: InferenceTask::Generate,
915 complexity: TaskComplexity::assess(&req.prompt),
916 reason: "explicit model".into(),
917 strategy: RoutingStrategy::Explicit,
918 predicted_quality: 0.5,
919 fallbacks: vec![],
920 context_length: ctx_len,
921 needs_compaction: ctx_len > 0 && estimated_input > ctx_len,
922 }
923 }
924 None => match &req.intent {
925 Some(hint) => self.adaptive_router.route_context_aware_with_intent(
926 &req.prompt,
927 estimated_input,
928 &self.unified_registry,
929 &tracker_read,
930 has_tools,
931 has_vision,
932 req.params.workload,
933 hint,
934 ),
935 None => self.adaptive_router.route_context_aware(
936 &req.prompt,
937 estimated_input,
938 &self.unified_registry,
939 &tracker_read,
940 has_tools,
941 has_vision,
942 req.params.workload,
943 ),
944 },
945 };
946 drop(tracker_read);
947
948 if decision.needs_compaction {
949 tracing::info!(
950 model = %decision.model_name,
951 prompt_tokens = estimated_input,
952 context_window = decision.context_length,
953 "prompt exceeds model context window — compaction or truncation needed"
954 );
955 }
956
957 let trace_id = {
959 let mut tracker = self.outcome_tracker.write().await;
960 tracker.record_start(&decision.model_id, decision.task, &decision.reason)
961 };
962
963 debug!(
964 model = %decision.model_name,
965 strategy = ?decision.strategy,
966 reason = %decision.reason,
967 trace = %trace_id,
968 "adaptive-routed generate request"
969 );
970
971 let mut req = req;
974 if req.params.budget_tokens == 0 && matches!(decision.complexity, TaskComplexity::Complex) {
975 let supports_thinking = self
976 .unified_registry
977 .get(&decision.model_id)
978 .map(|s| {
979 s.supported_params
980 .contains(&schema::GenerateParam::ExtendedThinking)
981 })
982 .unwrap_or(false);
983 if supports_thinking {
984 req.params.budget_tokens = 8000;
985 tracing::info!(model = %decision.model_name, budget = 8000, "auto-enabled extended thinking for complex task");
986 }
987 }
988
989 let mut models_to_try = vec![decision.model_id.clone()];
991 models_to_try.extend(decision.fallbacks.iter().cloned());
992
993 let mut last_error = None;
994
995 for candidate_id in &models_to_try {
996 #[allow(unused_mut)]
999 let mut schema = self
1000 .unified_registry
1001 .get(candidate_id)
1002 .or_else(|| self.unified_registry.find_by_name(candidate_id))
1003 .cloned();
1004
1005 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1007 if let Some(ref s) = schema {
1008 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1009 tracing::info!(
1010 from = %s.id, to = %mlx_equiv.id,
1011 "redirecting GGUF model to MLX equivalent on Apple Silicon"
1012 );
1013 schema = Some(mlx_equiv.clone());
1014 }
1015 }
1016
1017 let candidate_name = schema
1018 .as_ref()
1019 .map(|s| s.name.clone())
1020 .unwrap_or_else(|| candidate_id.clone());
1021
1022 let is_remote = schema
1023 .as_ref()
1024 .map(|s| s.is_remote() || s.is_vllm_mlx())
1025 .unwrap_or(false);
1026 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1027
1028 if is_delegated {
1034 let runner = match runner::current_inference_runner() {
1035 Some(r) => r,
1036 None => {
1037 last_error = Some(InferenceError::InferenceFailed(
1038 "model declares ModelSource::Delegated but no inference runner is registered"
1039 .into(),
1040 ));
1041 continue;
1042 }
1043 };
1044 let (tx, mut rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1045 let emitter = runner::EventEmitter::new(tx);
1046 let runner_req = req.clone();
1047 let runner_handle = tokio::spawn(async move { runner.run(runner_req, emitter).await });
1048 let mut accumulator = stream::StreamAccumulator::default();
1049 while let Some(evt) = rx.recv().await {
1050 accumulator.push(&evt);
1051 }
1052 let (acc_text, acc_tool_calls) = accumulator.finish();
1056 match runner_handle.await {
1057 Ok(Ok(_runner_result)) => {
1058 let elapsed = start.elapsed().as_millis() as u64;
1059 let mut tracker = self.outcome_tracker.write().await;
1060 tracker.record_complete(&trace_id, elapsed, 0, 0);
1061 return Ok(InferenceResult {
1062 text: acc_text,
1063 tool_calls: acc_tool_calls,
1064 bounding_boxes: vec![],
1065 trace_id,
1066 model_used: candidate_name,
1067 latency_ms: elapsed,
1068 time_to_first_token_ms: None,
1069 usage: None,
1070 provider_output_items: vec![],
1071 });
1072 }
1073 Ok(Err(e)) => {
1074 last_error = Some(InferenceError::InferenceFailed(e.to_string()));
1075 continue;
1076 }
1077 Err(join_err) => {
1078 last_error = Some(InferenceError::InferenceFailed(format!(
1079 "runner task panicked: {join_err}"
1080 )));
1081 continue;
1082 }
1083 }
1084 }
1085
1086 let has_tools = req.tools.is_some();
1087
1088 let context = if has_tools
1090 && req.tools.as_ref().map_or(false, |t| {
1091 t.iter().any(|tool| {
1092 tool.get("function")
1093 .and_then(|f| f.get("name"))
1094 .and_then(|n| n.as_str())
1095 == Some("done")
1096 })
1097 }) {
1098 let base = req.context.as_deref().unwrap_or("");
1099 Some(format!(
1100 "{base}\n\nIMPORTANT: When calling the `done` tool, the `result` field MUST contain a DETAILED summary of everything you found and did. This is the ONLY output the user sees. Do NOT just say 'completed' — include specific findings, data, and conclusions."
1101 ))
1102 } else {
1103 req.context.clone()
1104 };
1105
1106 let result = if is_remote {
1107 let schema_val = schema.unwrap();
1108 let _ctx_len = schema_val.context_length;
1109 let temperature = if !schema_val.supported_params.is_empty()
1112 && !schema_val
1113 .supported_params
1114 .contains(&crate::schema::GenerateParam::Temperature)
1115 {
1116 -1.0
1117 } else {
1118 req.params.temperature
1119 };
1120
1121 self.remote_backend
1127 .generate_with_tools_multi(
1128 &schema_val,
1129 &req.prompt,
1130 context.as_deref(),
1131 temperature,
1132 req.params.max_tokens,
1133 req.tools.as_deref(),
1134 req.images.as_deref(),
1135 req.messages.as_deref(),
1136 req.params.tool_choice.as_deref(),
1137 req.params.parallel_tool_calls,
1138 req.params.budget_tokens,
1139 req.cache_control,
1140 req.response_format.as_ref(),
1141 )
1142 .await
1143 .map(|(t, c, u)| (t, c, u, None::<u64>))
1148 } else {
1149 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1150 {
1151 let schema_ref = schema
1155 .as_ref()
1156 .ok_or_else(|| InferenceError::ModelNotFound(candidate_id.clone()))?;
1157
1158 if schema_ref.is_foundation_models() {
1163 if has_tools {
1164 Err(InferenceError::UnsupportedMode {
1165 mode: "tool-use",
1166 backend: "foundation-models",
1167 reason:
1168 "tool calling not yet wired through the FoundationModels \
1169 bridge — route to a remote model with ToolUse capability",
1170 })
1171 } else if Self::request_has_video(&req)
1172 || Self::request_has_audio(&req)
1173 || req
1174 .images
1175 .as_ref()
1176 .is_some_and(|imgs| !imgs.is_empty())
1177 {
1178 Err(InferenceError::UnsupportedMode {
1179 mode: "multimodal-content",
1180 backend: "foundation-models",
1181 reason:
1182 "the FoundationModels bridge currently exposes text-only \
1183 generation — route image/audio/video to a remote VL model",
1184 })
1185 } else {
1186 let prompt = req.prompt.clone();
1187 let instructions = context.clone();
1188 let max_tokens = req.params.max_tokens as u32;
1189 let temperature = req.params.temperature;
1190 tokio::task::spawn_blocking(move || {
1191 crate::backend::foundation_models::generate(
1192 &prompt,
1193 instructions.as_deref(),
1194 max_tokens,
1195 temperature as f32,
1196 )
1197 })
1198 .await
1199 .map_err(|e| {
1200 InferenceError::InferenceFailed(format!(
1201 "FoundationModels task panicked: {e}"
1202 ))
1203 })
1204 .and_then(|r| r)
1205 .map(|text| (text, vec![], None, None))
1206 }
1207 } else if !schema_ref.is_mlx() {
1208 Err(InferenceError::InferenceFailed(format!(
1209 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1210 schema_ref.id
1211 )))
1212 } else if schema_ref.tags.iter().any(|t| t == "mlx-vlm-cli") {
1213 let has_images = req
1225 .images
1226 .as_ref()
1227 .is_some_and(|imgs| !imgs.is_empty());
1228 if !has_images {
1229 return Err(InferenceError::UnsupportedMode {
1230 mode: "text-only-on-mlx-vlm-id",
1231 backend: "mlx-vlm-cli",
1232 reason:
1233 "the `mlx-vlm/...` model IDs route exclusively \
1234 through the mlx-vlm CLI for image inference. \
1235 For text-only generation, route to a Qwen3 \
1236 text model (`mlx/qwen3-4b:4bit` etc.) — the \
1237 CLI shell-out has higher latency than the \
1238 in-process MLX text tower.",
1239 });
1240 }
1241 let vlm_status = crate::backend::mlx_vlm_cli::runtime_status();
1242 if !vlm_status.is_available() {
1243 return Err(InferenceError::InferenceFailed(vlm_status.user_message()));
1244 }
1245 let repo = match &schema_ref.source {
1246 crate::schema::ModelSource::Mlx { hf_repo, .. } => {
1247 hf_repo.clone()
1248 }
1249 _ => {
1250 return Err(InferenceError::InferenceFailed(format!(
1251 "model '{}' is tagged mlx-vlm-cli but its \
1252 source isn't ModelSource::Mlx — registry bug",
1253 schema_ref.id
1254 )));
1255 }
1256 };
1257 let imgs = req.images.clone().unwrap_or_default();
1258 let temp = req.params.temperature;
1259 let max_t = req.params.max_tokens;
1260 let prompt = req.prompt.clone();
1261 let text = tokio::task::spawn_blocking(move || {
1262 crate::backend::mlx_vlm_cli::generate(
1263 &repo, &prompt, &imgs, temp, max_t,
1264 )
1265 })
1266 .await
1267 .map_err(|e| {
1268 InferenceError::InferenceFailed(format!(
1269 "mlx_vlm CLI task panicked: {e}"
1270 ))
1271 })??;
1272 let bounding_boxes = parse_boxes(&text);
1273 let latency_ms = start.elapsed().as_millis() as u64;
1274 {
1275 let mut tracker = self.outcome_tracker.write().await;
1276 tracker.record_complete(&trace_id, latency_ms, 0, 0);
1277 }
1278 return Ok(InferenceResult {
1279 text,
1280 tool_calls: vec![],
1281 bounding_boxes,
1282 trace_id: trace_id.clone(),
1283 model_used: schema_ref.id.clone(),
1284 latency_ms,
1285 time_to_first_token_ms: None,
1286 usage: None,
1287 provider_output_items: Vec::new(),
1288 });
1289 } else {
1290 let handle = self.ensure_mlx_backend(schema_ref).await?;
1300 if Self::request_has_video(&req) {
1306 return Err(InferenceError::UnsupportedMode {
1307 mode: "video-content-block",
1308 backend: "native-mlx-qwen25vl",
1309 reason:
1310 "Qwen2.5-VL video understanding is on the request surface \
1311 but the video-tokenization path (frame sampling + merger) \
1312 is not yet wired; route to a remote VL provider for now",
1313 });
1314 }
1315 if Self::request_has_audio(&req) {
1316 return Err(InferenceError::UnsupportedMode {
1317 mode: "audio-content-block",
1318 backend: "native-mlx-qwen25vl",
1319 reason:
1320 "audio understanding is on the request surface (Gemma 4 \
1321 E2B/E4B and Gemini accept it) but the native MLX path \
1322 for this model does not — route to Gemini or Gemma-4",
1323 });
1324 }
1325 let has_images = req
1326 .images
1327 .as_ref()
1328 .is_some_and(|imgs| !imgs.is_empty());
1329 if has_images {
1330 let can_do_vision = {
1331 let guard = handle.lock().map_err(|_| {
1332 InferenceError::InferenceFailed(
1333 "MLX backend mutex poisoned".into(),
1334 )
1335 })?;
1336 guard.supports_capability(
1337 crate::schema::ModelCapability::Vision,
1338 )
1339 };
1340 if !can_do_vision {
1341 return Err(InferenceError::UnsupportedMode {
1351 mode: "image-content-block",
1352 backend: "native-mlx-text",
1353 reason:
1354 "this MLX backend is a plain Qwen3 text tower. \
1355 For local image inference, route to \
1356 `mlx-vlm/qwen3-vl-2b:bf16` or another `mlx-vlm/...` \
1357 catalog ID so CAR shells out to `mlx_vlm.generate`. \
1358 Alternatives: a local vLLM-MLX VLM server, or a \
1359 remote VL model. (#115)",
1360 });
1361 }
1362 }
1363 self.generate_mlx(req.clone(), &schema_ref.id)
1364 .await
1365 .map(|(text, ttft)| (text, vec![], None, ttft))
1366 }
1367 }
1368
1369 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1370 {
1371 match self.ensure_backend(&candidate_name).await {
1372 Ok(()) => {
1373 let mut write = self.backend.write().await;
1374 let backend = write.as_mut().unwrap();
1375 tasks::generate::generate(backend, req.clone())
1376 .await
1377 .map(|(text, ttft)| (text, vec![], None, ttft))
1378 }
1379 Err(e) => Err(e),
1380 }
1381 }
1382 };
1383
1384 match result {
1385 Ok((text, tool_calls, usage, time_to_first_token_ms)) => {
1386 let latency_ms = start.elapsed().as_millis() as u64;
1387 let estimated_tokens = usage
1388 .as_ref()
1389 .map(|u| u.completion_tokens as usize)
1390 .unwrap_or_else(|| text.split_whitespace().count());
1391 {
1392 let mut tracker = self.outcome_tracker.write().await;
1393 tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
1394 }
1395 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1397 cb.record_success(candidate_id);
1398 }
1399 self.auto_save_outcomes().await;
1401
1402 let span = tracing::Span::current();
1404 span.record("model", candidate_name.as_str());
1405 span.record("latency_ms", latency_ms);
1406 if let Some(ttft) = time_to_first_token_ms {
1407 span.record("ttft_ms", ttft);
1408 }
1409 if let Some(ref u) = usage {
1410 span.record("prompt_tokens", u.prompt_tokens);
1411 span.record("completion_tokens", u.completion_tokens);
1412 }
1413
1414 let bounding_boxes = tasks::grounding::parse_boxes(&text);
1417 return Ok(InferenceResult {
1418 text,
1419 tool_calls,
1420 bounding_boxes,
1421 trace_id,
1422 model_used: candidate_name,
1423 latency_ms,
1424 time_to_first_token_ms,
1425 usage,
1426 provider_output_items: Vec::new(),
1427 });
1428 }
1429 Err(e) => {
1430 tracing::warn!(
1431 model = %candidate_name,
1432 error = %e,
1433 remaining = models_to_try.len().saturating_sub(
1434 models_to_try.iter().position(|m| m == candidate_id).unwrap_or(0) + 1
1435 ),
1436 "model failed, trying next fallback immediately"
1437 );
1438 {
1440 let mut tracker = self.outcome_tracker.write().await;
1441 let fail_trace =
1442 tracker.record_start(candidate_id, decision.task, "fallback");
1443 tracker.record_failure(&fail_trace, &e.to_string());
1444 }
1445 {
1449 let err_str = e.to_string();
1450 let is_client_error =
1451 err_str.contains("API returned 4") && !err_str.contains("429");
1452 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1453 cb.record_failure(candidate_id);
1454 if is_client_error {
1457 cb.record_failure(candidate_id);
1458 }
1459 }
1460 }
1461 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1463 {
1464 let mut write = self.backend.write().await;
1465 *write = None;
1466 }
1467 last_error = Some(e);
1468 }
1469 }
1470 }
1471
1472 let e = last_error.unwrap_or(InferenceError::InferenceFailed(
1474 "no models available".into(),
1475 ));
1476 {
1477 let mut tracker = self.outcome_tracker.write().await;
1478 tracker.record_failure(&trace_id, &e.to_string());
1479 }
1480 self.auto_save_outcomes().await;
1481 Err(e)
1482 }
1483
1484 pub async fn generate_tracked_stream(
1521 &self,
1522 req: GenerateRequest,
1523 ) -> Result<tokio::sync::mpsc::Receiver<stream::StreamEvent>, InferenceError> {
1524 let has_tools = req.tools.is_some();
1525 let has_vision = Self::request_needs_vision(&req);
1526 let preferred_model = self
1527 .preferred_model_for_capability(ModelCapability::Generate)
1528 .map(str::to_string);
1529 let decision = match req.model.clone().or(preferred_model) {
1530 Some(m) => {
1531 let ctx_len = self
1532 .unified_registry
1533 .get(&m)
1534 .or_else(|| self.unified_registry.find_by_name(&m))
1535 .map(|s| s.context_length)
1536 .unwrap_or(0);
1537 AdaptiveRoutingDecision {
1538 model_id: m.clone(),
1539 model_name: m,
1540 task: InferenceTask::Generate,
1541 complexity: TaskComplexity::assess(&req.prompt),
1542 reason: "explicit model".into(),
1543 strategy: RoutingStrategy::Explicit,
1544 predicted_quality: 0.5,
1545 fallbacks: vec![],
1546 context_length: ctx_len,
1547 needs_compaction: false,
1548 }
1549 }
1550 None => {
1551 let tracker_read = self.outcome_tracker.read().await;
1552 if has_vision {
1553 self.adaptive_router.route_with_vision(
1554 &req.prompt,
1555 &self.unified_registry,
1556 &tracker_read,
1557 has_tools,
1558 )
1559 } else if has_tools {
1560 self.adaptive_router.route_with_tools(
1561 &req.prompt,
1562 &self.unified_registry,
1563 &tracker_read,
1564 )
1565 } else {
1566 self.adaptive_router
1567 .route(&req.prompt, &self.unified_registry, &tracker_read)
1568 }
1569 }
1570 };
1571
1572 #[allow(unused_mut)]
1575 let mut schema = self
1576 .unified_registry
1577 .get(&decision.model_id)
1578 .or_else(|| self.unified_registry.find_by_name(&decision.model_id))
1579 .cloned();
1580
1581 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1583 if let Some(ref s) = schema {
1584 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1585 tracing::info!(
1586 from = %s.id, to = %mlx_equiv.id,
1587 "redirecting GGUF model to MLX equivalent on Apple Silicon (stream)"
1588 );
1589 schema = Some(mlx_equiv.clone());
1590 }
1591 }
1592
1593 let is_remote = schema
1594 .as_ref()
1595 .map(|s| s.is_remote() || s.is_vllm_mlx())
1596 .unwrap_or(false);
1597
1598 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1599
1600 if is_delegated {
1601 let runner = runner::current_inference_runner().ok_or_else(|| {
1605 InferenceError::InferenceFailed(
1606 "model declares ModelSource::Delegated but no inference runner is registered \
1607 (call set_inference_runner / registerInferenceRunner / register_inference_runner)"
1608 .into(),
1609 )
1610 })?;
1611 let (tx, rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1612 let emitter = runner::EventEmitter::new(tx);
1613 let request = req.clone();
1614 tokio::spawn(async move {
1615 if let Err(e) = runner.run(request, emitter).await {
1616 tracing::warn!(error = %e, "delegated inference runner failed");
1617 }
1618 });
1619 return Ok(rx);
1620 }
1621
1622 if is_remote {
1623 let schema = schema.unwrap();
1624 self.remote_backend.register_model_keys(&schema).await;
1626
1627 self.remote_backend
1628 .generate_stream(
1629 &schema,
1630 &req.prompt,
1631 req.context.as_deref(),
1632 req.params.temperature,
1633 req.params.max_tokens,
1634 req.tools.as_deref(),
1635 req.images.as_deref(),
1636 req.params.tool_choice.as_deref(),
1637 req.params.parallel_tool_calls,
1638 req.response_format.as_ref(),
1639 )
1640 .await
1641 } else {
1642 let schema = schema.ok_or_else(|| {
1643 InferenceError::ModelNotFound(decision.model_id.clone())
1644 })?;
1645 let (tx, rx) = tokio::sync::mpsc::channel(64);
1646
1647 #[cfg(any(
1653 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
1654 all(target_os = "ios", target_arch = "aarch64")
1655 ))]
1656 {
1657 if schema.is_foundation_models() {
1658 let prompt = req.prompt.clone();
1659 let instructions = req.context.clone();
1660 let max_tokens = req.params.max_tokens as u32;
1661 let temperature = req.params.temperature;
1662 let tx_clone = tx.clone();
1663 tokio::task::spawn_blocking(move || {
1664 let accum = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
1671 let accum_cb = accum.clone();
1672 let cb = crate::backend::foundation_models::StreamCallback::new(
1673 move |delta: &str| {
1674 if let Ok(mut g) = accum_cb.lock() {
1675 g.push_str(delta);
1676 }
1677 tx_clone
1678 .blocking_send(stream::StreamEvent::TextDelta(
1679 delta.to_string(),
1680 ))
1681 .is_ok()
1682 },
1683 );
1684 let result = crate::backend::foundation_models::stream(
1685 &prompt,
1686 instructions.as_deref(),
1687 max_tokens,
1688 temperature as f32,
1689 cb,
1690 );
1691 let final_text = accum
1692 .lock()
1693 .map(|g| g.clone())
1694 .unwrap_or_default();
1695 let _ = tx.blocking_send(stream::StreamEvent::Done {
1696 text: final_text,
1697 tool_calls: vec![],
1698 });
1699 result
1700 });
1701 return Ok(rx);
1702 }
1703 }
1704
1705 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1707 {
1708 if !schema.is_mlx() {
1711 return Err(InferenceError::InferenceFailed(format!(
1712 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1713 schema.id
1714 )));
1715 }
1716 let backend = self.ensure_mlx_backend(&schema).await?;
1717 let model_id = schema.id.clone();
1718 let cache = Arc::clone(&self.mlx_backends);
1719 tokio::task::spawn_blocking(move || {
1725 let _ = Self::stream_local_mlx(backend, cache, model_id, req, tx);
1726 });
1727 return Ok(rx);
1728 }
1729
1730 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1731 {
1732 self.ensure_backend(&schema.name).await?;
1733 let backend = self.backend.clone();
1734 tokio::spawn(async move {
1735 let _ = Self::stream_local_candle(backend, req, tx).await;
1736 });
1737 Ok(rx)
1738 }
1739 }
1740 }
1741
1742 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1743 async fn stream_local_candle(
1744 backend_lock: Arc<RwLock<Option<CandleBackend>>>,
1745 req: GenerateRequest,
1746 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1747 ) -> Result<(), InferenceError> {
1748 let mut write = backend_lock.write().await;
1749 let backend = write
1750 .as_mut()
1751 .ok_or_else(|| InferenceError::InferenceFailed("backend not initialized".into()))?;
1752 backend.clear_kv_cache();
1753
1754 let formatted = tasks::generate::apply_chat_template(
1755 &req.prompt,
1756 req.context.as_deref(),
1757 req.params.thinking,
1758 );
1759 let tokens = backend.encode(&formatted)?;
1760 let eos = backend.eos_token_id();
1761 let eos_alt = backend.token_id("<|im_end|>");
1762 let params = &req.params;
1763
1764 if tokens.is_empty() {
1765 let _ = tx
1766 .send(stream::StreamEvent::Done {
1767 text: String::new(),
1768 tool_calls: vec![],
1769 })
1770 .await;
1771 return Ok(());
1772 }
1773
1774 let max_ctx = backend.context_length().unwrap_or(32768);
1775 let headroom = params.max_tokens.min(max_ctx / 4);
1776 let max_prompt = max_ctx.saturating_sub(headroom);
1777 let tokens = if tokens.len() > max_prompt {
1778 tokens[tokens.len() - max_prompt..].to_vec()
1779 } else {
1780 tokens
1781 };
1782
1783 let mut generated = Vec::new();
1784 let logits = backend.forward(&tokens, 0)?;
1785 let mut next_token = tasks::generate::sample_token(&logits, params)?;
1786
1787 for _ in 0..params.max_tokens {
1788 if eos.map_or(false, |id| next_token == id)
1789 || eos_alt.map_or(false, |id| next_token == id)
1790 {
1791 break;
1792 }
1793
1794 generated.push(next_token);
1795 let delta = backend.decode(&[next_token])?;
1796 if !delta.is_empty()
1797 && tx.send(stream::StreamEvent::TextDelta(delta)).await.is_err()
1798 {
1799 return Ok(());
1800 }
1801
1802 if !params.stop.is_empty() {
1803 let text_so_far = backend.decode(&generated)?;
1804 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1805 break;
1806 }
1807 }
1808
1809 let pos = tokens.len() + generated.len() - 1;
1810 let logits = backend.forward(&[next_token], pos)?;
1811 next_token = tasks::generate::sample_token(&logits, params)?;
1812 }
1813
1814 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1815 let _ = tx
1816 .send(stream::StreamEvent::Done {
1817 text,
1818 tool_calls: vec![],
1819 })
1820 .await;
1821 Ok(())
1822 }
1823
1824 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1828 fn stream_local_mlx(
1829 handle: backend_cache::CachedBackend<backend::MlxBackend>,
1830 cache: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
1831 model_id: String,
1832 req: GenerateRequest,
1833 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1834 ) -> Result<(), InferenceError> {
1835 let mut guard = handle.lock().map_err(|_| {
1836 InferenceError::InferenceFailed(format!(
1837 "MLX backend mutex poisoned for {model_id}"
1838 ))
1839 })?;
1840 let backend: &mut backend::MlxBackend = &mut *guard;
1841 backend.clear_kv_cache();
1842
1843 let formatted = tasks::generate::apply_chat_template(
1844 &req.prompt,
1845 req.context.as_deref(),
1846 req.params.thinking,
1847 );
1848 let tokens = backend.encode(&formatted)?;
1849 let eos = backend.eos_token_id();
1850 let eos_alt = backend.token_id("<|im_end|>");
1851 let params = &req.params;
1852
1853 if tokens.is_empty() {
1854 let _ = tx.blocking_send(stream::StreamEvent::Done {
1855 text: String::new(),
1856 tool_calls: vec![],
1857 });
1858 return Ok(());
1859 }
1860
1861 let max_ctx = backend.context_length();
1862 let headroom = params.max_tokens.min(max_ctx / 4);
1863 let max_prompt = max_ctx.saturating_sub(headroom);
1864 let tokens = if tokens.len() > max_prompt {
1865 tokens[tokens.len() - max_prompt..].to_vec()
1866 } else {
1867 tokens
1868 };
1869
1870 let mut generated = Vec::new();
1871 let logits = match Self::catch_mlx("stream prefill", || backend.forward(&tokens, 0)) {
1876 Ok(v) => v,
1877 Err(e) => { cache.invalidate(&model_id); return Err(e); }
1878 };
1879 let mut next_token = Self::sample_from_logits(&logits, params)?;
1880
1881 for _ in 0..params.max_tokens {
1882 if eos.map_or(false, |id| next_token == id)
1883 || eos_alt.map_or(false, |id| next_token == id)
1884 {
1885 break;
1886 }
1887
1888 generated.push(next_token);
1889 let delta = backend.decode(&[next_token])?;
1890 if !delta.is_empty()
1891 && tx.blocking_send(stream::StreamEvent::TextDelta(delta)).is_err()
1892 {
1893 return Ok(());
1894 }
1895
1896 if !params.stop.is_empty() {
1897 let text_so_far = backend.decode(&generated)?;
1898 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1899 break;
1900 }
1901 }
1902
1903 let pos = tokens.len() + generated.len() - 1;
1904 let logits = match Self::catch_mlx("stream forward", || backend.forward(&[next_token], pos)) {
1905 Ok(v) => v,
1906 Err(e) => { cache.invalidate(&model_id); return Err(e); }
1907 };
1908 next_token = Self::sample_from_logits(&logits, params)?;
1909 }
1910
1911 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1912 let _ = tx.blocking_send(stream::StreamEvent::Done {
1913 text,
1914 tool_calls: vec![],
1915 });
1916 Ok(())
1917 }
1918
1919 pub async fn route_context_snapshot(
1921 &self,
1922 prompt: &str,
1923 workload: RoutingWorkload,
1924 has_tools: bool,
1925 has_vision: bool,
1926 ) -> AdaptiveRoutingDecision {
1927 let tracker = self.outcome_tracker.read().await;
1928 self.adaptive_router.route_context_aware(
1929 prompt,
1930 0,
1931 &self.unified_registry,
1932 &tracker,
1933 has_tools,
1934 has_vision,
1935 workload,
1936 )
1937 }
1938
1939 pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
1942 Ok(self.generate_tracked(req).await?.text)
1943 }
1944
1945 pub async fn tokenize(
1957 &self,
1958 model: &str,
1959 text: &str,
1960 ) -> Result<Vec<u32>, InferenceError> {
1961 self.assert_local_for_tokenize(model)?;
1962
1963 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1964 {
1965 let schema = self
1966 .unified_registry
1967 .get(model)
1968 .or_else(|| self.unified_registry.find_by_name(model))
1969 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1970 .clone();
1971 let handle = self.ensure_mlx_backend(&schema).await?;
1972 let guard = handle.lock().map_err(|_| {
1973 InferenceError::InferenceFailed(format!(
1974 "MLX backend mutex poisoned for {}", schema.id
1975 ))
1976 })?;
1977 return guard.tokenize_raw(text);
1978 }
1979
1980 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1981 {
1982 self.ensure_backend(model).await?;
1983 let read = self.backend.read().await;
1984 let backend = read.as_ref().ok_or_else(|| {
1985 InferenceError::InferenceFailed(
1986 "candle backend missing after ensure_backend".to_string(),
1987 )
1988 })?;
1989 backend.tokenize_raw(text)
1990 }
1991 }
1992
1993 pub async fn detokenize(
1995 &self,
1996 model: &str,
1997 tokens: &[u32],
1998 ) -> Result<String, InferenceError> {
1999 self.assert_local_for_tokenize(model)?;
2000
2001 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2002 {
2003 let schema = self
2004 .unified_registry
2005 .get(model)
2006 .or_else(|| self.unified_registry.find_by_name(model))
2007 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
2008 .clone();
2009 let handle = self.ensure_mlx_backend(&schema).await?;
2010 let guard = handle.lock().map_err(|_| {
2011 InferenceError::InferenceFailed(format!(
2012 "MLX backend mutex poisoned for {}", schema.id
2013 ))
2014 })?;
2015 return guard.detokenize_raw(tokens);
2016 }
2017
2018 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2019 {
2020 self.ensure_backend(model).await?;
2021 let read = self.backend.read().await;
2022 let backend = read.as_ref().ok_or_else(|| {
2023 InferenceError::InferenceFailed(
2024 "candle backend missing after ensure_backend".to_string(),
2025 )
2026 })?;
2027 backend.detokenize_raw(tokens)
2028 }
2029 }
2030
2031 fn assert_local_for_tokenize(&self, model: &str) -> Result<(), InferenceError> {
2035 if let Some(schema) = self
2036 .unified_registry
2037 .get(model)
2038 .or_else(|| self.unified_registry.find_by_name(model))
2039 {
2040 if !schema.is_local() {
2041 return Err(InferenceError::UnsupportedMode {
2042 mode: "tokenize/detokenize",
2043 backend: "remote",
2044 reason:
2045 "remote provider tokenizer is not exposed by the runtime; \
2046 use a local model (Qwen3 GGUF / MLX) for tokenizer-correctness checks",
2047 });
2048 }
2049 }
2050 Ok(())
2052 }
2053
2054 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2060 fn catch_mlx<F, T>(context: &str, f: F) -> Result<T, InferenceError>
2061 where
2062 F: FnOnce() -> Result<T, InferenceError>,
2063 {
2064 std::panic::catch_unwind(std::panic::AssertUnwindSafe(f))
2065 .map_err(|e| InferenceError::InferenceFailed(
2066 format!("MLX panicked during {context}: {e:?}")
2067 ))?
2068 }
2069
2070 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2073 async fn generate_mlx(
2074 &self,
2075 req: GenerateRequest,
2076 model_id: &str,
2077 ) -> Result<(String, Option<u64>), InferenceError> {
2078 let start = std::time::Instant::now();
2079
2080 let schema = self
2081 .unified_registry
2082 .get(model_id)
2083 .cloned()
2084 .ok_or_else(|| InferenceError::InferenceFailed(format!(
2085 "generate_mlx: unknown schema id {model_id}"
2086 )))?;
2087 let handle = self.ensure_mlx_backend(&schema).await?;
2088 let mut guard = handle.lock().map_err(|_| {
2089 InferenceError::InferenceFailed(format!(
2090 "MLX backend mutex poisoned for {model_id}"
2091 ))
2092 })?;
2093 let backend: &mut backend::MlxBackend = &mut *guard;
2094 backend.clear_kv_cache();
2095
2096 let formatted = tasks::generate::apply_chat_template(
2097 &req.prompt,
2098 req.context.as_deref(),
2099 req.params.thinking,
2100 );
2101 let tokens = backend.encode(&formatted)?;
2102 let eos = backend.eos_token_id();
2103 let eos_alt = backend.token_id("<|im_end|>");
2104 let params = &req.params;
2105
2106 if tokens.is_empty() {
2107 return Ok((String::new(), None));
2108 }
2109
2110 let max_ctx = backend.context_length();
2112 let headroom = params.max_tokens.min(max_ctx / 4);
2113 let max_prompt = max_ctx.saturating_sub(headroom);
2114 let tokens = if tokens.len() > max_prompt {
2115 tokens[tokens.len() - max_prompt..].to_vec()
2116 } else {
2117 tokens
2118 };
2119
2120 let mut generated = Vec::new();
2121
2122 let logits = match Self::catch_mlx("prefill", || backend.forward(&tokens, 0)) {
2128 Ok(v) => v,
2129 Err(e) => {
2130 drop(guard);
2131 self.mlx_backends.invalidate(model_id);
2132 return Err(e);
2133 }
2134 };
2135 let mut next_token = Self::sample_from_logits(&logits, params)?;
2136 let ttft_ms = Some(start.elapsed().as_millis() as u64);
2137
2138 for _ in 0..params.max_tokens {
2139 if eos.map_or(false, |id| next_token == id)
2140 || eos_alt.map_or(false, |id| next_token == id)
2141 {
2142 break;
2143 }
2144
2145 generated.push(next_token);
2146
2147 if !params.stop.is_empty() {
2148 let text_so_far = backend.decode(&generated)?;
2149 if params.stop.iter().any(|s| text_so_far.contains(s)) {
2150 break;
2151 }
2152 }
2153
2154 let pos = tokens.len() + generated.len() - 1;
2155 let logits = match Self::catch_mlx("forward", || backend.forward(&[next_token], pos)) {
2156 Ok(v) => v,
2157 Err(e) => {
2158 drop(guard);
2159 self.mlx_backends.invalidate(model_id);
2160 return Err(e);
2161 }
2162 };
2163 next_token = Self::sample_from_logits(&logits, params)?;
2164 }
2165
2166 let text = backend.decode(&generated)?;
2167 Ok((tasks::generate::strip_thinking(&text, params.thinking), ttft_ms))
2168 }
2169
2170 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2172 fn sample_from_logits(logits: &[f32], params: &GenerateParams) -> Result<u32, InferenceError> {
2173 if params.temperature <= 0.0 {
2174 let (idx, _) = logits
2176 .iter()
2177 .enumerate()
2178 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2179 .ok_or_else(|| InferenceError::InferenceFailed("empty logits".into()))?;
2180 return Ok(idx as u32);
2181 }
2182
2183 let temp = params.temperature as f32;
2185 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
2186 let mut probs: Vec<f32> = logits
2187 .iter()
2188 .map(|&l| ((l - max_logit) / temp).exp())
2189 .collect();
2190 let sum: f32 = probs.iter().sum();
2191 for p in &mut probs {
2192 *p /= sum;
2193 }
2194
2195 if params.top_p < 1.0 {
2197 let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
2198 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2199 let mut cumsum = 0.0;
2200 let mut cutoff_idx = indexed.len();
2201 for (i, &(_, p)) in indexed.iter().enumerate() {
2202 cumsum += p;
2203 if cumsum > params.top_p as f32 {
2204 cutoff_idx = i + 1;
2205 break;
2206 }
2207 }
2208 let allowed: std::collections::HashSet<usize> =
2210 indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
2211 for (i, p) in probs.iter_mut().enumerate() {
2212 if !allowed.contains(&i) {
2213 *p = 0.0;
2214 }
2215 }
2216 let sum: f32 = probs.iter().sum();
2217 if sum > 0.0 {
2218 for p in &mut probs {
2219 *p /= sum;
2220 }
2221 }
2222 }
2223
2224 use rand::Rng;
2226 let mut rng = rand::rng();
2227 let r: f32 = rng.random();
2228 let mut cumsum = 0.0;
2229 for (i, &p) in probs.iter().enumerate() {
2230 cumsum += p;
2231 if cumsum >= r {
2232 return Ok(i as u32);
2233 }
2234 }
2235 Ok((probs.len() - 1) as u32)
2236 }
2237
2238 pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
2241 let instruction = req
2242 .instruction
2243 .as_deref()
2244 .unwrap_or("Retrieve relevant memory facts");
2245
2246 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2247 {
2248 let model_id = self.ensure_mlx_embedding_backend().await?;
2249 let schema = self
2250 .unified_registry
2251 .get(&model_id)
2252 .cloned()
2253 .ok_or_else(|| InferenceError::InferenceFailed(format!(
2254 "embed: unknown schema id {model_id}"
2255 )))?;
2256 let handle = self.ensure_mlx_backend(&schema).await?;
2257 let mut guard = handle.lock().map_err(|_| {
2258 InferenceError::InferenceFailed(format!(
2259 "MLX embedding backend mutex poisoned for {model_id}"
2260 ))
2261 })?;
2262 let backend: &mut backend::MlxBackend = &mut *guard;
2263
2264 let mut results = Vec::with_capacity(req.texts.len());
2265 for text in &req.texts {
2266 let embedding = if req.is_query {
2267 backend.embed_query(text, instruction)?
2268 } else {
2269 backend.embed_one(text)?
2270 };
2271 results.push(embedding);
2272 }
2273 return Ok(results);
2274 }
2275
2276 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2277 {
2278 self.ensure_embedding_backend().await?;
2279 let mut write = self.embedding_backend.write().await;
2280 let backend = write.as_mut().unwrap();
2281
2282 let mut results = Vec::with_capacity(req.texts.len());
2283 for text in &req.texts {
2284 let embedding = if req.is_query {
2285 backend.embed_query(text, instruction)?
2286 } else {
2287 backend.embed_one(text)?
2288 };
2289 results.push(embedding);
2290 }
2291 Ok(results)
2292 }
2293 }
2294
2295 pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResult, InferenceError> {
2325 if req.documents.is_empty() {
2326 return Ok(RerankResult { ranked: Vec::new(), model_used: None });
2327 }
2328
2329 let model_name = match req.model.clone() {
2330 Some(m) => m,
2331 None => self
2332 .preferred_model_for_capability(ModelCapability::Rerank)
2333 .map(str::to_string)
2334 .ok_or_else(|| {
2335 InferenceError::InferenceFailed(
2336 "no reranker model available — pull a Qwen3-Reranker model first".into(),
2337 )
2338 })?,
2339 };
2340
2341 let schema = self
2342 .unified_registry
2343 .find_by_name(&model_name)
2344 .or_else(|| self.unified_registry.get(&model_name))
2345 .cloned()
2346 .ok_or_else(|| {
2347 InferenceError::InferenceFailed(format!(
2348 "rerank: unknown reranker model {model_name}"
2349 ))
2350 })?;
2351 if !schema.has_capability(ModelCapability::Rerank) {
2352 return Err(InferenceError::InferenceFailed(format!(
2353 "model {} does not declare the Rerank capability",
2354 schema.name
2355 )));
2356 }
2357
2358 let instruction = req.instruction.as_deref().unwrap_or(
2359 "Given a web search query, retrieve relevant passages that answer the query",
2360 );
2361
2362 let mut scored: Vec<RerankedDocument> = Vec::with_capacity(req.documents.len());
2363 for (idx, doc) in req.documents.iter().enumerate() {
2364 let prompt = rerank_prompt(instruction, &req.query, doc);
2365 let gen_req = GenerateRequest {
2366 prompt,
2367 model: Some(schema.id.clone()),
2368 params: tasks::generate::GenerateParams {
2369 temperature: 0.0,
2370 max_tokens: 3,
2374 thinking: tasks::generate::ThinkingMode::Off,
2375 ..Default::default()
2376 },
2377 context: None,
2378 tools: None,
2379 images: None,
2380 messages: None,
2381 cache_control: false,
2382 response_format: None,
2383 intent: None,
2384 };
2385 let out = self.generate(gen_req).await?;
2386 let score = score_from_rerank_output(&out, &schema.name);
2387 scored.push(RerankedDocument {
2388 index: idx,
2389 score,
2390 document: doc.clone(),
2391 });
2392 }
2393
2394 scored.sort_by(|a, b| {
2397 b.score
2398 .partial_cmp(&a.score)
2399 .unwrap_or(std::cmp::Ordering::Equal)
2400 .then_with(|| a.index.cmp(&b.index))
2401 });
2402 if let Some(n) = req.top_n {
2403 scored.truncate(n);
2404 }
2405
2406 Ok(RerankResult {
2407 ranked: scored,
2408 model_used: Some(schema.name),
2409 })
2410 }
2411
2412 pub async fn ground(&self, req: GroundRequest) -> Result<GroundResult, InferenceError> {
2422 let model_name = match req.model.clone() {
2423 Some(m) => m,
2424 None => self
2425 .preferred_model_for_capability(ModelCapability::Grounding)
2426 .map(str::to_string)
2427 .ok_or_else(|| {
2428 InferenceError::InferenceFailed(
2429 "no grounding-capable model available — pull a Qwen2.5-VL model first"
2430 .into(),
2431 )
2432 })?,
2433 };
2434
2435 let gen_req = GenerateRequest {
2436 prompt: req.prompt.clone(),
2437 model: Some(model_name),
2438 params: GenerateParams::default(),
2439 context: None,
2440 tools: None,
2441 images: Some(vec![req.image.clone()]),
2442 messages: None,
2443 cache_control: false,
2444 response_format: None,
2445 intent: None,
2446 };
2447 let result = self.generate_tracked(gen_req).await?;
2448 Ok(GroundResult {
2449 boxes: result.bounding_boxes,
2450 raw_text: result.text,
2451 model_used: Some(result.model_used),
2452 })
2453 }
2454
2455 pub async fn classify(
2458 &self,
2459 req: ClassifyRequest,
2460 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2461 let model = match req
2462 .model
2463 .clone()
2464 .or_else(|| self.preferred_model_for_capability(ModelCapability::Classify).map(str::to_string))
2465 {
2466 Some(m) => m,
2467 None => {
2468 let m = self.router.route_small(&self.registry);
2469 debug!(model = %m, "auto-routed classify request");
2470 m
2471 }
2472 };
2473
2474 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2477 {
2478 return self.classify_via_generate(req, &model).await;
2479 }
2480
2481 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2482 {
2483 self.ensure_backend(&model).await?;
2484 let mut write = self.backend.write().await;
2485 let backend = write.as_mut().unwrap();
2486 tasks::classify::classify(backend, req).await
2487 }
2488 }
2489
2490 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2492 async fn classify_via_generate(
2493 &self,
2494 req: ClassifyRequest,
2495 model: &str,
2496 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2497 let labels_str = req
2498 .labels
2499 .iter()
2500 .enumerate()
2501 .map(|(i, l)| format!("{}. {}", i + 1, l))
2502 .collect::<Vec<_>>()
2503 .join("\n");
2504
2505 let prompt = format!(
2506 "Classify the following text into one of these categories:\n\
2507 {labels_str}\n\n\
2508 Text: {}\n\n\
2509 Respond with ONLY the category name, nothing else.",
2510 req.text
2511 );
2512
2513 let gen_req = GenerateRequest {
2514 prompt,
2515 model: Some(model.to_string()),
2516 params: tasks::generate::GenerateParams {
2517 temperature: 0.0,
2518 max_tokens: 32,
2519 thinking: tasks::generate::ThinkingMode::Off,
2522 ..Default::default()
2523 },
2524 context: None,
2525 tools: None,
2526 images: None,
2527 messages: None,
2528 cache_control: false,
2529 response_format: None,
2530 intent: None,
2531 };
2532
2533 let response = self.generate(gen_req).await?;
2534 let response_lower = response.trim().to_lowercase();
2535
2536 let mut results: Vec<ClassifyResult> = req
2537 .labels
2538 .iter()
2539 .map(|label| {
2540 let label_lower = label.to_lowercase();
2541 let score = if response_lower == label_lower {
2542 1.0
2543 } else if response_lower.contains(&label_lower) {
2544 0.8
2545 } else {
2546 let label_words: Vec<&str> = label_lower.split_whitespace().collect();
2547 let matches = label_words
2548 .iter()
2549 .filter(|w| response_lower.contains(**w))
2550 .count();
2551 if label_words.is_empty() {
2552 0.0
2553 } else {
2554 0.5 * (matches as f64 / label_words.len() as f64)
2555 }
2556 };
2557 ClassifyResult {
2558 label: label.clone(),
2559 score,
2560 }
2561 })
2562 .collect();
2563
2564 results.sort_by(|a, b| {
2565 b.score
2566 .partial_cmp(&a.score)
2567 .unwrap_or(std::cmp::Ordering::Equal)
2568 });
2569
2570 let total: f64 = results.iter().map(|r| r.score).sum();
2571 if total > 0.0 {
2572 for r in &mut results {
2573 r.score /= total;
2574 }
2575 }
2576
2577 Ok(results)
2578 }
2579
2580 pub async fn transcribe(
2582 &self,
2583 req: TranscribeRequest,
2584 ) -> Result<TranscribeResult, InferenceError> {
2585 let candidates =
2586 self.speech_candidates(ModelCapability::SpeechToText, req.model.as_deref())?;
2587 let mut last_error = None;
2588
2589 for schema in candidates {
2590 let result = match &schema.source {
2591 ModelSource::Mlx { .. } => self.transcribe_local_mlx(&schema, &req).await,
2592 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2593 self.transcribe_elevenlabs(&schema, &req).await
2594 }
2595 _ => Err(InferenceError::InferenceFailed(format!(
2596 "speech-to-text not implemented for model source: {}",
2597 schema.id
2598 ))),
2599 };
2600
2601 match result {
2602 Ok(result) => return Ok(result),
2603 Err(err) => last_error = Some(err),
2604 }
2605 }
2606
2607 Err(last_error.unwrap_or_else(|| {
2608 InferenceError::InferenceFailed("no speech-to-text models available".into())
2609 }))
2610 }
2611
2612 pub async fn synthesize(
2614 &self,
2615 req: SynthesizeRequest,
2616 ) -> Result<SynthesizeResult, InferenceError> {
2617 let candidates =
2618 self.speech_candidates(ModelCapability::TextToSpeech, req.model.as_deref())?;
2619 let mut last_error = None;
2620
2621 for schema in candidates {
2622 let result = match &schema.source {
2623 ModelSource::Mlx { .. } => self.synthesize_local_mlx(&schema, &req).await,
2624 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2625 self.synthesize_elevenlabs(&schema, &req).await
2626 }
2627 _ => Err(InferenceError::InferenceFailed(format!(
2628 "text-to-speech not implemented for model source: {}",
2629 schema.id
2630 ))),
2631 };
2632
2633 match result {
2634 Ok(result) => return Ok(result),
2635 Err(err) => last_error = Some(err),
2636 }
2637 }
2638
2639 Err(last_error.unwrap_or_else(|| {
2640 InferenceError::InferenceFailed("no text-to-speech models available".into())
2641 }))
2642 }
2643
2644 pub async fn generate_image(
2646 &self,
2647 req: GenerateImageRequest,
2648 ) -> Result<GenerateImageResult, InferenceError> {
2649 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2656 {
2657 use crate::backend::external_flux;
2658 let backend = std::env::var("CAR_IMAGE_BACKEND")
2659 .unwrap_or_else(|_| "native".to_string());
2660 let use_external = match backend.as_str() {
2661 "external" => true,
2662 "native" => false,
2663 _ => external_flux::is_available() && backend == "auto-external",
2665 };
2666 if use_external {
2667 tracing::info!(
2668 "routing image generation to external mflux \
2669 (set CAR_IMAGE_BACKEND=native to use the Rust port)"
2670 );
2671 let mut req = req;
2672 req.model = self.resolve_external_hf_repo(
2673 req.model.as_deref(),
2674 ModelCapability::ImageGeneration,
2675 );
2676 return external_flux::generate_image(&req);
2677 }
2678 tracing::info!("using native Rust MLX Flux backend");
2679 }
2680
2681 let candidates =
2682 self.media_generation_candidates(ModelCapability::ImageGeneration, req.model.as_deref())?;
2683 let mut last_error = None;
2684
2685 for schema in candidates {
2686 let result = match &schema.source {
2687 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2688 ModelSource::Mlx { .. } => self.generate_image_native_mlx(&schema, &req).await,
2689 _ => Err(InferenceError::InferenceFailed(format!(
2690 "image generation not implemented for model source: {}",
2691 schema.id
2692 ))),
2693 };
2694
2695 match result {
2696 Ok(result) => return Ok(result),
2697 Err(err) => last_error = Some(err),
2698 }
2699 }
2700
2701 Err(last_error.unwrap_or_else(|| {
2702 InferenceError::InferenceFailed("no image generation models available".into())
2703 }))
2704 }
2705
2706 pub async fn generate_image_batch(
2722 &self,
2723 req: GenerateImageRequest,
2724 ) -> Result<Vec<GenerateImageResult>, InferenceError> {
2725 let count = req.variant_count.unwrap_or(1).max(1);
2726 if count == 1 {
2727 return self.generate_image(req).await.map(|r| vec![r]);
2728 }
2729 let base_seed = req.seed.unwrap_or(0);
2730 let mut results = Vec::with_capacity(count as usize);
2731 for i in 0..count {
2732 let mut variant_req = req.clone();
2737 variant_req.seed = Some(base_seed.wrapping_add(i as u64));
2738 variant_req.variant_count = Some(1);
2742 results.push(self.generate_image(variant_req).await?);
2743 }
2744 Ok(results)
2745 }
2746
2747 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2749 async fn generate_image_native_mlx(
2750 &self,
2751 schema: &ModelSchema,
2752 req: &GenerateImageRequest,
2753 ) -> Result<GenerateImageResult, InferenceError> {
2754 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2755 let size = backend_cache::estimate_model_size(&model_dir);
2756 let cache = Arc::clone(&self.flux_cache);
2757 let key = schema.id.clone();
2758 let handle = cache.get_or_load(&key, size, || {
2759 backend::mlx_flux::FluxBackend::load(&model_dir)
2760 })?;
2761 let req = req.clone();
2765 tokio::task::spawn_blocking(move || -> Result<GenerateImageResult, InferenceError> {
2766 let mut guard = handle.lock().map_err(|_| {
2767 InferenceError::InferenceFailed("flux backend mutex poisoned".into())
2768 })?;
2769 guard.generate(&req)
2770 })
2771 .await
2772 .map_err(|e| InferenceError::InferenceFailed(format!("flux task join: {e}")))?
2773 }
2774
2775 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2777 async fn generate_video_native_mlx(
2778 &self,
2779 schema: &ModelSchema,
2780 req: &GenerateVideoRequest,
2781 ) -> Result<GenerateVideoResult, InferenceError> {
2782 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2783 let size = backend_cache::estimate_model_size(&model_dir);
2784 let cache = Arc::clone(&self.ltx_cache);
2785 let key = schema.id.clone();
2786 let handle = cache.get_or_load(&key, size, || {
2787 backend::mlx_ltx::LtxBackend::load(&model_dir)
2788 })?;
2789 let req = req.clone();
2790 tokio::task::spawn_blocking(move || -> Result<GenerateVideoResult, InferenceError> {
2791 let mut guard = handle.lock().map_err(|_| {
2792 InferenceError::InferenceFailed("ltx backend mutex poisoned".into())
2793 })?;
2794 guard.generate(&req)
2795 })
2796 .await
2797 .map_err(|e| InferenceError::InferenceFailed(format!("ltx task join: {e}")))?
2798 }
2799
2800 pub async fn generate_video(
2802 &self,
2803 req: GenerateVideoRequest,
2804 ) -> Result<GenerateVideoResult, InferenceError> {
2805 if let Err(msg) = req.validate() {
2808 return Err(InferenceError::InferenceFailed(format!(
2809 "invalid GenerateVideoRequest: {}",
2810 msg
2811 )));
2812 }
2813 if req.requires_audio_passthrough_opt_in() {
2824 return Err(InferenceError::InferenceFailed(
2825 "audio_path is set but audio_passthrough is false. \
2826 Audio-reference conditioning (driving video timing off the audio bytes) \
2827 is not yet implemented — tracked at Parslee-ai/car#130. \
2828 For the muxing-only workflow (audio path recorded for downstream tooling; \
2829 video frames stay text-only), set audio_passthrough=true on the request \
2830 (the `car video --audio-mux` CLI flag does this implicitly)."
2831 .to_string(),
2832 ));
2833 }
2834 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2839 {
2840 use crate::backend::external_ltx;
2841 let backend = std::env::var("CAR_VIDEO_BACKEND")
2846 .unwrap_or_else(|_| "native".to_string());
2847 let use_external = match backend.as_str() {
2848 "external" => true,
2849 "native" => false,
2850 "auto-external" => external_ltx::is_available(),
2854 _ => false,
2855 };
2856 if use_external {
2857 tracing::info!("CAR_VIDEO_BACKEND requested external LTX routing for LTX-family models");
2858 } else {
2859 tracing::info!("using family-aware MLX video routing");
2860 }
2861 }
2862
2863 let candidates =
2864 self.media_generation_candidates(ModelCapability::VideoGeneration, req.model.as_deref())?;
2865 let mut last_error = None;
2866
2867 for schema in candidates {
2868 let result = match &schema.source {
2869 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2870 ModelSource::Mlx { hf_repo, .. } => {
2871 if crate::backend::external_mlx_video::is_wan_family(&schema) {
2872 match self.unified_registry.ensure_local(&schema.id).await {
2873 Ok(model_dir) => crate::backend::external_mlx_video::generate_wan_video(&schema, &model_dir, &req),
2874 Err(err) => Err(err),
2875 }
2876 } else {
2877 let backend = std::env::var("CAR_VIDEO_BACKEND")
2878 .unwrap_or_else(|_| "native".to_string());
2879 let use_external_ltx = match backend.as_str() {
2880 "external" => true,
2881 "native" => false,
2882 "auto-external" => crate::backend::external_ltx::is_available(),
2883 _ => false,
2884 };
2885 if use_external_ltx {
2886 let mut req = req.clone();
2887 req.model = Some(hf_repo.clone());
2888 crate::backend::external_ltx::generate_video(&req)
2889 } else {
2890 self.generate_video_native_mlx(&schema, &req).await
2891 }
2892 }
2893 }
2894 _ => Err(InferenceError::InferenceFailed(format!(
2895 "video generation not implemented for model source: {}",
2896 schema.id
2897 ))),
2898 };
2899
2900 match result {
2901 Ok(result) => return Ok(result),
2902 Err(err) => last_error = Some(err),
2903 }
2904 }
2905
2906 Err(last_error.unwrap_or_else(|| {
2907 InferenceError::InferenceFailed("no video generation models available".into())
2908 }))
2909 }
2910
2911 pub fn list_models_unified(&self) -> Vec<ModelInfo> {
2913 self.unified_registry
2914 .list()
2915 .iter()
2916 .map(|m| ModelInfo::from(*m))
2917 .collect()
2918 }
2919
2920 pub fn available_model_upgrades(&self) -> Vec<ModelUpgrade> {
2922 self.unified_registry.available_upgrades()
2923 }
2924
2925 pub fn list_schemas(&self) -> Vec<ModelSchema> {
2928 self.unified_registry.list().into_iter().cloned().collect()
2929 }
2930
2931 pub fn list_models(&self) -> Vec<models::ModelInfo> {
2932 self.registry.list_models()
2933 }
2934
2935 pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
2937 let schema = self
2938 .unified_registry
2939 .find_by_name(name)
2940 .or_else(|| self.unified_registry.get(name))
2941 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2942 self.unified_registry.ensure_local(&schema.id).await
2943 }
2944
2945 pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
2947 let schema = self
2948 .unified_registry
2949 .get(name)
2950 .or_else(|| {
2951 self.unified_registry
2952 .list()
2953 .into_iter()
2954 .find(|schema| schema.name.eq_ignore_ascii_case(name))
2955 })
2956 .or_else(|| self.unified_registry.find_by_name(name))
2957 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2958 let model_dir = self.unified_registry.models_dir().join(&schema.name);
2959 if model_dir.exists() {
2960 std::fs::remove_dir_all(&model_dir)?;
2961 }
2962 match &schema.source {
2963 ModelSource::Mlx { hf_repo, .. } => {
2964 remove_huggingface_repo_cache(hf_repo)?;
2965 }
2966 ModelSource::Local {
2967 hf_repo,
2968 tokenizer_repo,
2969 ..
2970 } => {
2971 remove_huggingface_repo_cache(hf_repo)?;
2972 remove_huggingface_repo_cache(tokenizer_repo)?;
2973 }
2974 _ => {}
2975 }
2976 Ok(())
2977 }
2978
2979 pub fn register_model(&mut self, schema: ModelSchema) {
2981 self.unified_registry.register(schema);
2982 }
2983
2984 pub async fn discover_vllm_mlx_models(&mut self) -> usize {
2987 let config = vllm_mlx::VllmMlxConfig::default();
2988 if !config.auto_discover {
2989 return 0;
2990 }
2991 vllm_mlx::discover_and_register(&config, &mut self.unified_registry).await
2992 }
2993
2994 pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
2996 self.outcome_tracker.clone()
2997 }
2998
2999 async fn auto_save_outcomes(&self) {
3001 if let Err(e) = self.save_outcomes().await {
3002 tracing::debug!("auto-save outcomes failed: {}", e);
3003 }
3004 if let Err(e) = self.save_key_pool_stats().await {
3005 tracing::debug!("auto-save key pool stats failed: {}", e);
3006 }
3007 }
3008
3009 pub async fn save_outcomes(&self) -> Result<(), std::io::Error> {
3011 let tracker = self.outcome_tracker.read().await;
3012 let path = self.config.models_dir.join("outcome_profiles.json");
3013 tracker.save_to_file(&path)
3014 }
3015
3016 pub async fn save_key_pool_stats(&self) -> Result<(), std::io::Error> {
3018 let path = self.config.models_dir.join("key_pool_stats.json");
3019 self.remote_backend.key_pool.save_stats(&path).await
3020 }
3021
3022 pub async fn key_pool_stats(
3024 &self,
3025 ) -> std::collections::HashMap<String, Vec<key_pool::KeyStats>> {
3026 self.remote_backend.key_pool.all_stats().await
3027 }
3028
3029 pub async fn export_profiles(&self) -> Vec<ModelProfile> {
3031 let tracker = self.outcome_tracker.read().await;
3032 tracker.export_profiles()
3033 }
3034
3035 pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
3037 let mut tracker = self.outcome_tracker.write().await;
3038 tracker.import_profiles(profiles);
3039 }
3040
3041 pub async fn prepare_speech_runtime(&self) -> Result<PathBuf, InferenceError> {
3044 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3045 {
3046 Ok(self.config.models_dir.clone())
3048 }
3049 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3050 {
3051 Ok(self.ensure_speech_runtime().await?.root)
3052 }
3053 }
3054
3055 pub fn set_speech_policy(&mut self, policy: SpeechPolicy) {
3057 self.speech_policy = policy;
3058 }
3059
3060 pub fn set_routing_config(&mut self, config: RoutingConfig) {
3061 self.adaptive_router.set_config(config);
3062 }
3063
3064 pub async fn install_curated_speech(
3066 &mut self,
3067 ) -> Result<Vec<SpeechInstallReport>, InferenceError> {
3068 let _runtime_root = self.prepare_speech_runtime().await?;
3069 let schemas = self.list_schemas();
3070 let mut repos = Vec::new();
3071 for schema in &schemas {
3072 if !schema.is_mlx() || !schema.tags.iter().any(|tag| tag == "speech") {
3073 continue;
3074 }
3075 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3076 if !repos.iter().any(|existing: &String| existing == hf_repo) {
3077 repos.push(hf_repo.clone());
3078 }
3079 }
3080 }
3081
3082 let mut installed = Vec::new();
3083 for repo in repos {
3084 let (snapshot_path, files_downloaded) = download_hf_repo_snapshot(&repo).await?;
3085 let name = schemas
3086 .iter()
3087 .find(|schema| {
3088 matches!(&schema.source, ModelSource::Mlx { hf_repo, .. } if hf_repo == &repo)
3089 })
3090 .map(|schema| schema.name.clone())
3091 .unwrap_or_else(|| repo.clone());
3092 installed.push(SpeechInstallReport {
3093 name,
3094 hf_repo: repo,
3095 snapshot_path,
3096 files_downloaded,
3097 });
3098 }
3099
3100 self.unified_registry.refresh_availability();
3101 Ok(installed)
3102 }
3103
3104
3105 pub fn speech_health(&self) -> SpeechHealthReport {
3107 let local_stt_default = self
3108 .speech_health_default_name(ModelCapability::SpeechToText, true, false);
3109 let local_tts_default = self
3110 .speech_health_default_name(ModelCapability::TextToSpeech, true, false);
3111 let remote_stt_default = self
3112 .speech_health_default_name(ModelCapability::SpeechToText, false, true);
3113 let remote_tts_default = self
3114 .speech_health_default_name(ModelCapability::TextToSpeech, false, true);
3115
3116 let mut local_models = Vec::new();
3117 let mut remote_models = Vec::new();
3118 for schema in self.list_schemas() {
3119 let capability = if schema.has_capability(ModelCapability::SpeechToText) {
3120 Some(ModelCapability::SpeechToText)
3121 } else if schema.has_capability(ModelCapability::TextToSpeech) {
3122 Some(ModelCapability::TextToSpeech)
3123 } else {
3124 None
3125 };
3126 let Some(capability) = capability else {
3127 continue;
3128 };
3129
3130 let selected_by_default = local_stt_default
3131 .as_ref()
3132 .is_some_and(|name| name == &schema.name)
3133 || local_tts_default
3134 .as_ref()
3135 .is_some_and(|name| name == &schema.name)
3136 || remote_stt_default
3137 .as_ref()
3138 .is_some_and(|name| name == &schema.name)
3139 || remote_tts_default
3140 .as_ref()
3141 .is_some_and(|name| name == &schema.name);
3142
3143 let health = SpeechModelHealth {
3144 id: schema.id.clone(),
3145 name: schema.name.clone(),
3146 provider: schema.provider.clone(),
3147 capability,
3148 is_local: schema.is_local(),
3149 available: schema.available,
3150 cached: speech_model_cached(&schema),
3151 selected_by_default,
3152 source: speech_model_source_label(&schema),
3153 };
3154 if schema.is_local() {
3155 local_models.push(health);
3156 } else {
3157 remote_models.push(health);
3158 }
3159 }
3160
3161 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3163 let runtime = SpeechRuntimeHealth {
3164 root: self.config.models_dir.clone(),
3165 installed: true,
3166 python: PathBuf::new(),
3167 stt_command: PathBuf::new(),
3168 tts_command: PathBuf::new(),
3169 configured_python: None,
3170 detected_python: None,
3171 };
3172
3173 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3174 let runtime = {
3175 let rt = SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3176 SpeechRuntimeHealth {
3177 root: rt.root.clone(),
3178 installed: rt.is_ready(),
3179 python: rt.python.clone(),
3180 stt_command: rt.stt_program.clone(),
3181 tts_command: rt.tts_program.clone(),
3182 configured_python: std::env::var("CAR_SPEECH_PYTHON")
3183 .ok()
3184 .filter(|value| !value.trim().is_empty()),
3185 detected_python: detect_speech_python(),
3186 }
3187 };
3188
3189 SpeechHealthReport {
3190 runtime,
3191 local_models,
3192 remote_models,
3193 elevenlabs_configured: car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY").is_some(),
3194 prefer_local: self.speech_policy.prefer_local,
3195 allow_remote_fallback: self.speech_policy.allow_remote_fallback,
3196 preferred_local_stt: self.speech_policy.preferred_local_stt.clone(),
3197 preferred_local_tts: self.speech_policy.preferred_local_tts.clone(),
3198 preferred_remote_stt: self.speech_policy.preferred_remote_stt.clone(),
3199 preferred_remote_tts: self.speech_policy.preferred_remote_tts.clone(),
3200 local_stt_default,
3201 local_tts_default,
3202 remote_stt_default,
3203 remote_tts_default,
3204 }
3205 }
3206
3207 pub async fn model_health(&self) -> ModelHealthReport {
3210 let schemas = self.list_schemas();
3211 let total_models = schemas.len();
3212 let available_models = schemas.iter().filter(|schema| schema.available).count();
3213 let local_models = schemas.iter().filter(|schema| schema.is_local()).count();
3214 let remote_models = total_models.saturating_sub(local_models);
3215
3216 let defaults = vec![
3217 self.model_default_health(
3218 ModelCapability::Generate,
3219 self.preferred_model_for_capability(ModelCapability::Generate)
3220 .unwrap_or(&self.config.generation_model),
3221 ),
3222 self.model_default_health(
3223 ModelCapability::Embed,
3224 self.preferred_model_for_capability(ModelCapability::Embed)
3225 .unwrap_or(&self.config.embedding_model),
3226 ),
3227 self.model_default_health(
3228 ModelCapability::Classify,
3229 self.preferred_model_for_capability(ModelCapability::Classify)
3230 .unwrap_or(&self.config.classification_model),
3231 ),
3232 ];
3233
3234 let mut providers = std::collections::BTreeMap::new();
3235 for schema in &schemas {
3236 let entry = providers
3237 .entry(schema.provider.clone())
3238 .or_insert_with(|| ProviderAccumulator {
3239 configured: false,
3240 local_models: 0,
3241 remote_models: 0,
3242 available_models: 0,
3243 capabilities: std::collections::HashSet::new(),
3244 });
3245
3246 entry.configured |= model_source_configured(schema);
3247 if schema.is_local() {
3248 entry.local_models += 1;
3249 } else {
3250 entry.remote_models += 1;
3251 }
3252 if schema.available {
3253 entry.available_models += 1;
3254 }
3255 for capability in &schema.capabilities {
3256 entry.capabilities.insert(*capability);
3257 }
3258 }
3259
3260 let providers = providers
3261 .into_iter()
3262 .map(|(provider, acc)| ModelProviderHealth {
3263 provider,
3264 configured: acc.configured,
3265 local_models: acc.local_models,
3266 remote_models: acc.remote_models,
3267 available_models: acc.available_models,
3268 capabilities: sort_capabilities(acc.capabilities.into_iter().collect()),
3269 })
3270 .collect();
3271
3272 let capabilities = all_model_capabilities()
3273 .into_iter()
3274 .map(|capability| {
3275 let relevant: Vec<&ModelSchema> = schemas
3276 .iter()
3277 .filter(|schema| schema.has_capability(capability))
3278 .collect();
3279 let available: Vec<&ModelSchema> = relevant
3280 .iter()
3281 .copied()
3282 .filter(|schema| schema.available)
3283 .collect();
3284 ModelCapabilityHealth {
3285 capability,
3286 total_models: relevant.len(),
3287 available_models: available.len(),
3288 local_available_models: available
3289 .iter()
3290 .filter(|schema| schema.is_local())
3291 .count(),
3292 remote_available_models: available
3293 .iter()
3294 .filter(|schema| !schema.is_local())
3295 .count(),
3296 }
3297 })
3298 .collect();
3299
3300 let routing = self.routing_scenarios().await;
3301 let routing_config = self.adaptive_router.config().clone();
3302 let benchmark_priors = load_benchmark_prior_health(&self.config.models_dir, &schemas);
3303
3304 ModelHealthReport {
3305 total_models,
3306 available_models,
3307 local_models,
3308 remote_models,
3309 defaults,
3310 providers,
3311 capabilities,
3312 routing_prefer_local: routing_config.prefer_local,
3313 routing_quality_first_cold_start: routing_config.quality_first_cold_start,
3314 routing_min_observations: routing_config.min_observations,
3315 routing_bootstrap_min_task_observations: routing_config.bootstrap_min_task_observations,
3316 routing_bootstrap_quality_floor: routing_config.bootstrap_quality_floor,
3317 routing_quality_weight: routing_config.quality_weight,
3318 routing_latency_weight: routing_config.latency_weight,
3319 routing_cost_weight: routing_config.cost_weight,
3320 routing_scenarios: routing,
3321 benchmark_priors,
3322 speech: self.speech_health(),
3323 }
3324 }
3325
3326 async fn routing_scenarios(&self) -> Vec<RoutingScenarioHealth> {
3327 let tracker = self.outcome_tracker.read().await;
3328 let config = self.adaptive_router.config().clone();
3329 let scenarios = [
3330 (
3331 "interactive_text",
3332 "Summarize the benefits of local-first AI routing in two sentences.",
3333 "text",
3334 RoutingWorkload::Interactive,
3335 false,
3336 false,
3337 ),
3338 (
3339 "background_code",
3340 "Write a Python function named fibonacci(n) that returns the nth Fibonacci number.",
3341 "code",
3342 RoutingWorkload::Background,
3343 false,
3344 false,
3345 ),
3346 (
3347 "interactive_tool_use",
3348 "Use the provided weather tool to get the weather for Boston.",
3349 "tool_use",
3350 RoutingWorkload::Interactive,
3351 true,
3352 false,
3353 ),
3354 (
3355 "interactive_vision",
3356 "What is in this image? Answer in one word.",
3357 "vision",
3358 RoutingWorkload::Interactive,
3359 false,
3360 true,
3361 ),
3362 ];
3363
3364 scenarios
3365 .into_iter()
3366 .map(|(name, prompt, task_family, workload, has_tools, has_vision)| {
3367 let decision = self.adaptive_router.route_context_aware(
3368 prompt,
3369 0,
3370 &self.unified_registry,
3371 &tracker,
3372 has_tools,
3373 has_vision,
3374 workload,
3375 );
3376 let quality_first_cold_start = if has_tools || has_vision {
3377 config.quality_first_cold_start
3378 } else if task_family == "code" && matches!(workload, RoutingWorkload::Background) {
3379 false
3380 } else {
3381 config.quality_first_cold_start
3382 };
3383 RoutingScenarioHealth {
3384 name: name.to_string(),
3385 task_family: task_family.to_string(),
3386 workload,
3387 has_tools,
3388 has_vision,
3389 prefer_local: if task_family == "speech" {
3390 self.speech_policy.prefer_local
3391 } else {
3392 config.prefer_local
3393 },
3394 quality_first_cold_start,
3395 bootstrap_min_task_observations: config.bootstrap_min_task_observations,
3396 bootstrap_quality_floor: config.bootstrap_quality_floor,
3397 model_id: decision.model_id,
3398 model_name: decision.model_name,
3399 reason: decision.reason,
3400 strategy: decision.strategy,
3401 }
3402 })
3403 .collect()
3404 }
3405
3406 pub async fn smoke_test_speech(
3408 &self,
3409 local: bool,
3410 remote: bool,
3411 ) -> Result<SpeechSmokeReport, InferenceError> {
3412 let mut report = SpeechSmokeReport::default();
3413
3414 if local {
3415 let tts = self
3416 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
3417 .ok_or_else(|| {
3418 InferenceError::InferenceFailed(
3419 "no local text-to-speech model available".into(),
3420 )
3421 })?;
3422 let stt = self
3423 .preferred_speech_schema(ModelCapability::SpeechToText, true, false)
3424 .ok_or_else(|| {
3425 InferenceError::InferenceFailed(
3426 "no local speech-to-text model available".into(),
3427 )
3428 })?;
3429 report.local = Some(
3430 self.run_speech_smoke_path("local", &tts, &stt, "Testing CAR local speech path.")
3431 .await?,
3432 );
3433 } else {
3434 report.skipped.push("local".to_string());
3435 }
3436
3437 if remote {
3438 let tts = self
3439 .preferred_speech_schema(ModelCapability::TextToSpeech, false, true)
3440 .ok_or_else(|| {
3441 InferenceError::InferenceFailed(
3442 "no remote text-to-speech model available".into(),
3443 )
3444 })?;
3445 let stt = self
3446 .preferred_speech_schema(ModelCapability::SpeechToText, false, true)
3447 .ok_or_else(|| {
3448 InferenceError::InferenceFailed(
3449 "no remote speech-to-text model available".into(),
3450 )
3451 })?;
3452 report.remote = Some(
3453 self.run_speech_smoke_path("remote", &tts, &stt, "Testing CAR remote speech path.")
3454 .await?,
3455 );
3456 } else {
3457 report.skipped.push("remote".to_string());
3458 }
3459
3460 Ok(report)
3461 }
3462
3463 fn speech_candidates(
3464 &self,
3465 capability: ModelCapability,
3466 explicit: Option<&str>,
3467 ) -> Result<Vec<ModelSchema>, InferenceError> {
3468 if let Some(model) = explicit {
3469 let schema = self
3470 .unified_registry
3471 .get(model)
3472 .or_else(|| self.unified_registry.find_by_name(model))
3473 .cloned()
3474 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3475 if !schema.has_capability(capability) {
3476 return Err(InferenceError::InferenceFailed(format!(
3477 "model {} does not support {:?}",
3478 schema.name, capability
3479 )));
3480 }
3481 return Ok(vec![schema]);
3482 }
3483
3484 let mut candidates: Vec<ModelSchema> = self
3485 .unified_registry
3486 .query(&ModelFilter {
3487 capabilities: vec![capability],
3488 ..Default::default()
3489 })
3490 .into_iter()
3491 .cloned()
3492 .collect();
3493
3494 if candidates.is_empty() {
3495 return Err(InferenceError::InferenceFailed(format!(
3496 "no models registered for capability {:?}",
3497 capability
3498 )));
3499 }
3500
3501 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3502 if !self.speech_policy.allow_remote_fallback && candidates.iter().any(|model| model.is_local()) {
3503 candidates.retain(|model| model.is_local());
3504 }
3505
3506 Ok(candidates)
3507 }
3508
3509 fn resolve_external_hf_repo(
3514 &self,
3515 explicit: Option<&str>,
3516 capability: ModelCapability,
3517 ) -> Option<String> {
3518 let id = explicit?;
3519 let schema = self
3520 .unified_registry
3521 .get(id)
3522 .or_else(|| self.unified_registry.find_by_name(id))?;
3523 if !schema.has_capability(capability) {
3524 return Some(id.to_string());
3525 }
3526 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3527 return Some(hf_repo.clone());
3528 }
3529 Some(id.to_string())
3530 }
3531
3532 fn media_generation_candidates(
3533 &self,
3534 capability: ModelCapability,
3535 explicit: Option<&str>,
3536 ) -> Result<Vec<ModelSchema>, InferenceError> {
3537 if let Some(model) = explicit {
3538 let schema = self
3539 .unified_registry
3540 .get(model)
3541 .or_else(|| self.unified_registry.find_by_name(model))
3542 .cloned()
3543 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3544 if !schema.has_capability(capability) {
3545 return Err(InferenceError::InferenceFailed(format!(
3546 "model {} does not support {:?}",
3547 schema.name, capability
3548 )));
3549 }
3550 return Ok(vec![schema]);
3551 }
3552
3553 let mut candidates: Vec<ModelSchema> = self
3554 .unified_registry
3555 .query(&ModelFilter {
3556 capabilities: vec![capability],
3557 local_only: true,
3558 ..Default::default()
3559 })
3560 .into_iter()
3561 .cloned()
3562 .collect();
3563 candidates.sort_by_key(|schema| (!schema.available, schema.size_mb()));
3564 if candidates.is_empty() {
3565 return Err(InferenceError::InferenceFailed(format!(
3566 "no models registered for capability {:?}",
3567 capability
3568 )));
3569 }
3570 Ok(candidates)
3571 }
3572
3573 fn preferred_speech_schema(
3574 &self,
3575 capability: ModelCapability,
3576 local_only: bool,
3577 remote_only: bool,
3578 ) -> Option<ModelSchema> {
3579 let available_only = remote_only;
3580 let mut candidates: Vec<ModelSchema> = self
3581 .unified_registry
3582 .query(&ModelFilter {
3583 capabilities: vec![capability],
3584 available_only,
3585 ..Default::default()
3586 })
3587 .into_iter()
3588 .filter(|schema| (!local_only || schema.is_local()) && (!remote_only || schema.is_remote()))
3589 .cloned()
3590 .collect();
3591 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3592 candidates.into_iter().next()
3593 }
3594
3595 fn speech_health_default_name(
3596 &self,
3597 capability: ModelCapability,
3598 local_only: bool,
3599 remote_only: bool,
3600 ) -> Option<String> {
3601 let preferred = match capability {
3602 ModelCapability::SpeechToText if local_only => {
3603 self.speech_policy.preferred_local_stt.as_ref()
3604 }
3605 ModelCapability::SpeechToText if remote_only => {
3606 self.speech_policy.preferred_remote_stt.as_ref()
3607 }
3608 ModelCapability::TextToSpeech if local_only => {
3609 self.speech_policy.preferred_local_tts.as_ref()
3610 }
3611 ModelCapability::TextToSpeech if remote_only => {
3612 self.speech_policy.preferred_remote_tts.as_ref()
3613 }
3614 _ => None,
3615 };
3616
3617 preferred
3618 .filter(|name| {
3619 self.unified_registry.list().iter().any(|schema| {
3620 schema.name == **name
3621 && schema.has_capability(capability)
3622 && (!local_only || schema.is_local())
3623 && (!remote_only || schema.is_remote())
3624 })
3625 })
3626 .cloned()
3627 .or_else(|| {
3628 self.preferred_speech_schema(capability, local_only, remote_only)
3629 .map(|schema| schema.name)
3630 })
3631 }
3632
3633 fn model_default_health(
3634 &self,
3635 capability: ModelCapability,
3636 configured_model: &str,
3637 ) -> ModelDefaultHealth {
3638 let schema = self
3639 .unified_registry
3640 .find_by_name(configured_model)
3641 .or_else(|| self.unified_registry.get(configured_model));
3642
3643 ModelDefaultHealth {
3644 capability,
3645 configured_model: configured_model.to_string(),
3646 available: schema.is_some_and(|model| model.available),
3647 is_local: schema.is_some_and(ModelSchema::is_local),
3648 provider: schema.map(|model| model.provider.clone()),
3649 }
3650 }
3651
3652 fn speech_sort_key(
3653 &self,
3654 capability: ModelCapability,
3655 model: &ModelSchema,
3656 ) -> (u8, u8, u8, u8, u64, u64) {
3657 let policy_preference = match capability {
3658 ModelCapability::SpeechToText if model.is_local() => self
3659 .speech_policy
3660 .preferred_local_stt
3661 .as_ref(),
3662 ModelCapability::SpeechToText => self.speech_policy.preferred_remote_stt.as_ref(),
3663 ModelCapability::TextToSpeech if model.is_local() => self
3664 .speech_policy
3665 .preferred_local_tts
3666 .as_ref(),
3667 ModelCapability::TextToSpeech => self.speech_policy.preferred_remote_tts.as_ref(),
3668 _ => None,
3669 };
3670 let local_rank = if self.speech_policy.prefer_local {
3671 if model.is_local() {
3672 0
3673 } else {
3674 1
3675 }
3676 } else if model.is_remote() {
3677 0
3678 } else {
3679 1
3680 };
3681 let availability_rank = if model.available {
3682 0
3683 } else if model.is_local() {
3684 1
3685 } else {
3686 2
3687 };
3688 let policy_rank: u8 = if policy_preference.is_some_and(|preferred| preferred == &model.name) {
3689 0
3690 } else {
3691 1
3692 };
3693 let speech_rank = match capability {
3694 ModelCapability::TextToSpeech => {
3695 if model.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
3696 0
3697 } else if model.name == "Kokoro-82M-bf16" {
3698 1
3699 } else if model.name == "Kokoro-82M-6bit" {
3700 2
3701 } else {
3702 3
3703 }
3704 }
3705 ModelCapability::SpeechToText => {
3706 if model.name == "Parakeet-TDT-0.6B-v3-MLX" {
3707 0
3708 } else {
3709 1
3710 }
3711 }
3712 _ => 0,
3713 };
3714 let latency_rank = model.performance.latency_p50_ms.unwrap_or(u64::MAX);
3715 let size_rank = model.cost.size_mb.unwrap_or(u64::MAX);
3716 (
3717 local_rank,
3718 availability_rank,
3719 policy_rank,
3720 speech_rank,
3721 latency_rank,
3722 size_rank,
3723 )
3724 }
3725
3726 async fn run_speech_smoke_path(
3727 &self,
3728 path: &str,
3729 tts: &ModelSchema,
3730 stt: &ModelSchema,
3731 text: &str,
3732 ) -> Result<SpeechSmokePathReport, InferenceError> {
3733 let work_dir = temp_work_dir(&format!("speech-smoke-{path}"))?;
3734 let audio_path = work_dir.join(format!("{path}.wav"));
3735 let synth = self
3736 .synthesize(SynthesizeRequest {
3737 text: text.to_string(),
3738 model: Some(tts.name.clone()),
3739 voice: default_speech_voice(tts),
3740 language: Some("en".to_string()),
3741 output_path: Some(audio_path.display().to_string()),
3742 ..SynthesizeRequest::default()
3743 })
3744 .await?;
3745 let transcript = self
3746 .transcribe(TranscribeRequest {
3747 audio_path: synth.audio_path.clone(),
3748 model: Some(stt.name.clone()),
3749 language: Some("en".to_string()),
3750 prompt: None,
3751 timestamps: false,
3752 })
3753 .await?;
3754
3755 Ok(SpeechSmokePathReport {
3756 path: path.to_string(),
3757 tts_model: synth.model_used.unwrap_or_else(|| tts.name.clone()),
3758 stt_model: transcript.model_used.unwrap_or_else(|| stt.name.clone()),
3759 audio_path: PathBuf::from(synth.audio_path),
3760 transcript: transcript.text,
3761 })
3762 }
3763
3764 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3765 async fn ensure_speech_runtime(&self) -> Result<SpeechRuntime, InferenceError> {
3766 let mut guard = self.speech_runtime.lock().await;
3767 if let Some(runtime) = guard.as_ref() {
3768 if runtime.is_ready() {
3769 return Ok(runtime.clone());
3770 }
3771 }
3772
3773 let runtime = SpeechRuntime::new(speech_runtime_root_from_models_dir(
3774 &self.config.models_dir,
3775 ));
3776 if !runtime.is_ready() {
3777 bootstrap_speech_runtime(&runtime).await?;
3778 }
3779 if !runtime.is_ready() {
3780 return Err(InferenceError::InferenceFailed(format!(
3781 "managed speech runtime is not ready at {}",
3782 runtime.root.display()
3783 )));
3784 }
3785
3786 *guard = Some(runtime.clone());
3787 Ok(runtime)
3788 }
3789
3790 async fn transcribe_local_mlx(
3791 &self,
3792 schema: &ModelSchema,
3793 req: &TranscribeRequest,
3794 ) -> Result<TranscribeResult, InferenceError> {
3795 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3797 {
3798 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3799 let parakeet = backend::mlx_parakeet::ParakeetBackend::load(&model_dir)?;
3800 let (text, words) = if req.timestamps {
3802 parakeet
3803 .transcribe_detailed(Path::new(&req.audio_path))
3804 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?
3805 } else {
3806 let t = parakeet
3807 .transcribe(Path::new(&req.audio_path))
3808 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?;
3809 (t, Vec::new())
3810 };
3811 return Ok(TranscribeResult {
3812 text,
3813 model_used: Some(schema.name.clone()),
3814 language: req.language.clone(),
3815 words,
3816 });
3817 }
3818
3819 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3821 {
3822 let runtime = self.ensure_speech_runtime().await?;
3823 let hf_repo = match &schema.source {
3824 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3825 _ => unreachable!(),
3826 };
3827 let output_dir = temp_work_dir("stt")?;
3828 let output_prefix = output_dir.join("transcript");
3829 let mut args = vec![
3830 "--model".to_string(),
3831 hf_repo,
3832 "--audio".to_string(),
3833 req.audio_path.clone(),
3834 "--output-path".to_string(),
3835 output_prefix.display().to_string(),
3836 "--format".to_string(),
3837 "json".to_string(),
3838 ];
3839 if let Some(language) = &req.language {
3840 args.push("--language".to_string());
3841 args.push(normalize_lang_code(language));
3842 }
3843 if let Some(prompt) = &req.prompt {
3844 args.push("--context".to_string());
3845 args.push(prompt.clone());
3846 }
3847 if req.timestamps {
3848 args.push("--verbose".to_string());
3849 }
3850
3851 let output = run_mlx_audio_command(&runtime, "stt.generate", &args).await?;
3852 let text = read_transcription_result(&output_prefix)?
3853 .or_else(|| extract_text_from_payload(&output.stdout))
3854 .ok_or_else(|| {
3855 InferenceError::InferenceFailed(format!(
3856 "mlx-audio transcription returned no text: {}",
3857 output.stderr
3858 ))
3859 })?;
3860
3861 Ok(TranscribeResult {
3862 text,
3863 model_used: Some(schema.name.clone()),
3864 language: req.language.clone(),
3865 words: Vec::new(),
3866 })
3867 }
3868 }
3869
3870 async fn synthesize_local_mlx(
3871 &self,
3872 schema: &ModelSchema,
3873 req: &SynthesizeRequest,
3874 ) -> Result<SynthesizeResult, InferenceError> {
3875 let requested = req.requested_advanced_controls();
3880 let repo_supports_advanced = match &schema.source {
3881 ModelSource::Mlx { hf_repo, .. } => {
3882 hf_repo.to_ascii_lowercase().contains("qwen3-tts")
3883 }
3884 _ => false,
3885 };
3886 if !requested.is_empty() && !repo_supports_advanced {
3887 if req.strict_capabilities {
3888 return Err(InferenceError::InferenceFailed(format!(
3889 "model {name} does not support Qwen3-TTS advanced controls {requested:?}; \
3890 route to a Qwen3-TTS model or set strict_capabilities = false to degrade",
3891 name = schema.name,
3892 )));
3893 }
3894 tracing::warn!(
3895 model = %schema.name,
3896 fields = ?requested,
3897 "Qwen3-TTS advanced controls set on non-Qwen3-TTS backend — ignored \
3898 (set strict_capabilities=true to error instead)"
3899 );
3900 }
3901
3902 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3904 {
3905 if repo_supports_advanced && !requested.is_empty() {
3911 if req.strict_capabilities {
3912 return Err(InferenceError::InferenceFailed(format!(
3913 "native MLX TTS backend does not yet implement Qwen3-TTS advanced \
3914 controls {requested:?}; run on non-Apple-Silicon to use the Python \
3915 mlx-audio fallback, or set strict_capabilities = false"
3916 )));
3917 }
3918 tracing::warn!(
3919 model = %schema.name,
3920 fields = ?requested,
3921 "Qwen3-TTS advanced controls are not yet implemented in the native MLX TTS \
3922 backend; synthesizing without cloning/voice-design"
3923 );
3924 }
3925 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3926 let size = backend_cache::estimate_model_size(&model_dir);
3927 let cache = Arc::clone(&self.kokoro_cache);
3928 let key = schema.id.clone();
3929 let handle = cache.get_or_load(&key, size, || {
3930 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
3931 })?;
3932
3933 let output_path = req.output_path.clone().unwrap_or_else(|| {
3934 let dir = std::env::temp_dir().join("car_tts");
3935 let _ = std::fs::create_dir_all(&dir);
3936 dir.join("output.wav").display().to_string()
3937 });
3938 let voice = req.voice.as_deref().unwrap_or("af_heart").to_string();
3939 let text = req.text.clone();
3940 let op = tokio::task::spawn_blocking(move || -> Result<PathBuf, InferenceError> {
3941 let mut guard = handle.lock().map_err(|_| {
3942 InferenceError::InferenceFailed("kokoro backend mutex poisoned".into())
3943 })?;
3944 guard
3945 .synthesize(&text, Some(&voice), Path::new(&output_path))
3946 .map_err(|e| InferenceError::InferenceFailed(format!("native TTS: {e}")))
3947 })
3948 .await
3949 .map_err(|e| InferenceError::InferenceFailed(format!("kokoro task join: {e}")))??;
3950
3951 let final_path =
3952 materialize_audio_output(&op, req.output_path.as_deref(), &req.format)?;
3953 return Ok(SynthesizeResult {
3954 audio_path: final_path.display().to_string(),
3955 media_type: media_type_for_format(&req.format),
3956 model_used: Some(schema.name.clone()),
3957 voice_used: req.voice.clone(),
3958 });
3959 }
3960
3961 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3963 {
3964 let runtime = self.ensure_speech_runtime().await?;
3965 let primary_hf_repo = match &schema.source {
3966 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3967 _ => unreachable!(),
3968 };
3969 let (produced, model_used) = match self
3970 .synthesize_local_mlx_repo(&runtime, &primary_hf_repo, schema.name.as_str(), req)
3971 .await
3972 {
3973 Ok(result) => result,
3974 Err(primary_err)
3975 if primary_hf_repo == "mlx-community/Kokoro-82M-6bit"
3976 && kokoro_runtime_fallback_enabled() =>
3977 {
3978 let fallback_repo = "mlx-community/Kokoro-82M-bf16";
3979 let fallback_name = "Kokoro-82M-bf16";
3980 match self
3981 .synthesize_local_mlx_repo(&runtime, fallback_repo, fallback_name, req)
3982 .await
3983 {
3984 Ok(result) => result,
3985 Err(fallback_err) => {
3986 return Err(InferenceError::InferenceFailed(format!(
3987 "{primary_err}; fallback {fallback_name} also failed: {fallback_err}"
3988 )));
3989 }
3990 }
3991 }
3992 Err(err) => return Err(err),
3993 };
3994 let final_path =
3995 materialize_audio_output(&produced, req.output_path.as_deref(), &req.format)?;
3996
3997 Ok(SynthesizeResult {
3998 audio_path: final_path.display().to_string(),
3999 media_type: media_type_for_format(&req.format),
4000 model_used: Some(model_used),
4001 voice_used: req.voice.clone(),
4002 })
4003 }
4004 }
4005
4006 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4007 async fn synthesize_local_mlx_repo(
4008 &self,
4009 runtime: &SpeechRuntime,
4010 hf_repo: &str,
4011 model_name: &str,
4012 req: &SynthesizeRequest,
4013 ) -> Result<(PathBuf, String), InferenceError> {
4014 let output_dir = temp_work_dir("tts")?;
4015 let mut args = vec![
4016 "--model".to_string(),
4017 hf_repo.to_string(),
4018 "--text".to_string(),
4019 req.text.clone(),
4020 "--output_path".to_string(),
4021 output_dir.display().to_string(),
4022 ];
4023 if let Some(voice) = &req.voice {
4024 args.push("--voice".to_string());
4025 args.push(voice.clone());
4026 }
4027 if let Some(speed) = req.speed {
4028 args.push("--speed".to_string());
4029 args.push(speed.to_string());
4030 }
4031 let repo_lower = hf_repo.to_ascii_lowercase();
4032 if repo_lower.contains("kokoro") {
4033 args.push("--lang_code".to_string());
4034 args.push(kokoro_lang_code(req.language.as_deref()).to_string());
4035 } else if let Some(language) = &req.language {
4036 args.push("--lang_code".to_string());
4037 args.push(normalize_lang_code(language));
4038 }
4039
4040 if repo_lower.contains("qwen3-tts") {
4046 if let Some(ref_audio) = &req.reference_audio_path {
4047 args.push("--ref_audio".to_string());
4048 args.push(ref_audio.clone());
4049 }
4050 if let Some(ref_text) = &req.reference_text {
4051 args.push("--ref_text".to_string());
4052 args.push(ref_text.clone());
4053 }
4054 if let Some(instruct) = &req.voice_instruction {
4055 args.push("--instruct".to_string());
4056 args.push(instruct.clone());
4057 }
4058 }
4059
4060 let output = if repo_lower.contains("kokoro") {
4061 let device = std::env::var("CAR_SPEECH_KOKORO_DEVICE")
4062 .or_else(|_| std::env::var("CAR_SPEECH_MLX_DEVICE"))
4063 .unwrap_or_else(|_| "cpu".to_string());
4064 let extra_env = vec![
4065 ("MLX_DEVICE".to_string(), device),
4067 ("PYTORCH_ENABLE_MPS_FALLBACK".to_string(), "1".to_string()),
4069 ];
4070 run_mlx_audio_command_with_env(runtime, "tts.generate", &args, &extra_env).await?
4071 } else {
4072 run_mlx_audio_command(runtime, "tts.generate", &args).await?
4073 };
4074 let produced = find_audio_file(&output_dir)?.ok_or_else(|| {
4075 let hint = if repo_lower.contains("kokoro") {
4076 ". Kokoro models may crash on GPU — try CAR_SPEECH_KOKORO_DEVICE=cpu or use the default Qwen3-TTS model"
4077 } else {
4078 ""
4079 };
4080 InferenceError::InferenceFailed(format!(
4081 "mlx-audio synthesis produced no audio file: {}{}",
4082 output.stderr, hint
4083 ))
4084 })?;
4085 Ok((produced, model_name.to_string()))
4086 }
4087
4088
4089 async fn transcribe_elevenlabs(
4090 &self,
4091 schema: &ModelSchema,
4092 req: &TranscribeRequest,
4093 ) -> Result<TranscribeResult, InferenceError> {
4094 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4095 let file_name = Path::new(&req.audio_path)
4096 .file_name()
4097 .and_then(|f| f.to_str())
4098 .unwrap_or("audio.wav")
4099 .to_string();
4100 let audio_bytes = tokio::fs::read(&req.audio_path).await?;
4101 let file_part = Part::bytes(audio_bytes).file_name(file_name);
4102 let mut form = Form::new()
4103 .text("model_id", schema.name.clone())
4104 .part("file", file_part);
4105 if let Some(language) = &req.language {
4106 form = form.text("language_code", language.clone());
4107 }
4108
4109 let resp = self
4110 .remote_backend
4111 .client
4112 .post(format!(
4113 "{}/v1/speech-to-text",
4114 endpoint.trim_end_matches('/')
4115 ))
4116 .header("xi-api-key", api_key)
4117 .multipart(form)
4118 .send()
4119 .await
4120 .map_err(|e| {
4121 InferenceError::InferenceFailed(format!("ElevenLabs STT request failed: {e}"))
4122 })?;
4123 let status = resp.status();
4124 let body = resp.text().await.map_err(|e| {
4125 InferenceError::InferenceFailed(format!("read ElevenLabs STT body: {e}"))
4126 })?;
4127 if !status.is_success() {
4128 return Err(InferenceError::InferenceFailed(format!(
4129 "ElevenLabs STT returned {status}: {body}"
4130 )));
4131 }
4132 let payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
4133 InferenceError::InferenceFailed(format!("parse ElevenLabs STT response: {e}"))
4134 })?;
4135 let text = payload
4136 .get("text")
4137 .and_then(|v| v.as_str())
4138 .map(str::to_string)
4139 .ok_or_else(|| {
4140 InferenceError::InferenceFailed("ElevenLabs STT response missing text".into())
4141 })?;
4142
4143 Ok(TranscribeResult {
4144 text,
4145 model_used: Some(schema.name.clone()),
4146 language: payload
4147 .get("language_code")
4148 .and_then(|v| v.as_str())
4149 .map(str::to_string),
4150 words: Vec::new(),
4151 })
4152 }
4153
4154 async fn synthesize_elevenlabs(
4155 &self,
4156 schema: &ModelSchema,
4157 req: &SynthesizeRequest,
4158 ) -> Result<SynthesizeResult, InferenceError> {
4159 let requested = req.requested_advanced_controls();
4163 if !requested.is_empty() {
4164 if req.strict_capabilities {
4165 return Err(InferenceError::InferenceFailed(format!(
4166 "ElevenLabs backend does not support Qwen3-TTS advanced controls \
4167 {requested:?}; route to a Qwen3-TTS model or set strict_capabilities = false"
4168 )));
4169 }
4170 tracing::warn!(
4171 model = %schema.name,
4172 fields = ?requested,
4173 "Qwen3-TTS advanced controls ignored by ElevenLabs backend"
4174 );
4175 }
4176 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4177 let voice_id = req
4178 .voice
4179 .clone()
4180 .unwrap_or_else(|| "JBFqnCBsd6RMkjVDRZzb".to_string());
4181 let output_format = elevenlabs_output_format(&req.format);
4182 let url = format!(
4183 "{}/v1/text-to-speech/{}?output_format={}",
4184 endpoint.trim_end_matches('/'),
4185 voice_id,
4186 output_format
4187 );
4188
4189 let mut body = serde_json::json!({
4190 "text": req.text,
4191 "model_id": schema.name,
4192 });
4193 if let Some(language) = &req.language {
4194 body["language_code"] = serde_json::Value::String(language.clone());
4195 }
4196
4197 let resp = self
4198 .remote_backend
4199 .client
4200 .post(url)
4201 .header("xi-api-key", api_key)
4202 .header("Content-Type", "application/json")
4203 .json(&body)
4204 .send()
4205 .await
4206 .map_err(|e| {
4207 InferenceError::InferenceFailed(format!("ElevenLabs TTS request failed: {e}"))
4208 })?;
4209 let status = resp.status();
4210 let audio = resp.bytes().await.map_err(|e| {
4211 InferenceError::InferenceFailed(format!("read ElevenLabs TTS body: {e}"))
4212 })?;
4213 if !status.is_success() {
4214 let err_body = String::from_utf8_lossy(&audio);
4215 return Err(InferenceError::InferenceFailed(format!(
4216 "ElevenLabs TTS returned {status}: {err_body}"
4217 )));
4218 }
4219
4220 let final_path = requested_or_temp_output(req.output_path.as_deref(), &req.format)?;
4221 ensure_parent_dir(&final_path)?;
4222 tokio::fs::write(&final_path, &audio).await?;
4223
4224 Ok(SynthesizeResult {
4225 audio_path: final_path.display().to_string(),
4226 media_type: media_type_for_format(&req.format),
4227 model_used: Some(schema.name.clone()),
4228 voice_used: Some(voice_id),
4229 })
4230 }
4231}
4232
4233#[derive(Default)]
4234struct ProviderAccumulator {
4235 configured: bool,
4236 local_models: usize,
4237 remote_models: usize,
4238 available_models: usize,
4239 capabilities: std::collections::HashSet<ModelCapability>,
4240}
4241
4242#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4246struct CommandOutput {
4247 stdout: String,
4248 stderr: String,
4249}
4250
4251#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4252#[derive(Debug, Clone)]
4253struct SpeechRuntime {
4254 root: PathBuf,
4255 python: PathBuf,
4256 stt_program: PathBuf,
4257 tts_program: PathBuf,
4258}
4259
4260#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4261impl SpeechRuntime {
4262 fn new(root: PathBuf) -> Self {
4263 let bin_dir = root.join("bin");
4264 Self {
4265 root,
4266 python: bin_dir.join("python"),
4267 stt_program: bin_dir.join("mlx_audio.stt.generate"),
4268 tts_program: bin_dir.join("mlx_audio.tts.generate"),
4269 }
4270 }
4271
4272 fn is_ready(&self) -> bool {
4273 self.python.exists() && self.stt_program.exists() && self.tts_program.exists()
4274 }
4275
4276 fn command_for(&self, subcommand: &str) -> Result<&Path, InferenceError> {
4277 match subcommand {
4278 "stt.generate" => Ok(&self.stt_program),
4279 "tts.generate" => Ok(&self.tts_program),
4280 _ => Err(InferenceError::InferenceFailed(format!(
4281 "unknown speech subcommand: {subcommand}"
4282 ))),
4283 }
4284 }
4285}
4286
4287
4288#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4289async fn run_mlx_audio_command(
4290 runtime: &SpeechRuntime,
4291 subcommand: &str,
4292 args: &[String],
4293) -> Result<CommandOutput, InferenceError> {
4294 run_mlx_audio_command_with_env(runtime, subcommand, args, &[]).await
4295}
4296
4297#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4298async fn run_mlx_audio_command_with_env(
4299 runtime: &SpeechRuntime,
4300 subcommand: &str,
4301 args: &[String],
4302 envs: &[(String, String)],
4303) -> Result<CommandOutput, InferenceError> {
4304 let program = runtime.command_for(subcommand)?;
4305 let mut command = Command::new(program);
4306 command.args(args);
4307 for (key, value) in envs {
4308 command.env(key, value);
4309 }
4310 let output = command.output().await.map_err(|err| {
4311 InferenceError::InferenceFailed(format!("{}: {err}", program.display()))
4312 })?;
4313
4314 if output.status.success() {
4315 Ok(CommandOutput {
4316 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
4317 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
4318 })
4319 } else {
4320 Err(InferenceError::InferenceFailed(format!(
4321 "{} exited with {}: {}",
4322 program.display(),
4323 output.status,
4324 String::from_utf8_lossy(&output.stderr)
4325 )))
4326 }
4327}
4328
4329
4330#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4331async fn bootstrap_speech_runtime(runtime: &SpeechRuntime) -> Result<(), InferenceError> {
4332 std::fs::create_dir_all(&runtime.root)?;
4333 let python = select_speech_python()?;
4334
4335 run_command(
4336 "uv",
4337 &[
4338 "venv".to_string(),
4339 "--python".to_string(),
4340 python,
4341 runtime.root.display().to_string(),
4342 ],
4343 )
4344 .await?;
4345
4346 run_command(
4347 "uv",
4348 &[
4349 "pip".to_string(),
4350 "install".to_string(),
4351 "--python".to_string(),
4352 runtime.python.display().to_string(),
4353 speech_runtime_mlx_audio_spec(),
4354 "misaki[en]".to_string(),
4355 speech_runtime_spacy_model_spec(),
4356 ],
4357 )
4358 .await?;
4359
4360 Ok(())
4361}
4362
4363
4364#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4365async fn run_command(program: &str, args: &[String]) -> Result<(), InferenceError> {
4366 let output = Command::new(program)
4367 .args(args)
4368 .output()
4369 .await
4370 .map_err(|err| InferenceError::InferenceFailed(format!("{program}: {err}")))?;
4371
4372 if output.status.success() {
4373 Ok(())
4374 } else {
4375 Err(InferenceError::InferenceFailed(format!(
4376 "{} exited with {}: {}",
4377 program,
4378 output.status,
4379 String::from_utf8_lossy(&output.stderr)
4380 )))
4381 }
4382}
4383
4384#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4385fn select_speech_python() -> Result<String, InferenceError> {
4386 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4387 if !path.trim().is_empty() {
4388 return Ok(path);
4389 }
4390 }
4391
4392 for candidate in ["python3.13", "python3.12", "python3.11"] {
4393 if command_in_path(candidate) {
4394 return Ok(candidate.to_string());
4395 }
4396 }
4397
4398 Err(InferenceError::InferenceFailed(
4399 "no supported Python found for managed speech runtime (tried python3.13, python3.12, python3.11)".into(),
4400 ))
4401}
4402
4403#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4404fn detect_speech_python() -> Option<String> {
4405 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4406 if !path.trim().is_empty() {
4407 return Some(path);
4408 }
4409 }
4410
4411 ["python3.13", "python3.12", "python3.11"]
4412 .into_iter()
4413 .find(|candidate| command_in_path(candidate))
4414 .map(str::to_string)
4415}
4416
4417#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4418fn speech_runtime_root_from_models_dir(_models_dir: &Path) -> PathBuf {
4419 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
4420 if !path.trim().is_empty() {
4421 return PathBuf::from(path);
4422 }
4423 }
4424
4425 std::env::var("HOME")
4426 .map(PathBuf::from)
4427 .unwrap_or_else(|_| PathBuf::from("."))
4428 .join(".car")
4429 .join("speech-runtime")
4430}
4431
4432#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4433fn command_in_path(name: &str) -> bool {
4434 std::env::var_os("PATH")
4435 .map(|paths| {
4436 std::env::split_paths(&paths).any(|dir| {
4437 let path = dir.join(name);
4438 path.exists() && path.is_file()
4439 })
4440 })
4441 .unwrap_or(false)
4442}
4443
4444fn speech_model_cached(schema: &ModelSchema) -> bool {
4445 match &schema.source {
4446 ModelSource::Mlx { hf_repo, .. } => {
4447 huggingface_repo_has_snapshot(hf_repo)
4448 }
4449 ModelSource::Proprietary { auth, .. } => match auth {
4450 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4451 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4452 ProprietaryAuth::OAuth2Pkce { .. } => false,
4453 },
4454 _ => false,
4455 }
4456}
4457
4458fn remove_huggingface_repo_cache(repo_id: &str) -> Result<(), InferenceError> {
4459 let repo_dir = std::env::var("HF_HOME")
4460 .map(PathBuf::from)
4461 .unwrap_or_else(|_| {
4462 std::env::var("HOME")
4463 .map(PathBuf::from)
4464 .unwrap_or_else(|_| PathBuf::from("."))
4465 .join(".cache")
4466 .join("huggingface")
4467 })
4468 .join("hub")
4469 .join(format!("models--{}", repo_id.replace('/', "--")));
4470
4471 if repo_dir.exists() {
4472 std::fs::remove_dir_all(repo_dir)?;
4473 }
4474 Ok(())
4475}
4476
4477fn model_source_configured(schema: &ModelSchema) -> bool {
4478 match &schema.source {
4479 ModelSource::RemoteApi {
4480 api_key_env,
4481 api_key_envs,
4482 ..
4483 } => {
4484 std::env::var(api_key_env).is_ok()
4485 || api_key_envs
4486 .iter()
4487 .any(|env_var| std::env::var(env_var).is_ok())
4488 }
4489 ModelSource::Proprietary { auth, .. } => match auth {
4490 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4491 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4492 ProprietaryAuth::OAuth2Pkce { .. } => false,
4493 },
4494 ModelSource::VllmMlx { .. } => std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available,
4495 ModelSource::Ollama { .. } => schema.available,
4496 ModelSource::Mlx { .. } | ModelSource::Local { .. } => {
4497 true
4498 }
4499 ModelSource::AppleFoundationModels { .. } => schema.available,
4500 ModelSource::Delegated { .. } => true,
4505 }
4506}
4507
4508fn all_model_capabilities() -> [ModelCapability; 13] {
4509 [
4510 ModelCapability::Generate,
4511 ModelCapability::Embed,
4512 ModelCapability::Classify,
4513 ModelCapability::Code,
4514 ModelCapability::Reasoning,
4515 ModelCapability::Summarize,
4516 ModelCapability::ToolUse,
4517 ModelCapability::MultiToolCall,
4518 ModelCapability::Vision,
4519 ModelCapability::SpeechToText,
4520 ModelCapability::TextToSpeech,
4521 ModelCapability::ImageGeneration,
4522 ModelCapability::VideoGeneration,
4523 ]
4524}
4525
4526fn sort_capabilities(mut capabilities: Vec<ModelCapability>) -> Vec<ModelCapability> {
4527 capabilities.sort_by_key(|capability| {
4528 all_model_capabilities()
4529 .iter()
4530 .position(|candidate| candidate == capability)
4531 .unwrap_or(usize::MAX)
4532 });
4533 capabilities
4534}
4535
4536fn speech_model_source_label(schema: &ModelSchema) -> String {
4537 match &schema.source {
4538 ModelSource::Mlx { hf_repo, .. } => format!("mlx:{hf_repo}"),
4539 ModelSource::Proprietary {
4540 provider, endpoint, ..
4541 } => format!("proprietary:{provider}:{endpoint}"),
4542 ModelSource::RemoteApi { endpoint, .. } => format!("remote:{endpoint}"),
4543 ModelSource::Local { hf_repo, .. } => format!("local:{hf_repo}"),
4544 ModelSource::VllmMlx {
4545 endpoint,
4546 model_name,
4547 } => format!("vllm-mlx:{endpoint}:{model_name}"),
4548 ModelSource::Ollama { model_tag, host } => format!("ollama:{host}:{model_tag}"),
4549 ModelSource::AppleFoundationModels { use_case } => {
4550 format!("apple-foundation:{}", use_case.as_deref().unwrap_or("default"))
4551 }
4552 ModelSource::Delegated { hint } => {
4553 format!("delegated:{}", hint.as_deref().unwrap_or("(none)"))
4554 }
4555 }
4556}
4557
4558fn rerank_prompt(instruction: &str, query: &str, document: &str) -> String {
4566 const SYSTEM: &str = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".";
4567 format!(
4568 "<|im_start|>system\n{SYSTEM}<|im_end|>\n\
4569 <|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n\
4570 <|im_start|>assistant\n<think>\n\n</think>\n\n"
4571 )
4572}
4573
4574fn score_from_rerank_output(text: &str, model_name: &str) -> f32 {
4580 let normalized: String = text
4585 .to_ascii_lowercase()
4586 .chars()
4587 .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
4588 .collect();
4589 for tok in normalized.split_ascii_whitespace().take(5) {
4590 match tok {
4591 "yes" => return 1.0,
4592 "no" => return 0.0,
4593 _ => continue,
4594 }
4595 }
4596 tracing::warn!(
4597 model = %model_name,
4598 output = %text,
4599 "rerank: first tokens contain neither `yes` nor `no`; returning neutral 0.5"
4600 );
4601 0.5
4602}
4603
4604fn default_speech_voice(schema: &ModelSchema) -> Option<String> {
4605 if schema.provider == "elevenlabs" {
4606 Some("JBFqnCBsd6RMkjVDRZzb".to_string())
4607 } else if schema.name == "Kokoro-82M-6bit" || schema.name == "Kokoro-82M-bf16" {
4608 Some("af_heart".to_string())
4609 } else if schema.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
4610 Some("Chelsie".to_string())
4611 } else {
4612 None
4613 }
4614}
4615
4616fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
4617 find_latest_huggingface_snapshot(repo_id).is_some()
4618}
4619
4620fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
4621 let cache_root = std::env::var("HF_HOME")
4622 .map(PathBuf::from)
4623 .unwrap_or_else(|_| {
4624 std::env::var("HOME")
4625 .map(PathBuf::from)
4626 .unwrap_or_else(|_| PathBuf::from("."))
4627 .join(".cache")
4628 .join("huggingface")
4629 })
4630 .join("hub");
4631 cache_root.join(format!("models--{}", repo_id.replace('/', "--")))
4632}
4633
4634fn find_latest_huggingface_snapshot(repo_id: &str) -> Option<PathBuf> {
4635 let snapshots = huggingface_repo_dir(repo_id).join("snapshots");
4636 std::fs::read_dir(snapshots)
4637 .ok()?
4638 .filter_map(Result::ok)
4639 .map(|entry| entry.path())
4640 .find(|path| path.is_dir() && snapshot_looks_ready(path))
4641}
4642
4643fn snapshot_looks_ready(path: &Path) -> bool {
4644 if path.join("config.json").exists() || path.join("model_index.json").exists() {
4645 return true;
4646 }
4647 snapshot_contains_ext(path, "safetensors")
4648}
4649
4650fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
4651 let Ok(entries) = std::fs::read_dir(root) else {
4652 return false;
4653 };
4654 entries.filter_map(Result::ok).any(|entry| {
4655 let path = entry.path();
4656 if path.is_dir() {
4657 snapshot_contains_ext(&path, ext)
4658 } else {
4659 path.extension()
4660 .and_then(|value| value.to_str())
4661 .map(|value| value.eq_ignore_ascii_case(ext))
4662 .unwrap_or(false)
4663 }
4664 })
4665}
4666
4667fn count_files_recursive(root: &Path) -> usize {
4668 let Ok(entries) = std::fs::read_dir(root) else {
4669 return 0;
4670 };
4671 entries
4672 .filter_map(Result::ok)
4673 .map(|entry| entry.path())
4674 .map(|path| {
4675 if path.is_dir() {
4676 count_files_recursive(&path)
4677 } else if path.is_file() {
4678 1
4679 } else {
4680 0
4681 }
4682 })
4683 .sum()
4684}
4685
4686async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
4687 let api = hf_hub::api::tokio::ApiBuilder::from_env()
4688 .with_progress(false)
4689 .build()
4690 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
4691 let repo = api.model(repo_id.to_string());
4692 let info = repo
4693 .info()
4694 .await
4695 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
4696
4697 let snapshot_path = huggingface_repo_dir(repo_id)
4698 .join("snapshots")
4699 .join(&info.sha);
4700 let mut downloaded = 0usize;
4701 for sibling in &info.siblings {
4702 let local_path = snapshot_path.join(&sibling.rfilename);
4703 if local_path.exists() {
4704 downloaded += 1;
4705 continue;
4706 }
4707 repo.download(&sibling.rfilename).await.map_err(|e| {
4708 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
4709 })?;
4710 downloaded += 1;
4711 }
4712
4713 Ok((snapshot_path, downloaded))
4714}
4715
4716fn temp_work_dir(prefix: &str) -> Result<PathBuf, InferenceError> {
4717 let unique = SystemTime::now()
4718 .duration_since(UNIX_EPOCH)
4719 .map_err(|e| InferenceError::InferenceFailed(format!("clock error: {e}")))?
4720 .as_nanos();
4721 let dir = std::env::temp_dir().join(format!("car-inference-{prefix}-{unique}"));
4722 std::fs::create_dir_all(&dir)?;
4723 Ok(dir)
4724}
4725
4726fn ensure_parent_dir(path: &Path) -> Result<(), InferenceError> {
4727 if let Some(parent) = path.parent() {
4728 std::fs::create_dir_all(parent)?;
4729 }
4730 Ok(())
4731}
4732
4733fn requested_or_temp_output(
4734 output_path: Option<&str>,
4735 format: &str,
4736) -> Result<PathBuf, InferenceError> {
4737 if let Some(path) = output_path {
4738 return Ok(PathBuf::from(path));
4739 }
4740 let dir = temp_work_dir("audio-out")?;
4741 Ok(dir.join(format!("speech.{format}")))
4742}
4743
4744fn requested_or_temp_media_output(
4745 output_path: Option<&str>,
4746 format: &str,
4747 stem: &str,
4748) -> Result<PathBuf, InferenceError> {
4749 if let Some(path) = output_path {
4750 return Ok(PathBuf::from(path));
4751 }
4752 let dir = temp_work_dir(&format!("{stem}-out"))?;
4753 Ok(dir.join(format!("{stem}.{format}")))
4754}
4755
4756fn materialize_audio_output(
4757 produced: &Path,
4758 requested: Option<&str>,
4759 format: &str,
4760) -> Result<PathBuf, InferenceError> {
4761 if let Some(path) = requested {
4762 let dest = PathBuf::from(path);
4763 ensure_parent_dir(&dest)?;
4764 std::fs::copy(produced, &dest)?;
4765 Ok(dest)
4766 } else {
4767 let dest = requested_or_temp_output(None, format)?;
4768 ensure_parent_dir(&dest)?;
4769 std::fs::copy(produced, &dest)?;
4770 Ok(dest)
4771 }
4772}
4773
4774fn materialize_binary_output(
4775 produced: &Path,
4776 requested: Option<&str>,
4777 format: &str,
4778 stem: &str,
4779) -> Result<PathBuf, InferenceError> {
4780 let dest = requested_or_temp_media_output(requested, format, stem)?;
4781 ensure_parent_dir(&dest)?;
4782 std::fs::copy(produced, &dest)?;
4783 Ok(dest)
4784}
4785
4786fn find_generated_file(
4787 root: &Path,
4788 extensions: &[&str],
4789) -> Result<Option<PathBuf>, InferenceError> {
4790 let entries = std::fs::read_dir(root)?;
4791 let mut candidates: Vec<PathBuf> = entries
4792 .filter_map(Result::ok)
4793 .map(|entry| entry.path())
4794 .filter(|path| {
4795 path.is_file()
4796 && path
4797 .extension()
4798 .and_then(|ext| ext.to_str())
4799 .map(|ext| extensions.iter().any(|candidate| candidate.eq_ignore_ascii_case(ext)))
4800 .unwrap_or(false)
4801 })
4802 .collect();
4803 candidates.sort();
4804 Ok(candidates.pop())
4805}
4806
4807fn media_type_for_image_format(format: &str) -> String {
4808 match format.to_ascii_lowercase().as_str() {
4809 "jpg" | "jpeg" => "image/jpeg".to_string(),
4810 "webp" => "image/webp".to_string(),
4811 _ => "image/png".to_string(),
4812 }
4813}
4814
4815fn media_type_for_video_format(format: &str) -> String {
4816 match format.to_ascii_lowercase().as_str() {
4817 "mov" => "video/quicktime".to_string(),
4818 "gif" => "image/gif".to_string(),
4819 _ => "video/mp4".to_string(),
4820 }
4821}
4822
4823fn read_transcription_result(output_prefix: &Path) -> Result<Option<String>, InferenceError> {
4824 let candidates = [
4825 output_prefix.with_extension("json"),
4826 output_prefix.to_path_buf(),
4827 ];
4828
4829 for path in candidates {
4830 if path.exists() {
4831 let contents = std::fs::read_to_string(path)?;
4832 if let Some(text) = extract_text_from_payload(&contents) {
4833 return Ok(Some(text));
4834 }
4835 }
4836 }
4837
4838 Ok(None)
4839}
4840
4841fn extract_text_from_payload(payload: &str) -> Option<String> {
4842 let value: serde_json::Value = serde_json::from_str(payload).ok()?;
4843 if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
4844 return Some(text.to_string());
4845 }
4846 if let Some(transcripts) = value.get("transcripts").and_then(|v| v.as_array()) {
4847 let joined = transcripts
4848 .iter()
4849 .filter_map(|item| item.get("text").and_then(|v| v.as_str()))
4850 .collect::<Vec<_>>()
4851 .join("\n");
4852 if !joined.is_empty() {
4853 return Some(joined);
4854 }
4855 }
4856 if let Some(items) = value.as_array() {
4857 let joined = items
4858 .iter()
4859 .filter_map(|item| {
4860 item.get("text")
4861 .or_else(|| item.get("Content"))
4862 .and_then(|v| v.as_str())
4863 })
4864 .collect::<Vec<_>>()
4865 .join(" ");
4866 if !joined.is_empty() {
4867 return Some(joined);
4868 }
4869 }
4870 None
4871}
4872
4873fn find_audio_file(output_dir: &Path) -> Result<Option<PathBuf>, InferenceError> {
4874 let mut audio_files = Vec::new();
4875 collect_audio_files(output_dir, &mut audio_files)?;
4876 audio_files.sort();
4877 Ok(audio_files.into_iter().next())
4878}
4879
4880fn collect_audio_files(dir: &Path, audio_files: &mut Vec<PathBuf>) -> Result<(), InferenceError> {
4881 for entry in std::fs::read_dir(dir)? {
4882 let path = entry?.path();
4883 if path.is_dir() {
4884 collect_audio_files(&path, audio_files)?;
4885 } else if matches!(
4886 path.extension().and_then(|ext| ext.to_str()),
4887 Some("wav" | "mp3" | "flac" | "pcm" | "m4a")
4888 ) {
4889 audio_files.push(path);
4890 }
4891 }
4892 Ok(())
4893}
4894
4895fn media_type_for_format(format: &str) -> String {
4896 match format.to_ascii_lowercase().as_str() {
4897 "mp3" => "audio/mpeg".to_string(),
4898 "flac" => "audio/flac".to_string(),
4899 "pcm" => "audio/L16".to_string(),
4900 "m4a" => "audio/mp4".to_string(),
4901 _ => "audio/wav".to_string(),
4902 }
4903}
4904
4905#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4906fn kokoro_lang_code(language: Option<&str>) -> &'static str {
4907 match language.unwrap_or("en").to_ascii_lowercase().as_str() {
4908 "en-gb" | "british" | "british english" => "b",
4909 "ja" | "japanese" => "j",
4910 "zh" | "zh-cn" | "mandarin" | "chinese" => "z",
4911 "es" | "spanish" => "e",
4912 "fr" | "french" => "f",
4913 _ => "a",
4914 }
4915}
4916
4917fn normalize_lang_code(language: &str) -> String {
4918 match language.to_ascii_lowercase().as_str() {
4919 "english" | "en-us" | "en_us" => "en".to_string(),
4920 "spanish" => "es".to_string(),
4921 "french" => "fr".to_string(),
4922 "japanese" => "ja".to_string(),
4923 "chinese" | "mandarin" => "zh".to_string(),
4924 other => match other {
4925 "en" | "es" | "fr" | "ja" | "zh" => other.to_string(),
4926 _ => "en".to_string(),
4927 },
4928 }
4929}
4930
4931fn elevenlabs_auth(schema: &ModelSchema) -> Result<(String, String), InferenceError> {
4932 match &schema.source {
4933 ModelSource::Proprietary {
4934 endpoint,
4935 auth: schema::ProprietaryAuth::ApiKeyEnv { env_var },
4936 ..
4937 } => {
4938 let key = car_secrets::resolve_env_or_keychain(env_var).ok_or_else(|| {
4939 InferenceError::InferenceFailed(format!(
4940 "missing API key {env_var}; set the environment variable or \
4941 store it with `car secrets put {env_var}`"
4942 ))
4943 })?;
4944 Ok((endpoint.clone(), key))
4945 }
4946 _ => Err(InferenceError::InferenceFailed(format!(
4947 "model {} is not an ElevenLabs proprietary model",
4948 schema.id
4949 ))),
4950 }
4951}
4952
4953fn elevenlabs_output_format(format: &str) -> &'static str {
4954 match format.to_ascii_lowercase().as_str() {
4955 "mp3" => "mp3_44100_128",
4956 "pcm" => "pcm_16000",
4957 _ => "wav_44100",
4958 }
4959}
4960
4961fn benchmark_priors_paths(models_dir: &Path) -> Vec<PathBuf> {
4962 let mut paths = Vec::new();
4963
4964 let direct = models_dir.join("benchmark_priors.json");
4965 if !paths.contains(&direct) {
4966 paths.push(direct);
4967 }
4968
4969 if let Some(parent) = models_dir.parent() {
4970 let parent_path = parent.join("benchmark_priors.json");
4971 if !paths.contains(&parent_path) {
4972 paths.push(parent_path);
4973 }
4974 }
4975
4976 if let Some(path) = std::env::var_os("CAR_BENCHMARK_PRIORS_PATH") {
4977 let path = PathBuf::from(path);
4978 if !paths.contains(&path) {
4979 paths.push(path);
4980 }
4981 }
4982
4983 paths
4984}
4985
4986fn load_benchmark_prior_health(
4987 models_dir: &Path,
4988 schemas: &[ModelSchema],
4989) -> Vec<ModelBenchmarkPriorHealth> {
4990 let mut priors = std::collections::BTreeMap::new();
4991 for path in benchmark_priors_paths(models_dir) {
4992 let Ok(loaded) = routing_ext::load_benchmark_priors(&path) else {
4993 continue;
4994 };
4995 for (model_id, prior) in loaded {
4996 let model_name = schemas
4997 .iter()
4998 .find(|schema| schema.id == model_id)
4999 .map(|schema| schema.name.clone());
5000 priors.insert(
5001 model_id.clone(),
5002 ModelBenchmarkPriorHealth {
5003 model_id,
5004 model_name,
5005 overall_score: prior.overall_score,
5006 overall_latency_ms: prior.overall_latency_ms,
5007 task_scores: prior.task_scores,
5008 task_latency_ms: prior.task_latency_ms,
5009 source_path: path.clone(),
5010 },
5011 );
5012 }
5013 }
5014
5015 priors.into_values().collect()
5016}
5017
5018#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5019fn kokoro_runtime_fallback_enabled() -> bool {
5020 std::env::var("CAR_SPEECH_KOKORO_FALLBACK")
5021 .ok()
5022 .map(|value| !matches!(value.trim().to_ascii_lowercase().as_str(), "0" | "false" | "off"))
5023 .unwrap_or(true)
5024}
5025
5026#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5027fn speech_runtime_mlx_audio_spec() -> String {
5028 std::env::var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC")
5029 .ok()
5030 .filter(|value| !value.trim().is_empty())
5031 .unwrap_or_else(|| "mlx-audio==0.4.2".to_string())
5032}
5033
5034#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5035fn speech_runtime_spacy_model_spec() -> String {
5036 std::env::var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC")
5037 .ok()
5038 .filter(|value| !value.trim().is_empty())
5039 .unwrap_or_else(|| {
5040 "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl".to_string()
5041 })
5042}
5043
5044#[cfg(test)]
5045mod tests {
5046 use super::*;
5047 use tempfile::TempDir;
5048
5049 static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
5052
5053 fn test_config(models_dir: PathBuf) -> InferenceConfig {
5054 InferenceConfig {
5055 models_dir,
5056 device: None,
5057 generation_model: "Qwen3-0.6B".into(),
5058 preferred_generation_model: None,
5059 embedding_model: "Qwen3-Embedding-0.6B".into(),
5060 preferred_embedding_model: None,
5061 classification_model: "Qwen3-0.6B".into(),
5062 preferred_classification_model: None,
5063 }
5064 }
5065
5066 #[tokio::test]
5067 async fn tokenize_rejects_known_remote_model_with_unsupported_mode() {
5068 let tmp = TempDir::new().unwrap();
5073 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5074 let remote_id = engine
5075 .list_schemas()
5076 .into_iter()
5077 .find(|s| !s.is_local())
5078 .map(|s| s.id)
5079 .expect(
5080 "built-in catalog should include at least one remote model schema",
5081 );
5082
5083 let err = engine
5084 .tokenize(&remote_id, "hello")
5085 .await
5086 .expect_err("remote tokenize must error");
5087 match err {
5088 InferenceError::UnsupportedMode { mode, backend, .. } => {
5089 assert_eq!(mode, "tokenize/detokenize");
5090 assert_eq!(backend, "remote");
5091 }
5092 other => panic!("expected UnsupportedMode, got {other:?}"),
5093 }
5094
5095 let err = engine
5096 .detokenize(&remote_id, &[1, 2, 3])
5097 .await
5098 .expect_err("remote detokenize must error");
5099 assert!(
5100 matches!(err, InferenceError::UnsupportedMode { .. }),
5101 "expected UnsupportedMode, got {err:?}"
5102 );
5103 }
5104
5105 #[test]
5106 fn engine_loads_benchmark_priors_on_startup() {
5107 let _env = ENV_MUTEX.lock().unwrap();
5108 let tmp = TempDir::new().unwrap();
5109 let priors_path = tmp.path().join("benchmark_priors.json");
5110 std::fs::write(
5111 &priors_path,
5112 serde_json::json!({
5113 "model_id": "qwen/qwen3-8b:q4_k_m",
5114 "overall_score": 0.88
5115 })
5116 .to_string(),
5117 )
5118 .unwrap();
5119
5120 unsafe {
5121 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5122 }
5123
5124 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5125 let tracker = engine.outcome_tracker.blocking_read();
5126 let profile = tracker
5127 .profile("qwen/qwen3-8b:q4_k_m")
5128 .expect("benchmark prior should create a profile");
5129 assert!((profile.ema_quality - 0.88).abs() < 0.01);
5130
5131 unsafe {
5132 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5133 }
5134 }
5135
5136 #[test]
5137 fn benchmark_priors_do_not_override_observed_profiles() {
5138 let _env = ENV_MUTEX.lock().unwrap();
5139 let tmp = TempDir::new().unwrap();
5140 let models_dir = tmp.path().join("models");
5141 std::fs::create_dir_all(&models_dir).unwrap();
5142
5143 let observed = vec![ModelProfile {
5144 model_id: "qwen/qwen3-8b:q4_k_m".into(),
5145 total_calls: 12,
5146 success_count: 3,
5147 fail_count: 9,
5148 total_latency_ms: 1200,
5149 total_input_tokens: 0,
5150 total_output_tokens: 0,
5151 task_stats: std::collections::HashMap::new(),
5152 ema_quality: 0.21,
5153 quality_per_1k_tokens: 0.0,
5154 updated_at: 1,
5155 }];
5156 std::fs::write(
5157 models_dir.join("outcome_profiles.json"),
5158 serde_json::to_string(&observed).unwrap(),
5159 )
5160 .unwrap();
5161
5162 let priors_path = tmp.path().join("benchmark_priors.json");
5163 std::fs::write(
5164 &priors_path,
5165 serde_json::json!({
5166 "model_id": "qwen/qwen3-8b:q4_k_m",
5167 "overall_score": 0.95
5168 })
5169 .to_string(),
5170 )
5171 .unwrap();
5172
5173 unsafe {
5174 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5175 }
5176
5177 let engine = InferenceEngine::new(test_config(models_dir));
5178 let tracker = engine.outcome_tracker.blocking_read();
5179 let profile = tracker
5180 .profile("qwen/qwen3-8b:q4_k_m")
5181 .expect("observed profile should remain present");
5182 assert!((profile.ema_quality - 0.21).abs() < 0.01);
5183 assert_eq!(profile.total_calls, 12);
5184
5185 unsafe {
5186 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5187 }
5188 }
5189
5190 #[test]
5191 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5192 fn speech_runtime_package_spec_defaults_and_overrides() {
5193 let _env = ENV_MUTEX.lock().unwrap();
5194 unsafe {
5195 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5196 }
5197 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.2");
5198
5199 unsafe {
5200 std::env::set_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC", "mlx-audio==0.4.1");
5201 }
5202 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.1");
5203
5204 unsafe {
5205 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5206 }
5207 }
5208
5209 #[test]
5210 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5211 fn speech_runtime_spacy_model_spec_defaults_and_overrides() {
5212 let _env = ENV_MUTEX.lock().unwrap();
5213 unsafe {
5214 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5215 }
5216 assert!(
5217 speech_runtime_spacy_model_spec().starts_with("en-core-web-sm @ https://github.com/")
5218 );
5219
5220 unsafe {
5221 std::env::set_var(
5222 "CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC",
5223 "en-core-web-sm==3.8.0",
5224 );
5225 }
5226 assert_eq!(
5227 speech_runtime_spacy_model_spec(),
5228 "en-core-web-sm==3.8.0"
5229 );
5230
5231 unsafe {
5232 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5233 }
5234 }
5235
5236 #[test]
5237 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5238 fn kokoro_runtime_fallback_defaults_on() {
5239 unsafe {
5240 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5241 }
5242 assert!(kokoro_runtime_fallback_enabled());
5243
5244 unsafe {
5245 std::env::set_var("CAR_SPEECH_KOKORO_FALLBACK", "false");
5246 }
5247 assert!(!kokoro_runtime_fallback_enabled());
5248
5249 unsafe {
5250 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5251 }
5252 }
5253
5254 #[test]
5255 fn preferred_local_tts_wins_over_builtin_rank() {
5256 let tmp = TempDir::new().unwrap();
5257 let mut engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5258 engine.set_speech_policy(SpeechPolicy {
5259 prefer_local: true,
5260 allow_remote_fallback: false,
5261 preferred_local_stt: None,
5262 preferred_local_tts: Some("Kokoro-82M-6bit".into()),
5263 preferred_remote_stt: None,
5264 preferred_remote_tts: None,
5265 });
5266
5267 let schema = engine
5268 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
5269 .expect("preferred local TTS should resolve");
5270 assert_eq!(schema.name, "Kokoro-82M-6bit");
5271 }
5272
5273 #[test]
5274 fn preferred_discovered_vllm_mlx_model_wins_generate_routing() {
5275 let tmp = TempDir::new().unwrap();
5276 let mut config = test_config(tmp.path().join("models"));
5277 config.preferred_generation_model =
5278 Some("vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit".into());
5279 let mut engine = InferenceEngine::new(config);
5280 let schema = crate::vllm_mlx::to_model_schema(
5281 &crate::vllm_mlx::DiscoveredModel {
5282 id: "mlx-community/gemma-3n-E2B-it-lm-4bit".into(),
5283 owned_by: Some("mlx-community".into()),
5284 },
5285 "http://127.0.0.1:8001",
5286 );
5287 engine.register_model(schema);
5288
5289 let rt = tokio::runtime::Runtime::new().unwrap();
5290 let decision = rt.block_on(engine.route_adaptive("say hello in one sentence"));
5291 assert_eq!(decision.model_id, "vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit");
5292 assert_eq!(decision.strategy, RoutingStrategy::Explicit);
5293 assert_eq!(decision.reason, "preferred generation model override");
5294 }
5295
5296 #[test]
5301 fn inference_result_serializes_with_full_shape() {
5302 use crate::tasks::generate::ToolCall;
5303 use std::collections::HashMap;
5304
5305 let mut args = HashMap::new();
5306 args.insert("path".to_string(), serde_json::json!("README.md"));
5307
5308 let result = InferenceResult {
5309 text: String::new(),
5310 bounding_boxes: Vec::new(),
5311 tool_calls: vec![ToolCall {
5312 id: None,
5313 name: "read_file".into(),
5314 arguments: args,
5315 }],
5316 trace_id: "trace-abc".into(),
5317 model_used: "test-model".into(),
5318 latency_ms: 1234,
5319 time_to_first_token_ms: Some(180),
5320 usage: Some(TokenUsage {
5321 prompt_tokens: 100,
5322 completion_tokens: 50,
5323 total_tokens: 150,
5324 context_window: 8192,
5325 }),
5326 provider_output_items: Vec::new(),
5327 };
5328
5329 let json = serde_json::to_value(&result).expect("serialize");
5330
5331 assert_eq!(json["text"].as_str(), Some(""));
5333 assert_eq!(json["trace_id"].as_str(), Some("trace-abc"));
5334 assert_eq!(json["model_used"].as_str(), Some("test-model"));
5335 assert_eq!(json["latency_ms"].as_u64(), Some(1234));
5336
5337 let tool_calls = json["tool_calls"].as_array().expect("tool_calls array");
5339 assert_eq!(tool_calls.len(), 1);
5340 assert_eq!(tool_calls[0]["name"].as_str(), Some("read_file"));
5341 assert_eq!(tool_calls[0]["arguments"]["path"].as_str(), Some("README.md"));
5342
5343 let usage = &json["usage"];
5345 assert_eq!(usage["prompt_tokens"].as_u64(), Some(100));
5346 assert_eq!(usage["completion_tokens"].as_u64(), Some(50));
5347 assert_eq!(usage["total_tokens"].as_u64(), Some(150));
5348 assert_eq!(usage["context_window"].as_u64(), Some(8192));
5349
5350 assert_eq!(json["time_to_first_token_ms"].as_u64(), Some(180));
5352 }
5353
5354 #[test]
5360 fn inference_result_top_level_keys_are_locked() {
5361 use std::collections::BTreeSet;
5362
5363 let result = InferenceResult {
5364 text: "anything".into(),
5365 bounding_boxes: Vec::new(),
5366 tool_calls: vec![],
5367 trace_id: "t".into(),
5368 model_used: "m".into(),
5369 latency_ms: 0,
5370 time_to_first_token_ms: None,
5371 usage: None,
5372 provider_output_items: Vec::new(),
5373 };
5374
5375 let json = serde_json::to_value(&result).expect("serialize");
5376 let keys: BTreeSet<&str> = json
5377 .as_object()
5378 .expect("top-level object")
5379 .keys()
5380 .map(String::as_str)
5381 .collect();
5382
5383 let expected: BTreeSet<&str> = [
5384 "text",
5385 "tool_calls",
5386 "trace_id",
5387 "model_used",
5388 "latency_ms",
5389 "time_to_first_token_ms",
5390 "usage",
5391 ]
5392 .into_iter()
5393 .collect();
5394
5395 assert_eq!(
5396 keys, expected,
5397 "infer response top-level keys drifted -- update both the test \
5398 and the WebSocket protocol documentation if this is intentional"
5399 );
5400
5401 for key in &keys {
5403 assert!(
5404 !key.chars().any(|c| c.is_uppercase()) && !key.contains('-'),
5405 "key '{}' is not snake_case",
5406 key
5407 );
5408 }
5409 }
5410
5411 #[test]
5415 fn inference_result_serializes_plain_text_response() {
5416 let result = InferenceResult {
5417 text: "hello world".into(),
5418 bounding_boxes: Vec::new(),
5419 tool_calls: vec![],
5420 trace_id: "trace-xyz".into(),
5421 model_used: "test-model".into(),
5422 latency_ms: 42,
5423 time_to_first_token_ms: None,
5424 usage: None,
5425 provider_output_items: Vec::new(),
5426 };
5427
5428 let json = serde_json::to_value(&result).expect("serialize");
5429 assert_eq!(json["text"], "hello world");
5430 assert!(json["tool_calls"].is_array());
5431 assert_eq!(json["tool_calls"].as_array().unwrap().len(), 0);
5432 assert_eq!(json["model_used"], "test-model");
5433 assert!(json["usage"].is_null());
5434 assert!(json["time_to_first_token_ms"].is_null());
5437 }
5438
5439 #[test]
5451 fn generate_request_deserializes_intent_field_from_json_rpc_params() {
5452 use crate::intent::{IntentHint, TaskHint};
5453 use crate::schema::ModelCapability;
5454
5455 let params = serde_json::json!({
5458 "prompt": "summarize this email",
5459 "intent": {
5460 "task": "chat",
5461 "prefer_local": true,
5462 "require": ["tool_use"],
5463 },
5464 });
5465
5466 let req: GenerateRequest =
5467 serde_json::from_value(params).expect("GenerateRequest deserialize");
5468
5469 let intent = req.intent.as_ref().expect("intent field deserialized");
5470 assert_eq!(intent.task, Some(TaskHint::Chat));
5471 assert!(intent.prefer_local);
5472 assert_eq!(intent.require, vec![ModelCapability::ToolUse]);
5473
5474 let back: serde_json::Value =
5478 serde_json::to_value(&req).expect("re-serialize GenerateRequest");
5479 assert_eq!(back["intent"]["task"], "chat");
5480 assert_eq!(back["intent"]["prefer_local"], true);
5481 assert_eq!(back["intent"]["require"][0], "tool_use");
5482
5483 let default_req: GenerateRequest = serde_json::from_value(serde_json::json!({
5488 "prompt": "x",
5489 "intent": {},
5490 }))
5491 .unwrap();
5492 let default_intent = default_req.intent.expect("present but empty");
5493 assert_eq!(default_intent.task, None);
5494 assert!(!default_intent.prefer_local);
5495 assert!(default_intent.require.is_empty());
5496
5497 let no_intent: GenerateRequest =
5500 serde_json::from_value(serde_json::json!({"prompt": "x"})).unwrap();
5501 assert!(no_intent.intent.is_none());
5502 }
5503
5504 #[test]
5505 fn rerank_prompt_matches_upstream_template_shape() {
5506 let p = rerank_prompt("retrieve relevant passages", "who runs the treasury?", "doc x");
5507 assert!(p.contains("<|im_start|>system"));
5508 assert!(p.contains("Note that the answer can only be \"yes\" or \"no\"."));
5509 assert!(p.contains("<|im_start|>user\n<Instruct>: retrieve relevant passages"));
5510 assert!(p.contains("<Query>: who runs the treasury?"));
5511 assert!(p.contains("<Document>: doc x<|im_end|>"));
5512 assert!(p.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
5513 }
5514
5515 #[test]
5516 fn rerank_score_yes_and_no_exactly() {
5517 assert_eq!(score_from_rerank_output("yes", "m"), 1.0);
5518 assert_eq!(score_from_rerank_output("no", "m"), 0.0);
5519 }
5520
5521 #[test]
5522 fn rerank_score_handles_case_leading_space_and_chat_sentinels() {
5523 assert_eq!(score_from_rerank_output(" Yes", "m"), 1.0);
5526 assert_eq!(score_from_rerank_output("\nno.", "m"), 0.0);
5527 assert_eq!(score_from_rerank_output("<|im_end|>yes", "m"), 1.0);
5528 }
5529
5530 #[test]
5531 fn rerank_score_scans_up_to_three_tokens() {
5532 assert_eq!(score_from_rerank_output("_bos_ yes", "m"), 1.0);
5535 }
5536
5537 #[test]
5538 fn rerank_score_unexpected_is_neutral() {
5539 assert_eq!(score_from_rerank_output("maybe", "m"), 0.5);
5542 assert_eq!(score_from_rerank_output("", "m"), 0.5);
5543 }
5544}