1pub mod adaptive_router;
26pub mod backend;
27pub mod backend_cache;
28pub mod handle;
29pub mod hardware;
30pub mod intent;
31pub mod key_pool;
32pub mod models;
33pub mod outcome;
34pub mod protocol;
35pub mod registry;
36pub mod remote;
37pub mod router;
38pub mod routing_ext;
39pub mod runner;
40pub mod schema;
41pub mod service;
42pub mod stream;
43pub mod tasks;
44pub mod vllm_mlx;
45
46use std::path::{Path, PathBuf};
47use std::sync::Arc;
48use std::time::Instant;
49use std::time::{SystemTime, UNIX_EPOCH};
50
51use reqwest::multipart::{Form, Part};
52use serde::Serialize;
53use thiserror::Error;
54#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
55use tokio::process::Command;
56#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
57use tokio::sync::Mutex;
58use tokio::sync::RwLock;
59use tracing::{debug, instrument};
60
61pub use adaptive_router::{
63 AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy,
64};
65pub use handle::InferenceHandle;
66pub use intent::{IntentHint, TaskHint};
67pub use key_pool::{KeyPool, KeyStats};
68pub use outcome::{
69 CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
70};
71pub use registry::{
72 ModelFilter, ModelInfo, ModelRuntimeRequirement, ModelUpgrade, UnifiedRegistry,
73};
74pub use remote::RemoteBackend;
75pub use routing_ext::{
76 CircuitBreaker, CircuitBreakerRegistry, CircuitState, ImplicitSignal, ImplicitSignalType,
77 RoutingMode, SpendControl, SpendLimitExceeded, SpendLimits, SpendStatus,
78};
79pub use runner::{
80 current_inference_runner, set_inference_runner, EventEmitter, InferenceRunner, RunnerError,
81 RunnerResult,
82};
83pub use schema::{
84 ApiProtocol, BenchmarkScore, CostModel, ModelCapability, ModelSchema, ModelSource,
85 PerformanceEnvelope, ProprietaryAuth,
86};
87
88pub use adaptive_router::TaskComplexity;
90#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
91pub use backend::CandleBackend;
92#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
93pub use backend::EmbeddingBackend;
94#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
95pub use backend::MlxBackend;
96pub use hardware::HardwareInfo;
97pub use models::{ModelRegistry, ModelRole};
98pub use router::{ModelRouter, RoutingDecision};
99pub use stream::{StreamAccumulator, StreamEvent};
100pub use tasks::{
101 parse_boxes, BoundingBox, ClassifyRequest, ClassifyResult, ContentBlock, EmbedRequest,
102 GenerateImageRequest, GenerateImageResult, GenerateParams, GenerateRequest,
103 GenerateVideoRequest, GenerateVideoResult, GroundRequest, GroundResult, Message, RerankRequest,
104 RerankResult, RerankedDocument, RoutingWorkload, SynthesizeRequest, SynthesizeResult,
105 ThinkingMode, ToolCall, TranscribeRequest, TranscribeResult, VideoMode,
106};
107#[derive(Error, Debug)]
110pub enum InferenceError {
111 #[error("model not found: {0}")]
112 ModelNotFound(String),
113
114 #[error("model download failed: {0}")]
115 DownloadFailed(String),
116
117 #[error("inference failed: {0}")]
118 InferenceFailed(String),
119
120 #[error("mode {mode} not implemented on backend {backend}: {reason}")]
125 UnsupportedMode {
126 mode: &'static str,
127 backend: &'static str,
128 reason: &'static str,
129 },
130
131 #[error("tokenization error: {0}")]
132 TokenizationError(String),
133
134 #[error("device error: {0}")]
135 DeviceError(String),
136
137 #[error("io error: {0}")]
138 Io(#[from] std::io::Error),
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub enum Device {
144 Cpu,
145 Metal,
146 Cuda(usize), }
148
149impl Device {
150 pub fn auto() -> Self {
152 #[cfg(feature = "metal")]
153 {
154 return Device::Metal;
155 }
156 #[cfg(feature = "cuda")]
157 {
158 return Device::Cuda(0);
159 }
160 #[cfg(not(any(feature = "metal", feature = "cuda")))]
161 {
162 Device::Cpu
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct InferenceConfig {
170 pub models_dir: std::path::PathBuf,
172 pub device: Option<Device>,
174 pub generation_model: String,
176 pub preferred_generation_model: Option<String>,
178 pub embedding_model: String,
180 pub preferred_embedding_model: Option<String>,
182 pub classification_model: String,
184 pub preferred_classification_model: Option<String>,
186}
187
188impl Default for InferenceConfig {
189 fn default() -> Self {
190 let models_dir = dirs_next()
191 .unwrap_or_else(|| std::path::PathBuf::from("."))
192 .join(".car")
193 .join("models");
194
195 let hw = HardwareInfo::detect();
196
197 Self {
198 models_dir,
199 device: None,
200 generation_model: hw.recommended_model,
201 preferred_generation_model: None,
202 embedding_model: "Qwen3-Embedding-0.6B".to_string(),
203 preferred_embedding_model: None,
204 classification_model: "Qwen3-0.6B".to_string(),
205 preferred_classification_model: None,
206 }
207 }
208}
209
210fn dirs_next() -> Option<std::path::PathBuf> {
211 std::env::var("HOME").ok().map(std::path::PathBuf::from)
212}
213
214#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
216pub struct TokenUsage {
217 pub prompt_tokens: u64,
219 pub completion_tokens: u64,
221 pub total_tokens: u64,
223 pub context_window: u64,
225}
226
227#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
229pub struct InferenceResult {
230 pub text: String,
232 pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
234 #[serde(default, skip_serializing_if = "Vec::is_empty")]
241 pub bounding_boxes: Vec<crate::tasks::grounding::BoundingBox>,
242 pub trace_id: String,
244 pub model_used: String,
246 pub latency_ms: u64,
248 #[serde(default)]
260 pub time_to_first_token_ms: Option<u64>,
261 pub usage: Option<TokenUsage>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
276 pub provider_output_items: Vec<serde_json::Value>,
277}
278
279impl InferenceResult {
280 pub fn has_tool_calls(&self) -> bool {
282 !self.tool_calls.is_empty()
283 }
284}
285
286#[derive(Debug, Clone, Serialize)]
287pub struct SpeechRuntimeHealth {
288 pub root: PathBuf,
289 pub installed: bool,
290 pub python: PathBuf,
291 pub stt_command: PathBuf,
292 pub tts_command: PathBuf,
293 pub configured_python: Option<String>,
294 pub detected_python: Option<String>,
295}
296
297#[derive(Debug, Clone, Serialize)]
298pub struct SpeechModelHealth {
299 pub id: String,
300 pub name: String,
301 pub provider: String,
302 pub capability: ModelCapability,
303 pub is_local: bool,
304 pub available: bool,
305 pub cached: bool,
306 pub selected_by_default: bool,
307 pub source: String,
308}
309
310#[derive(Debug, Clone, Serialize)]
311pub struct SpeechHealthReport {
312 pub runtime: SpeechRuntimeHealth,
313 pub local_models: Vec<SpeechModelHealth>,
314 pub remote_models: Vec<SpeechModelHealth>,
315 pub elevenlabs_configured: bool,
316 pub prefer_local: bool,
317 pub allow_remote_fallback: bool,
318 pub preferred_local_stt: Option<String>,
319 pub preferred_local_tts: Option<String>,
320 pub preferred_remote_stt: Option<String>,
321 pub preferred_remote_tts: Option<String>,
322 pub local_stt_default: Option<String>,
323 pub local_tts_default: Option<String>,
324 pub remote_stt_default: Option<String>,
325 pub remote_tts_default: Option<String>,
326}
327
328#[derive(Debug, Clone, Serialize)]
329pub struct ModelDefaultHealth {
330 pub capability: ModelCapability,
331 pub configured_model: String,
332 pub available: bool,
333 pub is_local: bool,
334 pub provider: Option<String>,
335}
336
337#[derive(Debug, Clone, Serialize)]
338pub struct ModelProviderHealth {
339 pub provider: String,
340 pub configured: bool,
341 pub local_models: usize,
342 pub remote_models: usize,
343 pub available_models: usize,
344 pub capabilities: Vec<ModelCapability>,
345}
346
347#[derive(Debug, Clone, Serialize)]
348pub struct ModelCapabilityHealth {
349 pub capability: ModelCapability,
350 pub total_models: usize,
351 pub available_models: usize,
352 pub local_available_models: usize,
353 pub remote_available_models: usize,
354}
355
356#[derive(Debug, Clone, Serialize)]
357pub struct RoutingScenarioHealth {
358 pub name: String,
359 pub workload: RoutingWorkload,
360 pub task_family: String,
361 pub has_tools: bool,
362 pub has_vision: bool,
363 pub prefer_local: bool,
364 pub quality_first_cold_start: bool,
365 pub bootstrap_min_task_observations: u64,
366 pub bootstrap_quality_floor: f64,
367 pub model_id: String,
368 pub model_name: String,
369 pub reason: String,
370 pub strategy: RoutingStrategy,
371}
372
373#[derive(Debug, Clone, Serialize)]
374pub struct ModelBenchmarkPriorHealth {
375 pub model_id: String,
376 pub model_name: Option<String>,
377 pub overall_score: f64,
378 pub overall_latency_ms: Option<f64>,
379 pub task_scores: std::collections::HashMap<String, f64>,
380 pub task_latency_ms: std::collections::HashMap<String, f64>,
381 pub source_path: PathBuf,
382}
383
384#[derive(Debug, Clone, Serialize)]
385pub struct ModelHealthReport {
386 pub total_models: usize,
387 pub available_models: usize,
388 pub local_models: usize,
389 pub remote_models: usize,
390 pub defaults: Vec<ModelDefaultHealth>,
391 pub providers: Vec<ModelProviderHealth>,
392 pub capabilities: Vec<ModelCapabilityHealth>,
393 pub routing_prefer_local: bool,
394 pub routing_quality_first_cold_start: bool,
395 pub routing_min_observations: u64,
396 pub routing_bootstrap_min_task_observations: u64,
397 pub routing_bootstrap_quality_floor: f64,
398 pub routing_quality_weight: f64,
399 pub routing_latency_weight: f64,
400 pub routing_cost_weight: f64,
401 pub routing_scenarios: Vec<RoutingScenarioHealth>,
402 pub benchmark_priors: Vec<ModelBenchmarkPriorHealth>,
403 pub speech: SpeechHealthReport,
404}
405
406#[derive(Debug, Clone, Serialize)]
407pub struct SpeechInstallReport {
408 pub name: String,
409 pub hf_repo: String,
410 pub snapshot_path: PathBuf,
411 pub files_downloaded: usize,
412}
413
414#[derive(Debug, Clone, Serialize)]
415pub struct SpeechSmokePathReport {
416 pub path: String,
417 pub tts_model: String,
418 pub stt_model: String,
419 pub audio_path: PathBuf,
420 pub transcript: String,
421}
422
423#[derive(Debug, Clone, Serialize, Default)]
424pub struct SpeechSmokeReport {
425 pub local: Option<SpeechSmokePathReport>,
426 pub remote: Option<SpeechSmokePathReport>,
427 pub skipped: Vec<String>,
428}
429
430#[derive(Debug, Clone, Serialize, Default)]
431pub struct SpeechPolicy {
432 pub prefer_local: bool,
433 pub allow_remote_fallback: bool,
434 pub preferred_local_stt: Option<String>,
435 pub preferred_local_tts: Option<String>,
436 pub preferred_remote_stt: Option<String>,
437 pub preferred_remote_tts: Option<String>,
438}
439
440pub struct InferenceEngine {
445 pub config: InferenceConfig,
446 pub unified_registry: UnifiedRegistry,
448 pub adaptive_router: AdaptiveRouter,
450 pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
452 remote_backend: RemoteBackend,
454 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
460 mlx_backends: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
461 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
466 flux_cache: Arc<backend_cache::BackendCache<backend::mlx_flux::FluxBackend>>,
467 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
469 ltx_cache: Arc<backend_cache::BackendCache<backend::mlx_ltx::LtxBackend>>,
470 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
473 kokoro_cache: Arc<backend_cache::BackendCache<backend::mlx_kokoro::KokoroBackend>>,
474 pub registry: models::ModelRegistry,
476 pub router: ModelRouter,
477 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
478 backend: Arc<RwLock<Option<CandleBackend>>>,
479 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
480 embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
481 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
482 speech_runtime: Arc<Mutex<Option<SpeechRuntime>>>,
483 speech_policy: SpeechPolicy,
484}
485
486impl InferenceEngine {
487 fn preferred_model_for_capability(&self, capability: ModelCapability) -> Option<&str> {
488 match capability {
489 ModelCapability::Generate => self.config.preferred_generation_model.as_deref(),
490 ModelCapability::Embed => self.config.preferred_embedding_model.as_deref(),
491 ModelCapability::Classify => self.config.preferred_classification_model.as_deref(),
492 _ => None,
493 }
494 }
495
496 fn request_needs_vision(req: &GenerateRequest) -> bool {
497 req.images.as_ref().is_some_and(|images| !images.is_empty())
498 || req.messages.as_ref().is_some_and(|messages| {
499 messages
500 .iter()
501 .any(|msg| matches!(msg, Message::UserMultimodal { .. }))
502 })
503 }
504
505 fn request_has_video(req: &GenerateRequest) -> bool {
510 let images_have_video = req
511 .images
512 .as_ref()
513 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
514 let messages_have_video = req.messages.as_ref().is_some_and(|messages| {
515 messages.iter().any(|msg| match msg {
516 Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_video),
517 _ => false,
518 })
519 });
520 images_have_video || messages_have_video
521 }
522
523 fn request_has_audio(req: &GenerateRequest) -> bool {
527 let images_have_audio = req
528 .images
529 .as_ref()
530 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
531 let messages_have_audio = req.messages.as_ref().is_some_and(|messages| {
532 messages.iter().any(|msg| match msg {
533 Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_audio),
534 _ => false,
535 })
536 });
537 images_have_audio || messages_have_audio
538 }
539
540 pub fn new(config: InferenceConfig) -> Self {
541 let registry = models::ModelRegistry::new(config.models_dir.clone());
542 let hw = HardwareInfo::detect();
543 let router = ModelRouter::new(hw.clone());
544 let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
545 let adaptive_router = AdaptiveRouter::with_default_config(hw);
546 let mut tracker = OutcomeTracker::new();
547 let profiles_path = config.models_dir.join("outcome_profiles.json");
549 if let Ok(n) = tracker.load_from_file(&profiles_path) {
550 if n > 0 {
551 tracing::info!(loaded = n, "loaded persisted model profiles");
552 }
553 }
554 let mut benchmark_models_loaded = 0usize;
555 for path in benchmark_priors_paths(&config.models_dir) {
556 match routing_ext::load_benchmark_priors(&path) {
557 Ok(priors) if !priors.is_empty() => {
558 benchmark_models_loaded += priors.len();
559 routing_ext::apply_benchmark_priors(&mut tracker, &priors);
560 tracing::info!(
561 path = %path.display(),
562 loaded = priors.len(),
563 "loaded benchmark quality priors"
564 );
565 }
566 Ok(_) => {}
567 Err(error) => {
568 tracing::warn!(path = %path.display(), %error, "failed to load benchmark priors");
569 }
570 }
571 }
572 if benchmark_models_loaded > 0 {
573 tracing::info!(
574 loaded = benchmark_models_loaded,
575 "applied benchmark priors to cold-start routing"
576 );
577 }
578 let outcome_tracker = Arc::new(RwLock::new(tracker));
579
580 let remote_backend = RemoteBackend::new();
581
582 Self {
583 config,
584 unified_registry,
585 adaptive_router,
586 outcome_tracker,
587 remote_backend,
588 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
589 mlx_backends: Arc::new(backend_cache::BackendCache::from_env()),
590 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
591 flux_cache: Arc::new(backend_cache::BackendCache::from_env()),
592 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
593 ltx_cache: Arc::new(backend_cache::BackendCache::from_env()),
594 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
595 kokoro_cache: Arc::new(backend_cache::BackendCache::from_env()),
596 registry,
597 router,
598 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
599 backend: Arc::new(RwLock::new(None)),
600 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
601 embedding_backend: Arc::new(RwLock::new(None)),
602 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
603 speech_runtime: Arc::new(Mutex::new(None)),
604 speech_policy: SpeechPolicy {
605 prefer_local: cfg!(all(
606 target_os = "macos",
607 target_arch = "aarch64",
608 not(car_skip_mlx)
609 )),
610 allow_remote_fallback: true,
611 preferred_local_stt: None,
612 preferred_local_tts: None,
613 preferred_remote_stt: None,
614 preferred_remote_tts: None,
615 },
616 }
617 }
618
619 pub async fn init_key_pool(&self) {
622 for schema in self.unified_registry.list() {
624 if schema.is_remote() {
625 self.remote_backend.register_model_keys(schema).await;
626 }
627 }
628
629 let stats_path = self.config.models_dir.join("key_pool_stats.json");
631 if let Ok(n) = self.remote_backend.key_pool.load_stats(&stats_path).await {
632 if n > 0 {
633 tracing::info!(loaded = n, "loaded persisted key pool stats");
634 }
635 }
636
637 let total = self.remote_backend.key_pool.total_keys().await;
638 if total > 0 {
639 tracing::info!(keys = total, "key pool initialized");
640 }
641 }
642
643 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
646 async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
647 let read = self.backend.read().await;
648 if read.is_some() {
649 return Ok(());
650 }
651 drop(read);
652
653 let mut write = self.backend.write().await;
654 if write.is_some() {
655 return Ok(());
656 }
657
658 let model_path = self.registry.ensure_model(model_name).await?;
659 let device = self.config.device.unwrap_or_else(Device::auto);
660 let backend = CandleBackend::load(&model_path, device)?;
661 *write = Some(backend);
662 Ok(())
663 }
664
665 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
668 async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
669 let read = self.embedding_backend.read().await;
670 if read.is_some() {
671 return Ok(());
672 }
673 drop(read);
674
675 let mut write = self.embedding_backend.write().await;
676 if write.is_some() {
677 return Ok(());
678 }
679
680 let embedding_model = self
681 .preferred_model_for_capability(ModelCapability::Embed)
682 .unwrap_or(&self.config.embedding_model);
683 let model_path = self.registry.ensure_model(embedding_model).await?;
684 let device = self.config.device.unwrap_or_else(Device::auto);
685 let backend = EmbeddingBackend::load(&model_path, device)?;
686 *write = Some(backend);
687 Ok(())
688 }
689
690 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
693 async fn ensure_mlx_embedding_backend(&self) -> Result<String, InferenceError> {
694 let embedding_model_name = self
695 .preferred_model_for_capability(ModelCapability::Embed)
696 .unwrap_or(&self.config.embedding_model)
697 .to_string();
698 let schema = self
699 .unified_registry
700 .get(&embedding_model_name)
701 .or_else(|| self.unified_registry.find_by_name(&embedding_model_name))
702 .ok_or_else(|| InferenceError::ModelNotFound(embedding_model_name.clone()))?
703 .clone();
704 self.ensure_mlx_backend(&schema).await?;
705 Ok(schema.id)
706 }
707
708 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
712 async fn ensure_mlx_backend(
713 &self,
714 schema: &ModelSchema,
715 ) -> Result<backend_cache::CachedBackend<backend::MlxBackend>, InferenceError> {
716 if !Self::supports_native_mlx(schema) {
717 return Err(InferenceError::InferenceFailed(format!(
718 "native MLX backend does not support {} ({}) yet; use vLLM-MLX or add a family-specific MLX backend",
719 schema.name, schema.family
720 )));
721 }
722
723 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
724 let size = backend_cache::estimate_model_size(&model_dir);
725 let cache = Arc::clone(&self.mlx_backends);
726 let key = schema.id.clone();
727 cache.get_or_load(&key, size, move || {
731 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
732 backend::MlxBackend::load(&model_dir)
733 }))
734 .map_err(|e| {
735 InferenceError::InferenceFailed(format!(
736 "MLX backend loading panicked (possible Metal/accelerate exception): {:?}",
737 e
738 ))
739 })?
740 })
741 }
742
743 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
750 pub async fn warm_up<S: AsRef<str>>(
751 &self,
752 schema_ids: &[S],
753 ) -> Vec<Result<(), InferenceError>> {
754 let mut results = Vec::with_capacity(schema_ids.len());
755 for id in schema_ids {
756 let id = id.as_ref();
757 let outcome: Result<(), InferenceError> = async {
758 let schema = self.unified_registry.get(id).cloned().ok_or_else(|| {
759 InferenceError::InferenceFailed(format!("warm_up: unknown schema id {id}"))
760 })?;
761 match schema.capabilities.first().copied() {
762 Some(ModelCapability::ImageGeneration) => {
763 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
764 let size = backend_cache::estimate_model_size(&model_dir);
765 let _ = self.flux_cache.get_or_load(&schema.id, size, || {
766 backend::mlx_flux::FluxBackend::load(&model_dir)
767 })?;
768 }
769 Some(ModelCapability::VideoGeneration) => {
770 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
771 let size = backend_cache::estimate_model_size(&model_dir);
772 let _ = self.ltx_cache.get_or_load(&schema.id, size, || {
773 backend::mlx_ltx::LtxBackend::load(&model_dir)
774 })?;
775 }
776 Some(ModelCapability::TextToSpeech) => {
777 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
778 let size = backend_cache::estimate_model_size(&model_dir);
779 let _ = self.kokoro_cache.get_or_load(&schema.id, size, || {
780 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
781 })?;
782 }
783 _ => {
784 let _ = self.ensure_mlx_backend(&schema).await?;
785 }
786 }
787 Ok(())
788 }
789 .await;
790 results.push(outcome);
791 }
792 results
793 }
794
795 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
797 pub async fn warm_up<S: AsRef<str>>(
798 &self,
799 _schema_ids: &[S],
800 ) -> Vec<Result<(), InferenceError>> {
801 Vec::new()
802 }
803
804 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
805 fn supports_native_mlx(schema: &ModelSchema) -> bool {
806 matches!(schema.family.as_str(), "qwen3" | "qwen2.5-vl" | "qwen2-vl")
807 }
808
809 pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
811 if let Some(model) = self.preferred_model_for_capability(ModelCapability::Generate) {
812 let ctx_len = self
813 .unified_registry
814 .get(model)
815 .or_else(|| self.unified_registry.find_by_name(model))
816 .map(|s| s.context_length)
817 .unwrap_or(0);
818 return AdaptiveRoutingDecision {
819 model_id: model.to_string(),
820 model_name: model.to_string(),
821 task: InferenceTask::Generate,
822 complexity: TaskComplexity::assess(prompt),
823 reason: "preferred generation model override".into(),
824 strategy: RoutingStrategy::Explicit,
825 predicted_quality: 0.5,
826 fallbacks: vec![],
827 context_length: ctx_len,
828 needs_compaction: false,
829 };
830 }
831 let tracker = self.outcome_tracker.read().await;
832 self.adaptive_router
833 .route(prompt, &self.unified_registry, &tracker)
834 }
835
836 pub fn route(&self, prompt: &str) -> RoutingDecision {
838 self.router.route_generate(prompt, &self.registry)
839 }
840
841 pub fn estimated_tokens(
844 &self,
845 req: &GenerateRequest,
846 model_id: Option<&str>,
847 ) -> (usize, usize, bool) {
848 let prompt_tokens = remote::estimate_tokens(&req.prompt);
849 let context_tokens = req
850 .context
851 .as_ref()
852 .map(|c| remote::estimate_tokens(c))
853 .unwrap_or(0);
854 let tools_tokens = req
855 .tools
856 .as_ref()
857 .map(|t| remote::estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
858 .unwrap_or(0);
859 let total_input = prompt_tokens + context_tokens + tools_tokens;
860
861 let context_window = model_id
862 .and_then(|id| {
863 self.unified_registry
864 .get(id)
865 .or_else(|| self.unified_registry.find_by_name(id))
866 })
867 .map(|s| s.context_length)
868 .unwrap_or(0);
869
870 let fits = context_window == 0 || (total_input + req.params.max_tokens) <= context_window;
871 (total_input, context_window, fits)
872 }
873
874 #[instrument(
877 name = "inference.generate",
878 skip_all,
879 fields(
880 model = tracing::field::Empty,
881 max_tokens = req.params.max_tokens,
882 prompt_tokens = tracing::field::Empty,
883 completion_tokens = tracing::field::Empty,
884 latency_ms = tracing::field::Empty,
885 )
886 )]
887 pub async fn generate_tracked(
888 &self,
889 req: GenerateRequest,
890 ) -> Result<InferenceResult, InferenceError> {
891 let start = Instant::now();
892
893 let (estimated_input, _, _) = self.estimated_tokens(&req, None);
895 let tracker_read = self.outcome_tracker.read().await;
896 let has_tools = req.tools.is_some();
897 let has_vision = Self::request_needs_vision(&req);
898 let preferred_model = self
899 .preferred_model_for_capability(ModelCapability::Generate)
900 .map(str::to_string);
901 let decision = match req.model.clone().or(preferred_model) {
902 Some(m) => {
903 let ctx_len = self
904 .unified_registry
905 .get(&m)
906 .or_else(|| self.unified_registry.find_by_name(&m))
907 .map(|s| s.context_length)
908 .unwrap_or(0);
909 AdaptiveRoutingDecision {
910 model_id: m.clone(),
911 model_name: m.clone(),
912 task: InferenceTask::Generate,
913 complexity: TaskComplexity::assess(&req.prompt),
914 reason: "explicit model".into(),
915 strategy: RoutingStrategy::Explicit,
916 predicted_quality: 0.5,
917 fallbacks: vec![],
918 context_length: ctx_len,
919 needs_compaction: ctx_len > 0 && estimated_input > ctx_len,
920 }
921 }
922 None => match &req.intent {
923 Some(hint) => self.adaptive_router.route_context_aware_with_intent(
924 &req.prompt,
925 estimated_input,
926 &self.unified_registry,
927 &tracker_read,
928 has_tools,
929 has_vision,
930 req.params.workload,
931 hint,
932 ),
933 None => self.adaptive_router.route_context_aware(
934 &req.prompt,
935 estimated_input,
936 &self.unified_registry,
937 &tracker_read,
938 has_tools,
939 has_vision,
940 req.params.workload,
941 ),
942 },
943 };
944 drop(tracker_read);
945
946 if decision.needs_compaction {
947 tracing::info!(
948 model = %decision.model_name,
949 prompt_tokens = estimated_input,
950 context_window = decision.context_length,
951 "prompt exceeds model context window — compaction or truncation needed"
952 );
953 }
954
955 let trace_id = {
957 let mut tracker = self.outcome_tracker.write().await;
958 tracker.record_start(&decision.model_id, decision.task, &decision.reason)
959 };
960
961 debug!(
962 model = %decision.model_name,
963 strategy = ?decision.strategy,
964 reason = %decision.reason,
965 trace = %trace_id,
966 "adaptive-routed generate request"
967 );
968
969 let mut req = req;
972 if req.params.budget_tokens == 0 && matches!(decision.complexity, TaskComplexity::Complex) {
973 let supports_thinking = self
974 .unified_registry
975 .get(&decision.model_id)
976 .map(|s| {
977 s.supported_params
978 .contains(&schema::GenerateParam::ExtendedThinking)
979 })
980 .unwrap_or(false);
981 if supports_thinking {
982 req.params.budget_tokens = 8000;
983 tracing::info!(model = %decision.model_name, budget = 8000, "auto-enabled extended thinking for complex task");
984 }
985 }
986
987 let mut models_to_try = vec![decision.model_id.clone()];
989 models_to_try.extend(decision.fallbacks.iter().cloned());
990
991 let mut last_error = None;
992
993 for candidate_id in &models_to_try {
994 #[allow(unused_mut)]
997 let mut schema = self
998 .unified_registry
999 .get(candidate_id)
1000 .or_else(|| self.unified_registry.find_by_name(candidate_id))
1001 .cloned();
1002
1003 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1005 if let Some(ref s) = schema {
1006 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1007 tracing::info!(
1008 from = %s.id, to = %mlx_equiv.id,
1009 "redirecting GGUF model to MLX equivalent on Apple Silicon"
1010 );
1011 schema = Some(mlx_equiv.clone());
1012 }
1013 }
1014
1015 let candidate_name = schema
1016 .as_ref()
1017 .map(|s| s.name.clone())
1018 .unwrap_or_else(|| candidate_id.clone());
1019
1020 let is_remote = schema
1021 .as_ref()
1022 .map(|s| s.is_remote() || s.is_vllm_mlx())
1023 .unwrap_or(false);
1024 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1025
1026 if is_delegated {
1032 let runner = match runner::current_inference_runner() {
1033 Some(r) => r,
1034 None => {
1035 last_error = Some(InferenceError::InferenceFailed(
1036 "model declares ModelSource::Delegated but no inference runner is registered"
1037 .into(),
1038 ));
1039 continue;
1040 }
1041 };
1042 let (tx, mut rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1043 let emitter = runner::EventEmitter::new(tx);
1044 let runner_req = req.clone();
1045 let runner_handle =
1046 tokio::spawn(async move { runner.run(runner_req, emitter).await });
1047 let mut accumulator = stream::StreamAccumulator::default();
1048 while let Some(evt) = rx.recv().await {
1049 accumulator.push(&evt);
1050 }
1051 let (acc_text, acc_tool_calls) = accumulator.finish();
1055 match runner_handle.await {
1056 Ok(Ok(_runner_result)) => {
1057 let elapsed = start.elapsed().as_millis() as u64;
1058 let mut tracker = self.outcome_tracker.write().await;
1059 tracker.record_complete(&trace_id, elapsed, 0, 0);
1060 return Ok(InferenceResult {
1061 text: acc_text,
1062 tool_calls: acc_tool_calls,
1063 bounding_boxes: vec![],
1064 trace_id,
1065 model_used: candidate_name,
1066 latency_ms: elapsed,
1067 time_to_first_token_ms: None,
1068 usage: None,
1069 provider_output_items: vec![],
1070 });
1071 }
1072 Ok(Err(e)) => {
1073 last_error = Some(InferenceError::InferenceFailed(e.to_string()));
1074 continue;
1075 }
1076 Err(join_err) => {
1077 last_error = Some(InferenceError::InferenceFailed(format!(
1078 "runner task panicked: {join_err}"
1079 )));
1080 continue;
1081 }
1082 }
1083 }
1084
1085 let has_tools = req.tools.is_some();
1086
1087 let context = if has_tools
1089 && req.tools.as_ref().map_or(false, |t| {
1090 t.iter().any(|tool| {
1091 tool.get("function")
1092 .and_then(|f| f.get("name"))
1093 .and_then(|n| n.as_str())
1094 == Some("done")
1095 })
1096 }) {
1097 let base = req.context.as_deref().unwrap_or("");
1098 Some(format!(
1099 "{base}\n\nIMPORTANT: When calling the `done` tool, the `result` field MUST contain a DETAILED summary of everything you found and did. This is the ONLY output the user sees. Do NOT just say 'completed' — include specific findings, data, and conclusions."
1100 ))
1101 } else {
1102 req.context.clone()
1103 };
1104
1105 let result = if is_remote {
1106 let schema_val = schema.unwrap();
1107 let _ctx_len = schema_val.context_length;
1108 let temperature = if !schema_val.supported_params.is_empty()
1111 && !schema_val
1112 .supported_params
1113 .contains(&crate::schema::GenerateParam::Temperature)
1114 {
1115 -1.0
1116 } else {
1117 req.params.temperature
1118 };
1119
1120 self.remote_backend
1126 .generate_with_tools_multi(
1127 &schema_val,
1128 &req.prompt,
1129 context.as_deref(),
1130 temperature,
1131 req.params.max_tokens,
1132 req.tools.as_deref(),
1133 req.images.as_deref(),
1134 req.messages.as_deref(),
1135 req.params.tool_choice.as_deref(),
1136 req.params.parallel_tool_calls,
1137 req.params.budget_tokens,
1138 req.cache_control,
1139 req.response_format.as_ref(),
1140 )
1141 .await
1142 .map(|(t, c, u)| (t, c, u, None::<u64>))
1147 } else {
1148 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1149 {
1150 let schema_ref = schema
1154 .as_ref()
1155 .ok_or_else(|| InferenceError::ModelNotFound(candidate_id.clone()))?;
1156
1157 if schema_ref.is_foundation_models() {
1162 if has_tools {
1163 Err(InferenceError::UnsupportedMode {
1164 mode: "tool-use",
1165 backend: "foundation-models",
1166 reason: "tool calling not yet wired through the FoundationModels \
1167 bridge — route to a remote model with ToolUse capability",
1168 })
1169 } else if Self::request_has_video(&req)
1170 || Self::request_has_audio(&req)
1171 || req.images.as_ref().is_some_and(|imgs| !imgs.is_empty())
1172 {
1173 Err(InferenceError::UnsupportedMode {
1174 mode: "multimodal-content",
1175 backend: "foundation-models",
1176 reason: "the FoundationModels bridge currently exposes text-only \
1177 generation — route image/audio/video to a remote VL model",
1178 })
1179 } else {
1180 let prompt = req.prompt.clone();
1181 let instructions = context.clone();
1182 let max_tokens = req.params.max_tokens as u32;
1183 let temperature = req.params.temperature;
1184 tokio::task::spawn_blocking(move || {
1185 crate::backend::foundation_models::generate(
1186 &prompt,
1187 instructions.as_deref(),
1188 max_tokens,
1189 temperature as f32,
1190 )
1191 })
1192 .await
1193 .map_err(|e| {
1194 InferenceError::InferenceFailed(format!(
1195 "FoundationModels task panicked: {e}"
1196 ))
1197 })
1198 .and_then(|r| r)
1199 .map(|text| (text, vec![], None, None))
1200 }
1201 } else if !schema_ref.is_mlx() {
1202 Err(InferenceError::InferenceFailed(format!(
1203 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1204 schema_ref.id
1205 )))
1206 } else if schema_ref.tags.iter().any(|t| t == "mlx-vlm-cli") {
1207 let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
1219 if !has_images {
1220 return Err(InferenceError::UnsupportedMode {
1221 mode: "text-only-on-mlx-vlm-id",
1222 backend: "mlx-vlm-cli",
1223 reason: "the `mlx-vlm/...` model IDs route exclusively \
1224 through the mlx-vlm CLI for image inference. \
1225 For text-only generation, route to a Qwen3 \
1226 text model (`mlx/qwen3-4b:4bit` etc.) — the \
1227 CLI shell-out has higher latency than the \
1228 in-process MLX text tower.",
1229 });
1230 }
1231 let vlm_status = crate::backend::mlx_vlm_cli::runtime_status();
1232 if !vlm_status.is_available() {
1233 return Err(InferenceError::InferenceFailed(vlm_status.user_message()));
1234 }
1235 let repo = match &schema_ref.source {
1236 crate::schema::ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
1237 _ => {
1238 return Err(InferenceError::InferenceFailed(format!(
1239 "model '{}' is tagged mlx-vlm-cli but its \
1240 source isn't ModelSource::Mlx — registry bug",
1241 schema_ref.id
1242 )));
1243 }
1244 };
1245 let imgs = req.images.clone().unwrap_or_default();
1246 let temp = req.params.temperature;
1247 let max_t = req.params.max_tokens;
1248 let prompt = req.prompt.clone();
1249 let text = tokio::task::spawn_blocking(move || {
1250 crate::backend::mlx_vlm_cli::generate(
1251 &repo, &prompt, &imgs, temp, max_t,
1252 )
1253 })
1254 .await
1255 .map_err(|e| {
1256 InferenceError::InferenceFailed(format!(
1257 "mlx_vlm CLI task panicked: {e}"
1258 ))
1259 })??;
1260 let bounding_boxes = parse_boxes(&text);
1261 let latency_ms = start.elapsed().as_millis() as u64;
1262 {
1263 let mut tracker = self.outcome_tracker.write().await;
1264 tracker.record_complete(&trace_id, latency_ms, 0, 0);
1265 }
1266 return Ok(InferenceResult {
1267 text,
1268 tool_calls: vec![],
1269 bounding_boxes,
1270 trace_id: trace_id.clone(),
1271 model_used: schema_ref.id.clone(),
1272 latency_ms,
1273 time_to_first_token_ms: None,
1274 usage: None,
1275 provider_output_items: Vec::new(),
1276 });
1277 } else {
1278 let handle = self.ensure_mlx_backend(schema_ref).await?;
1288 if Self::request_has_video(&req) {
1294 return Err(InferenceError::UnsupportedMode {
1295 mode: "video-content-block",
1296 backend: "native-mlx-qwen25vl",
1297 reason: "Qwen2.5-VL video understanding is on the request surface \
1298 but the video-tokenization path (frame sampling + merger) \
1299 is not yet wired; route to a remote VL provider for now",
1300 });
1301 }
1302 if Self::request_has_audio(&req) {
1303 return Err(InferenceError::UnsupportedMode {
1304 mode: "audio-content-block",
1305 backend: "native-mlx-qwen25vl",
1306 reason: "audio understanding is on the request surface (Gemma 4 \
1307 E2B/E4B and Gemini accept it) but the native MLX path \
1308 for this model does not — route to Gemini or Gemma-4",
1309 });
1310 }
1311 let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
1312 if has_images {
1313 let can_do_vision = {
1314 let guard = handle.lock().map_err(|_| {
1315 InferenceError::InferenceFailed(
1316 "MLX backend mutex poisoned".into(),
1317 )
1318 })?;
1319 guard.supports_capability(crate::schema::ModelCapability::Vision)
1320 };
1321 if !can_do_vision {
1322 return Err(InferenceError::UnsupportedMode {
1332 mode: "image-content-block",
1333 backend: "native-mlx-text",
1334 reason: "this MLX backend is a plain Qwen3 text tower. \
1335 For local image inference, route to \
1336 `mlx-vlm/qwen3-vl-2b:bf16` or another `mlx-vlm/...` \
1337 catalog ID so CAR shells out to `mlx_vlm.generate`. \
1338 Alternatives: a local vLLM-MLX VLM server, or a \
1339 remote VL model. (#115)",
1340 });
1341 }
1342 }
1343 self.generate_mlx(req.clone(), &schema_ref.id)
1344 .await
1345 .map(|(text, ttft)| (text, vec![], None, ttft))
1346 }
1347 }
1348
1349 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1350 {
1351 match self.ensure_backend(&candidate_name).await {
1352 Ok(()) => {
1353 let mut write = self.backend.write().await;
1354 let backend = write.as_mut().unwrap();
1355 tasks::generate::generate(backend, req.clone())
1356 .await
1357 .map(|(text, ttft)| (text, vec![], None, ttft))
1358 }
1359 Err(e) => Err(e),
1360 }
1361 }
1362 };
1363
1364 match result {
1365 Ok((text, tool_calls, usage, time_to_first_token_ms)) => {
1366 let latency_ms = start.elapsed().as_millis() as u64;
1367 let estimated_tokens = usage
1368 .as_ref()
1369 .map(|u| u.completion_tokens as usize)
1370 .unwrap_or_else(|| text.split_whitespace().count());
1371 {
1372 let mut tracker = self.outcome_tracker.write().await;
1373 tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
1374 }
1375 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1377 cb.record_success(candidate_id);
1378 }
1379 self.auto_save_outcomes().await;
1381
1382 let span = tracing::Span::current();
1384 span.record("model", candidate_name.as_str());
1385 span.record("latency_ms", latency_ms);
1386 if let Some(ttft) = time_to_first_token_ms {
1387 span.record("ttft_ms", ttft);
1388 }
1389 if let Some(ref u) = usage {
1390 span.record("prompt_tokens", u.prompt_tokens);
1391 span.record("completion_tokens", u.completion_tokens);
1392 }
1393
1394 let bounding_boxes = tasks::grounding::parse_boxes(&text);
1397 return Ok(InferenceResult {
1398 text,
1399 tool_calls,
1400 bounding_boxes,
1401 trace_id,
1402 model_used: candidate_name,
1403 latency_ms,
1404 time_to_first_token_ms,
1405 usage,
1406 provider_output_items: Vec::new(),
1407 });
1408 }
1409 Err(e) => {
1410 tracing::warn!(
1411 model = %candidate_name,
1412 error = %e,
1413 remaining = models_to_try.len().saturating_sub(
1414 models_to_try.iter().position(|m| m == candidate_id).unwrap_or(0) + 1
1415 ),
1416 "model failed, trying next fallback immediately"
1417 );
1418 {
1420 let mut tracker = self.outcome_tracker.write().await;
1421 let fail_trace =
1422 tracker.record_start(candidate_id, decision.task, "fallback");
1423 tracker.record_failure(&fail_trace, &e.to_string());
1424 }
1425 {
1429 let err_str = e.to_string();
1430 let is_client_error =
1431 err_str.contains("API returned 4") && !err_str.contains("429");
1432 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1433 cb.record_failure(candidate_id);
1434 if is_client_error {
1437 cb.record_failure(candidate_id);
1438 }
1439 }
1440 }
1441 #[cfg(not(all(
1443 target_os = "macos",
1444 target_arch = "aarch64",
1445 not(car_skip_mlx)
1446 )))]
1447 {
1448 let mut write = self.backend.write().await;
1449 *write = None;
1450 }
1451 last_error = Some(e);
1452 }
1453 }
1454 }
1455
1456 let underlying = last_error.unwrap_or(InferenceError::InferenceFailed(
1458 "no models available".into(),
1459 ));
1460
1461 let e = match no_backend_recovery_hint(&underlying.to_string()) {
1470 Some(msg) => InferenceError::InferenceFailed(msg),
1471 None => underlying,
1472 };
1473 {
1474 let mut tracker = self.outcome_tracker.write().await;
1475 tracker.record_failure(&trace_id, &e.to_string());
1476 }
1477 self.auto_save_outcomes().await;
1478 Err(e)
1479 }
1480
1481 pub async fn generate_tracked_stream(
1518 &self,
1519 req: GenerateRequest,
1520 ) -> Result<tokio::sync::mpsc::Receiver<stream::StreamEvent>, InferenceError> {
1521 let has_tools = req.tools.is_some();
1522 let has_vision = Self::request_needs_vision(&req);
1523 let preferred_model = self
1524 .preferred_model_for_capability(ModelCapability::Generate)
1525 .map(str::to_string);
1526 let decision = match req.model.clone().or(preferred_model) {
1527 Some(m) => {
1528 let ctx_len = self
1529 .unified_registry
1530 .get(&m)
1531 .or_else(|| self.unified_registry.find_by_name(&m))
1532 .map(|s| s.context_length)
1533 .unwrap_or(0);
1534 AdaptiveRoutingDecision {
1535 model_id: m.clone(),
1536 model_name: m,
1537 task: InferenceTask::Generate,
1538 complexity: TaskComplexity::assess(&req.prompt),
1539 reason: "explicit model".into(),
1540 strategy: RoutingStrategy::Explicit,
1541 predicted_quality: 0.5,
1542 fallbacks: vec![],
1543 context_length: ctx_len,
1544 needs_compaction: false,
1545 }
1546 }
1547 None => {
1548 let tracker_read = self.outcome_tracker.read().await;
1549 if has_vision {
1550 self.adaptive_router.route_with_vision(
1551 &req.prompt,
1552 &self.unified_registry,
1553 &tracker_read,
1554 has_tools,
1555 )
1556 } else if has_tools {
1557 self.adaptive_router.route_with_tools(
1558 &req.prompt,
1559 &self.unified_registry,
1560 &tracker_read,
1561 )
1562 } else {
1563 self.adaptive_router
1564 .route(&req.prompt, &self.unified_registry, &tracker_read)
1565 }
1566 }
1567 };
1568
1569 #[allow(unused_mut)]
1572 let mut schema = self
1573 .unified_registry
1574 .get(&decision.model_id)
1575 .or_else(|| self.unified_registry.find_by_name(&decision.model_id))
1576 .cloned();
1577
1578 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1580 if let Some(ref s) = schema {
1581 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1582 tracing::info!(
1583 from = %s.id, to = %mlx_equiv.id,
1584 "redirecting GGUF model to MLX equivalent on Apple Silicon (stream)"
1585 );
1586 schema = Some(mlx_equiv.clone());
1587 }
1588 }
1589
1590 let is_remote = schema
1591 .as_ref()
1592 .map(|s| s.is_remote() || s.is_vllm_mlx())
1593 .unwrap_or(false);
1594
1595 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1596
1597 if is_delegated {
1598 let runner = runner::current_inference_runner().ok_or_else(|| {
1602 InferenceError::InferenceFailed(
1603 "model declares ModelSource::Delegated but no inference runner is registered \
1604 (call set_inference_runner / registerInferenceRunner / register_inference_runner)"
1605 .into(),
1606 )
1607 })?;
1608 let (tx, rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1609 let emitter = runner::EventEmitter::new(tx);
1610 let request = req.clone();
1611 tokio::spawn(async move {
1612 if let Err(e) = runner.run(request, emitter).await {
1613 tracing::warn!(error = %e, "delegated inference runner failed");
1614 }
1615 });
1616 return Ok(rx);
1617 }
1618
1619 if is_remote {
1620 let schema = schema.unwrap();
1621 self.remote_backend.register_model_keys(&schema).await;
1623
1624 self.remote_backend
1625 .generate_stream(
1626 &schema,
1627 &req.prompt,
1628 req.context.as_deref(),
1629 req.params.temperature,
1630 req.params.max_tokens,
1631 req.tools.as_deref(),
1632 req.images.as_deref(),
1633 req.params.tool_choice.as_deref(),
1634 req.params.parallel_tool_calls,
1635 req.response_format.as_ref(),
1636 )
1637 .await
1638 } else {
1639 let schema =
1640 schema.ok_or_else(|| InferenceError::ModelNotFound(decision.model_id.clone()))?;
1641 let (tx, rx) = tokio::sync::mpsc::channel(64);
1642
1643 #[cfg(any(
1649 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
1650 all(target_os = "ios", target_arch = "aarch64")
1651 ))]
1652 {
1653 if schema.is_foundation_models() {
1654 let prompt = req.prompt.clone();
1655 let instructions = req.context.clone();
1656 let max_tokens = req.params.max_tokens as u32;
1657 let temperature = req.params.temperature;
1658 let tx_clone = tx.clone();
1659 tokio::task::spawn_blocking(move || {
1660 let accum = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
1667 let accum_cb = accum.clone();
1668 let cb = crate::backend::foundation_models::StreamCallback::new(
1669 move |delta: &str| {
1670 if let Ok(mut g) = accum_cb.lock() {
1671 g.push_str(delta);
1672 }
1673 tx_clone
1674 .blocking_send(stream::StreamEvent::TextDelta(
1675 delta.to_string(),
1676 ))
1677 .is_ok()
1678 },
1679 );
1680 let result = crate::backend::foundation_models::stream(
1681 &prompt,
1682 instructions.as_deref(),
1683 max_tokens,
1684 temperature as f32,
1685 cb,
1686 );
1687 let final_text = accum.lock().map(|g| g.clone()).unwrap_or_default();
1688 let _ = tx.blocking_send(stream::StreamEvent::Done {
1689 text: final_text,
1690 tool_calls: vec![],
1691 });
1692 result
1693 });
1694 return Ok(rx);
1695 }
1696 }
1697
1698 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1700 {
1701 if !schema.is_mlx() {
1704 return Err(InferenceError::InferenceFailed(format!(
1705 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1706 schema.id
1707 )));
1708 }
1709 let backend = self.ensure_mlx_backend(&schema).await?;
1710 let model_id = schema.id.clone();
1711 let cache = Arc::clone(&self.mlx_backends);
1712 tokio::task::spawn_blocking(move || {
1718 let _ = Self::stream_local_mlx(backend, cache, model_id, req, tx);
1719 });
1720 return Ok(rx);
1721 }
1722
1723 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1724 {
1725 self.ensure_backend(&schema.name).await?;
1726 let backend = self.backend.clone();
1727 tokio::spawn(async move {
1728 let _ = Self::stream_local_candle(backend, req, tx).await;
1729 });
1730 Ok(rx)
1731 }
1732 }
1733 }
1734
1735 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1736 async fn stream_local_candle(
1737 backend_lock: Arc<RwLock<Option<CandleBackend>>>,
1738 req: GenerateRequest,
1739 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1740 ) -> Result<(), InferenceError> {
1741 let mut write = backend_lock.write().await;
1742 let backend = write
1743 .as_mut()
1744 .ok_or_else(|| InferenceError::InferenceFailed("backend not initialized".into()))?;
1745 backend.clear_kv_cache();
1746
1747 let formatted = tasks::generate::apply_chat_template(
1748 &req.prompt,
1749 req.context.as_deref(),
1750 req.params.thinking,
1751 );
1752 let tokens = backend.encode(&formatted)?;
1753 let eos = backend.eos_token_id();
1754 let eos_alt = backend.token_id("<|im_end|>");
1755 let params = &req.params;
1756
1757 if tokens.is_empty() {
1758 let _ = tx
1759 .send(stream::StreamEvent::Done {
1760 text: String::new(),
1761 tool_calls: vec![],
1762 })
1763 .await;
1764 return Ok(());
1765 }
1766
1767 let max_ctx = backend.context_length().unwrap_or(32768);
1768 let headroom = params.max_tokens.min(max_ctx / 4);
1769 let max_prompt = max_ctx.saturating_sub(headroom);
1770 let tokens = if tokens.len() > max_prompt {
1771 tokens[tokens.len() - max_prompt..].to_vec()
1772 } else {
1773 tokens
1774 };
1775
1776 let mut generated = Vec::new();
1777 let logits = backend.forward(&tokens, 0)?;
1778 let mut next_token = tasks::generate::sample_token(&logits, params)?;
1779
1780 for _ in 0..params.max_tokens {
1781 if eos.map_or(false, |id| next_token == id)
1782 || eos_alt.map_or(false, |id| next_token == id)
1783 {
1784 break;
1785 }
1786
1787 generated.push(next_token);
1788 let delta = backend.decode(&[next_token])?;
1789 if !delta.is_empty()
1790 && tx
1791 .send(stream::StreamEvent::TextDelta(delta))
1792 .await
1793 .is_err()
1794 {
1795 return Ok(());
1796 }
1797
1798 if !params.stop.is_empty() {
1799 let text_so_far = backend.decode(&generated)?;
1800 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1801 break;
1802 }
1803 }
1804
1805 let pos = tokens.len() + generated.len() - 1;
1806 let logits = backend.forward(&[next_token], pos)?;
1807 next_token = tasks::generate::sample_token(&logits, params)?;
1808 }
1809
1810 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1811 let _ = tx
1812 .send(stream::StreamEvent::Done {
1813 text,
1814 tool_calls: vec![],
1815 })
1816 .await;
1817 Ok(())
1818 }
1819
1820 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1824 fn stream_local_mlx(
1825 handle: backend_cache::CachedBackend<backend::MlxBackend>,
1826 cache: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
1827 model_id: String,
1828 req: GenerateRequest,
1829 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1830 ) -> Result<(), InferenceError> {
1831 let mut guard = handle.lock().map_err(|_| {
1832 InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
1833 })?;
1834 let backend: &mut backend::MlxBackend = &mut *guard;
1835 backend.clear_kv_cache();
1836
1837 let formatted = tasks::generate::apply_chat_template(
1838 &req.prompt,
1839 req.context.as_deref(),
1840 req.params.thinking,
1841 );
1842 let tokens = backend.encode(&formatted)?;
1843 let eos = backend.eos_token_id();
1844 let eos_alt = backend.token_id("<|im_end|>");
1845 let params = &req.params;
1846
1847 if tokens.is_empty() {
1848 let _ = tx.blocking_send(stream::StreamEvent::Done {
1849 text: String::new(),
1850 tool_calls: vec![],
1851 });
1852 return Ok(());
1853 }
1854
1855 let max_ctx = backend.context_length();
1856 let headroom = params.max_tokens.min(max_ctx / 4);
1857 let max_prompt = max_ctx.saturating_sub(headroom);
1858 let tokens = if tokens.len() > max_prompt {
1859 tokens[tokens.len() - max_prompt..].to_vec()
1860 } else {
1861 tokens
1862 };
1863
1864 let mut generated = Vec::new();
1865 let logits = match Self::catch_mlx("stream prefill", || backend.forward(&tokens, 0)) {
1870 Ok(v) => v,
1871 Err(e) => {
1872 cache.invalidate(&model_id);
1873 return Err(e);
1874 }
1875 };
1876 let mut next_token = Self::sample_from_logits(&logits, params)?;
1877
1878 for _ in 0..params.max_tokens {
1879 if eos.map_or(false, |id| next_token == id)
1880 || eos_alt.map_or(false, |id| next_token == id)
1881 {
1882 break;
1883 }
1884
1885 generated.push(next_token);
1886 let delta = backend.decode(&[next_token])?;
1887 if !delta.is_empty()
1888 && tx
1889 .blocking_send(stream::StreamEvent::TextDelta(delta))
1890 .is_err()
1891 {
1892 return Ok(());
1893 }
1894
1895 if !params.stop.is_empty() {
1896 let text_so_far = backend.decode(&generated)?;
1897 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1898 break;
1899 }
1900 }
1901
1902 let pos = tokens.len() + generated.len() - 1;
1903 let logits =
1904 match Self::catch_mlx("stream forward", || backend.forward(&[next_token], pos)) {
1905 Ok(v) => v,
1906 Err(e) => {
1907 cache.invalidate(&model_id);
1908 return Err(e);
1909 }
1910 };
1911 next_token = Self::sample_from_logits(&logits, params)?;
1912 }
1913
1914 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1915 let _ = tx.blocking_send(stream::StreamEvent::Done {
1916 text,
1917 tool_calls: vec![],
1918 });
1919 Ok(())
1920 }
1921
1922 pub async fn route_context_snapshot(
1924 &self,
1925 prompt: &str,
1926 workload: RoutingWorkload,
1927 has_tools: bool,
1928 has_vision: bool,
1929 ) -> AdaptiveRoutingDecision {
1930 let tracker = self.outcome_tracker.read().await;
1931 self.adaptive_router.route_context_aware(
1932 prompt,
1933 0,
1934 &self.unified_registry,
1935 &tracker,
1936 has_tools,
1937 has_vision,
1938 workload,
1939 )
1940 }
1941
1942 pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
1945 Ok(self.generate_tracked(req).await?.text)
1946 }
1947
1948 pub async fn tokenize(&self, model: &str, text: &str) -> Result<Vec<u32>, InferenceError> {
1960 self.assert_local_for_tokenize(model)?;
1961
1962 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1963 {
1964 let schema = self
1965 .unified_registry
1966 .get(model)
1967 .or_else(|| self.unified_registry.find_by_name(model))
1968 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1969 .clone();
1970 let handle = self.ensure_mlx_backend(&schema).await?;
1971 let guard = handle.lock().map_err(|_| {
1972 InferenceError::InferenceFailed(format!(
1973 "MLX backend mutex poisoned for {}",
1974 schema.id
1975 ))
1976 })?;
1977 return guard.tokenize_raw(text);
1978 }
1979
1980 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1981 {
1982 self.ensure_backend(model).await?;
1983 let read = self.backend.read().await;
1984 let backend = read.as_ref().ok_or_else(|| {
1985 InferenceError::InferenceFailed(
1986 "candle backend missing after ensure_backend".to_string(),
1987 )
1988 })?;
1989 backend.tokenize_raw(text)
1990 }
1991 }
1992
1993 pub async fn detokenize(&self, model: &str, tokens: &[u32]) -> Result<String, InferenceError> {
1995 self.assert_local_for_tokenize(model)?;
1996
1997 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1998 {
1999 let schema = self
2000 .unified_registry
2001 .get(model)
2002 .or_else(|| self.unified_registry.find_by_name(model))
2003 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
2004 .clone();
2005 let handle = self.ensure_mlx_backend(&schema).await?;
2006 let guard = handle.lock().map_err(|_| {
2007 InferenceError::InferenceFailed(format!(
2008 "MLX backend mutex poisoned for {}",
2009 schema.id
2010 ))
2011 })?;
2012 return guard.detokenize_raw(tokens);
2013 }
2014
2015 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2016 {
2017 self.ensure_backend(model).await?;
2018 let read = self.backend.read().await;
2019 let backend = read.as_ref().ok_or_else(|| {
2020 InferenceError::InferenceFailed(
2021 "candle backend missing after ensure_backend".to_string(),
2022 )
2023 })?;
2024 backend.detokenize_raw(tokens)
2025 }
2026 }
2027
2028 fn assert_local_for_tokenize(&self, model: &str) -> Result<(), InferenceError> {
2032 if let Some(schema) = self
2033 .unified_registry
2034 .get(model)
2035 .or_else(|| self.unified_registry.find_by_name(model))
2036 {
2037 if !schema.is_local() {
2038 return Err(InferenceError::UnsupportedMode {
2039 mode: "tokenize/detokenize",
2040 backend: "remote",
2041 reason: "remote provider tokenizer is not exposed by the runtime; \
2042 use a local model (Qwen3 GGUF / MLX) for tokenizer-correctness checks",
2043 });
2044 }
2045 }
2046 Ok(())
2048 }
2049
2050 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2056 fn catch_mlx<F, T>(context: &str, f: F) -> Result<T, InferenceError>
2057 where
2058 F: FnOnce() -> Result<T, InferenceError>,
2059 {
2060 std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)).map_err(|e| {
2061 InferenceError::InferenceFailed(format!("MLX panicked during {context}: {e:?}"))
2062 })?
2063 }
2064
2065 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2068 async fn generate_mlx(
2069 &self,
2070 req: GenerateRequest,
2071 model_id: &str,
2072 ) -> Result<(String, Option<u64>), InferenceError> {
2073 let start = std::time::Instant::now();
2074
2075 let schema = self
2076 .unified_registry
2077 .get(model_id)
2078 .cloned()
2079 .ok_or_else(|| {
2080 InferenceError::InferenceFailed(format!(
2081 "generate_mlx: unknown schema id {model_id}"
2082 ))
2083 })?;
2084 let handle = self.ensure_mlx_backend(&schema).await?;
2085 let mut guard = handle.lock().map_err(|_| {
2086 InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
2087 })?;
2088 let backend: &mut backend::MlxBackend = &mut *guard;
2089 backend.clear_kv_cache();
2090
2091 let formatted = tasks::generate::apply_chat_template(
2092 &req.prompt,
2093 req.context.as_deref(),
2094 req.params.thinking,
2095 );
2096 let tokens = backend.encode(&formatted)?;
2097 let eos = backend.eos_token_id();
2098 let eos_alt = backend.token_id("<|im_end|>");
2099 let params = &req.params;
2100
2101 if tokens.is_empty() {
2102 return Ok((String::new(), None));
2103 }
2104
2105 let max_ctx = backend.context_length();
2107 let headroom = params.max_tokens.min(max_ctx / 4);
2108 let max_prompt = max_ctx.saturating_sub(headroom);
2109 let tokens = if tokens.len() > max_prompt {
2110 tokens[tokens.len() - max_prompt..].to_vec()
2111 } else {
2112 tokens
2113 };
2114
2115 let mut generated = Vec::new();
2116
2117 let logits = match Self::catch_mlx("prefill", || backend.forward(&tokens, 0)) {
2123 Ok(v) => v,
2124 Err(e) => {
2125 drop(guard);
2126 self.mlx_backends.invalidate(model_id);
2127 return Err(e);
2128 }
2129 };
2130 let mut next_token = Self::sample_from_logits(&logits, params)?;
2131 let ttft_ms = Some(start.elapsed().as_millis() as u64);
2132
2133 for _ in 0..params.max_tokens {
2134 if eos.map_or(false, |id| next_token == id)
2135 || eos_alt.map_or(false, |id| next_token == id)
2136 {
2137 break;
2138 }
2139
2140 generated.push(next_token);
2141
2142 if !params.stop.is_empty() {
2143 let text_so_far = backend.decode(&generated)?;
2144 if params.stop.iter().any(|s| text_so_far.contains(s)) {
2145 break;
2146 }
2147 }
2148
2149 let pos = tokens.len() + generated.len() - 1;
2150 let logits = match Self::catch_mlx("forward", || backend.forward(&[next_token], pos)) {
2151 Ok(v) => v,
2152 Err(e) => {
2153 drop(guard);
2154 self.mlx_backends.invalidate(model_id);
2155 return Err(e);
2156 }
2157 };
2158 next_token = Self::sample_from_logits(&logits, params)?;
2159 }
2160
2161 let text = backend.decode(&generated)?;
2162 Ok((
2163 tasks::generate::strip_thinking(&text, params.thinking),
2164 ttft_ms,
2165 ))
2166 }
2167
2168 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2170 fn sample_from_logits(logits: &[f32], params: &GenerateParams) -> Result<u32, InferenceError> {
2171 if params.temperature <= 0.0 {
2172 let (idx, _) = logits
2174 .iter()
2175 .enumerate()
2176 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2177 .ok_or_else(|| InferenceError::InferenceFailed("empty logits".into()))?;
2178 return Ok(idx as u32);
2179 }
2180
2181 let temp = params.temperature as f32;
2183 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
2184 let mut probs: Vec<f32> = logits
2185 .iter()
2186 .map(|&l| ((l - max_logit) / temp).exp())
2187 .collect();
2188 let sum: f32 = probs.iter().sum();
2189 for p in &mut probs {
2190 *p /= sum;
2191 }
2192
2193 if params.top_p < 1.0 {
2195 let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
2196 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2197 let mut cumsum = 0.0;
2198 let mut cutoff_idx = indexed.len();
2199 for (i, &(_, p)) in indexed.iter().enumerate() {
2200 cumsum += p;
2201 if cumsum > params.top_p as f32 {
2202 cutoff_idx = i + 1;
2203 break;
2204 }
2205 }
2206 let allowed: std::collections::HashSet<usize> =
2208 indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
2209 for (i, p) in probs.iter_mut().enumerate() {
2210 if !allowed.contains(&i) {
2211 *p = 0.0;
2212 }
2213 }
2214 let sum: f32 = probs.iter().sum();
2215 if sum > 0.0 {
2216 for p in &mut probs {
2217 *p /= sum;
2218 }
2219 }
2220 }
2221
2222 use rand::Rng;
2224 let mut rng = rand::rng();
2225 let r: f32 = rng.random();
2226 let mut cumsum = 0.0;
2227 for (i, &p) in probs.iter().enumerate() {
2228 cumsum += p;
2229 if cumsum >= r {
2230 return Ok(i as u32);
2231 }
2232 }
2233 Ok((probs.len() - 1) as u32)
2234 }
2235
2236 pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
2239 let instruction = req
2240 .instruction
2241 .as_deref()
2242 .unwrap_or("Retrieve relevant memory facts");
2243
2244 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2245 {
2246 let model_id = self.ensure_mlx_embedding_backend().await?;
2247 let schema = self
2248 .unified_registry
2249 .get(&model_id)
2250 .cloned()
2251 .ok_or_else(|| {
2252 InferenceError::InferenceFailed(format!("embed: unknown schema id {model_id}"))
2253 })?;
2254 let handle = self.ensure_mlx_backend(&schema).await?;
2255 let mut guard = handle.lock().map_err(|_| {
2256 InferenceError::InferenceFailed(format!(
2257 "MLX embedding backend mutex poisoned for {model_id}"
2258 ))
2259 })?;
2260 let backend: &mut backend::MlxBackend = &mut *guard;
2261
2262 let mut results = Vec::with_capacity(req.texts.len());
2263 for text in &req.texts {
2264 let embedding = if req.is_query {
2265 backend.embed_query(text, instruction)?
2266 } else {
2267 backend.embed_one(text)?
2268 };
2269 results.push(embedding);
2270 }
2271 return Ok(results);
2272 }
2273
2274 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2275 {
2276 self.ensure_embedding_backend().await?;
2277 let mut write = self.embedding_backend.write().await;
2278 let backend = write.as_mut().unwrap();
2279
2280 let mut results = Vec::with_capacity(req.texts.len());
2281 for text in &req.texts {
2282 let embedding = if req.is_query {
2283 backend.embed_query(text, instruction)?
2284 } else {
2285 backend.embed_one(text)?
2286 };
2287 results.push(embedding);
2288 }
2289 Ok(results)
2290 }
2291 }
2292
2293 pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResult, InferenceError> {
2323 if req.documents.is_empty() {
2324 return Ok(RerankResult {
2325 ranked: Vec::new(),
2326 model_used: None,
2327 });
2328 }
2329
2330 let model_name = match req.model.clone() {
2331 Some(m) => m,
2332 None => self
2333 .preferred_model_for_capability(ModelCapability::Rerank)
2334 .map(str::to_string)
2335 .ok_or_else(|| {
2336 InferenceError::InferenceFailed(
2337 "no reranker model available — pull a Qwen3-Reranker model first".into(),
2338 )
2339 })?,
2340 };
2341
2342 let schema = self
2343 .unified_registry
2344 .find_by_name(&model_name)
2345 .or_else(|| self.unified_registry.get(&model_name))
2346 .cloned()
2347 .ok_or_else(|| {
2348 InferenceError::InferenceFailed(format!(
2349 "rerank: unknown reranker model {model_name}"
2350 ))
2351 })?;
2352 if !schema.has_capability(ModelCapability::Rerank) {
2353 return Err(InferenceError::InferenceFailed(format!(
2354 "model {} does not declare the Rerank capability",
2355 schema.name
2356 )));
2357 }
2358
2359 let instruction = req.instruction.as_deref().unwrap_or(
2360 "Given a web search query, retrieve relevant passages that answer the query",
2361 );
2362
2363 let mut scored: Vec<RerankedDocument> = Vec::with_capacity(req.documents.len());
2364 for (idx, doc) in req.documents.iter().enumerate() {
2365 let prompt = rerank_prompt(instruction, &req.query, doc);
2366 let gen_req = GenerateRequest {
2367 prompt,
2368 model: Some(schema.id.clone()),
2369 params: tasks::generate::GenerateParams {
2370 temperature: 0.0,
2371 max_tokens: 3,
2375 thinking: tasks::generate::ThinkingMode::Off,
2376 ..Default::default()
2377 },
2378 context: None,
2379 tools: None,
2380 images: None,
2381 messages: None,
2382 cache_control: false,
2383 response_format: None,
2384 intent: None,
2385 };
2386 let out = self.generate(gen_req).await?;
2387 let score = score_from_rerank_output(&out, &schema.name);
2388 scored.push(RerankedDocument {
2389 index: idx,
2390 score,
2391 document: doc.clone(),
2392 });
2393 }
2394
2395 scored.sort_by(|a, b| {
2398 b.score
2399 .partial_cmp(&a.score)
2400 .unwrap_or(std::cmp::Ordering::Equal)
2401 .then_with(|| a.index.cmp(&b.index))
2402 });
2403 if let Some(n) = req.top_n {
2404 scored.truncate(n);
2405 }
2406
2407 Ok(RerankResult {
2408 ranked: scored,
2409 model_used: Some(schema.name),
2410 })
2411 }
2412
2413 pub async fn ground(&self, req: GroundRequest) -> Result<GroundResult, InferenceError> {
2423 let model_name = match req.model.clone() {
2424 Some(m) => m,
2425 None => self
2426 .preferred_model_for_capability(ModelCapability::Grounding)
2427 .map(str::to_string)
2428 .ok_or_else(|| {
2429 InferenceError::InferenceFailed(
2430 "no grounding-capable model available — pull a Qwen2.5-VL model first"
2431 .into(),
2432 )
2433 })?,
2434 };
2435
2436 let gen_req = GenerateRequest {
2437 prompt: req.prompt.clone(),
2438 model: Some(model_name),
2439 params: GenerateParams::default(),
2440 context: None,
2441 tools: None,
2442 images: Some(vec![req.image.clone()]),
2443 messages: None,
2444 cache_control: false,
2445 response_format: None,
2446 intent: None,
2447 };
2448 let result = self.generate_tracked(gen_req).await?;
2449 Ok(GroundResult {
2450 boxes: result.bounding_boxes,
2451 raw_text: result.text,
2452 model_used: Some(result.model_used),
2453 })
2454 }
2455
2456 pub async fn classify(
2459 &self,
2460 req: ClassifyRequest,
2461 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2462 let model = match req.model.clone().or_else(|| {
2463 self.preferred_model_for_capability(ModelCapability::Classify)
2464 .map(str::to_string)
2465 }) {
2466 Some(m) => m,
2467 None => {
2468 let m = self.router.route_small(&self.registry);
2469 debug!(model = %m, "auto-routed classify request");
2470 m
2471 }
2472 };
2473
2474 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2477 {
2478 return self.classify_via_generate(req, &model).await;
2479 }
2480
2481 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2482 {
2483 self.ensure_backend(&model).await?;
2484 let mut write = self.backend.write().await;
2485 let backend = write.as_mut().unwrap();
2486 tasks::classify::classify(backend, req).await
2487 }
2488 }
2489
2490 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2492 async fn classify_via_generate(
2493 &self,
2494 req: ClassifyRequest,
2495 model: &str,
2496 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2497 let labels_str = req
2498 .labels
2499 .iter()
2500 .enumerate()
2501 .map(|(i, l)| format!("{}. {}", i + 1, l))
2502 .collect::<Vec<_>>()
2503 .join("\n");
2504
2505 let prompt = format!(
2506 "Classify the following text into one of these categories:\n\
2507 {labels_str}\n\n\
2508 Text: {}\n\n\
2509 Respond with ONLY the category name, nothing else.",
2510 req.text
2511 );
2512
2513 let gen_req = GenerateRequest {
2514 prompt,
2515 model: Some(model.to_string()),
2516 params: tasks::generate::GenerateParams {
2517 temperature: 0.0,
2518 max_tokens: 32,
2519 thinking: tasks::generate::ThinkingMode::Off,
2522 ..Default::default()
2523 },
2524 context: None,
2525 tools: None,
2526 images: None,
2527 messages: None,
2528 cache_control: false,
2529 response_format: None,
2530 intent: None,
2531 };
2532
2533 let response = self.generate(gen_req).await?;
2534 let response_lower = response.trim().to_lowercase();
2535
2536 let mut results: Vec<ClassifyResult> = req
2537 .labels
2538 .iter()
2539 .map(|label| {
2540 let label_lower = label.to_lowercase();
2541 let score = if response_lower == label_lower {
2542 1.0
2543 } else if response_lower.contains(&label_lower) {
2544 0.8
2545 } else {
2546 let label_words: Vec<&str> = label_lower.split_whitespace().collect();
2547 let matches = label_words
2548 .iter()
2549 .filter(|w| response_lower.contains(**w))
2550 .count();
2551 if label_words.is_empty() {
2552 0.0
2553 } else {
2554 0.5 * (matches as f64 / label_words.len() as f64)
2555 }
2556 };
2557 ClassifyResult {
2558 label: label.clone(),
2559 score,
2560 }
2561 })
2562 .collect();
2563
2564 results.sort_by(|a, b| {
2565 b.score
2566 .partial_cmp(&a.score)
2567 .unwrap_or(std::cmp::Ordering::Equal)
2568 });
2569
2570 let total: f64 = results.iter().map(|r| r.score).sum();
2571 if total > 0.0 {
2572 for r in &mut results {
2573 r.score /= total;
2574 }
2575 }
2576
2577 Ok(results)
2578 }
2579
2580 pub async fn transcribe(
2582 &self,
2583 req: TranscribeRequest,
2584 ) -> Result<TranscribeResult, InferenceError> {
2585 let candidates =
2586 self.speech_candidates(ModelCapability::SpeechToText, req.model.as_deref())?;
2587 let mut last_error = None;
2588
2589 for schema in candidates {
2590 let result = match &schema.source {
2591 ModelSource::Mlx { .. } => self.transcribe_local_mlx(&schema, &req).await,
2592 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2593 self.transcribe_elevenlabs(&schema, &req).await
2594 }
2595 _ => Err(InferenceError::InferenceFailed(format!(
2596 "speech-to-text not implemented for model source: {}",
2597 schema.id
2598 ))),
2599 };
2600
2601 match result {
2602 Ok(result) => return Ok(result),
2603 Err(err) => last_error = Some(err),
2604 }
2605 }
2606
2607 Err(last_error.unwrap_or_else(|| {
2608 InferenceError::InferenceFailed("no speech-to-text models available".into())
2609 }))
2610 }
2611
2612 pub async fn synthesize(
2614 &self,
2615 req: SynthesizeRequest,
2616 ) -> Result<SynthesizeResult, InferenceError> {
2617 let candidates =
2618 self.speech_candidates(ModelCapability::TextToSpeech, req.model.as_deref())?;
2619 let mut last_error = None;
2620
2621 for schema in candidates {
2622 let result = match &schema.source {
2623 ModelSource::Mlx { .. } => self.synthesize_local_mlx(&schema, &req).await,
2624 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2625 self.synthesize_elevenlabs(&schema, &req).await
2626 }
2627 _ => Err(InferenceError::InferenceFailed(format!(
2628 "text-to-speech not implemented for model source: {}",
2629 schema.id
2630 ))),
2631 };
2632
2633 match result {
2634 Ok(result) => return Ok(result),
2635 Err(err) => last_error = Some(err),
2636 }
2637 }
2638
2639 Err(last_error.unwrap_or_else(|| {
2640 InferenceError::InferenceFailed("no text-to-speech models available".into())
2641 }))
2642 }
2643
2644 pub async fn generate_image(
2646 &self,
2647 req: GenerateImageRequest,
2648 ) -> Result<GenerateImageResult, InferenceError> {
2649 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2656 {
2657 use crate::backend::external_flux;
2658 let backend =
2659 std::env::var("CAR_IMAGE_BACKEND").unwrap_or_else(|_| "native".to_string());
2660 let use_external = match backend.as_str() {
2661 "external" => true,
2662 "native" => false,
2663 _ => external_flux::is_available() && backend == "auto-external",
2665 };
2666 if use_external {
2667 tracing::info!(
2668 "routing image generation to external mflux \
2669 (set CAR_IMAGE_BACKEND=native to use the Rust port)"
2670 );
2671 let mut req = req;
2672 req.model = self.resolve_external_hf_repo(
2673 req.model.as_deref(),
2674 ModelCapability::ImageGeneration,
2675 );
2676 return external_flux::generate_image(&req);
2677 }
2678 tracing::info!("using native Rust MLX Flux backend");
2679 }
2680
2681 let candidates = self
2682 .media_generation_candidates(ModelCapability::ImageGeneration, req.model.as_deref())?;
2683 let mut last_error = None;
2684
2685 for schema in candidates {
2686 let result = match &schema.source {
2687 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2688 ModelSource::Mlx { .. } => self.generate_image_native_mlx(&schema, &req).await,
2689 _ => Err(InferenceError::InferenceFailed(format!(
2690 "image generation not implemented for model source: {}",
2691 schema.id
2692 ))),
2693 };
2694
2695 match result {
2696 Ok(result) => return Ok(result),
2697 Err(err) => last_error = Some(err),
2698 }
2699 }
2700
2701 Err(last_error.unwrap_or_else(|| {
2702 InferenceError::InferenceFailed("no image generation models available".into())
2703 }))
2704 }
2705
2706 pub async fn generate_image_batch(
2722 &self,
2723 req: GenerateImageRequest,
2724 ) -> Result<Vec<GenerateImageResult>, InferenceError> {
2725 let count = req.variant_count.unwrap_or(1).max(1);
2726 if count == 1 {
2727 return self.generate_image(req).await.map(|r| vec![r]);
2728 }
2729 let base_seed = req.seed.unwrap_or(0);
2730 let mut results = Vec::with_capacity(count as usize);
2731 for i in 0..count {
2732 let mut variant_req = req.clone();
2737 variant_req.seed = Some(base_seed.wrapping_add(i as u64));
2738 variant_req.variant_count = Some(1);
2742 results.push(self.generate_image(variant_req).await?);
2743 }
2744 Ok(results)
2745 }
2746
2747 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2749 async fn generate_image_native_mlx(
2750 &self,
2751 schema: &ModelSchema,
2752 req: &GenerateImageRequest,
2753 ) -> Result<GenerateImageResult, InferenceError> {
2754 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2755 let size = backend_cache::estimate_model_size(&model_dir);
2756 let cache = Arc::clone(&self.flux_cache);
2757 let key = schema.id.clone();
2758 let handle = cache.get_or_load(&key, size, || {
2759 backend::mlx_flux::FluxBackend::load(&model_dir)
2760 })?;
2761 let req = req.clone();
2765 tokio::task::spawn_blocking(move || -> Result<GenerateImageResult, InferenceError> {
2766 let mut guard = handle.lock().map_err(|_| {
2767 InferenceError::InferenceFailed("flux backend mutex poisoned".into())
2768 })?;
2769 guard.generate(&req)
2770 })
2771 .await
2772 .map_err(|e| InferenceError::InferenceFailed(format!("flux task join: {e}")))?
2773 }
2774
2775 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2777 async fn generate_video_native_mlx(
2778 &self,
2779 schema: &ModelSchema,
2780 req: &GenerateVideoRequest,
2781 ) -> Result<GenerateVideoResult, InferenceError> {
2782 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2783 let size = backend_cache::estimate_model_size(&model_dir);
2784 let cache = Arc::clone(&self.ltx_cache);
2785 let key = schema.id.clone();
2786 let handle = cache.get_or_load(&key, size, || {
2787 backend::mlx_ltx::LtxBackend::load(&model_dir)
2788 })?;
2789 let req = req.clone();
2790 tokio::task::spawn_blocking(move || -> Result<GenerateVideoResult, InferenceError> {
2791 let mut guard = handle.lock().map_err(|_| {
2792 InferenceError::InferenceFailed("ltx backend mutex poisoned".into())
2793 })?;
2794 guard.generate(&req)
2795 })
2796 .await
2797 .map_err(|e| InferenceError::InferenceFailed(format!("ltx task join: {e}")))?
2798 }
2799
2800 pub async fn generate_video(
2802 &self,
2803 req: GenerateVideoRequest,
2804 ) -> Result<GenerateVideoResult, InferenceError> {
2805 if let Err(msg) = req.validate() {
2808 return Err(InferenceError::InferenceFailed(format!(
2809 "invalid GenerateVideoRequest: {}",
2810 msg
2811 )));
2812 }
2813 let requires_audio_conditioning = req.requires_audio_passthrough_opt_in();
2814 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2819 {
2820 use crate::backend::external_ltx;
2821 let backend =
2826 std::env::var("CAR_VIDEO_BACKEND").unwrap_or_else(|_| "native".to_string());
2827 let use_external = match backend.as_str() {
2828 "external" => true,
2829 "native" => false,
2830 "auto-external" => external_ltx::is_available(),
2834 _ => false,
2835 };
2836 if use_external {
2837 tracing::info!(
2838 "CAR_VIDEO_BACKEND requested external LTX routing for LTX-family models"
2839 );
2840 } else {
2841 tracing::info!("using family-aware MLX video routing");
2842 }
2843 }
2844
2845 let candidates = self
2846 .media_generation_candidates(ModelCapability::VideoGeneration, req.model.as_deref())?;
2847 let mut last_error = None;
2848
2849 for schema in candidates {
2850 let result = match &schema.source {
2851 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2852 ModelSource::Mlx { hf_repo, .. } => {
2853 if crate::backend::external_mlx_video::is_wan_family(&schema) {
2854 match self.unified_registry.ensure_local(&schema.id).await {
2855 Ok(model_dir) => {
2856 crate::backend::external_mlx_video::generate_wan_video(
2857 &schema, &model_dir, &req,
2858 )
2859 }
2860 Err(err) => Err(err),
2861 }
2862 } else {
2863 let backend = std::env::var("CAR_VIDEO_BACKEND")
2864 .unwrap_or_else(|_| "native".to_string());
2865 let use_external_ltx = match backend.as_str() {
2866 "external" => true,
2867 "native" => false,
2868 "auto-external" => crate::backend::external_ltx::is_available(),
2869 _ => false,
2870 };
2871 let use_external_ltx = use_external_ltx || requires_audio_conditioning;
2872 if requires_audio_conditioning
2873 && !crate::backend::external_ltx::is_available()
2874 {
2875 return Err(InferenceError::InferenceFailed(
2876 "audio-reference video conditioning requires the external `ltx-2-mlx a2v` CLI on PATH"
2877 .to_string(),
2878 ));
2879 }
2880 if use_external_ltx {
2881 let mut req = req.clone();
2882 req.model = Some(hf_repo.clone());
2883 crate::backend::external_ltx::generate_video(&req)
2884 } else {
2885 self.generate_video_native_mlx(&schema, &req).await
2886 }
2887 }
2888 }
2889 _ => Err(InferenceError::InferenceFailed(format!(
2890 "video generation not implemented for model source: {}",
2891 schema.id
2892 ))),
2893 };
2894
2895 match result {
2896 Ok(result) => return Ok(result),
2897 Err(err) => last_error = Some(err),
2898 }
2899 }
2900
2901 Err(last_error.unwrap_or_else(|| {
2902 InferenceError::InferenceFailed("no video generation models available".into())
2903 }))
2904 }
2905
2906 pub fn list_models_unified(&self) -> Vec<ModelInfo> {
2908 self.unified_registry
2909 .list()
2910 .iter()
2911 .map(|m| ModelInfo::from(*m))
2912 .collect()
2913 }
2914
2915 pub fn available_model_upgrades(&self) -> Vec<ModelUpgrade> {
2917 self.unified_registry.available_upgrades()
2918 }
2919
2920 pub fn list_schemas(&self) -> Vec<ModelSchema> {
2923 self.unified_registry.list().into_iter().cloned().collect()
2924 }
2925
2926 pub fn list_models(&self) -> Vec<models::ModelInfo> {
2927 self.registry.list_models()
2928 }
2929
2930 pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
2932 let schema = self
2933 .unified_registry
2934 .find_by_name(name)
2935 .or_else(|| self.unified_registry.get(name))
2936 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2937 self.unified_registry.ensure_local(&schema.id).await
2938 }
2939
2940 pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
2942 let schema = self
2943 .unified_registry
2944 .get(name)
2945 .or_else(|| {
2946 self.unified_registry
2947 .list()
2948 .into_iter()
2949 .find(|schema| schema.name.eq_ignore_ascii_case(name))
2950 })
2951 .or_else(|| self.unified_registry.find_by_name(name))
2952 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2953 let model_dir = self.unified_registry.models_dir().join(&schema.name);
2954 if model_dir.exists() {
2955 std::fs::remove_dir_all(&model_dir)?;
2956 }
2957 match &schema.source {
2958 ModelSource::Mlx { hf_repo, .. } => {
2959 remove_huggingface_repo_cache(hf_repo)?;
2960 }
2961 ModelSource::Local {
2962 hf_repo,
2963 tokenizer_repo,
2964 ..
2965 } => {
2966 remove_huggingface_repo_cache(hf_repo)?;
2967 remove_huggingface_repo_cache(tokenizer_repo)?;
2968 }
2969 _ => {}
2970 }
2971 Ok(())
2972 }
2973
2974 pub fn register_model(&mut self, schema: ModelSchema) {
2976 self.unified_registry.register(schema);
2977 }
2978
2979 pub async fn discover_vllm_mlx_models(&mut self) -> usize {
2982 let config = vllm_mlx::VllmMlxConfig::default();
2983 if !config.auto_discover {
2984 return 0;
2985 }
2986 vllm_mlx::discover_and_register(&config, &mut self.unified_registry).await
2987 }
2988
2989 pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
2991 self.outcome_tracker.clone()
2992 }
2993
2994 async fn auto_save_outcomes(&self) {
2996 if let Err(e) = self.save_outcomes().await {
2997 tracing::debug!("auto-save outcomes failed: {}", e);
2998 }
2999 if let Err(e) = self.save_key_pool_stats().await {
3000 tracing::debug!("auto-save key pool stats failed: {}", e);
3001 }
3002 }
3003
3004 pub async fn save_outcomes(&self) -> Result<(), std::io::Error> {
3006 let tracker = self.outcome_tracker.read().await;
3007 let path = self.config.models_dir.join("outcome_profiles.json");
3008 tracker.save_to_file(&path)
3009 }
3010
3011 pub async fn save_key_pool_stats(&self) -> Result<(), std::io::Error> {
3013 let path = self.config.models_dir.join("key_pool_stats.json");
3014 self.remote_backend.key_pool.save_stats(&path).await
3015 }
3016
3017 pub async fn key_pool_stats(
3019 &self,
3020 ) -> std::collections::HashMap<String, Vec<key_pool::KeyStats>> {
3021 self.remote_backend.key_pool.all_stats().await
3022 }
3023
3024 pub async fn export_profiles(&self) -> Vec<ModelProfile> {
3026 let tracker = self.outcome_tracker.read().await;
3027 tracker.export_profiles()
3028 }
3029
3030 pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
3032 let mut tracker = self.outcome_tracker.write().await;
3033 tracker.import_profiles(profiles);
3034 }
3035
3036 pub async fn prepare_speech_runtime(&self) -> Result<PathBuf, InferenceError> {
3039 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3040 {
3041 Ok(self.config.models_dir.clone())
3043 }
3044 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3045 {
3046 Ok(self.ensure_speech_runtime().await?.root)
3047 }
3048 }
3049
3050 pub fn set_speech_policy(&mut self, policy: SpeechPolicy) {
3052 self.speech_policy = policy;
3053 }
3054
3055 pub fn set_routing_config(&mut self, config: RoutingConfig) {
3056 self.adaptive_router.set_config(config);
3057 }
3058
3059 pub async fn install_curated_speech(
3061 &mut self,
3062 ) -> Result<Vec<SpeechInstallReport>, InferenceError> {
3063 let _runtime_root = self.prepare_speech_runtime().await?;
3064 let schemas = self.list_schemas();
3065 let mut repos = Vec::new();
3066 for schema in &schemas {
3067 if !schema.is_mlx() || !schema.tags.iter().any(|tag| tag == "speech") {
3068 continue;
3069 }
3070 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3071 if !repos.iter().any(|existing: &String| existing == hf_repo) {
3072 repos.push(hf_repo.clone());
3073 }
3074 }
3075 }
3076
3077 let mut installed = Vec::new();
3078 for repo in repos {
3079 let (snapshot_path, files_downloaded) = download_hf_repo_snapshot(&repo).await?;
3080 let name = schemas
3081 .iter()
3082 .find(|schema| {
3083 matches!(&schema.source, ModelSource::Mlx { hf_repo, .. } if hf_repo == &repo)
3084 })
3085 .map(|schema| schema.name.clone())
3086 .unwrap_or_else(|| repo.clone());
3087 installed.push(SpeechInstallReport {
3088 name,
3089 hf_repo: repo,
3090 snapshot_path,
3091 files_downloaded,
3092 });
3093 }
3094
3095 self.unified_registry.refresh_availability();
3096 Ok(installed)
3097 }
3098
3099 pub fn speech_health(&self) -> SpeechHealthReport {
3101 let local_stt_default =
3102 self.speech_health_default_name(ModelCapability::SpeechToText, true, false);
3103 let local_tts_default =
3104 self.speech_health_default_name(ModelCapability::TextToSpeech, true, false);
3105 let remote_stt_default =
3106 self.speech_health_default_name(ModelCapability::SpeechToText, false, true);
3107 let remote_tts_default =
3108 self.speech_health_default_name(ModelCapability::TextToSpeech, false, true);
3109
3110 let mut local_models = Vec::new();
3111 let mut remote_models = Vec::new();
3112 for schema in self.list_schemas() {
3113 let capability = if schema.has_capability(ModelCapability::SpeechToText) {
3114 Some(ModelCapability::SpeechToText)
3115 } else if schema.has_capability(ModelCapability::TextToSpeech) {
3116 Some(ModelCapability::TextToSpeech)
3117 } else {
3118 None
3119 };
3120 let Some(capability) = capability else {
3121 continue;
3122 };
3123
3124 let selected_by_default = local_stt_default
3125 .as_ref()
3126 .is_some_and(|name| name == &schema.name)
3127 || local_tts_default
3128 .as_ref()
3129 .is_some_and(|name| name == &schema.name)
3130 || remote_stt_default
3131 .as_ref()
3132 .is_some_and(|name| name == &schema.name)
3133 || remote_tts_default
3134 .as_ref()
3135 .is_some_and(|name| name == &schema.name);
3136
3137 let health = SpeechModelHealth {
3138 id: schema.id.clone(),
3139 name: schema.name.clone(),
3140 provider: schema.provider.clone(),
3141 capability,
3142 is_local: schema.is_local(),
3143 available: schema.available,
3144 cached: speech_model_cached(&schema),
3145 selected_by_default,
3146 source: speech_model_source_label(&schema),
3147 };
3148 if schema.is_local() {
3149 local_models.push(health);
3150 } else {
3151 remote_models.push(health);
3152 }
3153 }
3154
3155 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3157 let runtime = SpeechRuntimeHealth {
3158 root: self.config.models_dir.clone(),
3159 installed: true,
3160 python: PathBuf::new(),
3161 stt_command: PathBuf::new(),
3162 tts_command: PathBuf::new(),
3163 configured_python: None,
3164 detected_python: None,
3165 };
3166
3167 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3168 let runtime = {
3169 let rt =
3170 SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3171 SpeechRuntimeHealth {
3172 root: rt.root.clone(),
3173 installed: rt.is_ready(),
3174 python: rt.python.clone(),
3175 stt_command: rt.stt_program.clone(),
3176 tts_command: rt.tts_program.clone(),
3177 configured_python: std::env::var("CAR_SPEECH_PYTHON")
3178 .ok()
3179 .filter(|value| !value.trim().is_empty()),
3180 detected_python: detect_speech_python(),
3181 }
3182 };
3183
3184 SpeechHealthReport {
3185 runtime,
3186 local_models,
3187 remote_models,
3188 elevenlabs_configured: car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY")
3189 .is_some(),
3190 prefer_local: self.speech_policy.prefer_local,
3191 allow_remote_fallback: self.speech_policy.allow_remote_fallback,
3192 preferred_local_stt: self.speech_policy.preferred_local_stt.clone(),
3193 preferred_local_tts: self.speech_policy.preferred_local_tts.clone(),
3194 preferred_remote_stt: self.speech_policy.preferred_remote_stt.clone(),
3195 preferred_remote_tts: self.speech_policy.preferred_remote_tts.clone(),
3196 local_stt_default,
3197 local_tts_default,
3198 remote_stt_default,
3199 remote_tts_default,
3200 }
3201 }
3202
3203 pub async fn model_health(&self) -> ModelHealthReport {
3206 let schemas = self.list_schemas();
3207 let total_models = schemas.len();
3208 let available_models = schemas.iter().filter(|schema| schema.available).count();
3209 let local_models = schemas.iter().filter(|schema| schema.is_local()).count();
3210 let remote_models = total_models.saturating_sub(local_models);
3211
3212 let defaults = vec![
3213 self.model_default_health(
3214 ModelCapability::Generate,
3215 self.preferred_model_for_capability(ModelCapability::Generate)
3216 .unwrap_or(&self.config.generation_model),
3217 ),
3218 self.model_default_health(
3219 ModelCapability::Embed,
3220 self.preferred_model_for_capability(ModelCapability::Embed)
3221 .unwrap_or(&self.config.embedding_model),
3222 ),
3223 self.model_default_health(
3224 ModelCapability::Classify,
3225 self.preferred_model_for_capability(ModelCapability::Classify)
3226 .unwrap_or(&self.config.classification_model),
3227 ),
3228 ];
3229
3230 let mut providers = std::collections::BTreeMap::new();
3231 for schema in &schemas {
3232 let entry =
3233 providers
3234 .entry(schema.provider.clone())
3235 .or_insert_with(|| ProviderAccumulator {
3236 configured: false,
3237 local_models: 0,
3238 remote_models: 0,
3239 available_models: 0,
3240 capabilities: std::collections::HashSet::new(),
3241 });
3242
3243 entry.configured |= model_source_configured(schema);
3244 if schema.is_local() {
3245 entry.local_models += 1;
3246 } else {
3247 entry.remote_models += 1;
3248 }
3249 if schema.available {
3250 entry.available_models += 1;
3251 }
3252 for capability in &schema.capabilities {
3253 entry.capabilities.insert(*capability);
3254 }
3255 }
3256
3257 let providers = providers
3258 .into_iter()
3259 .map(|(provider, acc)| ModelProviderHealth {
3260 provider,
3261 configured: acc.configured,
3262 local_models: acc.local_models,
3263 remote_models: acc.remote_models,
3264 available_models: acc.available_models,
3265 capabilities: sort_capabilities(acc.capabilities.into_iter().collect()),
3266 })
3267 .collect();
3268
3269 let capabilities = all_model_capabilities()
3270 .into_iter()
3271 .map(|capability| {
3272 let relevant: Vec<&ModelSchema> = schemas
3273 .iter()
3274 .filter(|schema| schema.has_capability(capability))
3275 .collect();
3276 let available: Vec<&ModelSchema> = relevant
3277 .iter()
3278 .copied()
3279 .filter(|schema| schema.available)
3280 .collect();
3281 ModelCapabilityHealth {
3282 capability,
3283 total_models: relevant.len(),
3284 available_models: available.len(),
3285 local_available_models: available
3286 .iter()
3287 .filter(|schema| schema.is_local())
3288 .count(),
3289 remote_available_models: available
3290 .iter()
3291 .filter(|schema| !schema.is_local())
3292 .count(),
3293 }
3294 })
3295 .collect();
3296
3297 let routing = self.routing_scenarios().await;
3298 let routing_config = self.adaptive_router.config().clone();
3299 let benchmark_priors = load_benchmark_prior_health(&self.config.models_dir, &schemas);
3300
3301 ModelHealthReport {
3302 total_models,
3303 available_models,
3304 local_models,
3305 remote_models,
3306 defaults,
3307 providers,
3308 capabilities,
3309 routing_prefer_local: routing_config.prefer_local,
3310 routing_quality_first_cold_start: routing_config.quality_first_cold_start,
3311 routing_min_observations: routing_config.min_observations,
3312 routing_bootstrap_min_task_observations: routing_config.bootstrap_min_task_observations,
3313 routing_bootstrap_quality_floor: routing_config.bootstrap_quality_floor,
3314 routing_quality_weight: routing_config.quality_weight,
3315 routing_latency_weight: routing_config.latency_weight,
3316 routing_cost_weight: routing_config.cost_weight,
3317 routing_scenarios: routing,
3318 benchmark_priors,
3319 speech: self.speech_health(),
3320 }
3321 }
3322
3323 async fn routing_scenarios(&self) -> Vec<RoutingScenarioHealth> {
3324 let tracker = self.outcome_tracker.read().await;
3325 let config = self.adaptive_router.config().clone();
3326 let scenarios = [
3327 (
3328 "interactive_text",
3329 "Summarize the benefits of local-first AI routing in two sentences.",
3330 "text",
3331 RoutingWorkload::Interactive,
3332 false,
3333 false,
3334 ),
3335 (
3336 "background_code",
3337 "Write a Python function named fibonacci(n) that returns the nth Fibonacci number.",
3338 "code",
3339 RoutingWorkload::Background,
3340 false,
3341 false,
3342 ),
3343 (
3344 "interactive_tool_use",
3345 "Use the provided weather tool to get the weather for Boston.",
3346 "tool_use",
3347 RoutingWorkload::Interactive,
3348 true,
3349 false,
3350 ),
3351 (
3352 "interactive_vision",
3353 "What is in this image? Answer in one word.",
3354 "vision",
3355 RoutingWorkload::Interactive,
3356 false,
3357 true,
3358 ),
3359 ];
3360
3361 scenarios
3362 .into_iter()
3363 .map(
3364 |(name, prompt, task_family, workload, has_tools, has_vision)| {
3365 let decision = self.adaptive_router.route_context_aware(
3366 prompt,
3367 0,
3368 &self.unified_registry,
3369 &tracker,
3370 has_tools,
3371 has_vision,
3372 workload,
3373 );
3374 let quality_first_cold_start = if has_tools || has_vision {
3375 config.quality_first_cold_start
3376 } else if task_family == "code"
3377 && matches!(workload, RoutingWorkload::Background)
3378 {
3379 false
3380 } else {
3381 config.quality_first_cold_start
3382 };
3383 RoutingScenarioHealth {
3384 name: name.to_string(),
3385 task_family: task_family.to_string(),
3386 workload,
3387 has_tools,
3388 has_vision,
3389 prefer_local: if task_family == "speech" {
3390 self.speech_policy.prefer_local
3391 } else {
3392 config.prefer_local
3393 },
3394 quality_first_cold_start,
3395 bootstrap_min_task_observations: config.bootstrap_min_task_observations,
3396 bootstrap_quality_floor: config.bootstrap_quality_floor,
3397 model_id: decision.model_id,
3398 model_name: decision.model_name,
3399 reason: decision.reason,
3400 strategy: decision.strategy,
3401 }
3402 },
3403 )
3404 .collect()
3405 }
3406
3407 pub async fn smoke_test_speech(
3409 &self,
3410 local: bool,
3411 remote: bool,
3412 ) -> Result<SpeechSmokeReport, InferenceError> {
3413 let mut report = SpeechSmokeReport::default();
3414
3415 if local {
3416 let tts = self
3417 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
3418 .ok_or_else(|| {
3419 InferenceError::InferenceFailed(
3420 "no local text-to-speech model available".into(),
3421 )
3422 })?;
3423 let stt = self
3424 .preferred_speech_schema(ModelCapability::SpeechToText, true, false)
3425 .ok_or_else(|| {
3426 InferenceError::InferenceFailed(
3427 "no local speech-to-text model available".into(),
3428 )
3429 })?;
3430 report.local = Some(
3431 self.run_speech_smoke_path("local", &tts, &stt, "Testing CAR local speech path.")
3432 .await?,
3433 );
3434 } else {
3435 report.skipped.push("local".to_string());
3436 }
3437
3438 if remote {
3439 let tts = self
3440 .preferred_speech_schema(ModelCapability::TextToSpeech, false, true)
3441 .ok_or_else(|| {
3442 InferenceError::InferenceFailed(
3443 "no remote text-to-speech model available".into(),
3444 )
3445 })?;
3446 let stt = self
3447 .preferred_speech_schema(ModelCapability::SpeechToText, false, true)
3448 .ok_or_else(|| {
3449 InferenceError::InferenceFailed(
3450 "no remote speech-to-text model available".into(),
3451 )
3452 })?;
3453 report.remote = Some(
3454 self.run_speech_smoke_path("remote", &tts, &stt, "Testing CAR remote speech path.")
3455 .await?,
3456 );
3457 } else {
3458 report.skipped.push("remote".to_string());
3459 }
3460
3461 Ok(report)
3462 }
3463
3464 fn speech_candidates(
3465 &self,
3466 capability: ModelCapability,
3467 explicit: Option<&str>,
3468 ) -> Result<Vec<ModelSchema>, InferenceError> {
3469 if let Some(model) = explicit {
3470 let schema = self
3471 .unified_registry
3472 .get(model)
3473 .or_else(|| self.unified_registry.find_by_name(model))
3474 .cloned()
3475 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3476 if !schema.has_capability(capability) {
3477 return Err(InferenceError::InferenceFailed(format!(
3478 "model {} does not support {:?}",
3479 schema.name, capability
3480 )));
3481 }
3482 return Ok(vec![schema]);
3483 }
3484
3485 let mut candidates: Vec<ModelSchema> = self
3486 .unified_registry
3487 .query(&ModelFilter {
3488 capabilities: vec![capability],
3489 ..Default::default()
3490 })
3491 .into_iter()
3492 .cloned()
3493 .collect();
3494
3495 if candidates.is_empty() {
3496 return Err(InferenceError::InferenceFailed(format!(
3497 "no models registered for capability {:?}",
3498 capability
3499 )));
3500 }
3501
3502 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3503 if !self.speech_policy.allow_remote_fallback
3504 && candidates.iter().any(|model| model.is_local())
3505 {
3506 candidates.retain(|model| model.is_local());
3507 }
3508
3509 Ok(candidates)
3510 }
3511
3512 fn resolve_external_hf_repo(
3517 &self,
3518 explicit: Option<&str>,
3519 capability: ModelCapability,
3520 ) -> Option<String> {
3521 let id = explicit?;
3522 let schema = self
3523 .unified_registry
3524 .get(id)
3525 .or_else(|| self.unified_registry.find_by_name(id))?;
3526 if !schema.has_capability(capability) {
3527 return Some(id.to_string());
3528 }
3529 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3530 return Some(hf_repo.clone());
3531 }
3532 Some(id.to_string())
3533 }
3534
3535 fn media_generation_candidates(
3536 &self,
3537 capability: ModelCapability,
3538 explicit: Option<&str>,
3539 ) -> Result<Vec<ModelSchema>, InferenceError> {
3540 if let Some(model) = explicit {
3541 let schema = self
3542 .unified_registry
3543 .get(model)
3544 .or_else(|| self.unified_registry.find_by_name(model))
3545 .cloned()
3546 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3547 if !schema.has_capability(capability) {
3548 return Err(InferenceError::InferenceFailed(format!(
3549 "model {} does not support {:?}",
3550 schema.name, capability
3551 )));
3552 }
3553 return Ok(vec![schema]);
3554 }
3555
3556 let mut candidates: Vec<ModelSchema> = self
3557 .unified_registry
3558 .query(&ModelFilter {
3559 capabilities: vec![capability],
3560 local_only: true,
3561 ..Default::default()
3562 })
3563 .into_iter()
3564 .cloned()
3565 .collect();
3566 candidates.sort_by_key(|schema| (!schema.available, schema.size_mb()));
3567 if candidates.is_empty() {
3568 return Err(InferenceError::InferenceFailed(format!(
3569 "no models registered for capability {:?}",
3570 capability
3571 )));
3572 }
3573 Ok(candidates)
3574 }
3575
3576 fn preferred_speech_schema(
3577 &self,
3578 capability: ModelCapability,
3579 local_only: bool,
3580 remote_only: bool,
3581 ) -> Option<ModelSchema> {
3582 let available_only = remote_only;
3583 let mut candidates: Vec<ModelSchema> = self
3584 .unified_registry
3585 .query(&ModelFilter {
3586 capabilities: vec![capability],
3587 available_only,
3588 ..Default::default()
3589 })
3590 .into_iter()
3591 .filter(|schema| {
3592 (!local_only || schema.is_local()) && (!remote_only || schema.is_remote())
3593 })
3594 .cloned()
3595 .collect();
3596 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3597 candidates.into_iter().next()
3598 }
3599
3600 fn speech_health_default_name(
3601 &self,
3602 capability: ModelCapability,
3603 local_only: bool,
3604 remote_only: bool,
3605 ) -> Option<String> {
3606 let preferred = match capability {
3607 ModelCapability::SpeechToText if local_only => {
3608 self.speech_policy.preferred_local_stt.as_ref()
3609 }
3610 ModelCapability::SpeechToText if remote_only => {
3611 self.speech_policy.preferred_remote_stt.as_ref()
3612 }
3613 ModelCapability::TextToSpeech if local_only => {
3614 self.speech_policy.preferred_local_tts.as_ref()
3615 }
3616 ModelCapability::TextToSpeech if remote_only => {
3617 self.speech_policy.preferred_remote_tts.as_ref()
3618 }
3619 _ => None,
3620 };
3621
3622 preferred
3623 .filter(|name| {
3624 self.unified_registry.list().iter().any(|schema| {
3625 schema.name == **name
3626 && schema.has_capability(capability)
3627 && (!local_only || schema.is_local())
3628 && (!remote_only || schema.is_remote())
3629 })
3630 })
3631 .cloned()
3632 .or_else(|| {
3633 self.preferred_speech_schema(capability, local_only, remote_only)
3634 .map(|schema| schema.name)
3635 })
3636 }
3637
3638 fn model_default_health(
3639 &self,
3640 capability: ModelCapability,
3641 configured_model: &str,
3642 ) -> ModelDefaultHealth {
3643 let schema = self
3644 .unified_registry
3645 .find_by_name(configured_model)
3646 .or_else(|| self.unified_registry.get(configured_model));
3647
3648 ModelDefaultHealth {
3649 capability,
3650 configured_model: configured_model.to_string(),
3651 available: schema.is_some_and(|model| model.available),
3652 is_local: schema.is_some_and(ModelSchema::is_local),
3653 provider: schema.map(|model| model.provider.clone()),
3654 }
3655 }
3656
3657 fn speech_sort_key(
3658 &self,
3659 capability: ModelCapability,
3660 model: &ModelSchema,
3661 ) -> (u8, u8, u8, u8, u64, u64) {
3662 let policy_preference = match capability {
3663 ModelCapability::SpeechToText if model.is_local() => {
3664 self.speech_policy.preferred_local_stt.as_ref()
3665 }
3666 ModelCapability::SpeechToText => self.speech_policy.preferred_remote_stt.as_ref(),
3667 ModelCapability::TextToSpeech if model.is_local() => {
3668 self.speech_policy.preferred_local_tts.as_ref()
3669 }
3670 ModelCapability::TextToSpeech => self.speech_policy.preferred_remote_tts.as_ref(),
3671 _ => None,
3672 };
3673 let local_rank = if self.speech_policy.prefer_local {
3674 if model.is_local() {
3675 0
3676 } else {
3677 1
3678 }
3679 } else if model.is_remote() {
3680 0
3681 } else {
3682 1
3683 };
3684 let availability_rank = if model.available {
3685 0
3686 } else if model.is_local() {
3687 1
3688 } else {
3689 2
3690 };
3691 let policy_rank: u8 = if policy_preference.is_some_and(|preferred| preferred == &model.name)
3692 {
3693 0
3694 } else {
3695 1
3696 };
3697 let speech_rank = match capability {
3698 ModelCapability::TextToSpeech => {
3699 if model.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
3700 0
3701 } else if model.name == "Kokoro-82M-bf16" {
3702 1
3703 } else if model.name == "Kokoro-82M-6bit" {
3704 2
3705 } else {
3706 3
3707 }
3708 }
3709 ModelCapability::SpeechToText => {
3710 if model.name == "Parakeet-TDT-0.6B-v3-MLX" {
3711 0
3712 } else {
3713 1
3714 }
3715 }
3716 _ => 0,
3717 };
3718 let latency_rank = model.performance.latency_p50_ms.unwrap_or(u64::MAX);
3719 let size_rank = model.cost.size_mb.unwrap_or(u64::MAX);
3720 (
3721 local_rank,
3722 availability_rank,
3723 policy_rank,
3724 speech_rank,
3725 latency_rank,
3726 size_rank,
3727 )
3728 }
3729
3730 async fn run_speech_smoke_path(
3731 &self,
3732 path: &str,
3733 tts: &ModelSchema,
3734 stt: &ModelSchema,
3735 text: &str,
3736 ) -> Result<SpeechSmokePathReport, InferenceError> {
3737 let work_dir = temp_work_dir(&format!("speech-smoke-{path}"))?;
3738 let audio_path = work_dir.join(format!("{path}.wav"));
3739 let synth = self
3740 .synthesize(SynthesizeRequest {
3741 text: text.to_string(),
3742 model: Some(tts.name.clone()),
3743 voice: default_speech_voice(tts),
3744 language: Some("en".to_string()),
3745 output_path: Some(audio_path.display().to_string()),
3746 ..SynthesizeRequest::default()
3747 })
3748 .await?;
3749 let transcript = self
3750 .transcribe(TranscribeRequest {
3751 audio_path: synth.audio_path.clone(),
3752 model: Some(stt.name.clone()),
3753 language: Some("en".to_string()),
3754 prompt: None,
3755 timestamps: false,
3756 })
3757 .await?;
3758
3759 Ok(SpeechSmokePathReport {
3760 path: path.to_string(),
3761 tts_model: synth.model_used.unwrap_or_else(|| tts.name.clone()),
3762 stt_model: transcript.model_used.unwrap_or_else(|| stt.name.clone()),
3763 audio_path: PathBuf::from(synth.audio_path),
3764 transcript: transcript.text,
3765 })
3766 }
3767
3768 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3769 async fn ensure_speech_runtime(&self) -> Result<SpeechRuntime, InferenceError> {
3770 let mut guard = self.speech_runtime.lock().await;
3771 if let Some(runtime) = guard.as_ref() {
3772 if runtime.is_ready() {
3773 return Ok(runtime.clone());
3774 }
3775 }
3776
3777 let runtime =
3778 SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3779 if !runtime.is_ready() {
3780 bootstrap_speech_runtime(&runtime).await?;
3781 }
3782 if !runtime.is_ready() {
3783 return Err(InferenceError::InferenceFailed(format!(
3784 "managed speech runtime is not ready at {}",
3785 runtime.root.display()
3786 )));
3787 }
3788
3789 *guard = Some(runtime.clone());
3790 Ok(runtime)
3791 }
3792
3793 async fn transcribe_local_mlx(
3794 &self,
3795 schema: &ModelSchema,
3796 req: &TranscribeRequest,
3797 ) -> Result<TranscribeResult, InferenceError> {
3798 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3800 {
3801 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3802 let parakeet = backend::mlx_parakeet::ParakeetBackend::load(&model_dir)?;
3803 let (text, words) = if req.timestamps {
3805 parakeet
3806 .transcribe_detailed(Path::new(&req.audio_path))
3807 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?
3808 } else {
3809 let t = parakeet
3810 .transcribe(Path::new(&req.audio_path))
3811 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?;
3812 (t, Vec::new())
3813 };
3814 return Ok(TranscribeResult {
3815 text,
3816 model_used: Some(schema.name.clone()),
3817 language: req.language.clone(),
3818 words,
3819 });
3820 }
3821
3822 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3824 {
3825 let runtime = self.ensure_speech_runtime().await?;
3826 let hf_repo = match &schema.source {
3827 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3828 _ => unreachable!(),
3829 };
3830 let output_dir = temp_work_dir("stt")?;
3831 let output_prefix = output_dir.join("transcript");
3832 let mut args = vec![
3833 "--model".to_string(),
3834 hf_repo,
3835 "--audio".to_string(),
3836 req.audio_path.clone(),
3837 "--output-path".to_string(),
3838 output_prefix.display().to_string(),
3839 "--format".to_string(),
3840 "json".to_string(),
3841 ];
3842 if let Some(language) = &req.language {
3843 args.push("--language".to_string());
3844 args.push(normalize_lang_code(language));
3845 }
3846 if let Some(prompt) = &req.prompt {
3847 args.push("--context".to_string());
3848 args.push(prompt.clone());
3849 }
3850 if req.timestamps {
3851 args.push("--verbose".to_string());
3852 }
3853
3854 let output = run_mlx_audio_command(&runtime, "stt.generate", &args).await?;
3855 let text = read_transcription_result(&output_prefix)?
3856 .or_else(|| extract_text_from_payload(&output.stdout))
3857 .ok_or_else(|| {
3858 InferenceError::InferenceFailed(format!(
3859 "mlx-audio transcription returned no text: {}",
3860 output.stderr
3861 ))
3862 })?;
3863
3864 Ok(TranscribeResult {
3865 text,
3866 model_used: Some(schema.name.clone()),
3867 language: req.language.clone(),
3868 words: Vec::new(),
3869 })
3870 }
3871 }
3872
3873 async fn synthesize_local_mlx(
3874 &self,
3875 schema: &ModelSchema,
3876 req: &SynthesizeRequest,
3877 ) -> Result<SynthesizeResult, InferenceError> {
3878 let requested = req.requested_advanced_controls();
3883 let repo_supports_advanced = match &schema.source {
3884 ModelSource::Mlx { hf_repo, .. } => hf_repo.to_ascii_lowercase().contains("qwen3-tts"),
3885 _ => false,
3886 };
3887 if !requested.is_empty() && !repo_supports_advanced {
3888 if req.strict_capabilities {
3889 return Err(InferenceError::InferenceFailed(format!(
3890 "model {name} does not support Qwen3-TTS advanced controls {requested:?}; \
3891 route to a Qwen3-TTS model or set strict_capabilities = false to degrade",
3892 name = schema.name,
3893 )));
3894 }
3895 tracing::warn!(
3896 model = %schema.name,
3897 fields = ?requested,
3898 "Qwen3-TTS advanced controls set on non-Qwen3-TTS backend — ignored \
3899 (set strict_capabilities=true to error instead)"
3900 );
3901 }
3902
3903 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3905 {
3906 if repo_supports_advanced && !requested.is_empty() {
3912 if req.strict_capabilities {
3913 return Err(InferenceError::InferenceFailed(format!(
3914 "native MLX TTS backend does not yet implement Qwen3-TTS advanced \
3915 controls {requested:?}; run on non-Apple-Silicon to use the Python \
3916 mlx-audio fallback, or set strict_capabilities = false"
3917 )));
3918 }
3919 tracing::warn!(
3920 model = %schema.name,
3921 fields = ?requested,
3922 "Qwen3-TTS advanced controls are not yet implemented in the native MLX TTS \
3923 backend; synthesizing without cloning/voice-design"
3924 );
3925 }
3926 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3927 let size = backend_cache::estimate_model_size(&model_dir);
3928 let cache = Arc::clone(&self.kokoro_cache);
3929 let key = schema.id.clone();
3930 let handle = cache.get_or_load(&key, size, || {
3931 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
3932 })?;
3933
3934 let output_path = req.output_path.clone().unwrap_or_else(|| {
3935 let dir = std::env::temp_dir().join("car_tts");
3936 let _ = std::fs::create_dir_all(&dir);
3937 dir.join("output.wav").display().to_string()
3938 });
3939 let voice = req.voice.as_deref().unwrap_or("af_heart").to_string();
3940 let text = req.text.clone();
3941 let op = tokio::task::spawn_blocking(move || -> Result<PathBuf, InferenceError> {
3942 let mut guard = handle.lock().map_err(|_| {
3943 InferenceError::InferenceFailed("kokoro backend mutex poisoned".into())
3944 })?;
3945 guard
3946 .synthesize(&text, Some(&voice), Path::new(&output_path))
3947 .map_err(|e| InferenceError::InferenceFailed(format!("native TTS: {e}")))
3948 })
3949 .await
3950 .map_err(|e| InferenceError::InferenceFailed(format!("kokoro task join: {e}")))??;
3951
3952 let final_path =
3953 materialize_audio_output(&op, req.output_path.as_deref(), &req.format)?;
3954 return Ok(SynthesizeResult {
3955 audio_path: final_path.display().to_string(),
3956 media_type: media_type_for_format(&req.format),
3957 model_used: Some(schema.name.clone()),
3958 voice_used: req.voice.clone(),
3959 });
3960 }
3961
3962 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3964 {
3965 let runtime = self.ensure_speech_runtime().await?;
3966 let primary_hf_repo = match &schema.source {
3967 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3968 _ => unreachable!(),
3969 };
3970 let (produced, model_used) = match self
3971 .synthesize_local_mlx_repo(&runtime, &primary_hf_repo, schema.name.as_str(), req)
3972 .await
3973 {
3974 Ok(result) => result,
3975 Err(primary_err)
3976 if primary_hf_repo == "mlx-community/Kokoro-82M-6bit"
3977 && kokoro_runtime_fallback_enabled() =>
3978 {
3979 let fallback_repo = "mlx-community/Kokoro-82M-bf16";
3980 let fallback_name = "Kokoro-82M-bf16";
3981 match self
3982 .synthesize_local_mlx_repo(&runtime, fallback_repo, fallback_name, req)
3983 .await
3984 {
3985 Ok(result) => result,
3986 Err(fallback_err) => {
3987 return Err(InferenceError::InferenceFailed(format!(
3988 "{primary_err}; fallback {fallback_name} also failed: {fallback_err}"
3989 )));
3990 }
3991 }
3992 }
3993 Err(err) => return Err(err),
3994 };
3995 let final_path =
3996 materialize_audio_output(&produced, req.output_path.as_deref(), &req.format)?;
3997
3998 Ok(SynthesizeResult {
3999 audio_path: final_path.display().to_string(),
4000 media_type: media_type_for_format(&req.format),
4001 model_used: Some(model_used),
4002 voice_used: req.voice.clone(),
4003 })
4004 }
4005 }
4006
4007 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4008 async fn synthesize_local_mlx_repo(
4009 &self,
4010 runtime: &SpeechRuntime,
4011 hf_repo: &str,
4012 model_name: &str,
4013 req: &SynthesizeRequest,
4014 ) -> Result<(PathBuf, String), InferenceError> {
4015 let output_dir = temp_work_dir("tts")?;
4016 let mut args = vec![
4017 "--model".to_string(),
4018 hf_repo.to_string(),
4019 "--text".to_string(),
4020 req.text.clone(),
4021 "--output_path".to_string(),
4022 output_dir.display().to_string(),
4023 ];
4024 if let Some(voice) = &req.voice {
4025 args.push("--voice".to_string());
4026 args.push(voice.clone());
4027 }
4028 if let Some(speed) = req.speed {
4029 args.push("--speed".to_string());
4030 args.push(speed.to_string());
4031 }
4032 let repo_lower = hf_repo.to_ascii_lowercase();
4033 if repo_lower.contains("kokoro") {
4034 args.push("--lang_code".to_string());
4035 args.push(kokoro_lang_code(req.language.as_deref()).to_string());
4036 } else if let Some(language) = &req.language {
4037 args.push("--lang_code".to_string());
4038 args.push(normalize_lang_code(language));
4039 }
4040
4041 if repo_lower.contains("qwen3-tts") {
4047 if let Some(ref_audio) = &req.reference_audio_path {
4048 args.push("--ref_audio".to_string());
4049 args.push(ref_audio.clone());
4050 }
4051 if let Some(ref_text) = &req.reference_text {
4052 args.push("--ref_text".to_string());
4053 args.push(ref_text.clone());
4054 }
4055 if let Some(instruct) = &req.voice_instruction {
4056 args.push("--instruct".to_string());
4057 args.push(instruct.clone());
4058 }
4059 }
4060
4061 let output = if repo_lower.contains("kokoro") {
4062 let device = std::env::var("CAR_SPEECH_KOKORO_DEVICE")
4063 .or_else(|_| std::env::var("CAR_SPEECH_MLX_DEVICE"))
4064 .unwrap_or_else(|_| "cpu".to_string());
4065 let extra_env = vec![
4066 ("MLX_DEVICE".to_string(), device),
4068 ("PYTORCH_ENABLE_MPS_FALLBACK".to_string(), "1".to_string()),
4070 ];
4071 run_mlx_audio_command_with_env(runtime, "tts.generate", &args, &extra_env).await?
4072 } else {
4073 run_mlx_audio_command(runtime, "tts.generate", &args).await?
4074 };
4075 let produced = find_audio_file(&output_dir)?.ok_or_else(|| {
4076 let hint = if repo_lower.contains("kokoro") {
4077 ". Kokoro models may crash on GPU — try CAR_SPEECH_KOKORO_DEVICE=cpu or use the default Qwen3-TTS model"
4078 } else {
4079 ""
4080 };
4081 InferenceError::InferenceFailed(format!(
4082 "mlx-audio synthesis produced no audio file: {}{}",
4083 output.stderr, hint
4084 ))
4085 })?;
4086 Ok((produced, model_name.to_string()))
4087 }
4088
4089 async fn transcribe_elevenlabs(
4090 &self,
4091 schema: &ModelSchema,
4092 req: &TranscribeRequest,
4093 ) -> Result<TranscribeResult, InferenceError> {
4094 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4095 let file_name = Path::new(&req.audio_path)
4096 .file_name()
4097 .and_then(|f| f.to_str())
4098 .unwrap_or("audio.wav")
4099 .to_string();
4100 let audio_bytes = tokio::fs::read(&req.audio_path).await?;
4101 let file_part = Part::bytes(audio_bytes).file_name(file_name);
4102 let mut form = Form::new()
4103 .text("model_id", schema.name.clone())
4104 .part("file", file_part);
4105 if let Some(language) = &req.language {
4106 form = form.text("language_code", language.clone());
4107 }
4108
4109 let resp = self
4110 .remote_backend
4111 .client
4112 .post(format!(
4113 "{}/v1/speech-to-text",
4114 endpoint.trim_end_matches('/')
4115 ))
4116 .header("xi-api-key", api_key)
4117 .multipart(form)
4118 .send()
4119 .await
4120 .map_err(|e| {
4121 InferenceError::InferenceFailed(format!("ElevenLabs STT request failed: {e}"))
4122 })?;
4123 let status = resp.status();
4124 let body = resp.text().await.map_err(|e| {
4125 InferenceError::InferenceFailed(format!("read ElevenLabs STT body: {e}"))
4126 })?;
4127 if !status.is_success() {
4128 return Err(InferenceError::InferenceFailed(format!(
4129 "ElevenLabs STT returned {status}: {body}"
4130 )));
4131 }
4132 let payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
4133 InferenceError::InferenceFailed(format!("parse ElevenLabs STT response: {e}"))
4134 })?;
4135 let text = payload
4136 .get("text")
4137 .and_then(|v| v.as_str())
4138 .map(str::to_string)
4139 .ok_or_else(|| {
4140 InferenceError::InferenceFailed("ElevenLabs STT response missing text".into())
4141 })?;
4142
4143 Ok(TranscribeResult {
4144 text,
4145 model_used: Some(schema.name.clone()),
4146 language: payload
4147 .get("language_code")
4148 .and_then(|v| v.as_str())
4149 .map(str::to_string),
4150 words: Vec::new(),
4151 })
4152 }
4153
4154 async fn synthesize_elevenlabs(
4155 &self,
4156 schema: &ModelSchema,
4157 req: &SynthesizeRequest,
4158 ) -> Result<SynthesizeResult, InferenceError> {
4159 let requested = req.requested_advanced_controls();
4163 if !requested.is_empty() {
4164 if req.strict_capabilities {
4165 return Err(InferenceError::InferenceFailed(format!(
4166 "ElevenLabs backend does not support Qwen3-TTS advanced controls \
4167 {requested:?}; route to a Qwen3-TTS model or set strict_capabilities = false"
4168 )));
4169 }
4170 tracing::warn!(
4171 model = %schema.name,
4172 fields = ?requested,
4173 "Qwen3-TTS advanced controls ignored by ElevenLabs backend"
4174 );
4175 }
4176 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4177 let voice_id = req
4178 .voice
4179 .clone()
4180 .unwrap_or_else(|| "JBFqnCBsd6RMkjVDRZzb".to_string());
4181 let output_format = elevenlabs_output_format(&req.format);
4182 let url = format!(
4183 "{}/v1/text-to-speech/{}?output_format={}",
4184 endpoint.trim_end_matches('/'),
4185 voice_id,
4186 output_format
4187 );
4188
4189 let mut body = serde_json::json!({
4190 "text": req.text,
4191 "model_id": schema.name,
4192 });
4193 if let Some(language) = &req.language {
4194 body["language_code"] = serde_json::Value::String(language.clone());
4195 }
4196
4197 let resp = self
4198 .remote_backend
4199 .client
4200 .post(url)
4201 .header("xi-api-key", api_key)
4202 .header("Content-Type", "application/json")
4203 .json(&body)
4204 .send()
4205 .await
4206 .map_err(|e| {
4207 InferenceError::InferenceFailed(format!("ElevenLabs TTS request failed: {e}"))
4208 })?;
4209 let status = resp.status();
4210 let audio = resp.bytes().await.map_err(|e| {
4211 InferenceError::InferenceFailed(format!("read ElevenLabs TTS body: {e}"))
4212 })?;
4213 if !status.is_success() {
4214 let err_body = String::from_utf8_lossy(&audio);
4215 return Err(InferenceError::InferenceFailed(format!(
4216 "ElevenLabs TTS returned {status}: {err_body}"
4217 )));
4218 }
4219
4220 let final_path = requested_or_temp_output(req.output_path.as_deref(), &req.format)?;
4221 ensure_parent_dir(&final_path)?;
4222 tokio::fs::write(&final_path, &audio).await?;
4223
4224 Ok(SynthesizeResult {
4225 audio_path: final_path.display().to_string(),
4226 media_type: media_type_for_format(&req.format),
4227 model_used: Some(schema.name.clone()),
4228 voice_used: Some(voice_id),
4229 })
4230 }
4231}
4232
4233#[derive(Default)]
4234struct ProviderAccumulator {
4235 configured: bool,
4236 local_models: usize,
4237 remote_models: usize,
4238 available_models: usize,
4239 capabilities: std::collections::HashSet<ModelCapability>,
4240}
4241
4242#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4246struct CommandOutput {
4247 stdout: String,
4248 stderr: String,
4249}
4250
4251#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4252#[derive(Debug, Clone)]
4253struct SpeechRuntime {
4254 root: PathBuf,
4255 python: PathBuf,
4256 stt_program: PathBuf,
4257 tts_program: PathBuf,
4258}
4259
4260#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4261impl SpeechRuntime {
4262 fn new(root: PathBuf) -> Self {
4263 let bin_dir = root.join("bin");
4264 Self {
4265 root,
4266 python: bin_dir.join("python"),
4267 stt_program: bin_dir.join("mlx_audio.stt.generate"),
4268 tts_program: bin_dir.join("mlx_audio.tts.generate"),
4269 }
4270 }
4271
4272 fn is_ready(&self) -> bool {
4273 self.python.exists() && self.stt_program.exists() && self.tts_program.exists()
4274 }
4275
4276 fn command_for(&self, subcommand: &str) -> Result<&Path, InferenceError> {
4277 match subcommand {
4278 "stt.generate" => Ok(&self.stt_program),
4279 "tts.generate" => Ok(&self.tts_program),
4280 _ => Err(InferenceError::InferenceFailed(format!(
4281 "unknown speech subcommand: {subcommand}"
4282 ))),
4283 }
4284 }
4285}
4286
4287#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4288async fn run_mlx_audio_command(
4289 runtime: &SpeechRuntime,
4290 subcommand: &str,
4291 args: &[String],
4292) -> Result<CommandOutput, InferenceError> {
4293 run_mlx_audio_command_with_env(runtime, subcommand, args, &[]).await
4294}
4295
4296#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4297async fn run_mlx_audio_command_with_env(
4298 runtime: &SpeechRuntime,
4299 subcommand: &str,
4300 args: &[String],
4301 envs: &[(String, String)],
4302) -> Result<CommandOutput, InferenceError> {
4303 let program = runtime.command_for(subcommand)?;
4304 let mut command = Command::new(program);
4305 command.args(args);
4306 for (key, value) in envs {
4307 command.env(key, value);
4308 }
4309 let output = command
4310 .output()
4311 .await
4312 .map_err(|err| InferenceError::InferenceFailed(format!("{}: {err}", program.display())))?;
4313
4314 if output.status.success() {
4315 Ok(CommandOutput {
4316 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
4317 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
4318 })
4319 } else {
4320 Err(InferenceError::InferenceFailed(format!(
4321 "{} exited with {}: {}",
4322 program.display(),
4323 output.status,
4324 String::from_utf8_lossy(&output.stderr)
4325 )))
4326 }
4327}
4328
4329#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4330async fn bootstrap_speech_runtime(runtime: &SpeechRuntime) -> Result<(), InferenceError> {
4331 std::fs::create_dir_all(&runtime.root)?;
4332 let python = select_speech_python()?;
4333
4334 run_command(
4335 "uv",
4336 &[
4337 "venv".to_string(),
4338 "--python".to_string(),
4339 python,
4340 runtime.root.display().to_string(),
4341 ],
4342 )
4343 .await?;
4344
4345 run_command(
4346 "uv",
4347 &[
4348 "pip".to_string(),
4349 "install".to_string(),
4350 "--python".to_string(),
4351 runtime.python.display().to_string(),
4352 speech_runtime_mlx_audio_spec(),
4353 "misaki[en]".to_string(),
4354 speech_runtime_spacy_model_spec(),
4355 ],
4356 )
4357 .await?;
4358
4359 Ok(())
4360}
4361
4362#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4363async fn run_command(program: &str, args: &[String]) -> Result<(), InferenceError> {
4364 let output = Command::new(program)
4365 .args(args)
4366 .output()
4367 .await
4368 .map_err(|err| InferenceError::InferenceFailed(format!("{program}: {err}")))?;
4369
4370 if output.status.success() {
4371 Ok(())
4372 } else {
4373 Err(InferenceError::InferenceFailed(format!(
4374 "{} exited with {}: {}",
4375 program,
4376 output.status,
4377 String::from_utf8_lossy(&output.stderr)
4378 )))
4379 }
4380}
4381
4382#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4383fn select_speech_python() -> Result<String, InferenceError> {
4384 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4385 if !path.trim().is_empty() {
4386 return Ok(path);
4387 }
4388 }
4389
4390 for candidate in ["python3.13", "python3.12", "python3.11"] {
4391 if command_in_path(candidate) {
4392 return Ok(candidate.to_string());
4393 }
4394 }
4395
4396 Err(InferenceError::InferenceFailed(
4397 "no supported Python found for managed speech runtime (tried python3.13, python3.12, python3.11)".into(),
4398 ))
4399}
4400
4401#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4402fn detect_speech_python() -> Option<String> {
4403 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4404 if !path.trim().is_empty() {
4405 return Some(path);
4406 }
4407 }
4408
4409 ["python3.13", "python3.12", "python3.11"]
4410 .into_iter()
4411 .find(|candidate| command_in_path(candidate))
4412 .map(str::to_string)
4413}
4414
4415#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4416fn speech_runtime_root_from_models_dir(_models_dir: &Path) -> PathBuf {
4417 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
4418 if !path.trim().is_empty() {
4419 return PathBuf::from(path);
4420 }
4421 }
4422
4423 std::env::var("HOME")
4424 .map(PathBuf::from)
4425 .unwrap_or_else(|_| PathBuf::from("."))
4426 .join(".car")
4427 .join("speech-runtime")
4428}
4429
4430#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4431fn command_in_path(name: &str) -> bool {
4432 std::env::var_os("PATH")
4433 .map(|paths| {
4434 std::env::split_paths(&paths).any(|dir| {
4435 let path = dir.join(name);
4436 path.exists() && path.is_file()
4437 })
4438 })
4439 .unwrap_or(false)
4440}
4441
4442fn speech_model_cached(schema: &ModelSchema) -> bool {
4443 match &schema.source {
4444 ModelSource::Mlx { hf_repo, .. } => huggingface_repo_has_snapshot(hf_repo),
4445 ModelSource::Proprietary { auth, .. } => match auth {
4446 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4447 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4448 ProprietaryAuth::OAuth2Pkce { .. } => false,
4449 },
4450 _ => false,
4451 }
4452}
4453
4454fn remove_huggingface_repo_cache(repo_id: &str) -> Result<(), InferenceError> {
4455 let repo_dir = std::env::var("HF_HOME")
4456 .map(PathBuf::from)
4457 .unwrap_or_else(|_| {
4458 std::env::var("HOME")
4459 .map(PathBuf::from)
4460 .unwrap_or_else(|_| PathBuf::from("."))
4461 .join(".cache")
4462 .join("huggingface")
4463 })
4464 .join("hub")
4465 .join(format!("models--{}", repo_id.replace('/', "--")));
4466
4467 if repo_dir.exists() {
4468 std::fs::remove_dir_all(repo_dir)?;
4469 }
4470 Ok(())
4471}
4472
4473fn model_source_configured(schema: &ModelSchema) -> bool {
4474 match &schema.source {
4475 ModelSource::RemoteApi {
4476 api_key_env,
4477 api_key_envs,
4478 ..
4479 } => {
4480 std::env::var(api_key_env).is_ok()
4481 || api_key_envs
4482 .iter()
4483 .any(|env_var| std::env::var(env_var).is_ok())
4484 }
4485 ModelSource::Proprietary { auth, .. } => match auth {
4486 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4487 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4488 ProprietaryAuth::OAuth2Pkce { .. } => false,
4489 },
4490 ModelSource::VllmMlx { .. } => {
4491 std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available
4492 }
4493 ModelSource::Ollama { .. } => schema.available,
4494 ModelSource::Mlx { .. } | ModelSource::Local { .. } => true,
4495 ModelSource::AppleFoundationModels { .. } => schema.available,
4496 ModelSource::Delegated { .. } => true,
4501 }
4502}
4503
4504fn all_model_capabilities() -> [ModelCapability; 13] {
4505 [
4506 ModelCapability::Generate,
4507 ModelCapability::Embed,
4508 ModelCapability::Classify,
4509 ModelCapability::Code,
4510 ModelCapability::Reasoning,
4511 ModelCapability::Summarize,
4512 ModelCapability::ToolUse,
4513 ModelCapability::MultiToolCall,
4514 ModelCapability::Vision,
4515 ModelCapability::SpeechToText,
4516 ModelCapability::TextToSpeech,
4517 ModelCapability::ImageGeneration,
4518 ModelCapability::VideoGeneration,
4519 ]
4520}
4521
4522fn sort_capabilities(mut capabilities: Vec<ModelCapability>) -> Vec<ModelCapability> {
4523 capabilities.sort_by_key(|capability| {
4524 all_model_capabilities()
4525 .iter()
4526 .position(|candidate| candidate == capability)
4527 .unwrap_or(usize::MAX)
4528 });
4529 capabilities
4530}
4531
4532fn speech_model_source_label(schema: &ModelSchema) -> String {
4533 match &schema.source {
4534 ModelSource::Mlx { hf_repo, .. } => format!("mlx:{hf_repo}"),
4535 ModelSource::Proprietary {
4536 provider, endpoint, ..
4537 } => format!("proprietary:{provider}:{endpoint}"),
4538 ModelSource::RemoteApi { endpoint, .. } => format!("remote:{endpoint}"),
4539 ModelSource::Local { hf_repo, .. } => format!("local:{hf_repo}"),
4540 ModelSource::VllmMlx {
4541 endpoint,
4542 model_name,
4543 } => format!("vllm-mlx:{endpoint}:{model_name}"),
4544 ModelSource::Ollama { model_tag, host } => format!("ollama:{host}:{model_tag}"),
4545 ModelSource::AppleFoundationModels { use_case } => {
4546 format!(
4547 "apple-foundation:{}",
4548 use_case.as_deref().unwrap_or("default")
4549 )
4550 }
4551 ModelSource::Delegated { hint } => {
4552 format!("delegated:{}", hint.as_deref().unwrap_or("(none)"))
4553 }
4554 }
4555}
4556
4557fn rerank_prompt(instruction: &str, query: &str, document: &str) -> String {
4565 const SYSTEM: &str = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".";
4566 format!(
4567 "<|im_start|>system\n{SYSTEM}<|im_end|>\n\
4568 <|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n\
4569 <|im_start|>assistant\n<think>\n\n</think>\n\n"
4570 )
4571}
4572
4573fn no_backend_recovery_hint(underlying: &str) -> Option<String> {
4592 let is_no_backend = underlying.contains("no credential")
4593 || underlying.contains("model not found")
4594 || underlying.contains("no models available")
4595 || underlying.contains("no inference runner");
4596 if !is_no_backend {
4597 return None;
4598 }
4599 Some(format!(
4600 "no inference backend is available. To install a local CPU-only \
4601 model (no account required, works on Windows/Linux/macOS), run:\n \
4602 car models pull qwen/qwen3-1.7b:q8_0\n\
4603 To use Parslee's hosted models instead, run:\n \
4604 car auth login parslee\n\
4605 (underlying error: {underlying})"
4606 ))
4607}
4608
4609fn score_from_rerank_output(text: &str, model_name: &str) -> f32 {
4610 let normalized: String = text
4615 .to_ascii_lowercase()
4616 .chars()
4617 .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
4618 .collect();
4619 for tok in normalized.split_ascii_whitespace().take(5) {
4620 match tok {
4621 "yes" => return 1.0,
4622 "no" => return 0.0,
4623 _ => continue,
4624 }
4625 }
4626 tracing::warn!(
4627 model = %model_name,
4628 output = %text,
4629 "rerank: first tokens contain neither `yes` nor `no`; returning neutral 0.5"
4630 );
4631 0.5
4632}
4633
4634fn default_speech_voice(schema: &ModelSchema) -> Option<String> {
4635 if schema.provider == "elevenlabs" {
4636 Some("JBFqnCBsd6RMkjVDRZzb".to_string())
4637 } else if schema.name == "Kokoro-82M-6bit" || schema.name == "Kokoro-82M-bf16" {
4638 Some("af_heart".to_string())
4639 } else if schema.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
4640 Some("Chelsie".to_string())
4641 } else {
4642 None
4643 }
4644}
4645
4646fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
4647 find_latest_huggingface_snapshot(repo_id).is_some()
4648}
4649
4650fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
4651 let cache_root = std::env::var("HF_HOME")
4652 .map(PathBuf::from)
4653 .unwrap_or_else(|_| {
4654 std::env::var("HOME")
4655 .map(PathBuf::from)
4656 .unwrap_or_else(|_| PathBuf::from("."))
4657 .join(".cache")
4658 .join("huggingface")
4659 })
4660 .join("hub");
4661 cache_root.join(format!("models--{}", repo_id.replace('/', "--")))
4662}
4663
4664fn find_latest_huggingface_snapshot(repo_id: &str) -> Option<PathBuf> {
4665 let snapshots = huggingface_repo_dir(repo_id).join("snapshots");
4666 std::fs::read_dir(snapshots)
4667 .ok()?
4668 .filter_map(Result::ok)
4669 .map(|entry| entry.path())
4670 .find(|path| path.is_dir() && snapshot_looks_ready(path))
4671}
4672
4673fn snapshot_looks_ready(path: &Path) -> bool {
4674 if path.join("config.json").exists() || path.join("model_index.json").exists() {
4675 return true;
4676 }
4677 snapshot_contains_ext(path, "safetensors")
4678}
4679
4680fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
4681 let Ok(entries) = std::fs::read_dir(root) else {
4682 return false;
4683 };
4684 entries.filter_map(Result::ok).any(|entry| {
4685 let path = entry.path();
4686 if path.is_dir() {
4687 snapshot_contains_ext(&path, ext)
4688 } else {
4689 path.extension()
4690 .and_then(|value| value.to_str())
4691 .map(|value| value.eq_ignore_ascii_case(ext))
4692 .unwrap_or(false)
4693 }
4694 })
4695}
4696
4697fn count_files_recursive(root: &Path) -> usize {
4698 let Ok(entries) = std::fs::read_dir(root) else {
4699 return 0;
4700 };
4701 entries
4702 .filter_map(Result::ok)
4703 .map(|entry| entry.path())
4704 .map(|path| {
4705 if path.is_dir() {
4706 count_files_recursive(&path)
4707 } else if path.is_file() {
4708 1
4709 } else {
4710 0
4711 }
4712 })
4713 .sum()
4714}
4715
4716async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
4717 let api = hf_hub::api::tokio::ApiBuilder::from_env()
4718 .with_progress(false)
4719 .build()
4720 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
4721 let repo = api.model(repo_id.to_string());
4722 let info = repo
4723 .info()
4724 .await
4725 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
4726
4727 let snapshot_path = huggingface_repo_dir(repo_id)
4728 .join("snapshots")
4729 .join(&info.sha);
4730 let mut downloaded = 0usize;
4731 for sibling in &info.siblings {
4732 let local_path = snapshot_path.join(&sibling.rfilename);
4733 if local_path.exists() {
4734 downloaded += 1;
4735 continue;
4736 }
4737 repo.download(&sibling.rfilename).await.map_err(|e| {
4738 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
4739 })?;
4740 downloaded += 1;
4741 }
4742
4743 Ok((snapshot_path, downloaded))
4744}
4745
4746fn temp_work_dir(prefix: &str) -> Result<PathBuf, InferenceError> {
4747 let unique = SystemTime::now()
4748 .duration_since(UNIX_EPOCH)
4749 .map_err(|e| InferenceError::InferenceFailed(format!("clock error: {e}")))?
4750 .as_nanos();
4751 let dir = std::env::temp_dir().join(format!("car-inference-{prefix}-{unique}"));
4752 std::fs::create_dir_all(&dir)?;
4753 Ok(dir)
4754}
4755
4756fn ensure_parent_dir(path: &Path) -> Result<(), InferenceError> {
4757 if let Some(parent) = path.parent() {
4758 std::fs::create_dir_all(parent)?;
4759 }
4760 Ok(())
4761}
4762
4763fn requested_or_temp_output(
4764 output_path: Option<&str>,
4765 format: &str,
4766) -> Result<PathBuf, InferenceError> {
4767 if let Some(path) = output_path {
4768 return Ok(PathBuf::from(path));
4769 }
4770 let dir = temp_work_dir("audio-out")?;
4771 Ok(dir.join(format!("speech.{format}")))
4772}
4773
4774fn requested_or_temp_media_output(
4775 output_path: Option<&str>,
4776 format: &str,
4777 stem: &str,
4778) -> Result<PathBuf, InferenceError> {
4779 if let Some(path) = output_path {
4780 return Ok(PathBuf::from(path));
4781 }
4782 let dir = temp_work_dir(&format!("{stem}-out"))?;
4783 Ok(dir.join(format!("{stem}.{format}")))
4784}
4785
4786fn materialize_audio_output(
4787 produced: &Path,
4788 requested: Option<&str>,
4789 format: &str,
4790) -> Result<PathBuf, InferenceError> {
4791 if let Some(path) = requested {
4792 let dest = PathBuf::from(path);
4793 ensure_parent_dir(&dest)?;
4794 std::fs::copy(produced, &dest)?;
4795 Ok(dest)
4796 } else {
4797 let dest = requested_or_temp_output(None, format)?;
4798 ensure_parent_dir(&dest)?;
4799 std::fs::copy(produced, &dest)?;
4800 Ok(dest)
4801 }
4802}
4803
4804fn materialize_binary_output(
4805 produced: &Path,
4806 requested: Option<&str>,
4807 format: &str,
4808 stem: &str,
4809) -> Result<PathBuf, InferenceError> {
4810 let dest = requested_or_temp_media_output(requested, format, stem)?;
4811 ensure_parent_dir(&dest)?;
4812 std::fs::copy(produced, &dest)?;
4813 Ok(dest)
4814}
4815
4816fn find_generated_file(
4817 root: &Path,
4818 extensions: &[&str],
4819) -> Result<Option<PathBuf>, InferenceError> {
4820 let entries = std::fs::read_dir(root)?;
4821 let mut candidates: Vec<PathBuf> = entries
4822 .filter_map(Result::ok)
4823 .map(|entry| entry.path())
4824 .filter(|path| {
4825 path.is_file()
4826 && path
4827 .extension()
4828 .and_then(|ext| ext.to_str())
4829 .map(|ext| {
4830 extensions
4831 .iter()
4832 .any(|candidate| candidate.eq_ignore_ascii_case(ext))
4833 })
4834 .unwrap_or(false)
4835 })
4836 .collect();
4837 candidates.sort();
4838 Ok(candidates.pop())
4839}
4840
4841fn media_type_for_image_format(format: &str) -> String {
4842 match format.to_ascii_lowercase().as_str() {
4843 "jpg" | "jpeg" => "image/jpeg".to_string(),
4844 "webp" => "image/webp".to_string(),
4845 _ => "image/png".to_string(),
4846 }
4847}
4848
4849fn media_type_for_video_format(format: &str) -> String {
4850 match format.to_ascii_lowercase().as_str() {
4851 "mov" => "video/quicktime".to_string(),
4852 "gif" => "image/gif".to_string(),
4853 _ => "video/mp4".to_string(),
4854 }
4855}
4856
4857fn read_transcription_result(output_prefix: &Path) -> Result<Option<String>, InferenceError> {
4858 let candidates = [
4859 output_prefix.with_extension("json"),
4860 output_prefix.to_path_buf(),
4861 ];
4862
4863 for path in candidates {
4864 if path.exists() {
4865 let contents = std::fs::read_to_string(path)?;
4866 if let Some(text) = extract_text_from_payload(&contents) {
4867 return Ok(Some(text));
4868 }
4869 }
4870 }
4871
4872 Ok(None)
4873}
4874
4875fn extract_text_from_payload(payload: &str) -> Option<String> {
4876 let value: serde_json::Value = serde_json::from_str(payload).ok()?;
4877 if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
4878 return Some(text.to_string());
4879 }
4880 if let Some(transcripts) = value.get("transcripts").and_then(|v| v.as_array()) {
4881 let joined = transcripts
4882 .iter()
4883 .filter_map(|item| item.get("text").and_then(|v| v.as_str()))
4884 .collect::<Vec<_>>()
4885 .join("\n");
4886 if !joined.is_empty() {
4887 return Some(joined);
4888 }
4889 }
4890 if let Some(items) = value.as_array() {
4891 let joined = items
4892 .iter()
4893 .filter_map(|item| {
4894 item.get("text")
4895 .or_else(|| item.get("Content"))
4896 .and_then(|v| v.as_str())
4897 })
4898 .collect::<Vec<_>>()
4899 .join(" ");
4900 if !joined.is_empty() {
4901 return Some(joined);
4902 }
4903 }
4904 None
4905}
4906
4907fn find_audio_file(output_dir: &Path) -> Result<Option<PathBuf>, InferenceError> {
4908 let mut audio_files = Vec::new();
4909 collect_audio_files(output_dir, &mut audio_files)?;
4910 audio_files.sort();
4911 Ok(audio_files.into_iter().next())
4912}
4913
4914fn collect_audio_files(dir: &Path, audio_files: &mut Vec<PathBuf>) -> Result<(), InferenceError> {
4915 for entry in std::fs::read_dir(dir)? {
4916 let path = entry?.path();
4917 if path.is_dir() {
4918 collect_audio_files(&path, audio_files)?;
4919 } else if matches!(
4920 path.extension().and_then(|ext| ext.to_str()),
4921 Some("wav" | "mp3" | "flac" | "pcm" | "m4a")
4922 ) {
4923 audio_files.push(path);
4924 }
4925 }
4926 Ok(())
4927}
4928
4929fn media_type_for_format(format: &str) -> String {
4930 match format.to_ascii_lowercase().as_str() {
4931 "mp3" => "audio/mpeg".to_string(),
4932 "flac" => "audio/flac".to_string(),
4933 "pcm" => "audio/L16".to_string(),
4934 "m4a" => "audio/mp4".to_string(),
4935 _ => "audio/wav".to_string(),
4936 }
4937}
4938
4939#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4940fn kokoro_lang_code(language: Option<&str>) -> &'static str {
4941 match language.unwrap_or("en").to_ascii_lowercase().as_str() {
4942 "en-gb" | "british" | "british english" => "b",
4943 "ja" | "japanese" => "j",
4944 "zh" | "zh-cn" | "mandarin" | "chinese" => "z",
4945 "es" | "spanish" => "e",
4946 "fr" | "french" => "f",
4947 _ => "a",
4948 }
4949}
4950
4951fn normalize_lang_code(language: &str) -> String {
4952 match language.to_ascii_lowercase().as_str() {
4953 "english" | "en-us" | "en_us" => "en".to_string(),
4954 "spanish" => "es".to_string(),
4955 "french" => "fr".to_string(),
4956 "japanese" => "ja".to_string(),
4957 "chinese" | "mandarin" => "zh".to_string(),
4958 other => match other {
4959 "en" | "es" | "fr" | "ja" | "zh" => other.to_string(),
4960 _ => "en".to_string(),
4961 },
4962 }
4963}
4964
4965fn elevenlabs_auth(schema: &ModelSchema) -> Result<(String, String), InferenceError> {
4966 match &schema.source {
4967 ModelSource::Proprietary {
4968 endpoint,
4969 auth: schema::ProprietaryAuth::ApiKeyEnv { env_var },
4970 ..
4971 } => {
4972 let key = car_secrets::resolve_env_or_keychain(env_var).ok_or_else(|| {
4973 InferenceError::InferenceFailed(format!(
4974 "missing API key {env_var}; set the environment variable or \
4975 store it with `car secrets put {env_var}`"
4976 ))
4977 })?;
4978 Ok((endpoint.clone(), key))
4979 }
4980 _ => Err(InferenceError::InferenceFailed(format!(
4981 "model {} is not an ElevenLabs proprietary model",
4982 schema.id
4983 ))),
4984 }
4985}
4986
4987fn elevenlabs_output_format(format: &str) -> &'static str {
4988 match format.to_ascii_lowercase().as_str() {
4989 "mp3" => "mp3_44100_128",
4990 "pcm" => "pcm_16000",
4991 _ => "wav_44100",
4992 }
4993}
4994
4995fn benchmark_priors_paths(models_dir: &Path) -> Vec<PathBuf> {
4996 let mut paths = Vec::new();
4997
4998 let direct = models_dir.join("benchmark_priors.json");
4999 if !paths.contains(&direct) {
5000 paths.push(direct);
5001 }
5002
5003 if let Some(parent) = models_dir.parent() {
5004 let parent_path = parent.join("benchmark_priors.json");
5005 if !paths.contains(&parent_path) {
5006 paths.push(parent_path);
5007 }
5008 }
5009
5010 if let Some(path) = std::env::var_os("CAR_BENCHMARK_PRIORS_PATH") {
5011 let path = PathBuf::from(path);
5012 if !paths.contains(&path) {
5013 paths.push(path);
5014 }
5015 }
5016
5017 paths
5018}
5019
5020fn load_benchmark_prior_health(
5021 models_dir: &Path,
5022 schemas: &[ModelSchema],
5023) -> Vec<ModelBenchmarkPriorHealth> {
5024 let mut priors = std::collections::BTreeMap::new();
5025 for path in benchmark_priors_paths(models_dir) {
5026 let Ok(loaded) = routing_ext::load_benchmark_priors(&path) else {
5027 continue;
5028 };
5029 for (model_id, prior) in loaded {
5030 let model_name = schemas
5031 .iter()
5032 .find(|schema| schema.id == model_id)
5033 .map(|schema| schema.name.clone());
5034 priors.insert(
5035 model_id.clone(),
5036 ModelBenchmarkPriorHealth {
5037 model_id,
5038 model_name,
5039 overall_score: prior.overall_score,
5040 overall_latency_ms: prior.overall_latency_ms,
5041 task_scores: prior.task_scores,
5042 task_latency_ms: prior.task_latency_ms,
5043 source_path: path.clone(),
5044 },
5045 );
5046 }
5047 }
5048
5049 priors.into_values().collect()
5050}
5051
5052#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5053fn kokoro_runtime_fallback_enabled() -> bool {
5054 std::env::var("CAR_SPEECH_KOKORO_FALLBACK")
5055 .ok()
5056 .map(|value| {
5057 !matches!(
5058 value.trim().to_ascii_lowercase().as_str(),
5059 "0" | "false" | "off"
5060 )
5061 })
5062 .unwrap_or(true)
5063}
5064
5065#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5066fn speech_runtime_mlx_audio_spec() -> String {
5067 std::env::var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC")
5068 .ok()
5069 .filter(|value| !value.trim().is_empty())
5070 .unwrap_or_else(|| "mlx-audio==0.4.2".to_string())
5071}
5072
5073#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5074fn speech_runtime_spacy_model_spec() -> String {
5075 std::env::var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC")
5076 .ok()
5077 .filter(|value| !value.trim().is_empty())
5078 .unwrap_or_else(|| {
5079 "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl".to_string()
5080 })
5081}
5082
5083#[cfg(test)]
5084mod tests {
5085 use super::*;
5086 use tempfile::TempDir;
5087
5088 #[test]
5089 fn no_backend_hint_fires_on_missing_backend_phrases() {
5090 for phrase in [
5092 "no credential for proprietary provider 'parslee'",
5093 "model not found",
5094 "no models available",
5095 "model declares ModelSource::Delegated but no inference runner is registered",
5096 ] {
5097 let hint = no_backend_recovery_hint(phrase)
5098 .unwrap_or_else(|| panic!("expected a hint for {phrase:?}"));
5099 assert!(hint.contains("car models pull"));
5100 assert!(hint.contains("car auth login parslee"));
5101 assert!(hint.contains(phrase));
5103 }
5104 }
5105
5106 #[test]
5107 fn no_backend_hint_passes_through_transient_errors() {
5108 for phrase in [
5111 "API returned 401 Unauthorized",
5112 "API returned 429 Too Many Requests",
5113 "API returned 500 Internal Server Error",
5114 "connection refused",
5115 "request timed out",
5116 "parse response: unexpected end of input",
5117 ] {
5118 assert!(
5119 no_backend_recovery_hint(phrase).is_none(),
5120 "transient error wrongly classified as no-backend: {phrase:?}"
5121 );
5122 }
5123 }
5124
5125 static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
5128
5129 fn test_config(models_dir: PathBuf) -> InferenceConfig {
5130 InferenceConfig {
5131 models_dir,
5132 device: None,
5133 generation_model: "Qwen3-0.6B".into(),
5134 preferred_generation_model: None,
5135 embedding_model: "Qwen3-Embedding-0.6B".into(),
5136 preferred_embedding_model: None,
5137 classification_model: "Qwen3-0.6B".into(),
5138 preferred_classification_model: None,
5139 }
5140 }
5141
5142 #[tokio::test]
5143 async fn tokenize_rejects_known_remote_model_with_unsupported_mode() {
5144 let tmp = TempDir::new().unwrap();
5149 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5150 let remote_id = engine
5151 .list_schemas()
5152 .into_iter()
5153 .find(|s| !s.is_local())
5154 .map(|s| s.id)
5155 .expect("built-in catalog should include at least one remote model schema");
5156
5157 let err = engine
5158 .tokenize(&remote_id, "hello")
5159 .await
5160 .expect_err("remote tokenize must error");
5161 match err {
5162 InferenceError::UnsupportedMode { mode, backend, .. } => {
5163 assert_eq!(mode, "tokenize/detokenize");
5164 assert_eq!(backend, "remote");
5165 }
5166 other => panic!("expected UnsupportedMode, got {other:?}"),
5167 }
5168
5169 let err = engine
5170 .detokenize(&remote_id, &[1, 2, 3])
5171 .await
5172 .expect_err("remote detokenize must error");
5173 assert!(
5174 matches!(err, InferenceError::UnsupportedMode { .. }),
5175 "expected UnsupportedMode, got {err:?}"
5176 );
5177 }
5178
5179 #[test]
5180 fn engine_loads_benchmark_priors_on_startup() {
5181 let _env = ENV_MUTEX.lock().unwrap();
5182 let tmp = TempDir::new().unwrap();
5183 let priors_path = tmp.path().join("benchmark_priors.json");
5184 std::fs::write(
5185 &priors_path,
5186 serde_json::json!({
5187 "model_id": "qwen/qwen3-8b:q4_k_m",
5188 "overall_score": 0.88
5189 })
5190 .to_string(),
5191 )
5192 .unwrap();
5193
5194 unsafe {
5195 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5196 }
5197
5198 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5199 let tracker = engine.outcome_tracker.blocking_read();
5200 let profile = tracker
5201 .profile("qwen/qwen3-8b:q4_k_m")
5202 .expect("benchmark prior should create a profile");
5203 assert!((profile.ema_quality - 0.88).abs() < 0.01);
5204
5205 unsafe {
5206 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5207 }
5208 }
5209
5210 #[test]
5211 fn benchmark_priors_do_not_override_observed_profiles() {
5212 let _env = ENV_MUTEX.lock().unwrap();
5213 let tmp = TempDir::new().unwrap();
5214 let models_dir = tmp.path().join("models");
5215 std::fs::create_dir_all(&models_dir).unwrap();
5216
5217 let observed = vec![ModelProfile {
5218 model_id: "qwen/qwen3-8b:q4_k_m".into(),
5219 total_calls: 12,
5220 success_count: 3,
5221 fail_count: 9,
5222 total_latency_ms: 1200,
5223 total_input_tokens: 0,
5224 total_output_tokens: 0,
5225 task_stats: std::collections::HashMap::new(),
5226 ema_quality: 0.21,
5227 quality_per_1k_tokens: 0.0,
5228 updated_at: 1,
5229 }];
5230 std::fs::write(
5231 models_dir.join("outcome_profiles.json"),
5232 serde_json::to_string(&observed).unwrap(),
5233 )
5234 .unwrap();
5235
5236 let priors_path = tmp.path().join("benchmark_priors.json");
5237 std::fs::write(
5238 &priors_path,
5239 serde_json::json!({
5240 "model_id": "qwen/qwen3-8b:q4_k_m",
5241 "overall_score": 0.95
5242 })
5243 .to_string(),
5244 )
5245 .unwrap();
5246
5247 unsafe {
5248 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5249 }
5250
5251 let engine = InferenceEngine::new(test_config(models_dir));
5252 let tracker = engine.outcome_tracker.blocking_read();
5253 let profile = tracker
5254 .profile("qwen/qwen3-8b:q4_k_m")
5255 .expect("observed profile should remain present");
5256 assert!((profile.ema_quality - 0.21).abs() < 0.01);
5257 assert_eq!(profile.total_calls, 12);
5258
5259 unsafe {
5260 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5261 }
5262 }
5263
5264 #[test]
5265 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5266 fn speech_runtime_package_spec_defaults_and_overrides() {
5267 let _env = ENV_MUTEX.lock().unwrap();
5268 unsafe {
5269 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5270 }
5271 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.2");
5272
5273 unsafe {
5274 std::env::set_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC", "mlx-audio==0.4.1");
5275 }
5276 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.1");
5277
5278 unsafe {
5279 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5280 }
5281 }
5282
5283 #[test]
5284 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5285 fn speech_runtime_spacy_model_spec_defaults_and_overrides() {
5286 let _env = ENV_MUTEX.lock().unwrap();
5287 unsafe {
5288 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5289 }
5290 assert!(
5291 speech_runtime_spacy_model_spec().starts_with("en-core-web-sm @ https://github.com/")
5292 );
5293
5294 unsafe {
5295 std::env::set_var(
5296 "CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC",
5297 "en-core-web-sm==3.8.0",
5298 );
5299 }
5300 assert_eq!(speech_runtime_spacy_model_spec(), "en-core-web-sm==3.8.0");
5301
5302 unsafe {
5303 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5304 }
5305 }
5306
5307 #[test]
5308 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5309 fn kokoro_runtime_fallback_defaults_on() {
5310 unsafe {
5311 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5312 }
5313 assert!(kokoro_runtime_fallback_enabled());
5314
5315 unsafe {
5316 std::env::set_var("CAR_SPEECH_KOKORO_FALLBACK", "false");
5317 }
5318 assert!(!kokoro_runtime_fallback_enabled());
5319
5320 unsafe {
5321 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5322 }
5323 }
5324
5325 #[test]
5326 fn preferred_local_tts_wins_over_builtin_rank() {
5327 let tmp = TempDir::new().unwrap();
5328 let mut engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5329 engine.set_speech_policy(SpeechPolicy {
5330 prefer_local: true,
5331 allow_remote_fallback: false,
5332 preferred_local_stt: None,
5333 preferred_local_tts: Some("Kokoro-82M-6bit".into()),
5334 preferred_remote_stt: None,
5335 preferred_remote_tts: None,
5336 });
5337
5338 let schema = engine
5339 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
5340 .expect("preferred local TTS should resolve");
5341 assert_eq!(schema.name, "Kokoro-82M-6bit");
5342 }
5343
5344 #[test]
5345 fn preferred_discovered_vllm_mlx_model_wins_generate_routing() {
5346 let tmp = TempDir::new().unwrap();
5347 let mut config = test_config(tmp.path().join("models"));
5348 config.preferred_generation_model =
5349 Some("vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit".into());
5350 let mut engine = InferenceEngine::new(config);
5351 let schema = crate::vllm_mlx::to_model_schema(
5352 &crate::vllm_mlx::DiscoveredModel {
5353 id: "mlx-community/gemma-3n-E2B-it-lm-4bit".into(),
5354 owned_by: Some("mlx-community".into()),
5355 },
5356 "http://127.0.0.1:8001",
5357 );
5358 engine.register_model(schema);
5359
5360 let rt = tokio::runtime::Runtime::new().unwrap();
5361 let decision = rt.block_on(engine.route_adaptive("say hello in one sentence"));
5362 assert_eq!(
5363 decision.model_id,
5364 "vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit"
5365 );
5366 assert_eq!(decision.strategy, RoutingStrategy::Explicit);
5367 assert_eq!(decision.reason, "preferred generation model override");
5368 }
5369
5370 #[test]
5375 fn inference_result_serializes_with_full_shape() {
5376 use crate::tasks::generate::ToolCall;
5377 use std::collections::HashMap;
5378
5379 let mut args = HashMap::new();
5380 args.insert("path".to_string(), serde_json::json!("README.md"));
5381
5382 let result = InferenceResult {
5383 text: String::new(),
5384 bounding_boxes: Vec::new(),
5385 tool_calls: vec![ToolCall {
5386 id: None,
5387 name: "read_file".into(),
5388 arguments: args,
5389 }],
5390 trace_id: "trace-abc".into(),
5391 model_used: "test-model".into(),
5392 latency_ms: 1234,
5393 time_to_first_token_ms: Some(180),
5394 usage: Some(TokenUsage {
5395 prompt_tokens: 100,
5396 completion_tokens: 50,
5397 total_tokens: 150,
5398 context_window: 8192,
5399 }),
5400 provider_output_items: Vec::new(),
5401 };
5402
5403 let json = serde_json::to_value(&result).expect("serialize");
5404
5405 assert_eq!(json["text"].as_str(), Some(""));
5407 assert_eq!(json["trace_id"].as_str(), Some("trace-abc"));
5408 assert_eq!(json["model_used"].as_str(), Some("test-model"));
5409 assert_eq!(json["latency_ms"].as_u64(), Some(1234));
5410
5411 let tool_calls = json["tool_calls"].as_array().expect("tool_calls array");
5413 assert_eq!(tool_calls.len(), 1);
5414 assert_eq!(tool_calls[0]["name"].as_str(), Some("read_file"));
5415 assert_eq!(
5416 tool_calls[0]["arguments"]["path"].as_str(),
5417 Some("README.md")
5418 );
5419
5420 let usage = &json["usage"];
5422 assert_eq!(usage["prompt_tokens"].as_u64(), Some(100));
5423 assert_eq!(usage["completion_tokens"].as_u64(), Some(50));
5424 assert_eq!(usage["total_tokens"].as_u64(), Some(150));
5425 assert_eq!(usage["context_window"].as_u64(), Some(8192));
5426
5427 assert_eq!(json["time_to_first_token_ms"].as_u64(), Some(180));
5429 }
5430
5431 #[test]
5437 fn inference_result_top_level_keys_are_locked() {
5438 use std::collections::BTreeSet;
5439
5440 let result = InferenceResult {
5441 text: "anything".into(),
5442 bounding_boxes: Vec::new(),
5443 tool_calls: vec![],
5444 trace_id: "t".into(),
5445 model_used: "m".into(),
5446 latency_ms: 0,
5447 time_to_first_token_ms: None,
5448 usage: None,
5449 provider_output_items: Vec::new(),
5450 };
5451
5452 let json = serde_json::to_value(&result).expect("serialize");
5453 let keys: BTreeSet<&str> = json
5454 .as_object()
5455 .expect("top-level object")
5456 .keys()
5457 .map(String::as_str)
5458 .collect();
5459
5460 let expected: BTreeSet<&str> = [
5461 "text",
5462 "tool_calls",
5463 "trace_id",
5464 "model_used",
5465 "latency_ms",
5466 "time_to_first_token_ms",
5467 "usage",
5468 ]
5469 .into_iter()
5470 .collect();
5471
5472 assert_eq!(
5473 keys, expected,
5474 "infer response top-level keys drifted -- update both the test \
5475 and the WebSocket protocol documentation if this is intentional"
5476 );
5477
5478 for key in &keys {
5480 assert!(
5481 !key.chars().any(|c| c.is_uppercase()) && !key.contains('-'),
5482 "key '{}' is not snake_case",
5483 key
5484 );
5485 }
5486 }
5487
5488 #[test]
5492 fn inference_result_serializes_plain_text_response() {
5493 let result = InferenceResult {
5494 text: "hello world".into(),
5495 bounding_boxes: Vec::new(),
5496 tool_calls: vec![],
5497 trace_id: "trace-xyz".into(),
5498 model_used: "test-model".into(),
5499 latency_ms: 42,
5500 time_to_first_token_ms: None,
5501 usage: None,
5502 provider_output_items: Vec::new(),
5503 };
5504
5505 let json = serde_json::to_value(&result).expect("serialize");
5506 assert_eq!(json["text"], "hello world");
5507 assert!(json["tool_calls"].is_array());
5508 assert_eq!(json["tool_calls"].as_array().unwrap().len(), 0);
5509 assert_eq!(json["model_used"], "test-model");
5510 assert!(json["usage"].is_null());
5511 assert!(json["time_to_first_token_ms"].is_null());
5514 }
5515
5516 #[test]
5528 fn generate_request_deserializes_intent_field_from_json_rpc_params() {
5529 use crate::intent::{IntentHint, TaskHint};
5530 use crate::schema::ModelCapability;
5531
5532 let params = serde_json::json!({
5535 "prompt": "summarize this email",
5536 "intent": {
5537 "task": "chat",
5538 "prefer_local": true,
5539 "require": ["tool_use"],
5540 },
5541 });
5542
5543 let req: GenerateRequest =
5544 serde_json::from_value(params).expect("GenerateRequest deserialize");
5545
5546 let intent = req.intent.as_ref().expect("intent field deserialized");
5547 assert_eq!(intent.task, Some(TaskHint::Chat));
5548 assert!(intent.prefer_local);
5549 assert_eq!(intent.require, vec![ModelCapability::ToolUse]);
5550
5551 let back: serde_json::Value =
5555 serde_json::to_value(&req).expect("re-serialize GenerateRequest");
5556 assert_eq!(back["intent"]["task"], "chat");
5557 assert_eq!(back["intent"]["prefer_local"], true);
5558 assert_eq!(back["intent"]["require"][0], "tool_use");
5559
5560 let default_req: GenerateRequest = serde_json::from_value(serde_json::json!({
5565 "prompt": "x",
5566 "intent": {},
5567 }))
5568 .unwrap();
5569 let default_intent = default_req.intent.expect("present but empty");
5570 assert_eq!(default_intent.task, None);
5571 assert!(!default_intent.prefer_local);
5572 assert!(default_intent.require.is_empty());
5573
5574 let no_intent: GenerateRequest =
5577 serde_json::from_value(serde_json::json!({"prompt": "x"})).unwrap();
5578 assert!(no_intent.intent.is_none());
5579 }
5580
5581 #[test]
5582 fn rerank_prompt_matches_upstream_template_shape() {
5583 let p = rerank_prompt(
5584 "retrieve relevant passages",
5585 "who runs the treasury?",
5586 "doc x",
5587 );
5588 assert!(p.contains("<|im_start|>system"));
5589 assert!(p.contains("Note that the answer can only be \"yes\" or \"no\"."));
5590 assert!(p.contains("<|im_start|>user\n<Instruct>: retrieve relevant passages"));
5591 assert!(p.contains("<Query>: who runs the treasury?"));
5592 assert!(p.contains("<Document>: doc x<|im_end|>"));
5593 assert!(p.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
5594 }
5595
5596 #[test]
5597 fn rerank_score_yes_and_no_exactly() {
5598 assert_eq!(score_from_rerank_output("yes", "m"), 1.0);
5599 assert_eq!(score_from_rerank_output("no", "m"), 0.0);
5600 }
5601
5602 #[test]
5603 fn rerank_score_handles_case_leading_space_and_chat_sentinels() {
5604 assert_eq!(score_from_rerank_output(" Yes", "m"), 1.0);
5607 assert_eq!(score_from_rerank_output("\nno.", "m"), 0.0);
5608 assert_eq!(score_from_rerank_output("<|im_end|>yes", "m"), 1.0);
5609 }
5610
5611 #[test]
5612 fn rerank_score_scans_up_to_three_tokens() {
5613 assert_eq!(score_from_rerank_output("_bos_ yes", "m"), 1.0);
5616 }
5617
5618 #[test]
5619 fn rerank_score_unexpected_is_neutral() {
5620 assert_eq!(score_from_rerank_output("maybe", "m"), 0.5);
5623 assert_eq!(score_from_rerank_output("", "m"), 0.5);
5624 }
5625}