1pub mod adaptive_router;
26pub mod backend;
27pub mod backend_cache;
28pub mod concierge;
29pub mod download;
30pub mod handle;
31pub mod hardware;
32pub mod intent;
33pub mod key_pool;
34pub mod models;
35pub mod nudge;
36pub mod outcome;
37pub mod protocol;
38pub mod recommend;
39pub mod registry;
40pub mod remote;
41pub mod router;
42pub mod routing_ext;
43pub mod runner;
44pub mod schema;
45pub mod service;
46pub mod stream;
47pub mod tasks;
48pub mod update_prefs;
49pub mod upgrade;
50pub mod vllm_mlx;
51
52use std::path::{Path, PathBuf};
53use std::sync::Arc;
54use std::time::Instant;
55use std::time::{SystemTime, UNIX_EPOCH};
56
57use reqwest::multipart::{Form, Part};
58use serde::Serialize;
59use thiserror::Error;
60#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
61use tokio::process::Command;
62#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
63use tokio::sync::Mutex;
64use tokio::sync::RwLock;
65use tracing::{debug, instrument};
66
67pub use adaptive_router::{
69 AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy,
70};
71pub use handle::InferenceHandle;
72pub use intent::{
73 IntentHint, Privacy, QualityTier, TaskHint, TierWeights, UseCase, UseCaseRole,
74};
75pub use download::{DownloadEvent, DownloadProgress, ProgressSink};
76pub use update_prefs::{UpdateChannel, UpdatePolicy, UpdatePreferences};
77pub use nudge::{NudgeDecision, NudgeState, UpgradeNudge};
78pub use concierge::{
79 decide_concierge, ConciergeSuggestion, DEFAULT_CONCIERGE_THROTTLE_SECS,
80 DEFAULT_WATCHED_USE_CASES,
81};
82pub use upgrade::{HuggingFaceProbe, UpgradeFinding, UpgradeSource, UpstreamProbe};
83pub use recommend::{recommend, FitStatus, Recommendation, RecommendationSet};
84pub use key_pool::{KeyPool, KeyStats};
85pub use outcome::{
86 CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
87};
88pub use registry::{
89 ModelFilter, ModelInfo, ModelRuntimeRequirement, ModelUpgrade, UnifiedRegistry,
90};
91pub use remote::RemoteBackend;
92pub use routing_ext::{
93 CircuitBreaker, CircuitBreakerRegistry, CircuitState, ImplicitSignal, ImplicitSignalType,
94 RoutingMode, SpendControl, SpendLimitExceeded, SpendLimits, SpendStatus,
95};
96pub use runner::{
97 current_inference_runner, set_inference_runner, EventEmitter, InferenceRunner, RunnerError,
98 RunnerResult,
99};
100pub use schema::{
101 ApiProtocol, BenchmarkScore, CostModel, ModelCapability, ModelSchema, ModelSource,
102 PerformanceEnvelope, ProprietaryAuth, TrustTier,
103};
104
105pub use adaptive_router::TaskComplexity;
107#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
108pub use backend::CandleBackend;
109#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
110pub use backend::EmbeddingBackend;
111#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
112pub use backend::MlxBackend;
113pub use hardware::HardwareInfo;
114pub use models::{ModelRegistry, ModelRole};
115pub use router::{ModelRouter, RoutingDecision};
116pub use stream::{StreamAccumulator, StreamEvent};
117pub use tasks::{
118 parse_boxes, BoundingBox, ClassifyRequest, ClassifyResult, ContentBlock, EmbedRequest,
119 GenerateImageRequest, GenerateImageResult, GenerateParams, GenerateRequest,
120 GenerateVideoRequest, GenerateVideoResult, GroundRequest, GroundResult, Message, RerankRequest,
121 RerankResult, RerankedDocument, RoutingWorkload, SynthesizeRequest, SynthesizeResult,
122 ThinkingMode, ToolCall, TranscribeRequest, TranscribeResult, VideoMode,
123};
124#[derive(Error, Debug)]
127pub enum InferenceError {
128 #[error("model not found: {0}")]
129 ModelNotFound(String),
130
131 #[error("model download failed: {0}")]
132 DownloadFailed(String),
133
134 #[error("inference failed: {0}")]
135 InferenceFailed(String),
136
137 #[error("mode {mode} not implemented on backend {backend}: {reason}")]
142 UnsupportedMode {
143 mode: &'static str,
144 backend: &'static str,
145 reason: &'static str,
146 },
147
148 #[error("tokenization error: {0}")]
149 TokenizationError(String),
150
151 #[error("device error: {0}")]
152 DeviceError(String),
153
154 #[error("io error: {0}")]
155 Io(#[from] std::io::Error),
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum Device {
161 Cpu,
162 Metal,
163 Cuda(usize), }
165
166impl Device {
167 pub fn auto() -> Self {
169 #[cfg(feature = "metal")]
170 {
171 return Device::Metal;
172 }
173 #[cfg(feature = "cuda")]
174 {
175 return Device::Cuda(0);
176 }
177 #[cfg(not(any(feature = "metal", feature = "cuda")))]
178 {
179 Device::Cpu
180 }
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct InferenceConfig {
187 pub models_dir: std::path::PathBuf,
189 pub device: Option<Device>,
191 pub generation_model: String,
193 pub preferred_generation_model: Option<String>,
195 pub embedding_model: String,
197 pub preferred_embedding_model: Option<String>,
199 pub classification_model: String,
201 pub preferred_classification_model: Option<String>,
203}
204
205impl Default for InferenceConfig {
206 fn default() -> Self {
207 let models_dir = dirs_next()
208 .unwrap_or_else(|| std::path::PathBuf::from("."))
209 .join(".car")
210 .join("models");
211
212 let hw = HardwareInfo::detect();
213
214 Self {
215 models_dir,
216 device: None,
217 generation_model: hw.recommended_model,
218 preferred_generation_model: None,
219 embedding_model: "Qwen3-Embedding-0.6B".to_string(),
220 preferred_embedding_model: None,
221 classification_model: "Qwen3-0.6B".to_string(),
222 preferred_classification_model: None,
223 }
224 }
225}
226
227fn dirs_next() -> Option<std::path::PathBuf> {
228 std::env::var("HOME").ok().map(std::path::PathBuf::from)
229}
230
231#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
233pub struct TokenUsage {
234 pub prompt_tokens: u64,
236 pub completion_tokens: u64,
238 pub total_tokens: u64,
240 pub context_window: u64,
242}
243
244#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
246pub struct InferenceResult {
247 pub text: String,
249 pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
251 #[serde(default, skip_serializing_if = "Vec::is_empty")]
258 pub bounding_boxes: Vec<crate::tasks::grounding::BoundingBox>,
259 pub trace_id: String,
261 pub model_used: String,
263 pub latency_ms: u64,
265 #[serde(default)]
277 pub time_to_first_token_ms: Option<u64>,
278 pub usage: Option<TokenUsage>,
280 #[serde(default, skip_serializing_if = "Vec::is_empty")]
293 pub provider_output_items: Vec<serde_json::Value>,
294}
295
296impl InferenceResult {
297 pub fn has_tool_calls(&self) -> bool {
299 !self.tool_calls.is_empty()
300 }
301}
302
303#[derive(Debug, Clone, Serialize)]
304pub struct SpeechRuntimeHealth {
305 pub root: PathBuf,
306 pub installed: bool,
307 pub python: PathBuf,
308 pub stt_command: PathBuf,
309 pub tts_command: PathBuf,
310 pub configured_python: Option<String>,
311 pub detected_python: Option<String>,
312}
313
314#[derive(Debug, Clone, Serialize)]
315pub struct SpeechModelHealth {
316 pub id: String,
317 pub name: String,
318 pub provider: String,
319 pub capability: ModelCapability,
320 pub is_local: bool,
321 pub available: bool,
322 pub cached: bool,
323 pub selected_by_default: bool,
324 pub source: String,
325}
326
327#[derive(Debug, Clone, Serialize)]
328pub struct SpeechHealthReport {
329 pub runtime: SpeechRuntimeHealth,
330 pub local_models: Vec<SpeechModelHealth>,
331 pub remote_models: Vec<SpeechModelHealth>,
332 pub elevenlabs_configured: bool,
333 pub prefer_local: bool,
334 pub allow_remote_fallback: bool,
335 pub preferred_local_stt: Option<String>,
336 pub preferred_local_tts: Option<String>,
337 pub preferred_remote_stt: Option<String>,
338 pub preferred_remote_tts: Option<String>,
339 pub local_stt_default: Option<String>,
340 pub local_tts_default: Option<String>,
341 pub remote_stt_default: Option<String>,
342 pub remote_tts_default: Option<String>,
343}
344
345#[derive(Debug, Clone, Serialize)]
346pub struct ModelDefaultHealth {
347 pub capability: ModelCapability,
348 pub configured_model: String,
349 pub available: bool,
350 pub is_local: bool,
351 pub provider: Option<String>,
352}
353
354#[derive(Debug, Clone, Serialize)]
355pub struct ModelProviderHealth {
356 pub provider: String,
357 pub configured: bool,
358 pub local_models: usize,
359 pub remote_models: usize,
360 pub available_models: usize,
361 pub capabilities: Vec<ModelCapability>,
362}
363
364#[derive(Debug, Clone, Serialize)]
365pub struct ModelCapabilityHealth {
366 pub capability: ModelCapability,
367 pub total_models: usize,
368 pub available_models: usize,
369 pub local_available_models: usize,
370 pub remote_available_models: usize,
371}
372
373#[derive(Debug, Clone, Serialize)]
374pub struct RoutingScenarioHealth {
375 pub name: String,
376 pub workload: RoutingWorkload,
377 pub task_family: String,
378 pub has_tools: bool,
379 pub has_vision: bool,
380 pub prefer_local: bool,
381 pub quality_first_cold_start: bool,
382 pub bootstrap_min_task_observations: u64,
383 pub bootstrap_quality_floor: f64,
384 pub model_id: String,
385 pub model_name: String,
386 pub reason: String,
387 pub strategy: RoutingStrategy,
388}
389
390#[derive(Debug, Clone, Serialize)]
391pub struct ModelBenchmarkPriorHealth {
392 pub model_id: String,
393 pub model_name: Option<String>,
394 pub overall_score: f64,
395 pub overall_latency_ms: Option<f64>,
396 pub task_scores: std::collections::HashMap<String, f64>,
397 pub task_latency_ms: std::collections::HashMap<String, f64>,
398 pub source_path: PathBuf,
399}
400
401#[derive(Debug, Clone, Serialize)]
402pub struct ModelHealthReport {
403 pub total_models: usize,
404 pub available_models: usize,
405 pub local_models: usize,
406 pub remote_models: usize,
407 pub defaults: Vec<ModelDefaultHealth>,
408 pub providers: Vec<ModelProviderHealth>,
409 pub capabilities: Vec<ModelCapabilityHealth>,
410 pub routing_prefer_local: bool,
411 pub routing_quality_first_cold_start: bool,
412 pub routing_min_observations: u64,
413 pub routing_bootstrap_min_task_observations: u64,
414 pub routing_bootstrap_quality_floor: f64,
415 pub routing_quality_weight: f64,
416 pub routing_latency_weight: f64,
417 pub routing_cost_weight: f64,
418 pub routing_scenarios: Vec<RoutingScenarioHealth>,
419 pub benchmark_priors: Vec<ModelBenchmarkPriorHealth>,
420 pub speech: SpeechHealthReport,
421}
422
423#[derive(Debug, Clone, Serialize)]
424pub struct SpeechInstallReport {
425 pub name: String,
426 pub hf_repo: String,
427 pub snapshot_path: PathBuf,
428 pub files_downloaded: usize,
429}
430
431#[derive(Debug, Clone, Serialize)]
432pub struct SpeechSmokePathReport {
433 pub path: String,
434 pub tts_model: String,
435 pub stt_model: String,
436 pub audio_path: PathBuf,
437 pub transcript: String,
438}
439
440#[derive(Debug, Clone, Serialize, Default)]
441pub struct SpeechSmokeReport {
442 pub local: Option<SpeechSmokePathReport>,
443 pub remote: Option<SpeechSmokePathReport>,
444 pub skipped: Vec<String>,
445}
446
447#[derive(Debug, Clone, Serialize, Default)]
448pub struct SpeechPolicy {
449 pub prefer_local: bool,
450 pub allow_remote_fallback: bool,
451 pub preferred_local_stt: Option<String>,
452 pub preferred_local_tts: Option<String>,
453 pub preferred_remote_stt: Option<String>,
454 pub preferred_remote_tts: Option<String>,
455}
456
457pub struct InferenceEngine {
462 pub config: InferenceConfig,
463 pub unified_registry: UnifiedRegistry,
465 pub adaptive_router: AdaptiveRouter,
467 pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
469 remote_backend: RemoteBackend,
471 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
477 mlx_backends: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
478 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
483 flux_cache: Arc<backend_cache::BackendCache<backend::mlx_flux::FluxBackend>>,
484 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
486 ltx_cache: Arc<backend_cache::BackendCache<backend::mlx_ltx::LtxBackend>>,
487 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
490 kokoro_cache: Arc<backend_cache::BackendCache<backend::mlx_kokoro::KokoroBackend>>,
491 pub registry: models::ModelRegistry,
493 pub router: ModelRouter,
494 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
495 backend: Arc<RwLock<Option<CandleBackend>>>,
496 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
497 embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
498 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
499 speech_runtime: Arc<Mutex<Option<SpeechRuntime>>>,
500 speech_policy: SpeechPolicy,
501}
502
503impl InferenceEngine {
504 fn preferred_model_for_capability(&self, capability: ModelCapability) -> Option<&str> {
505 match capability {
506 ModelCapability::Generate => self.config.preferred_generation_model.as_deref(),
507 ModelCapability::Embed => self.config.preferred_embedding_model.as_deref(),
508 ModelCapability::Classify => self.config.preferred_classification_model.as_deref(),
509 _ => None,
510 }
511 }
512
513 fn request_needs_vision(req: &GenerateRequest) -> bool {
514 req.images.as_ref().is_some_and(|images| !images.is_empty())
515 || req.messages.as_ref().is_some_and(|messages| {
516 messages
517 .iter()
518 .any(|msg| matches!(msg, Message::UserMultimodal { .. }))
519 })
520 }
521
522 #[allow(dead_code)] fn request_has_video(req: &GenerateRequest) -> bool {
528 let images_have_video = req
529 .images
530 .as_ref()
531 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
532 let messages_have_video = req.messages.as_ref().is_some_and(|messages| {
533 messages.iter().any(|msg| match msg {
534 Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_video),
535 _ => false,
536 })
537 });
538 images_have_video || messages_have_video
539 }
540
541 #[allow(dead_code)] fn request_has_audio(req: &GenerateRequest) -> bool {
546 let images_have_audio = req
547 .images
548 .as_ref()
549 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
550 let messages_have_audio = req.messages.as_ref().is_some_and(|messages| {
551 messages.iter().any(|msg| match msg {
552 Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_audio),
553 _ => false,
554 })
555 });
556 images_have_audio || messages_have_audio
557 }
558
559 pub fn new(config: InferenceConfig) -> Self {
560 let registry = models::ModelRegistry::new(config.models_dir.clone());
561 let hw = HardwareInfo::detect();
562 let router = ModelRouter::new(hw.clone());
563 let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
564 let adaptive_router = AdaptiveRouter::with_default_config(hw);
565 let mut tracker = OutcomeTracker::new();
566 let profiles_path = config.models_dir.join("outcome_profiles.json");
568 if let Ok(n) = tracker.load_from_file(&profiles_path) {
569 if n > 0 {
570 tracing::info!(loaded = n, "loaded persisted model profiles");
571 }
572 }
573 let mut benchmark_models_loaded = 0usize;
574 for path in benchmark_priors_paths(&config.models_dir) {
575 match routing_ext::load_benchmark_priors(&path) {
576 Ok(priors) if !priors.is_empty() => {
577 benchmark_models_loaded += priors.len();
578 routing_ext::apply_benchmark_priors(&mut tracker, &priors);
579 tracing::info!(
580 path = %path.display(),
581 loaded = priors.len(),
582 "loaded benchmark quality priors"
583 );
584 }
585 Ok(_) => {}
586 Err(error) => {
587 tracing::warn!(path = %path.display(), %error, "failed to load benchmark priors");
588 }
589 }
590 }
591 if benchmark_models_loaded > 0 {
592 tracing::info!(
593 loaded = benchmark_models_loaded,
594 "applied benchmark priors to cold-start routing"
595 );
596 }
597 let outcome_tracker = Arc::new(RwLock::new(tracker));
598
599 let remote_backend = RemoteBackend::new();
600
601 Self {
602 config,
603 unified_registry,
604 adaptive_router,
605 outcome_tracker,
606 remote_backend,
607 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
608 mlx_backends: Arc::new(backend_cache::BackendCache::from_env()),
609 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
610 flux_cache: Arc::new(backend_cache::BackendCache::from_env()),
611 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
612 ltx_cache: Arc::new(backend_cache::BackendCache::from_env()),
613 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
614 kokoro_cache: Arc::new(backend_cache::BackendCache::from_env()),
615 registry,
616 router,
617 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
618 backend: Arc::new(RwLock::new(None)),
619 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
620 embedding_backend: Arc::new(RwLock::new(None)),
621 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
622 speech_runtime: Arc::new(Mutex::new(None)),
623 speech_policy: SpeechPolicy {
624 prefer_local: cfg!(all(
625 target_os = "macos",
626 target_arch = "aarch64",
627 not(car_skip_mlx)
628 )),
629 allow_remote_fallback: true,
630 preferred_local_stt: None,
631 preferred_local_tts: None,
632 preferred_remote_stt: None,
633 preferred_remote_tts: None,
634 },
635 }
636 }
637
638 pub async fn init_key_pool(&self) {
641 for schema in self.unified_registry.list() {
643 if schema.is_remote() {
644 self.remote_backend.register_model_keys(schema).await;
645 }
646 }
647
648 let stats_path = self.config.models_dir.join("key_pool_stats.json");
650 if let Ok(n) = self.remote_backend.key_pool.load_stats(&stats_path).await {
651 if n > 0 {
652 tracing::info!(loaded = n, "loaded persisted key pool stats");
653 }
654 }
655
656 let total = self.remote_backend.key_pool.total_keys().await;
657 if total > 0 {
658 tracing::info!(keys = total, "key pool initialized");
659 }
660 }
661
662 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
665 async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
666 let read = self.backend.read().await;
667 if read.is_some() {
668 return Ok(());
669 }
670 drop(read);
671
672 let mut write = self.backend.write().await;
673 if write.is_some() {
674 return Ok(());
675 }
676
677 let model_path = self.registry.ensure_model(model_name).await?;
678 let device = self.config.device.unwrap_or_else(Device::auto);
679 let backend = CandleBackend::load(&model_path, device)?;
680 *write = Some(backend);
681 Ok(())
682 }
683
684 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
687 async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
688 let read = self.embedding_backend.read().await;
689 if read.is_some() {
690 return Ok(());
691 }
692 drop(read);
693
694 let mut write = self.embedding_backend.write().await;
695 if write.is_some() {
696 return Ok(());
697 }
698
699 let embedding_model = self
700 .preferred_model_for_capability(ModelCapability::Embed)
701 .unwrap_or(&self.config.embedding_model);
702 let model_path = self.registry.ensure_model(embedding_model).await?;
703 let device = self.config.device.unwrap_or_else(Device::auto);
704 let backend = EmbeddingBackend::load(&model_path, device)?;
705 *write = Some(backend);
706 Ok(())
707 }
708
709 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
712 async fn ensure_mlx_embedding_backend(&self) -> Result<String, InferenceError> {
713 let embedding_model_name = self
714 .preferred_model_for_capability(ModelCapability::Embed)
715 .unwrap_or(&self.config.embedding_model)
716 .to_string();
717 let schema = self
718 .unified_registry
719 .get(&embedding_model_name)
720 .or_else(|| self.unified_registry.find_by_name(&embedding_model_name))
721 .ok_or_else(|| InferenceError::ModelNotFound(embedding_model_name.clone()))?
722 .clone();
723 self.ensure_mlx_backend(&schema).await?;
724 Ok(schema.id)
725 }
726
727 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
731 async fn ensure_mlx_backend(
732 &self,
733 schema: &ModelSchema,
734 ) -> Result<backend_cache::CachedBackend<backend::MlxBackend>, InferenceError> {
735 if !Self::supports_native_mlx(schema) {
736 return Err(InferenceError::InferenceFailed(format!(
737 "native MLX backend does not support {} ({}) yet; use vLLM-MLX or add a family-specific MLX backend",
738 schema.name, schema.family
739 )));
740 }
741
742 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
743 let size = backend_cache::estimate_model_size(&model_dir);
744 let cache = Arc::clone(&self.mlx_backends);
745 let key = schema.id.clone();
746 cache.get_or_load(&key, size, move || {
750 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
751 backend::MlxBackend::load(&model_dir)
752 }))
753 .map_err(|e| {
754 InferenceError::InferenceFailed(format!(
755 "MLX backend loading panicked (possible Metal/accelerate exception): {:?}",
756 e
757 ))
758 })?
759 })
760 }
761
762 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
769 pub async fn warm_up<S: AsRef<str>>(
770 &self,
771 schema_ids: &[S],
772 ) -> Vec<Result<(), InferenceError>> {
773 let mut results = Vec::with_capacity(schema_ids.len());
774 for id in schema_ids {
775 let id = id.as_ref();
776 let outcome: Result<(), InferenceError> = async {
777 let schema = self.unified_registry.get(id).cloned().ok_or_else(|| {
778 InferenceError::InferenceFailed(format!("warm_up: unknown schema id {id}"))
779 })?;
780 match schema.capabilities.first().copied() {
781 Some(ModelCapability::ImageGeneration) => {
782 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
783 let size = backend_cache::estimate_model_size(&model_dir);
784 let _ = self.flux_cache.get_or_load(&schema.id, size, || {
785 backend::mlx_flux::FluxBackend::load(&model_dir)
786 })?;
787 }
788 Some(ModelCapability::VideoGeneration) => {
789 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
790 let size = backend_cache::estimate_model_size(&model_dir);
791 let _ = self.ltx_cache.get_or_load(&schema.id, size, || {
792 backend::mlx_ltx::LtxBackend::load(&model_dir)
793 })?;
794 }
795 Some(ModelCapability::TextToSpeech) => {
796 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
797 let size = backend_cache::estimate_model_size(&model_dir);
798 let _ = self.kokoro_cache.get_or_load(&schema.id, size, || {
799 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
800 })?;
801 }
802 _ => {
803 let _ = self.ensure_mlx_backend(&schema).await?;
804 }
805 }
806 Ok(())
807 }
808 .await;
809 results.push(outcome);
810 }
811 results
812 }
813
814 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
816 pub async fn warm_up<S: AsRef<str>>(
817 &self,
818 _schema_ids: &[S],
819 ) -> Vec<Result<(), InferenceError>> {
820 Vec::new()
821 }
822
823 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
824 fn supports_native_mlx(schema: &ModelSchema) -> bool {
825 matches!(schema.family.as_str(), "qwen3" | "qwen2.5-vl" | "qwen2-vl")
826 }
827
828 pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
830 if let Some(model) = self.preferred_model_for_capability(ModelCapability::Generate) {
831 let ctx_len = self
832 .unified_registry
833 .get(model)
834 .or_else(|| self.unified_registry.find_by_name(model))
835 .map(|s| s.context_length)
836 .unwrap_or(0);
837 return AdaptiveRoutingDecision {
838 model_id: model.to_string(),
839 model_name: model.to_string(),
840 task: InferenceTask::Generate,
841 complexity: TaskComplexity::assess(prompt),
842 reason: "preferred generation model override".into(),
843 strategy: RoutingStrategy::Explicit,
844 predicted_quality: 0.5,
845 fallbacks: vec![],
846 context_length: ctx_len,
847 needs_compaction: false,
848 };
849 }
850 let tracker = self.outcome_tracker.read().await;
851 self.adaptive_router
852 .route(prompt, &self.unified_registry, &tracker)
853 }
854
855 pub fn route(&self, prompt: &str) -> RoutingDecision {
857 self.router.route_generate(prompt, &self.registry)
858 }
859
860 pub fn estimated_tokens(
863 &self,
864 req: &GenerateRequest,
865 model_id: Option<&str>,
866 ) -> (usize, usize, bool) {
867 let prompt_tokens = remote::estimate_tokens(&req.prompt);
868 let context_tokens = req
869 .context
870 .as_ref()
871 .map(|c| remote::estimate_tokens(c))
872 .unwrap_or(0);
873 let tools_tokens = req
874 .tools
875 .as_ref()
876 .map(|t| remote::estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
877 .unwrap_or(0);
878 let total_input = prompt_tokens + context_tokens + tools_tokens;
879
880 let context_window = model_id
881 .and_then(|id| {
882 self.unified_registry
883 .get(id)
884 .or_else(|| self.unified_registry.find_by_name(id))
885 })
886 .map(|s| s.context_length)
887 .unwrap_or(0);
888
889 let fits = context_window == 0 || (total_input + req.params.max_tokens) <= context_window;
890 (total_input, context_window, fits)
891 }
892
893 #[instrument(
896 name = "inference.generate",
897 skip_all,
898 fields(
899 model = tracing::field::Empty,
900 max_tokens = req.params.max_tokens,
901 prompt_tokens = tracing::field::Empty,
902 completion_tokens = tracing::field::Empty,
903 latency_ms = tracing::field::Empty,
904 )
905 )]
906 pub async fn generate_tracked(
907 &self,
908 req: GenerateRequest,
909 ) -> Result<InferenceResult, InferenceError> {
910 let start = Instant::now();
911
912 let (estimated_input, _, _) = self.estimated_tokens(&req, None);
914 let tracker_read = self.outcome_tracker.read().await;
915 let has_tools = req.tools.is_some();
916 let has_vision = Self::request_needs_vision(&req);
917 let preferred_model = self
918 .preferred_model_for_capability(ModelCapability::Generate)
919 .map(str::to_string);
920 let decision = match req.model.clone().or(preferred_model) {
921 Some(m) => {
922 let ctx_len = self
923 .unified_registry
924 .get(&m)
925 .or_else(|| self.unified_registry.find_by_name(&m))
926 .map(|s| s.context_length)
927 .unwrap_or(0);
928 AdaptiveRoutingDecision {
929 model_id: m.clone(),
930 model_name: m.clone(),
931 task: InferenceTask::Generate,
932 complexity: TaskComplexity::assess(&req.prompt),
933 reason: "explicit model".into(),
934 strategy: RoutingStrategy::Explicit,
935 predicted_quality: 0.5,
936 fallbacks: vec![],
937 context_length: ctx_len,
938 needs_compaction: ctx_len > 0 && estimated_input > ctx_len,
939 }
940 }
941 None => match &req.intent {
942 Some(hint) => self.adaptive_router.route_context_aware_with_intent(
943 &req.prompt,
944 estimated_input,
945 &self.unified_registry,
946 &tracker_read,
947 has_tools,
948 has_vision,
949 req.params.workload,
950 hint,
951 ),
952 None => self.adaptive_router.route_context_aware(
953 &req.prompt,
954 estimated_input,
955 &self.unified_registry,
956 &tracker_read,
957 has_tools,
958 has_vision,
959 req.params.workload,
960 ),
961 },
962 };
963 drop(tracker_read);
964
965 if decision.needs_compaction {
966 tracing::info!(
967 model = %decision.model_name,
968 prompt_tokens = estimated_input,
969 context_window = decision.context_length,
970 "prompt exceeds model context window — compaction or truncation needed"
971 );
972 }
973
974 let trace_id = {
976 let mut tracker = self.outcome_tracker.write().await;
977 tracker.record_start(&decision.model_id, decision.task, &decision.reason)
978 };
979
980 debug!(
981 model = %decision.model_name,
982 strategy = ?decision.strategy,
983 reason = %decision.reason,
984 trace = %trace_id,
985 "adaptive-routed generate request"
986 );
987
988 let mut req = req;
991 if req.params.budget_tokens == 0 && matches!(decision.complexity, TaskComplexity::Complex) {
992 let supports_thinking = self
993 .unified_registry
994 .get(&decision.model_id)
995 .map(|s| {
996 s.supported_params
997 .contains(&schema::GenerateParam::ExtendedThinking)
998 })
999 .unwrap_or(false);
1000 if supports_thinking {
1001 req.params.budget_tokens = 8000;
1002 tracing::info!(model = %decision.model_name, budget = 8000, "auto-enabled extended thinking for complex task");
1003 }
1004 }
1005
1006 let mut models_to_try = vec![decision.model_id.clone()];
1008 models_to_try.extend(decision.fallbacks.iter().cloned());
1009
1010 let mut last_error = None;
1011
1012 for candidate_id in &models_to_try {
1013 #[allow(unused_mut)]
1016 let mut schema = self
1017 .unified_registry
1018 .get(candidate_id)
1019 .or_else(|| self.unified_registry.find_by_name(candidate_id))
1020 .cloned();
1021
1022 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1024 if let Some(ref s) = schema {
1025 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1026 tracing::info!(
1027 from = %s.id, to = %mlx_equiv.id,
1028 "redirecting GGUF model to MLX equivalent on Apple Silicon"
1029 );
1030 schema = Some(mlx_equiv.clone());
1031 }
1032 }
1033
1034 let candidate_name = schema
1035 .as_ref()
1036 .map(|s| s.name.clone())
1037 .unwrap_or_else(|| candidate_id.clone());
1038
1039 let is_remote = schema
1040 .as_ref()
1041 .map(|s| s.is_remote() || s.is_vllm_mlx())
1042 .unwrap_or(false);
1043 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1044
1045 if is_delegated {
1051 let runner = match runner::current_inference_runner() {
1052 Some(r) => r,
1053 None => {
1054 last_error = Some(InferenceError::InferenceFailed(
1055 "model declares ModelSource::Delegated but no inference runner is registered"
1056 .into(),
1057 ));
1058 continue;
1059 }
1060 };
1061 let (tx, mut rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1062 let emitter = runner::EventEmitter::new(tx);
1063 let runner_req = req.clone();
1064 let runner_handle =
1065 tokio::spawn(async move { runner.run(runner_req, emitter).await });
1066 let mut accumulator = stream::StreamAccumulator::default();
1067 while let Some(evt) = rx.recv().await {
1068 accumulator.push(&evt);
1069 }
1070 let (acc_text, acc_tool_calls) = accumulator.finish();
1074 match runner_handle.await {
1075 Ok(Ok(_runner_result)) => {
1076 let elapsed = start.elapsed().as_millis() as u64;
1077 let mut tracker = self.outcome_tracker.write().await;
1078 tracker.record_complete(&trace_id, elapsed, 0, 0);
1079 return Ok(InferenceResult {
1080 text: acc_text,
1081 tool_calls: acc_tool_calls,
1082 bounding_boxes: vec![],
1083 trace_id,
1084 model_used: candidate_name,
1085 latency_ms: elapsed,
1086 time_to_first_token_ms: None,
1087 usage: None,
1088 provider_output_items: vec![],
1089 });
1090 }
1091 Ok(Err(e)) => {
1092 last_error = Some(InferenceError::InferenceFailed(e.to_string()));
1093 continue;
1094 }
1095 Err(join_err) => {
1096 last_error = Some(InferenceError::InferenceFailed(format!(
1097 "runner task panicked: {join_err}"
1098 )));
1099 continue;
1100 }
1101 }
1102 }
1103
1104 let has_tools = req.tools.is_some();
1105
1106 let context = if has_tools
1108 && req.tools.as_ref().map_or(false, |t| {
1109 t.iter().any(|tool| {
1110 tool.get("function")
1111 .and_then(|f| f.get("name"))
1112 .and_then(|n| n.as_str())
1113 == Some("done")
1114 })
1115 }) {
1116 let base = req.context.as_deref().unwrap_or("");
1117 Some(format!(
1118 "{base}\n\nIMPORTANT: When calling the `done` tool, the `result` field MUST contain a DETAILED summary of everything you found and did. This is the ONLY output the user sees. Do NOT just say 'completed' — include specific findings, data, and conclusions."
1119 ))
1120 } else {
1121 req.context.clone()
1122 };
1123
1124 let result = if is_remote {
1125 let schema_val = schema.unwrap();
1126 let _ctx_len = schema_val.context_length;
1127 let temperature = if !schema_val.supported_params.is_empty()
1130 && !schema_val
1131 .supported_params
1132 .contains(&crate::schema::GenerateParam::Temperature)
1133 {
1134 -1.0
1135 } else {
1136 req.params.temperature
1137 };
1138
1139 self.remote_backend
1145 .generate_with_tools_multi(
1146 &schema_val,
1147 &req.prompt,
1148 context.as_deref(),
1149 temperature,
1150 req.params.max_tokens,
1151 req.tools.as_deref(),
1152 req.images.as_deref(),
1153 req.messages.as_deref(),
1154 req.params.tool_choice.as_deref(),
1155 req.params.parallel_tool_calls,
1156 req.params.budget_tokens,
1157 req.cache_control,
1158 req.response_format.as_ref(),
1159 )
1160 .await
1161 .map(|(t, c, u)| (t, c, u, None::<u64>))
1166 } else {
1167 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1168 {
1169 let schema_ref = schema
1173 .as_ref()
1174 .ok_or_else(|| InferenceError::ModelNotFound(candidate_id.clone()))?;
1175
1176 if schema_ref.is_foundation_models() {
1181 if has_tools {
1182 Err(InferenceError::UnsupportedMode {
1183 mode: "tool-use",
1184 backend: "foundation-models",
1185 reason: "tool calling not yet wired through the FoundationModels \
1186 bridge — route to a remote model with ToolUse capability",
1187 })
1188 } else if Self::request_has_video(&req)
1189 || Self::request_has_audio(&req)
1190 || req.images.as_ref().is_some_and(|imgs| !imgs.is_empty())
1191 {
1192 Err(InferenceError::UnsupportedMode {
1193 mode: "multimodal-content",
1194 backend: "foundation-models",
1195 reason: "the FoundationModels bridge currently exposes text-only \
1196 generation — route image/audio/video to a remote VL model",
1197 })
1198 } else {
1199 let prompt = req.prompt.clone();
1200 let instructions = context.clone();
1201 let max_tokens = req.params.max_tokens as u32;
1202 let temperature = req.params.temperature;
1203 tokio::task::spawn_blocking(move || {
1204 crate::backend::foundation_models::generate(
1205 &prompt,
1206 instructions.as_deref(),
1207 max_tokens,
1208 temperature as f32,
1209 )
1210 })
1211 .await
1212 .map_err(|e| {
1213 InferenceError::InferenceFailed(format!(
1214 "FoundationModels task panicked: {e}"
1215 ))
1216 })
1217 .and_then(|r| r)
1218 .map(|text| (text, vec![], None, None))
1219 }
1220 } else if !schema_ref.is_mlx() {
1221 Err(InferenceError::InferenceFailed(format!(
1222 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1223 schema_ref.id
1224 )))
1225 } else if schema_ref.tags.iter().any(|t| t == "mlx-vlm-cli") {
1226 let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
1238 if !has_images {
1239 return Err(InferenceError::UnsupportedMode {
1240 mode: "text-only-on-mlx-vlm-id",
1241 backend: "mlx-vlm-cli",
1242 reason: "the `mlx-vlm/...` model IDs route exclusively \
1243 through the mlx-vlm CLI for image inference. \
1244 For text-only generation, route to a Qwen3 \
1245 text model (`mlx/qwen3-4b:4bit` etc.) — the \
1246 CLI shell-out has higher latency than the \
1247 in-process MLX text tower.",
1248 });
1249 }
1250 let vlm_status = crate::backend::mlx_vlm_cli::runtime_status();
1251 if !vlm_status.is_available() {
1252 return Err(InferenceError::InferenceFailed(vlm_status.user_message()));
1253 }
1254 let repo = match &schema_ref.source {
1255 crate::schema::ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
1256 _ => {
1257 return Err(InferenceError::InferenceFailed(format!(
1258 "model '{}' is tagged mlx-vlm-cli but its \
1259 source isn't ModelSource::Mlx — registry bug",
1260 schema_ref.id
1261 )));
1262 }
1263 };
1264 let imgs = req.images.clone().unwrap_or_default();
1265 let temp = req.params.temperature;
1266 let max_t = req.params.max_tokens;
1267 let prompt = req.prompt.clone();
1268 let text = tokio::task::spawn_blocking(move || {
1269 crate::backend::mlx_vlm_cli::generate(
1270 &repo, &prompt, &imgs, temp, max_t,
1271 )
1272 })
1273 .await
1274 .map_err(|e| {
1275 InferenceError::InferenceFailed(format!(
1276 "mlx_vlm CLI task panicked: {e}"
1277 ))
1278 })??;
1279 let bounding_boxes = parse_boxes(&text);
1280 let latency_ms = start.elapsed().as_millis() as u64;
1281 {
1282 let mut tracker = self.outcome_tracker.write().await;
1283 tracker.record_complete(&trace_id, latency_ms, 0, 0);
1284 }
1285 return Ok(InferenceResult {
1286 text,
1287 tool_calls: vec![],
1288 bounding_boxes,
1289 trace_id: trace_id.clone(),
1290 model_used: schema_ref.id.clone(),
1291 latency_ms,
1292 time_to_first_token_ms: None,
1293 usage: None,
1294 provider_output_items: Vec::new(),
1295 });
1296 } else {
1297 let handle = self.ensure_mlx_backend(schema_ref).await?;
1307 if Self::request_has_video(&req) {
1313 return Err(InferenceError::UnsupportedMode {
1314 mode: "video-content-block",
1315 backend: "native-mlx-qwen25vl",
1316 reason: "Qwen2.5-VL video understanding is on the request surface \
1317 but the video-tokenization path (frame sampling + merger) \
1318 is not yet wired; route to a remote VL provider for now",
1319 });
1320 }
1321 if Self::request_has_audio(&req) {
1322 return Err(InferenceError::UnsupportedMode {
1323 mode: "audio-content-block",
1324 backend: "native-mlx-qwen25vl",
1325 reason: "audio understanding is on the request surface (Gemma 4 \
1326 E2B/E4B and Gemini accept it) but the native MLX path \
1327 for this model does not — route to Gemini or Gemma-4",
1328 });
1329 }
1330 let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
1331 if has_images {
1332 let can_do_vision = {
1333 let guard = handle.lock().map_err(|_| {
1334 InferenceError::InferenceFailed(
1335 "MLX backend mutex poisoned".into(),
1336 )
1337 })?;
1338 guard.supports_capability(crate::schema::ModelCapability::Vision)
1339 };
1340 if !can_do_vision {
1341 return Err(InferenceError::UnsupportedMode {
1351 mode: "image-content-block",
1352 backend: "native-mlx-text",
1353 reason: "this MLX backend is a plain Qwen3 text tower. \
1354 For local image inference, route to \
1355 `mlx-vlm/qwen3-vl-2b:bf16` or another `mlx-vlm/...` \
1356 catalog ID so CAR shells out to `mlx_vlm.generate`. \
1357 Alternatives: a local vLLM-MLX VLM server, or a \
1358 remote VL model. (#115)",
1359 });
1360 }
1361 }
1362 self.generate_mlx(req.clone(), &schema_ref.id)
1363 .await
1364 .map(|(text, ttft)| (text, vec![], None, ttft))
1365 }
1366 }
1367
1368 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1369 {
1370 match self.ensure_backend(&candidate_name).await {
1371 Ok(()) => {
1372 let mut write = self.backend.write().await;
1373 let backend = write.as_mut().unwrap();
1374 tasks::generate::generate(backend, req.clone())
1375 .await
1376 .map(|(text, ttft)| (text, vec![], None, ttft))
1377 }
1378 Err(e) => Err(e),
1379 }
1380 }
1381 };
1382
1383 match result {
1384 Ok((text, tool_calls, usage, time_to_first_token_ms)) => {
1385 let latency_ms = start.elapsed().as_millis() as u64;
1386 let estimated_tokens = usage
1387 .as_ref()
1388 .map(|u| u.completion_tokens as usize)
1389 .unwrap_or_else(|| text.split_whitespace().count());
1390 {
1391 let mut tracker = self.outcome_tracker.write().await;
1392 tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
1393 }
1394 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1396 cb.record_success(candidate_id);
1397 }
1398 self.auto_save_outcomes().await;
1400
1401 let span = tracing::Span::current();
1403 span.record("model", candidate_name.as_str());
1404 span.record("latency_ms", latency_ms);
1405 if let Some(ttft) = time_to_first_token_ms {
1406 span.record("ttft_ms", ttft);
1407 }
1408 if let Some(ref u) = usage {
1409 span.record("prompt_tokens", u.prompt_tokens);
1410 span.record("completion_tokens", u.completion_tokens);
1411 }
1412
1413 let bounding_boxes = tasks::grounding::parse_boxes(&text);
1416 return Ok(InferenceResult {
1417 text,
1418 tool_calls,
1419 bounding_boxes,
1420 trace_id,
1421 model_used: candidate_name,
1422 latency_ms,
1423 time_to_first_token_ms,
1424 usage,
1425 provider_output_items: Vec::new(),
1426 });
1427 }
1428 Err(e) => {
1429 tracing::warn!(
1430 model = %candidate_name,
1431 error = %e,
1432 remaining = models_to_try.len().saturating_sub(
1433 models_to_try.iter().position(|m| m == candidate_id).unwrap_or(0) + 1
1434 ),
1435 "model failed, trying next fallback immediately"
1436 );
1437 {
1439 let mut tracker = self.outcome_tracker.write().await;
1440 let fail_trace =
1441 tracker.record_start(candidate_id, decision.task, "fallback");
1442 tracker.record_failure(&fail_trace, &e.to_string());
1443 }
1444 {
1448 let err_str = e.to_string();
1449 let is_client_error =
1450 err_str.contains("API returned 4") && !err_str.contains("429");
1451 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1452 cb.record_failure(candidate_id);
1453 if is_client_error {
1456 cb.record_failure(candidate_id);
1457 }
1458 }
1459 }
1460 #[cfg(not(all(
1462 target_os = "macos",
1463 target_arch = "aarch64",
1464 not(car_skip_mlx)
1465 )))]
1466 {
1467 let mut write = self.backend.write().await;
1468 *write = None;
1469 }
1470 last_error = Some(e);
1471 }
1472 }
1473 }
1474
1475 let underlying = last_error.unwrap_or(InferenceError::InferenceFailed(
1477 "no models available".into(),
1478 ));
1479
1480 let e = match no_backend_recovery_hint(&underlying.to_string()) {
1489 Some(msg) => InferenceError::InferenceFailed(msg),
1490 None => underlying,
1491 };
1492 {
1493 let mut tracker = self.outcome_tracker.write().await;
1494 tracker.record_failure(&trace_id, &e.to_string());
1495 }
1496 self.auto_save_outcomes().await;
1497 Err(e)
1498 }
1499
1500 pub async fn generate_tracked_stream(
1537 &self,
1538 req: GenerateRequest,
1539 ) -> Result<tokio::sync::mpsc::Receiver<stream::StreamEvent>, InferenceError> {
1540 let has_tools = req.tools.is_some();
1541 let has_vision = Self::request_needs_vision(&req);
1542 let preferred_model = self
1543 .preferred_model_for_capability(ModelCapability::Generate)
1544 .map(str::to_string);
1545 let decision = match req.model.clone().or(preferred_model) {
1546 Some(m) => {
1547 let ctx_len = self
1548 .unified_registry
1549 .get(&m)
1550 .or_else(|| self.unified_registry.find_by_name(&m))
1551 .map(|s| s.context_length)
1552 .unwrap_or(0);
1553 AdaptiveRoutingDecision {
1554 model_id: m.clone(),
1555 model_name: m,
1556 task: InferenceTask::Generate,
1557 complexity: TaskComplexity::assess(&req.prompt),
1558 reason: "explicit model".into(),
1559 strategy: RoutingStrategy::Explicit,
1560 predicted_quality: 0.5,
1561 fallbacks: vec![],
1562 context_length: ctx_len,
1563 needs_compaction: false,
1564 }
1565 }
1566 None => {
1567 let tracker_read = self.outcome_tracker.read().await;
1568 if has_vision {
1569 self.adaptive_router.route_with_vision(
1570 &req.prompt,
1571 &self.unified_registry,
1572 &tracker_read,
1573 has_tools,
1574 )
1575 } else if has_tools {
1576 self.adaptive_router.route_with_tools(
1577 &req.prompt,
1578 &self.unified_registry,
1579 &tracker_read,
1580 )
1581 } else {
1582 self.adaptive_router
1583 .route(&req.prompt, &self.unified_registry, &tracker_read)
1584 }
1585 }
1586 };
1587
1588 #[allow(unused_mut)]
1591 let mut schema = self
1592 .unified_registry
1593 .get(&decision.model_id)
1594 .or_else(|| self.unified_registry.find_by_name(&decision.model_id))
1595 .cloned();
1596
1597 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1599 if let Some(ref s) = schema {
1600 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1601 tracing::info!(
1602 from = %s.id, to = %mlx_equiv.id,
1603 "redirecting GGUF model to MLX equivalent on Apple Silicon (stream)"
1604 );
1605 schema = Some(mlx_equiv.clone());
1606 }
1607 }
1608
1609 let is_remote = schema
1610 .as_ref()
1611 .map(|s| s.is_remote() || s.is_vllm_mlx())
1612 .unwrap_or(false);
1613
1614 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1615
1616 if is_delegated {
1617 let runner = runner::current_inference_runner().ok_or_else(|| {
1621 InferenceError::InferenceFailed(
1622 "model declares ModelSource::Delegated but no inference runner is registered \
1623 (call set_inference_runner / registerInferenceRunner / register_inference_runner)"
1624 .into(),
1625 )
1626 })?;
1627 let (tx, rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1628 let emitter = runner::EventEmitter::new(tx);
1629 let request = req.clone();
1630 tokio::spawn(async move {
1631 if let Err(e) = runner.run(request, emitter).await {
1632 tracing::warn!(error = %e, "delegated inference runner failed");
1633 }
1634 });
1635 return Ok(rx);
1636 }
1637
1638 if is_remote {
1639 let schema = schema.unwrap();
1640 self.remote_backend.register_model_keys(&schema).await;
1642
1643 self.remote_backend
1644 .generate_stream(
1645 &schema,
1646 &req.prompt,
1647 req.context.as_deref(),
1648 req.params.temperature,
1649 req.params.max_tokens,
1650 req.tools.as_deref(),
1651 req.images.as_deref(),
1652 req.params.tool_choice.as_deref(),
1653 req.params.parallel_tool_calls,
1654 req.response_format.as_ref(),
1655 )
1656 .await
1657 } else {
1658 let schema =
1659 schema.ok_or_else(|| InferenceError::ModelNotFound(decision.model_id.clone()))?;
1660 let (tx, rx) = tokio::sync::mpsc::channel(64);
1661
1662 #[cfg(any(
1668 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
1669 all(target_os = "ios", target_arch = "aarch64")
1670 ))]
1671 {
1672 if schema.is_foundation_models() {
1673 let prompt = req.prompt.clone();
1674 let instructions = req.context.clone();
1675 let max_tokens = req.params.max_tokens as u32;
1676 let temperature = req.params.temperature;
1677 let tx_clone = tx.clone();
1678 tokio::task::spawn_blocking(move || {
1679 let accum = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
1686 let accum_cb = accum.clone();
1687 let cb = crate::backend::foundation_models::StreamCallback::new(
1688 move |delta: &str| {
1689 if let Ok(mut g) = accum_cb.lock() {
1690 g.push_str(delta);
1691 }
1692 tx_clone
1693 .blocking_send(stream::StreamEvent::TextDelta(
1694 delta.to_string(),
1695 ))
1696 .is_ok()
1697 },
1698 );
1699 let result = crate::backend::foundation_models::stream(
1700 &prompt,
1701 instructions.as_deref(),
1702 max_tokens,
1703 temperature as f32,
1704 cb,
1705 );
1706 let final_text = accum.lock().map(|g| g.clone()).unwrap_or_default();
1707 let _ = tx.blocking_send(stream::StreamEvent::Done {
1708 text: final_text,
1709 tool_calls: vec![],
1710 });
1711 result
1712 });
1713 return Ok(rx);
1714 }
1715 }
1716
1717 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1719 {
1720 if !schema.is_mlx() {
1723 return Err(InferenceError::InferenceFailed(format!(
1724 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1725 schema.id
1726 )));
1727 }
1728 let backend = self.ensure_mlx_backend(&schema).await?;
1729 let model_id = schema.id.clone();
1730 let cache = Arc::clone(&self.mlx_backends);
1731 tokio::task::spawn_blocking(move || {
1737 let _ = Self::stream_local_mlx(backend, cache, model_id, req, tx);
1738 });
1739 return Ok(rx);
1740 }
1741
1742 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1743 {
1744 self.ensure_backend(&schema.name).await?;
1745 let backend = self.backend.clone();
1746 tokio::spawn(async move {
1747 let _ = Self::stream_local_candle(backend, req, tx).await;
1748 });
1749 Ok(rx)
1750 }
1751 }
1752 }
1753
1754 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1755 async fn stream_local_candle(
1756 backend_lock: Arc<RwLock<Option<CandleBackend>>>,
1757 req: GenerateRequest,
1758 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1759 ) -> Result<(), InferenceError> {
1760 let mut write = backend_lock.write().await;
1761 let backend = write
1762 .as_mut()
1763 .ok_or_else(|| InferenceError::InferenceFailed("backend not initialized".into()))?;
1764 backend.clear_kv_cache();
1765
1766 let formatted = tasks::generate::apply_chat_template(
1767 &req.prompt,
1768 req.context.as_deref(),
1769 req.params.thinking,
1770 );
1771 let tokens = backend.encode(&formatted)?;
1772 let eos = backend.eos_token_id();
1773 let eos_alt = backend.token_id("<|im_end|>");
1774 let params = &req.params;
1775
1776 if tokens.is_empty() {
1777 let _ = tx
1778 .send(stream::StreamEvent::Done {
1779 text: String::new(),
1780 tool_calls: vec![],
1781 })
1782 .await;
1783 return Ok(());
1784 }
1785
1786 let max_ctx = backend.context_length().unwrap_or(32768);
1787 let headroom = params.max_tokens.min(max_ctx / 4);
1788 let max_prompt = max_ctx.saturating_sub(headroom);
1789 let tokens = if tokens.len() > max_prompt {
1790 tokens[tokens.len() - max_prompt..].to_vec()
1791 } else {
1792 tokens
1793 };
1794
1795 let mut generated = Vec::new();
1796 let logits = backend.forward(&tokens, 0)?;
1797 let mut next_token = tasks::generate::sample_token(&logits, params)?;
1798
1799 for _ in 0..params.max_tokens {
1800 if eos.map_or(false, |id| next_token == id)
1801 || eos_alt.map_or(false, |id| next_token == id)
1802 {
1803 break;
1804 }
1805
1806 generated.push(next_token);
1807 let delta = backend.decode(&[next_token])?;
1808 if !delta.is_empty()
1809 && tx
1810 .send(stream::StreamEvent::TextDelta(delta))
1811 .await
1812 .is_err()
1813 {
1814 return Ok(());
1815 }
1816
1817 if !params.stop.is_empty() {
1818 let text_so_far = backend.decode(&generated)?;
1819 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1820 break;
1821 }
1822 }
1823
1824 let pos = tokens.len() + generated.len() - 1;
1825 let logits = backend.forward(&[next_token], pos)?;
1826 next_token = tasks::generate::sample_token(&logits, params)?;
1827 }
1828
1829 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1830 let _ = tx
1831 .send(stream::StreamEvent::Done {
1832 text,
1833 tool_calls: vec![],
1834 })
1835 .await;
1836 Ok(())
1837 }
1838
1839 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1843 fn stream_local_mlx(
1844 handle: backend_cache::CachedBackend<backend::MlxBackend>,
1845 cache: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
1846 model_id: String,
1847 req: GenerateRequest,
1848 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1849 ) -> Result<(), InferenceError> {
1850 let mut guard = handle.lock().map_err(|_| {
1851 InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
1852 })?;
1853 let backend: &mut backend::MlxBackend = &mut *guard;
1854 backend.clear_kv_cache();
1855
1856 let formatted = tasks::generate::apply_chat_template(
1857 &req.prompt,
1858 req.context.as_deref(),
1859 req.params.thinking,
1860 );
1861 let tokens = backend.encode(&formatted)?;
1862 let eos = backend.eos_token_id();
1863 let eos_alt = backend.token_id("<|im_end|>");
1864 let params = &req.params;
1865
1866 if tokens.is_empty() {
1867 let _ = tx.blocking_send(stream::StreamEvent::Done {
1868 text: String::new(),
1869 tool_calls: vec![],
1870 });
1871 return Ok(());
1872 }
1873
1874 let max_ctx = backend.context_length();
1875 let headroom = params.max_tokens.min(max_ctx / 4);
1876 let max_prompt = max_ctx.saturating_sub(headroom);
1877 let tokens = if tokens.len() > max_prompt {
1878 tokens[tokens.len() - max_prompt..].to_vec()
1879 } else {
1880 tokens
1881 };
1882
1883 let mut generated = Vec::new();
1884 let logits = match Self::catch_mlx("stream prefill", || backend.forward(&tokens, 0)) {
1889 Ok(v) => v,
1890 Err(e) => {
1891 cache.invalidate(&model_id);
1892 return Err(e);
1893 }
1894 };
1895 let mut next_token = Self::sample_from_logits(&logits, params)?;
1896
1897 for _ in 0..params.max_tokens {
1898 if eos.map_or(false, |id| next_token == id)
1899 || eos_alt.map_or(false, |id| next_token == id)
1900 {
1901 break;
1902 }
1903
1904 generated.push(next_token);
1905 let delta = backend.decode(&[next_token])?;
1906 if !delta.is_empty()
1907 && tx
1908 .blocking_send(stream::StreamEvent::TextDelta(delta))
1909 .is_err()
1910 {
1911 return Ok(());
1912 }
1913
1914 if !params.stop.is_empty() {
1915 let text_so_far = backend.decode(&generated)?;
1916 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1917 break;
1918 }
1919 }
1920
1921 let pos = tokens.len() + generated.len() - 1;
1922 let logits =
1923 match Self::catch_mlx("stream forward", || backend.forward(&[next_token], pos)) {
1924 Ok(v) => v,
1925 Err(e) => {
1926 cache.invalidate(&model_id);
1927 return Err(e);
1928 }
1929 };
1930 next_token = Self::sample_from_logits(&logits, params)?;
1931 }
1932
1933 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1934 let _ = tx.blocking_send(stream::StreamEvent::Done {
1935 text,
1936 tool_calls: vec![],
1937 });
1938 Ok(())
1939 }
1940
1941 pub async fn route_context_snapshot(
1943 &self,
1944 prompt: &str,
1945 workload: RoutingWorkload,
1946 has_tools: bool,
1947 has_vision: bool,
1948 ) -> AdaptiveRoutingDecision {
1949 let tracker = self.outcome_tracker.read().await;
1950 self.adaptive_router.route_context_aware(
1951 prompt,
1952 0,
1953 &self.unified_registry,
1954 &tracker,
1955 has_tools,
1956 has_vision,
1957 workload,
1958 )
1959 }
1960
1961 pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
1964 Ok(self.generate_tracked(req).await?.text)
1965 }
1966
1967 pub async fn tokenize(&self, model: &str, text: &str) -> Result<Vec<u32>, InferenceError> {
1979 self.assert_local_for_tokenize(model)?;
1980
1981 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1982 {
1983 let schema = self
1984 .unified_registry
1985 .get(model)
1986 .or_else(|| self.unified_registry.find_by_name(model))
1987 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1988 .clone();
1989 let handle = self.ensure_mlx_backend(&schema).await?;
1990 let guard = handle.lock().map_err(|_| {
1991 InferenceError::InferenceFailed(format!(
1992 "MLX backend mutex poisoned for {}",
1993 schema.id
1994 ))
1995 })?;
1996 return guard.tokenize_raw(text);
1997 }
1998
1999 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2000 {
2001 self.ensure_backend(model).await?;
2002 let read = self.backend.read().await;
2003 let backend = read.as_ref().ok_or_else(|| {
2004 InferenceError::InferenceFailed(
2005 "candle backend missing after ensure_backend".to_string(),
2006 )
2007 })?;
2008 backend.tokenize_raw(text)
2009 }
2010 }
2011
2012 pub async fn detokenize(&self, model: &str, tokens: &[u32]) -> Result<String, InferenceError> {
2014 self.assert_local_for_tokenize(model)?;
2015
2016 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2017 {
2018 let schema = self
2019 .unified_registry
2020 .get(model)
2021 .or_else(|| self.unified_registry.find_by_name(model))
2022 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
2023 .clone();
2024 let handle = self.ensure_mlx_backend(&schema).await?;
2025 let guard = handle.lock().map_err(|_| {
2026 InferenceError::InferenceFailed(format!(
2027 "MLX backend mutex poisoned for {}",
2028 schema.id
2029 ))
2030 })?;
2031 return guard.detokenize_raw(tokens);
2032 }
2033
2034 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2035 {
2036 self.ensure_backend(model).await?;
2037 let read = self.backend.read().await;
2038 let backend = read.as_ref().ok_or_else(|| {
2039 InferenceError::InferenceFailed(
2040 "candle backend missing after ensure_backend".to_string(),
2041 )
2042 })?;
2043 backend.detokenize_raw(tokens)
2044 }
2045 }
2046
2047 fn assert_local_for_tokenize(&self, model: &str) -> Result<(), InferenceError> {
2051 if let Some(schema) = self
2052 .unified_registry
2053 .get(model)
2054 .or_else(|| self.unified_registry.find_by_name(model))
2055 {
2056 if !schema.is_local() {
2057 return Err(InferenceError::UnsupportedMode {
2058 mode: "tokenize/detokenize",
2059 backend: "remote",
2060 reason: "remote provider tokenizer is not exposed by the runtime; \
2061 use a local model (Qwen3 GGUF / MLX) for tokenizer-correctness checks",
2062 });
2063 }
2064 }
2065 Ok(())
2067 }
2068
2069 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2075 fn catch_mlx<F, T>(context: &str, f: F) -> Result<T, InferenceError>
2076 where
2077 F: FnOnce() -> Result<T, InferenceError>,
2078 {
2079 std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)).map_err(|e| {
2080 InferenceError::InferenceFailed(format!("MLX panicked during {context}: {e:?}"))
2081 })?
2082 }
2083
2084 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2087 async fn generate_mlx(
2088 &self,
2089 req: GenerateRequest,
2090 model_id: &str,
2091 ) -> Result<(String, Option<u64>), InferenceError> {
2092 let start = std::time::Instant::now();
2093
2094 let schema = self
2095 .unified_registry
2096 .get(model_id)
2097 .cloned()
2098 .ok_or_else(|| {
2099 InferenceError::InferenceFailed(format!(
2100 "generate_mlx: unknown schema id {model_id}"
2101 ))
2102 })?;
2103 let handle = self.ensure_mlx_backend(&schema).await?;
2104 let mut guard = handle.lock().map_err(|_| {
2105 InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
2106 })?;
2107 let backend: &mut backend::MlxBackend = &mut *guard;
2108 backend.clear_kv_cache();
2109
2110 let formatted = tasks::generate::apply_chat_template(
2111 &req.prompt,
2112 req.context.as_deref(),
2113 req.params.thinking,
2114 );
2115 let tokens = backend.encode(&formatted)?;
2116 let eos = backend.eos_token_id();
2117 let eos_alt = backend.token_id("<|im_end|>");
2118 let params = &req.params;
2119
2120 if tokens.is_empty() {
2121 return Ok((String::new(), None));
2122 }
2123
2124 let max_ctx = backend.context_length();
2126 let headroom = params.max_tokens.min(max_ctx / 4);
2127 let max_prompt = max_ctx.saturating_sub(headroom);
2128 let tokens = if tokens.len() > max_prompt {
2129 tokens[tokens.len() - max_prompt..].to_vec()
2130 } else {
2131 tokens
2132 };
2133
2134 let mut generated = Vec::new();
2135
2136 let logits = match Self::catch_mlx("prefill", || backend.forward(&tokens, 0)) {
2142 Ok(v) => v,
2143 Err(e) => {
2144 drop(guard);
2145 self.mlx_backends.invalidate(model_id);
2146 return Err(e);
2147 }
2148 };
2149 let mut next_token = Self::sample_from_logits(&logits, params)?;
2150 let ttft_ms = Some(start.elapsed().as_millis() as u64);
2151
2152 for _ in 0..params.max_tokens {
2153 if eos.map_or(false, |id| next_token == id)
2154 || eos_alt.map_or(false, |id| next_token == id)
2155 {
2156 break;
2157 }
2158
2159 generated.push(next_token);
2160
2161 if !params.stop.is_empty() {
2162 let text_so_far = backend.decode(&generated)?;
2163 if params.stop.iter().any(|s| text_so_far.contains(s)) {
2164 break;
2165 }
2166 }
2167
2168 let pos = tokens.len() + generated.len() - 1;
2169 let logits = match Self::catch_mlx("forward", || backend.forward(&[next_token], pos)) {
2170 Ok(v) => v,
2171 Err(e) => {
2172 drop(guard);
2173 self.mlx_backends.invalidate(model_id);
2174 return Err(e);
2175 }
2176 };
2177 next_token = Self::sample_from_logits(&logits, params)?;
2178 }
2179
2180 let text = backend.decode(&generated)?;
2181 Ok((
2182 tasks::generate::strip_thinking(&text, params.thinking),
2183 ttft_ms,
2184 ))
2185 }
2186
2187 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2189 fn sample_from_logits(logits: &[f32], params: &GenerateParams) -> Result<u32, InferenceError> {
2190 if params.temperature <= 0.0 {
2191 let (idx, _) = logits
2193 .iter()
2194 .enumerate()
2195 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2196 .ok_or_else(|| InferenceError::InferenceFailed("empty logits".into()))?;
2197 return Ok(idx as u32);
2198 }
2199
2200 let temp = params.temperature as f32;
2202 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
2203 let mut probs: Vec<f32> = logits
2204 .iter()
2205 .map(|&l| ((l - max_logit) / temp).exp())
2206 .collect();
2207 let sum: f32 = probs.iter().sum();
2208 for p in &mut probs {
2209 *p /= sum;
2210 }
2211
2212 if params.top_p < 1.0 {
2214 let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
2215 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2216 let mut cumsum = 0.0;
2217 let mut cutoff_idx = indexed.len();
2218 for (i, &(_, p)) in indexed.iter().enumerate() {
2219 cumsum += p;
2220 if cumsum > params.top_p as f32 {
2221 cutoff_idx = i + 1;
2222 break;
2223 }
2224 }
2225 let allowed: std::collections::HashSet<usize> =
2227 indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
2228 for (i, p) in probs.iter_mut().enumerate() {
2229 if !allowed.contains(&i) {
2230 *p = 0.0;
2231 }
2232 }
2233 let sum: f32 = probs.iter().sum();
2234 if sum > 0.0 {
2235 for p in &mut probs {
2236 *p /= sum;
2237 }
2238 }
2239 }
2240
2241 use rand::Rng;
2243 let mut rng = rand::rng();
2244 let r: f32 = rng.random();
2245 let mut cumsum = 0.0;
2246 for (i, &p) in probs.iter().enumerate() {
2247 cumsum += p;
2248 if cumsum >= r {
2249 return Ok(i as u32);
2250 }
2251 }
2252 Ok((probs.len() - 1) as u32)
2253 }
2254
2255 pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
2258 let instruction = req
2259 .instruction
2260 .as_deref()
2261 .unwrap_or("Retrieve relevant memory facts");
2262
2263 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2264 {
2265 let model_id = self.ensure_mlx_embedding_backend().await?;
2266 let schema = self
2267 .unified_registry
2268 .get(&model_id)
2269 .cloned()
2270 .ok_or_else(|| {
2271 InferenceError::InferenceFailed(format!("embed: unknown schema id {model_id}"))
2272 })?;
2273 let handle = self.ensure_mlx_backend(&schema).await?;
2274 let mut guard = handle.lock().map_err(|_| {
2275 InferenceError::InferenceFailed(format!(
2276 "MLX embedding backend mutex poisoned for {model_id}"
2277 ))
2278 })?;
2279 let backend: &mut backend::MlxBackend = &mut *guard;
2280
2281 let mut results = Vec::with_capacity(req.texts.len());
2282 for text in &req.texts {
2283 let embedding = if req.is_query {
2284 backend.embed_query(text, instruction)?
2285 } else {
2286 backend.embed_one(text)?
2287 };
2288 results.push(embedding);
2289 }
2290 return Ok(results);
2291 }
2292
2293 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2294 {
2295 self.ensure_embedding_backend().await?;
2296 let mut write = self.embedding_backend.write().await;
2297 let backend = write.as_mut().unwrap();
2298
2299 let mut results = Vec::with_capacity(req.texts.len());
2300 for text in &req.texts {
2301 let embedding = if req.is_query {
2302 backend.embed_query(text, instruction)?
2303 } else {
2304 backend.embed_one(text)?
2305 };
2306 results.push(embedding);
2307 }
2308 Ok(results)
2309 }
2310 }
2311
2312 pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResult, InferenceError> {
2342 if req.documents.is_empty() {
2343 return Ok(RerankResult {
2344 ranked: Vec::new(),
2345 model_used: None,
2346 });
2347 }
2348
2349 let model_name = match req.model.clone() {
2350 Some(m) => m,
2351 None => self
2352 .preferred_model_for_capability(ModelCapability::Rerank)
2353 .map(str::to_string)
2354 .ok_or_else(|| {
2355 InferenceError::InferenceFailed(
2356 "no reranker model available — pull a Qwen3-Reranker model first".into(),
2357 )
2358 })?,
2359 };
2360
2361 let schema = self
2362 .unified_registry
2363 .find_by_name(&model_name)
2364 .or_else(|| self.unified_registry.get(&model_name))
2365 .cloned()
2366 .ok_or_else(|| {
2367 InferenceError::InferenceFailed(format!(
2368 "rerank: unknown reranker model {model_name}"
2369 ))
2370 })?;
2371 if !schema.has_capability(ModelCapability::Rerank) {
2372 return Err(InferenceError::InferenceFailed(format!(
2373 "model {} does not declare the Rerank capability",
2374 schema.name
2375 )));
2376 }
2377
2378 let instruction = req.instruction.as_deref().unwrap_or(
2379 "Given a web search query, retrieve relevant passages that answer the query",
2380 );
2381
2382 let mut scored: Vec<RerankedDocument> = Vec::with_capacity(req.documents.len());
2383 for (idx, doc) in req.documents.iter().enumerate() {
2384 let prompt = rerank_prompt(instruction, &req.query, doc);
2385 let gen_req = GenerateRequest {
2386 prompt,
2387 model: Some(schema.id.clone()),
2388 params: tasks::generate::GenerateParams {
2389 temperature: 0.0,
2390 max_tokens: 3,
2394 thinking: tasks::generate::ThinkingMode::Off,
2395 ..Default::default()
2396 },
2397 context: None,
2398 tools: None,
2399 images: None,
2400 messages: None,
2401 cache_control: false,
2402 response_format: None,
2403 intent: None,
2404 };
2405 let out = self.generate(gen_req).await?;
2406 let score = score_from_rerank_output(&out, &schema.name);
2407 scored.push(RerankedDocument {
2408 index: idx,
2409 score,
2410 document: doc.clone(),
2411 });
2412 }
2413
2414 scored.sort_by(|a, b| {
2417 b.score
2418 .partial_cmp(&a.score)
2419 .unwrap_or(std::cmp::Ordering::Equal)
2420 .then_with(|| a.index.cmp(&b.index))
2421 });
2422 if let Some(n) = req.top_n {
2423 scored.truncate(n);
2424 }
2425
2426 Ok(RerankResult {
2427 ranked: scored,
2428 model_used: Some(schema.name),
2429 })
2430 }
2431
2432 pub async fn ground(&self, req: GroundRequest) -> Result<GroundResult, InferenceError> {
2442 let model_name = match req.model.clone() {
2443 Some(m) => m,
2444 None => self
2445 .preferred_model_for_capability(ModelCapability::Grounding)
2446 .map(str::to_string)
2447 .ok_or_else(|| {
2448 InferenceError::InferenceFailed(
2449 "no grounding-capable model available — pull a Qwen2.5-VL model first"
2450 .into(),
2451 )
2452 })?,
2453 };
2454
2455 let gen_req = GenerateRequest {
2456 prompt: req.prompt.clone(),
2457 model: Some(model_name),
2458 params: GenerateParams::default(),
2459 context: None,
2460 tools: None,
2461 images: Some(vec![req.image.clone()]),
2462 messages: None,
2463 cache_control: false,
2464 response_format: None,
2465 intent: None,
2466 };
2467 let result = self.generate_tracked(gen_req).await?;
2468 Ok(GroundResult {
2469 boxes: result.bounding_boxes,
2470 raw_text: result.text,
2471 model_used: Some(result.model_used),
2472 })
2473 }
2474
2475 pub async fn classify(
2478 &self,
2479 req: ClassifyRequest,
2480 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2481 let model = match req.model.clone().or_else(|| {
2482 self.preferred_model_for_capability(ModelCapability::Classify)
2483 .map(str::to_string)
2484 }) {
2485 Some(m) => m,
2486 None => {
2487 let m = self.router.route_small(&self.registry);
2488 debug!(model = %m, "auto-routed classify request");
2489 m
2490 }
2491 };
2492
2493 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2496 {
2497 return self.classify_via_generate(req, &model).await;
2498 }
2499
2500 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2501 {
2502 self.ensure_backend(&model).await?;
2503 let mut write = self.backend.write().await;
2504 let backend = write.as_mut().unwrap();
2505 tasks::classify::classify(backend, req).await
2506 }
2507 }
2508
2509 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2511 async fn classify_via_generate(
2512 &self,
2513 req: ClassifyRequest,
2514 model: &str,
2515 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2516 let labels_str = req
2517 .labels
2518 .iter()
2519 .enumerate()
2520 .map(|(i, l)| format!("{}. {}", i + 1, l))
2521 .collect::<Vec<_>>()
2522 .join("\n");
2523
2524 let prompt = format!(
2525 "Classify the following text into one of these categories:\n\
2526 {labels_str}\n\n\
2527 Text: {}\n\n\
2528 Respond with ONLY the category name, nothing else.",
2529 req.text
2530 );
2531
2532 let gen_req = GenerateRequest {
2533 prompt,
2534 model: Some(model.to_string()),
2535 params: tasks::generate::GenerateParams {
2536 temperature: 0.0,
2537 max_tokens: 32,
2538 thinking: tasks::generate::ThinkingMode::Off,
2541 ..Default::default()
2542 },
2543 context: None,
2544 tools: None,
2545 images: None,
2546 messages: None,
2547 cache_control: false,
2548 response_format: None,
2549 intent: None,
2550 };
2551
2552 let response = self.generate(gen_req).await?;
2553 let response_lower = response.trim().to_lowercase();
2554
2555 let mut results: Vec<ClassifyResult> = req
2556 .labels
2557 .iter()
2558 .map(|label| {
2559 let label_lower = label.to_lowercase();
2560 let score = if response_lower == label_lower {
2561 1.0
2562 } else if response_lower.contains(&label_lower) {
2563 0.8
2564 } else {
2565 let label_words: Vec<&str> = label_lower.split_whitespace().collect();
2566 let matches = label_words
2567 .iter()
2568 .filter(|w| response_lower.contains(**w))
2569 .count();
2570 if label_words.is_empty() {
2571 0.0
2572 } else {
2573 0.5 * (matches as f64 / label_words.len() as f64)
2574 }
2575 };
2576 ClassifyResult {
2577 label: label.clone(),
2578 score,
2579 }
2580 })
2581 .collect();
2582
2583 results.sort_by(|a, b| {
2584 b.score
2585 .partial_cmp(&a.score)
2586 .unwrap_or(std::cmp::Ordering::Equal)
2587 });
2588
2589 let total: f64 = results.iter().map(|r| r.score).sum();
2590 if total > 0.0 {
2591 for r in &mut results {
2592 r.score /= total;
2593 }
2594 }
2595
2596 Ok(results)
2597 }
2598
2599 pub async fn transcribe(
2601 &self,
2602 req: TranscribeRequest,
2603 ) -> Result<TranscribeResult, InferenceError> {
2604 let candidates =
2605 self.speech_candidates(ModelCapability::SpeechToText, req.model.as_deref())?;
2606 let mut last_error = None;
2607
2608 for schema in candidates {
2609 let result = match &schema.source {
2610 ModelSource::Mlx { .. } => self.transcribe_local_mlx(&schema, &req).await,
2611 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2612 self.transcribe_elevenlabs(&schema, &req).await
2613 }
2614 _ => Err(InferenceError::InferenceFailed(format!(
2615 "speech-to-text not implemented for model source: {}",
2616 schema.id
2617 ))),
2618 };
2619
2620 match result {
2621 Ok(result) => return Ok(result),
2622 Err(err) => last_error = Some(err),
2623 }
2624 }
2625
2626 Err(last_error.unwrap_or_else(|| {
2627 InferenceError::InferenceFailed("no speech-to-text models available".into())
2628 }))
2629 }
2630
2631 pub async fn synthesize(
2633 &self,
2634 req: SynthesizeRequest,
2635 ) -> Result<SynthesizeResult, InferenceError> {
2636 let candidates =
2637 self.speech_candidates(ModelCapability::TextToSpeech, req.model.as_deref())?;
2638 let mut last_error = None;
2639
2640 for schema in candidates {
2641 let result = match &schema.source {
2642 ModelSource::Mlx { .. } => self.synthesize_local_mlx(&schema, &req).await,
2643 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2644 self.synthesize_elevenlabs(&schema, &req).await
2645 }
2646 _ => Err(InferenceError::InferenceFailed(format!(
2647 "text-to-speech not implemented for model source: {}",
2648 schema.id
2649 ))),
2650 };
2651
2652 match result {
2653 Ok(result) => return Ok(result),
2654 Err(err) => last_error = Some(err),
2655 }
2656 }
2657
2658 Err(last_error.unwrap_or_else(|| {
2659 InferenceError::InferenceFailed("no text-to-speech models available".into())
2660 }))
2661 }
2662
2663 pub async fn generate_image(
2665 &self,
2666 req: GenerateImageRequest,
2667 ) -> Result<GenerateImageResult, InferenceError> {
2668 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2675 {
2676 use crate::backend::external_flux;
2677 let backend =
2678 std::env::var("CAR_IMAGE_BACKEND").unwrap_or_else(|_| "native".to_string());
2679 let use_external = match backend.as_str() {
2680 "external" => true,
2681 "native" => false,
2682 _ => external_flux::is_available() && backend == "auto-external",
2684 };
2685 if use_external {
2686 tracing::info!(
2687 "routing image generation to external mflux \
2688 (set CAR_IMAGE_BACKEND=native to use the Rust port)"
2689 );
2690 let mut req = req;
2691 req.model = self.resolve_external_hf_repo(
2692 req.model.as_deref(),
2693 ModelCapability::ImageGeneration,
2694 );
2695 return external_flux::generate_image(&req);
2696 }
2697 tracing::info!("using native Rust MLX Flux backend");
2698 }
2699
2700 let candidates = self
2701 .media_generation_candidates(ModelCapability::ImageGeneration, req.model.as_deref())?;
2702 let mut last_error = None;
2703
2704 for schema in candidates {
2705 let result = match &schema.source {
2706 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2707 ModelSource::Mlx { .. } => self.generate_image_native_mlx(&schema, &req).await,
2708 _ => Err(InferenceError::InferenceFailed(format!(
2709 "image generation not implemented for model source: {}",
2710 schema.id
2711 ))),
2712 };
2713
2714 match result {
2715 Ok(result) => return Ok(result),
2716 Err(err) => last_error = Some(err),
2717 }
2718 }
2719
2720 Err(last_error.unwrap_or_else(|| {
2721 InferenceError::InferenceFailed("no image generation models available".into())
2722 }))
2723 }
2724
2725 pub async fn generate_image_batch(
2741 &self,
2742 req: GenerateImageRequest,
2743 ) -> Result<Vec<GenerateImageResult>, InferenceError> {
2744 let count = req.variant_count.unwrap_or(1).max(1);
2745 if count == 1 {
2746 return self.generate_image(req).await.map(|r| vec![r]);
2747 }
2748 let base_seed = req.seed.unwrap_or(0);
2749 let mut results = Vec::with_capacity(count as usize);
2750 for i in 0..count {
2751 let mut variant_req = req.clone();
2756 variant_req.seed = Some(base_seed.wrapping_add(i as u64));
2757 variant_req.variant_count = Some(1);
2761 results.push(self.generate_image(variant_req).await?);
2762 }
2763 Ok(results)
2764 }
2765
2766 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2768 async fn generate_image_native_mlx(
2769 &self,
2770 schema: &ModelSchema,
2771 req: &GenerateImageRequest,
2772 ) -> Result<GenerateImageResult, InferenceError> {
2773 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2774 let size = backend_cache::estimate_model_size(&model_dir);
2775 let cache = Arc::clone(&self.flux_cache);
2776 let key = schema.id.clone();
2777 let handle = cache.get_or_load(&key, size, || {
2778 backend::mlx_flux::FluxBackend::load(&model_dir)
2779 })?;
2780 let req = req.clone();
2784 tokio::task::spawn_blocking(move || -> Result<GenerateImageResult, InferenceError> {
2785 let mut guard = handle.lock().map_err(|_| {
2786 InferenceError::InferenceFailed("flux backend mutex poisoned".into())
2787 })?;
2788 guard.generate(&req)
2789 })
2790 .await
2791 .map_err(|e| InferenceError::InferenceFailed(format!("flux task join: {e}")))?
2792 }
2793
2794 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2796 async fn generate_video_native_mlx(
2797 &self,
2798 schema: &ModelSchema,
2799 req: &GenerateVideoRequest,
2800 ) -> Result<GenerateVideoResult, InferenceError> {
2801 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2802 let size = backend_cache::estimate_model_size(&model_dir);
2803 let cache = Arc::clone(&self.ltx_cache);
2804 let key = schema.id.clone();
2805 let handle = cache.get_or_load(&key, size, || {
2806 backend::mlx_ltx::LtxBackend::load(&model_dir)
2807 })?;
2808 let req = req.clone();
2809 tokio::task::spawn_blocking(move || -> Result<GenerateVideoResult, InferenceError> {
2810 let mut guard = handle.lock().map_err(|_| {
2811 InferenceError::InferenceFailed("ltx backend mutex poisoned".into())
2812 })?;
2813 guard.generate(&req)
2814 })
2815 .await
2816 .map_err(|e| InferenceError::InferenceFailed(format!("ltx task join: {e}")))?
2817 }
2818
2819 pub async fn generate_video(
2821 &self,
2822 req: GenerateVideoRequest,
2823 ) -> Result<GenerateVideoResult, InferenceError> {
2824 if let Err(msg) = req.validate() {
2827 return Err(InferenceError::InferenceFailed(format!(
2828 "invalid GenerateVideoRequest: {}",
2829 msg
2830 )));
2831 }
2832 let requires_audio_conditioning = req.requires_audio_passthrough_opt_in();
2833 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2838 {
2839 use crate::backend::external_ltx;
2840 let backend =
2845 std::env::var("CAR_VIDEO_BACKEND").unwrap_or_else(|_| "native".to_string());
2846 let use_external = match backend.as_str() {
2847 "external" => true,
2848 "native" => false,
2849 "auto-external" => external_ltx::is_available(),
2853 _ => false,
2854 };
2855 if use_external {
2856 tracing::info!(
2857 "CAR_VIDEO_BACKEND requested external LTX routing for LTX-family models"
2858 );
2859 } else {
2860 tracing::info!("using family-aware MLX video routing");
2861 }
2862 }
2863
2864 let candidates = self
2865 .media_generation_candidates(ModelCapability::VideoGeneration, req.model.as_deref())?;
2866 let mut last_error = None;
2867
2868 for schema in candidates {
2869 let result = match &schema.source {
2870 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2871 ModelSource::Mlx { hf_repo, .. } => {
2872 if crate::backend::external_mlx_video::is_wan_family(&schema) {
2873 match self.unified_registry.ensure_local(&schema.id).await {
2874 Ok(model_dir) => {
2875 crate::backend::external_mlx_video::generate_wan_video(
2876 &schema, &model_dir, &req,
2877 )
2878 }
2879 Err(err) => Err(err),
2880 }
2881 } else {
2882 let backend = std::env::var("CAR_VIDEO_BACKEND")
2883 .unwrap_or_else(|_| "native".to_string());
2884 let use_external_ltx = match backend.as_str() {
2885 "external" => true,
2886 "native" => false,
2887 "auto-external" => crate::backend::external_ltx::is_available(),
2888 _ => false,
2889 };
2890 let use_external_ltx = use_external_ltx || requires_audio_conditioning;
2891 if requires_audio_conditioning
2892 && !crate::backend::external_ltx::is_available()
2893 {
2894 return Err(InferenceError::InferenceFailed(
2895 "audio-reference video conditioning requires the external `ltx-2-mlx a2v` CLI on PATH"
2896 .to_string(),
2897 ));
2898 }
2899 if use_external_ltx {
2900 let mut req = req.clone();
2901 req.model = Some(hf_repo.clone());
2902 crate::backend::external_ltx::generate_video(&req)
2903 } else {
2904 self.generate_video_native_mlx(&schema, &req).await
2905 }
2906 }
2907 }
2908 _ => Err(InferenceError::InferenceFailed(format!(
2909 "video generation not implemented for model source: {}",
2910 schema.id
2911 ))),
2912 };
2913
2914 match result {
2915 Ok(result) => return Ok(result),
2916 Err(err) => last_error = Some(err),
2917 }
2918 }
2919
2920 Err(last_error.unwrap_or_else(|| {
2921 InferenceError::InferenceFailed("no video generation models available".into())
2922 }))
2923 }
2924
2925 pub fn list_models_unified(&self) -> Vec<ModelInfo> {
2927 self.unified_registry
2928 .list()
2929 .iter()
2930 .map(|m| ModelInfo::from(*m))
2931 .collect()
2932 }
2933
2934 pub fn available_model_upgrades(&self) -> Vec<ModelUpgrade> {
2936 self.unified_registry.available_upgrades()
2937 }
2938
2939 pub async fn check_upgrade_nudge(
2945 &self,
2946 inference_active: bool,
2947 ) -> (crate::nudge::NudgeDecision, crate::nudge::NudgeState) {
2948 let findings = self.detect_upgrades().await;
2949 let prefs = self.update_prefs();
2950 let state = crate::nudge::NudgeState::load_from(&crate::nudge::NudgeState::default_path());
2951 let now = std::time::SystemTime::now()
2952 .duration_since(std::time::UNIX_EPOCH)
2953 .map(|d| d.as_secs())
2954 .unwrap_or(0);
2955 let decision = crate::nudge::decide_nudge(
2956 &findings,
2957 &prefs,
2958 &state,
2959 now,
2960 crate::nudge::DEFAULT_THROTTLE_SECS,
2961 inference_active,
2962 );
2963 (decision, state)
2964 }
2965
2966 pub fn dismiss_upgrade_nudge(&self, dismiss_key: &str) -> Result<(), InferenceError> {
2969 let path = crate::nudge::NudgeState::default_path();
2970 let mut state = crate::nudge::NudgeState::load_from(&path);
2971 state.dismiss(dismiss_key);
2972 state.save_to(&path).map_err(InferenceError::InferenceFailed)
2973 }
2974
2975 pub async fn check_concierge(
2982 &self,
2983 inference_active: bool,
2984 ) -> (
2985 Vec<crate::concierge::ConciergeSuggestion>,
2986 crate::nudge::NudgeState,
2987 ) {
2988 let prefs = self.update_prefs();
2989 let state = crate::nudge::NudgeState::load_from(&crate::nudge::NudgeState::default_path());
2990 let hw = crate::hardware::HardwareInfo::detect();
2991 let schemas = self.list_schemas();
2992 let refs: Vec<&ModelSchema> = schemas.iter().collect();
2993 let now = std::time::SystemTime::now()
2994 .duration_since(std::time::UNIX_EPOCH)
2995 .map(|d| d.as_secs())
2996 .unwrap_or(0);
2997 let suggestions = crate::concierge::decide_concierge(
2998 &refs,
2999 &hw,
3000 crate::concierge::DEFAULT_WATCHED_USE_CASES,
3001 crate::intent::QualityTier::Balanced,
3002 &prefs,
3003 &state,
3004 now,
3005 crate::concierge::DEFAULT_CONCIERGE_THROTTLE_SECS,
3006 inference_active,
3007 );
3008 (suggestions, state)
3009 }
3010
3011 pub fn dismiss_concierge_suggestion(&self, dismiss_key: &str) -> Result<(), InferenceError> {
3016 self.dismiss_upgrade_nudge(dismiss_key)
3017 }
3018
3019 pub async fn detect_upgrades(&self) -> Vec<crate::upgrade::UpgradeFinding> {
3023 let prefs = self.update_prefs();
3024 let curated = self.unified_registry.available_upgrades();
3025 let schemas = self.list_schemas();
3026 let refs: Vec<&ModelSchema> = schemas.iter().collect();
3027 let probe = crate::upgrade::HuggingFaceProbe::new();
3028 let now = std::time::SystemTime::now()
3029 .duration_since(std::time::UNIX_EPOCH)
3030 .map(|d| d.as_secs())
3031 .unwrap_or(0);
3032 crate::upgrade::detect_upgrades(
3033 curated,
3034 &refs,
3035 &prefs,
3036 &probe,
3037 &crate::upgrade::UpgradeCache::default_path(),
3038 now,
3039 crate::upgrade::DEFAULT_TTL_SECS,
3040 )
3041 .await
3042 }
3043
3044 pub fn list_schemas(&self) -> Vec<ModelSchema> {
3047 self.unified_registry.list().into_iter().cloned().collect()
3048 }
3049
3050 pub fn list_models(&self) -> Vec<models::ModelInfo> {
3051 self.registry.list_models()
3052 }
3053
3054 pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
3056 self.pull_model_with_progress(name, &crate::download::ProgressSink::none())
3057 .await
3058 }
3059
3060 pub async fn pull_model_with_progress(
3064 &self,
3065 name: &str,
3066 sink: &crate::download::ProgressSink,
3067 ) -> Result<std::path::PathBuf, InferenceError> {
3068 let schema = self
3069 .unified_registry
3070 .find_by_name(name)
3071 .or_else(|| self.unified_registry.get(name))
3072 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
3073 self.unified_registry
3074 .ensure_local_with_progress(&schema.id, sink)
3075 .await
3076 }
3077
3078 pub fn update_prefs(&self) -> crate::update_prefs::UpdatePreferences {
3083 let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
3084 crate::update_prefs::UpdatePreferences::load_effective(&cwd).unwrap_or_default()
3085 }
3086
3087 pub fn set_update_prefs(
3089 &self,
3090 prefs: &crate::update_prefs::UpdatePreferences,
3091 ) -> Result<(), InferenceError> {
3092 prefs
3093 .save()
3094 .map_err(InferenceError::InferenceFailed)
3095 }
3096
3097 pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
3099 let schema = self
3100 .unified_registry
3101 .get(name)
3102 .or_else(|| {
3103 self.unified_registry
3104 .list()
3105 .into_iter()
3106 .find(|schema| schema.name.eq_ignore_ascii_case(name))
3107 })
3108 .or_else(|| self.unified_registry.find_by_name(name))
3109 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
3110 let model_dir = self.unified_registry.models_dir().join(&schema.name);
3111 if model_dir.exists() {
3112 std::fs::remove_dir_all(&model_dir)?;
3113 }
3114 match &schema.source {
3115 ModelSource::Mlx { hf_repo, .. } => {
3116 remove_huggingface_repo_cache(hf_repo)?;
3117 }
3118 ModelSource::Local {
3119 hf_repo,
3120 tokenizer_repo,
3121 ..
3122 } => {
3123 remove_huggingface_repo_cache(hf_repo)?;
3124 remove_huggingface_repo_cache(tokenizer_repo)?;
3125 }
3126 _ => {}
3127 }
3128 Ok(())
3129 }
3130
3131 pub fn register_model(&mut self, schema: ModelSchema) {
3133 self.unified_registry.register(schema);
3134 }
3135
3136 pub async fn discover_vllm_mlx_models(&mut self) -> usize {
3139 let config = vllm_mlx::VllmMlxConfig::default();
3140 if !config.auto_discover {
3141 return 0;
3142 }
3143 vllm_mlx::discover_and_register(&config, &mut self.unified_registry).await
3144 }
3145
3146 pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
3148 self.outcome_tracker.clone()
3149 }
3150
3151 async fn auto_save_outcomes(&self) {
3153 if let Err(e) = self.save_outcomes().await {
3154 tracing::debug!("auto-save outcomes failed: {}", e);
3155 }
3156 if let Err(e) = self.save_key_pool_stats().await {
3157 tracing::debug!("auto-save key pool stats failed: {}", e);
3158 }
3159 }
3160
3161 pub async fn save_outcomes(&self) -> Result<(), std::io::Error> {
3163 let tracker = self.outcome_tracker.read().await;
3164 let path = self.config.models_dir.join("outcome_profiles.json");
3165 tracker.save_to_file(&path)
3166 }
3167
3168 pub async fn save_key_pool_stats(&self) -> Result<(), std::io::Error> {
3170 let path = self.config.models_dir.join("key_pool_stats.json");
3171 self.remote_backend.key_pool.save_stats(&path).await
3172 }
3173
3174 pub async fn key_pool_stats(
3176 &self,
3177 ) -> std::collections::HashMap<String, Vec<key_pool::KeyStats>> {
3178 self.remote_backend.key_pool.all_stats().await
3179 }
3180
3181 pub async fn export_profiles(&self) -> Vec<ModelProfile> {
3183 let tracker = self.outcome_tracker.read().await;
3184 tracker.export_profiles()
3185 }
3186
3187 pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
3189 let mut tracker = self.outcome_tracker.write().await;
3190 tracker.import_profiles(profiles);
3191 }
3192
3193 pub async fn prepare_speech_runtime(&self) -> Result<PathBuf, InferenceError> {
3196 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3197 {
3198 Ok(self.config.models_dir.clone())
3200 }
3201 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3202 {
3203 Ok(self.ensure_speech_runtime().await?.root)
3204 }
3205 }
3206
3207 pub fn set_speech_policy(&mut self, policy: SpeechPolicy) {
3209 self.speech_policy = policy;
3210 }
3211
3212 pub fn set_routing_config(&mut self, config: RoutingConfig) {
3213 self.adaptive_router.set_config(config);
3214 }
3215
3216 pub async fn install_curated_speech(
3218 &mut self,
3219 ) -> Result<Vec<SpeechInstallReport>, InferenceError> {
3220 let _runtime_root = self.prepare_speech_runtime().await?;
3221 let schemas = self.list_schemas();
3222 let mut repos = Vec::new();
3223 for schema in &schemas {
3224 if !schema.is_mlx() || !schema.tags.iter().any(|tag| tag == "speech") {
3225 continue;
3226 }
3227 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3228 if !repos.iter().any(|existing: &String| existing == hf_repo) {
3229 repos.push(hf_repo.clone());
3230 }
3231 }
3232 }
3233
3234 let mut installed = Vec::new();
3235 for repo in repos {
3236 let (snapshot_path, files_downloaded) = download_hf_repo_snapshot(&repo).await?;
3237 let name = schemas
3238 .iter()
3239 .find(|schema| {
3240 matches!(&schema.source, ModelSource::Mlx { hf_repo, .. } if hf_repo == &repo)
3241 })
3242 .map(|schema| schema.name.clone())
3243 .unwrap_or_else(|| repo.clone());
3244 installed.push(SpeechInstallReport {
3245 name,
3246 hf_repo: repo,
3247 snapshot_path,
3248 files_downloaded,
3249 });
3250 }
3251
3252 self.unified_registry.refresh_availability();
3253 Ok(installed)
3254 }
3255
3256 pub fn speech_health(&self) -> SpeechHealthReport {
3258 let local_stt_default =
3259 self.speech_health_default_name(ModelCapability::SpeechToText, true, false);
3260 let local_tts_default =
3261 self.speech_health_default_name(ModelCapability::TextToSpeech, true, false);
3262 let remote_stt_default =
3263 self.speech_health_default_name(ModelCapability::SpeechToText, false, true);
3264 let remote_tts_default =
3265 self.speech_health_default_name(ModelCapability::TextToSpeech, false, true);
3266
3267 let mut local_models = Vec::new();
3268 let mut remote_models = Vec::new();
3269 for schema in self.list_schemas() {
3270 let capability = if schema.has_capability(ModelCapability::SpeechToText) {
3271 Some(ModelCapability::SpeechToText)
3272 } else if schema.has_capability(ModelCapability::TextToSpeech) {
3273 Some(ModelCapability::TextToSpeech)
3274 } else {
3275 None
3276 };
3277 let Some(capability) = capability else {
3278 continue;
3279 };
3280
3281 let selected_by_default = local_stt_default
3282 .as_ref()
3283 .is_some_and(|name| name == &schema.name)
3284 || local_tts_default
3285 .as_ref()
3286 .is_some_and(|name| name == &schema.name)
3287 || remote_stt_default
3288 .as_ref()
3289 .is_some_and(|name| name == &schema.name)
3290 || remote_tts_default
3291 .as_ref()
3292 .is_some_and(|name| name == &schema.name);
3293
3294 let health = SpeechModelHealth {
3295 id: schema.id.clone(),
3296 name: schema.name.clone(),
3297 provider: schema.provider.clone(),
3298 capability,
3299 is_local: schema.is_local(),
3300 available: schema.available,
3301 cached: speech_model_cached(&schema),
3302 selected_by_default,
3303 source: speech_model_source_label(&schema),
3304 };
3305 if schema.is_local() {
3306 local_models.push(health);
3307 } else {
3308 remote_models.push(health);
3309 }
3310 }
3311
3312 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3314 let runtime = SpeechRuntimeHealth {
3315 root: self.config.models_dir.clone(),
3316 installed: true,
3317 python: PathBuf::new(),
3318 stt_command: PathBuf::new(),
3319 tts_command: PathBuf::new(),
3320 configured_python: None,
3321 detected_python: None,
3322 };
3323
3324 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3325 let runtime = {
3326 let rt =
3327 SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3328 SpeechRuntimeHealth {
3329 root: rt.root.clone(),
3330 installed: rt.is_ready(),
3331 python: rt.python.clone(),
3332 stt_command: rt.stt_program.clone(),
3333 tts_command: rt.tts_program.clone(),
3334 configured_python: std::env::var("CAR_SPEECH_PYTHON")
3335 .ok()
3336 .filter(|value| !value.trim().is_empty()),
3337 detected_python: detect_speech_python(),
3338 }
3339 };
3340
3341 SpeechHealthReport {
3342 runtime,
3343 local_models,
3344 remote_models,
3345 elevenlabs_configured: car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY")
3346 .is_some(),
3347 prefer_local: self.speech_policy.prefer_local,
3348 allow_remote_fallback: self.speech_policy.allow_remote_fallback,
3349 preferred_local_stt: self.speech_policy.preferred_local_stt.clone(),
3350 preferred_local_tts: self.speech_policy.preferred_local_tts.clone(),
3351 preferred_remote_stt: self.speech_policy.preferred_remote_stt.clone(),
3352 preferred_remote_tts: self.speech_policy.preferred_remote_tts.clone(),
3353 local_stt_default,
3354 local_tts_default,
3355 remote_stt_default,
3356 remote_tts_default,
3357 }
3358 }
3359
3360 pub async fn model_health(&self) -> ModelHealthReport {
3363 let schemas = self.list_schemas();
3364 let total_models = schemas.len();
3365 let available_models = schemas.iter().filter(|schema| schema.available).count();
3366 let local_models = schemas.iter().filter(|schema| schema.is_local()).count();
3367 let remote_models = total_models.saturating_sub(local_models);
3368
3369 let defaults = vec![
3370 self.model_default_health(
3371 ModelCapability::Generate,
3372 self.preferred_model_for_capability(ModelCapability::Generate)
3373 .unwrap_or(&self.config.generation_model),
3374 ),
3375 self.model_default_health(
3376 ModelCapability::Embed,
3377 self.preferred_model_for_capability(ModelCapability::Embed)
3378 .unwrap_or(&self.config.embedding_model),
3379 ),
3380 self.model_default_health(
3381 ModelCapability::Classify,
3382 self.preferred_model_for_capability(ModelCapability::Classify)
3383 .unwrap_or(&self.config.classification_model),
3384 ),
3385 ];
3386
3387 let mut providers = std::collections::BTreeMap::new();
3388 for schema in &schemas {
3389 let entry =
3390 providers
3391 .entry(schema.provider.clone())
3392 .or_insert_with(|| ProviderAccumulator {
3393 configured: false,
3394 local_models: 0,
3395 remote_models: 0,
3396 available_models: 0,
3397 capabilities: std::collections::HashSet::new(),
3398 });
3399
3400 entry.configured |= model_source_configured(schema);
3401 if schema.is_local() {
3402 entry.local_models += 1;
3403 } else {
3404 entry.remote_models += 1;
3405 }
3406 if schema.available {
3407 entry.available_models += 1;
3408 }
3409 for capability in &schema.capabilities {
3410 entry.capabilities.insert(*capability);
3411 }
3412 }
3413
3414 let providers = providers
3415 .into_iter()
3416 .map(|(provider, acc)| ModelProviderHealth {
3417 provider,
3418 configured: acc.configured,
3419 local_models: acc.local_models,
3420 remote_models: acc.remote_models,
3421 available_models: acc.available_models,
3422 capabilities: sort_capabilities(acc.capabilities.into_iter().collect()),
3423 })
3424 .collect();
3425
3426 let capabilities = all_model_capabilities()
3427 .into_iter()
3428 .map(|capability| {
3429 let relevant: Vec<&ModelSchema> = schemas
3430 .iter()
3431 .filter(|schema| schema.has_capability(capability))
3432 .collect();
3433 let available: Vec<&ModelSchema> = relevant
3434 .iter()
3435 .copied()
3436 .filter(|schema| schema.available)
3437 .collect();
3438 ModelCapabilityHealth {
3439 capability,
3440 total_models: relevant.len(),
3441 available_models: available.len(),
3442 local_available_models: available
3443 .iter()
3444 .filter(|schema| schema.is_local())
3445 .count(),
3446 remote_available_models: available
3447 .iter()
3448 .filter(|schema| !schema.is_local())
3449 .count(),
3450 }
3451 })
3452 .collect();
3453
3454 let routing = self.routing_scenarios().await;
3455 let routing_config = self.adaptive_router.config().clone();
3456 let benchmark_priors = load_benchmark_prior_health(&self.config.models_dir, &schemas);
3457
3458 ModelHealthReport {
3459 total_models,
3460 available_models,
3461 local_models,
3462 remote_models,
3463 defaults,
3464 providers,
3465 capabilities,
3466 routing_prefer_local: routing_config.prefer_local,
3467 routing_quality_first_cold_start: routing_config.quality_first_cold_start,
3468 routing_min_observations: routing_config.min_observations,
3469 routing_bootstrap_min_task_observations: routing_config.bootstrap_min_task_observations,
3470 routing_bootstrap_quality_floor: routing_config.bootstrap_quality_floor,
3471 routing_quality_weight: routing_config.quality_weight,
3472 routing_latency_weight: routing_config.latency_weight,
3473 routing_cost_weight: routing_config.cost_weight,
3474 routing_scenarios: routing,
3475 benchmark_priors,
3476 speech: self.speech_health(),
3477 }
3478 }
3479
3480 async fn routing_scenarios(&self) -> Vec<RoutingScenarioHealth> {
3481 let tracker = self.outcome_tracker.read().await;
3482 let config = self.adaptive_router.config().clone();
3483 let scenarios = [
3484 (
3485 "interactive_text",
3486 "Summarize the benefits of local-first AI routing in two sentences.",
3487 "text",
3488 RoutingWorkload::Interactive,
3489 false,
3490 false,
3491 ),
3492 (
3493 "background_code",
3494 "Write a Python function named fibonacci(n) that returns the nth Fibonacci number.",
3495 "code",
3496 RoutingWorkload::Background,
3497 false,
3498 false,
3499 ),
3500 (
3501 "interactive_tool_use",
3502 "Use the provided weather tool to get the weather for Boston.",
3503 "tool_use",
3504 RoutingWorkload::Interactive,
3505 true,
3506 false,
3507 ),
3508 (
3509 "interactive_vision",
3510 "What is in this image? Answer in one word.",
3511 "vision",
3512 RoutingWorkload::Interactive,
3513 false,
3514 true,
3515 ),
3516 ];
3517
3518 scenarios
3519 .into_iter()
3520 .map(
3521 |(name, prompt, task_family, workload, has_tools, has_vision)| {
3522 let decision = self.adaptive_router.route_context_aware(
3523 prompt,
3524 0,
3525 &self.unified_registry,
3526 &tracker,
3527 has_tools,
3528 has_vision,
3529 workload,
3530 );
3531 let quality_first_cold_start = if has_tools || has_vision {
3532 config.quality_first_cold_start
3533 } else if task_family == "code"
3534 && matches!(workload, RoutingWorkload::Background)
3535 {
3536 false
3537 } else {
3538 config.quality_first_cold_start
3539 };
3540 RoutingScenarioHealth {
3541 name: name.to_string(),
3542 task_family: task_family.to_string(),
3543 workload,
3544 has_tools,
3545 has_vision,
3546 prefer_local: if task_family == "speech" {
3547 self.speech_policy.prefer_local
3548 } else {
3549 config.prefer_local
3550 },
3551 quality_first_cold_start,
3552 bootstrap_min_task_observations: config.bootstrap_min_task_observations,
3553 bootstrap_quality_floor: config.bootstrap_quality_floor,
3554 model_id: decision.model_id,
3555 model_name: decision.model_name,
3556 reason: decision.reason,
3557 strategy: decision.strategy,
3558 }
3559 },
3560 )
3561 .collect()
3562 }
3563
3564 pub async fn smoke_test_speech(
3566 &self,
3567 local: bool,
3568 remote: bool,
3569 ) -> Result<SpeechSmokeReport, InferenceError> {
3570 let mut report = SpeechSmokeReport::default();
3571
3572 if local {
3573 let tts = self
3574 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
3575 .ok_or_else(|| {
3576 InferenceError::InferenceFailed(
3577 "no local text-to-speech model available".into(),
3578 )
3579 })?;
3580 let stt = self
3581 .preferred_speech_schema(ModelCapability::SpeechToText, true, false)
3582 .ok_or_else(|| {
3583 InferenceError::InferenceFailed(
3584 "no local speech-to-text model available".into(),
3585 )
3586 })?;
3587 report.local = Some(
3588 self.run_speech_smoke_path("local", &tts, &stt, "Testing CAR local speech path.")
3589 .await?,
3590 );
3591 } else {
3592 report.skipped.push("local".to_string());
3593 }
3594
3595 if remote {
3596 let tts = self
3597 .preferred_speech_schema(ModelCapability::TextToSpeech, false, true)
3598 .ok_or_else(|| {
3599 InferenceError::InferenceFailed(
3600 "no remote text-to-speech model available".into(),
3601 )
3602 })?;
3603 let stt = self
3604 .preferred_speech_schema(ModelCapability::SpeechToText, false, true)
3605 .ok_or_else(|| {
3606 InferenceError::InferenceFailed(
3607 "no remote speech-to-text model available".into(),
3608 )
3609 })?;
3610 report.remote = Some(
3611 self.run_speech_smoke_path("remote", &tts, &stt, "Testing CAR remote speech path.")
3612 .await?,
3613 );
3614 } else {
3615 report.skipped.push("remote".to_string());
3616 }
3617
3618 Ok(report)
3619 }
3620
3621 fn speech_candidates(
3622 &self,
3623 capability: ModelCapability,
3624 explicit: Option<&str>,
3625 ) -> Result<Vec<ModelSchema>, InferenceError> {
3626 if let Some(model) = explicit {
3627 let schema = self
3628 .unified_registry
3629 .get(model)
3630 .or_else(|| self.unified_registry.find_by_name(model))
3631 .cloned()
3632 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3633 if !schema.has_capability(capability) {
3634 return Err(InferenceError::InferenceFailed(format!(
3635 "model {} does not support {:?}",
3636 schema.name, capability
3637 )));
3638 }
3639 return Ok(vec![schema]);
3640 }
3641
3642 let mut candidates: Vec<ModelSchema> = self
3643 .unified_registry
3644 .query(&ModelFilter {
3645 capabilities: vec![capability],
3646 ..Default::default()
3647 })
3648 .into_iter()
3649 .cloned()
3650 .collect();
3651
3652 if candidates.is_empty() {
3653 return Err(InferenceError::InferenceFailed(format!(
3654 "no models registered for capability {:?}",
3655 capability
3656 )));
3657 }
3658
3659 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3660 if !self.speech_policy.allow_remote_fallback
3661 && candidates.iter().any(|model| model.is_local())
3662 {
3663 candidates.retain(|model| model.is_local());
3664 }
3665
3666 Ok(candidates)
3667 }
3668
3669 #[allow(dead_code)] fn resolve_external_hf_repo(
3675 &self,
3676 explicit: Option<&str>,
3677 capability: ModelCapability,
3678 ) -> Option<String> {
3679 let id = explicit?;
3680 let schema = self
3681 .unified_registry
3682 .get(id)
3683 .or_else(|| self.unified_registry.find_by_name(id))?;
3684 if !schema.has_capability(capability) {
3685 return Some(id.to_string());
3686 }
3687 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3688 return Some(hf_repo.clone());
3689 }
3690 Some(id.to_string())
3691 }
3692
3693 fn media_generation_candidates(
3694 &self,
3695 capability: ModelCapability,
3696 explicit: Option<&str>,
3697 ) -> Result<Vec<ModelSchema>, InferenceError> {
3698 if let Some(model) = explicit {
3699 let schema = self
3700 .unified_registry
3701 .get(model)
3702 .or_else(|| self.unified_registry.find_by_name(model))
3703 .cloned()
3704 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3705 if !schema.has_capability(capability) {
3706 return Err(InferenceError::InferenceFailed(format!(
3707 "model {} does not support {:?}",
3708 schema.name, capability
3709 )));
3710 }
3711 return Ok(vec![schema]);
3712 }
3713
3714 let mut candidates: Vec<ModelSchema> = self
3715 .unified_registry
3716 .query(&ModelFilter {
3717 capabilities: vec![capability],
3718 local_only: true,
3719 ..Default::default()
3720 })
3721 .into_iter()
3722 .cloned()
3723 .collect();
3724 candidates.sort_by_key(|schema| (!schema.available, schema.size_mb()));
3725 if candidates.is_empty() {
3726 return Err(InferenceError::InferenceFailed(format!(
3727 "no models registered for capability {:?}",
3728 capability
3729 )));
3730 }
3731 Ok(candidates)
3732 }
3733
3734 fn preferred_speech_schema(
3735 &self,
3736 capability: ModelCapability,
3737 local_only: bool,
3738 remote_only: bool,
3739 ) -> Option<ModelSchema> {
3740 let available_only = remote_only;
3741 let mut candidates: Vec<ModelSchema> = self
3742 .unified_registry
3743 .query(&ModelFilter {
3744 capabilities: vec![capability],
3745 available_only,
3746 ..Default::default()
3747 })
3748 .into_iter()
3749 .filter(|schema| {
3750 (!local_only || schema.is_local()) && (!remote_only || schema.is_remote())
3751 })
3752 .cloned()
3753 .collect();
3754 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3755 candidates.into_iter().next()
3756 }
3757
3758 fn speech_health_default_name(
3759 &self,
3760 capability: ModelCapability,
3761 local_only: bool,
3762 remote_only: bool,
3763 ) -> Option<String> {
3764 let preferred = match capability {
3765 ModelCapability::SpeechToText if local_only => {
3766 self.speech_policy.preferred_local_stt.as_ref()
3767 }
3768 ModelCapability::SpeechToText if remote_only => {
3769 self.speech_policy.preferred_remote_stt.as_ref()
3770 }
3771 ModelCapability::TextToSpeech if local_only => {
3772 self.speech_policy.preferred_local_tts.as_ref()
3773 }
3774 ModelCapability::TextToSpeech if remote_only => {
3775 self.speech_policy.preferred_remote_tts.as_ref()
3776 }
3777 _ => None,
3778 };
3779
3780 preferred
3781 .filter(|name| {
3782 self.unified_registry.list().iter().any(|schema| {
3783 schema.name == **name
3784 && schema.has_capability(capability)
3785 && (!local_only || schema.is_local())
3786 && (!remote_only || schema.is_remote())
3787 })
3788 })
3789 .cloned()
3790 .or_else(|| {
3791 self.preferred_speech_schema(capability, local_only, remote_only)
3792 .map(|schema| schema.name)
3793 })
3794 }
3795
3796 fn model_default_health(
3797 &self,
3798 capability: ModelCapability,
3799 configured_model: &str,
3800 ) -> ModelDefaultHealth {
3801 let schema = self
3802 .unified_registry
3803 .find_by_name(configured_model)
3804 .or_else(|| self.unified_registry.get(configured_model));
3805
3806 ModelDefaultHealth {
3807 capability,
3808 configured_model: configured_model.to_string(),
3809 available: schema.is_some_and(|model| model.available),
3810 is_local: schema.is_some_and(ModelSchema::is_local),
3811 provider: schema.map(|model| model.provider.clone()),
3812 }
3813 }
3814
3815 fn speech_sort_key(
3816 &self,
3817 capability: ModelCapability,
3818 model: &ModelSchema,
3819 ) -> (u8, u8, u8, u8, u64, u64) {
3820 let policy_preference = match capability {
3821 ModelCapability::SpeechToText if model.is_local() => {
3822 self.speech_policy.preferred_local_stt.as_ref()
3823 }
3824 ModelCapability::SpeechToText => self.speech_policy.preferred_remote_stt.as_ref(),
3825 ModelCapability::TextToSpeech if model.is_local() => {
3826 self.speech_policy.preferred_local_tts.as_ref()
3827 }
3828 ModelCapability::TextToSpeech => self.speech_policy.preferred_remote_tts.as_ref(),
3829 _ => None,
3830 };
3831 let local_rank = if self.speech_policy.prefer_local {
3832 if model.is_local() {
3833 0
3834 } else {
3835 1
3836 }
3837 } else if model.is_remote() {
3838 0
3839 } else {
3840 1
3841 };
3842 let availability_rank = if model.available {
3843 0
3844 } else if model.is_local() {
3845 1
3846 } else {
3847 2
3848 };
3849 let policy_rank: u8 = if policy_preference.is_some_and(|preferred| preferred == &model.name)
3850 {
3851 0
3852 } else {
3853 1
3854 };
3855 let speech_rank = match capability {
3856 ModelCapability::TextToSpeech => {
3857 if model.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
3858 0
3859 } else if model.name == "Kokoro-82M-bf16" {
3860 1
3861 } else if model.name == "Kokoro-82M-6bit" {
3862 2
3863 } else {
3864 3
3865 }
3866 }
3867 ModelCapability::SpeechToText => {
3868 if model.name == "Parakeet-TDT-0.6B-v3-MLX" {
3869 0
3870 } else {
3871 1
3872 }
3873 }
3874 _ => 0,
3875 };
3876 let latency_rank = model.performance.latency_p50_ms.unwrap_or(u64::MAX);
3877 let size_rank = model.cost.size_mb.unwrap_or(u64::MAX);
3878 (
3879 local_rank,
3880 availability_rank,
3881 policy_rank,
3882 speech_rank,
3883 latency_rank,
3884 size_rank,
3885 )
3886 }
3887
3888 async fn run_speech_smoke_path(
3889 &self,
3890 path: &str,
3891 tts: &ModelSchema,
3892 stt: &ModelSchema,
3893 text: &str,
3894 ) -> Result<SpeechSmokePathReport, InferenceError> {
3895 let work_dir = temp_work_dir(&format!("speech-smoke-{path}"))?;
3896 let audio_path = work_dir.join(format!("{path}.wav"));
3897 let synth = self
3898 .synthesize(SynthesizeRequest {
3899 text: text.to_string(),
3900 model: Some(tts.name.clone()),
3901 voice: default_speech_voice(tts),
3902 language: Some("en".to_string()),
3903 output_path: Some(audio_path.display().to_string()),
3904 ..SynthesizeRequest::default()
3905 })
3906 .await?;
3907 let transcript = self
3908 .transcribe(TranscribeRequest {
3909 audio_path: synth.audio_path.clone(),
3910 model: Some(stt.name.clone()),
3911 language: Some("en".to_string()),
3912 prompt: None,
3913 timestamps: false,
3914 })
3915 .await?;
3916
3917 Ok(SpeechSmokePathReport {
3918 path: path.to_string(),
3919 tts_model: synth.model_used.unwrap_or_else(|| tts.name.clone()),
3920 stt_model: transcript.model_used.unwrap_or_else(|| stt.name.clone()),
3921 audio_path: PathBuf::from(synth.audio_path),
3922 transcript: transcript.text,
3923 })
3924 }
3925
3926 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3927 async fn ensure_speech_runtime(&self) -> Result<SpeechRuntime, InferenceError> {
3928 let mut guard = self.speech_runtime.lock().await;
3929 if let Some(runtime) = guard.as_ref() {
3930 if runtime.is_ready() {
3931 return Ok(runtime.clone());
3932 }
3933 }
3934
3935 let runtime =
3936 SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3937 if !runtime.is_ready() {
3938 bootstrap_speech_runtime(&runtime).await?;
3939 }
3940 if !runtime.is_ready() {
3941 return Err(InferenceError::InferenceFailed(format!(
3942 "managed speech runtime is not ready at {}",
3943 runtime.root.display()
3944 )));
3945 }
3946
3947 *guard = Some(runtime.clone());
3948 Ok(runtime)
3949 }
3950
3951 async fn transcribe_local_mlx(
3952 &self,
3953 schema: &ModelSchema,
3954 req: &TranscribeRequest,
3955 ) -> Result<TranscribeResult, InferenceError> {
3956 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3958 {
3959 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3960 let parakeet = backend::mlx_parakeet::ParakeetBackend::load(&model_dir)?;
3961 let (text, words) = if req.timestamps {
3963 parakeet
3964 .transcribe_detailed(Path::new(&req.audio_path))
3965 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?
3966 } else {
3967 let t = parakeet
3968 .transcribe(Path::new(&req.audio_path))
3969 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?;
3970 (t, Vec::new())
3971 };
3972 return Ok(TranscribeResult {
3973 text,
3974 model_used: Some(schema.name.clone()),
3975 language: req.language.clone(),
3976 words,
3977 });
3978 }
3979
3980 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3982 {
3983 let runtime = self.ensure_speech_runtime().await?;
3984 let hf_repo = match &schema.source {
3985 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3986 _ => unreachable!(),
3987 };
3988 let output_dir = temp_work_dir("stt")?;
3989 let output_prefix = output_dir.join("transcript");
3990 let mut args = vec![
3991 "--model".to_string(),
3992 hf_repo,
3993 "--audio".to_string(),
3994 req.audio_path.clone(),
3995 "--output-path".to_string(),
3996 output_prefix.display().to_string(),
3997 "--format".to_string(),
3998 "json".to_string(),
3999 ];
4000 if let Some(language) = &req.language {
4001 args.push("--language".to_string());
4002 args.push(normalize_lang_code(language));
4003 }
4004 if let Some(prompt) = &req.prompt {
4005 args.push("--context".to_string());
4006 args.push(prompt.clone());
4007 }
4008 if req.timestamps {
4009 args.push("--verbose".to_string());
4010 }
4011
4012 let output = run_mlx_audio_command(&runtime, "stt.generate", &args).await?;
4013 let text = read_transcription_result(&output_prefix)?
4014 .or_else(|| extract_text_from_payload(&output.stdout))
4015 .ok_or_else(|| {
4016 InferenceError::InferenceFailed(format!(
4017 "mlx-audio transcription returned no text: {}",
4018 output.stderr
4019 ))
4020 })?;
4021
4022 Ok(TranscribeResult {
4023 text,
4024 model_used: Some(schema.name.clone()),
4025 language: req.language.clone(),
4026 words: Vec::new(),
4027 })
4028 }
4029 }
4030
4031 async fn synthesize_local_mlx(
4032 &self,
4033 schema: &ModelSchema,
4034 req: &SynthesizeRequest,
4035 ) -> Result<SynthesizeResult, InferenceError> {
4036 let requested = req.requested_advanced_controls();
4041 let repo_supports_advanced = match &schema.source {
4042 ModelSource::Mlx { hf_repo, .. } => hf_repo.to_ascii_lowercase().contains("qwen3-tts"),
4043 _ => false,
4044 };
4045 if !requested.is_empty() && !repo_supports_advanced {
4046 if req.strict_capabilities {
4047 return Err(InferenceError::InferenceFailed(format!(
4048 "model {name} does not support Qwen3-TTS advanced controls {requested:?}; \
4049 route to a Qwen3-TTS model or set strict_capabilities = false to degrade",
4050 name = schema.name,
4051 )));
4052 }
4053 tracing::warn!(
4054 model = %schema.name,
4055 fields = ?requested,
4056 "Qwen3-TTS advanced controls set on non-Qwen3-TTS backend — ignored \
4057 (set strict_capabilities=true to error instead)"
4058 );
4059 }
4060
4061 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
4063 {
4064 if repo_supports_advanced && !requested.is_empty() {
4070 if req.strict_capabilities {
4071 return Err(InferenceError::InferenceFailed(format!(
4072 "native MLX TTS backend does not yet implement Qwen3-TTS advanced \
4073 controls {requested:?}; run on non-Apple-Silicon to use the Python \
4074 mlx-audio fallback, or set strict_capabilities = false"
4075 )));
4076 }
4077 tracing::warn!(
4078 model = %schema.name,
4079 fields = ?requested,
4080 "Qwen3-TTS advanced controls are not yet implemented in the native MLX TTS \
4081 backend; synthesizing without cloning/voice-design"
4082 );
4083 }
4084 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
4085 let size = backend_cache::estimate_model_size(&model_dir);
4086 let cache = Arc::clone(&self.kokoro_cache);
4087 let key = schema.id.clone();
4088 let handle = cache.get_or_load(&key, size, || {
4089 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
4090 })?;
4091
4092 let output_path = req.output_path.clone().unwrap_or_else(|| {
4093 let dir = std::env::temp_dir().join("car_tts");
4094 let _ = std::fs::create_dir_all(&dir);
4095 dir.join("output.wav").display().to_string()
4096 });
4097 let voice = req.voice.as_deref().unwrap_or("af_heart").to_string();
4098 let text = req.text.clone();
4099 let op = tokio::task::spawn_blocking(move || -> Result<PathBuf, InferenceError> {
4100 let mut guard = handle.lock().map_err(|_| {
4101 InferenceError::InferenceFailed("kokoro backend mutex poisoned".into())
4102 })?;
4103 guard
4104 .synthesize(&text, Some(&voice), Path::new(&output_path))
4105 .map_err(|e| InferenceError::InferenceFailed(format!("native TTS: {e}")))
4106 })
4107 .await
4108 .map_err(|e| InferenceError::InferenceFailed(format!("kokoro task join: {e}")))??;
4109
4110 let final_path =
4111 materialize_audio_output(&op, req.output_path.as_deref(), &req.format)?;
4112 return Ok(SynthesizeResult {
4113 audio_path: final_path.display().to_string(),
4114 media_type: media_type_for_format(&req.format),
4115 model_used: Some(schema.name.clone()),
4116 voice_used: req.voice.clone(),
4117 });
4118 }
4119
4120 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4122 {
4123 let runtime = self.ensure_speech_runtime().await?;
4124 let primary_hf_repo = match &schema.source {
4125 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
4126 _ => unreachable!(),
4127 };
4128 let (produced, model_used) = match self
4129 .synthesize_local_mlx_repo(&runtime, &primary_hf_repo, schema.name.as_str(), req)
4130 .await
4131 {
4132 Ok(result) => result,
4133 Err(primary_err)
4134 if primary_hf_repo == "mlx-community/Kokoro-82M-6bit"
4135 && kokoro_runtime_fallback_enabled() =>
4136 {
4137 let fallback_repo = "mlx-community/Kokoro-82M-bf16";
4138 let fallback_name = "Kokoro-82M-bf16";
4139 match self
4140 .synthesize_local_mlx_repo(&runtime, fallback_repo, fallback_name, req)
4141 .await
4142 {
4143 Ok(result) => result,
4144 Err(fallback_err) => {
4145 return Err(InferenceError::InferenceFailed(format!(
4146 "{primary_err}; fallback {fallback_name} also failed: {fallback_err}"
4147 )));
4148 }
4149 }
4150 }
4151 Err(err) => return Err(err),
4152 };
4153 let final_path =
4154 materialize_audio_output(&produced, req.output_path.as_deref(), &req.format)?;
4155
4156 Ok(SynthesizeResult {
4157 audio_path: final_path.display().to_string(),
4158 media_type: media_type_for_format(&req.format),
4159 model_used: Some(model_used),
4160 voice_used: req.voice.clone(),
4161 })
4162 }
4163 }
4164
4165 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4166 async fn synthesize_local_mlx_repo(
4167 &self,
4168 runtime: &SpeechRuntime,
4169 hf_repo: &str,
4170 model_name: &str,
4171 req: &SynthesizeRequest,
4172 ) -> Result<(PathBuf, String), InferenceError> {
4173 let output_dir = temp_work_dir("tts")?;
4174 let mut args = vec![
4175 "--model".to_string(),
4176 hf_repo.to_string(),
4177 "--text".to_string(),
4178 req.text.clone(),
4179 "--output_path".to_string(),
4180 output_dir.display().to_string(),
4181 ];
4182 if let Some(voice) = &req.voice {
4183 args.push("--voice".to_string());
4184 args.push(voice.clone());
4185 }
4186 if let Some(speed) = req.speed {
4187 args.push("--speed".to_string());
4188 args.push(speed.to_string());
4189 }
4190 let repo_lower = hf_repo.to_ascii_lowercase();
4191 if repo_lower.contains("kokoro") {
4192 args.push("--lang_code".to_string());
4193 args.push(kokoro_lang_code(req.language.as_deref()).to_string());
4194 } else if let Some(language) = &req.language {
4195 args.push("--lang_code".to_string());
4196 args.push(normalize_lang_code(language));
4197 }
4198
4199 if repo_lower.contains("qwen3-tts") {
4205 if let Some(ref_audio) = &req.reference_audio_path {
4206 args.push("--ref_audio".to_string());
4207 args.push(ref_audio.clone());
4208 }
4209 if let Some(ref_text) = &req.reference_text {
4210 args.push("--ref_text".to_string());
4211 args.push(ref_text.clone());
4212 }
4213 if let Some(instruct) = &req.voice_instruction {
4214 args.push("--instruct".to_string());
4215 args.push(instruct.clone());
4216 }
4217 }
4218
4219 let output = if repo_lower.contains("kokoro") {
4220 let device = std::env::var("CAR_SPEECH_KOKORO_DEVICE")
4221 .or_else(|_| std::env::var("CAR_SPEECH_MLX_DEVICE"))
4222 .unwrap_or_else(|_| "cpu".to_string());
4223 let extra_env = vec![
4224 ("MLX_DEVICE".to_string(), device),
4226 ("PYTORCH_ENABLE_MPS_FALLBACK".to_string(), "1".to_string()),
4228 ];
4229 run_mlx_audio_command_with_env(runtime, "tts.generate", &args, &extra_env).await?
4230 } else {
4231 run_mlx_audio_command(runtime, "tts.generate", &args).await?
4232 };
4233 let produced = find_audio_file(&output_dir)?.ok_or_else(|| {
4234 let hint = if repo_lower.contains("kokoro") {
4235 ". Kokoro models may crash on GPU — try CAR_SPEECH_KOKORO_DEVICE=cpu or use the default Qwen3-TTS model"
4236 } else {
4237 ""
4238 };
4239 InferenceError::InferenceFailed(format!(
4240 "mlx-audio synthesis produced no audio file: {}{}",
4241 output.stderr, hint
4242 ))
4243 })?;
4244 Ok((produced, model_name.to_string()))
4245 }
4246
4247 async fn transcribe_elevenlabs(
4248 &self,
4249 schema: &ModelSchema,
4250 req: &TranscribeRequest,
4251 ) -> Result<TranscribeResult, InferenceError> {
4252 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4253 let file_name = Path::new(&req.audio_path)
4254 .file_name()
4255 .and_then(|f| f.to_str())
4256 .unwrap_or("audio.wav")
4257 .to_string();
4258 let audio_bytes = tokio::fs::read(&req.audio_path).await?;
4259 let file_part = Part::bytes(audio_bytes).file_name(file_name);
4260 let mut form = Form::new()
4261 .text("model_id", schema.name.clone())
4262 .part("file", file_part);
4263 if let Some(language) = &req.language {
4264 form = form.text("language_code", language.clone());
4265 }
4266
4267 let resp = self
4268 .remote_backend
4269 .client
4270 .post(format!(
4271 "{}/v1/speech-to-text",
4272 endpoint.trim_end_matches('/')
4273 ))
4274 .header("xi-api-key", api_key)
4275 .multipart(form)
4276 .send()
4277 .await
4278 .map_err(|e| {
4279 InferenceError::InferenceFailed(format!("ElevenLabs STT request failed: {e}"))
4280 })?;
4281 let status = resp.status();
4282 let body = resp.text().await.map_err(|e| {
4283 InferenceError::InferenceFailed(format!("read ElevenLabs STT body: {e}"))
4284 })?;
4285 if !status.is_success() {
4286 return Err(InferenceError::InferenceFailed(format!(
4287 "ElevenLabs STT returned {status}: {body}"
4288 )));
4289 }
4290 let payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
4291 InferenceError::InferenceFailed(format!("parse ElevenLabs STT response: {e}"))
4292 })?;
4293 let text = payload
4294 .get("text")
4295 .and_then(|v| v.as_str())
4296 .map(str::to_string)
4297 .ok_or_else(|| {
4298 InferenceError::InferenceFailed("ElevenLabs STT response missing text".into())
4299 })?;
4300
4301 Ok(TranscribeResult {
4302 text,
4303 model_used: Some(schema.name.clone()),
4304 language: payload
4305 .get("language_code")
4306 .and_then(|v| v.as_str())
4307 .map(str::to_string),
4308 words: Vec::new(),
4309 })
4310 }
4311
4312 async fn synthesize_elevenlabs(
4313 &self,
4314 schema: &ModelSchema,
4315 req: &SynthesizeRequest,
4316 ) -> Result<SynthesizeResult, InferenceError> {
4317 let requested = req.requested_advanced_controls();
4321 if !requested.is_empty() {
4322 if req.strict_capabilities {
4323 return Err(InferenceError::InferenceFailed(format!(
4324 "ElevenLabs backend does not support Qwen3-TTS advanced controls \
4325 {requested:?}; route to a Qwen3-TTS model or set strict_capabilities = false"
4326 )));
4327 }
4328 tracing::warn!(
4329 model = %schema.name,
4330 fields = ?requested,
4331 "Qwen3-TTS advanced controls ignored by ElevenLabs backend"
4332 );
4333 }
4334 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4335 let voice_id = req
4336 .voice
4337 .clone()
4338 .unwrap_or_else(|| "JBFqnCBsd6RMkjVDRZzb".to_string());
4339 let output_format = elevenlabs_output_format(&req.format);
4340 let url = format!(
4341 "{}/v1/text-to-speech/{}?output_format={}",
4342 endpoint.trim_end_matches('/'),
4343 voice_id,
4344 output_format
4345 );
4346
4347 let mut body = serde_json::json!({
4348 "text": req.text,
4349 "model_id": schema.name,
4350 });
4351 if let Some(language) = &req.language {
4352 body["language_code"] = serde_json::Value::String(language.clone());
4353 }
4354
4355 let resp = self
4356 .remote_backend
4357 .client
4358 .post(url)
4359 .header("xi-api-key", api_key)
4360 .header("Content-Type", "application/json")
4361 .json(&body)
4362 .send()
4363 .await
4364 .map_err(|e| {
4365 InferenceError::InferenceFailed(format!("ElevenLabs TTS request failed: {e}"))
4366 })?;
4367 let status = resp.status();
4368 let audio = resp.bytes().await.map_err(|e| {
4369 InferenceError::InferenceFailed(format!("read ElevenLabs TTS body: {e}"))
4370 })?;
4371 if !status.is_success() {
4372 let err_body = String::from_utf8_lossy(&audio);
4373 return Err(InferenceError::InferenceFailed(format!(
4374 "ElevenLabs TTS returned {status}: {err_body}"
4375 )));
4376 }
4377
4378 let final_path = requested_or_temp_output(req.output_path.as_deref(), &req.format)?;
4379 ensure_parent_dir(&final_path)?;
4380 tokio::fs::write(&final_path, &audio).await?;
4381
4382 Ok(SynthesizeResult {
4383 audio_path: final_path.display().to_string(),
4384 media_type: media_type_for_format(&req.format),
4385 model_used: Some(schema.name.clone()),
4386 voice_used: Some(voice_id),
4387 })
4388 }
4389}
4390
4391#[derive(Default)]
4392struct ProviderAccumulator {
4393 configured: bool,
4394 local_models: usize,
4395 remote_models: usize,
4396 available_models: usize,
4397 capabilities: std::collections::HashSet<ModelCapability>,
4398}
4399
4400#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4404struct CommandOutput {
4405 stdout: String,
4406 stderr: String,
4407}
4408
4409#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4410#[derive(Debug, Clone)]
4411struct SpeechRuntime {
4412 root: PathBuf,
4413 python: PathBuf,
4414 stt_program: PathBuf,
4415 tts_program: PathBuf,
4416}
4417
4418#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4419impl SpeechRuntime {
4420 fn new(root: PathBuf) -> Self {
4421 let bin_dir = root.join("bin");
4422 Self {
4423 root,
4424 python: bin_dir.join("python"),
4425 stt_program: bin_dir.join("mlx_audio.stt.generate"),
4426 tts_program: bin_dir.join("mlx_audio.tts.generate"),
4427 }
4428 }
4429
4430 fn is_ready(&self) -> bool {
4431 self.python.exists() && self.stt_program.exists() && self.tts_program.exists()
4432 }
4433
4434 fn command_for(&self, subcommand: &str) -> Result<&Path, InferenceError> {
4435 match subcommand {
4436 "stt.generate" => Ok(&self.stt_program),
4437 "tts.generate" => Ok(&self.tts_program),
4438 _ => Err(InferenceError::InferenceFailed(format!(
4439 "unknown speech subcommand: {subcommand}"
4440 ))),
4441 }
4442 }
4443}
4444
4445#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4446async fn run_mlx_audio_command(
4447 runtime: &SpeechRuntime,
4448 subcommand: &str,
4449 args: &[String],
4450) -> Result<CommandOutput, InferenceError> {
4451 run_mlx_audio_command_with_env(runtime, subcommand, args, &[]).await
4452}
4453
4454#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4455async fn run_mlx_audio_command_with_env(
4456 runtime: &SpeechRuntime,
4457 subcommand: &str,
4458 args: &[String],
4459 envs: &[(String, String)],
4460) -> Result<CommandOutput, InferenceError> {
4461 let program = runtime.command_for(subcommand)?;
4462 let mut command = Command::new(program);
4463 command.args(args);
4464 for (key, value) in envs {
4465 command.env(key, value);
4466 }
4467 let output = command
4468 .output()
4469 .await
4470 .map_err(|err| InferenceError::InferenceFailed(format!("{}: {err}", program.display())))?;
4471
4472 if output.status.success() {
4473 Ok(CommandOutput {
4474 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
4475 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
4476 })
4477 } else {
4478 Err(InferenceError::InferenceFailed(format!(
4479 "{} exited with {}: {}",
4480 program.display(),
4481 output.status,
4482 String::from_utf8_lossy(&output.stderr)
4483 )))
4484 }
4485}
4486
4487#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4488async fn bootstrap_speech_runtime(runtime: &SpeechRuntime) -> Result<(), InferenceError> {
4489 std::fs::create_dir_all(&runtime.root)?;
4490 let python = select_speech_python()?;
4491
4492 run_command(
4493 "uv",
4494 &[
4495 "venv".to_string(),
4496 "--python".to_string(),
4497 python,
4498 runtime.root.display().to_string(),
4499 ],
4500 )
4501 .await?;
4502
4503 run_command(
4504 "uv",
4505 &[
4506 "pip".to_string(),
4507 "install".to_string(),
4508 "--python".to_string(),
4509 runtime.python.display().to_string(),
4510 speech_runtime_mlx_audio_spec(),
4511 "misaki[en]".to_string(),
4512 speech_runtime_spacy_model_spec(),
4513 ],
4514 )
4515 .await?;
4516
4517 Ok(())
4518}
4519
4520#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4521async fn run_command(program: &str, args: &[String]) -> Result<(), InferenceError> {
4522 let output = Command::new(program)
4523 .args(args)
4524 .output()
4525 .await
4526 .map_err(|err| InferenceError::InferenceFailed(format!("{program}: {err}")))?;
4527
4528 if output.status.success() {
4529 Ok(())
4530 } else {
4531 Err(InferenceError::InferenceFailed(format!(
4532 "{} exited with {}: {}",
4533 program,
4534 output.status,
4535 String::from_utf8_lossy(&output.stderr)
4536 )))
4537 }
4538}
4539
4540#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4541fn select_speech_python() -> Result<String, InferenceError> {
4542 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4543 if !path.trim().is_empty() {
4544 return Ok(path);
4545 }
4546 }
4547
4548 for candidate in ["python3.13", "python3.12", "python3.11"] {
4549 if command_in_path(candidate) {
4550 return Ok(candidate.to_string());
4551 }
4552 }
4553
4554 Err(InferenceError::InferenceFailed(
4555 "no supported Python found for managed speech runtime (tried python3.13, python3.12, python3.11)".into(),
4556 ))
4557}
4558
4559#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4560fn detect_speech_python() -> Option<String> {
4561 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4562 if !path.trim().is_empty() {
4563 return Some(path);
4564 }
4565 }
4566
4567 ["python3.13", "python3.12", "python3.11"]
4568 .into_iter()
4569 .find(|candidate| command_in_path(candidate))
4570 .map(str::to_string)
4571}
4572
4573#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4574fn speech_runtime_root_from_models_dir(_models_dir: &Path) -> PathBuf {
4575 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
4576 if !path.trim().is_empty() {
4577 return PathBuf::from(path);
4578 }
4579 }
4580
4581 std::env::var("HOME")
4582 .map(PathBuf::from)
4583 .unwrap_or_else(|_| PathBuf::from("."))
4584 .join(".car")
4585 .join("speech-runtime")
4586}
4587
4588#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4589fn command_in_path(name: &str) -> bool {
4590 std::env::var_os("PATH")
4591 .map(|paths| {
4592 std::env::split_paths(&paths).any(|dir| {
4593 let path = dir.join(name);
4594 path.exists() && path.is_file()
4595 })
4596 })
4597 .unwrap_or(false)
4598}
4599
4600fn speech_model_cached(schema: &ModelSchema) -> bool {
4601 match &schema.source {
4602 ModelSource::Mlx { hf_repo, .. } => huggingface_repo_has_snapshot(hf_repo),
4603 ModelSource::Proprietary { auth, .. } => match auth {
4604 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4605 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4606 ProprietaryAuth::OAuth2Pkce { .. } => false,
4607 },
4608 _ => false,
4609 }
4610}
4611
4612fn remove_huggingface_repo_cache(repo_id: &str) -> Result<(), InferenceError> {
4613 let repo_dir = std::env::var("HF_HOME")
4614 .map(PathBuf::from)
4615 .unwrap_or_else(|_| {
4616 std::env::var("HOME")
4617 .map(PathBuf::from)
4618 .unwrap_or_else(|_| PathBuf::from("."))
4619 .join(".cache")
4620 .join("huggingface")
4621 })
4622 .join("hub")
4623 .join(format!("models--{}", repo_id.replace('/', "--")));
4624
4625 if repo_dir.exists() {
4626 std::fs::remove_dir_all(repo_dir)?;
4627 }
4628 Ok(())
4629}
4630
4631fn model_source_configured(schema: &ModelSchema) -> bool {
4632 match &schema.source {
4633 ModelSource::RemoteApi {
4634 api_key_env,
4635 api_key_envs,
4636 ..
4637 } => {
4638 std::env::var(api_key_env).is_ok()
4639 || api_key_envs
4640 .iter()
4641 .any(|env_var| std::env::var(env_var).is_ok())
4642 }
4643 ModelSource::Proprietary { auth, .. } => match auth {
4644 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4645 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4646 ProprietaryAuth::OAuth2Pkce { .. } => false,
4647 },
4648 ModelSource::VllmMlx { .. } => {
4649 std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available
4650 }
4651 ModelSource::Ollama { .. } => schema.available,
4652 ModelSource::Mlx { .. } | ModelSource::Local { .. } => true,
4653 ModelSource::AppleFoundationModels { .. } => schema.available,
4654 ModelSource::Delegated { .. } => true,
4659 }
4660}
4661
4662fn all_model_capabilities() -> [ModelCapability; 13] {
4663 [
4664 ModelCapability::Generate,
4665 ModelCapability::Embed,
4666 ModelCapability::Classify,
4667 ModelCapability::Code,
4668 ModelCapability::Reasoning,
4669 ModelCapability::Summarize,
4670 ModelCapability::ToolUse,
4671 ModelCapability::MultiToolCall,
4672 ModelCapability::Vision,
4673 ModelCapability::SpeechToText,
4674 ModelCapability::TextToSpeech,
4675 ModelCapability::ImageGeneration,
4676 ModelCapability::VideoGeneration,
4677 ]
4678}
4679
4680fn sort_capabilities(mut capabilities: Vec<ModelCapability>) -> Vec<ModelCapability> {
4681 capabilities.sort_by_key(|capability| {
4682 all_model_capabilities()
4683 .iter()
4684 .position(|candidate| candidate == capability)
4685 .unwrap_or(usize::MAX)
4686 });
4687 capabilities
4688}
4689
4690fn speech_model_source_label(schema: &ModelSchema) -> String {
4691 match &schema.source {
4692 ModelSource::Mlx { hf_repo, .. } => format!("mlx:{hf_repo}"),
4693 ModelSource::Proprietary {
4694 provider, endpoint, ..
4695 } => format!("proprietary:{provider}:{endpoint}"),
4696 ModelSource::RemoteApi { endpoint, .. } => format!("remote:{endpoint}"),
4697 ModelSource::Local { hf_repo, .. } => format!("local:{hf_repo}"),
4698 ModelSource::VllmMlx {
4699 endpoint,
4700 model_name,
4701 } => format!("vllm-mlx:{endpoint}:{model_name}"),
4702 ModelSource::Ollama { model_tag, host } => format!("ollama:{host}:{model_tag}"),
4703 ModelSource::AppleFoundationModels { use_case } => {
4704 format!(
4705 "apple-foundation:{}",
4706 use_case.as_deref().unwrap_or("default")
4707 )
4708 }
4709 ModelSource::Delegated { hint } => {
4710 format!("delegated:{}", hint.as_deref().unwrap_or("(none)"))
4711 }
4712 }
4713}
4714
4715fn rerank_prompt(instruction: &str, query: &str, document: &str) -> String {
4723 const SYSTEM: &str = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".";
4724 format!(
4725 "<|im_start|>system\n{SYSTEM}<|im_end|>\n\
4726 <|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n\
4727 <|im_start|>assistant\n<think>\n\n</think>\n\n"
4728 )
4729}
4730
4731fn no_backend_recovery_hint(underlying: &str) -> Option<String> {
4750 let is_no_backend = underlying.contains("no credential")
4751 || underlying.contains("model not found")
4752 || underlying.contains("no models available")
4753 || underlying.contains("no inference runner");
4754 if !is_no_backend {
4755 return None;
4756 }
4757 Some(format!(
4758 "no inference backend is available. To install a local CPU-only \
4759 model (no account required, works on Windows/Linux/macOS), run:\n \
4760 car models pull qwen/qwen3-1.7b:q8_0\n\
4761 To use Parslee's hosted models instead, run:\n \
4762 car auth login parslee\n\
4763 (underlying error: {underlying})"
4764 ))
4765}
4766
4767fn score_from_rerank_output(text: &str, model_name: &str) -> f32 {
4768 let normalized: String = text
4773 .to_ascii_lowercase()
4774 .chars()
4775 .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
4776 .collect();
4777 for tok in normalized.split_ascii_whitespace().take(5) {
4778 match tok {
4779 "yes" => return 1.0,
4780 "no" => return 0.0,
4781 _ => continue,
4782 }
4783 }
4784 tracing::warn!(
4785 model = %model_name,
4786 output = %text,
4787 "rerank: first tokens contain neither `yes` nor `no`; returning neutral 0.5"
4788 );
4789 0.5
4790}
4791
4792fn default_speech_voice(schema: &ModelSchema) -> Option<String> {
4793 if schema.provider == "elevenlabs" {
4794 Some("JBFqnCBsd6RMkjVDRZzb".to_string())
4795 } else if schema.name == "Kokoro-82M-6bit" || schema.name == "Kokoro-82M-bf16" {
4796 Some("af_heart".to_string())
4797 } else if schema.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
4798 Some("Chelsie".to_string())
4799 } else {
4800 None
4801 }
4802}
4803
4804#[allow(dead_code)] fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
4806 find_latest_huggingface_snapshot(repo_id).is_some()
4807}
4808
4809fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
4810 let cache_root = std::env::var("HF_HOME")
4811 .map(PathBuf::from)
4812 .unwrap_or_else(|_| {
4813 std::env::var("HOME")
4814 .map(PathBuf::from)
4815 .unwrap_or_else(|_| PathBuf::from("."))
4816 .join(".cache")
4817 .join("huggingface")
4818 })
4819 .join("hub");
4820 cache_root.join(format!("models--{}", repo_id.replace('/', "--")))
4821}
4822
4823fn find_latest_huggingface_snapshot(repo_id: &str) -> Option<PathBuf> {
4824 let snapshots = huggingface_repo_dir(repo_id).join("snapshots");
4825 std::fs::read_dir(snapshots)
4826 .ok()?
4827 .filter_map(Result::ok)
4828 .map(|entry| entry.path())
4829 .find(|path| path.is_dir() && snapshot_looks_ready(path))
4830}
4831
4832fn snapshot_looks_ready(path: &Path) -> bool {
4833 if path.join("config.json").exists() || path.join("model_index.json").exists() {
4834 return true;
4835 }
4836 snapshot_contains_ext(path, "safetensors")
4837}
4838
4839fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
4840 let Ok(entries) = std::fs::read_dir(root) else {
4841 return false;
4842 };
4843 entries.filter_map(Result::ok).any(|entry| {
4844 let path = entry.path();
4845 if path.is_dir() {
4846 snapshot_contains_ext(&path, ext)
4847 } else {
4848 path.extension()
4849 .and_then(|value| value.to_str())
4850 .map(|value| value.eq_ignore_ascii_case(ext))
4851 .unwrap_or(false)
4852 }
4853 })
4854}
4855
4856#[allow(dead_code)] fn count_files_recursive(root: &Path) -> usize {
4858 let Ok(entries) = std::fs::read_dir(root) else {
4859 return 0;
4860 };
4861 entries
4862 .filter_map(Result::ok)
4863 .map(|entry| entry.path())
4864 .map(|path| {
4865 if path.is_dir() {
4866 count_files_recursive(&path)
4867 } else if path.is_file() {
4868 1
4869 } else {
4870 0
4871 }
4872 })
4873 .sum()
4874}
4875
4876async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
4877 let api = hf_hub::api::tokio::ApiBuilder::from_env()
4878 .with_progress(false)
4879 .build()
4880 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
4881 let repo = api.model(repo_id.to_string());
4882 let info = repo
4883 .info()
4884 .await
4885 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
4886
4887 let snapshot_path = huggingface_repo_dir(repo_id)
4888 .join("snapshots")
4889 .join(&info.sha);
4890 let mut downloaded = 0usize;
4891 for sibling in &info.siblings {
4892 let local_path = snapshot_path.join(&sibling.rfilename);
4893 if local_path.exists() {
4894 downloaded += 1;
4895 continue;
4896 }
4897 repo.download(&sibling.rfilename).await.map_err(|e| {
4898 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
4899 })?;
4900 downloaded += 1;
4901 }
4902
4903 Ok((snapshot_path, downloaded))
4904}
4905
4906fn temp_work_dir(prefix: &str) -> Result<PathBuf, InferenceError> {
4907 let unique = SystemTime::now()
4908 .duration_since(UNIX_EPOCH)
4909 .map_err(|e| InferenceError::InferenceFailed(format!("clock error: {e}")))?
4910 .as_nanos();
4911 let dir = std::env::temp_dir().join(format!("car-inference-{prefix}-{unique}"));
4912 std::fs::create_dir_all(&dir)?;
4913 Ok(dir)
4914}
4915
4916fn ensure_parent_dir(path: &Path) -> Result<(), InferenceError> {
4917 if let Some(parent) = path.parent() {
4918 std::fs::create_dir_all(parent)?;
4919 }
4920 Ok(())
4921}
4922
4923fn requested_or_temp_output(
4924 output_path: Option<&str>,
4925 format: &str,
4926) -> Result<PathBuf, InferenceError> {
4927 if let Some(path) = output_path {
4928 return Ok(PathBuf::from(path));
4929 }
4930 let dir = temp_work_dir("audio-out")?;
4931 Ok(dir.join(format!("speech.{format}")))
4932}
4933
4934#[allow(dead_code)] fn requested_or_temp_media_output(
4936 output_path: Option<&str>,
4937 format: &str,
4938 stem: &str,
4939) -> Result<PathBuf, InferenceError> {
4940 if let Some(path) = output_path {
4941 return Ok(PathBuf::from(path));
4942 }
4943 let dir = temp_work_dir(&format!("{stem}-out"))?;
4944 Ok(dir.join(format!("{stem}.{format}")))
4945}
4946
4947fn materialize_audio_output(
4948 produced: &Path,
4949 requested: Option<&str>,
4950 format: &str,
4951) -> Result<PathBuf, InferenceError> {
4952 if let Some(path) = requested {
4953 let dest = PathBuf::from(path);
4954 ensure_parent_dir(&dest)?;
4955 std::fs::copy(produced, &dest)?;
4956 Ok(dest)
4957 } else {
4958 let dest = requested_or_temp_output(None, format)?;
4959 ensure_parent_dir(&dest)?;
4960 std::fs::copy(produced, &dest)?;
4961 Ok(dest)
4962 }
4963}
4964
4965#[allow(dead_code)] fn read_transcription_result(output_prefix: &Path) -> Result<Option<String>, InferenceError> {
4967 let candidates = [
4968 output_prefix.with_extension("json"),
4969 output_prefix.to_path_buf(),
4970 ];
4971
4972 for path in candidates {
4973 if path.exists() {
4974 let contents = std::fs::read_to_string(path)?;
4975 if let Some(text) = extract_text_from_payload(&contents) {
4976 return Ok(Some(text));
4977 }
4978 }
4979 }
4980
4981 Ok(None)
4982}
4983
4984#[allow(dead_code)] fn extract_text_from_payload(payload: &str) -> Option<String> {
4986 let value: serde_json::Value = serde_json::from_str(payload).ok()?;
4987 if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
4988 return Some(text.to_string());
4989 }
4990 if let Some(transcripts) = value.get("transcripts").and_then(|v| v.as_array()) {
4991 let joined = transcripts
4992 .iter()
4993 .filter_map(|item| item.get("text").and_then(|v| v.as_str()))
4994 .collect::<Vec<_>>()
4995 .join("\n");
4996 if !joined.is_empty() {
4997 return Some(joined);
4998 }
4999 }
5000 if let Some(items) = value.as_array() {
5001 let joined = items
5002 .iter()
5003 .filter_map(|item| {
5004 item.get("text")
5005 .or_else(|| item.get("Content"))
5006 .and_then(|v| v.as_str())
5007 })
5008 .collect::<Vec<_>>()
5009 .join(" ");
5010 if !joined.is_empty() {
5011 return Some(joined);
5012 }
5013 }
5014 None
5015}
5016
5017#[allow(dead_code)] fn find_audio_file(output_dir: &Path) -> Result<Option<PathBuf>, InferenceError> {
5019 let mut audio_files = Vec::new();
5020 collect_audio_files(output_dir, &mut audio_files)?;
5021 audio_files.sort();
5022 Ok(audio_files.into_iter().next())
5023}
5024
5025#[allow(dead_code)] fn collect_audio_files(dir: &Path, audio_files: &mut Vec<PathBuf>) -> Result<(), InferenceError> {
5027 for entry in std::fs::read_dir(dir)? {
5028 let path = entry?.path();
5029 if path.is_dir() {
5030 collect_audio_files(&path, audio_files)?;
5031 } else if matches!(
5032 path.extension().and_then(|ext| ext.to_str()),
5033 Some("wav" | "mp3" | "flac" | "pcm" | "m4a")
5034 ) {
5035 audio_files.push(path);
5036 }
5037 }
5038 Ok(())
5039}
5040
5041fn media_type_for_format(format: &str) -> String {
5042 match format.to_ascii_lowercase().as_str() {
5043 "mp3" => "audio/mpeg".to_string(),
5044 "flac" => "audio/flac".to_string(),
5045 "pcm" => "audio/L16".to_string(),
5046 "m4a" => "audio/mp4".to_string(),
5047 _ => "audio/wav".to_string(),
5048 }
5049}
5050
5051#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5052fn kokoro_lang_code(language: Option<&str>) -> &'static str {
5053 match language.unwrap_or("en").to_ascii_lowercase().as_str() {
5054 "en-gb" | "british" | "british english" => "b",
5055 "ja" | "japanese" => "j",
5056 "zh" | "zh-cn" | "mandarin" | "chinese" => "z",
5057 "es" | "spanish" => "e",
5058 "fr" | "french" => "f",
5059 _ => "a",
5060 }
5061}
5062
5063#[allow(dead_code)] fn normalize_lang_code(language: &str) -> String {
5065 match language.to_ascii_lowercase().as_str() {
5066 "english" | "en-us" | "en_us" => "en".to_string(),
5067 "spanish" => "es".to_string(),
5068 "french" => "fr".to_string(),
5069 "japanese" => "ja".to_string(),
5070 "chinese" | "mandarin" => "zh".to_string(),
5071 other => match other {
5072 "en" | "es" | "fr" | "ja" | "zh" => other.to_string(),
5073 _ => "en".to_string(),
5074 },
5075 }
5076}
5077
5078fn elevenlabs_auth(schema: &ModelSchema) -> Result<(String, String), InferenceError> {
5079 match &schema.source {
5080 ModelSource::Proprietary {
5081 endpoint,
5082 auth: schema::ProprietaryAuth::ApiKeyEnv { env_var },
5083 ..
5084 } => {
5085 let key = car_secrets::resolve_env_or_keychain(env_var).ok_or_else(|| {
5086 InferenceError::InferenceFailed(format!(
5087 "missing API key {env_var}; set the environment variable or \
5088 store it with `car secrets put {env_var}`"
5089 ))
5090 })?;
5091 Ok((endpoint.clone(), key))
5092 }
5093 _ => Err(InferenceError::InferenceFailed(format!(
5094 "model {} is not an ElevenLabs proprietary model",
5095 schema.id
5096 ))),
5097 }
5098}
5099
5100fn elevenlabs_output_format(format: &str) -> &'static str {
5101 match format.to_ascii_lowercase().as_str() {
5102 "mp3" => "mp3_44100_128",
5103 "pcm" => "pcm_16000",
5104 _ => "wav_44100",
5105 }
5106}
5107
5108fn benchmark_priors_paths(models_dir: &Path) -> Vec<PathBuf> {
5109 let mut paths = Vec::new();
5110
5111 let direct = models_dir.join("benchmark_priors.json");
5112 if !paths.contains(&direct) {
5113 paths.push(direct);
5114 }
5115
5116 if let Some(parent) = models_dir.parent() {
5117 let parent_path = parent.join("benchmark_priors.json");
5118 if !paths.contains(&parent_path) {
5119 paths.push(parent_path);
5120 }
5121 }
5122
5123 if let Some(path) = std::env::var_os("CAR_BENCHMARK_PRIORS_PATH") {
5124 let path = PathBuf::from(path);
5125 if !paths.contains(&path) {
5126 paths.push(path);
5127 }
5128 }
5129
5130 paths
5131}
5132
5133fn load_benchmark_prior_health(
5134 models_dir: &Path,
5135 schemas: &[ModelSchema],
5136) -> Vec<ModelBenchmarkPriorHealth> {
5137 let mut priors = std::collections::BTreeMap::new();
5138 for path in benchmark_priors_paths(models_dir) {
5139 let Ok(loaded) = routing_ext::load_benchmark_priors(&path) else {
5140 continue;
5141 };
5142 for (model_id, prior) in loaded {
5143 let model_name = schemas
5144 .iter()
5145 .find(|schema| schema.id == model_id)
5146 .map(|schema| schema.name.clone());
5147 priors.insert(
5148 model_id.clone(),
5149 ModelBenchmarkPriorHealth {
5150 model_id,
5151 model_name,
5152 overall_score: prior.overall_score,
5153 overall_latency_ms: prior.overall_latency_ms,
5154 task_scores: prior.task_scores,
5155 task_latency_ms: prior.task_latency_ms,
5156 source_path: path.clone(),
5157 },
5158 );
5159 }
5160 }
5161
5162 priors.into_values().collect()
5163}
5164
5165#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5166fn kokoro_runtime_fallback_enabled() -> bool {
5167 std::env::var("CAR_SPEECH_KOKORO_FALLBACK")
5168 .ok()
5169 .map(|value| {
5170 !matches!(
5171 value.trim().to_ascii_lowercase().as_str(),
5172 "0" | "false" | "off"
5173 )
5174 })
5175 .unwrap_or(true)
5176}
5177
5178#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5179fn speech_runtime_mlx_audio_spec() -> String {
5180 std::env::var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC")
5181 .ok()
5182 .filter(|value| !value.trim().is_empty())
5183 .unwrap_or_else(|| "mlx-audio==0.4.2".to_string())
5184}
5185
5186#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5187fn speech_runtime_spacy_model_spec() -> String {
5188 std::env::var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC")
5189 .ok()
5190 .filter(|value| !value.trim().is_empty())
5191 .unwrap_or_else(|| {
5192 "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl".to_string()
5193 })
5194}
5195
5196#[cfg(test)]
5197mod tests {
5198 use super::*;
5199 use tempfile::TempDir;
5200
5201 #[test]
5202 fn no_backend_hint_fires_on_missing_backend_phrases() {
5203 for phrase in [
5205 "no credential for proprietary provider 'parslee'",
5206 "model not found",
5207 "no models available",
5208 "model declares ModelSource::Delegated but no inference runner is registered",
5209 ] {
5210 let hint = no_backend_recovery_hint(phrase)
5211 .unwrap_or_else(|| panic!("expected a hint for {phrase:?}"));
5212 assert!(hint.contains("car models pull"));
5213 assert!(hint.contains("car auth login parslee"));
5214 assert!(hint.contains(phrase));
5216 }
5217 }
5218
5219 #[test]
5220 fn no_backend_hint_passes_through_transient_errors() {
5221 for phrase in [
5224 "API returned 401 Unauthorized",
5225 "API returned 429 Too Many Requests",
5226 "API returned 500 Internal Server Error",
5227 "connection refused",
5228 "request timed out",
5229 "parse response: unexpected end of input",
5230 ] {
5231 assert!(
5232 no_backend_recovery_hint(phrase).is_none(),
5233 "transient error wrongly classified as no-backend: {phrase:?}"
5234 );
5235 }
5236 }
5237
5238 static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
5241
5242 fn test_config(models_dir: PathBuf) -> InferenceConfig {
5243 InferenceConfig {
5244 models_dir,
5245 device: None,
5246 generation_model: "Qwen3-0.6B".into(),
5247 preferred_generation_model: None,
5248 embedding_model: "Qwen3-Embedding-0.6B".into(),
5249 preferred_embedding_model: None,
5250 classification_model: "Qwen3-0.6B".into(),
5251 preferred_classification_model: None,
5252 }
5253 }
5254
5255 #[tokio::test]
5256 async fn tokenize_rejects_known_remote_model_with_unsupported_mode() {
5257 let tmp = TempDir::new().unwrap();
5262 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5263 let remote_id = engine
5264 .list_schemas()
5265 .into_iter()
5266 .find(|s| !s.is_local())
5267 .map(|s| s.id)
5268 .expect("built-in catalog should include at least one remote model schema");
5269
5270 let err = engine
5271 .tokenize(&remote_id, "hello")
5272 .await
5273 .expect_err("remote tokenize must error");
5274 match err {
5275 InferenceError::UnsupportedMode { mode, backend, .. } => {
5276 assert_eq!(mode, "tokenize/detokenize");
5277 assert_eq!(backend, "remote");
5278 }
5279 other => panic!("expected UnsupportedMode, got {other:?}"),
5280 }
5281
5282 let err = engine
5283 .detokenize(&remote_id, &[1, 2, 3])
5284 .await
5285 .expect_err("remote detokenize must error");
5286 assert!(
5287 matches!(err, InferenceError::UnsupportedMode { .. }),
5288 "expected UnsupportedMode, got {err:?}"
5289 );
5290 }
5291
5292 #[test]
5293 fn engine_loads_benchmark_priors_on_startup() {
5294 let _env = ENV_MUTEX.lock().unwrap();
5295 let tmp = TempDir::new().unwrap();
5296 let priors_path = tmp.path().join("benchmark_priors.json");
5297 std::fs::write(
5298 &priors_path,
5299 serde_json::json!({
5300 "model_id": "qwen/qwen3-8b:q4_k_m",
5301 "overall_score": 0.88
5302 })
5303 .to_string(),
5304 )
5305 .unwrap();
5306
5307 unsafe {
5308 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5309 }
5310
5311 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5312 let tracker = engine.outcome_tracker.blocking_read();
5313 let profile = tracker
5314 .profile("qwen/qwen3-8b:q4_k_m")
5315 .expect("benchmark prior should create a profile");
5316 assert!((profile.ema_quality - 0.88).abs() < 0.01);
5317
5318 unsafe {
5319 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5320 }
5321 }
5322
5323 #[test]
5324 fn benchmark_priors_do_not_override_observed_profiles() {
5325 let _env = ENV_MUTEX.lock().unwrap();
5326 let tmp = TempDir::new().unwrap();
5327 let models_dir = tmp.path().join("models");
5328 std::fs::create_dir_all(&models_dir).unwrap();
5329
5330 let observed = vec![ModelProfile {
5331 model_id: "qwen/qwen3-8b:q4_k_m".into(),
5332 total_calls: 12,
5333 success_count: 3,
5334 fail_count: 9,
5335 total_latency_ms: 1200,
5336 total_input_tokens: 0,
5337 total_output_tokens: 0,
5338 task_stats: std::collections::HashMap::new(),
5339 ema_quality: 0.21,
5340 quality_per_1k_tokens: 0.0,
5341 updated_at: 1,
5342 }];
5343 std::fs::write(
5344 models_dir.join("outcome_profiles.json"),
5345 serde_json::to_string(&observed).unwrap(),
5346 )
5347 .unwrap();
5348
5349 let priors_path = tmp.path().join("benchmark_priors.json");
5350 std::fs::write(
5351 &priors_path,
5352 serde_json::json!({
5353 "model_id": "qwen/qwen3-8b:q4_k_m",
5354 "overall_score": 0.95
5355 })
5356 .to_string(),
5357 )
5358 .unwrap();
5359
5360 unsafe {
5361 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5362 }
5363
5364 let engine = InferenceEngine::new(test_config(models_dir));
5365 let tracker = engine.outcome_tracker.blocking_read();
5366 let profile = tracker
5367 .profile("qwen/qwen3-8b:q4_k_m")
5368 .expect("observed profile should remain present");
5369 assert!((profile.ema_quality - 0.21).abs() < 0.01);
5370 assert_eq!(profile.total_calls, 12);
5371
5372 unsafe {
5373 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5374 }
5375 }
5376
5377 #[test]
5378 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5379 fn speech_runtime_package_spec_defaults_and_overrides() {
5380 let _env = ENV_MUTEX.lock().unwrap();
5381 unsafe {
5382 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5383 }
5384 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.2");
5385
5386 unsafe {
5387 std::env::set_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC", "mlx-audio==0.4.1");
5388 }
5389 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.1");
5390
5391 unsafe {
5392 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5393 }
5394 }
5395
5396 #[test]
5397 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5398 fn speech_runtime_spacy_model_spec_defaults_and_overrides() {
5399 let _env = ENV_MUTEX.lock().unwrap();
5400 unsafe {
5401 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5402 }
5403 assert!(
5404 speech_runtime_spacy_model_spec().starts_with("en-core-web-sm @ https://github.com/")
5405 );
5406
5407 unsafe {
5408 std::env::set_var(
5409 "CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC",
5410 "en-core-web-sm==3.8.0",
5411 );
5412 }
5413 assert_eq!(speech_runtime_spacy_model_spec(), "en-core-web-sm==3.8.0");
5414
5415 unsafe {
5416 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5417 }
5418 }
5419
5420 #[test]
5421 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5422 fn kokoro_runtime_fallback_defaults_on() {
5423 unsafe {
5424 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5425 }
5426 assert!(kokoro_runtime_fallback_enabled());
5427
5428 unsafe {
5429 std::env::set_var("CAR_SPEECH_KOKORO_FALLBACK", "false");
5430 }
5431 assert!(!kokoro_runtime_fallback_enabled());
5432
5433 unsafe {
5434 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5435 }
5436 }
5437
5438 #[test]
5439 fn preferred_local_tts_wins_over_builtin_rank() {
5440 let tmp = TempDir::new().unwrap();
5441 let mut engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5442 engine.set_speech_policy(SpeechPolicy {
5443 prefer_local: true,
5444 allow_remote_fallback: false,
5445 preferred_local_stt: None,
5446 preferred_local_tts: Some("Kokoro-82M-6bit".into()),
5447 preferred_remote_stt: None,
5448 preferred_remote_tts: None,
5449 });
5450
5451 let schema = engine
5452 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
5453 .expect("preferred local TTS should resolve");
5454 assert_eq!(schema.name, "Kokoro-82M-6bit");
5455 }
5456
5457 #[test]
5458 fn preferred_discovered_vllm_mlx_model_wins_generate_routing() {
5459 let tmp = TempDir::new().unwrap();
5460 let mut config = test_config(tmp.path().join("models"));
5461 config.preferred_generation_model =
5462 Some("vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit".into());
5463 let mut engine = InferenceEngine::new(config);
5464 let schema = crate::vllm_mlx::to_model_schema(
5465 &crate::vllm_mlx::DiscoveredModel {
5466 id: "mlx-community/gemma-3n-E2B-it-lm-4bit".into(),
5467 owned_by: Some("mlx-community".into()),
5468 },
5469 "http://127.0.0.1:8001",
5470 );
5471 engine.register_model(schema);
5472
5473 let rt = tokio::runtime::Runtime::new().unwrap();
5474 let decision = rt.block_on(engine.route_adaptive("say hello in one sentence"));
5475 assert_eq!(
5476 decision.model_id,
5477 "vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit"
5478 );
5479 assert_eq!(decision.strategy, RoutingStrategy::Explicit);
5480 assert_eq!(decision.reason, "preferred generation model override");
5481 }
5482
5483 #[test]
5488 fn inference_result_serializes_with_full_shape() {
5489 use crate::tasks::generate::ToolCall;
5490 use std::collections::HashMap;
5491
5492 let mut args = HashMap::new();
5493 args.insert("path".to_string(), serde_json::json!("README.md"));
5494
5495 let result = InferenceResult {
5496 text: String::new(),
5497 bounding_boxes: Vec::new(),
5498 tool_calls: vec![ToolCall {
5499 id: None,
5500 name: "read_file".into(),
5501 arguments: args,
5502 }],
5503 trace_id: "trace-abc".into(),
5504 model_used: "test-model".into(),
5505 latency_ms: 1234,
5506 time_to_first_token_ms: Some(180),
5507 usage: Some(TokenUsage {
5508 prompt_tokens: 100,
5509 completion_tokens: 50,
5510 total_tokens: 150,
5511 context_window: 8192,
5512 }),
5513 provider_output_items: Vec::new(),
5514 };
5515
5516 let json = serde_json::to_value(&result).expect("serialize");
5517
5518 assert_eq!(json["text"].as_str(), Some(""));
5520 assert_eq!(json["trace_id"].as_str(), Some("trace-abc"));
5521 assert_eq!(json["model_used"].as_str(), Some("test-model"));
5522 assert_eq!(json["latency_ms"].as_u64(), Some(1234));
5523
5524 let tool_calls = json["tool_calls"].as_array().expect("tool_calls array");
5526 assert_eq!(tool_calls.len(), 1);
5527 assert_eq!(tool_calls[0]["name"].as_str(), Some("read_file"));
5528 assert_eq!(
5529 tool_calls[0]["arguments"]["path"].as_str(),
5530 Some("README.md")
5531 );
5532
5533 let usage = &json["usage"];
5535 assert_eq!(usage["prompt_tokens"].as_u64(), Some(100));
5536 assert_eq!(usage["completion_tokens"].as_u64(), Some(50));
5537 assert_eq!(usage["total_tokens"].as_u64(), Some(150));
5538 assert_eq!(usage["context_window"].as_u64(), Some(8192));
5539
5540 assert_eq!(json["time_to_first_token_ms"].as_u64(), Some(180));
5542 }
5543
5544 #[test]
5550 fn inference_result_top_level_keys_are_locked() {
5551 use std::collections::BTreeSet;
5552
5553 let result = InferenceResult {
5554 text: "anything".into(),
5555 bounding_boxes: Vec::new(),
5556 tool_calls: vec![],
5557 trace_id: "t".into(),
5558 model_used: "m".into(),
5559 latency_ms: 0,
5560 time_to_first_token_ms: None,
5561 usage: None,
5562 provider_output_items: Vec::new(),
5563 };
5564
5565 let json = serde_json::to_value(&result).expect("serialize");
5566 let keys: BTreeSet<&str> = json
5567 .as_object()
5568 .expect("top-level object")
5569 .keys()
5570 .map(String::as_str)
5571 .collect();
5572
5573 let expected: BTreeSet<&str> = [
5574 "text",
5575 "tool_calls",
5576 "trace_id",
5577 "model_used",
5578 "latency_ms",
5579 "time_to_first_token_ms",
5580 "usage",
5581 ]
5582 .into_iter()
5583 .collect();
5584
5585 assert_eq!(
5586 keys, expected,
5587 "infer response top-level keys drifted -- update both the test \
5588 and the WebSocket protocol documentation if this is intentional"
5589 );
5590
5591 for key in &keys {
5593 assert!(
5594 !key.chars().any(|c| c.is_uppercase()) && !key.contains('-'),
5595 "key '{}' is not snake_case",
5596 key
5597 );
5598 }
5599 }
5600
5601 #[test]
5605 fn inference_result_serializes_plain_text_response() {
5606 let result = InferenceResult {
5607 text: "hello world".into(),
5608 bounding_boxes: Vec::new(),
5609 tool_calls: vec![],
5610 trace_id: "trace-xyz".into(),
5611 model_used: "test-model".into(),
5612 latency_ms: 42,
5613 time_to_first_token_ms: None,
5614 usage: None,
5615 provider_output_items: Vec::new(),
5616 };
5617
5618 let json = serde_json::to_value(&result).expect("serialize");
5619 assert_eq!(json["text"], "hello world");
5620 assert!(json["tool_calls"].is_array());
5621 assert_eq!(json["tool_calls"].as_array().unwrap().len(), 0);
5622 assert_eq!(json["model_used"], "test-model");
5623 assert!(json["usage"].is_null());
5624 assert!(json["time_to_first_token_ms"].is_null());
5627 }
5628
5629 #[test]
5641 fn generate_request_deserializes_intent_field_from_json_rpc_params() {
5642 use crate::intent::{IntentHint, TaskHint};
5643 use crate::schema::ModelCapability;
5644
5645 let params = serde_json::json!({
5648 "prompt": "summarize this email",
5649 "intent": {
5650 "task": "chat",
5651 "prefer_local": true,
5652 "require": ["tool_use"],
5653 },
5654 });
5655
5656 let req: GenerateRequest =
5657 serde_json::from_value(params).expect("GenerateRequest deserialize");
5658
5659 let intent = req.intent.as_ref().expect("intent field deserialized");
5660 assert_eq!(intent.task, Some(TaskHint::Chat));
5661 assert!(intent.prefer_local);
5662 assert_eq!(intent.require, vec![ModelCapability::ToolUse]);
5663
5664 let back: serde_json::Value =
5668 serde_json::to_value(&req).expect("re-serialize GenerateRequest");
5669 assert_eq!(back["intent"]["task"], "chat");
5670 assert_eq!(back["intent"]["prefer_local"], true);
5671 assert_eq!(back["intent"]["require"][0], "tool_use");
5672
5673 let default_req: GenerateRequest = serde_json::from_value(serde_json::json!({
5678 "prompt": "x",
5679 "intent": {},
5680 }))
5681 .unwrap();
5682 let default_intent = default_req.intent.expect("present but empty");
5683 assert_eq!(default_intent.task, None);
5684 assert!(!default_intent.prefer_local);
5685 assert!(default_intent.require.is_empty());
5686
5687 let no_intent: GenerateRequest =
5690 serde_json::from_value(serde_json::json!({"prompt": "x"})).unwrap();
5691 assert!(no_intent.intent.is_none());
5692 }
5693
5694 #[test]
5695 fn rerank_prompt_matches_upstream_template_shape() {
5696 let p = rerank_prompt(
5697 "retrieve relevant passages",
5698 "who runs the treasury?",
5699 "doc x",
5700 );
5701 assert!(p.contains("<|im_start|>system"));
5702 assert!(p.contains("Note that the answer can only be \"yes\" or \"no\"."));
5703 assert!(p.contains("<|im_start|>user\n<Instruct>: retrieve relevant passages"));
5704 assert!(p.contains("<Query>: who runs the treasury?"));
5705 assert!(p.contains("<Document>: doc x<|im_end|>"));
5706 assert!(p.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
5707 }
5708
5709 #[test]
5710 fn rerank_score_yes_and_no_exactly() {
5711 assert_eq!(score_from_rerank_output("yes", "m"), 1.0);
5712 assert_eq!(score_from_rerank_output("no", "m"), 0.0);
5713 }
5714
5715 #[test]
5716 fn rerank_score_handles_case_leading_space_and_chat_sentinels() {
5717 assert_eq!(score_from_rerank_output(" Yes", "m"), 1.0);
5720 assert_eq!(score_from_rerank_output("\nno.", "m"), 0.0);
5721 assert_eq!(score_from_rerank_output("<|im_end|>yes", "m"), 1.0);
5722 }
5723
5724 #[test]
5725 fn rerank_score_scans_up_to_three_tokens() {
5726 assert_eq!(score_from_rerank_output("_bos_ yes", "m"), 1.0);
5729 }
5730
5731 #[test]
5732 fn rerank_score_unexpected_is_neutral() {
5733 assert_eq!(score_from_rerank_output("maybe", "m"), 0.5);
5736 assert_eq!(score_from_rerank_output("", "m"), 0.5);
5737 }
5738}