1pub mod adaptive_router;
26pub mod backend;
27pub mod backend_cache;
28pub mod handle;
29pub mod hardware;
30pub mod intent;
31pub mod key_pool;
32pub mod models;
33pub mod outcome;
34pub mod protocol;
35pub mod registry;
36pub mod remote;
37pub mod router;
38pub mod routing_ext;
39pub mod runner;
40pub mod schema;
41pub mod service;
42pub mod stream;
43pub mod tasks;
44pub mod vllm_mlx;
45
46use std::path::{Path, PathBuf};
47use std::sync::Arc;
48use std::time::Instant;
49use std::time::{SystemTime, UNIX_EPOCH};
50
51use reqwest::multipart::{Form, Part};
52use serde::Serialize;
53use thiserror::Error;
54#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
55use tokio::process::Command;
56#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
57use tokio::sync::Mutex;
58use tokio::sync::RwLock;
59use tracing::{debug, instrument};
60
61pub use adaptive_router::{
63 AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy,
64};
65pub use handle::InferenceHandle;
66pub use intent::{IntentHint, TaskHint};
67pub use key_pool::{KeyPool, KeyStats};
68pub use outcome::{
69 CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
70};
71pub use registry::{
72 ModelFilter, ModelInfo, ModelRuntimeRequirement, ModelUpgrade, UnifiedRegistry,
73};
74pub use remote::RemoteBackend;
75pub use routing_ext::{
76 CircuitBreaker, CircuitBreakerRegistry, CircuitState, ImplicitSignal, ImplicitSignalType,
77 RoutingMode, SpendControl, SpendLimitExceeded, SpendLimits, SpendStatus,
78};
79pub use runner::{
80 current_inference_runner, set_inference_runner, EventEmitter, InferenceRunner, RunnerError,
81 RunnerResult,
82};
83pub use schema::{
84 ApiProtocol, BenchmarkScore, CostModel, ModelCapability, ModelSchema, ModelSource,
85 PerformanceEnvelope, ProprietaryAuth,
86};
87
88pub use adaptive_router::TaskComplexity;
90#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
91pub use backend::CandleBackend;
92#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
93pub use backend::EmbeddingBackend;
94#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
95pub use backend::MlxBackend;
96pub use hardware::HardwareInfo;
97pub use models::{ModelRegistry, ModelRole};
98pub use router::{ModelRouter, RoutingDecision};
99pub use stream::{StreamAccumulator, StreamEvent};
100pub use tasks::{
101 parse_boxes, BoundingBox, ClassifyRequest, ClassifyResult, ContentBlock, EmbedRequest,
102 GenerateImageRequest, GenerateImageResult, GenerateParams, GenerateRequest,
103 GenerateVideoRequest, GenerateVideoResult, GroundRequest, GroundResult, Message, RerankRequest,
104 RerankResult, RerankedDocument, RoutingWorkload, SynthesizeRequest, SynthesizeResult,
105 ThinkingMode, ToolCall, TranscribeRequest, TranscribeResult, VideoMode,
106};
107#[derive(Error, Debug)]
110pub enum InferenceError {
111 #[error("model not found: {0}")]
112 ModelNotFound(String),
113
114 #[error("model download failed: {0}")]
115 DownloadFailed(String),
116
117 #[error("inference failed: {0}")]
118 InferenceFailed(String),
119
120 #[error("mode {mode} not implemented on backend {backend}: {reason}")]
125 UnsupportedMode {
126 mode: &'static str,
127 backend: &'static str,
128 reason: &'static str,
129 },
130
131 #[error("tokenization error: {0}")]
132 TokenizationError(String),
133
134 #[error("device error: {0}")]
135 DeviceError(String),
136
137 #[error("io error: {0}")]
138 Io(#[from] std::io::Error),
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub enum Device {
144 Cpu,
145 Metal,
146 Cuda(usize), }
148
149impl Device {
150 pub fn auto() -> Self {
152 #[cfg(feature = "metal")]
153 {
154 return Device::Metal;
155 }
156 #[cfg(feature = "cuda")]
157 {
158 return Device::Cuda(0);
159 }
160 #[cfg(not(any(feature = "metal", feature = "cuda")))]
161 {
162 Device::Cpu
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct InferenceConfig {
170 pub models_dir: std::path::PathBuf,
172 pub device: Option<Device>,
174 pub generation_model: String,
176 pub preferred_generation_model: Option<String>,
178 pub embedding_model: String,
180 pub preferred_embedding_model: Option<String>,
182 pub classification_model: String,
184 pub preferred_classification_model: Option<String>,
186}
187
188impl Default for InferenceConfig {
189 fn default() -> Self {
190 let models_dir = dirs_next()
191 .unwrap_or_else(|| std::path::PathBuf::from("."))
192 .join(".car")
193 .join("models");
194
195 let hw = HardwareInfo::detect();
196
197 Self {
198 models_dir,
199 device: None,
200 generation_model: hw.recommended_model,
201 preferred_generation_model: None,
202 embedding_model: "Qwen3-Embedding-0.6B".to_string(),
203 preferred_embedding_model: None,
204 classification_model: "Qwen3-0.6B".to_string(),
205 preferred_classification_model: None,
206 }
207 }
208}
209
210fn dirs_next() -> Option<std::path::PathBuf> {
211 std::env::var("HOME").ok().map(std::path::PathBuf::from)
212}
213
214#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
216pub struct TokenUsage {
217 pub prompt_tokens: u64,
219 pub completion_tokens: u64,
221 pub total_tokens: u64,
223 pub context_window: u64,
225}
226
227#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
229pub struct InferenceResult {
230 pub text: String,
232 pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
234 #[serde(default, skip_serializing_if = "Vec::is_empty")]
241 pub bounding_boxes: Vec<crate::tasks::grounding::BoundingBox>,
242 pub trace_id: String,
244 pub model_used: String,
246 pub latency_ms: u64,
248 #[serde(default)]
260 pub time_to_first_token_ms: Option<u64>,
261 pub usage: Option<TokenUsage>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
276 pub provider_output_items: Vec<serde_json::Value>,
277}
278
279impl InferenceResult {
280 pub fn has_tool_calls(&self) -> bool {
282 !self.tool_calls.is_empty()
283 }
284}
285
286#[derive(Debug, Clone, Serialize)]
287pub struct SpeechRuntimeHealth {
288 pub root: PathBuf,
289 pub installed: bool,
290 pub python: PathBuf,
291 pub stt_command: PathBuf,
292 pub tts_command: PathBuf,
293 pub configured_python: Option<String>,
294 pub detected_python: Option<String>,
295}
296
297#[derive(Debug, Clone, Serialize)]
298pub struct SpeechModelHealth {
299 pub id: String,
300 pub name: String,
301 pub provider: String,
302 pub capability: ModelCapability,
303 pub is_local: bool,
304 pub available: bool,
305 pub cached: bool,
306 pub selected_by_default: bool,
307 pub source: String,
308}
309
310#[derive(Debug, Clone, Serialize)]
311pub struct SpeechHealthReport {
312 pub runtime: SpeechRuntimeHealth,
313 pub local_models: Vec<SpeechModelHealth>,
314 pub remote_models: Vec<SpeechModelHealth>,
315 pub elevenlabs_configured: bool,
316 pub prefer_local: bool,
317 pub allow_remote_fallback: bool,
318 pub preferred_local_stt: Option<String>,
319 pub preferred_local_tts: Option<String>,
320 pub preferred_remote_stt: Option<String>,
321 pub preferred_remote_tts: Option<String>,
322 pub local_stt_default: Option<String>,
323 pub local_tts_default: Option<String>,
324 pub remote_stt_default: Option<String>,
325 pub remote_tts_default: Option<String>,
326}
327
328#[derive(Debug, Clone, Serialize)]
329pub struct ModelDefaultHealth {
330 pub capability: ModelCapability,
331 pub configured_model: String,
332 pub available: bool,
333 pub is_local: bool,
334 pub provider: Option<String>,
335}
336
337#[derive(Debug, Clone, Serialize)]
338pub struct ModelProviderHealth {
339 pub provider: String,
340 pub configured: bool,
341 pub local_models: usize,
342 pub remote_models: usize,
343 pub available_models: usize,
344 pub capabilities: Vec<ModelCapability>,
345}
346
347#[derive(Debug, Clone, Serialize)]
348pub struct ModelCapabilityHealth {
349 pub capability: ModelCapability,
350 pub total_models: usize,
351 pub available_models: usize,
352 pub local_available_models: usize,
353 pub remote_available_models: usize,
354}
355
356#[derive(Debug, Clone, Serialize)]
357pub struct RoutingScenarioHealth {
358 pub name: String,
359 pub workload: RoutingWorkload,
360 pub task_family: String,
361 pub has_tools: bool,
362 pub has_vision: bool,
363 pub prefer_local: bool,
364 pub quality_first_cold_start: bool,
365 pub bootstrap_min_task_observations: u64,
366 pub bootstrap_quality_floor: f64,
367 pub model_id: String,
368 pub model_name: String,
369 pub reason: String,
370 pub strategy: RoutingStrategy,
371}
372
373#[derive(Debug, Clone, Serialize)]
374pub struct ModelBenchmarkPriorHealth {
375 pub model_id: String,
376 pub model_name: Option<String>,
377 pub overall_score: f64,
378 pub overall_latency_ms: Option<f64>,
379 pub task_scores: std::collections::HashMap<String, f64>,
380 pub task_latency_ms: std::collections::HashMap<String, f64>,
381 pub source_path: PathBuf,
382}
383
384#[derive(Debug, Clone, Serialize)]
385pub struct ModelHealthReport {
386 pub total_models: usize,
387 pub available_models: usize,
388 pub local_models: usize,
389 pub remote_models: usize,
390 pub defaults: Vec<ModelDefaultHealth>,
391 pub providers: Vec<ModelProviderHealth>,
392 pub capabilities: Vec<ModelCapabilityHealth>,
393 pub routing_prefer_local: bool,
394 pub routing_quality_first_cold_start: bool,
395 pub routing_min_observations: u64,
396 pub routing_bootstrap_min_task_observations: u64,
397 pub routing_bootstrap_quality_floor: f64,
398 pub routing_quality_weight: f64,
399 pub routing_latency_weight: f64,
400 pub routing_cost_weight: f64,
401 pub routing_scenarios: Vec<RoutingScenarioHealth>,
402 pub benchmark_priors: Vec<ModelBenchmarkPriorHealth>,
403 pub speech: SpeechHealthReport,
404}
405
406#[derive(Debug, Clone, Serialize)]
407pub struct SpeechInstallReport {
408 pub name: String,
409 pub hf_repo: String,
410 pub snapshot_path: PathBuf,
411 pub files_downloaded: usize,
412}
413
414#[derive(Debug, Clone, Serialize)]
415pub struct SpeechSmokePathReport {
416 pub path: String,
417 pub tts_model: String,
418 pub stt_model: String,
419 pub audio_path: PathBuf,
420 pub transcript: String,
421}
422
423#[derive(Debug, Clone, Serialize, Default)]
424pub struct SpeechSmokeReport {
425 pub local: Option<SpeechSmokePathReport>,
426 pub remote: Option<SpeechSmokePathReport>,
427 pub skipped: Vec<String>,
428}
429
430#[derive(Debug, Clone, Serialize, Default)]
431pub struct SpeechPolicy {
432 pub prefer_local: bool,
433 pub allow_remote_fallback: bool,
434 pub preferred_local_stt: Option<String>,
435 pub preferred_local_tts: Option<String>,
436 pub preferred_remote_stt: Option<String>,
437 pub preferred_remote_tts: Option<String>,
438}
439
440pub struct InferenceEngine {
445 pub config: InferenceConfig,
446 pub unified_registry: UnifiedRegistry,
448 pub adaptive_router: AdaptiveRouter,
450 pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
452 remote_backend: RemoteBackend,
454 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
460 mlx_backends: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
461 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
466 flux_cache: Arc<backend_cache::BackendCache<backend::mlx_flux::FluxBackend>>,
467 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
469 ltx_cache: Arc<backend_cache::BackendCache<backend::mlx_ltx::LtxBackend>>,
470 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
473 kokoro_cache: Arc<backend_cache::BackendCache<backend::mlx_kokoro::KokoroBackend>>,
474 pub registry: models::ModelRegistry,
476 pub router: ModelRouter,
477 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
478 backend: Arc<RwLock<Option<CandleBackend>>>,
479 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
480 embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
481 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
482 speech_runtime: Arc<Mutex<Option<SpeechRuntime>>>,
483 speech_policy: SpeechPolicy,
484}
485
486impl InferenceEngine {
487 fn preferred_model_for_capability(&self, capability: ModelCapability) -> Option<&str> {
488 match capability {
489 ModelCapability::Generate => self.config.preferred_generation_model.as_deref(),
490 ModelCapability::Embed => self.config.preferred_embedding_model.as_deref(),
491 ModelCapability::Classify => self.config.preferred_classification_model.as_deref(),
492 _ => None,
493 }
494 }
495
496 fn request_needs_vision(req: &GenerateRequest) -> bool {
497 req.images.as_ref().is_some_and(|images| !images.is_empty())
498 || req.messages.as_ref().is_some_and(|messages| {
499 messages
500 .iter()
501 .any(|msg| matches!(msg, Message::UserMultimodal { .. }))
502 })
503 }
504
505 fn request_has_video(req: &GenerateRequest) -> bool {
510 let images_have_video = req
511 .images
512 .as_ref()
513 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
514 let messages_have_video = req.messages.as_ref().is_some_and(|messages| {
515 messages.iter().any(|msg| match msg {
516 Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_video),
517 _ => false,
518 })
519 });
520 images_have_video || messages_have_video
521 }
522
523 fn request_has_audio(req: &GenerateRequest) -> bool {
527 let images_have_audio = req
528 .images
529 .as_ref()
530 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
531 let messages_have_audio = req.messages.as_ref().is_some_and(|messages| {
532 messages.iter().any(|msg| match msg {
533 Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_audio),
534 _ => false,
535 })
536 });
537 images_have_audio || messages_have_audio
538 }
539
540 pub fn new(config: InferenceConfig) -> Self {
541 let registry = models::ModelRegistry::new(config.models_dir.clone());
542 let hw = HardwareInfo::detect();
543 let router = ModelRouter::new(hw.clone());
544 let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
545 let adaptive_router = AdaptiveRouter::with_default_config(hw);
546 let mut tracker = OutcomeTracker::new();
547 let profiles_path = config.models_dir.join("outcome_profiles.json");
549 if let Ok(n) = tracker.load_from_file(&profiles_path) {
550 if n > 0 {
551 tracing::info!(loaded = n, "loaded persisted model profiles");
552 }
553 }
554 let mut benchmark_models_loaded = 0usize;
555 for path in benchmark_priors_paths(&config.models_dir) {
556 match routing_ext::load_benchmark_priors(&path) {
557 Ok(priors) if !priors.is_empty() => {
558 benchmark_models_loaded += priors.len();
559 routing_ext::apply_benchmark_priors(&mut tracker, &priors);
560 tracing::info!(
561 path = %path.display(),
562 loaded = priors.len(),
563 "loaded benchmark quality priors"
564 );
565 }
566 Ok(_) => {}
567 Err(error) => {
568 tracing::warn!(path = %path.display(), %error, "failed to load benchmark priors");
569 }
570 }
571 }
572 if benchmark_models_loaded > 0 {
573 tracing::info!(
574 loaded = benchmark_models_loaded,
575 "applied benchmark priors to cold-start routing"
576 );
577 }
578 let outcome_tracker = Arc::new(RwLock::new(tracker));
579
580 let remote_backend = RemoteBackend::new();
581
582 Self {
583 config,
584 unified_registry,
585 adaptive_router,
586 outcome_tracker,
587 remote_backend,
588 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
589 mlx_backends: Arc::new(backend_cache::BackendCache::from_env()),
590 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
591 flux_cache: Arc::new(backend_cache::BackendCache::from_env()),
592 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
593 ltx_cache: Arc::new(backend_cache::BackendCache::from_env()),
594 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
595 kokoro_cache: Arc::new(backend_cache::BackendCache::from_env()),
596 registry,
597 router,
598 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
599 backend: Arc::new(RwLock::new(None)),
600 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
601 embedding_backend: Arc::new(RwLock::new(None)),
602 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
603 speech_runtime: Arc::new(Mutex::new(None)),
604 speech_policy: SpeechPolicy {
605 prefer_local: cfg!(all(
606 target_os = "macos",
607 target_arch = "aarch64",
608 not(car_skip_mlx)
609 )),
610 allow_remote_fallback: true,
611 preferred_local_stt: None,
612 preferred_local_tts: None,
613 preferred_remote_stt: None,
614 preferred_remote_tts: None,
615 },
616 }
617 }
618
619 pub async fn init_key_pool(&self) {
622 for schema in self.unified_registry.list() {
624 if schema.is_remote() {
625 self.remote_backend.register_model_keys(schema).await;
626 }
627 }
628
629 let stats_path = self.config.models_dir.join("key_pool_stats.json");
631 if let Ok(n) = self.remote_backend.key_pool.load_stats(&stats_path).await {
632 if n > 0 {
633 tracing::info!(loaded = n, "loaded persisted key pool stats");
634 }
635 }
636
637 let total = self.remote_backend.key_pool.total_keys().await;
638 if total > 0 {
639 tracing::info!(keys = total, "key pool initialized");
640 }
641 }
642
643 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
646 async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
647 let read = self.backend.read().await;
648 if read.is_some() {
649 return Ok(());
650 }
651 drop(read);
652
653 let mut write = self.backend.write().await;
654 if write.is_some() {
655 return Ok(());
656 }
657
658 let model_path = self.registry.ensure_model(model_name).await?;
659 let device = self.config.device.unwrap_or_else(Device::auto);
660 let backend = CandleBackend::load(&model_path, device)?;
661 *write = Some(backend);
662 Ok(())
663 }
664
665 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
668 async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
669 let read = self.embedding_backend.read().await;
670 if read.is_some() {
671 return Ok(());
672 }
673 drop(read);
674
675 let mut write = self.embedding_backend.write().await;
676 if write.is_some() {
677 return Ok(());
678 }
679
680 let embedding_model = self
681 .preferred_model_for_capability(ModelCapability::Embed)
682 .unwrap_or(&self.config.embedding_model);
683 let model_path = self.registry.ensure_model(embedding_model).await?;
684 let device = self.config.device.unwrap_or_else(Device::auto);
685 let backend = EmbeddingBackend::load(&model_path, device)?;
686 *write = Some(backend);
687 Ok(())
688 }
689
690 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
693 async fn ensure_mlx_embedding_backend(&self) -> Result<String, InferenceError> {
694 let embedding_model_name = self
695 .preferred_model_for_capability(ModelCapability::Embed)
696 .unwrap_or(&self.config.embedding_model)
697 .to_string();
698 let schema = self
699 .unified_registry
700 .get(&embedding_model_name)
701 .or_else(|| self.unified_registry.find_by_name(&embedding_model_name))
702 .ok_or_else(|| InferenceError::ModelNotFound(embedding_model_name.clone()))?
703 .clone();
704 self.ensure_mlx_backend(&schema).await?;
705 Ok(schema.id)
706 }
707
708 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
712 async fn ensure_mlx_backend(
713 &self,
714 schema: &ModelSchema,
715 ) -> Result<backend_cache::CachedBackend<backend::MlxBackend>, InferenceError> {
716 if !Self::supports_native_mlx(schema) {
717 return Err(InferenceError::InferenceFailed(format!(
718 "native MLX backend does not support {} ({}) yet; use vLLM-MLX or add a family-specific MLX backend",
719 schema.name, schema.family
720 )));
721 }
722
723 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
724 let size = backend_cache::estimate_model_size(&model_dir);
725 let cache = Arc::clone(&self.mlx_backends);
726 let key = schema.id.clone();
727 cache.get_or_load(&key, size, move || {
731 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
732 backend::MlxBackend::load(&model_dir)
733 }))
734 .map_err(|e| {
735 InferenceError::InferenceFailed(format!(
736 "MLX backend loading panicked (possible Metal/accelerate exception): {:?}",
737 e
738 ))
739 })?
740 })
741 }
742
743 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
750 pub async fn warm_up<S: AsRef<str>>(
751 &self,
752 schema_ids: &[S],
753 ) -> Vec<Result<(), InferenceError>> {
754 let mut results = Vec::with_capacity(schema_ids.len());
755 for id in schema_ids {
756 let id = id.as_ref();
757 let outcome: Result<(), InferenceError> = async {
758 let schema = self.unified_registry.get(id).cloned().ok_or_else(|| {
759 InferenceError::InferenceFailed(format!("warm_up: unknown schema id {id}"))
760 })?;
761 match schema.capabilities.first().copied() {
762 Some(ModelCapability::ImageGeneration) => {
763 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
764 let size = backend_cache::estimate_model_size(&model_dir);
765 let _ = self.flux_cache.get_or_load(&schema.id, size, || {
766 backend::mlx_flux::FluxBackend::load(&model_dir)
767 })?;
768 }
769 Some(ModelCapability::VideoGeneration) => {
770 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
771 let size = backend_cache::estimate_model_size(&model_dir);
772 let _ = self.ltx_cache.get_or_load(&schema.id, size, || {
773 backend::mlx_ltx::LtxBackend::load(&model_dir)
774 })?;
775 }
776 Some(ModelCapability::TextToSpeech) => {
777 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
778 let size = backend_cache::estimate_model_size(&model_dir);
779 let _ = self.kokoro_cache.get_or_load(&schema.id, size, || {
780 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
781 })?;
782 }
783 _ => {
784 let _ = self.ensure_mlx_backend(&schema).await?;
785 }
786 }
787 Ok(())
788 }
789 .await;
790 results.push(outcome);
791 }
792 results
793 }
794
795 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
797 pub async fn warm_up<S: AsRef<str>>(
798 &self,
799 _schema_ids: &[S],
800 ) -> Vec<Result<(), InferenceError>> {
801 Vec::new()
802 }
803
804 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
805 fn supports_native_mlx(schema: &ModelSchema) -> bool {
806 matches!(schema.family.as_str(), "qwen3" | "qwen2.5-vl" | "qwen2-vl")
807 }
808
809 pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
811 if let Some(model) = self.preferred_model_for_capability(ModelCapability::Generate) {
812 let ctx_len = self
813 .unified_registry
814 .get(model)
815 .or_else(|| self.unified_registry.find_by_name(model))
816 .map(|s| s.context_length)
817 .unwrap_or(0);
818 return AdaptiveRoutingDecision {
819 model_id: model.to_string(),
820 model_name: model.to_string(),
821 task: InferenceTask::Generate,
822 complexity: TaskComplexity::assess(prompt),
823 reason: "preferred generation model override".into(),
824 strategy: RoutingStrategy::Explicit,
825 predicted_quality: 0.5,
826 fallbacks: vec![],
827 context_length: ctx_len,
828 needs_compaction: false,
829 };
830 }
831 let tracker = self.outcome_tracker.read().await;
832 self.adaptive_router
833 .route(prompt, &self.unified_registry, &tracker)
834 }
835
836 pub fn route(&self, prompt: &str) -> RoutingDecision {
838 self.router.route_generate(prompt, &self.registry)
839 }
840
841 pub fn estimated_tokens(
844 &self,
845 req: &GenerateRequest,
846 model_id: Option<&str>,
847 ) -> (usize, usize, bool) {
848 let prompt_tokens = remote::estimate_tokens(&req.prompt);
849 let context_tokens = req
850 .context
851 .as_ref()
852 .map(|c| remote::estimate_tokens(c))
853 .unwrap_or(0);
854 let tools_tokens = req
855 .tools
856 .as_ref()
857 .map(|t| remote::estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
858 .unwrap_or(0);
859 let total_input = prompt_tokens + context_tokens + tools_tokens;
860
861 let context_window = model_id
862 .and_then(|id| {
863 self.unified_registry
864 .get(id)
865 .or_else(|| self.unified_registry.find_by_name(id))
866 })
867 .map(|s| s.context_length)
868 .unwrap_or(0);
869
870 let fits = context_window == 0 || (total_input + req.params.max_tokens) <= context_window;
871 (total_input, context_window, fits)
872 }
873
874 #[instrument(
877 name = "inference.generate",
878 skip_all,
879 fields(
880 model = tracing::field::Empty,
881 max_tokens = req.params.max_tokens,
882 prompt_tokens = tracing::field::Empty,
883 completion_tokens = tracing::field::Empty,
884 latency_ms = tracing::field::Empty,
885 )
886 )]
887 pub async fn generate_tracked(
888 &self,
889 req: GenerateRequest,
890 ) -> Result<InferenceResult, InferenceError> {
891 let start = Instant::now();
892
893 let (estimated_input, _, _) = self.estimated_tokens(&req, None);
895 let tracker_read = self.outcome_tracker.read().await;
896 let has_tools = req.tools.is_some();
897 let has_vision = Self::request_needs_vision(&req);
898 let preferred_model = self
899 .preferred_model_for_capability(ModelCapability::Generate)
900 .map(str::to_string);
901 let decision = match req.model.clone().or(preferred_model) {
902 Some(m) => {
903 let ctx_len = self
904 .unified_registry
905 .get(&m)
906 .or_else(|| self.unified_registry.find_by_name(&m))
907 .map(|s| s.context_length)
908 .unwrap_or(0);
909 AdaptiveRoutingDecision {
910 model_id: m.clone(),
911 model_name: m.clone(),
912 task: InferenceTask::Generate,
913 complexity: TaskComplexity::assess(&req.prompt),
914 reason: "explicit model".into(),
915 strategy: RoutingStrategy::Explicit,
916 predicted_quality: 0.5,
917 fallbacks: vec![],
918 context_length: ctx_len,
919 needs_compaction: ctx_len > 0 && estimated_input > ctx_len,
920 }
921 }
922 None => match &req.intent {
923 Some(hint) => self.adaptive_router.route_context_aware_with_intent(
924 &req.prompt,
925 estimated_input,
926 &self.unified_registry,
927 &tracker_read,
928 has_tools,
929 has_vision,
930 req.params.workload,
931 hint,
932 ),
933 None => self.adaptive_router.route_context_aware(
934 &req.prompt,
935 estimated_input,
936 &self.unified_registry,
937 &tracker_read,
938 has_tools,
939 has_vision,
940 req.params.workload,
941 ),
942 },
943 };
944 drop(tracker_read);
945
946 if decision.needs_compaction {
947 tracing::info!(
948 model = %decision.model_name,
949 prompt_tokens = estimated_input,
950 context_window = decision.context_length,
951 "prompt exceeds model context window — compaction or truncation needed"
952 );
953 }
954
955 let trace_id = {
957 let mut tracker = self.outcome_tracker.write().await;
958 tracker.record_start(&decision.model_id, decision.task, &decision.reason)
959 };
960
961 debug!(
962 model = %decision.model_name,
963 strategy = ?decision.strategy,
964 reason = %decision.reason,
965 trace = %trace_id,
966 "adaptive-routed generate request"
967 );
968
969 let mut req = req;
972 if req.params.budget_tokens == 0 && matches!(decision.complexity, TaskComplexity::Complex) {
973 let supports_thinking = self
974 .unified_registry
975 .get(&decision.model_id)
976 .map(|s| {
977 s.supported_params
978 .contains(&schema::GenerateParam::ExtendedThinking)
979 })
980 .unwrap_or(false);
981 if supports_thinking {
982 req.params.budget_tokens = 8000;
983 tracing::info!(model = %decision.model_name, budget = 8000, "auto-enabled extended thinking for complex task");
984 }
985 }
986
987 let mut models_to_try = vec![decision.model_id.clone()];
989 models_to_try.extend(decision.fallbacks.iter().cloned());
990
991 let mut last_error = None;
992
993 for candidate_id in &models_to_try {
994 #[allow(unused_mut)]
997 let mut schema = self
998 .unified_registry
999 .get(candidate_id)
1000 .or_else(|| self.unified_registry.find_by_name(candidate_id))
1001 .cloned();
1002
1003 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1005 if let Some(ref s) = schema {
1006 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1007 tracing::info!(
1008 from = %s.id, to = %mlx_equiv.id,
1009 "redirecting GGUF model to MLX equivalent on Apple Silicon"
1010 );
1011 schema = Some(mlx_equiv.clone());
1012 }
1013 }
1014
1015 let candidate_name = schema
1016 .as_ref()
1017 .map(|s| s.name.clone())
1018 .unwrap_or_else(|| candidate_id.clone());
1019
1020 let is_remote = schema
1021 .as_ref()
1022 .map(|s| s.is_remote() || s.is_vllm_mlx())
1023 .unwrap_or(false);
1024 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1025
1026 if is_delegated {
1032 let runner = match runner::current_inference_runner() {
1033 Some(r) => r,
1034 None => {
1035 last_error = Some(InferenceError::InferenceFailed(
1036 "model declares ModelSource::Delegated but no inference runner is registered"
1037 .into(),
1038 ));
1039 continue;
1040 }
1041 };
1042 let (tx, mut rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1043 let emitter = runner::EventEmitter::new(tx);
1044 let runner_req = req.clone();
1045 let runner_handle =
1046 tokio::spawn(async move { runner.run(runner_req, emitter).await });
1047 let mut accumulator = stream::StreamAccumulator::default();
1048 while let Some(evt) = rx.recv().await {
1049 accumulator.push(&evt);
1050 }
1051 let (acc_text, acc_tool_calls) = accumulator.finish();
1055 match runner_handle.await {
1056 Ok(Ok(_runner_result)) => {
1057 let elapsed = start.elapsed().as_millis() as u64;
1058 let mut tracker = self.outcome_tracker.write().await;
1059 tracker.record_complete(&trace_id, elapsed, 0, 0);
1060 return Ok(InferenceResult {
1061 text: acc_text,
1062 tool_calls: acc_tool_calls,
1063 bounding_boxes: vec![],
1064 trace_id,
1065 model_used: candidate_name,
1066 latency_ms: elapsed,
1067 time_to_first_token_ms: None,
1068 usage: None,
1069 provider_output_items: vec![],
1070 });
1071 }
1072 Ok(Err(e)) => {
1073 last_error = Some(InferenceError::InferenceFailed(e.to_string()));
1074 continue;
1075 }
1076 Err(join_err) => {
1077 last_error = Some(InferenceError::InferenceFailed(format!(
1078 "runner task panicked: {join_err}"
1079 )));
1080 continue;
1081 }
1082 }
1083 }
1084
1085 let has_tools = req.tools.is_some();
1086
1087 let context = if has_tools
1089 && req.tools.as_ref().map_or(false, |t| {
1090 t.iter().any(|tool| {
1091 tool.get("function")
1092 .and_then(|f| f.get("name"))
1093 .and_then(|n| n.as_str())
1094 == Some("done")
1095 })
1096 }) {
1097 let base = req.context.as_deref().unwrap_or("");
1098 Some(format!(
1099 "{base}\n\nIMPORTANT: When calling the `done` tool, the `result` field MUST contain a DETAILED summary of everything you found and did. This is the ONLY output the user sees. Do NOT just say 'completed' — include specific findings, data, and conclusions."
1100 ))
1101 } else {
1102 req.context.clone()
1103 };
1104
1105 let result = if is_remote {
1106 let schema_val = schema.unwrap();
1107 let _ctx_len = schema_val.context_length;
1108 let temperature = if !schema_val.supported_params.is_empty()
1111 && !schema_val
1112 .supported_params
1113 .contains(&crate::schema::GenerateParam::Temperature)
1114 {
1115 -1.0
1116 } else {
1117 req.params.temperature
1118 };
1119
1120 self.remote_backend
1126 .generate_with_tools_multi(
1127 &schema_val,
1128 &req.prompt,
1129 context.as_deref(),
1130 temperature,
1131 req.params.max_tokens,
1132 req.tools.as_deref(),
1133 req.images.as_deref(),
1134 req.messages.as_deref(),
1135 req.params.tool_choice.as_deref(),
1136 req.params.parallel_tool_calls,
1137 req.params.budget_tokens,
1138 req.cache_control,
1139 req.response_format.as_ref(),
1140 )
1141 .await
1142 .map(|(t, c, u)| (t, c, u, None::<u64>))
1147 } else {
1148 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1149 {
1150 let schema_ref = schema
1154 .as_ref()
1155 .ok_or_else(|| InferenceError::ModelNotFound(candidate_id.clone()))?;
1156
1157 if schema_ref.is_foundation_models() {
1162 if has_tools {
1163 Err(InferenceError::UnsupportedMode {
1164 mode: "tool-use",
1165 backend: "foundation-models",
1166 reason: "tool calling not yet wired through the FoundationModels \
1167 bridge — route to a remote model with ToolUse capability",
1168 })
1169 } else if Self::request_has_video(&req)
1170 || Self::request_has_audio(&req)
1171 || req.images.as_ref().is_some_and(|imgs| !imgs.is_empty())
1172 {
1173 Err(InferenceError::UnsupportedMode {
1174 mode: "multimodal-content",
1175 backend: "foundation-models",
1176 reason: "the FoundationModels bridge currently exposes text-only \
1177 generation — route image/audio/video to a remote VL model",
1178 })
1179 } else {
1180 let prompt = req.prompt.clone();
1181 let instructions = context.clone();
1182 let max_tokens = req.params.max_tokens as u32;
1183 let temperature = req.params.temperature;
1184 tokio::task::spawn_blocking(move || {
1185 crate::backend::foundation_models::generate(
1186 &prompt,
1187 instructions.as_deref(),
1188 max_tokens,
1189 temperature as f32,
1190 )
1191 })
1192 .await
1193 .map_err(|e| {
1194 InferenceError::InferenceFailed(format!(
1195 "FoundationModels task panicked: {e}"
1196 ))
1197 })
1198 .and_then(|r| r)
1199 .map(|text| (text, vec![], None, None))
1200 }
1201 } else if !schema_ref.is_mlx() {
1202 Err(InferenceError::InferenceFailed(format!(
1203 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1204 schema_ref.id
1205 )))
1206 } else if schema_ref.tags.iter().any(|t| t == "mlx-vlm-cli") {
1207 let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
1219 if !has_images {
1220 return Err(InferenceError::UnsupportedMode {
1221 mode: "text-only-on-mlx-vlm-id",
1222 backend: "mlx-vlm-cli",
1223 reason: "the `mlx-vlm/...` model IDs route exclusively \
1224 through the mlx-vlm CLI for image inference. \
1225 For text-only generation, route to a Qwen3 \
1226 text model (`mlx/qwen3-4b:4bit` etc.) — the \
1227 CLI shell-out has higher latency than the \
1228 in-process MLX text tower.",
1229 });
1230 }
1231 let vlm_status = crate::backend::mlx_vlm_cli::runtime_status();
1232 if !vlm_status.is_available() {
1233 return Err(InferenceError::InferenceFailed(vlm_status.user_message()));
1234 }
1235 let repo = match &schema_ref.source {
1236 crate::schema::ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
1237 _ => {
1238 return Err(InferenceError::InferenceFailed(format!(
1239 "model '{}' is tagged mlx-vlm-cli but its \
1240 source isn't ModelSource::Mlx — registry bug",
1241 schema_ref.id
1242 )));
1243 }
1244 };
1245 let imgs = req.images.clone().unwrap_or_default();
1246 let temp = req.params.temperature;
1247 let max_t = req.params.max_tokens;
1248 let prompt = req.prompt.clone();
1249 let text = tokio::task::spawn_blocking(move || {
1250 crate::backend::mlx_vlm_cli::generate(
1251 &repo, &prompt, &imgs, temp, max_t,
1252 )
1253 })
1254 .await
1255 .map_err(|e| {
1256 InferenceError::InferenceFailed(format!(
1257 "mlx_vlm CLI task panicked: {e}"
1258 ))
1259 })??;
1260 let bounding_boxes = parse_boxes(&text);
1261 let latency_ms = start.elapsed().as_millis() as u64;
1262 {
1263 let mut tracker = self.outcome_tracker.write().await;
1264 tracker.record_complete(&trace_id, latency_ms, 0, 0);
1265 }
1266 return Ok(InferenceResult {
1267 text,
1268 tool_calls: vec![],
1269 bounding_boxes,
1270 trace_id: trace_id.clone(),
1271 model_used: schema_ref.id.clone(),
1272 latency_ms,
1273 time_to_first_token_ms: None,
1274 usage: None,
1275 provider_output_items: Vec::new(),
1276 });
1277 } else {
1278 let handle = self.ensure_mlx_backend(schema_ref).await?;
1288 if Self::request_has_video(&req) {
1294 return Err(InferenceError::UnsupportedMode {
1295 mode: "video-content-block",
1296 backend: "native-mlx-qwen25vl",
1297 reason: "Qwen2.5-VL video understanding is on the request surface \
1298 but the video-tokenization path (frame sampling + merger) \
1299 is not yet wired; route to a remote VL provider for now",
1300 });
1301 }
1302 if Self::request_has_audio(&req) {
1303 return Err(InferenceError::UnsupportedMode {
1304 mode: "audio-content-block",
1305 backend: "native-mlx-qwen25vl",
1306 reason: "audio understanding is on the request surface (Gemma 4 \
1307 E2B/E4B and Gemini accept it) but the native MLX path \
1308 for this model does not — route to Gemini or Gemma-4",
1309 });
1310 }
1311 let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
1312 if has_images {
1313 let can_do_vision = {
1314 let guard = handle.lock().map_err(|_| {
1315 InferenceError::InferenceFailed(
1316 "MLX backend mutex poisoned".into(),
1317 )
1318 })?;
1319 guard.supports_capability(crate::schema::ModelCapability::Vision)
1320 };
1321 if !can_do_vision {
1322 return Err(InferenceError::UnsupportedMode {
1332 mode: "image-content-block",
1333 backend: "native-mlx-text",
1334 reason: "this MLX backend is a plain Qwen3 text tower. \
1335 For local image inference, route to \
1336 `mlx-vlm/qwen3-vl-2b:bf16` or another `mlx-vlm/...` \
1337 catalog ID so CAR shells out to `mlx_vlm.generate`. \
1338 Alternatives: a local vLLM-MLX VLM server, or a \
1339 remote VL model. (#115)",
1340 });
1341 }
1342 }
1343 self.generate_mlx(req.clone(), &schema_ref.id)
1344 .await
1345 .map(|(text, ttft)| (text, vec![], None, ttft))
1346 }
1347 }
1348
1349 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1350 {
1351 match self.ensure_backend(&candidate_name).await {
1352 Ok(()) => {
1353 let mut write = self.backend.write().await;
1354 let backend = write.as_mut().unwrap();
1355 tasks::generate::generate(backend, req.clone())
1356 .await
1357 .map(|(text, ttft)| (text, vec![], None, ttft))
1358 }
1359 Err(e) => Err(e),
1360 }
1361 }
1362 };
1363
1364 match result {
1365 Ok((text, tool_calls, usage, time_to_first_token_ms)) => {
1366 let latency_ms = start.elapsed().as_millis() as u64;
1367 let estimated_tokens = usage
1368 .as_ref()
1369 .map(|u| u.completion_tokens as usize)
1370 .unwrap_or_else(|| text.split_whitespace().count());
1371 {
1372 let mut tracker = self.outcome_tracker.write().await;
1373 tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
1374 }
1375 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1377 cb.record_success(candidate_id);
1378 }
1379 self.auto_save_outcomes().await;
1381
1382 let span = tracing::Span::current();
1384 span.record("model", candidate_name.as_str());
1385 span.record("latency_ms", latency_ms);
1386 if let Some(ttft) = time_to_first_token_ms {
1387 span.record("ttft_ms", ttft);
1388 }
1389 if let Some(ref u) = usage {
1390 span.record("prompt_tokens", u.prompt_tokens);
1391 span.record("completion_tokens", u.completion_tokens);
1392 }
1393
1394 let bounding_boxes = tasks::grounding::parse_boxes(&text);
1397 return Ok(InferenceResult {
1398 text,
1399 tool_calls,
1400 bounding_boxes,
1401 trace_id,
1402 model_used: candidate_name,
1403 latency_ms,
1404 time_to_first_token_ms,
1405 usage,
1406 provider_output_items: Vec::new(),
1407 });
1408 }
1409 Err(e) => {
1410 tracing::warn!(
1411 model = %candidate_name,
1412 error = %e,
1413 remaining = models_to_try.len().saturating_sub(
1414 models_to_try.iter().position(|m| m == candidate_id).unwrap_or(0) + 1
1415 ),
1416 "model failed, trying next fallback immediately"
1417 );
1418 {
1420 let mut tracker = self.outcome_tracker.write().await;
1421 let fail_trace =
1422 tracker.record_start(candidate_id, decision.task, "fallback");
1423 tracker.record_failure(&fail_trace, &e.to_string());
1424 }
1425 {
1429 let err_str = e.to_string();
1430 let is_client_error =
1431 err_str.contains("API returned 4") && !err_str.contains("429");
1432 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1433 cb.record_failure(candidate_id);
1434 if is_client_error {
1437 cb.record_failure(candidate_id);
1438 }
1439 }
1440 }
1441 #[cfg(not(all(
1443 target_os = "macos",
1444 target_arch = "aarch64",
1445 not(car_skip_mlx)
1446 )))]
1447 {
1448 let mut write = self.backend.write().await;
1449 *write = None;
1450 }
1451 last_error = Some(e);
1452 }
1453 }
1454 }
1455
1456 let e = last_error.unwrap_or(InferenceError::InferenceFailed(
1458 "no models available".into(),
1459 ));
1460 {
1461 let mut tracker = self.outcome_tracker.write().await;
1462 tracker.record_failure(&trace_id, &e.to_string());
1463 }
1464 self.auto_save_outcomes().await;
1465 Err(e)
1466 }
1467
1468 pub async fn generate_tracked_stream(
1505 &self,
1506 req: GenerateRequest,
1507 ) -> Result<tokio::sync::mpsc::Receiver<stream::StreamEvent>, InferenceError> {
1508 let has_tools = req.tools.is_some();
1509 let has_vision = Self::request_needs_vision(&req);
1510 let preferred_model = self
1511 .preferred_model_for_capability(ModelCapability::Generate)
1512 .map(str::to_string);
1513 let decision = match req.model.clone().or(preferred_model) {
1514 Some(m) => {
1515 let ctx_len = self
1516 .unified_registry
1517 .get(&m)
1518 .or_else(|| self.unified_registry.find_by_name(&m))
1519 .map(|s| s.context_length)
1520 .unwrap_or(0);
1521 AdaptiveRoutingDecision {
1522 model_id: m.clone(),
1523 model_name: m,
1524 task: InferenceTask::Generate,
1525 complexity: TaskComplexity::assess(&req.prompt),
1526 reason: "explicit model".into(),
1527 strategy: RoutingStrategy::Explicit,
1528 predicted_quality: 0.5,
1529 fallbacks: vec![],
1530 context_length: ctx_len,
1531 needs_compaction: false,
1532 }
1533 }
1534 None => {
1535 let tracker_read = self.outcome_tracker.read().await;
1536 if has_vision {
1537 self.adaptive_router.route_with_vision(
1538 &req.prompt,
1539 &self.unified_registry,
1540 &tracker_read,
1541 has_tools,
1542 )
1543 } else if has_tools {
1544 self.adaptive_router.route_with_tools(
1545 &req.prompt,
1546 &self.unified_registry,
1547 &tracker_read,
1548 )
1549 } else {
1550 self.adaptive_router
1551 .route(&req.prompt, &self.unified_registry, &tracker_read)
1552 }
1553 }
1554 };
1555
1556 #[allow(unused_mut)]
1559 let mut schema = self
1560 .unified_registry
1561 .get(&decision.model_id)
1562 .or_else(|| self.unified_registry.find_by_name(&decision.model_id))
1563 .cloned();
1564
1565 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1567 if let Some(ref s) = schema {
1568 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1569 tracing::info!(
1570 from = %s.id, to = %mlx_equiv.id,
1571 "redirecting GGUF model to MLX equivalent on Apple Silicon (stream)"
1572 );
1573 schema = Some(mlx_equiv.clone());
1574 }
1575 }
1576
1577 let is_remote = schema
1578 .as_ref()
1579 .map(|s| s.is_remote() || s.is_vllm_mlx())
1580 .unwrap_or(false);
1581
1582 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1583
1584 if is_delegated {
1585 let runner = runner::current_inference_runner().ok_or_else(|| {
1589 InferenceError::InferenceFailed(
1590 "model declares ModelSource::Delegated but no inference runner is registered \
1591 (call set_inference_runner / registerInferenceRunner / register_inference_runner)"
1592 .into(),
1593 )
1594 })?;
1595 let (tx, rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1596 let emitter = runner::EventEmitter::new(tx);
1597 let request = req.clone();
1598 tokio::spawn(async move {
1599 if let Err(e) = runner.run(request, emitter).await {
1600 tracing::warn!(error = %e, "delegated inference runner failed");
1601 }
1602 });
1603 return Ok(rx);
1604 }
1605
1606 if is_remote {
1607 let schema = schema.unwrap();
1608 self.remote_backend.register_model_keys(&schema).await;
1610
1611 self.remote_backend
1612 .generate_stream(
1613 &schema,
1614 &req.prompt,
1615 req.context.as_deref(),
1616 req.params.temperature,
1617 req.params.max_tokens,
1618 req.tools.as_deref(),
1619 req.images.as_deref(),
1620 req.params.tool_choice.as_deref(),
1621 req.params.parallel_tool_calls,
1622 req.response_format.as_ref(),
1623 )
1624 .await
1625 } else {
1626 let schema =
1627 schema.ok_or_else(|| InferenceError::ModelNotFound(decision.model_id.clone()))?;
1628 let (tx, rx) = tokio::sync::mpsc::channel(64);
1629
1630 #[cfg(any(
1636 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
1637 all(target_os = "ios", target_arch = "aarch64")
1638 ))]
1639 {
1640 if schema.is_foundation_models() {
1641 let prompt = req.prompt.clone();
1642 let instructions = req.context.clone();
1643 let max_tokens = req.params.max_tokens as u32;
1644 let temperature = req.params.temperature;
1645 let tx_clone = tx.clone();
1646 tokio::task::spawn_blocking(move || {
1647 let accum = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
1654 let accum_cb = accum.clone();
1655 let cb = crate::backend::foundation_models::StreamCallback::new(
1656 move |delta: &str| {
1657 if let Ok(mut g) = accum_cb.lock() {
1658 g.push_str(delta);
1659 }
1660 tx_clone
1661 .blocking_send(stream::StreamEvent::TextDelta(
1662 delta.to_string(),
1663 ))
1664 .is_ok()
1665 },
1666 );
1667 let result = crate::backend::foundation_models::stream(
1668 &prompt,
1669 instructions.as_deref(),
1670 max_tokens,
1671 temperature as f32,
1672 cb,
1673 );
1674 let final_text = accum.lock().map(|g| g.clone()).unwrap_or_default();
1675 let _ = tx.blocking_send(stream::StreamEvent::Done {
1676 text: final_text,
1677 tool_calls: vec![],
1678 });
1679 result
1680 });
1681 return Ok(rx);
1682 }
1683 }
1684
1685 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1687 {
1688 if !schema.is_mlx() {
1691 return Err(InferenceError::InferenceFailed(format!(
1692 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1693 schema.id
1694 )));
1695 }
1696 let backend = self.ensure_mlx_backend(&schema).await?;
1697 let model_id = schema.id.clone();
1698 let cache = Arc::clone(&self.mlx_backends);
1699 tokio::task::spawn_blocking(move || {
1705 let _ = Self::stream_local_mlx(backend, cache, model_id, req, tx);
1706 });
1707 return Ok(rx);
1708 }
1709
1710 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1711 {
1712 self.ensure_backend(&schema.name).await?;
1713 let backend = self.backend.clone();
1714 tokio::spawn(async move {
1715 let _ = Self::stream_local_candle(backend, req, tx).await;
1716 });
1717 Ok(rx)
1718 }
1719 }
1720 }
1721
1722 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1723 async fn stream_local_candle(
1724 backend_lock: Arc<RwLock<Option<CandleBackend>>>,
1725 req: GenerateRequest,
1726 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1727 ) -> Result<(), InferenceError> {
1728 let mut write = backend_lock.write().await;
1729 let backend = write
1730 .as_mut()
1731 .ok_or_else(|| InferenceError::InferenceFailed("backend not initialized".into()))?;
1732 backend.clear_kv_cache();
1733
1734 let formatted = tasks::generate::apply_chat_template(
1735 &req.prompt,
1736 req.context.as_deref(),
1737 req.params.thinking,
1738 );
1739 let tokens = backend.encode(&formatted)?;
1740 let eos = backend.eos_token_id();
1741 let eos_alt = backend.token_id("<|im_end|>");
1742 let params = &req.params;
1743
1744 if tokens.is_empty() {
1745 let _ = tx
1746 .send(stream::StreamEvent::Done {
1747 text: String::new(),
1748 tool_calls: vec![],
1749 })
1750 .await;
1751 return Ok(());
1752 }
1753
1754 let max_ctx = backend.context_length().unwrap_or(32768);
1755 let headroom = params.max_tokens.min(max_ctx / 4);
1756 let max_prompt = max_ctx.saturating_sub(headroom);
1757 let tokens = if tokens.len() > max_prompt {
1758 tokens[tokens.len() - max_prompt..].to_vec()
1759 } else {
1760 tokens
1761 };
1762
1763 let mut generated = Vec::new();
1764 let logits = backend.forward(&tokens, 0)?;
1765 let mut next_token = tasks::generate::sample_token(&logits, params)?;
1766
1767 for _ in 0..params.max_tokens {
1768 if eos.map_or(false, |id| next_token == id)
1769 || eos_alt.map_or(false, |id| next_token == id)
1770 {
1771 break;
1772 }
1773
1774 generated.push(next_token);
1775 let delta = backend.decode(&[next_token])?;
1776 if !delta.is_empty()
1777 && tx
1778 .send(stream::StreamEvent::TextDelta(delta))
1779 .await
1780 .is_err()
1781 {
1782 return Ok(());
1783 }
1784
1785 if !params.stop.is_empty() {
1786 let text_so_far = backend.decode(&generated)?;
1787 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1788 break;
1789 }
1790 }
1791
1792 let pos = tokens.len() + generated.len() - 1;
1793 let logits = backend.forward(&[next_token], pos)?;
1794 next_token = tasks::generate::sample_token(&logits, params)?;
1795 }
1796
1797 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1798 let _ = tx
1799 .send(stream::StreamEvent::Done {
1800 text,
1801 tool_calls: vec![],
1802 })
1803 .await;
1804 Ok(())
1805 }
1806
1807 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1811 fn stream_local_mlx(
1812 handle: backend_cache::CachedBackend<backend::MlxBackend>,
1813 cache: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
1814 model_id: String,
1815 req: GenerateRequest,
1816 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1817 ) -> Result<(), InferenceError> {
1818 let mut guard = handle.lock().map_err(|_| {
1819 InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
1820 })?;
1821 let backend: &mut backend::MlxBackend = &mut *guard;
1822 backend.clear_kv_cache();
1823
1824 let formatted = tasks::generate::apply_chat_template(
1825 &req.prompt,
1826 req.context.as_deref(),
1827 req.params.thinking,
1828 );
1829 let tokens = backend.encode(&formatted)?;
1830 let eos = backend.eos_token_id();
1831 let eos_alt = backend.token_id("<|im_end|>");
1832 let params = &req.params;
1833
1834 if tokens.is_empty() {
1835 let _ = tx.blocking_send(stream::StreamEvent::Done {
1836 text: String::new(),
1837 tool_calls: vec![],
1838 });
1839 return Ok(());
1840 }
1841
1842 let max_ctx = backend.context_length();
1843 let headroom = params.max_tokens.min(max_ctx / 4);
1844 let max_prompt = max_ctx.saturating_sub(headroom);
1845 let tokens = if tokens.len() > max_prompt {
1846 tokens[tokens.len() - max_prompt..].to_vec()
1847 } else {
1848 tokens
1849 };
1850
1851 let mut generated = Vec::new();
1852 let logits = match Self::catch_mlx("stream prefill", || backend.forward(&tokens, 0)) {
1857 Ok(v) => v,
1858 Err(e) => {
1859 cache.invalidate(&model_id);
1860 return Err(e);
1861 }
1862 };
1863 let mut next_token = Self::sample_from_logits(&logits, params)?;
1864
1865 for _ in 0..params.max_tokens {
1866 if eos.map_or(false, |id| next_token == id)
1867 || eos_alt.map_or(false, |id| next_token == id)
1868 {
1869 break;
1870 }
1871
1872 generated.push(next_token);
1873 let delta = backend.decode(&[next_token])?;
1874 if !delta.is_empty()
1875 && tx
1876 .blocking_send(stream::StreamEvent::TextDelta(delta))
1877 .is_err()
1878 {
1879 return Ok(());
1880 }
1881
1882 if !params.stop.is_empty() {
1883 let text_so_far = backend.decode(&generated)?;
1884 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1885 break;
1886 }
1887 }
1888
1889 let pos = tokens.len() + generated.len() - 1;
1890 let logits =
1891 match Self::catch_mlx("stream forward", || backend.forward(&[next_token], pos)) {
1892 Ok(v) => v,
1893 Err(e) => {
1894 cache.invalidate(&model_id);
1895 return Err(e);
1896 }
1897 };
1898 next_token = Self::sample_from_logits(&logits, params)?;
1899 }
1900
1901 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1902 let _ = tx.blocking_send(stream::StreamEvent::Done {
1903 text,
1904 tool_calls: vec![],
1905 });
1906 Ok(())
1907 }
1908
1909 pub async fn route_context_snapshot(
1911 &self,
1912 prompt: &str,
1913 workload: RoutingWorkload,
1914 has_tools: bool,
1915 has_vision: bool,
1916 ) -> AdaptiveRoutingDecision {
1917 let tracker = self.outcome_tracker.read().await;
1918 self.adaptive_router.route_context_aware(
1919 prompt,
1920 0,
1921 &self.unified_registry,
1922 &tracker,
1923 has_tools,
1924 has_vision,
1925 workload,
1926 )
1927 }
1928
1929 pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
1932 Ok(self.generate_tracked(req).await?.text)
1933 }
1934
1935 pub async fn tokenize(&self, model: &str, text: &str) -> Result<Vec<u32>, InferenceError> {
1947 self.assert_local_for_tokenize(model)?;
1948
1949 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1950 {
1951 let schema = self
1952 .unified_registry
1953 .get(model)
1954 .or_else(|| self.unified_registry.find_by_name(model))
1955 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1956 .clone();
1957 let handle = self.ensure_mlx_backend(&schema).await?;
1958 let guard = handle.lock().map_err(|_| {
1959 InferenceError::InferenceFailed(format!(
1960 "MLX backend mutex poisoned for {}",
1961 schema.id
1962 ))
1963 })?;
1964 return guard.tokenize_raw(text);
1965 }
1966
1967 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1968 {
1969 self.ensure_backend(model).await?;
1970 let read = self.backend.read().await;
1971 let backend = read.as_ref().ok_or_else(|| {
1972 InferenceError::InferenceFailed(
1973 "candle backend missing after ensure_backend".to_string(),
1974 )
1975 })?;
1976 backend.tokenize_raw(text)
1977 }
1978 }
1979
1980 pub async fn detokenize(&self, model: &str, tokens: &[u32]) -> Result<String, InferenceError> {
1982 self.assert_local_for_tokenize(model)?;
1983
1984 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1985 {
1986 let schema = self
1987 .unified_registry
1988 .get(model)
1989 .or_else(|| self.unified_registry.find_by_name(model))
1990 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1991 .clone();
1992 let handle = self.ensure_mlx_backend(&schema).await?;
1993 let guard = handle.lock().map_err(|_| {
1994 InferenceError::InferenceFailed(format!(
1995 "MLX backend mutex poisoned for {}",
1996 schema.id
1997 ))
1998 })?;
1999 return guard.detokenize_raw(tokens);
2000 }
2001
2002 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2003 {
2004 self.ensure_backend(model).await?;
2005 let read = self.backend.read().await;
2006 let backend = read.as_ref().ok_or_else(|| {
2007 InferenceError::InferenceFailed(
2008 "candle backend missing after ensure_backend".to_string(),
2009 )
2010 })?;
2011 backend.detokenize_raw(tokens)
2012 }
2013 }
2014
2015 fn assert_local_for_tokenize(&self, model: &str) -> Result<(), InferenceError> {
2019 if let Some(schema) = self
2020 .unified_registry
2021 .get(model)
2022 .or_else(|| self.unified_registry.find_by_name(model))
2023 {
2024 if !schema.is_local() {
2025 return Err(InferenceError::UnsupportedMode {
2026 mode: "tokenize/detokenize",
2027 backend: "remote",
2028 reason: "remote provider tokenizer is not exposed by the runtime; \
2029 use a local model (Qwen3 GGUF / MLX) for tokenizer-correctness checks",
2030 });
2031 }
2032 }
2033 Ok(())
2035 }
2036
2037 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2043 fn catch_mlx<F, T>(context: &str, f: F) -> Result<T, InferenceError>
2044 where
2045 F: FnOnce() -> Result<T, InferenceError>,
2046 {
2047 std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)).map_err(|e| {
2048 InferenceError::InferenceFailed(format!("MLX panicked during {context}: {e:?}"))
2049 })?
2050 }
2051
2052 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2055 async fn generate_mlx(
2056 &self,
2057 req: GenerateRequest,
2058 model_id: &str,
2059 ) -> Result<(String, Option<u64>), InferenceError> {
2060 let start = std::time::Instant::now();
2061
2062 let schema = self
2063 .unified_registry
2064 .get(model_id)
2065 .cloned()
2066 .ok_or_else(|| {
2067 InferenceError::InferenceFailed(format!(
2068 "generate_mlx: unknown schema id {model_id}"
2069 ))
2070 })?;
2071 let handle = self.ensure_mlx_backend(&schema).await?;
2072 let mut guard = handle.lock().map_err(|_| {
2073 InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
2074 })?;
2075 let backend: &mut backend::MlxBackend = &mut *guard;
2076 backend.clear_kv_cache();
2077
2078 let formatted = tasks::generate::apply_chat_template(
2079 &req.prompt,
2080 req.context.as_deref(),
2081 req.params.thinking,
2082 );
2083 let tokens = backend.encode(&formatted)?;
2084 let eos = backend.eos_token_id();
2085 let eos_alt = backend.token_id("<|im_end|>");
2086 let params = &req.params;
2087
2088 if tokens.is_empty() {
2089 return Ok((String::new(), None));
2090 }
2091
2092 let max_ctx = backend.context_length();
2094 let headroom = params.max_tokens.min(max_ctx / 4);
2095 let max_prompt = max_ctx.saturating_sub(headroom);
2096 let tokens = if tokens.len() > max_prompt {
2097 tokens[tokens.len() - max_prompt..].to_vec()
2098 } else {
2099 tokens
2100 };
2101
2102 let mut generated = Vec::new();
2103
2104 let logits = match Self::catch_mlx("prefill", || backend.forward(&tokens, 0)) {
2110 Ok(v) => v,
2111 Err(e) => {
2112 drop(guard);
2113 self.mlx_backends.invalidate(model_id);
2114 return Err(e);
2115 }
2116 };
2117 let mut next_token = Self::sample_from_logits(&logits, params)?;
2118 let ttft_ms = Some(start.elapsed().as_millis() as u64);
2119
2120 for _ in 0..params.max_tokens {
2121 if eos.map_or(false, |id| next_token == id)
2122 || eos_alt.map_or(false, |id| next_token == id)
2123 {
2124 break;
2125 }
2126
2127 generated.push(next_token);
2128
2129 if !params.stop.is_empty() {
2130 let text_so_far = backend.decode(&generated)?;
2131 if params.stop.iter().any(|s| text_so_far.contains(s)) {
2132 break;
2133 }
2134 }
2135
2136 let pos = tokens.len() + generated.len() - 1;
2137 let logits = match Self::catch_mlx("forward", || backend.forward(&[next_token], pos)) {
2138 Ok(v) => v,
2139 Err(e) => {
2140 drop(guard);
2141 self.mlx_backends.invalidate(model_id);
2142 return Err(e);
2143 }
2144 };
2145 next_token = Self::sample_from_logits(&logits, params)?;
2146 }
2147
2148 let text = backend.decode(&generated)?;
2149 Ok((
2150 tasks::generate::strip_thinking(&text, params.thinking),
2151 ttft_ms,
2152 ))
2153 }
2154
2155 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2157 fn sample_from_logits(logits: &[f32], params: &GenerateParams) -> Result<u32, InferenceError> {
2158 if params.temperature <= 0.0 {
2159 let (idx, _) = logits
2161 .iter()
2162 .enumerate()
2163 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2164 .ok_or_else(|| InferenceError::InferenceFailed("empty logits".into()))?;
2165 return Ok(idx as u32);
2166 }
2167
2168 let temp = params.temperature as f32;
2170 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
2171 let mut probs: Vec<f32> = logits
2172 .iter()
2173 .map(|&l| ((l - max_logit) / temp).exp())
2174 .collect();
2175 let sum: f32 = probs.iter().sum();
2176 for p in &mut probs {
2177 *p /= sum;
2178 }
2179
2180 if params.top_p < 1.0 {
2182 let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
2183 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2184 let mut cumsum = 0.0;
2185 let mut cutoff_idx = indexed.len();
2186 for (i, &(_, p)) in indexed.iter().enumerate() {
2187 cumsum += p;
2188 if cumsum > params.top_p as f32 {
2189 cutoff_idx = i + 1;
2190 break;
2191 }
2192 }
2193 let allowed: std::collections::HashSet<usize> =
2195 indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
2196 for (i, p) in probs.iter_mut().enumerate() {
2197 if !allowed.contains(&i) {
2198 *p = 0.0;
2199 }
2200 }
2201 let sum: f32 = probs.iter().sum();
2202 if sum > 0.0 {
2203 for p in &mut probs {
2204 *p /= sum;
2205 }
2206 }
2207 }
2208
2209 use rand::Rng;
2211 let mut rng = rand::rng();
2212 let r: f32 = rng.random();
2213 let mut cumsum = 0.0;
2214 for (i, &p) in probs.iter().enumerate() {
2215 cumsum += p;
2216 if cumsum >= r {
2217 return Ok(i as u32);
2218 }
2219 }
2220 Ok((probs.len() - 1) as u32)
2221 }
2222
2223 pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
2226 let instruction = req
2227 .instruction
2228 .as_deref()
2229 .unwrap_or("Retrieve relevant memory facts");
2230
2231 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2232 {
2233 let model_id = self.ensure_mlx_embedding_backend().await?;
2234 let schema = self
2235 .unified_registry
2236 .get(&model_id)
2237 .cloned()
2238 .ok_or_else(|| {
2239 InferenceError::InferenceFailed(format!("embed: unknown schema id {model_id}"))
2240 })?;
2241 let handle = self.ensure_mlx_backend(&schema).await?;
2242 let mut guard = handle.lock().map_err(|_| {
2243 InferenceError::InferenceFailed(format!(
2244 "MLX embedding backend mutex poisoned for {model_id}"
2245 ))
2246 })?;
2247 let backend: &mut backend::MlxBackend = &mut *guard;
2248
2249 let mut results = Vec::with_capacity(req.texts.len());
2250 for text in &req.texts {
2251 let embedding = if req.is_query {
2252 backend.embed_query(text, instruction)?
2253 } else {
2254 backend.embed_one(text)?
2255 };
2256 results.push(embedding);
2257 }
2258 return Ok(results);
2259 }
2260
2261 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2262 {
2263 self.ensure_embedding_backend().await?;
2264 let mut write = self.embedding_backend.write().await;
2265 let backend = write.as_mut().unwrap();
2266
2267 let mut results = Vec::with_capacity(req.texts.len());
2268 for text in &req.texts {
2269 let embedding = if req.is_query {
2270 backend.embed_query(text, instruction)?
2271 } else {
2272 backend.embed_one(text)?
2273 };
2274 results.push(embedding);
2275 }
2276 Ok(results)
2277 }
2278 }
2279
2280 pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResult, InferenceError> {
2310 if req.documents.is_empty() {
2311 return Ok(RerankResult {
2312 ranked: Vec::new(),
2313 model_used: None,
2314 });
2315 }
2316
2317 let model_name = match req.model.clone() {
2318 Some(m) => m,
2319 None => self
2320 .preferred_model_for_capability(ModelCapability::Rerank)
2321 .map(str::to_string)
2322 .ok_or_else(|| {
2323 InferenceError::InferenceFailed(
2324 "no reranker model available — pull a Qwen3-Reranker model first".into(),
2325 )
2326 })?,
2327 };
2328
2329 let schema = self
2330 .unified_registry
2331 .find_by_name(&model_name)
2332 .or_else(|| self.unified_registry.get(&model_name))
2333 .cloned()
2334 .ok_or_else(|| {
2335 InferenceError::InferenceFailed(format!(
2336 "rerank: unknown reranker model {model_name}"
2337 ))
2338 })?;
2339 if !schema.has_capability(ModelCapability::Rerank) {
2340 return Err(InferenceError::InferenceFailed(format!(
2341 "model {} does not declare the Rerank capability",
2342 schema.name
2343 )));
2344 }
2345
2346 let instruction = req.instruction.as_deref().unwrap_or(
2347 "Given a web search query, retrieve relevant passages that answer the query",
2348 );
2349
2350 let mut scored: Vec<RerankedDocument> = Vec::with_capacity(req.documents.len());
2351 for (idx, doc) in req.documents.iter().enumerate() {
2352 let prompt = rerank_prompt(instruction, &req.query, doc);
2353 let gen_req = GenerateRequest {
2354 prompt,
2355 model: Some(schema.id.clone()),
2356 params: tasks::generate::GenerateParams {
2357 temperature: 0.0,
2358 max_tokens: 3,
2362 thinking: tasks::generate::ThinkingMode::Off,
2363 ..Default::default()
2364 },
2365 context: None,
2366 tools: None,
2367 images: None,
2368 messages: None,
2369 cache_control: false,
2370 response_format: None,
2371 intent: None,
2372 };
2373 let out = self.generate(gen_req).await?;
2374 let score = score_from_rerank_output(&out, &schema.name);
2375 scored.push(RerankedDocument {
2376 index: idx,
2377 score,
2378 document: doc.clone(),
2379 });
2380 }
2381
2382 scored.sort_by(|a, b| {
2385 b.score
2386 .partial_cmp(&a.score)
2387 .unwrap_or(std::cmp::Ordering::Equal)
2388 .then_with(|| a.index.cmp(&b.index))
2389 });
2390 if let Some(n) = req.top_n {
2391 scored.truncate(n);
2392 }
2393
2394 Ok(RerankResult {
2395 ranked: scored,
2396 model_used: Some(schema.name),
2397 })
2398 }
2399
2400 pub async fn ground(&self, req: GroundRequest) -> Result<GroundResult, InferenceError> {
2410 let model_name = match req.model.clone() {
2411 Some(m) => m,
2412 None => self
2413 .preferred_model_for_capability(ModelCapability::Grounding)
2414 .map(str::to_string)
2415 .ok_or_else(|| {
2416 InferenceError::InferenceFailed(
2417 "no grounding-capable model available — pull a Qwen2.5-VL model first"
2418 .into(),
2419 )
2420 })?,
2421 };
2422
2423 let gen_req = GenerateRequest {
2424 prompt: req.prompt.clone(),
2425 model: Some(model_name),
2426 params: GenerateParams::default(),
2427 context: None,
2428 tools: None,
2429 images: Some(vec![req.image.clone()]),
2430 messages: None,
2431 cache_control: false,
2432 response_format: None,
2433 intent: None,
2434 };
2435 let result = self.generate_tracked(gen_req).await?;
2436 Ok(GroundResult {
2437 boxes: result.bounding_boxes,
2438 raw_text: result.text,
2439 model_used: Some(result.model_used),
2440 })
2441 }
2442
2443 pub async fn classify(
2446 &self,
2447 req: ClassifyRequest,
2448 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2449 let model = match req.model.clone().or_else(|| {
2450 self.preferred_model_for_capability(ModelCapability::Classify)
2451 .map(str::to_string)
2452 }) {
2453 Some(m) => m,
2454 None => {
2455 let m = self.router.route_small(&self.registry);
2456 debug!(model = %m, "auto-routed classify request");
2457 m
2458 }
2459 };
2460
2461 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2464 {
2465 return self.classify_via_generate(req, &model).await;
2466 }
2467
2468 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2469 {
2470 self.ensure_backend(&model).await?;
2471 let mut write = self.backend.write().await;
2472 let backend = write.as_mut().unwrap();
2473 tasks::classify::classify(backend, req).await
2474 }
2475 }
2476
2477 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2479 async fn classify_via_generate(
2480 &self,
2481 req: ClassifyRequest,
2482 model: &str,
2483 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2484 let labels_str = req
2485 .labels
2486 .iter()
2487 .enumerate()
2488 .map(|(i, l)| format!("{}. {}", i + 1, l))
2489 .collect::<Vec<_>>()
2490 .join("\n");
2491
2492 let prompt = format!(
2493 "Classify the following text into one of these categories:\n\
2494 {labels_str}\n\n\
2495 Text: {}\n\n\
2496 Respond with ONLY the category name, nothing else.",
2497 req.text
2498 );
2499
2500 let gen_req = GenerateRequest {
2501 prompt,
2502 model: Some(model.to_string()),
2503 params: tasks::generate::GenerateParams {
2504 temperature: 0.0,
2505 max_tokens: 32,
2506 thinking: tasks::generate::ThinkingMode::Off,
2509 ..Default::default()
2510 },
2511 context: None,
2512 tools: None,
2513 images: None,
2514 messages: None,
2515 cache_control: false,
2516 response_format: None,
2517 intent: None,
2518 };
2519
2520 let response = self.generate(gen_req).await?;
2521 let response_lower = response.trim().to_lowercase();
2522
2523 let mut results: Vec<ClassifyResult> = req
2524 .labels
2525 .iter()
2526 .map(|label| {
2527 let label_lower = label.to_lowercase();
2528 let score = if response_lower == label_lower {
2529 1.0
2530 } else if response_lower.contains(&label_lower) {
2531 0.8
2532 } else {
2533 let label_words: Vec<&str> = label_lower.split_whitespace().collect();
2534 let matches = label_words
2535 .iter()
2536 .filter(|w| response_lower.contains(**w))
2537 .count();
2538 if label_words.is_empty() {
2539 0.0
2540 } else {
2541 0.5 * (matches as f64 / label_words.len() as f64)
2542 }
2543 };
2544 ClassifyResult {
2545 label: label.clone(),
2546 score,
2547 }
2548 })
2549 .collect();
2550
2551 results.sort_by(|a, b| {
2552 b.score
2553 .partial_cmp(&a.score)
2554 .unwrap_or(std::cmp::Ordering::Equal)
2555 });
2556
2557 let total: f64 = results.iter().map(|r| r.score).sum();
2558 if total > 0.0 {
2559 for r in &mut results {
2560 r.score /= total;
2561 }
2562 }
2563
2564 Ok(results)
2565 }
2566
2567 pub async fn transcribe(
2569 &self,
2570 req: TranscribeRequest,
2571 ) -> Result<TranscribeResult, InferenceError> {
2572 let candidates =
2573 self.speech_candidates(ModelCapability::SpeechToText, req.model.as_deref())?;
2574 let mut last_error = None;
2575
2576 for schema in candidates {
2577 let result = match &schema.source {
2578 ModelSource::Mlx { .. } => self.transcribe_local_mlx(&schema, &req).await,
2579 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2580 self.transcribe_elevenlabs(&schema, &req).await
2581 }
2582 _ => Err(InferenceError::InferenceFailed(format!(
2583 "speech-to-text not implemented for model source: {}",
2584 schema.id
2585 ))),
2586 };
2587
2588 match result {
2589 Ok(result) => return Ok(result),
2590 Err(err) => last_error = Some(err),
2591 }
2592 }
2593
2594 Err(last_error.unwrap_or_else(|| {
2595 InferenceError::InferenceFailed("no speech-to-text models available".into())
2596 }))
2597 }
2598
2599 pub async fn synthesize(
2601 &self,
2602 req: SynthesizeRequest,
2603 ) -> Result<SynthesizeResult, InferenceError> {
2604 let candidates =
2605 self.speech_candidates(ModelCapability::TextToSpeech, req.model.as_deref())?;
2606 let mut last_error = None;
2607
2608 for schema in candidates {
2609 let result = match &schema.source {
2610 ModelSource::Mlx { .. } => self.synthesize_local_mlx(&schema, &req).await,
2611 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2612 self.synthesize_elevenlabs(&schema, &req).await
2613 }
2614 _ => Err(InferenceError::InferenceFailed(format!(
2615 "text-to-speech not implemented for model source: {}",
2616 schema.id
2617 ))),
2618 };
2619
2620 match result {
2621 Ok(result) => return Ok(result),
2622 Err(err) => last_error = Some(err),
2623 }
2624 }
2625
2626 Err(last_error.unwrap_or_else(|| {
2627 InferenceError::InferenceFailed("no text-to-speech models available".into())
2628 }))
2629 }
2630
2631 pub async fn generate_image(
2633 &self,
2634 req: GenerateImageRequest,
2635 ) -> Result<GenerateImageResult, InferenceError> {
2636 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2643 {
2644 use crate::backend::external_flux;
2645 let backend =
2646 std::env::var("CAR_IMAGE_BACKEND").unwrap_or_else(|_| "native".to_string());
2647 let use_external = match backend.as_str() {
2648 "external" => true,
2649 "native" => false,
2650 _ => external_flux::is_available() && backend == "auto-external",
2652 };
2653 if use_external {
2654 tracing::info!(
2655 "routing image generation to external mflux \
2656 (set CAR_IMAGE_BACKEND=native to use the Rust port)"
2657 );
2658 let mut req = req;
2659 req.model = self.resolve_external_hf_repo(
2660 req.model.as_deref(),
2661 ModelCapability::ImageGeneration,
2662 );
2663 return external_flux::generate_image(&req);
2664 }
2665 tracing::info!("using native Rust MLX Flux backend");
2666 }
2667
2668 let candidates = self
2669 .media_generation_candidates(ModelCapability::ImageGeneration, req.model.as_deref())?;
2670 let mut last_error = None;
2671
2672 for schema in candidates {
2673 let result = match &schema.source {
2674 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2675 ModelSource::Mlx { .. } => self.generate_image_native_mlx(&schema, &req).await,
2676 _ => Err(InferenceError::InferenceFailed(format!(
2677 "image generation not implemented for model source: {}",
2678 schema.id
2679 ))),
2680 };
2681
2682 match result {
2683 Ok(result) => return Ok(result),
2684 Err(err) => last_error = Some(err),
2685 }
2686 }
2687
2688 Err(last_error.unwrap_or_else(|| {
2689 InferenceError::InferenceFailed("no image generation models available".into())
2690 }))
2691 }
2692
2693 pub async fn generate_image_batch(
2709 &self,
2710 req: GenerateImageRequest,
2711 ) -> Result<Vec<GenerateImageResult>, InferenceError> {
2712 let count = req.variant_count.unwrap_or(1).max(1);
2713 if count == 1 {
2714 return self.generate_image(req).await.map(|r| vec![r]);
2715 }
2716 let base_seed = req.seed.unwrap_or(0);
2717 let mut results = Vec::with_capacity(count as usize);
2718 for i in 0..count {
2719 let mut variant_req = req.clone();
2724 variant_req.seed = Some(base_seed.wrapping_add(i as u64));
2725 variant_req.variant_count = Some(1);
2729 results.push(self.generate_image(variant_req).await?);
2730 }
2731 Ok(results)
2732 }
2733
2734 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2736 async fn generate_image_native_mlx(
2737 &self,
2738 schema: &ModelSchema,
2739 req: &GenerateImageRequest,
2740 ) -> Result<GenerateImageResult, InferenceError> {
2741 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2742 let size = backend_cache::estimate_model_size(&model_dir);
2743 let cache = Arc::clone(&self.flux_cache);
2744 let key = schema.id.clone();
2745 let handle = cache.get_or_load(&key, size, || {
2746 backend::mlx_flux::FluxBackend::load(&model_dir)
2747 })?;
2748 let req = req.clone();
2752 tokio::task::spawn_blocking(move || -> Result<GenerateImageResult, InferenceError> {
2753 let mut guard = handle.lock().map_err(|_| {
2754 InferenceError::InferenceFailed("flux backend mutex poisoned".into())
2755 })?;
2756 guard.generate(&req)
2757 })
2758 .await
2759 .map_err(|e| InferenceError::InferenceFailed(format!("flux task join: {e}")))?
2760 }
2761
2762 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2764 async fn generate_video_native_mlx(
2765 &self,
2766 schema: &ModelSchema,
2767 req: &GenerateVideoRequest,
2768 ) -> Result<GenerateVideoResult, InferenceError> {
2769 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2770 let size = backend_cache::estimate_model_size(&model_dir);
2771 let cache = Arc::clone(&self.ltx_cache);
2772 let key = schema.id.clone();
2773 let handle = cache.get_or_load(&key, size, || {
2774 backend::mlx_ltx::LtxBackend::load(&model_dir)
2775 })?;
2776 let req = req.clone();
2777 tokio::task::spawn_blocking(move || -> Result<GenerateVideoResult, InferenceError> {
2778 let mut guard = handle.lock().map_err(|_| {
2779 InferenceError::InferenceFailed("ltx backend mutex poisoned".into())
2780 })?;
2781 guard.generate(&req)
2782 })
2783 .await
2784 .map_err(|e| InferenceError::InferenceFailed(format!("ltx task join: {e}")))?
2785 }
2786
2787 pub async fn generate_video(
2789 &self,
2790 req: GenerateVideoRequest,
2791 ) -> Result<GenerateVideoResult, InferenceError> {
2792 if let Err(msg) = req.validate() {
2795 return Err(InferenceError::InferenceFailed(format!(
2796 "invalid GenerateVideoRequest: {}",
2797 msg
2798 )));
2799 }
2800 let requires_audio_conditioning = req.requires_audio_passthrough_opt_in();
2801 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2806 {
2807 use crate::backend::external_ltx;
2808 let backend =
2813 std::env::var("CAR_VIDEO_BACKEND").unwrap_or_else(|_| "native".to_string());
2814 let use_external = match backend.as_str() {
2815 "external" => true,
2816 "native" => false,
2817 "auto-external" => external_ltx::is_available(),
2821 _ => false,
2822 };
2823 if use_external {
2824 tracing::info!(
2825 "CAR_VIDEO_BACKEND requested external LTX routing for LTX-family models"
2826 );
2827 } else {
2828 tracing::info!("using family-aware MLX video routing");
2829 }
2830 }
2831
2832 let candidates = self
2833 .media_generation_candidates(ModelCapability::VideoGeneration, req.model.as_deref())?;
2834 let mut last_error = None;
2835
2836 for schema in candidates {
2837 let result = match &schema.source {
2838 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2839 ModelSource::Mlx { hf_repo, .. } => {
2840 if crate::backend::external_mlx_video::is_wan_family(&schema) {
2841 match self.unified_registry.ensure_local(&schema.id).await {
2842 Ok(model_dir) => {
2843 crate::backend::external_mlx_video::generate_wan_video(
2844 &schema, &model_dir, &req,
2845 )
2846 }
2847 Err(err) => Err(err),
2848 }
2849 } else {
2850 let backend = std::env::var("CAR_VIDEO_BACKEND")
2851 .unwrap_or_else(|_| "native".to_string());
2852 let use_external_ltx = match backend.as_str() {
2853 "external" => true,
2854 "native" => false,
2855 "auto-external" => crate::backend::external_ltx::is_available(),
2856 _ => false,
2857 };
2858 let use_external_ltx = use_external_ltx || requires_audio_conditioning;
2859 if requires_audio_conditioning
2860 && !crate::backend::external_ltx::is_available()
2861 {
2862 return Err(InferenceError::InferenceFailed(
2863 "audio-reference video conditioning requires the external `ltx-2-mlx a2v` CLI on PATH"
2864 .to_string(),
2865 ));
2866 }
2867 if use_external_ltx {
2868 let mut req = req.clone();
2869 req.model = Some(hf_repo.clone());
2870 crate::backend::external_ltx::generate_video(&req)
2871 } else {
2872 self.generate_video_native_mlx(&schema, &req).await
2873 }
2874 }
2875 }
2876 _ => Err(InferenceError::InferenceFailed(format!(
2877 "video generation not implemented for model source: {}",
2878 schema.id
2879 ))),
2880 };
2881
2882 match result {
2883 Ok(result) => return Ok(result),
2884 Err(err) => last_error = Some(err),
2885 }
2886 }
2887
2888 Err(last_error.unwrap_or_else(|| {
2889 InferenceError::InferenceFailed("no video generation models available".into())
2890 }))
2891 }
2892
2893 pub fn list_models_unified(&self) -> Vec<ModelInfo> {
2895 self.unified_registry
2896 .list()
2897 .iter()
2898 .map(|m| ModelInfo::from(*m))
2899 .collect()
2900 }
2901
2902 pub fn available_model_upgrades(&self) -> Vec<ModelUpgrade> {
2904 self.unified_registry.available_upgrades()
2905 }
2906
2907 pub fn list_schemas(&self) -> Vec<ModelSchema> {
2910 self.unified_registry.list().into_iter().cloned().collect()
2911 }
2912
2913 pub fn list_models(&self) -> Vec<models::ModelInfo> {
2914 self.registry.list_models()
2915 }
2916
2917 pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
2919 let schema = self
2920 .unified_registry
2921 .find_by_name(name)
2922 .or_else(|| self.unified_registry.get(name))
2923 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2924 self.unified_registry.ensure_local(&schema.id).await
2925 }
2926
2927 pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
2929 let schema = self
2930 .unified_registry
2931 .get(name)
2932 .or_else(|| {
2933 self.unified_registry
2934 .list()
2935 .into_iter()
2936 .find(|schema| schema.name.eq_ignore_ascii_case(name))
2937 })
2938 .or_else(|| self.unified_registry.find_by_name(name))
2939 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2940 let model_dir = self.unified_registry.models_dir().join(&schema.name);
2941 if model_dir.exists() {
2942 std::fs::remove_dir_all(&model_dir)?;
2943 }
2944 match &schema.source {
2945 ModelSource::Mlx { hf_repo, .. } => {
2946 remove_huggingface_repo_cache(hf_repo)?;
2947 }
2948 ModelSource::Local {
2949 hf_repo,
2950 tokenizer_repo,
2951 ..
2952 } => {
2953 remove_huggingface_repo_cache(hf_repo)?;
2954 remove_huggingface_repo_cache(tokenizer_repo)?;
2955 }
2956 _ => {}
2957 }
2958 Ok(())
2959 }
2960
2961 pub fn register_model(&mut self, schema: ModelSchema) {
2963 self.unified_registry.register(schema);
2964 }
2965
2966 pub async fn discover_vllm_mlx_models(&mut self) -> usize {
2969 let config = vllm_mlx::VllmMlxConfig::default();
2970 if !config.auto_discover {
2971 return 0;
2972 }
2973 vllm_mlx::discover_and_register(&config, &mut self.unified_registry).await
2974 }
2975
2976 pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
2978 self.outcome_tracker.clone()
2979 }
2980
2981 async fn auto_save_outcomes(&self) {
2983 if let Err(e) = self.save_outcomes().await {
2984 tracing::debug!("auto-save outcomes failed: {}", e);
2985 }
2986 if let Err(e) = self.save_key_pool_stats().await {
2987 tracing::debug!("auto-save key pool stats failed: {}", e);
2988 }
2989 }
2990
2991 pub async fn save_outcomes(&self) -> Result<(), std::io::Error> {
2993 let tracker = self.outcome_tracker.read().await;
2994 let path = self.config.models_dir.join("outcome_profiles.json");
2995 tracker.save_to_file(&path)
2996 }
2997
2998 pub async fn save_key_pool_stats(&self) -> Result<(), std::io::Error> {
3000 let path = self.config.models_dir.join("key_pool_stats.json");
3001 self.remote_backend.key_pool.save_stats(&path).await
3002 }
3003
3004 pub async fn key_pool_stats(
3006 &self,
3007 ) -> std::collections::HashMap<String, Vec<key_pool::KeyStats>> {
3008 self.remote_backend.key_pool.all_stats().await
3009 }
3010
3011 pub async fn export_profiles(&self) -> Vec<ModelProfile> {
3013 let tracker = self.outcome_tracker.read().await;
3014 tracker.export_profiles()
3015 }
3016
3017 pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
3019 let mut tracker = self.outcome_tracker.write().await;
3020 tracker.import_profiles(profiles);
3021 }
3022
3023 pub async fn prepare_speech_runtime(&self) -> Result<PathBuf, InferenceError> {
3026 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3027 {
3028 Ok(self.config.models_dir.clone())
3030 }
3031 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3032 {
3033 Ok(self.ensure_speech_runtime().await?.root)
3034 }
3035 }
3036
3037 pub fn set_speech_policy(&mut self, policy: SpeechPolicy) {
3039 self.speech_policy = policy;
3040 }
3041
3042 pub fn set_routing_config(&mut self, config: RoutingConfig) {
3043 self.adaptive_router.set_config(config);
3044 }
3045
3046 pub async fn install_curated_speech(
3048 &mut self,
3049 ) -> Result<Vec<SpeechInstallReport>, InferenceError> {
3050 let _runtime_root = self.prepare_speech_runtime().await?;
3051 let schemas = self.list_schemas();
3052 let mut repos = Vec::new();
3053 for schema in &schemas {
3054 if !schema.is_mlx() || !schema.tags.iter().any(|tag| tag == "speech") {
3055 continue;
3056 }
3057 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3058 if !repos.iter().any(|existing: &String| existing == hf_repo) {
3059 repos.push(hf_repo.clone());
3060 }
3061 }
3062 }
3063
3064 let mut installed = Vec::new();
3065 for repo in repos {
3066 let (snapshot_path, files_downloaded) = download_hf_repo_snapshot(&repo).await?;
3067 let name = schemas
3068 .iter()
3069 .find(|schema| {
3070 matches!(&schema.source, ModelSource::Mlx { hf_repo, .. } if hf_repo == &repo)
3071 })
3072 .map(|schema| schema.name.clone())
3073 .unwrap_or_else(|| repo.clone());
3074 installed.push(SpeechInstallReport {
3075 name,
3076 hf_repo: repo,
3077 snapshot_path,
3078 files_downloaded,
3079 });
3080 }
3081
3082 self.unified_registry.refresh_availability();
3083 Ok(installed)
3084 }
3085
3086 pub fn speech_health(&self) -> SpeechHealthReport {
3088 let local_stt_default =
3089 self.speech_health_default_name(ModelCapability::SpeechToText, true, false);
3090 let local_tts_default =
3091 self.speech_health_default_name(ModelCapability::TextToSpeech, true, false);
3092 let remote_stt_default =
3093 self.speech_health_default_name(ModelCapability::SpeechToText, false, true);
3094 let remote_tts_default =
3095 self.speech_health_default_name(ModelCapability::TextToSpeech, false, true);
3096
3097 let mut local_models = Vec::new();
3098 let mut remote_models = Vec::new();
3099 for schema in self.list_schemas() {
3100 let capability = if schema.has_capability(ModelCapability::SpeechToText) {
3101 Some(ModelCapability::SpeechToText)
3102 } else if schema.has_capability(ModelCapability::TextToSpeech) {
3103 Some(ModelCapability::TextToSpeech)
3104 } else {
3105 None
3106 };
3107 let Some(capability) = capability else {
3108 continue;
3109 };
3110
3111 let selected_by_default = local_stt_default
3112 .as_ref()
3113 .is_some_and(|name| name == &schema.name)
3114 || local_tts_default
3115 .as_ref()
3116 .is_some_and(|name| name == &schema.name)
3117 || remote_stt_default
3118 .as_ref()
3119 .is_some_and(|name| name == &schema.name)
3120 || remote_tts_default
3121 .as_ref()
3122 .is_some_and(|name| name == &schema.name);
3123
3124 let health = SpeechModelHealth {
3125 id: schema.id.clone(),
3126 name: schema.name.clone(),
3127 provider: schema.provider.clone(),
3128 capability,
3129 is_local: schema.is_local(),
3130 available: schema.available,
3131 cached: speech_model_cached(&schema),
3132 selected_by_default,
3133 source: speech_model_source_label(&schema),
3134 };
3135 if schema.is_local() {
3136 local_models.push(health);
3137 } else {
3138 remote_models.push(health);
3139 }
3140 }
3141
3142 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3144 let runtime = SpeechRuntimeHealth {
3145 root: self.config.models_dir.clone(),
3146 installed: true,
3147 python: PathBuf::new(),
3148 stt_command: PathBuf::new(),
3149 tts_command: PathBuf::new(),
3150 configured_python: None,
3151 detected_python: None,
3152 };
3153
3154 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3155 let runtime = {
3156 let rt =
3157 SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3158 SpeechRuntimeHealth {
3159 root: rt.root.clone(),
3160 installed: rt.is_ready(),
3161 python: rt.python.clone(),
3162 stt_command: rt.stt_program.clone(),
3163 tts_command: rt.tts_program.clone(),
3164 configured_python: std::env::var("CAR_SPEECH_PYTHON")
3165 .ok()
3166 .filter(|value| !value.trim().is_empty()),
3167 detected_python: detect_speech_python(),
3168 }
3169 };
3170
3171 SpeechHealthReport {
3172 runtime,
3173 local_models,
3174 remote_models,
3175 elevenlabs_configured: car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY")
3176 .is_some(),
3177 prefer_local: self.speech_policy.prefer_local,
3178 allow_remote_fallback: self.speech_policy.allow_remote_fallback,
3179 preferred_local_stt: self.speech_policy.preferred_local_stt.clone(),
3180 preferred_local_tts: self.speech_policy.preferred_local_tts.clone(),
3181 preferred_remote_stt: self.speech_policy.preferred_remote_stt.clone(),
3182 preferred_remote_tts: self.speech_policy.preferred_remote_tts.clone(),
3183 local_stt_default,
3184 local_tts_default,
3185 remote_stt_default,
3186 remote_tts_default,
3187 }
3188 }
3189
3190 pub async fn model_health(&self) -> ModelHealthReport {
3193 let schemas = self.list_schemas();
3194 let total_models = schemas.len();
3195 let available_models = schemas.iter().filter(|schema| schema.available).count();
3196 let local_models = schemas.iter().filter(|schema| schema.is_local()).count();
3197 let remote_models = total_models.saturating_sub(local_models);
3198
3199 let defaults = vec![
3200 self.model_default_health(
3201 ModelCapability::Generate,
3202 self.preferred_model_for_capability(ModelCapability::Generate)
3203 .unwrap_or(&self.config.generation_model),
3204 ),
3205 self.model_default_health(
3206 ModelCapability::Embed,
3207 self.preferred_model_for_capability(ModelCapability::Embed)
3208 .unwrap_or(&self.config.embedding_model),
3209 ),
3210 self.model_default_health(
3211 ModelCapability::Classify,
3212 self.preferred_model_for_capability(ModelCapability::Classify)
3213 .unwrap_or(&self.config.classification_model),
3214 ),
3215 ];
3216
3217 let mut providers = std::collections::BTreeMap::new();
3218 for schema in &schemas {
3219 let entry =
3220 providers
3221 .entry(schema.provider.clone())
3222 .or_insert_with(|| ProviderAccumulator {
3223 configured: false,
3224 local_models: 0,
3225 remote_models: 0,
3226 available_models: 0,
3227 capabilities: std::collections::HashSet::new(),
3228 });
3229
3230 entry.configured |= model_source_configured(schema);
3231 if schema.is_local() {
3232 entry.local_models += 1;
3233 } else {
3234 entry.remote_models += 1;
3235 }
3236 if schema.available {
3237 entry.available_models += 1;
3238 }
3239 for capability in &schema.capabilities {
3240 entry.capabilities.insert(*capability);
3241 }
3242 }
3243
3244 let providers = providers
3245 .into_iter()
3246 .map(|(provider, acc)| ModelProviderHealth {
3247 provider,
3248 configured: acc.configured,
3249 local_models: acc.local_models,
3250 remote_models: acc.remote_models,
3251 available_models: acc.available_models,
3252 capabilities: sort_capabilities(acc.capabilities.into_iter().collect()),
3253 })
3254 .collect();
3255
3256 let capabilities = all_model_capabilities()
3257 .into_iter()
3258 .map(|capability| {
3259 let relevant: Vec<&ModelSchema> = schemas
3260 .iter()
3261 .filter(|schema| schema.has_capability(capability))
3262 .collect();
3263 let available: Vec<&ModelSchema> = relevant
3264 .iter()
3265 .copied()
3266 .filter(|schema| schema.available)
3267 .collect();
3268 ModelCapabilityHealth {
3269 capability,
3270 total_models: relevant.len(),
3271 available_models: available.len(),
3272 local_available_models: available
3273 .iter()
3274 .filter(|schema| schema.is_local())
3275 .count(),
3276 remote_available_models: available
3277 .iter()
3278 .filter(|schema| !schema.is_local())
3279 .count(),
3280 }
3281 })
3282 .collect();
3283
3284 let routing = self.routing_scenarios().await;
3285 let routing_config = self.adaptive_router.config().clone();
3286 let benchmark_priors = load_benchmark_prior_health(&self.config.models_dir, &schemas);
3287
3288 ModelHealthReport {
3289 total_models,
3290 available_models,
3291 local_models,
3292 remote_models,
3293 defaults,
3294 providers,
3295 capabilities,
3296 routing_prefer_local: routing_config.prefer_local,
3297 routing_quality_first_cold_start: routing_config.quality_first_cold_start,
3298 routing_min_observations: routing_config.min_observations,
3299 routing_bootstrap_min_task_observations: routing_config.bootstrap_min_task_observations,
3300 routing_bootstrap_quality_floor: routing_config.bootstrap_quality_floor,
3301 routing_quality_weight: routing_config.quality_weight,
3302 routing_latency_weight: routing_config.latency_weight,
3303 routing_cost_weight: routing_config.cost_weight,
3304 routing_scenarios: routing,
3305 benchmark_priors,
3306 speech: self.speech_health(),
3307 }
3308 }
3309
3310 async fn routing_scenarios(&self) -> Vec<RoutingScenarioHealth> {
3311 let tracker = self.outcome_tracker.read().await;
3312 let config = self.adaptive_router.config().clone();
3313 let scenarios = [
3314 (
3315 "interactive_text",
3316 "Summarize the benefits of local-first AI routing in two sentences.",
3317 "text",
3318 RoutingWorkload::Interactive,
3319 false,
3320 false,
3321 ),
3322 (
3323 "background_code",
3324 "Write a Python function named fibonacci(n) that returns the nth Fibonacci number.",
3325 "code",
3326 RoutingWorkload::Background,
3327 false,
3328 false,
3329 ),
3330 (
3331 "interactive_tool_use",
3332 "Use the provided weather tool to get the weather for Boston.",
3333 "tool_use",
3334 RoutingWorkload::Interactive,
3335 true,
3336 false,
3337 ),
3338 (
3339 "interactive_vision",
3340 "What is in this image? Answer in one word.",
3341 "vision",
3342 RoutingWorkload::Interactive,
3343 false,
3344 true,
3345 ),
3346 ];
3347
3348 scenarios
3349 .into_iter()
3350 .map(
3351 |(name, prompt, task_family, workload, has_tools, has_vision)| {
3352 let decision = self.adaptive_router.route_context_aware(
3353 prompt,
3354 0,
3355 &self.unified_registry,
3356 &tracker,
3357 has_tools,
3358 has_vision,
3359 workload,
3360 );
3361 let quality_first_cold_start = if has_tools || has_vision {
3362 config.quality_first_cold_start
3363 } else if task_family == "code"
3364 && matches!(workload, RoutingWorkload::Background)
3365 {
3366 false
3367 } else {
3368 config.quality_first_cold_start
3369 };
3370 RoutingScenarioHealth {
3371 name: name.to_string(),
3372 task_family: task_family.to_string(),
3373 workload,
3374 has_tools,
3375 has_vision,
3376 prefer_local: if task_family == "speech" {
3377 self.speech_policy.prefer_local
3378 } else {
3379 config.prefer_local
3380 },
3381 quality_first_cold_start,
3382 bootstrap_min_task_observations: config.bootstrap_min_task_observations,
3383 bootstrap_quality_floor: config.bootstrap_quality_floor,
3384 model_id: decision.model_id,
3385 model_name: decision.model_name,
3386 reason: decision.reason,
3387 strategy: decision.strategy,
3388 }
3389 },
3390 )
3391 .collect()
3392 }
3393
3394 pub async fn smoke_test_speech(
3396 &self,
3397 local: bool,
3398 remote: bool,
3399 ) -> Result<SpeechSmokeReport, InferenceError> {
3400 let mut report = SpeechSmokeReport::default();
3401
3402 if local {
3403 let tts = self
3404 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
3405 .ok_or_else(|| {
3406 InferenceError::InferenceFailed(
3407 "no local text-to-speech model available".into(),
3408 )
3409 })?;
3410 let stt = self
3411 .preferred_speech_schema(ModelCapability::SpeechToText, true, false)
3412 .ok_or_else(|| {
3413 InferenceError::InferenceFailed(
3414 "no local speech-to-text model available".into(),
3415 )
3416 })?;
3417 report.local = Some(
3418 self.run_speech_smoke_path("local", &tts, &stt, "Testing CAR local speech path.")
3419 .await?,
3420 );
3421 } else {
3422 report.skipped.push("local".to_string());
3423 }
3424
3425 if remote {
3426 let tts = self
3427 .preferred_speech_schema(ModelCapability::TextToSpeech, false, true)
3428 .ok_or_else(|| {
3429 InferenceError::InferenceFailed(
3430 "no remote text-to-speech model available".into(),
3431 )
3432 })?;
3433 let stt = self
3434 .preferred_speech_schema(ModelCapability::SpeechToText, false, true)
3435 .ok_or_else(|| {
3436 InferenceError::InferenceFailed(
3437 "no remote speech-to-text model available".into(),
3438 )
3439 })?;
3440 report.remote = Some(
3441 self.run_speech_smoke_path("remote", &tts, &stt, "Testing CAR remote speech path.")
3442 .await?,
3443 );
3444 } else {
3445 report.skipped.push("remote".to_string());
3446 }
3447
3448 Ok(report)
3449 }
3450
3451 fn speech_candidates(
3452 &self,
3453 capability: ModelCapability,
3454 explicit: Option<&str>,
3455 ) -> Result<Vec<ModelSchema>, InferenceError> {
3456 if let Some(model) = explicit {
3457 let schema = self
3458 .unified_registry
3459 .get(model)
3460 .or_else(|| self.unified_registry.find_by_name(model))
3461 .cloned()
3462 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3463 if !schema.has_capability(capability) {
3464 return Err(InferenceError::InferenceFailed(format!(
3465 "model {} does not support {:?}",
3466 schema.name, capability
3467 )));
3468 }
3469 return Ok(vec![schema]);
3470 }
3471
3472 let mut candidates: Vec<ModelSchema> = self
3473 .unified_registry
3474 .query(&ModelFilter {
3475 capabilities: vec![capability],
3476 ..Default::default()
3477 })
3478 .into_iter()
3479 .cloned()
3480 .collect();
3481
3482 if candidates.is_empty() {
3483 return Err(InferenceError::InferenceFailed(format!(
3484 "no models registered for capability {:?}",
3485 capability
3486 )));
3487 }
3488
3489 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3490 if !self.speech_policy.allow_remote_fallback
3491 && candidates.iter().any(|model| model.is_local())
3492 {
3493 candidates.retain(|model| model.is_local());
3494 }
3495
3496 Ok(candidates)
3497 }
3498
3499 fn resolve_external_hf_repo(
3504 &self,
3505 explicit: Option<&str>,
3506 capability: ModelCapability,
3507 ) -> Option<String> {
3508 let id = explicit?;
3509 let schema = self
3510 .unified_registry
3511 .get(id)
3512 .or_else(|| self.unified_registry.find_by_name(id))?;
3513 if !schema.has_capability(capability) {
3514 return Some(id.to_string());
3515 }
3516 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3517 return Some(hf_repo.clone());
3518 }
3519 Some(id.to_string())
3520 }
3521
3522 fn media_generation_candidates(
3523 &self,
3524 capability: ModelCapability,
3525 explicit: Option<&str>,
3526 ) -> Result<Vec<ModelSchema>, InferenceError> {
3527 if let Some(model) = explicit {
3528 let schema = self
3529 .unified_registry
3530 .get(model)
3531 .or_else(|| self.unified_registry.find_by_name(model))
3532 .cloned()
3533 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3534 if !schema.has_capability(capability) {
3535 return Err(InferenceError::InferenceFailed(format!(
3536 "model {} does not support {:?}",
3537 schema.name, capability
3538 )));
3539 }
3540 return Ok(vec![schema]);
3541 }
3542
3543 let mut candidates: Vec<ModelSchema> = self
3544 .unified_registry
3545 .query(&ModelFilter {
3546 capabilities: vec![capability],
3547 local_only: true,
3548 ..Default::default()
3549 })
3550 .into_iter()
3551 .cloned()
3552 .collect();
3553 candidates.sort_by_key(|schema| (!schema.available, schema.size_mb()));
3554 if candidates.is_empty() {
3555 return Err(InferenceError::InferenceFailed(format!(
3556 "no models registered for capability {:?}",
3557 capability
3558 )));
3559 }
3560 Ok(candidates)
3561 }
3562
3563 fn preferred_speech_schema(
3564 &self,
3565 capability: ModelCapability,
3566 local_only: bool,
3567 remote_only: bool,
3568 ) -> Option<ModelSchema> {
3569 let available_only = remote_only;
3570 let mut candidates: Vec<ModelSchema> = self
3571 .unified_registry
3572 .query(&ModelFilter {
3573 capabilities: vec![capability],
3574 available_only,
3575 ..Default::default()
3576 })
3577 .into_iter()
3578 .filter(|schema| {
3579 (!local_only || schema.is_local()) && (!remote_only || schema.is_remote())
3580 })
3581 .cloned()
3582 .collect();
3583 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3584 candidates.into_iter().next()
3585 }
3586
3587 fn speech_health_default_name(
3588 &self,
3589 capability: ModelCapability,
3590 local_only: bool,
3591 remote_only: bool,
3592 ) -> Option<String> {
3593 let preferred = match capability {
3594 ModelCapability::SpeechToText if local_only => {
3595 self.speech_policy.preferred_local_stt.as_ref()
3596 }
3597 ModelCapability::SpeechToText if remote_only => {
3598 self.speech_policy.preferred_remote_stt.as_ref()
3599 }
3600 ModelCapability::TextToSpeech if local_only => {
3601 self.speech_policy.preferred_local_tts.as_ref()
3602 }
3603 ModelCapability::TextToSpeech if remote_only => {
3604 self.speech_policy.preferred_remote_tts.as_ref()
3605 }
3606 _ => None,
3607 };
3608
3609 preferred
3610 .filter(|name| {
3611 self.unified_registry.list().iter().any(|schema| {
3612 schema.name == **name
3613 && schema.has_capability(capability)
3614 && (!local_only || schema.is_local())
3615 && (!remote_only || schema.is_remote())
3616 })
3617 })
3618 .cloned()
3619 .or_else(|| {
3620 self.preferred_speech_schema(capability, local_only, remote_only)
3621 .map(|schema| schema.name)
3622 })
3623 }
3624
3625 fn model_default_health(
3626 &self,
3627 capability: ModelCapability,
3628 configured_model: &str,
3629 ) -> ModelDefaultHealth {
3630 let schema = self
3631 .unified_registry
3632 .find_by_name(configured_model)
3633 .or_else(|| self.unified_registry.get(configured_model));
3634
3635 ModelDefaultHealth {
3636 capability,
3637 configured_model: configured_model.to_string(),
3638 available: schema.is_some_and(|model| model.available),
3639 is_local: schema.is_some_and(ModelSchema::is_local),
3640 provider: schema.map(|model| model.provider.clone()),
3641 }
3642 }
3643
3644 fn speech_sort_key(
3645 &self,
3646 capability: ModelCapability,
3647 model: &ModelSchema,
3648 ) -> (u8, u8, u8, u8, u64, u64) {
3649 let policy_preference = match capability {
3650 ModelCapability::SpeechToText if model.is_local() => {
3651 self.speech_policy.preferred_local_stt.as_ref()
3652 }
3653 ModelCapability::SpeechToText => self.speech_policy.preferred_remote_stt.as_ref(),
3654 ModelCapability::TextToSpeech if model.is_local() => {
3655 self.speech_policy.preferred_local_tts.as_ref()
3656 }
3657 ModelCapability::TextToSpeech => self.speech_policy.preferred_remote_tts.as_ref(),
3658 _ => None,
3659 };
3660 let local_rank = if self.speech_policy.prefer_local {
3661 if model.is_local() {
3662 0
3663 } else {
3664 1
3665 }
3666 } else if model.is_remote() {
3667 0
3668 } else {
3669 1
3670 };
3671 let availability_rank = if model.available {
3672 0
3673 } else if model.is_local() {
3674 1
3675 } else {
3676 2
3677 };
3678 let policy_rank: u8 = if policy_preference.is_some_and(|preferred| preferred == &model.name)
3679 {
3680 0
3681 } else {
3682 1
3683 };
3684 let speech_rank = match capability {
3685 ModelCapability::TextToSpeech => {
3686 if model.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
3687 0
3688 } else if model.name == "Kokoro-82M-bf16" {
3689 1
3690 } else if model.name == "Kokoro-82M-6bit" {
3691 2
3692 } else {
3693 3
3694 }
3695 }
3696 ModelCapability::SpeechToText => {
3697 if model.name == "Parakeet-TDT-0.6B-v3-MLX" {
3698 0
3699 } else {
3700 1
3701 }
3702 }
3703 _ => 0,
3704 };
3705 let latency_rank = model.performance.latency_p50_ms.unwrap_or(u64::MAX);
3706 let size_rank = model.cost.size_mb.unwrap_or(u64::MAX);
3707 (
3708 local_rank,
3709 availability_rank,
3710 policy_rank,
3711 speech_rank,
3712 latency_rank,
3713 size_rank,
3714 )
3715 }
3716
3717 async fn run_speech_smoke_path(
3718 &self,
3719 path: &str,
3720 tts: &ModelSchema,
3721 stt: &ModelSchema,
3722 text: &str,
3723 ) -> Result<SpeechSmokePathReport, InferenceError> {
3724 let work_dir = temp_work_dir(&format!("speech-smoke-{path}"))?;
3725 let audio_path = work_dir.join(format!("{path}.wav"));
3726 let synth = self
3727 .synthesize(SynthesizeRequest {
3728 text: text.to_string(),
3729 model: Some(tts.name.clone()),
3730 voice: default_speech_voice(tts),
3731 language: Some("en".to_string()),
3732 output_path: Some(audio_path.display().to_string()),
3733 ..SynthesizeRequest::default()
3734 })
3735 .await?;
3736 let transcript = self
3737 .transcribe(TranscribeRequest {
3738 audio_path: synth.audio_path.clone(),
3739 model: Some(stt.name.clone()),
3740 language: Some("en".to_string()),
3741 prompt: None,
3742 timestamps: false,
3743 })
3744 .await?;
3745
3746 Ok(SpeechSmokePathReport {
3747 path: path.to_string(),
3748 tts_model: synth.model_used.unwrap_or_else(|| tts.name.clone()),
3749 stt_model: transcript.model_used.unwrap_or_else(|| stt.name.clone()),
3750 audio_path: PathBuf::from(synth.audio_path),
3751 transcript: transcript.text,
3752 })
3753 }
3754
3755 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3756 async fn ensure_speech_runtime(&self) -> Result<SpeechRuntime, InferenceError> {
3757 let mut guard = self.speech_runtime.lock().await;
3758 if let Some(runtime) = guard.as_ref() {
3759 if runtime.is_ready() {
3760 return Ok(runtime.clone());
3761 }
3762 }
3763
3764 let runtime =
3765 SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3766 if !runtime.is_ready() {
3767 bootstrap_speech_runtime(&runtime).await?;
3768 }
3769 if !runtime.is_ready() {
3770 return Err(InferenceError::InferenceFailed(format!(
3771 "managed speech runtime is not ready at {}",
3772 runtime.root.display()
3773 )));
3774 }
3775
3776 *guard = Some(runtime.clone());
3777 Ok(runtime)
3778 }
3779
3780 async fn transcribe_local_mlx(
3781 &self,
3782 schema: &ModelSchema,
3783 req: &TranscribeRequest,
3784 ) -> Result<TranscribeResult, InferenceError> {
3785 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3787 {
3788 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3789 let parakeet = backend::mlx_parakeet::ParakeetBackend::load(&model_dir)?;
3790 let (text, words) = if req.timestamps {
3792 parakeet
3793 .transcribe_detailed(Path::new(&req.audio_path))
3794 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?
3795 } else {
3796 let t = parakeet
3797 .transcribe(Path::new(&req.audio_path))
3798 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?;
3799 (t, Vec::new())
3800 };
3801 return Ok(TranscribeResult {
3802 text,
3803 model_used: Some(schema.name.clone()),
3804 language: req.language.clone(),
3805 words,
3806 });
3807 }
3808
3809 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3811 {
3812 let runtime = self.ensure_speech_runtime().await?;
3813 let hf_repo = match &schema.source {
3814 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3815 _ => unreachable!(),
3816 };
3817 let output_dir = temp_work_dir("stt")?;
3818 let output_prefix = output_dir.join("transcript");
3819 let mut args = vec![
3820 "--model".to_string(),
3821 hf_repo,
3822 "--audio".to_string(),
3823 req.audio_path.clone(),
3824 "--output-path".to_string(),
3825 output_prefix.display().to_string(),
3826 "--format".to_string(),
3827 "json".to_string(),
3828 ];
3829 if let Some(language) = &req.language {
3830 args.push("--language".to_string());
3831 args.push(normalize_lang_code(language));
3832 }
3833 if let Some(prompt) = &req.prompt {
3834 args.push("--context".to_string());
3835 args.push(prompt.clone());
3836 }
3837 if req.timestamps {
3838 args.push("--verbose".to_string());
3839 }
3840
3841 let output = run_mlx_audio_command(&runtime, "stt.generate", &args).await?;
3842 let text = read_transcription_result(&output_prefix)?
3843 .or_else(|| extract_text_from_payload(&output.stdout))
3844 .ok_or_else(|| {
3845 InferenceError::InferenceFailed(format!(
3846 "mlx-audio transcription returned no text: {}",
3847 output.stderr
3848 ))
3849 })?;
3850
3851 Ok(TranscribeResult {
3852 text,
3853 model_used: Some(schema.name.clone()),
3854 language: req.language.clone(),
3855 words: Vec::new(),
3856 })
3857 }
3858 }
3859
3860 async fn synthesize_local_mlx(
3861 &self,
3862 schema: &ModelSchema,
3863 req: &SynthesizeRequest,
3864 ) -> Result<SynthesizeResult, InferenceError> {
3865 let requested = req.requested_advanced_controls();
3870 let repo_supports_advanced = match &schema.source {
3871 ModelSource::Mlx { hf_repo, .. } => hf_repo.to_ascii_lowercase().contains("qwen3-tts"),
3872 _ => false,
3873 };
3874 if !requested.is_empty() && !repo_supports_advanced {
3875 if req.strict_capabilities {
3876 return Err(InferenceError::InferenceFailed(format!(
3877 "model {name} does not support Qwen3-TTS advanced controls {requested:?}; \
3878 route to a Qwen3-TTS model or set strict_capabilities = false to degrade",
3879 name = schema.name,
3880 )));
3881 }
3882 tracing::warn!(
3883 model = %schema.name,
3884 fields = ?requested,
3885 "Qwen3-TTS advanced controls set on non-Qwen3-TTS backend — ignored \
3886 (set strict_capabilities=true to error instead)"
3887 );
3888 }
3889
3890 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3892 {
3893 if repo_supports_advanced && !requested.is_empty() {
3899 if req.strict_capabilities {
3900 return Err(InferenceError::InferenceFailed(format!(
3901 "native MLX TTS backend does not yet implement Qwen3-TTS advanced \
3902 controls {requested:?}; run on non-Apple-Silicon to use the Python \
3903 mlx-audio fallback, or set strict_capabilities = false"
3904 )));
3905 }
3906 tracing::warn!(
3907 model = %schema.name,
3908 fields = ?requested,
3909 "Qwen3-TTS advanced controls are not yet implemented in the native MLX TTS \
3910 backend; synthesizing without cloning/voice-design"
3911 );
3912 }
3913 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3914 let size = backend_cache::estimate_model_size(&model_dir);
3915 let cache = Arc::clone(&self.kokoro_cache);
3916 let key = schema.id.clone();
3917 let handle = cache.get_or_load(&key, size, || {
3918 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
3919 })?;
3920
3921 let output_path = req.output_path.clone().unwrap_or_else(|| {
3922 let dir = std::env::temp_dir().join("car_tts");
3923 let _ = std::fs::create_dir_all(&dir);
3924 dir.join("output.wav").display().to_string()
3925 });
3926 let voice = req.voice.as_deref().unwrap_or("af_heart").to_string();
3927 let text = req.text.clone();
3928 let op = tokio::task::spawn_blocking(move || -> Result<PathBuf, InferenceError> {
3929 let mut guard = handle.lock().map_err(|_| {
3930 InferenceError::InferenceFailed("kokoro backend mutex poisoned".into())
3931 })?;
3932 guard
3933 .synthesize(&text, Some(&voice), Path::new(&output_path))
3934 .map_err(|e| InferenceError::InferenceFailed(format!("native TTS: {e}")))
3935 })
3936 .await
3937 .map_err(|e| InferenceError::InferenceFailed(format!("kokoro task join: {e}")))??;
3938
3939 let final_path =
3940 materialize_audio_output(&op, req.output_path.as_deref(), &req.format)?;
3941 return Ok(SynthesizeResult {
3942 audio_path: final_path.display().to_string(),
3943 media_type: media_type_for_format(&req.format),
3944 model_used: Some(schema.name.clone()),
3945 voice_used: req.voice.clone(),
3946 });
3947 }
3948
3949 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3951 {
3952 let runtime = self.ensure_speech_runtime().await?;
3953 let primary_hf_repo = match &schema.source {
3954 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3955 _ => unreachable!(),
3956 };
3957 let (produced, model_used) = match self
3958 .synthesize_local_mlx_repo(&runtime, &primary_hf_repo, schema.name.as_str(), req)
3959 .await
3960 {
3961 Ok(result) => result,
3962 Err(primary_err)
3963 if primary_hf_repo == "mlx-community/Kokoro-82M-6bit"
3964 && kokoro_runtime_fallback_enabled() =>
3965 {
3966 let fallback_repo = "mlx-community/Kokoro-82M-bf16";
3967 let fallback_name = "Kokoro-82M-bf16";
3968 match self
3969 .synthesize_local_mlx_repo(&runtime, fallback_repo, fallback_name, req)
3970 .await
3971 {
3972 Ok(result) => result,
3973 Err(fallback_err) => {
3974 return Err(InferenceError::InferenceFailed(format!(
3975 "{primary_err}; fallback {fallback_name} also failed: {fallback_err}"
3976 )));
3977 }
3978 }
3979 }
3980 Err(err) => return Err(err),
3981 };
3982 let final_path =
3983 materialize_audio_output(&produced, req.output_path.as_deref(), &req.format)?;
3984
3985 Ok(SynthesizeResult {
3986 audio_path: final_path.display().to_string(),
3987 media_type: media_type_for_format(&req.format),
3988 model_used: Some(model_used),
3989 voice_used: req.voice.clone(),
3990 })
3991 }
3992 }
3993
3994 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3995 async fn synthesize_local_mlx_repo(
3996 &self,
3997 runtime: &SpeechRuntime,
3998 hf_repo: &str,
3999 model_name: &str,
4000 req: &SynthesizeRequest,
4001 ) -> Result<(PathBuf, String), InferenceError> {
4002 let output_dir = temp_work_dir("tts")?;
4003 let mut args = vec![
4004 "--model".to_string(),
4005 hf_repo.to_string(),
4006 "--text".to_string(),
4007 req.text.clone(),
4008 "--output_path".to_string(),
4009 output_dir.display().to_string(),
4010 ];
4011 if let Some(voice) = &req.voice {
4012 args.push("--voice".to_string());
4013 args.push(voice.clone());
4014 }
4015 if let Some(speed) = req.speed {
4016 args.push("--speed".to_string());
4017 args.push(speed.to_string());
4018 }
4019 let repo_lower = hf_repo.to_ascii_lowercase();
4020 if repo_lower.contains("kokoro") {
4021 args.push("--lang_code".to_string());
4022 args.push(kokoro_lang_code(req.language.as_deref()).to_string());
4023 } else if let Some(language) = &req.language {
4024 args.push("--lang_code".to_string());
4025 args.push(normalize_lang_code(language));
4026 }
4027
4028 if repo_lower.contains("qwen3-tts") {
4034 if let Some(ref_audio) = &req.reference_audio_path {
4035 args.push("--ref_audio".to_string());
4036 args.push(ref_audio.clone());
4037 }
4038 if let Some(ref_text) = &req.reference_text {
4039 args.push("--ref_text".to_string());
4040 args.push(ref_text.clone());
4041 }
4042 if let Some(instruct) = &req.voice_instruction {
4043 args.push("--instruct".to_string());
4044 args.push(instruct.clone());
4045 }
4046 }
4047
4048 let output = if repo_lower.contains("kokoro") {
4049 let device = std::env::var("CAR_SPEECH_KOKORO_DEVICE")
4050 .or_else(|_| std::env::var("CAR_SPEECH_MLX_DEVICE"))
4051 .unwrap_or_else(|_| "cpu".to_string());
4052 let extra_env = vec![
4053 ("MLX_DEVICE".to_string(), device),
4055 ("PYTORCH_ENABLE_MPS_FALLBACK".to_string(), "1".to_string()),
4057 ];
4058 run_mlx_audio_command_with_env(runtime, "tts.generate", &args, &extra_env).await?
4059 } else {
4060 run_mlx_audio_command(runtime, "tts.generate", &args).await?
4061 };
4062 let produced = find_audio_file(&output_dir)?.ok_or_else(|| {
4063 let hint = if repo_lower.contains("kokoro") {
4064 ". Kokoro models may crash on GPU — try CAR_SPEECH_KOKORO_DEVICE=cpu or use the default Qwen3-TTS model"
4065 } else {
4066 ""
4067 };
4068 InferenceError::InferenceFailed(format!(
4069 "mlx-audio synthesis produced no audio file: {}{}",
4070 output.stderr, hint
4071 ))
4072 })?;
4073 Ok((produced, model_name.to_string()))
4074 }
4075
4076 async fn transcribe_elevenlabs(
4077 &self,
4078 schema: &ModelSchema,
4079 req: &TranscribeRequest,
4080 ) -> Result<TranscribeResult, InferenceError> {
4081 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4082 let file_name = Path::new(&req.audio_path)
4083 .file_name()
4084 .and_then(|f| f.to_str())
4085 .unwrap_or("audio.wav")
4086 .to_string();
4087 let audio_bytes = tokio::fs::read(&req.audio_path).await?;
4088 let file_part = Part::bytes(audio_bytes).file_name(file_name);
4089 let mut form = Form::new()
4090 .text("model_id", schema.name.clone())
4091 .part("file", file_part);
4092 if let Some(language) = &req.language {
4093 form = form.text("language_code", language.clone());
4094 }
4095
4096 let resp = self
4097 .remote_backend
4098 .client
4099 .post(format!(
4100 "{}/v1/speech-to-text",
4101 endpoint.trim_end_matches('/')
4102 ))
4103 .header("xi-api-key", api_key)
4104 .multipart(form)
4105 .send()
4106 .await
4107 .map_err(|e| {
4108 InferenceError::InferenceFailed(format!("ElevenLabs STT request failed: {e}"))
4109 })?;
4110 let status = resp.status();
4111 let body = resp.text().await.map_err(|e| {
4112 InferenceError::InferenceFailed(format!("read ElevenLabs STT body: {e}"))
4113 })?;
4114 if !status.is_success() {
4115 return Err(InferenceError::InferenceFailed(format!(
4116 "ElevenLabs STT returned {status}: {body}"
4117 )));
4118 }
4119 let payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
4120 InferenceError::InferenceFailed(format!("parse ElevenLabs STT response: {e}"))
4121 })?;
4122 let text = payload
4123 .get("text")
4124 .and_then(|v| v.as_str())
4125 .map(str::to_string)
4126 .ok_or_else(|| {
4127 InferenceError::InferenceFailed("ElevenLabs STT response missing text".into())
4128 })?;
4129
4130 Ok(TranscribeResult {
4131 text,
4132 model_used: Some(schema.name.clone()),
4133 language: payload
4134 .get("language_code")
4135 .and_then(|v| v.as_str())
4136 .map(str::to_string),
4137 words: Vec::new(),
4138 })
4139 }
4140
4141 async fn synthesize_elevenlabs(
4142 &self,
4143 schema: &ModelSchema,
4144 req: &SynthesizeRequest,
4145 ) -> Result<SynthesizeResult, InferenceError> {
4146 let requested = req.requested_advanced_controls();
4150 if !requested.is_empty() {
4151 if req.strict_capabilities {
4152 return Err(InferenceError::InferenceFailed(format!(
4153 "ElevenLabs backend does not support Qwen3-TTS advanced controls \
4154 {requested:?}; route to a Qwen3-TTS model or set strict_capabilities = false"
4155 )));
4156 }
4157 tracing::warn!(
4158 model = %schema.name,
4159 fields = ?requested,
4160 "Qwen3-TTS advanced controls ignored by ElevenLabs backend"
4161 );
4162 }
4163 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4164 let voice_id = req
4165 .voice
4166 .clone()
4167 .unwrap_or_else(|| "JBFqnCBsd6RMkjVDRZzb".to_string());
4168 let output_format = elevenlabs_output_format(&req.format);
4169 let url = format!(
4170 "{}/v1/text-to-speech/{}?output_format={}",
4171 endpoint.trim_end_matches('/'),
4172 voice_id,
4173 output_format
4174 );
4175
4176 let mut body = serde_json::json!({
4177 "text": req.text,
4178 "model_id": schema.name,
4179 });
4180 if let Some(language) = &req.language {
4181 body["language_code"] = serde_json::Value::String(language.clone());
4182 }
4183
4184 let resp = self
4185 .remote_backend
4186 .client
4187 .post(url)
4188 .header("xi-api-key", api_key)
4189 .header("Content-Type", "application/json")
4190 .json(&body)
4191 .send()
4192 .await
4193 .map_err(|e| {
4194 InferenceError::InferenceFailed(format!("ElevenLabs TTS request failed: {e}"))
4195 })?;
4196 let status = resp.status();
4197 let audio = resp.bytes().await.map_err(|e| {
4198 InferenceError::InferenceFailed(format!("read ElevenLabs TTS body: {e}"))
4199 })?;
4200 if !status.is_success() {
4201 let err_body = String::from_utf8_lossy(&audio);
4202 return Err(InferenceError::InferenceFailed(format!(
4203 "ElevenLabs TTS returned {status}: {err_body}"
4204 )));
4205 }
4206
4207 let final_path = requested_or_temp_output(req.output_path.as_deref(), &req.format)?;
4208 ensure_parent_dir(&final_path)?;
4209 tokio::fs::write(&final_path, &audio).await?;
4210
4211 Ok(SynthesizeResult {
4212 audio_path: final_path.display().to_string(),
4213 media_type: media_type_for_format(&req.format),
4214 model_used: Some(schema.name.clone()),
4215 voice_used: Some(voice_id),
4216 })
4217 }
4218}
4219
4220#[derive(Default)]
4221struct ProviderAccumulator {
4222 configured: bool,
4223 local_models: usize,
4224 remote_models: usize,
4225 available_models: usize,
4226 capabilities: std::collections::HashSet<ModelCapability>,
4227}
4228
4229#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4233struct CommandOutput {
4234 stdout: String,
4235 stderr: String,
4236}
4237
4238#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4239#[derive(Debug, Clone)]
4240struct SpeechRuntime {
4241 root: PathBuf,
4242 python: PathBuf,
4243 stt_program: PathBuf,
4244 tts_program: PathBuf,
4245}
4246
4247#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4248impl SpeechRuntime {
4249 fn new(root: PathBuf) -> Self {
4250 let bin_dir = root.join("bin");
4251 Self {
4252 root,
4253 python: bin_dir.join("python"),
4254 stt_program: bin_dir.join("mlx_audio.stt.generate"),
4255 tts_program: bin_dir.join("mlx_audio.tts.generate"),
4256 }
4257 }
4258
4259 fn is_ready(&self) -> bool {
4260 self.python.exists() && self.stt_program.exists() && self.tts_program.exists()
4261 }
4262
4263 fn command_for(&self, subcommand: &str) -> Result<&Path, InferenceError> {
4264 match subcommand {
4265 "stt.generate" => Ok(&self.stt_program),
4266 "tts.generate" => Ok(&self.tts_program),
4267 _ => Err(InferenceError::InferenceFailed(format!(
4268 "unknown speech subcommand: {subcommand}"
4269 ))),
4270 }
4271 }
4272}
4273
4274#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4275async fn run_mlx_audio_command(
4276 runtime: &SpeechRuntime,
4277 subcommand: &str,
4278 args: &[String],
4279) -> Result<CommandOutput, InferenceError> {
4280 run_mlx_audio_command_with_env(runtime, subcommand, args, &[]).await
4281}
4282
4283#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4284async fn run_mlx_audio_command_with_env(
4285 runtime: &SpeechRuntime,
4286 subcommand: &str,
4287 args: &[String],
4288 envs: &[(String, String)],
4289) -> Result<CommandOutput, InferenceError> {
4290 let program = runtime.command_for(subcommand)?;
4291 let mut command = Command::new(program);
4292 command.args(args);
4293 for (key, value) in envs {
4294 command.env(key, value);
4295 }
4296 let output = command
4297 .output()
4298 .await
4299 .map_err(|err| InferenceError::InferenceFailed(format!("{}: {err}", program.display())))?;
4300
4301 if output.status.success() {
4302 Ok(CommandOutput {
4303 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
4304 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
4305 })
4306 } else {
4307 Err(InferenceError::InferenceFailed(format!(
4308 "{} exited with {}: {}",
4309 program.display(),
4310 output.status,
4311 String::from_utf8_lossy(&output.stderr)
4312 )))
4313 }
4314}
4315
4316#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4317async fn bootstrap_speech_runtime(runtime: &SpeechRuntime) -> Result<(), InferenceError> {
4318 std::fs::create_dir_all(&runtime.root)?;
4319 let python = select_speech_python()?;
4320
4321 run_command(
4322 "uv",
4323 &[
4324 "venv".to_string(),
4325 "--python".to_string(),
4326 python,
4327 runtime.root.display().to_string(),
4328 ],
4329 )
4330 .await?;
4331
4332 run_command(
4333 "uv",
4334 &[
4335 "pip".to_string(),
4336 "install".to_string(),
4337 "--python".to_string(),
4338 runtime.python.display().to_string(),
4339 speech_runtime_mlx_audio_spec(),
4340 "misaki[en]".to_string(),
4341 speech_runtime_spacy_model_spec(),
4342 ],
4343 )
4344 .await?;
4345
4346 Ok(())
4347}
4348
4349#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4350async fn run_command(program: &str, args: &[String]) -> Result<(), InferenceError> {
4351 let output = Command::new(program)
4352 .args(args)
4353 .output()
4354 .await
4355 .map_err(|err| InferenceError::InferenceFailed(format!("{program}: {err}")))?;
4356
4357 if output.status.success() {
4358 Ok(())
4359 } else {
4360 Err(InferenceError::InferenceFailed(format!(
4361 "{} exited with {}: {}",
4362 program,
4363 output.status,
4364 String::from_utf8_lossy(&output.stderr)
4365 )))
4366 }
4367}
4368
4369#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4370fn select_speech_python() -> Result<String, InferenceError> {
4371 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4372 if !path.trim().is_empty() {
4373 return Ok(path);
4374 }
4375 }
4376
4377 for candidate in ["python3.13", "python3.12", "python3.11"] {
4378 if command_in_path(candidate) {
4379 return Ok(candidate.to_string());
4380 }
4381 }
4382
4383 Err(InferenceError::InferenceFailed(
4384 "no supported Python found for managed speech runtime (tried python3.13, python3.12, python3.11)".into(),
4385 ))
4386}
4387
4388#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4389fn detect_speech_python() -> Option<String> {
4390 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4391 if !path.trim().is_empty() {
4392 return Some(path);
4393 }
4394 }
4395
4396 ["python3.13", "python3.12", "python3.11"]
4397 .into_iter()
4398 .find(|candidate| command_in_path(candidate))
4399 .map(str::to_string)
4400}
4401
4402#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4403fn speech_runtime_root_from_models_dir(_models_dir: &Path) -> PathBuf {
4404 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
4405 if !path.trim().is_empty() {
4406 return PathBuf::from(path);
4407 }
4408 }
4409
4410 std::env::var("HOME")
4411 .map(PathBuf::from)
4412 .unwrap_or_else(|_| PathBuf::from("."))
4413 .join(".car")
4414 .join("speech-runtime")
4415}
4416
4417#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4418fn command_in_path(name: &str) -> bool {
4419 std::env::var_os("PATH")
4420 .map(|paths| {
4421 std::env::split_paths(&paths).any(|dir| {
4422 let path = dir.join(name);
4423 path.exists() && path.is_file()
4424 })
4425 })
4426 .unwrap_or(false)
4427}
4428
4429fn speech_model_cached(schema: &ModelSchema) -> bool {
4430 match &schema.source {
4431 ModelSource::Mlx { hf_repo, .. } => huggingface_repo_has_snapshot(hf_repo),
4432 ModelSource::Proprietary { auth, .. } => match auth {
4433 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4434 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4435 ProprietaryAuth::OAuth2Pkce { .. } => false,
4436 },
4437 _ => false,
4438 }
4439}
4440
4441fn remove_huggingface_repo_cache(repo_id: &str) -> Result<(), InferenceError> {
4442 let repo_dir = std::env::var("HF_HOME")
4443 .map(PathBuf::from)
4444 .unwrap_or_else(|_| {
4445 std::env::var("HOME")
4446 .map(PathBuf::from)
4447 .unwrap_or_else(|_| PathBuf::from("."))
4448 .join(".cache")
4449 .join("huggingface")
4450 })
4451 .join("hub")
4452 .join(format!("models--{}", repo_id.replace('/', "--")));
4453
4454 if repo_dir.exists() {
4455 std::fs::remove_dir_all(repo_dir)?;
4456 }
4457 Ok(())
4458}
4459
4460fn model_source_configured(schema: &ModelSchema) -> bool {
4461 match &schema.source {
4462 ModelSource::RemoteApi {
4463 api_key_env,
4464 api_key_envs,
4465 ..
4466 } => {
4467 std::env::var(api_key_env).is_ok()
4468 || api_key_envs
4469 .iter()
4470 .any(|env_var| std::env::var(env_var).is_ok())
4471 }
4472 ModelSource::Proprietary { auth, .. } => match auth {
4473 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4474 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4475 ProprietaryAuth::OAuth2Pkce { .. } => false,
4476 },
4477 ModelSource::VllmMlx { .. } => {
4478 std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available
4479 }
4480 ModelSource::Ollama { .. } => schema.available,
4481 ModelSource::Mlx { .. } | ModelSource::Local { .. } => true,
4482 ModelSource::AppleFoundationModels { .. } => schema.available,
4483 ModelSource::Delegated { .. } => true,
4488 }
4489}
4490
4491fn all_model_capabilities() -> [ModelCapability; 13] {
4492 [
4493 ModelCapability::Generate,
4494 ModelCapability::Embed,
4495 ModelCapability::Classify,
4496 ModelCapability::Code,
4497 ModelCapability::Reasoning,
4498 ModelCapability::Summarize,
4499 ModelCapability::ToolUse,
4500 ModelCapability::MultiToolCall,
4501 ModelCapability::Vision,
4502 ModelCapability::SpeechToText,
4503 ModelCapability::TextToSpeech,
4504 ModelCapability::ImageGeneration,
4505 ModelCapability::VideoGeneration,
4506 ]
4507}
4508
4509fn sort_capabilities(mut capabilities: Vec<ModelCapability>) -> Vec<ModelCapability> {
4510 capabilities.sort_by_key(|capability| {
4511 all_model_capabilities()
4512 .iter()
4513 .position(|candidate| candidate == capability)
4514 .unwrap_or(usize::MAX)
4515 });
4516 capabilities
4517}
4518
4519fn speech_model_source_label(schema: &ModelSchema) -> String {
4520 match &schema.source {
4521 ModelSource::Mlx { hf_repo, .. } => format!("mlx:{hf_repo}"),
4522 ModelSource::Proprietary {
4523 provider, endpoint, ..
4524 } => format!("proprietary:{provider}:{endpoint}"),
4525 ModelSource::RemoteApi { endpoint, .. } => format!("remote:{endpoint}"),
4526 ModelSource::Local { hf_repo, .. } => format!("local:{hf_repo}"),
4527 ModelSource::VllmMlx {
4528 endpoint,
4529 model_name,
4530 } => format!("vllm-mlx:{endpoint}:{model_name}"),
4531 ModelSource::Ollama { model_tag, host } => format!("ollama:{host}:{model_tag}"),
4532 ModelSource::AppleFoundationModels { use_case } => {
4533 format!(
4534 "apple-foundation:{}",
4535 use_case.as_deref().unwrap_or("default")
4536 )
4537 }
4538 ModelSource::Delegated { hint } => {
4539 format!("delegated:{}", hint.as_deref().unwrap_or("(none)"))
4540 }
4541 }
4542}
4543
4544fn rerank_prompt(instruction: &str, query: &str, document: &str) -> String {
4552 const SYSTEM: &str = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".";
4553 format!(
4554 "<|im_start|>system\n{SYSTEM}<|im_end|>\n\
4555 <|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n\
4556 <|im_start|>assistant\n<think>\n\n</think>\n\n"
4557 )
4558}
4559
4560fn score_from_rerank_output(text: &str, model_name: &str) -> f32 {
4566 let normalized: String = text
4571 .to_ascii_lowercase()
4572 .chars()
4573 .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
4574 .collect();
4575 for tok in normalized.split_ascii_whitespace().take(5) {
4576 match tok {
4577 "yes" => return 1.0,
4578 "no" => return 0.0,
4579 _ => continue,
4580 }
4581 }
4582 tracing::warn!(
4583 model = %model_name,
4584 output = %text,
4585 "rerank: first tokens contain neither `yes` nor `no`; returning neutral 0.5"
4586 );
4587 0.5
4588}
4589
4590fn default_speech_voice(schema: &ModelSchema) -> Option<String> {
4591 if schema.provider == "elevenlabs" {
4592 Some("JBFqnCBsd6RMkjVDRZzb".to_string())
4593 } else if schema.name == "Kokoro-82M-6bit" || schema.name == "Kokoro-82M-bf16" {
4594 Some("af_heart".to_string())
4595 } else if schema.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
4596 Some("Chelsie".to_string())
4597 } else {
4598 None
4599 }
4600}
4601
4602fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
4603 find_latest_huggingface_snapshot(repo_id).is_some()
4604}
4605
4606fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
4607 let cache_root = std::env::var("HF_HOME")
4608 .map(PathBuf::from)
4609 .unwrap_or_else(|_| {
4610 std::env::var("HOME")
4611 .map(PathBuf::from)
4612 .unwrap_or_else(|_| PathBuf::from("."))
4613 .join(".cache")
4614 .join("huggingface")
4615 })
4616 .join("hub");
4617 cache_root.join(format!("models--{}", repo_id.replace('/', "--")))
4618}
4619
4620fn find_latest_huggingface_snapshot(repo_id: &str) -> Option<PathBuf> {
4621 let snapshots = huggingface_repo_dir(repo_id).join("snapshots");
4622 std::fs::read_dir(snapshots)
4623 .ok()?
4624 .filter_map(Result::ok)
4625 .map(|entry| entry.path())
4626 .find(|path| path.is_dir() && snapshot_looks_ready(path))
4627}
4628
4629fn snapshot_looks_ready(path: &Path) -> bool {
4630 if path.join("config.json").exists() || path.join("model_index.json").exists() {
4631 return true;
4632 }
4633 snapshot_contains_ext(path, "safetensors")
4634}
4635
4636fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
4637 let Ok(entries) = std::fs::read_dir(root) else {
4638 return false;
4639 };
4640 entries.filter_map(Result::ok).any(|entry| {
4641 let path = entry.path();
4642 if path.is_dir() {
4643 snapshot_contains_ext(&path, ext)
4644 } else {
4645 path.extension()
4646 .and_then(|value| value.to_str())
4647 .map(|value| value.eq_ignore_ascii_case(ext))
4648 .unwrap_or(false)
4649 }
4650 })
4651}
4652
4653fn count_files_recursive(root: &Path) -> usize {
4654 let Ok(entries) = std::fs::read_dir(root) else {
4655 return 0;
4656 };
4657 entries
4658 .filter_map(Result::ok)
4659 .map(|entry| entry.path())
4660 .map(|path| {
4661 if path.is_dir() {
4662 count_files_recursive(&path)
4663 } else if path.is_file() {
4664 1
4665 } else {
4666 0
4667 }
4668 })
4669 .sum()
4670}
4671
4672async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
4673 let api = hf_hub::api::tokio::ApiBuilder::from_env()
4674 .with_progress(false)
4675 .build()
4676 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
4677 let repo = api.model(repo_id.to_string());
4678 let info = repo
4679 .info()
4680 .await
4681 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
4682
4683 let snapshot_path = huggingface_repo_dir(repo_id)
4684 .join("snapshots")
4685 .join(&info.sha);
4686 let mut downloaded = 0usize;
4687 for sibling in &info.siblings {
4688 let local_path = snapshot_path.join(&sibling.rfilename);
4689 if local_path.exists() {
4690 downloaded += 1;
4691 continue;
4692 }
4693 repo.download(&sibling.rfilename).await.map_err(|e| {
4694 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
4695 })?;
4696 downloaded += 1;
4697 }
4698
4699 Ok((snapshot_path, downloaded))
4700}
4701
4702fn temp_work_dir(prefix: &str) -> Result<PathBuf, InferenceError> {
4703 let unique = SystemTime::now()
4704 .duration_since(UNIX_EPOCH)
4705 .map_err(|e| InferenceError::InferenceFailed(format!("clock error: {e}")))?
4706 .as_nanos();
4707 let dir = std::env::temp_dir().join(format!("car-inference-{prefix}-{unique}"));
4708 std::fs::create_dir_all(&dir)?;
4709 Ok(dir)
4710}
4711
4712fn ensure_parent_dir(path: &Path) -> Result<(), InferenceError> {
4713 if let Some(parent) = path.parent() {
4714 std::fs::create_dir_all(parent)?;
4715 }
4716 Ok(())
4717}
4718
4719fn requested_or_temp_output(
4720 output_path: Option<&str>,
4721 format: &str,
4722) -> Result<PathBuf, InferenceError> {
4723 if let Some(path) = output_path {
4724 return Ok(PathBuf::from(path));
4725 }
4726 let dir = temp_work_dir("audio-out")?;
4727 Ok(dir.join(format!("speech.{format}")))
4728}
4729
4730fn requested_or_temp_media_output(
4731 output_path: Option<&str>,
4732 format: &str,
4733 stem: &str,
4734) -> Result<PathBuf, InferenceError> {
4735 if let Some(path) = output_path {
4736 return Ok(PathBuf::from(path));
4737 }
4738 let dir = temp_work_dir(&format!("{stem}-out"))?;
4739 Ok(dir.join(format!("{stem}.{format}")))
4740}
4741
4742fn materialize_audio_output(
4743 produced: &Path,
4744 requested: Option<&str>,
4745 format: &str,
4746) -> Result<PathBuf, InferenceError> {
4747 if let Some(path) = requested {
4748 let dest = PathBuf::from(path);
4749 ensure_parent_dir(&dest)?;
4750 std::fs::copy(produced, &dest)?;
4751 Ok(dest)
4752 } else {
4753 let dest = requested_or_temp_output(None, format)?;
4754 ensure_parent_dir(&dest)?;
4755 std::fs::copy(produced, &dest)?;
4756 Ok(dest)
4757 }
4758}
4759
4760fn materialize_binary_output(
4761 produced: &Path,
4762 requested: Option<&str>,
4763 format: &str,
4764 stem: &str,
4765) -> Result<PathBuf, InferenceError> {
4766 let dest = requested_or_temp_media_output(requested, format, stem)?;
4767 ensure_parent_dir(&dest)?;
4768 std::fs::copy(produced, &dest)?;
4769 Ok(dest)
4770}
4771
4772fn find_generated_file(
4773 root: &Path,
4774 extensions: &[&str],
4775) -> Result<Option<PathBuf>, InferenceError> {
4776 let entries = std::fs::read_dir(root)?;
4777 let mut candidates: Vec<PathBuf> = entries
4778 .filter_map(Result::ok)
4779 .map(|entry| entry.path())
4780 .filter(|path| {
4781 path.is_file()
4782 && path
4783 .extension()
4784 .and_then(|ext| ext.to_str())
4785 .map(|ext| {
4786 extensions
4787 .iter()
4788 .any(|candidate| candidate.eq_ignore_ascii_case(ext))
4789 })
4790 .unwrap_or(false)
4791 })
4792 .collect();
4793 candidates.sort();
4794 Ok(candidates.pop())
4795}
4796
4797fn media_type_for_image_format(format: &str) -> String {
4798 match format.to_ascii_lowercase().as_str() {
4799 "jpg" | "jpeg" => "image/jpeg".to_string(),
4800 "webp" => "image/webp".to_string(),
4801 _ => "image/png".to_string(),
4802 }
4803}
4804
4805fn media_type_for_video_format(format: &str) -> String {
4806 match format.to_ascii_lowercase().as_str() {
4807 "mov" => "video/quicktime".to_string(),
4808 "gif" => "image/gif".to_string(),
4809 _ => "video/mp4".to_string(),
4810 }
4811}
4812
4813fn read_transcription_result(output_prefix: &Path) -> Result<Option<String>, InferenceError> {
4814 let candidates = [
4815 output_prefix.with_extension("json"),
4816 output_prefix.to_path_buf(),
4817 ];
4818
4819 for path in candidates {
4820 if path.exists() {
4821 let contents = std::fs::read_to_string(path)?;
4822 if let Some(text) = extract_text_from_payload(&contents) {
4823 return Ok(Some(text));
4824 }
4825 }
4826 }
4827
4828 Ok(None)
4829}
4830
4831fn extract_text_from_payload(payload: &str) -> Option<String> {
4832 let value: serde_json::Value = serde_json::from_str(payload).ok()?;
4833 if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
4834 return Some(text.to_string());
4835 }
4836 if let Some(transcripts) = value.get("transcripts").and_then(|v| v.as_array()) {
4837 let joined = transcripts
4838 .iter()
4839 .filter_map(|item| item.get("text").and_then(|v| v.as_str()))
4840 .collect::<Vec<_>>()
4841 .join("\n");
4842 if !joined.is_empty() {
4843 return Some(joined);
4844 }
4845 }
4846 if let Some(items) = value.as_array() {
4847 let joined = items
4848 .iter()
4849 .filter_map(|item| {
4850 item.get("text")
4851 .or_else(|| item.get("Content"))
4852 .and_then(|v| v.as_str())
4853 })
4854 .collect::<Vec<_>>()
4855 .join(" ");
4856 if !joined.is_empty() {
4857 return Some(joined);
4858 }
4859 }
4860 None
4861}
4862
4863fn find_audio_file(output_dir: &Path) -> Result<Option<PathBuf>, InferenceError> {
4864 let mut audio_files = Vec::new();
4865 collect_audio_files(output_dir, &mut audio_files)?;
4866 audio_files.sort();
4867 Ok(audio_files.into_iter().next())
4868}
4869
4870fn collect_audio_files(dir: &Path, audio_files: &mut Vec<PathBuf>) -> Result<(), InferenceError> {
4871 for entry in std::fs::read_dir(dir)? {
4872 let path = entry?.path();
4873 if path.is_dir() {
4874 collect_audio_files(&path, audio_files)?;
4875 } else if matches!(
4876 path.extension().and_then(|ext| ext.to_str()),
4877 Some("wav" | "mp3" | "flac" | "pcm" | "m4a")
4878 ) {
4879 audio_files.push(path);
4880 }
4881 }
4882 Ok(())
4883}
4884
4885fn media_type_for_format(format: &str) -> String {
4886 match format.to_ascii_lowercase().as_str() {
4887 "mp3" => "audio/mpeg".to_string(),
4888 "flac" => "audio/flac".to_string(),
4889 "pcm" => "audio/L16".to_string(),
4890 "m4a" => "audio/mp4".to_string(),
4891 _ => "audio/wav".to_string(),
4892 }
4893}
4894
4895#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4896fn kokoro_lang_code(language: Option<&str>) -> &'static str {
4897 match language.unwrap_or("en").to_ascii_lowercase().as_str() {
4898 "en-gb" | "british" | "british english" => "b",
4899 "ja" | "japanese" => "j",
4900 "zh" | "zh-cn" | "mandarin" | "chinese" => "z",
4901 "es" | "spanish" => "e",
4902 "fr" | "french" => "f",
4903 _ => "a",
4904 }
4905}
4906
4907fn normalize_lang_code(language: &str) -> String {
4908 match language.to_ascii_lowercase().as_str() {
4909 "english" | "en-us" | "en_us" => "en".to_string(),
4910 "spanish" => "es".to_string(),
4911 "french" => "fr".to_string(),
4912 "japanese" => "ja".to_string(),
4913 "chinese" | "mandarin" => "zh".to_string(),
4914 other => match other {
4915 "en" | "es" | "fr" | "ja" | "zh" => other.to_string(),
4916 _ => "en".to_string(),
4917 },
4918 }
4919}
4920
4921fn elevenlabs_auth(schema: &ModelSchema) -> Result<(String, String), InferenceError> {
4922 match &schema.source {
4923 ModelSource::Proprietary {
4924 endpoint,
4925 auth: schema::ProprietaryAuth::ApiKeyEnv { env_var },
4926 ..
4927 } => {
4928 let key = car_secrets::resolve_env_or_keychain(env_var).ok_or_else(|| {
4929 InferenceError::InferenceFailed(format!(
4930 "missing API key {env_var}; set the environment variable or \
4931 store it with `car secrets put {env_var}`"
4932 ))
4933 })?;
4934 Ok((endpoint.clone(), key))
4935 }
4936 _ => Err(InferenceError::InferenceFailed(format!(
4937 "model {} is not an ElevenLabs proprietary model",
4938 schema.id
4939 ))),
4940 }
4941}
4942
4943fn elevenlabs_output_format(format: &str) -> &'static str {
4944 match format.to_ascii_lowercase().as_str() {
4945 "mp3" => "mp3_44100_128",
4946 "pcm" => "pcm_16000",
4947 _ => "wav_44100",
4948 }
4949}
4950
4951fn benchmark_priors_paths(models_dir: &Path) -> Vec<PathBuf> {
4952 let mut paths = Vec::new();
4953
4954 let direct = models_dir.join("benchmark_priors.json");
4955 if !paths.contains(&direct) {
4956 paths.push(direct);
4957 }
4958
4959 if let Some(parent) = models_dir.parent() {
4960 let parent_path = parent.join("benchmark_priors.json");
4961 if !paths.contains(&parent_path) {
4962 paths.push(parent_path);
4963 }
4964 }
4965
4966 if let Some(path) = std::env::var_os("CAR_BENCHMARK_PRIORS_PATH") {
4967 let path = PathBuf::from(path);
4968 if !paths.contains(&path) {
4969 paths.push(path);
4970 }
4971 }
4972
4973 paths
4974}
4975
4976fn load_benchmark_prior_health(
4977 models_dir: &Path,
4978 schemas: &[ModelSchema],
4979) -> Vec<ModelBenchmarkPriorHealth> {
4980 let mut priors = std::collections::BTreeMap::new();
4981 for path in benchmark_priors_paths(models_dir) {
4982 let Ok(loaded) = routing_ext::load_benchmark_priors(&path) else {
4983 continue;
4984 };
4985 for (model_id, prior) in loaded {
4986 let model_name = schemas
4987 .iter()
4988 .find(|schema| schema.id == model_id)
4989 .map(|schema| schema.name.clone());
4990 priors.insert(
4991 model_id.clone(),
4992 ModelBenchmarkPriorHealth {
4993 model_id,
4994 model_name,
4995 overall_score: prior.overall_score,
4996 overall_latency_ms: prior.overall_latency_ms,
4997 task_scores: prior.task_scores,
4998 task_latency_ms: prior.task_latency_ms,
4999 source_path: path.clone(),
5000 },
5001 );
5002 }
5003 }
5004
5005 priors.into_values().collect()
5006}
5007
5008#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5009fn kokoro_runtime_fallback_enabled() -> bool {
5010 std::env::var("CAR_SPEECH_KOKORO_FALLBACK")
5011 .ok()
5012 .map(|value| {
5013 !matches!(
5014 value.trim().to_ascii_lowercase().as_str(),
5015 "0" | "false" | "off"
5016 )
5017 })
5018 .unwrap_or(true)
5019}
5020
5021#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5022fn speech_runtime_mlx_audio_spec() -> String {
5023 std::env::var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC")
5024 .ok()
5025 .filter(|value| !value.trim().is_empty())
5026 .unwrap_or_else(|| "mlx-audio==0.4.2".to_string())
5027}
5028
5029#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5030fn speech_runtime_spacy_model_spec() -> String {
5031 std::env::var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC")
5032 .ok()
5033 .filter(|value| !value.trim().is_empty())
5034 .unwrap_or_else(|| {
5035 "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl".to_string()
5036 })
5037}
5038
5039#[cfg(test)]
5040mod tests {
5041 use super::*;
5042 use tempfile::TempDir;
5043
5044 static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
5047
5048 fn test_config(models_dir: PathBuf) -> InferenceConfig {
5049 InferenceConfig {
5050 models_dir,
5051 device: None,
5052 generation_model: "Qwen3-0.6B".into(),
5053 preferred_generation_model: None,
5054 embedding_model: "Qwen3-Embedding-0.6B".into(),
5055 preferred_embedding_model: None,
5056 classification_model: "Qwen3-0.6B".into(),
5057 preferred_classification_model: None,
5058 }
5059 }
5060
5061 #[tokio::test]
5062 async fn tokenize_rejects_known_remote_model_with_unsupported_mode() {
5063 let tmp = TempDir::new().unwrap();
5068 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5069 let remote_id = engine
5070 .list_schemas()
5071 .into_iter()
5072 .find(|s| !s.is_local())
5073 .map(|s| s.id)
5074 .expect("built-in catalog should include at least one remote model schema");
5075
5076 let err = engine
5077 .tokenize(&remote_id, "hello")
5078 .await
5079 .expect_err("remote tokenize must error");
5080 match err {
5081 InferenceError::UnsupportedMode { mode, backend, .. } => {
5082 assert_eq!(mode, "tokenize/detokenize");
5083 assert_eq!(backend, "remote");
5084 }
5085 other => panic!("expected UnsupportedMode, got {other:?}"),
5086 }
5087
5088 let err = engine
5089 .detokenize(&remote_id, &[1, 2, 3])
5090 .await
5091 .expect_err("remote detokenize must error");
5092 assert!(
5093 matches!(err, InferenceError::UnsupportedMode { .. }),
5094 "expected UnsupportedMode, got {err:?}"
5095 );
5096 }
5097
5098 #[test]
5099 fn engine_loads_benchmark_priors_on_startup() {
5100 let _env = ENV_MUTEX.lock().unwrap();
5101 let tmp = TempDir::new().unwrap();
5102 let priors_path = tmp.path().join("benchmark_priors.json");
5103 std::fs::write(
5104 &priors_path,
5105 serde_json::json!({
5106 "model_id": "qwen/qwen3-8b:q4_k_m",
5107 "overall_score": 0.88
5108 })
5109 .to_string(),
5110 )
5111 .unwrap();
5112
5113 unsafe {
5114 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5115 }
5116
5117 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5118 let tracker = engine.outcome_tracker.blocking_read();
5119 let profile = tracker
5120 .profile("qwen/qwen3-8b:q4_k_m")
5121 .expect("benchmark prior should create a profile");
5122 assert!((profile.ema_quality - 0.88).abs() < 0.01);
5123
5124 unsafe {
5125 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5126 }
5127 }
5128
5129 #[test]
5130 fn benchmark_priors_do_not_override_observed_profiles() {
5131 let _env = ENV_MUTEX.lock().unwrap();
5132 let tmp = TempDir::new().unwrap();
5133 let models_dir = tmp.path().join("models");
5134 std::fs::create_dir_all(&models_dir).unwrap();
5135
5136 let observed = vec![ModelProfile {
5137 model_id: "qwen/qwen3-8b:q4_k_m".into(),
5138 total_calls: 12,
5139 success_count: 3,
5140 fail_count: 9,
5141 total_latency_ms: 1200,
5142 total_input_tokens: 0,
5143 total_output_tokens: 0,
5144 task_stats: std::collections::HashMap::new(),
5145 ema_quality: 0.21,
5146 quality_per_1k_tokens: 0.0,
5147 updated_at: 1,
5148 }];
5149 std::fs::write(
5150 models_dir.join("outcome_profiles.json"),
5151 serde_json::to_string(&observed).unwrap(),
5152 )
5153 .unwrap();
5154
5155 let priors_path = tmp.path().join("benchmark_priors.json");
5156 std::fs::write(
5157 &priors_path,
5158 serde_json::json!({
5159 "model_id": "qwen/qwen3-8b:q4_k_m",
5160 "overall_score": 0.95
5161 })
5162 .to_string(),
5163 )
5164 .unwrap();
5165
5166 unsafe {
5167 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5168 }
5169
5170 let engine = InferenceEngine::new(test_config(models_dir));
5171 let tracker = engine.outcome_tracker.blocking_read();
5172 let profile = tracker
5173 .profile("qwen/qwen3-8b:q4_k_m")
5174 .expect("observed profile should remain present");
5175 assert!((profile.ema_quality - 0.21).abs() < 0.01);
5176 assert_eq!(profile.total_calls, 12);
5177
5178 unsafe {
5179 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5180 }
5181 }
5182
5183 #[test]
5184 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5185 fn speech_runtime_package_spec_defaults_and_overrides() {
5186 let _env = ENV_MUTEX.lock().unwrap();
5187 unsafe {
5188 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5189 }
5190 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.2");
5191
5192 unsafe {
5193 std::env::set_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC", "mlx-audio==0.4.1");
5194 }
5195 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.1");
5196
5197 unsafe {
5198 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5199 }
5200 }
5201
5202 #[test]
5203 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5204 fn speech_runtime_spacy_model_spec_defaults_and_overrides() {
5205 let _env = ENV_MUTEX.lock().unwrap();
5206 unsafe {
5207 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5208 }
5209 assert!(
5210 speech_runtime_spacy_model_spec().starts_with("en-core-web-sm @ https://github.com/")
5211 );
5212
5213 unsafe {
5214 std::env::set_var(
5215 "CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC",
5216 "en-core-web-sm==3.8.0",
5217 );
5218 }
5219 assert_eq!(speech_runtime_spacy_model_spec(), "en-core-web-sm==3.8.0");
5220
5221 unsafe {
5222 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5223 }
5224 }
5225
5226 #[test]
5227 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5228 fn kokoro_runtime_fallback_defaults_on() {
5229 unsafe {
5230 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5231 }
5232 assert!(kokoro_runtime_fallback_enabled());
5233
5234 unsafe {
5235 std::env::set_var("CAR_SPEECH_KOKORO_FALLBACK", "false");
5236 }
5237 assert!(!kokoro_runtime_fallback_enabled());
5238
5239 unsafe {
5240 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5241 }
5242 }
5243
5244 #[test]
5245 fn preferred_local_tts_wins_over_builtin_rank() {
5246 let tmp = TempDir::new().unwrap();
5247 let mut engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5248 engine.set_speech_policy(SpeechPolicy {
5249 prefer_local: true,
5250 allow_remote_fallback: false,
5251 preferred_local_stt: None,
5252 preferred_local_tts: Some("Kokoro-82M-6bit".into()),
5253 preferred_remote_stt: None,
5254 preferred_remote_tts: None,
5255 });
5256
5257 let schema = engine
5258 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
5259 .expect("preferred local TTS should resolve");
5260 assert_eq!(schema.name, "Kokoro-82M-6bit");
5261 }
5262
5263 #[test]
5264 fn preferred_discovered_vllm_mlx_model_wins_generate_routing() {
5265 let tmp = TempDir::new().unwrap();
5266 let mut config = test_config(tmp.path().join("models"));
5267 config.preferred_generation_model =
5268 Some("vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit".into());
5269 let mut engine = InferenceEngine::new(config);
5270 let schema = crate::vllm_mlx::to_model_schema(
5271 &crate::vllm_mlx::DiscoveredModel {
5272 id: "mlx-community/gemma-3n-E2B-it-lm-4bit".into(),
5273 owned_by: Some("mlx-community".into()),
5274 },
5275 "http://127.0.0.1:8001",
5276 );
5277 engine.register_model(schema);
5278
5279 let rt = tokio::runtime::Runtime::new().unwrap();
5280 let decision = rt.block_on(engine.route_adaptive("say hello in one sentence"));
5281 assert_eq!(
5282 decision.model_id,
5283 "vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit"
5284 );
5285 assert_eq!(decision.strategy, RoutingStrategy::Explicit);
5286 assert_eq!(decision.reason, "preferred generation model override");
5287 }
5288
5289 #[test]
5294 fn inference_result_serializes_with_full_shape() {
5295 use crate::tasks::generate::ToolCall;
5296 use std::collections::HashMap;
5297
5298 let mut args = HashMap::new();
5299 args.insert("path".to_string(), serde_json::json!("README.md"));
5300
5301 let result = InferenceResult {
5302 text: String::new(),
5303 bounding_boxes: Vec::new(),
5304 tool_calls: vec![ToolCall {
5305 id: None,
5306 name: "read_file".into(),
5307 arguments: args,
5308 }],
5309 trace_id: "trace-abc".into(),
5310 model_used: "test-model".into(),
5311 latency_ms: 1234,
5312 time_to_first_token_ms: Some(180),
5313 usage: Some(TokenUsage {
5314 prompt_tokens: 100,
5315 completion_tokens: 50,
5316 total_tokens: 150,
5317 context_window: 8192,
5318 }),
5319 provider_output_items: Vec::new(),
5320 };
5321
5322 let json = serde_json::to_value(&result).expect("serialize");
5323
5324 assert_eq!(json["text"].as_str(), Some(""));
5326 assert_eq!(json["trace_id"].as_str(), Some("trace-abc"));
5327 assert_eq!(json["model_used"].as_str(), Some("test-model"));
5328 assert_eq!(json["latency_ms"].as_u64(), Some(1234));
5329
5330 let tool_calls = json["tool_calls"].as_array().expect("tool_calls array");
5332 assert_eq!(tool_calls.len(), 1);
5333 assert_eq!(tool_calls[0]["name"].as_str(), Some("read_file"));
5334 assert_eq!(
5335 tool_calls[0]["arguments"]["path"].as_str(),
5336 Some("README.md")
5337 );
5338
5339 let usage = &json["usage"];
5341 assert_eq!(usage["prompt_tokens"].as_u64(), Some(100));
5342 assert_eq!(usage["completion_tokens"].as_u64(), Some(50));
5343 assert_eq!(usage["total_tokens"].as_u64(), Some(150));
5344 assert_eq!(usage["context_window"].as_u64(), Some(8192));
5345
5346 assert_eq!(json["time_to_first_token_ms"].as_u64(), Some(180));
5348 }
5349
5350 #[test]
5356 fn inference_result_top_level_keys_are_locked() {
5357 use std::collections::BTreeSet;
5358
5359 let result = InferenceResult {
5360 text: "anything".into(),
5361 bounding_boxes: Vec::new(),
5362 tool_calls: vec![],
5363 trace_id: "t".into(),
5364 model_used: "m".into(),
5365 latency_ms: 0,
5366 time_to_first_token_ms: None,
5367 usage: None,
5368 provider_output_items: Vec::new(),
5369 };
5370
5371 let json = serde_json::to_value(&result).expect("serialize");
5372 let keys: BTreeSet<&str> = json
5373 .as_object()
5374 .expect("top-level object")
5375 .keys()
5376 .map(String::as_str)
5377 .collect();
5378
5379 let expected: BTreeSet<&str> = [
5380 "text",
5381 "tool_calls",
5382 "trace_id",
5383 "model_used",
5384 "latency_ms",
5385 "time_to_first_token_ms",
5386 "usage",
5387 ]
5388 .into_iter()
5389 .collect();
5390
5391 assert_eq!(
5392 keys, expected,
5393 "infer response top-level keys drifted -- update both the test \
5394 and the WebSocket protocol documentation if this is intentional"
5395 );
5396
5397 for key in &keys {
5399 assert!(
5400 !key.chars().any(|c| c.is_uppercase()) && !key.contains('-'),
5401 "key '{}' is not snake_case",
5402 key
5403 );
5404 }
5405 }
5406
5407 #[test]
5411 fn inference_result_serializes_plain_text_response() {
5412 let result = InferenceResult {
5413 text: "hello world".into(),
5414 bounding_boxes: Vec::new(),
5415 tool_calls: vec![],
5416 trace_id: "trace-xyz".into(),
5417 model_used: "test-model".into(),
5418 latency_ms: 42,
5419 time_to_first_token_ms: None,
5420 usage: None,
5421 provider_output_items: Vec::new(),
5422 };
5423
5424 let json = serde_json::to_value(&result).expect("serialize");
5425 assert_eq!(json["text"], "hello world");
5426 assert!(json["tool_calls"].is_array());
5427 assert_eq!(json["tool_calls"].as_array().unwrap().len(), 0);
5428 assert_eq!(json["model_used"], "test-model");
5429 assert!(json["usage"].is_null());
5430 assert!(json["time_to_first_token_ms"].is_null());
5433 }
5434
5435 #[test]
5447 fn generate_request_deserializes_intent_field_from_json_rpc_params() {
5448 use crate::intent::{IntentHint, TaskHint};
5449 use crate::schema::ModelCapability;
5450
5451 let params = serde_json::json!({
5454 "prompt": "summarize this email",
5455 "intent": {
5456 "task": "chat",
5457 "prefer_local": true,
5458 "require": ["tool_use"],
5459 },
5460 });
5461
5462 let req: GenerateRequest =
5463 serde_json::from_value(params).expect("GenerateRequest deserialize");
5464
5465 let intent = req.intent.as_ref().expect("intent field deserialized");
5466 assert_eq!(intent.task, Some(TaskHint::Chat));
5467 assert!(intent.prefer_local);
5468 assert_eq!(intent.require, vec![ModelCapability::ToolUse]);
5469
5470 let back: serde_json::Value =
5474 serde_json::to_value(&req).expect("re-serialize GenerateRequest");
5475 assert_eq!(back["intent"]["task"], "chat");
5476 assert_eq!(back["intent"]["prefer_local"], true);
5477 assert_eq!(back["intent"]["require"][0], "tool_use");
5478
5479 let default_req: GenerateRequest = serde_json::from_value(serde_json::json!({
5484 "prompt": "x",
5485 "intent": {},
5486 }))
5487 .unwrap();
5488 let default_intent = default_req.intent.expect("present but empty");
5489 assert_eq!(default_intent.task, None);
5490 assert!(!default_intent.prefer_local);
5491 assert!(default_intent.require.is_empty());
5492
5493 let no_intent: GenerateRequest =
5496 serde_json::from_value(serde_json::json!({"prompt": "x"})).unwrap();
5497 assert!(no_intent.intent.is_none());
5498 }
5499
5500 #[test]
5501 fn rerank_prompt_matches_upstream_template_shape() {
5502 let p = rerank_prompt(
5503 "retrieve relevant passages",
5504 "who runs the treasury?",
5505 "doc x",
5506 );
5507 assert!(p.contains("<|im_start|>system"));
5508 assert!(p.contains("Note that the answer can only be \"yes\" or \"no\"."));
5509 assert!(p.contains("<|im_start|>user\n<Instruct>: retrieve relevant passages"));
5510 assert!(p.contains("<Query>: who runs the treasury?"));
5511 assert!(p.contains("<Document>: doc x<|im_end|>"));
5512 assert!(p.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
5513 }
5514
5515 #[test]
5516 fn rerank_score_yes_and_no_exactly() {
5517 assert_eq!(score_from_rerank_output("yes", "m"), 1.0);
5518 assert_eq!(score_from_rerank_output("no", "m"), 0.0);
5519 }
5520
5521 #[test]
5522 fn rerank_score_handles_case_leading_space_and_chat_sentinels() {
5523 assert_eq!(score_from_rerank_output(" Yes", "m"), 1.0);
5526 assert_eq!(score_from_rerank_output("\nno.", "m"), 0.0);
5527 assert_eq!(score_from_rerank_output("<|im_end|>yes", "m"), 1.0);
5528 }
5529
5530 #[test]
5531 fn rerank_score_scans_up_to_three_tokens() {
5532 assert_eq!(score_from_rerank_output("_bos_ yes", "m"), 1.0);
5535 }
5536
5537 #[test]
5538 fn rerank_score_unexpected_is_neutral() {
5539 assert_eq!(score_from_rerank_output("maybe", "m"), 0.5);
5542 assert_eq!(score_from_rerank_output("", "m"), 0.5);
5543 }
5544}