1pub mod adaptive_router;
26pub mod backend;
27pub mod backend_cache;
28pub mod download;
29pub mod handle;
30pub mod hardware;
31pub mod intent;
32pub mod key_pool;
33pub mod models;
34pub mod nudge;
35pub mod outcome;
36pub mod protocol;
37pub mod recommend;
38pub mod registry;
39pub mod remote;
40pub mod router;
41pub mod routing_ext;
42pub mod runner;
43pub mod schema;
44pub mod service;
45pub mod stream;
46pub mod tasks;
47pub mod update_prefs;
48pub mod upgrade;
49pub mod vllm_mlx;
50
51use std::path::{Path, PathBuf};
52use std::sync::Arc;
53use std::time::Instant;
54use std::time::{SystemTime, UNIX_EPOCH};
55
56use reqwest::multipart::{Form, Part};
57use serde::Serialize;
58use thiserror::Error;
59#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
60use tokio::process::Command;
61#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
62use tokio::sync::Mutex;
63use tokio::sync::RwLock;
64use tracing::{debug, instrument};
65
66pub use adaptive_router::{
68 AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy,
69};
70pub use handle::InferenceHandle;
71pub use intent::{
72 IntentHint, Privacy, QualityTier, TaskHint, TierWeights, UseCase, UseCaseRole,
73};
74pub use download::{DownloadEvent, DownloadProgress, ProgressSink};
75pub use update_prefs::{UpdateChannel, UpdatePolicy, UpdatePreferences};
76pub use nudge::{NudgeDecision, NudgeState, UpgradeNudge};
77pub use upgrade::{HuggingFaceProbe, UpgradeFinding, UpgradeSource, UpstreamProbe};
78pub use recommend::{recommend, FitStatus, Recommendation, RecommendationSet};
79pub use key_pool::{KeyPool, KeyStats};
80pub use outcome::{
81 CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
82};
83pub use registry::{
84 ModelFilter, ModelInfo, ModelRuntimeRequirement, ModelUpgrade, UnifiedRegistry,
85};
86pub use remote::RemoteBackend;
87pub use routing_ext::{
88 CircuitBreaker, CircuitBreakerRegistry, CircuitState, ImplicitSignal, ImplicitSignalType,
89 RoutingMode, SpendControl, SpendLimitExceeded, SpendLimits, SpendStatus,
90};
91pub use runner::{
92 current_inference_runner, set_inference_runner, EventEmitter, InferenceRunner, RunnerError,
93 RunnerResult,
94};
95pub use schema::{
96 ApiProtocol, BenchmarkScore, CostModel, ModelCapability, ModelSchema, ModelSource,
97 PerformanceEnvelope, ProprietaryAuth, TrustTier,
98};
99
100pub use adaptive_router::TaskComplexity;
102#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
103pub use backend::CandleBackend;
104#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
105pub use backend::EmbeddingBackend;
106#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
107pub use backend::MlxBackend;
108pub use hardware::HardwareInfo;
109pub use models::{ModelRegistry, ModelRole};
110pub use router::{ModelRouter, RoutingDecision};
111pub use stream::{StreamAccumulator, StreamEvent};
112pub use tasks::{
113 parse_boxes, BoundingBox, ClassifyRequest, ClassifyResult, ContentBlock, EmbedRequest,
114 GenerateImageRequest, GenerateImageResult, GenerateParams, GenerateRequest,
115 GenerateVideoRequest, GenerateVideoResult, GroundRequest, GroundResult, Message, RerankRequest,
116 RerankResult, RerankedDocument, RoutingWorkload, SynthesizeRequest, SynthesizeResult,
117 ThinkingMode, ToolCall, TranscribeRequest, TranscribeResult, VideoMode,
118};
119#[derive(Error, Debug)]
122pub enum InferenceError {
123 #[error("model not found: {0}")]
124 ModelNotFound(String),
125
126 #[error("model download failed: {0}")]
127 DownloadFailed(String),
128
129 #[error("inference failed: {0}")]
130 InferenceFailed(String),
131
132 #[error("mode {mode} not implemented on backend {backend}: {reason}")]
137 UnsupportedMode {
138 mode: &'static str,
139 backend: &'static str,
140 reason: &'static str,
141 },
142
143 #[error("tokenization error: {0}")]
144 TokenizationError(String),
145
146 #[error("device error: {0}")]
147 DeviceError(String),
148
149 #[error("io error: {0}")]
150 Io(#[from] std::io::Error),
151}
152
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
155pub enum Device {
156 Cpu,
157 Metal,
158 Cuda(usize), }
160
161impl Device {
162 pub fn auto() -> Self {
164 #[cfg(feature = "metal")]
165 {
166 return Device::Metal;
167 }
168 #[cfg(feature = "cuda")]
169 {
170 return Device::Cuda(0);
171 }
172 #[cfg(not(any(feature = "metal", feature = "cuda")))]
173 {
174 Device::Cpu
175 }
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct InferenceConfig {
182 pub models_dir: std::path::PathBuf,
184 pub device: Option<Device>,
186 pub generation_model: String,
188 pub preferred_generation_model: Option<String>,
190 pub embedding_model: String,
192 pub preferred_embedding_model: Option<String>,
194 pub classification_model: String,
196 pub preferred_classification_model: Option<String>,
198}
199
200impl Default for InferenceConfig {
201 fn default() -> Self {
202 let models_dir = dirs_next()
203 .unwrap_or_else(|| std::path::PathBuf::from("."))
204 .join(".car")
205 .join("models");
206
207 let hw = HardwareInfo::detect();
208
209 Self {
210 models_dir,
211 device: None,
212 generation_model: hw.recommended_model,
213 preferred_generation_model: None,
214 embedding_model: "Qwen3-Embedding-0.6B".to_string(),
215 preferred_embedding_model: None,
216 classification_model: "Qwen3-0.6B".to_string(),
217 preferred_classification_model: None,
218 }
219 }
220}
221
222fn dirs_next() -> Option<std::path::PathBuf> {
223 std::env::var("HOME").ok().map(std::path::PathBuf::from)
224}
225
226#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
228pub struct TokenUsage {
229 pub prompt_tokens: u64,
231 pub completion_tokens: u64,
233 pub total_tokens: u64,
235 pub context_window: u64,
237}
238
239#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
241pub struct InferenceResult {
242 pub text: String,
244 pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
246 #[serde(default, skip_serializing_if = "Vec::is_empty")]
253 pub bounding_boxes: Vec<crate::tasks::grounding::BoundingBox>,
254 pub trace_id: String,
256 pub model_used: String,
258 pub latency_ms: u64,
260 #[serde(default)]
272 pub time_to_first_token_ms: Option<u64>,
273 pub usage: Option<TokenUsage>,
275 #[serde(default, skip_serializing_if = "Vec::is_empty")]
288 pub provider_output_items: Vec<serde_json::Value>,
289}
290
291impl InferenceResult {
292 pub fn has_tool_calls(&self) -> bool {
294 !self.tool_calls.is_empty()
295 }
296}
297
298#[derive(Debug, Clone, Serialize)]
299pub struct SpeechRuntimeHealth {
300 pub root: PathBuf,
301 pub installed: bool,
302 pub python: PathBuf,
303 pub stt_command: PathBuf,
304 pub tts_command: PathBuf,
305 pub configured_python: Option<String>,
306 pub detected_python: Option<String>,
307}
308
309#[derive(Debug, Clone, Serialize)]
310pub struct SpeechModelHealth {
311 pub id: String,
312 pub name: String,
313 pub provider: String,
314 pub capability: ModelCapability,
315 pub is_local: bool,
316 pub available: bool,
317 pub cached: bool,
318 pub selected_by_default: bool,
319 pub source: String,
320}
321
322#[derive(Debug, Clone, Serialize)]
323pub struct SpeechHealthReport {
324 pub runtime: SpeechRuntimeHealth,
325 pub local_models: Vec<SpeechModelHealth>,
326 pub remote_models: Vec<SpeechModelHealth>,
327 pub elevenlabs_configured: bool,
328 pub prefer_local: bool,
329 pub allow_remote_fallback: bool,
330 pub preferred_local_stt: Option<String>,
331 pub preferred_local_tts: Option<String>,
332 pub preferred_remote_stt: Option<String>,
333 pub preferred_remote_tts: Option<String>,
334 pub local_stt_default: Option<String>,
335 pub local_tts_default: Option<String>,
336 pub remote_stt_default: Option<String>,
337 pub remote_tts_default: Option<String>,
338}
339
340#[derive(Debug, Clone, Serialize)]
341pub struct ModelDefaultHealth {
342 pub capability: ModelCapability,
343 pub configured_model: String,
344 pub available: bool,
345 pub is_local: bool,
346 pub provider: Option<String>,
347}
348
349#[derive(Debug, Clone, Serialize)]
350pub struct ModelProviderHealth {
351 pub provider: String,
352 pub configured: bool,
353 pub local_models: usize,
354 pub remote_models: usize,
355 pub available_models: usize,
356 pub capabilities: Vec<ModelCapability>,
357}
358
359#[derive(Debug, Clone, Serialize)]
360pub struct ModelCapabilityHealth {
361 pub capability: ModelCapability,
362 pub total_models: usize,
363 pub available_models: usize,
364 pub local_available_models: usize,
365 pub remote_available_models: usize,
366}
367
368#[derive(Debug, Clone, Serialize)]
369pub struct RoutingScenarioHealth {
370 pub name: String,
371 pub workload: RoutingWorkload,
372 pub task_family: String,
373 pub has_tools: bool,
374 pub has_vision: bool,
375 pub prefer_local: bool,
376 pub quality_first_cold_start: bool,
377 pub bootstrap_min_task_observations: u64,
378 pub bootstrap_quality_floor: f64,
379 pub model_id: String,
380 pub model_name: String,
381 pub reason: String,
382 pub strategy: RoutingStrategy,
383}
384
385#[derive(Debug, Clone, Serialize)]
386pub struct ModelBenchmarkPriorHealth {
387 pub model_id: String,
388 pub model_name: Option<String>,
389 pub overall_score: f64,
390 pub overall_latency_ms: Option<f64>,
391 pub task_scores: std::collections::HashMap<String, f64>,
392 pub task_latency_ms: std::collections::HashMap<String, f64>,
393 pub source_path: PathBuf,
394}
395
396#[derive(Debug, Clone, Serialize)]
397pub struct ModelHealthReport {
398 pub total_models: usize,
399 pub available_models: usize,
400 pub local_models: usize,
401 pub remote_models: usize,
402 pub defaults: Vec<ModelDefaultHealth>,
403 pub providers: Vec<ModelProviderHealth>,
404 pub capabilities: Vec<ModelCapabilityHealth>,
405 pub routing_prefer_local: bool,
406 pub routing_quality_first_cold_start: bool,
407 pub routing_min_observations: u64,
408 pub routing_bootstrap_min_task_observations: u64,
409 pub routing_bootstrap_quality_floor: f64,
410 pub routing_quality_weight: f64,
411 pub routing_latency_weight: f64,
412 pub routing_cost_weight: f64,
413 pub routing_scenarios: Vec<RoutingScenarioHealth>,
414 pub benchmark_priors: Vec<ModelBenchmarkPriorHealth>,
415 pub speech: SpeechHealthReport,
416}
417
418#[derive(Debug, Clone, Serialize)]
419pub struct SpeechInstallReport {
420 pub name: String,
421 pub hf_repo: String,
422 pub snapshot_path: PathBuf,
423 pub files_downloaded: usize,
424}
425
426#[derive(Debug, Clone, Serialize)]
427pub struct SpeechSmokePathReport {
428 pub path: String,
429 pub tts_model: String,
430 pub stt_model: String,
431 pub audio_path: PathBuf,
432 pub transcript: String,
433}
434
435#[derive(Debug, Clone, Serialize, Default)]
436pub struct SpeechSmokeReport {
437 pub local: Option<SpeechSmokePathReport>,
438 pub remote: Option<SpeechSmokePathReport>,
439 pub skipped: Vec<String>,
440}
441
442#[derive(Debug, Clone, Serialize, Default)]
443pub struct SpeechPolicy {
444 pub prefer_local: bool,
445 pub allow_remote_fallback: bool,
446 pub preferred_local_stt: Option<String>,
447 pub preferred_local_tts: Option<String>,
448 pub preferred_remote_stt: Option<String>,
449 pub preferred_remote_tts: Option<String>,
450}
451
452pub struct InferenceEngine {
457 pub config: InferenceConfig,
458 pub unified_registry: UnifiedRegistry,
460 pub adaptive_router: AdaptiveRouter,
462 pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
464 remote_backend: RemoteBackend,
466 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
472 mlx_backends: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
473 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
478 flux_cache: Arc<backend_cache::BackendCache<backend::mlx_flux::FluxBackend>>,
479 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
481 ltx_cache: Arc<backend_cache::BackendCache<backend::mlx_ltx::LtxBackend>>,
482 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
485 kokoro_cache: Arc<backend_cache::BackendCache<backend::mlx_kokoro::KokoroBackend>>,
486 pub registry: models::ModelRegistry,
488 pub router: ModelRouter,
489 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
490 backend: Arc<RwLock<Option<CandleBackend>>>,
491 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
492 embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
493 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
494 speech_runtime: Arc<Mutex<Option<SpeechRuntime>>>,
495 speech_policy: SpeechPolicy,
496}
497
498impl InferenceEngine {
499 fn preferred_model_for_capability(&self, capability: ModelCapability) -> Option<&str> {
500 match capability {
501 ModelCapability::Generate => self.config.preferred_generation_model.as_deref(),
502 ModelCapability::Embed => self.config.preferred_embedding_model.as_deref(),
503 ModelCapability::Classify => self.config.preferred_classification_model.as_deref(),
504 _ => None,
505 }
506 }
507
508 fn request_needs_vision(req: &GenerateRequest) -> bool {
509 req.images.as_ref().is_some_and(|images| !images.is_empty())
510 || req.messages.as_ref().is_some_and(|messages| {
511 messages
512 .iter()
513 .any(|msg| matches!(msg, Message::UserMultimodal { .. }))
514 })
515 }
516
517 #[allow(dead_code)] fn request_has_video(req: &GenerateRequest) -> bool {
523 let images_have_video = req
524 .images
525 .as_ref()
526 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
527 let messages_have_video = req.messages.as_ref().is_some_and(|messages| {
528 messages.iter().any(|msg| match msg {
529 Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_video),
530 _ => false,
531 })
532 });
533 images_have_video || messages_have_video
534 }
535
536 #[allow(dead_code)] fn request_has_audio(req: &GenerateRequest) -> bool {
541 let images_have_audio = req
542 .images
543 .as_ref()
544 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
545 let messages_have_audio = req.messages.as_ref().is_some_and(|messages| {
546 messages.iter().any(|msg| match msg {
547 Message::UserMultimodal { content } => content.iter().any(ContentBlock::is_audio),
548 _ => false,
549 })
550 });
551 images_have_audio || messages_have_audio
552 }
553
554 pub fn new(config: InferenceConfig) -> Self {
555 let registry = models::ModelRegistry::new(config.models_dir.clone());
556 let hw = HardwareInfo::detect();
557 let router = ModelRouter::new(hw.clone());
558 let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
559 let adaptive_router = AdaptiveRouter::with_default_config(hw);
560 let mut tracker = OutcomeTracker::new();
561 let profiles_path = config.models_dir.join("outcome_profiles.json");
563 if let Ok(n) = tracker.load_from_file(&profiles_path) {
564 if n > 0 {
565 tracing::info!(loaded = n, "loaded persisted model profiles");
566 }
567 }
568 let mut benchmark_models_loaded = 0usize;
569 for path in benchmark_priors_paths(&config.models_dir) {
570 match routing_ext::load_benchmark_priors(&path) {
571 Ok(priors) if !priors.is_empty() => {
572 benchmark_models_loaded += priors.len();
573 routing_ext::apply_benchmark_priors(&mut tracker, &priors);
574 tracing::info!(
575 path = %path.display(),
576 loaded = priors.len(),
577 "loaded benchmark quality priors"
578 );
579 }
580 Ok(_) => {}
581 Err(error) => {
582 tracing::warn!(path = %path.display(), %error, "failed to load benchmark priors");
583 }
584 }
585 }
586 if benchmark_models_loaded > 0 {
587 tracing::info!(
588 loaded = benchmark_models_loaded,
589 "applied benchmark priors to cold-start routing"
590 );
591 }
592 let outcome_tracker = Arc::new(RwLock::new(tracker));
593
594 let remote_backend = RemoteBackend::new();
595
596 Self {
597 config,
598 unified_registry,
599 adaptive_router,
600 outcome_tracker,
601 remote_backend,
602 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
603 mlx_backends: Arc::new(backend_cache::BackendCache::from_env()),
604 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
605 flux_cache: Arc::new(backend_cache::BackendCache::from_env()),
606 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
607 ltx_cache: Arc::new(backend_cache::BackendCache::from_env()),
608 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
609 kokoro_cache: Arc::new(backend_cache::BackendCache::from_env()),
610 registry,
611 router,
612 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
613 backend: Arc::new(RwLock::new(None)),
614 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
615 embedding_backend: Arc::new(RwLock::new(None)),
616 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
617 speech_runtime: Arc::new(Mutex::new(None)),
618 speech_policy: SpeechPolicy {
619 prefer_local: cfg!(all(
620 target_os = "macos",
621 target_arch = "aarch64",
622 not(car_skip_mlx)
623 )),
624 allow_remote_fallback: true,
625 preferred_local_stt: None,
626 preferred_local_tts: None,
627 preferred_remote_stt: None,
628 preferred_remote_tts: None,
629 },
630 }
631 }
632
633 pub async fn init_key_pool(&self) {
636 for schema in self.unified_registry.list() {
638 if schema.is_remote() {
639 self.remote_backend.register_model_keys(schema).await;
640 }
641 }
642
643 let stats_path = self.config.models_dir.join("key_pool_stats.json");
645 if let Ok(n) = self.remote_backend.key_pool.load_stats(&stats_path).await {
646 if n > 0 {
647 tracing::info!(loaded = n, "loaded persisted key pool stats");
648 }
649 }
650
651 let total = self.remote_backend.key_pool.total_keys().await;
652 if total > 0 {
653 tracing::info!(keys = total, "key pool initialized");
654 }
655 }
656
657 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
660 async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
661 let read = self.backend.read().await;
662 if read.is_some() {
663 return Ok(());
664 }
665 drop(read);
666
667 let mut write = self.backend.write().await;
668 if write.is_some() {
669 return Ok(());
670 }
671
672 let model_path = self.registry.ensure_model(model_name).await?;
673 let device = self.config.device.unwrap_or_else(Device::auto);
674 let backend = CandleBackend::load(&model_path, device)?;
675 *write = Some(backend);
676 Ok(())
677 }
678
679 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
682 async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
683 let read = self.embedding_backend.read().await;
684 if read.is_some() {
685 return Ok(());
686 }
687 drop(read);
688
689 let mut write = self.embedding_backend.write().await;
690 if write.is_some() {
691 return Ok(());
692 }
693
694 let embedding_model = self
695 .preferred_model_for_capability(ModelCapability::Embed)
696 .unwrap_or(&self.config.embedding_model);
697 let model_path = self.registry.ensure_model(embedding_model).await?;
698 let device = self.config.device.unwrap_or_else(Device::auto);
699 let backend = EmbeddingBackend::load(&model_path, device)?;
700 *write = Some(backend);
701 Ok(())
702 }
703
704 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
707 async fn ensure_mlx_embedding_backend(&self) -> Result<String, InferenceError> {
708 let embedding_model_name = self
709 .preferred_model_for_capability(ModelCapability::Embed)
710 .unwrap_or(&self.config.embedding_model)
711 .to_string();
712 let schema = self
713 .unified_registry
714 .get(&embedding_model_name)
715 .or_else(|| self.unified_registry.find_by_name(&embedding_model_name))
716 .ok_or_else(|| InferenceError::ModelNotFound(embedding_model_name.clone()))?
717 .clone();
718 self.ensure_mlx_backend(&schema).await?;
719 Ok(schema.id)
720 }
721
722 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
726 async fn ensure_mlx_backend(
727 &self,
728 schema: &ModelSchema,
729 ) -> Result<backend_cache::CachedBackend<backend::MlxBackend>, InferenceError> {
730 if !Self::supports_native_mlx(schema) {
731 return Err(InferenceError::InferenceFailed(format!(
732 "native MLX backend does not support {} ({}) yet; use vLLM-MLX or add a family-specific MLX backend",
733 schema.name, schema.family
734 )));
735 }
736
737 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
738 let size = backend_cache::estimate_model_size(&model_dir);
739 let cache = Arc::clone(&self.mlx_backends);
740 let key = schema.id.clone();
741 cache.get_or_load(&key, size, move || {
745 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
746 backend::MlxBackend::load(&model_dir)
747 }))
748 .map_err(|e| {
749 InferenceError::InferenceFailed(format!(
750 "MLX backend loading panicked (possible Metal/accelerate exception): {:?}",
751 e
752 ))
753 })?
754 })
755 }
756
757 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
764 pub async fn warm_up<S: AsRef<str>>(
765 &self,
766 schema_ids: &[S],
767 ) -> Vec<Result<(), InferenceError>> {
768 let mut results = Vec::with_capacity(schema_ids.len());
769 for id in schema_ids {
770 let id = id.as_ref();
771 let outcome: Result<(), InferenceError> = async {
772 let schema = self.unified_registry.get(id).cloned().ok_or_else(|| {
773 InferenceError::InferenceFailed(format!("warm_up: unknown schema id {id}"))
774 })?;
775 match schema.capabilities.first().copied() {
776 Some(ModelCapability::ImageGeneration) => {
777 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
778 let size = backend_cache::estimate_model_size(&model_dir);
779 let _ = self.flux_cache.get_or_load(&schema.id, size, || {
780 backend::mlx_flux::FluxBackend::load(&model_dir)
781 })?;
782 }
783 Some(ModelCapability::VideoGeneration) => {
784 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
785 let size = backend_cache::estimate_model_size(&model_dir);
786 let _ = self.ltx_cache.get_or_load(&schema.id, size, || {
787 backend::mlx_ltx::LtxBackend::load(&model_dir)
788 })?;
789 }
790 Some(ModelCapability::TextToSpeech) => {
791 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
792 let size = backend_cache::estimate_model_size(&model_dir);
793 let _ = self.kokoro_cache.get_or_load(&schema.id, size, || {
794 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
795 })?;
796 }
797 _ => {
798 let _ = self.ensure_mlx_backend(&schema).await?;
799 }
800 }
801 Ok(())
802 }
803 .await;
804 results.push(outcome);
805 }
806 results
807 }
808
809 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
811 pub async fn warm_up<S: AsRef<str>>(
812 &self,
813 _schema_ids: &[S],
814 ) -> Vec<Result<(), InferenceError>> {
815 Vec::new()
816 }
817
818 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
819 fn supports_native_mlx(schema: &ModelSchema) -> bool {
820 matches!(schema.family.as_str(), "qwen3" | "qwen2.5-vl" | "qwen2-vl")
821 }
822
823 pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
825 if let Some(model) = self.preferred_model_for_capability(ModelCapability::Generate) {
826 let ctx_len = self
827 .unified_registry
828 .get(model)
829 .or_else(|| self.unified_registry.find_by_name(model))
830 .map(|s| s.context_length)
831 .unwrap_or(0);
832 return AdaptiveRoutingDecision {
833 model_id: model.to_string(),
834 model_name: model.to_string(),
835 task: InferenceTask::Generate,
836 complexity: TaskComplexity::assess(prompt),
837 reason: "preferred generation model override".into(),
838 strategy: RoutingStrategy::Explicit,
839 predicted_quality: 0.5,
840 fallbacks: vec![],
841 context_length: ctx_len,
842 needs_compaction: false,
843 };
844 }
845 let tracker = self.outcome_tracker.read().await;
846 self.adaptive_router
847 .route(prompt, &self.unified_registry, &tracker)
848 }
849
850 pub fn route(&self, prompt: &str) -> RoutingDecision {
852 self.router.route_generate(prompt, &self.registry)
853 }
854
855 pub fn estimated_tokens(
858 &self,
859 req: &GenerateRequest,
860 model_id: Option<&str>,
861 ) -> (usize, usize, bool) {
862 let prompt_tokens = remote::estimate_tokens(&req.prompt);
863 let context_tokens = req
864 .context
865 .as_ref()
866 .map(|c| remote::estimate_tokens(c))
867 .unwrap_or(0);
868 let tools_tokens = req
869 .tools
870 .as_ref()
871 .map(|t| remote::estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
872 .unwrap_or(0);
873 let total_input = prompt_tokens + context_tokens + tools_tokens;
874
875 let context_window = model_id
876 .and_then(|id| {
877 self.unified_registry
878 .get(id)
879 .or_else(|| self.unified_registry.find_by_name(id))
880 })
881 .map(|s| s.context_length)
882 .unwrap_or(0);
883
884 let fits = context_window == 0 || (total_input + req.params.max_tokens) <= context_window;
885 (total_input, context_window, fits)
886 }
887
888 #[instrument(
891 name = "inference.generate",
892 skip_all,
893 fields(
894 model = tracing::field::Empty,
895 max_tokens = req.params.max_tokens,
896 prompt_tokens = tracing::field::Empty,
897 completion_tokens = tracing::field::Empty,
898 latency_ms = tracing::field::Empty,
899 )
900 )]
901 pub async fn generate_tracked(
902 &self,
903 req: GenerateRequest,
904 ) -> Result<InferenceResult, InferenceError> {
905 let start = Instant::now();
906
907 let (estimated_input, _, _) = self.estimated_tokens(&req, None);
909 let tracker_read = self.outcome_tracker.read().await;
910 let has_tools = req.tools.is_some();
911 let has_vision = Self::request_needs_vision(&req);
912 let preferred_model = self
913 .preferred_model_for_capability(ModelCapability::Generate)
914 .map(str::to_string);
915 let decision = match req.model.clone().or(preferred_model) {
916 Some(m) => {
917 let ctx_len = self
918 .unified_registry
919 .get(&m)
920 .or_else(|| self.unified_registry.find_by_name(&m))
921 .map(|s| s.context_length)
922 .unwrap_or(0);
923 AdaptiveRoutingDecision {
924 model_id: m.clone(),
925 model_name: m.clone(),
926 task: InferenceTask::Generate,
927 complexity: TaskComplexity::assess(&req.prompt),
928 reason: "explicit model".into(),
929 strategy: RoutingStrategy::Explicit,
930 predicted_quality: 0.5,
931 fallbacks: vec![],
932 context_length: ctx_len,
933 needs_compaction: ctx_len > 0 && estimated_input > ctx_len,
934 }
935 }
936 None => match &req.intent {
937 Some(hint) => self.adaptive_router.route_context_aware_with_intent(
938 &req.prompt,
939 estimated_input,
940 &self.unified_registry,
941 &tracker_read,
942 has_tools,
943 has_vision,
944 req.params.workload,
945 hint,
946 ),
947 None => self.adaptive_router.route_context_aware(
948 &req.prompt,
949 estimated_input,
950 &self.unified_registry,
951 &tracker_read,
952 has_tools,
953 has_vision,
954 req.params.workload,
955 ),
956 },
957 };
958 drop(tracker_read);
959
960 if decision.needs_compaction {
961 tracing::info!(
962 model = %decision.model_name,
963 prompt_tokens = estimated_input,
964 context_window = decision.context_length,
965 "prompt exceeds model context window — compaction or truncation needed"
966 );
967 }
968
969 let trace_id = {
971 let mut tracker = self.outcome_tracker.write().await;
972 tracker.record_start(&decision.model_id, decision.task, &decision.reason)
973 };
974
975 debug!(
976 model = %decision.model_name,
977 strategy = ?decision.strategy,
978 reason = %decision.reason,
979 trace = %trace_id,
980 "adaptive-routed generate request"
981 );
982
983 let mut req = req;
986 if req.params.budget_tokens == 0 && matches!(decision.complexity, TaskComplexity::Complex) {
987 let supports_thinking = self
988 .unified_registry
989 .get(&decision.model_id)
990 .map(|s| {
991 s.supported_params
992 .contains(&schema::GenerateParam::ExtendedThinking)
993 })
994 .unwrap_or(false);
995 if supports_thinking {
996 req.params.budget_tokens = 8000;
997 tracing::info!(model = %decision.model_name, budget = 8000, "auto-enabled extended thinking for complex task");
998 }
999 }
1000
1001 let mut models_to_try = vec![decision.model_id.clone()];
1003 models_to_try.extend(decision.fallbacks.iter().cloned());
1004
1005 let mut last_error = None;
1006
1007 for candidate_id in &models_to_try {
1008 #[allow(unused_mut)]
1011 let mut schema = self
1012 .unified_registry
1013 .get(candidate_id)
1014 .or_else(|| self.unified_registry.find_by_name(candidate_id))
1015 .cloned();
1016
1017 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1019 if let Some(ref s) = schema {
1020 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1021 tracing::info!(
1022 from = %s.id, to = %mlx_equiv.id,
1023 "redirecting GGUF model to MLX equivalent on Apple Silicon"
1024 );
1025 schema = Some(mlx_equiv.clone());
1026 }
1027 }
1028
1029 let candidate_name = schema
1030 .as_ref()
1031 .map(|s| s.name.clone())
1032 .unwrap_or_else(|| candidate_id.clone());
1033
1034 let is_remote = schema
1035 .as_ref()
1036 .map(|s| s.is_remote() || s.is_vllm_mlx())
1037 .unwrap_or(false);
1038 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1039
1040 if is_delegated {
1046 let runner = match runner::current_inference_runner() {
1047 Some(r) => r,
1048 None => {
1049 last_error = Some(InferenceError::InferenceFailed(
1050 "model declares ModelSource::Delegated but no inference runner is registered"
1051 .into(),
1052 ));
1053 continue;
1054 }
1055 };
1056 let (tx, mut rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1057 let emitter = runner::EventEmitter::new(tx);
1058 let runner_req = req.clone();
1059 let runner_handle =
1060 tokio::spawn(async move { runner.run(runner_req, emitter).await });
1061 let mut accumulator = stream::StreamAccumulator::default();
1062 while let Some(evt) = rx.recv().await {
1063 accumulator.push(&evt);
1064 }
1065 let (acc_text, acc_tool_calls) = accumulator.finish();
1069 match runner_handle.await {
1070 Ok(Ok(_runner_result)) => {
1071 let elapsed = start.elapsed().as_millis() as u64;
1072 let mut tracker = self.outcome_tracker.write().await;
1073 tracker.record_complete(&trace_id, elapsed, 0, 0);
1074 return Ok(InferenceResult {
1075 text: acc_text,
1076 tool_calls: acc_tool_calls,
1077 bounding_boxes: vec![],
1078 trace_id,
1079 model_used: candidate_name,
1080 latency_ms: elapsed,
1081 time_to_first_token_ms: None,
1082 usage: None,
1083 provider_output_items: vec![],
1084 });
1085 }
1086 Ok(Err(e)) => {
1087 last_error = Some(InferenceError::InferenceFailed(e.to_string()));
1088 continue;
1089 }
1090 Err(join_err) => {
1091 last_error = Some(InferenceError::InferenceFailed(format!(
1092 "runner task panicked: {join_err}"
1093 )));
1094 continue;
1095 }
1096 }
1097 }
1098
1099 let has_tools = req.tools.is_some();
1100
1101 let context = if has_tools
1103 && req.tools.as_ref().map_or(false, |t| {
1104 t.iter().any(|tool| {
1105 tool.get("function")
1106 .and_then(|f| f.get("name"))
1107 .and_then(|n| n.as_str())
1108 == Some("done")
1109 })
1110 }) {
1111 let base = req.context.as_deref().unwrap_or("");
1112 Some(format!(
1113 "{base}\n\nIMPORTANT: When calling the `done` tool, the `result` field MUST contain a DETAILED summary of everything you found and did. This is the ONLY output the user sees. Do NOT just say 'completed' — include specific findings, data, and conclusions."
1114 ))
1115 } else {
1116 req.context.clone()
1117 };
1118
1119 let result = if is_remote {
1120 let schema_val = schema.unwrap();
1121 let _ctx_len = schema_val.context_length;
1122 let temperature = if !schema_val.supported_params.is_empty()
1125 && !schema_val
1126 .supported_params
1127 .contains(&crate::schema::GenerateParam::Temperature)
1128 {
1129 -1.0
1130 } else {
1131 req.params.temperature
1132 };
1133
1134 self.remote_backend
1140 .generate_with_tools_multi(
1141 &schema_val,
1142 &req.prompt,
1143 context.as_deref(),
1144 temperature,
1145 req.params.max_tokens,
1146 req.tools.as_deref(),
1147 req.images.as_deref(),
1148 req.messages.as_deref(),
1149 req.params.tool_choice.as_deref(),
1150 req.params.parallel_tool_calls,
1151 req.params.budget_tokens,
1152 req.cache_control,
1153 req.response_format.as_ref(),
1154 )
1155 .await
1156 .map(|(t, c, u)| (t, c, u, None::<u64>))
1161 } else {
1162 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1163 {
1164 let schema_ref = schema
1168 .as_ref()
1169 .ok_or_else(|| InferenceError::ModelNotFound(candidate_id.clone()))?;
1170
1171 if schema_ref.is_foundation_models() {
1176 if has_tools {
1177 Err(InferenceError::UnsupportedMode {
1178 mode: "tool-use",
1179 backend: "foundation-models",
1180 reason: "tool calling not yet wired through the FoundationModels \
1181 bridge — route to a remote model with ToolUse capability",
1182 })
1183 } else if Self::request_has_video(&req)
1184 || Self::request_has_audio(&req)
1185 || req.images.as_ref().is_some_and(|imgs| !imgs.is_empty())
1186 {
1187 Err(InferenceError::UnsupportedMode {
1188 mode: "multimodal-content",
1189 backend: "foundation-models",
1190 reason: "the FoundationModels bridge currently exposes text-only \
1191 generation — route image/audio/video to a remote VL model",
1192 })
1193 } else {
1194 let prompt = req.prompt.clone();
1195 let instructions = context.clone();
1196 let max_tokens = req.params.max_tokens as u32;
1197 let temperature = req.params.temperature;
1198 tokio::task::spawn_blocking(move || {
1199 crate::backend::foundation_models::generate(
1200 &prompt,
1201 instructions.as_deref(),
1202 max_tokens,
1203 temperature as f32,
1204 )
1205 })
1206 .await
1207 .map_err(|e| {
1208 InferenceError::InferenceFailed(format!(
1209 "FoundationModels task panicked: {e}"
1210 ))
1211 })
1212 .and_then(|r| r)
1213 .map(|text| (text, vec![], None, None))
1214 }
1215 } else if !schema_ref.is_mlx() {
1216 Err(InferenceError::InferenceFailed(format!(
1217 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1218 schema_ref.id
1219 )))
1220 } else if schema_ref.tags.iter().any(|t| t == "mlx-vlm-cli") {
1221 let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
1233 if !has_images {
1234 return Err(InferenceError::UnsupportedMode {
1235 mode: "text-only-on-mlx-vlm-id",
1236 backend: "mlx-vlm-cli",
1237 reason: "the `mlx-vlm/...` model IDs route exclusively \
1238 through the mlx-vlm CLI for image inference. \
1239 For text-only generation, route to a Qwen3 \
1240 text model (`mlx/qwen3-4b:4bit` etc.) — the \
1241 CLI shell-out has higher latency than the \
1242 in-process MLX text tower.",
1243 });
1244 }
1245 let vlm_status = crate::backend::mlx_vlm_cli::runtime_status();
1246 if !vlm_status.is_available() {
1247 return Err(InferenceError::InferenceFailed(vlm_status.user_message()));
1248 }
1249 let repo = match &schema_ref.source {
1250 crate::schema::ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
1251 _ => {
1252 return Err(InferenceError::InferenceFailed(format!(
1253 "model '{}' is tagged mlx-vlm-cli but its \
1254 source isn't ModelSource::Mlx — registry bug",
1255 schema_ref.id
1256 )));
1257 }
1258 };
1259 let imgs = req.images.clone().unwrap_or_default();
1260 let temp = req.params.temperature;
1261 let max_t = req.params.max_tokens;
1262 let prompt = req.prompt.clone();
1263 let text = tokio::task::spawn_blocking(move || {
1264 crate::backend::mlx_vlm_cli::generate(
1265 &repo, &prompt, &imgs, temp, max_t,
1266 )
1267 })
1268 .await
1269 .map_err(|e| {
1270 InferenceError::InferenceFailed(format!(
1271 "mlx_vlm CLI task panicked: {e}"
1272 ))
1273 })??;
1274 let bounding_boxes = parse_boxes(&text);
1275 let latency_ms = start.elapsed().as_millis() as u64;
1276 {
1277 let mut tracker = self.outcome_tracker.write().await;
1278 tracker.record_complete(&trace_id, latency_ms, 0, 0);
1279 }
1280 return Ok(InferenceResult {
1281 text,
1282 tool_calls: vec![],
1283 bounding_boxes,
1284 trace_id: trace_id.clone(),
1285 model_used: schema_ref.id.clone(),
1286 latency_ms,
1287 time_to_first_token_ms: None,
1288 usage: None,
1289 provider_output_items: Vec::new(),
1290 });
1291 } else {
1292 let handle = self.ensure_mlx_backend(schema_ref).await?;
1302 if Self::request_has_video(&req) {
1308 return Err(InferenceError::UnsupportedMode {
1309 mode: "video-content-block",
1310 backend: "native-mlx-qwen25vl",
1311 reason: "Qwen2.5-VL video understanding is on the request surface \
1312 but the video-tokenization path (frame sampling + merger) \
1313 is not yet wired; route to a remote VL provider for now",
1314 });
1315 }
1316 if Self::request_has_audio(&req) {
1317 return Err(InferenceError::UnsupportedMode {
1318 mode: "audio-content-block",
1319 backend: "native-mlx-qwen25vl",
1320 reason: "audio understanding is on the request surface (Gemma 4 \
1321 E2B/E4B and Gemini accept it) but the native MLX path \
1322 for this model does not — route to Gemini or Gemma-4",
1323 });
1324 }
1325 let has_images = req.images.as_ref().is_some_and(|imgs| !imgs.is_empty());
1326 if has_images {
1327 let can_do_vision = {
1328 let guard = handle.lock().map_err(|_| {
1329 InferenceError::InferenceFailed(
1330 "MLX backend mutex poisoned".into(),
1331 )
1332 })?;
1333 guard.supports_capability(crate::schema::ModelCapability::Vision)
1334 };
1335 if !can_do_vision {
1336 return Err(InferenceError::UnsupportedMode {
1346 mode: "image-content-block",
1347 backend: "native-mlx-text",
1348 reason: "this MLX backend is a plain Qwen3 text tower. \
1349 For local image inference, route to \
1350 `mlx-vlm/qwen3-vl-2b:bf16` or another `mlx-vlm/...` \
1351 catalog ID so CAR shells out to `mlx_vlm.generate`. \
1352 Alternatives: a local vLLM-MLX VLM server, or a \
1353 remote VL model. (#115)",
1354 });
1355 }
1356 }
1357 self.generate_mlx(req.clone(), &schema_ref.id)
1358 .await
1359 .map(|(text, ttft)| (text, vec![], None, ttft))
1360 }
1361 }
1362
1363 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1364 {
1365 match self.ensure_backend(&candidate_name).await {
1366 Ok(()) => {
1367 let mut write = self.backend.write().await;
1368 let backend = write.as_mut().unwrap();
1369 tasks::generate::generate(backend, req.clone())
1370 .await
1371 .map(|(text, ttft)| (text, vec![], None, ttft))
1372 }
1373 Err(e) => Err(e),
1374 }
1375 }
1376 };
1377
1378 match result {
1379 Ok((text, tool_calls, usage, time_to_first_token_ms)) => {
1380 let latency_ms = start.elapsed().as_millis() as u64;
1381 let estimated_tokens = usage
1382 .as_ref()
1383 .map(|u| u.completion_tokens as usize)
1384 .unwrap_or_else(|| text.split_whitespace().count());
1385 {
1386 let mut tracker = self.outcome_tracker.write().await;
1387 tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
1388 }
1389 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1391 cb.record_success(candidate_id);
1392 }
1393 self.auto_save_outcomes().await;
1395
1396 let span = tracing::Span::current();
1398 span.record("model", candidate_name.as_str());
1399 span.record("latency_ms", latency_ms);
1400 if let Some(ttft) = time_to_first_token_ms {
1401 span.record("ttft_ms", ttft);
1402 }
1403 if let Some(ref u) = usage {
1404 span.record("prompt_tokens", u.prompt_tokens);
1405 span.record("completion_tokens", u.completion_tokens);
1406 }
1407
1408 let bounding_boxes = tasks::grounding::parse_boxes(&text);
1411 return Ok(InferenceResult {
1412 text,
1413 tool_calls,
1414 bounding_boxes,
1415 trace_id,
1416 model_used: candidate_name,
1417 latency_ms,
1418 time_to_first_token_ms,
1419 usage,
1420 provider_output_items: Vec::new(),
1421 });
1422 }
1423 Err(e) => {
1424 tracing::warn!(
1425 model = %candidate_name,
1426 error = %e,
1427 remaining = models_to_try.len().saturating_sub(
1428 models_to_try.iter().position(|m| m == candidate_id).unwrap_or(0) + 1
1429 ),
1430 "model failed, trying next fallback immediately"
1431 );
1432 {
1434 let mut tracker = self.outcome_tracker.write().await;
1435 let fail_trace =
1436 tracker.record_start(candidate_id, decision.task, "fallback");
1437 tracker.record_failure(&fail_trace, &e.to_string());
1438 }
1439 {
1443 let err_str = e.to_string();
1444 let is_client_error =
1445 err_str.contains("API returned 4") && !err_str.contains("429");
1446 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1447 cb.record_failure(candidate_id);
1448 if is_client_error {
1451 cb.record_failure(candidate_id);
1452 }
1453 }
1454 }
1455 #[cfg(not(all(
1457 target_os = "macos",
1458 target_arch = "aarch64",
1459 not(car_skip_mlx)
1460 )))]
1461 {
1462 let mut write = self.backend.write().await;
1463 *write = None;
1464 }
1465 last_error = Some(e);
1466 }
1467 }
1468 }
1469
1470 let underlying = last_error.unwrap_or(InferenceError::InferenceFailed(
1472 "no models available".into(),
1473 ));
1474
1475 let e = match no_backend_recovery_hint(&underlying.to_string()) {
1484 Some(msg) => InferenceError::InferenceFailed(msg),
1485 None => underlying,
1486 };
1487 {
1488 let mut tracker = self.outcome_tracker.write().await;
1489 tracker.record_failure(&trace_id, &e.to_string());
1490 }
1491 self.auto_save_outcomes().await;
1492 Err(e)
1493 }
1494
1495 pub async fn generate_tracked_stream(
1532 &self,
1533 req: GenerateRequest,
1534 ) -> Result<tokio::sync::mpsc::Receiver<stream::StreamEvent>, InferenceError> {
1535 let has_tools = req.tools.is_some();
1536 let has_vision = Self::request_needs_vision(&req);
1537 let preferred_model = self
1538 .preferred_model_for_capability(ModelCapability::Generate)
1539 .map(str::to_string);
1540 let decision = match req.model.clone().or(preferred_model) {
1541 Some(m) => {
1542 let ctx_len = self
1543 .unified_registry
1544 .get(&m)
1545 .or_else(|| self.unified_registry.find_by_name(&m))
1546 .map(|s| s.context_length)
1547 .unwrap_or(0);
1548 AdaptiveRoutingDecision {
1549 model_id: m.clone(),
1550 model_name: m,
1551 task: InferenceTask::Generate,
1552 complexity: TaskComplexity::assess(&req.prompt),
1553 reason: "explicit model".into(),
1554 strategy: RoutingStrategy::Explicit,
1555 predicted_quality: 0.5,
1556 fallbacks: vec![],
1557 context_length: ctx_len,
1558 needs_compaction: false,
1559 }
1560 }
1561 None => {
1562 let tracker_read = self.outcome_tracker.read().await;
1563 if has_vision {
1564 self.adaptive_router.route_with_vision(
1565 &req.prompt,
1566 &self.unified_registry,
1567 &tracker_read,
1568 has_tools,
1569 )
1570 } else if has_tools {
1571 self.adaptive_router.route_with_tools(
1572 &req.prompt,
1573 &self.unified_registry,
1574 &tracker_read,
1575 )
1576 } else {
1577 self.adaptive_router
1578 .route(&req.prompt, &self.unified_registry, &tracker_read)
1579 }
1580 }
1581 };
1582
1583 #[allow(unused_mut)]
1586 let mut schema = self
1587 .unified_registry
1588 .get(&decision.model_id)
1589 .or_else(|| self.unified_registry.find_by_name(&decision.model_id))
1590 .cloned();
1591
1592 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1594 if let Some(ref s) = schema {
1595 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1596 tracing::info!(
1597 from = %s.id, to = %mlx_equiv.id,
1598 "redirecting GGUF model to MLX equivalent on Apple Silicon (stream)"
1599 );
1600 schema = Some(mlx_equiv.clone());
1601 }
1602 }
1603
1604 let is_remote = schema
1605 .as_ref()
1606 .map(|s| s.is_remote() || s.is_vllm_mlx())
1607 .unwrap_or(false);
1608
1609 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1610
1611 if is_delegated {
1612 let runner = runner::current_inference_runner().ok_or_else(|| {
1616 InferenceError::InferenceFailed(
1617 "model declares ModelSource::Delegated but no inference runner is registered \
1618 (call set_inference_runner / registerInferenceRunner / register_inference_runner)"
1619 .into(),
1620 )
1621 })?;
1622 let (tx, rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1623 let emitter = runner::EventEmitter::new(tx);
1624 let request = req.clone();
1625 tokio::spawn(async move {
1626 if let Err(e) = runner.run(request, emitter).await {
1627 tracing::warn!(error = %e, "delegated inference runner failed");
1628 }
1629 });
1630 return Ok(rx);
1631 }
1632
1633 if is_remote {
1634 let schema = schema.unwrap();
1635 self.remote_backend.register_model_keys(&schema).await;
1637
1638 self.remote_backend
1639 .generate_stream(
1640 &schema,
1641 &req.prompt,
1642 req.context.as_deref(),
1643 req.params.temperature,
1644 req.params.max_tokens,
1645 req.tools.as_deref(),
1646 req.images.as_deref(),
1647 req.params.tool_choice.as_deref(),
1648 req.params.parallel_tool_calls,
1649 req.response_format.as_ref(),
1650 )
1651 .await
1652 } else {
1653 let schema =
1654 schema.ok_or_else(|| InferenceError::ModelNotFound(decision.model_id.clone()))?;
1655 let (tx, rx) = tokio::sync::mpsc::channel(64);
1656
1657 #[cfg(any(
1663 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
1664 all(target_os = "ios", target_arch = "aarch64")
1665 ))]
1666 {
1667 if schema.is_foundation_models() {
1668 let prompt = req.prompt.clone();
1669 let instructions = req.context.clone();
1670 let max_tokens = req.params.max_tokens as u32;
1671 let temperature = req.params.temperature;
1672 let tx_clone = tx.clone();
1673 tokio::task::spawn_blocking(move || {
1674 let accum = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
1681 let accum_cb = accum.clone();
1682 let cb = crate::backend::foundation_models::StreamCallback::new(
1683 move |delta: &str| {
1684 if let Ok(mut g) = accum_cb.lock() {
1685 g.push_str(delta);
1686 }
1687 tx_clone
1688 .blocking_send(stream::StreamEvent::TextDelta(
1689 delta.to_string(),
1690 ))
1691 .is_ok()
1692 },
1693 );
1694 let result = crate::backend::foundation_models::stream(
1695 &prompt,
1696 instructions.as_deref(),
1697 max_tokens,
1698 temperature as f32,
1699 cb,
1700 );
1701 let final_text = accum.lock().map(|g| g.clone()).unwrap_or_default();
1702 let _ = tx.blocking_send(stream::StreamEvent::Done {
1703 text: final_text,
1704 tool_calls: vec![],
1705 });
1706 result
1707 });
1708 return Ok(rx);
1709 }
1710 }
1711
1712 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1714 {
1715 if !schema.is_mlx() {
1718 return Err(InferenceError::InferenceFailed(format!(
1719 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1720 schema.id
1721 )));
1722 }
1723 let backend = self.ensure_mlx_backend(&schema).await?;
1724 let model_id = schema.id.clone();
1725 let cache = Arc::clone(&self.mlx_backends);
1726 tokio::task::spawn_blocking(move || {
1732 let _ = Self::stream_local_mlx(backend, cache, model_id, req, tx);
1733 });
1734 return Ok(rx);
1735 }
1736
1737 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1738 {
1739 self.ensure_backend(&schema.name).await?;
1740 let backend = self.backend.clone();
1741 tokio::spawn(async move {
1742 let _ = Self::stream_local_candle(backend, req, tx).await;
1743 });
1744 Ok(rx)
1745 }
1746 }
1747 }
1748
1749 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1750 async fn stream_local_candle(
1751 backend_lock: Arc<RwLock<Option<CandleBackend>>>,
1752 req: GenerateRequest,
1753 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1754 ) -> Result<(), InferenceError> {
1755 let mut write = backend_lock.write().await;
1756 let backend = write
1757 .as_mut()
1758 .ok_or_else(|| InferenceError::InferenceFailed("backend not initialized".into()))?;
1759 backend.clear_kv_cache();
1760
1761 let formatted = tasks::generate::apply_chat_template(
1762 &req.prompt,
1763 req.context.as_deref(),
1764 req.params.thinking,
1765 );
1766 let tokens = backend.encode(&formatted)?;
1767 let eos = backend.eos_token_id();
1768 let eos_alt = backend.token_id("<|im_end|>");
1769 let params = &req.params;
1770
1771 if tokens.is_empty() {
1772 let _ = tx
1773 .send(stream::StreamEvent::Done {
1774 text: String::new(),
1775 tool_calls: vec![],
1776 })
1777 .await;
1778 return Ok(());
1779 }
1780
1781 let max_ctx = backend.context_length().unwrap_or(32768);
1782 let headroom = params.max_tokens.min(max_ctx / 4);
1783 let max_prompt = max_ctx.saturating_sub(headroom);
1784 let tokens = if tokens.len() > max_prompt {
1785 tokens[tokens.len() - max_prompt..].to_vec()
1786 } else {
1787 tokens
1788 };
1789
1790 let mut generated = Vec::new();
1791 let logits = backend.forward(&tokens, 0)?;
1792 let mut next_token = tasks::generate::sample_token(&logits, params)?;
1793
1794 for _ in 0..params.max_tokens {
1795 if eos.map_or(false, |id| next_token == id)
1796 || eos_alt.map_or(false, |id| next_token == id)
1797 {
1798 break;
1799 }
1800
1801 generated.push(next_token);
1802 let delta = backend.decode(&[next_token])?;
1803 if !delta.is_empty()
1804 && tx
1805 .send(stream::StreamEvent::TextDelta(delta))
1806 .await
1807 .is_err()
1808 {
1809 return Ok(());
1810 }
1811
1812 if !params.stop.is_empty() {
1813 let text_so_far = backend.decode(&generated)?;
1814 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1815 break;
1816 }
1817 }
1818
1819 let pos = tokens.len() + generated.len() - 1;
1820 let logits = backend.forward(&[next_token], pos)?;
1821 next_token = tasks::generate::sample_token(&logits, params)?;
1822 }
1823
1824 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1825 let _ = tx
1826 .send(stream::StreamEvent::Done {
1827 text,
1828 tool_calls: vec![],
1829 })
1830 .await;
1831 Ok(())
1832 }
1833
1834 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1838 fn stream_local_mlx(
1839 handle: backend_cache::CachedBackend<backend::MlxBackend>,
1840 cache: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
1841 model_id: String,
1842 req: GenerateRequest,
1843 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1844 ) -> Result<(), InferenceError> {
1845 let mut guard = handle.lock().map_err(|_| {
1846 InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
1847 })?;
1848 let backend: &mut backend::MlxBackend = &mut *guard;
1849 backend.clear_kv_cache();
1850
1851 let formatted = tasks::generate::apply_chat_template(
1852 &req.prompt,
1853 req.context.as_deref(),
1854 req.params.thinking,
1855 );
1856 let tokens = backend.encode(&formatted)?;
1857 let eos = backend.eos_token_id();
1858 let eos_alt = backend.token_id("<|im_end|>");
1859 let params = &req.params;
1860
1861 if tokens.is_empty() {
1862 let _ = tx.blocking_send(stream::StreamEvent::Done {
1863 text: String::new(),
1864 tool_calls: vec![],
1865 });
1866 return Ok(());
1867 }
1868
1869 let max_ctx = backend.context_length();
1870 let headroom = params.max_tokens.min(max_ctx / 4);
1871 let max_prompt = max_ctx.saturating_sub(headroom);
1872 let tokens = if tokens.len() > max_prompt {
1873 tokens[tokens.len() - max_prompt..].to_vec()
1874 } else {
1875 tokens
1876 };
1877
1878 let mut generated = Vec::new();
1879 let logits = match Self::catch_mlx("stream prefill", || backend.forward(&tokens, 0)) {
1884 Ok(v) => v,
1885 Err(e) => {
1886 cache.invalidate(&model_id);
1887 return Err(e);
1888 }
1889 };
1890 let mut next_token = Self::sample_from_logits(&logits, params)?;
1891
1892 for _ in 0..params.max_tokens {
1893 if eos.map_or(false, |id| next_token == id)
1894 || eos_alt.map_or(false, |id| next_token == id)
1895 {
1896 break;
1897 }
1898
1899 generated.push(next_token);
1900 let delta = backend.decode(&[next_token])?;
1901 if !delta.is_empty()
1902 && tx
1903 .blocking_send(stream::StreamEvent::TextDelta(delta))
1904 .is_err()
1905 {
1906 return Ok(());
1907 }
1908
1909 if !params.stop.is_empty() {
1910 let text_so_far = backend.decode(&generated)?;
1911 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1912 break;
1913 }
1914 }
1915
1916 let pos = tokens.len() + generated.len() - 1;
1917 let logits =
1918 match Self::catch_mlx("stream forward", || backend.forward(&[next_token], pos)) {
1919 Ok(v) => v,
1920 Err(e) => {
1921 cache.invalidate(&model_id);
1922 return Err(e);
1923 }
1924 };
1925 next_token = Self::sample_from_logits(&logits, params)?;
1926 }
1927
1928 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1929 let _ = tx.blocking_send(stream::StreamEvent::Done {
1930 text,
1931 tool_calls: vec![],
1932 });
1933 Ok(())
1934 }
1935
1936 pub async fn route_context_snapshot(
1938 &self,
1939 prompt: &str,
1940 workload: RoutingWorkload,
1941 has_tools: bool,
1942 has_vision: bool,
1943 ) -> AdaptiveRoutingDecision {
1944 let tracker = self.outcome_tracker.read().await;
1945 self.adaptive_router.route_context_aware(
1946 prompt,
1947 0,
1948 &self.unified_registry,
1949 &tracker,
1950 has_tools,
1951 has_vision,
1952 workload,
1953 )
1954 }
1955
1956 pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
1959 Ok(self.generate_tracked(req).await?.text)
1960 }
1961
1962 pub async fn tokenize(&self, model: &str, text: &str) -> Result<Vec<u32>, InferenceError> {
1974 self.assert_local_for_tokenize(model)?;
1975
1976 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1977 {
1978 let schema = self
1979 .unified_registry
1980 .get(model)
1981 .or_else(|| self.unified_registry.find_by_name(model))
1982 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1983 .clone();
1984 let handle = self.ensure_mlx_backend(&schema).await?;
1985 let guard = handle.lock().map_err(|_| {
1986 InferenceError::InferenceFailed(format!(
1987 "MLX backend mutex poisoned for {}",
1988 schema.id
1989 ))
1990 })?;
1991 return guard.tokenize_raw(text);
1992 }
1993
1994 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1995 {
1996 self.ensure_backend(model).await?;
1997 let read = self.backend.read().await;
1998 let backend = read.as_ref().ok_or_else(|| {
1999 InferenceError::InferenceFailed(
2000 "candle backend missing after ensure_backend".to_string(),
2001 )
2002 })?;
2003 backend.tokenize_raw(text)
2004 }
2005 }
2006
2007 pub async fn detokenize(&self, model: &str, tokens: &[u32]) -> Result<String, InferenceError> {
2009 self.assert_local_for_tokenize(model)?;
2010
2011 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2012 {
2013 let schema = self
2014 .unified_registry
2015 .get(model)
2016 .or_else(|| self.unified_registry.find_by_name(model))
2017 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
2018 .clone();
2019 let handle = self.ensure_mlx_backend(&schema).await?;
2020 let guard = handle.lock().map_err(|_| {
2021 InferenceError::InferenceFailed(format!(
2022 "MLX backend mutex poisoned for {}",
2023 schema.id
2024 ))
2025 })?;
2026 return guard.detokenize_raw(tokens);
2027 }
2028
2029 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2030 {
2031 self.ensure_backend(model).await?;
2032 let read = self.backend.read().await;
2033 let backend = read.as_ref().ok_or_else(|| {
2034 InferenceError::InferenceFailed(
2035 "candle backend missing after ensure_backend".to_string(),
2036 )
2037 })?;
2038 backend.detokenize_raw(tokens)
2039 }
2040 }
2041
2042 fn assert_local_for_tokenize(&self, model: &str) -> Result<(), InferenceError> {
2046 if let Some(schema) = self
2047 .unified_registry
2048 .get(model)
2049 .or_else(|| self.unified_registry.find_by_name(model))
2050 {
2051 if !schema.is_local() {
2052 return Err(InferenceError::UnsupportedMode {
2053 mode: "tokenize/detokenize",
2054 backend: "remote",
2055 reason: "remote provider tokenizer is not exposed by the runtime; \
2056 use a local model (Qwen3 GGUF / MLX) for tokenizer-correctness checks",
2057 });
2058 }
2059 }
2060 Ok(())
2062 }
2063
2064 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2070 fn catch_mlx<F, T>(context: &str, f: F) -> Result<T, InferenceError>
2071 where
2072 F: FnOnce() -> Result<T, InferenceError>,
2073 {
2074 std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)).map_err(|e| {
2075 InferenceError::InferenceFailed(format!("MLX panicked during {context}: {e:?}"))
2076 })?
2077 }
2078
2079 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2082 async fn generate_mlx(
2083 &self,
2084 req: GenerateRequest,
2085 model_id: &str,
2086 ) -> Result<(String, Option<u64>), InferenceError> {
2087 let start = std::time::Instant::now();
2088
2089 let schema = self
2090 .unified_registry
2091 .get(model_id)
2092 .cloned()
2093 .ok_or_else(|| {
2094 InferenceError::InferenceFailed(format!(
2095 "generate_mlx: unknown schema id {model_id}"
2096 ))
2097 })?;
2098 let handle = self.ensure_mlx_backend(&schema).await?;
2099 let mut guard = handle.lock().map_err(|_| {
2100 InferenceError::InferenceFailed(format!("MLX backend mutex poisoned for {model_id}"))
2101 })?;
2102 let backend: &mut backend::MlxBackend = &mut *guard;
2103 backend.clear_kv_cache();
2104
2105 let formatted = tasks::generate::apply_chat_template(
2106 &req.prompt,
2107 req.context.as_deref(),
2108 req.params.thinking,
2109 );
2110 let tokens = backend.encode(&formatted)?;
2111 let eos = backend.eos_token_id();
2112 let eos_alt = backend.token_id("<|im_end|>");
2113 let params = &req.params;
2114
2115 if tokens.is_empty() {
2116 return Ok((String::new(), None));
2117 }
2118
2119 let max_ctx = backend.context_length();
2121 let headroom = params.max_tokens.min(max_ctx / 4);
2122 let max_prompt = max_ctx.saturating_sub(headroom);
2123 let tokens = if tokens.len() > max_prompt {
2124 tokens[tokens.len() - max_prompt..].to_vec()
2125 } else {
2126 tokens
2127 };
2128
2129 let mut generated = Vec::new();
2130
2131 let logits = match Self::catch_mlx("prefill", || backend.forward(&tokens, 0)) {
2137 Ok(v) => v,
2138 Err(e) => {
2139 drop(guard);
2140 self.mlx_backends.invalidate(model_id);
2141 return Err(e);
2142 }
2143 };
2144 let mut next_token = Self::sample_from_logits(&logits, params)?;
2145 let ttft_ms = Some(start.elapsed().as_millis() as u64);
2146
2147 for _ in 0..params.max_tokens {
2148 if eos.map_or(false, |id| next_token == id)
2149 || eos_alt.map_or(false, |id| next_token == id)
2150 {
2151 break;
2152 }
2153
2154 generated.push(next_token);
2155
2156 if !params.stop.is_empty() {
2157 let text_so_far = backend.decode(&generated)?;
2158 if params.stop.iter().any(|s| text_so_far.contains(s)) {
2159 break;
2160 }
2161 }
2162
2163 let pos = tokens.len() + generated.len() - 1;
2164 let logits = match Self::catch_mlx("forward", || backend.forward(&[next_token], pos)) {
2165 Ok(v) => v,
2166 Err(e) => {
2167 drop(guard);
2168 self.mlx_backends.invalidate(model_id);
2169 return Err(e);
2170 }
2171 };
2172 next_token = Self::sample_from_logits(&logits, params)?;
2173 }
2174
2175 let text = backend.decode(&generated)?;
2176 Ok((
2177 tasks::generate::strip_thinking(&text, params.thinking),
2178 ttft_ms,
2179 ))
2180 }
2181
2182 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2184 fn sample_from_logits(logits: &[f32], params: &GenerateParams) -> Result<u32, InferenceError> {
2185 if params.temperature <= 0.0 {
2186 let (idx, _) = logits
2188 .iter()
2189 .enumerate()
2190 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2191 .ok_or_else(|| InferenceError::InferenceFailed("empty logits".into()))?;
2192 return Ok(idx as u32);
2193 }
2194
2195 let temp = params.temperature as f32;
2197 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
2198 let mut probs: Vec<f32> = logits
2199 .iter()
2200 .map(|&l| ((l - max_logit) / temp).exp())
2201 .collect();
2202 let sum: f32 = probs.iter().sum();
2203 for p in &mut probs {
2204 *p /= sum;
2205 }
2206
2207 if params.top_p < 1.0 {
2209 let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
2210 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2211 let mut cumsum = 0.0;
2212 let mut cutoff_idx = indexed.len();
2213 for (i, &(_, p)) in indexed.iter().enumerate() {
2214 cumsum += p;
2215 if cumsum > params.top_p as f32 {
2216 cutoff_idx = i + 1;
2217 break;
2218 }
2219 }
2220 let allowed: std::collections::HashSet<usize> =
2222 indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
2223 for (i, p) in probs.iter_mut().enumerate() {
2224 if !allowed.contains(&i) {
2225 *p = 0.0;
2226 }
2227 }
2228 let sum: f32 = probs.iter().sum();
2229 if sum > 0.0 {
2230 for p in &mut probs {
2231 *p /= sum;
2232 }
2233 }
2234 }
2235
2236 use rand::Rng;
2238 let mut rng = rand::rng();
2239 let r: f32 = rng.random();
2240 let mut cumsum = 0.0;
2241 for (i, &p) in probs.iter().enumerate() {
2242 cumsum += p;
2243 if cumsum >= r {
2244 return Ok(i as u32);
2245 }
2246 }
2247 Ok((probs.len() - 1) as u32)
2248 }
2249
2250 pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
2253 let instruction = req
2254 .instruction
2255 .as_deref()
2256 .unwrap_or("Retrieve relevant memory facts");
2257
2258 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2259 {
2260 let model_id = self.ensure_mlx_embedding_backend().await?;
2261 let schema = self
2262 .unified_registry
2263 .get(&model_id)
2264 .cloned()
2265 .ok_or_else(|| {
2266 InferenceError::InferenceFailed(format!("embed: unknown schema id {model_id}"))
2267 })?;
2268 let handle = self.ensure_mlx_backend(&schema).await?;
2269 let mut guard = handle.lock().map_err(|_| {
2270 InferenceError::InferenceFailed(format!(
2271 "MLX embedding backend mutex poisoned for {model_id}"
2272 ))
2273 })?;
2274 let backend: &mut backend::MlxBackend = &mut *guard;
2275
2276 let mut results = Vec::with_capacity(req.texts.len());
2277 for text in &req.texts {
2278 let embedding = if req.is_query {
2279 backend.embed_query(text, instruction)?
2280 } else {
2281 backend.embed_one(text)?
2282 };
2283 results.push(embedding);
2284 }
2285 return Ok(results);
2286 }
2287
2288 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2289 {
2290 self.ensure_embedding_backend().await?;
2291 let mut write = self.embedding_backend.write().await;
2292 let backend = write.as_mut().unwrap();
2293
2294 let mut results = Vec::with_capacity(req.texts.len());
2295 for text in &req.texts {
2296 let embedding = if req.is_query {
2297 backend.embed_query(text, instruction)?
2298 } else {
2299 backend.embed_one(text)?
2300 };
2301 results.push(embedding);
2302 }
2303 Ok(results)
2304 }
2305 }
2306
2307 pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResult, InferenceError> {
2337 if req.documents.is_empty() {
2338 return Ok(RerankResult {
2339 ranked: Vec::new(),
2340 model_used: None,
2341 });
2342 }
2343
2344 let model_name = match req.model.clone() {
2345 Some(m) => m,
2346 None => self
2347 .preferred_model_for_capability(ModelCapability::Rerank)
2348 .map(str::to_string)
2349 .ok_or_else(|| {
2350 InferenceError::InferenceFailed(
2351 "no reranker model available — pull a Qwen3-Reranker model first".into(),
2352 )
2353 })?,
2354 };
2355
2356 let schema = self
2357 .unified_registry
2358 .find_by_name(&model_name)
2359 .or_else(|| self.unified_registry.get(&model_name))
2360 .cloned()
2361 .ok_or_else(|| {
2362 InferenceError::InferenceFailed(format!(
2363 "rerank: unknown reranker model {model_name}"
2364 ))
2365 })?;
2366 if !schema.has_capability(ModelCapability::Rerank) {
2367 return Err(InferenceError::InferenceFailed(format!(
2368 "model {} does not declare the Rerank capability",
2369 schema.name
2370 )));
2371 }
2372
2373 let instruction = req.instruction.as_deref().unwrap_or(
2374 "Given a web search query, retrieve relevant passages that answer the query",
2375 );
2376
2377 let mut scored: Vec<RerankedDocument> = Vec::with_capacity(req.documents.len());
2378 for (idx, doc) in req.documents.iter().enumerate() {
2379 let prompt = rerank_prompt(instruction, &req.query, doc);
2380 let gen_req = GenerateRequest {
2381 prompt,
2382 model: Some(schema.id.clone()),
2383 params: tasks::generate::GenerateParams {
2384 temperature: 0.0,
2385 max_tokens: 3,
2389 thinking: tasks::generate::ThinkingMode::Off,
2390 ..Default::default()
2391 },
2392 context: None,
2393 tools: None,
2394 images: None,
2395 messages: None,
2396 cache_control: false,
2397 response_format: None,
2398 intent: None,
2399 };
2400 let out = self.generate(gen_req).await?;
2401 let score = score_from_rerank_output(&out, &schema.name);
2402 scored.push(RerankedDocument {
2403 index: idx,
2404 score,
2405 document: doc.clone(),
2406 });
2407 }
2408
2409 scored.sort_by(|a, b| {
2412 b.score
2413 .partial_cmp(&a.score)
2414 .unwrap_or(std::cmp::Ordering::Equal)
2415 .then_with(|| a.index.cmp(&b.index))
2416 });
2417 if let Some(n) = req.top_n {
2418 scored.truncate(n);
2419 }
2420
2421 Ok(RerankResult {
2422 ranked: scored,
2423 model_used: Some(schema.name),
2424 })
2425 }
2426
2427 pub async fn ground(&self, req: GroundRequest) -> Result<GroundResult, InferenceError> {
2437 let model_name = match req.model.clone() {
2438 Some(m) => m,
2439 None => self
2440 .preferred_model_for_capability(ModelCapability::Grounding)
2441 .map(str::to_string)
2442 .ok_or_else(|| {
2443 InferenceError::InferenceFailed(
2444 "no grounding-capable model available — pull a Qwen2.5-VL model first"
2445 .into(),
2446 )
2447 })?,
2448 };
2449
2450 let gen_req = GenerateRequest {
2451 prompt: req.prompt.clone(),
2452 model: Some(model_name),
2453 params: GenerateParams::default(),
2454 context: None,
2455 tools: None,
2456 images: Some(vec![req.image.clone()]),
2457 messages: None,
2458 cache_control: false,
2459 response_format: None,
2460 intent: None,
2461 };
2462 let result = self.generate_tracked(gen_req).await?;
2463 Ok(GroundResult {
2464 boxes: result.bounding_boxes,
2465 raw_text: result.text,
2466 model_used: Some(result.model_used),
2467 })
2468 }
2469
2470 pub async fn classify(
2473 &self,
2474 req: ClassifyRequest,
2475 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2476 let model = match req.model.clone().or_else(|| {
2477 self.preferred_model_for_capability(ModelCapability::Classify)
2478 .map(str::to_string)
2479 }) {
2480 Some(m) => m,
2481 None => {
2482 let m = self.router.route_small(&self.registry);
2483 debug!(model = %m, "auto-routed classify request");
2484 m
2485 }
2486 };
2487
2488 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2491 {
2492 return self.classify_via_generate(req, &model).await;
2493 }
2494
2495 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2496 {
2497 self.ensure_backend(&model).await?;
2498 let mut write = self.backend.write().await;
2499 let backend = write.as_mut().unwrap();
2500 tasks::classify::classify(backend, req).await
2501 }
2502 }
2503
2504 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2506 async fn classify_via_generate(
2507 &self,
2508 req: ClassifyRequest,
2509 model: &str,
2510 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2511 let labels_str = req
2512 .labels
2513 .iter()
2514 .enumerate()
2515 .map(|(i, l)| format!("{}. {}", i + 1, l))
2516 .collect::<Vec<_>>()
2517 .join("\n");
2518
2519 let prompt = format!(
2520 "Classify the following text into one of these categories:\n\
2521 {labels_str}\n\n\
2522 Text: {}\n\n\
2523 Respond with ONLY the category name, nothing else.",
2524 req.text
2525 );
2526
2527 let gen_req = GenerateRequest {
2528 prompt,
2529 model: Some(model.to_string()),
2530 params: tasks::generate::GenerateParams {
2531 temperature: 0.0,
2532 max_tokens: 32,
2533 thinking: tasks::generate::ThinkingMode::Off,
2536 ..Default::default()
2537 },
2538 context: None,
2539 tools: None,
2540 images: None,
2541 messages: None,
2542 cache_control: false,
2543 response_format: None,
2544 intent: None,
2545 };
2546
2547 let response = self.generate(gen_req).await?;
2548 let response_lower = response.trim().to_lowercase();
2549
2550 let mut results: Vec<ClassifyResult> = req
2551 .labels
2552 .iter()
2553 .map(|label| {
2554 let label_lower = label.to_lowercase();
2555 let score = if response_lower == label_lower {
2556 1.0
2557 } else if response_lower.contains(&label_lower) {
2558 0.8
2559 } else {
2560 let label_words: Vec<&str> = label_lower.split_whitespace().collect();
2561 let matches = label_words
2562 .iter()
2563 .filter(|w| response_lower.contains(**w))
2564 .count();
2565 if label_words.is_empty() {
2566 0.0
2567 } else {
2568 0.5 * (matches as f64 / label_words.len() as f64)
2569 }
2570 };
2571 ClassifyResult {
2572 label: label.clone(),
2573 score,
2574 }
2575 })
2576 .collect();
2577
2578 results.sort_by(|a, b| {
2579 b.score
2580 .partial_cmp(&a.score)
2581 .unwrap_or(std::cmp::Ordering::Equal)
2582 });
2583
2584 let total: f64 = results.iter().map(|r| r.score).sum();
2585 if total > 0.0 {
2586 for r in &mut results {
2587 r.score /= total;
2588 }
2589 }
2590
2591 Ok(results)
2592 }
2593
2594 pub async fn transcribe(
2596 &self,
2597 req: TranscribeRequest,
2598 ) -> Result<TranscribeResult, InferenceError> {
2599 let candidates =
2600 self.speech_candidates(ModelCapability::SpeechToText, req.model.as_deref())?;
2601 let mut last_error = None;
2602
2603 for schema in candidates {
2604 let result = match &schema.source {
2605 ModelSource::Mlx { .. } => self.transcribe_local_mlx(&schema, &req).await,
2606 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2607 self.transcribe_elevenlabs(&schema, &req).await
2608 }
2609 _ => Err(InferenceError::InferenceFailed(format!(
2610 "speech-to-text not implemented for model source: {}",
2611 schema.id
2612 ))),
2613 };
2614
2615 match result {
2616 Ok(result) => return Ok(result),
2617 Err(err) => last_error = Some(err),
2618 }
2619 }
2620
2621 Err(last_error.unwrap_or_else(|| {
2622 InferenceError::InferenceFailed("no speech-to-text models available".into())
2623 }))
2624 }
2625
2626 pub async fn synthesize(
2628 &self,
2629 req: SynthesizeRequest,
2630 ) -> Result<SynthesizeResult, InferenceError> {
2631 let candidates =
2632 self.speech_candidates(ModelCapability::TextToSpeech, req.model.as_deref())?;
2633 let mut last_error = None;
2634
2635 for schema in candidates {
2636 let result = match &schema.source {
2637 ModelSource::Mlx { .. } => self.synthesize_local_mlx(&schema, &req).await,
2638 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2639 self.synthesize_elevenlabs(&schema, &req).await
2640 }
2641 _ => Err(InferenceError::InferenceFailed(format!(
2642 "text-to-speech not implemented for model source: {}",
2643 schema.id
2644 ))),
2645 };
2646
2647 match result {
2648 Ok(result) => return Ok(result),
2649 Err(err) => last_error = Some(err),
2650 }
2651 }
2652
2653 Err(last_error.unwrap_or_else(|| {
2654 InferenceError::InferenceFailed("no text-to-speech models available".into())
2655 }))
2656 }
2657
2658 pub async fn generate_image(
2660 &self,
2661 req: GenerateImageRequest,
2662 ) -> Result<GenerateImageResult, InferenceError> {
2663 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2670 {
2671 use crate::backend::external_flux;
2672 let backend =
2673 std::env::var("CAR_IMAGE_BACKEND").unwrap_or_else(|_| "native".to_string());
2674 let use_external = match backend.as_str() {
2675 "external" => true,
2676 "native" => false,
2677 _ => external_flux::is_available() && backend == "auto-external",
2679 };
2680 if use_external {
2681 tracing::info!(
2682 "routing image generation to external mflux \
2683 (set CAR_IMAGE_BACKEND=native to use the Rust port)"
2684 );
2685 let mut req = req;
2686 req.model = self.resolve_external_hf_repo(
2687 req.model.as_deref(),
2688 ModelCapability::ImageGeneration,
2689 );
2690 return external_flux::generate_image(&req);
2691 }
2692 tracing::info!("using native Rust MLX Flux backend");
2693 }
2694
2695 let candidates = self
2696 .media_generation_candidates(ModelCapability::ImageGeneration, req.model.as_deref())?;
2697 let mut last_error = None;
2698
2699 for schema in candidates {
2700 let result = match &schema.source {
2701 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2702 ModelSource::Mlx { .. } => self.generate_image_native_mlx(&schema, &req).await,
2703 _ => Err(InferenceError::InferenceFailed(format!(
2704 "image generation not implemented for model source: {}",
2705 schema.id
2706 ))),
2707 };
2708
2709 match result {
2710 Ok(result) => return Ok(result),
2711 Err(err) => last_error = Some(err),
2712 }
2713 }
2714
2715 Err(last_error.unwrap_or_else(|| {
2716 InferenceError::InferenceFailed("no image generation models available".into())
2717 }))
2718 }
2719
2720 pub async fn generate_image_batch(
2736 &self,
2737 req: GenerateImageRequest,
2738 ) -> Result<Vec<GenerateImageResult>, InferenceError> {
2739 let count = req.variant_count.unwrap_or(1).max(1);
2740 if count == 1 {
2741 return self.generate_image(req).await.map(|r| vec![r]);
2742 }
2743 let base_seed = req.seed.unwrap_or(0);
2744 let mut results = Vec::with_capacity(count as usize);
2745 for i in 0..count {
2746 let mut variant_req = req.clone();
2751 variant_req.seed = Some(base_seed.wrapping_add(i as u64));
2752 variant_req.variant_count = Some(1);
2756 results.push(self.generate_image(variant_req).await?);
2757 }
2758 Ok(results)
2759 }
2760
2761 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2763 async fn generate_image_native_mlx(
2764 &self,
2765 schema: &ModelSchema,
2766 req: &GenerateImageRequest,
2767 ) -> Result<GenerateImageResult, InferenceError> {
2768 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2769 let size = backend_cache::estimate_model_size(&model_dir);
2770 let cache = Arc::clone(&self.flux_cache);
2771 let key = schema.id.clone();
2772 let handle = cache.get_or_load(&key, size, || {
2773 backend::mlx_flux::FluxBackend::load(&model_dir)
2774 })?;
2775 let req = req.clone();
2779 tokio::task::spawn_blocking(move || -> Result<GenerateImageResult, InferenceError> {
2780 let mut guard = handle.lock().map_err(|_| {
2781 InferenceError::InferenceFailed("flux backend mutex poisoned".into())
2782 })?;
2783 guard.generate(&req)
2784 })
2785 .await
2786 .map_err(|e| InferenceError::InferenceFailed(format!("flux task join: {e}")))?
2787 }
2788
2789 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2791 async fn generate_video_native_mlx(
2792 &self,
2793 schema: &ModelSchema,
2794 req: &GenerateVideoRequest,
2795 ) -> Result<GenerateVideoResult, InferenceError> {
2796 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2797 let size = backend_cache::estimate_model_size(&model_dir);
2798 let cache = Arc::clone(&self.ltx_cache);
2799 let key = schema.id.clone();
2800 let handle = cache.get_or_load(&key, size, || {
2801 backend::mlx_ltx::LtxBackend::load(&model_dir)
2802 })?;
2803 let req = req.clone();
2804 tokio::task::spawn_blocking(move || -> Result<GenerateVideoResult, InferenceError> {
2805 let mut guard = handle.lock().map_err(|_| {
2806 InferenceError::InferenceFailed("ltx backend mutex poisoned".into())
2807 })?;
2808 guard.generate(&req)
2809 })
2810 .await
2811 .map_err(|e| InferenceError::InferenceFailed(format!("ltx task join: {e}")))?
2812 }
2813
2814 pub async fn generate_video(
2816 &self,
2817 req: GenerateVideoRequest,
2818 ) -> Result<GenerateVideoResult, InferenceError> {
2819 if let Err(msg) = req.validate() {
2822 return Err(InferenceError::InferenceFailed(format!(
2823 "invalid GenerateVideoRequest: {}",
2824 msg
2825 )));
2826 }
2827 let requires_audio_conditioning = req.requires_audio_passthrough_opt_in();
2828 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2833 {
2834 use crate::backend::external_ltx;
2835 let backend =
2840 std::env::var("CAR_VIDEO_BACKEND").unwrap_or_else(|_| "native".to_string());
2841 let use_external = match backend.as_str() {
2842 "external" => true,
2843 "native" => false,
2844 "auto-external" => external_ltx::is_available(),
2848 _ => false,
2849 };
2850 if use_external {
2851 tracing::info!(
2852 "CAR_VIDEO_BACKEND requested external LTX routing for LTX-family models"
2853 );
2854 } else {
2855 tracing::info!("using family-aware MLX video routing");
2856 }
2857 }
2858
2859 let candidates = self
2860 .media_generation_candidates(ModelCapability::VideoGeneration, req.model.as_deref())?;
2861 let mut last_error = None;
2862
2863 for schema in candidates {
2864 let result = match &schema.source {
2865 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2866 ModelSource::Mlx { hf_repo, .. } => {
2867 if crate::backend::external_mlx_video::is_wan_family(&schema) {
2868 match self.unified_registry.ensure_local(&schema.id).await {
2869 Ok(model_dir) => {
2870 crate::backend::external_mlx_video::generate_wan_video(
2871 &schema, &model_dir, &req,
2872 )
2873 }
2874 Err(err) => Err(err),
2875 }
2876 } else {
2877 let backend = std::env::var("CAR_VIDEO_BACKEND")
2878 .unwrap_or_else(|_| "native".to_string());
2879 let use_external_ltx = match backend.as_str() {
2880 "external" => true,
2881 "native" => false,
2882 "auto-external" => crate::backend::external_ltx::is_available(),
2883 _ => false,
2884 };
2885 let use_external_ltx = use_external_ltx || requires_audio_conditioning;
2886 if requires_audio_conditioning
2887 && !crate::backend::external_ltx::is_available()
2888 {
2889 return Err(InferenceError::InferenceFailed(
2890 "audio-reference video conditioning requires the external `ltx-2-mlx a2v` CLI on PATH"
2891 .to_string(),
2892 ));
2893 }
2894 if use_external_ltx {
2895 let mut req = req.clone();
2896 req.model = Some(hf_repo.clone());
2897 crate::backend::external_ltx::generate_video(&req)
2898 } else {
2899 self.generate_video_native_mlx(&schema, &req).await
2900 }
2901 }
2902 }
2903 _ => Err(InferenceError::InferenceFailed(format!(
2904 "video generation not implemented for model source: {}",
2905 schema.id
2906 ))),
2907 };
2908
2909 match result {
2910 Ok(result) => return Ok(result),
2911 Err(err) => last_error = Some(err),
2912 }
2913 }
2914
2915 Err(last_error.unwrap_or_else(|| {
2916 InferenceError::InferenceFailed("no video generation models available".into())
2917 }))
2918 }
2919
2920 pub fn list_models_unified(&self) -> Vec<ModelInfo> {
2922 self.unified_registry
2923 .list()
2924 .iter()
2925 .map(|m| ModelInfo::from(*m))
2926 .collect()
2927 }
2928
2929 pub fn available_model_upgrades(&self) -> Vec<ModelUpgrade> {
2931 self.unified_registry.available_upgrades()
2932 }
2933
2934 pub async fn check_upgrade_nudge(
2940 &self,
2941 inference_active: bool,
2942 ) -> (crate::nudge::NudgeDecision, crate::nudge::NudgeState) {
2943 let findings = self.detect_upgrades().await;
2944 let prefs = self.update_prefs();
2945 let state = crate::nudge::NudgeState::load_from(&crate::nudge::NudgeState::default_path());
2946 let now = std::time::SystemTime::now()
2947 .duration_since(std::time::UNIX_EPOCH)
2948 .map(|d| d.as_secs())
2949 .unwrap_or(0);
2950 let decision = crate::nudge::decide_nudge(
2951 &findings,
2952 &prefs,
2953 &state,
2954 now,
2955 crate::nudge::DEFAULT_THROTTLE_SECS,
2956 inference_active,
2957 );
2958 (decision, state)
2959 }
2960
2961 pub fn dismiss_upgrade_nudge(&self, dismiss_key: &str) -> Result<(), InferenceError> {
2964 let path = crate::nudge::NudgeState::default_path();
2965 let mut state = crate::nudge::NudgeState::load_from(&path);
2966 state.dismiss(dismiss_key);
2967 state.save_to(&path).map_err(InferenceError::InferenceFailed)
2968 }
2969
2970 pub async fn detect_upgrades(&self) -> Vec<crate::upgrade::UpgradeFinding> {
2974 let prefs = self.update_prefs();
2975 let curated = self.unified_registry.available_upgrades();
2976 let schemas = self.list_schemas();
2977 let refs: Vec<&ModelSchema> = schemas.iter().collect();
2978 let probe = crate::upgrade::HuggingFaceProbe::new();
2979 let now = std::time::SystemTime::now()
2980 .duration_since(std::time::UNIX_EPOCH)
2981 .map(|d| d.as_secs())
2982 .unwrap_or(0);
2983 crate::upgrade::detect_upgrades(
2984 curated,
2985 &refs,
2986 &prefs,
2987 &probe,
2988 &crate::upgrade::UpgradeCache::default_path(),
2989 now,
2990 crate::upgrade::DEFAULT_TTL_SECS,
2991 )
2992 .await
2993 }
2994
2995 pub fn list_schemas(&self) -> Vec<ModelSchema> {
2998 self.unified_registry.list().into_iter().cloned().collect()
2999 }
3000
3001 pub fn list_models(&self) -> Vec<models::ModelInfo> {
3002 self.registry.list_models()
3003 }
3004
3005 pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
3007 self.pull_model_with_progress(name, &crate::download::ProgressSink::none())
3008 .await
3009 }
3010
3011 pub async fn pull_model_with_progress(
3015 &self,
3016 name: &str,
3017 sink: &crate::download::ProgressSink,
3018 ) -> Result<std::path::PathBuf, InferenceError> {
3019 let schema = self
3020 .unified_registry
3021 .find_by_name(name)
3022 .or_else(|| self.unified_registry.get(name))
3023 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
3024 self.unified_registry
3025 .ensure_local_with_progress(&schema.id, sink)
3026 .await
3027 }
3028
3029 pub fn update_prefs(&self) -> crate::update_prefs::UpdatePreferences {
3034 let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
3035 crate::update_prefs::UpdatePreferences::load_effective(&cwd).unwrap_or_default()
3036 }
3037
3038 pub fn set_update_prefs(
3040 &self,
3041 prefs: &crate::update_prefs::UpdatePreferences,
3042 ) -> Result<(), InferenceError> {
3043 prefs
3044 .save()
3045 .map_err(InferenceError::InferenceFailed)
3046 }
3047
3048 pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
3050 let schema = self
3051 .unified_registry
3052 .get(name)
3053 .or_else(|| {
3054 self.unified_registry
3055 .list()
3056 .into_iter()
3057 .find(|schema| schema.name.eq_ignore_ascii_case(name))
3058 })
3059 .or_else(|| self.unified_registry.find_by_name(name))
3060 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
3061 let model_dir = self.unified_registry.models_dir().join(&schema.name);
3062 if model_dir.exists() {
3063 std::fs::remove_dir_all(&model_dir)?;
3064 }
3065 match &schema.source {
3066 ModelSource::Mlx { hf_repo, .. } => {
3067 remove_huggingface_repo_cache(hf_repo)?;
3068 }
3069 ModelSource::Local {
3070 hf_repo,
3071 tokenizer_repo,
3072 ..
3073 } => {
3074 remove_huggingface_repo_cache(hf_repo)?;
3075 remove_huggingface_repo_cache(tokenizer_repo)?;
3076 }
3077 _ => {}
3078 }
3079 Ok(())
3080 }
3081
3082 pub fn register_model(&mut self, schema: ModelSchema) {
3084 self.unified_registry.register(schema);
3085 }
3086
3087 pub async fn discover_vllm_mlx_models(&mut self) -> usize {
3090 let config = vllm_mlx::VllmMlxConfig::default();
3091 if !config.auto_discover {
3092 return 0;
3093 }
3094 vllm_mlx::discover_and_register(&config, &mut self.unified_registry).await
3095 }
3096
3097 pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
3099 self.outcome_tracker.clone()
3100 }
3101
3102 async fn auto_save_outcomes(&self) {
3104 if let Err(e) = self.save_outcomes().await {
3105 tracing::debug!("auto-save outcomes failed: {}", e);
3106 }
3107 if let Err(e) = self.save_key_pool_stats().await {
3108 tracing::debug!("auto-save key pool stats failed: {}", e);
3109 }
3110 }
3111
3112 pub async fn save_outcomes(&self) -> Result<(), std::io::Error> {
3114 let tracker = self.outcome_tracker.read().await;
3115 let path = self.config.models_dir.join("outcome_profiles.json");
3116 tracker.save_to_file(&path)
3117 }
3118
3119 pub async fn save_key_pool_stats(&self) -> Result<(), std::io::Error> {
3121 let path = self.config.models_dir.join("key_pool_stats.json");
3122 self.remote_backend.key_pool.save_stats(&path).await
3123 }
3124
3125 pub async fn key_pool_stats(
3127 &self,
3128 ) -> std::collections::HashMap<String, Vec<key_pool::KeyStats>> {
3129 self.remote_backend.key_pool.all_stats().await
3130 }
3131
3132 pub async fn export_profiles(&self) -> Vec<ModelProfile> {
3134 let tracker = self.outcome_tracker.read().await;
3135 tracker.export_profiles()
3136 }
3137
3138 pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
3140 let mut tracker = self.outcome_tracker.write().await;
3141 tracker.import_profiles(profiles);
3142 }
3143
3144 pub async fn prepare_speech_runtime(&self) -> Result<PathBuf, InferenceError> {
3147 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3148 {
3149 Ok(self.config.models_dir.clone())
3151 }
3152 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3153 {
3154 Ok(self.ensure_speech_runtime().await?.root)
3155 }
3156 }
3157
3158 pub fn set_speech_policy(&mut self, policy: SpeechPolicy) {
3160 self.speech_policy = policy;
3161 }
3162
3163 pub fn set_routing_config(&mut self, config: RoutingConfig) {
3164 self.adaptive_router.set_config(config);
3165 }
3166
3167 pub async fn install_curated_speech(
3169 &mut self,
3170 ) -> Result<Vec<SpeechInstallReport>, InferenceError> {
3171 let _runtime_root = self.prepare_speech_runtime().await?;
3172 let schemas = self.list_schemas();
3173 let mut repos = Vec::new();
3174 for schema in &schemas {
3175 if !schema.is_mlx() || !schema.tags.iter().any(|tag| tag == "speech") {
3176 continue;
3177 }
3178 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3179 if !repos.iter().any(|existing: &String| existing == hf_repo) {
3180 repos.push(hf_repo.clone());
3181 }
3182 }
3183 }
3184
3185 let mut installed = Vec::new();
3186 for repo in repos {
3187 let (snapshot_path, files_downloaded) = download_hf_repo_snapshot(&repo).await?;
3188 let name = schemas
3189 .iter()
3190 .find(|schema| {
3191 matches!(&schema.source, ModelSource::Mlx { hf_repo, .. } if hf_repo == &repo)
3192 })
3193 .map(|schema| schema.name.clone())
3194 .unwrap_or_else(|| repo.clone());
3195 installed.push(SpeechInstallReport {
3196 name,
3197 hf_repo: repo,
3198 snapshot_path,
3199 files_downloaded,
3200 });
3201 }
3202
3203 self.unified_registry.refresh_availability();
3204 Ok(installed)
3205 }
3206
3207 pub fn speech_health(&self) -> SpeechHealthReport {
3209 let local_stt_default =
3210 self.speech_health_default_name(ModelCapability::SpeechToText, true, false);
3211 let local_tts_default =
3212 self.speech_health_default_name(ModelCapability::TextToSpeech, true, false);
3213 let remote_stt_default =
3214 self.speech_health_default_name(ModelCapability::SpeechToText, false, true);
3215 let remote_tts_default =
3216 self.speech_health_default_name(ModelCapability::TextToSpeech, false, true);
3217
3218 let mut local_models = Vec::new();
3219 let mut remote_models = Vec::new();
3220 for schema in self.list_schemas() {
3221 let capability = if schema.has_capability(ModelCapability::SpeechToText) {
3222 Some(ModelCapability::SpeechToText)
3223 } else if schema.has_capability(ModelCapability::TextToSpeech) {
3224 Some(ModelCapability::TextToSpeech)
3225 } else {
3226 None
3227 };
3228 let Some(capability) = capability else {
3229 continue;
3230 };
3231
3232 let selected_by_default = local_stt_default
3233 .as_ref()
3234 .is_some_and(|name| name == &schema.name)
3235 || local_tts_default
3236 .as_ref()
3237 .is_some_and(|name| name == &schema.name)
3238 || remote_stt_default
3239 .as_ref()
3240 .is_some_and(|name| name == &schema.name)
3241 || remote_tts_default
3242 .as_ref()
3243 .is_some_and(|name| name == &schema.name);
3244
3245 let health = SpeechModelHealth {
3246 id: schema.id.clone(),
3247 name: schema.name.clone(),
3248 provider: schema.provider.clone(),
3249 capability,
3250 is_local: schema.is_local(),
3251 available: schema.available,
3252 cached: speech_model_cached(&schema),
3253 selected_by_default,
3254 source: speech_model_source_label(&schema),
3255 };
3256 if schema.is_local() {
3257 local_models.push(health);
3258 } else {
3259 remote_models.push(health);
3260 }
3261 }
3262
3263 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3265 let runtime = SpeechRuntimeHealth {
3266 root: self.config.models_dir.clone(),
3267 installed: true,
3268 python: PathBuf::new(),
3269 stt_command: PathBuf::new(),
3270 tts_command: PathBuf::new(),
3271 configured_python: None,
3272 detected_python: None,
3273 };
3274
3275 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3276 let runtime = {
3277 let rt =
3278 SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3279 SpeechRuntimeHealth {
3280 root: rt.root.clone(),
3281 installed: rt.is_ready(),
3282 python: rt.python.clone(),
3283 stt_command: rt.stt_program.clone(),
3284 tts_command: rt.tts_program.clone(),
3285 configured_python: std::env::var("CAR_SPEECH_PYTHON")
3286 .ok()
3287 .filter(|value| !value.trim().is_empty()),
3288 detected_python: detect_speech_python(),
3289 }
3290 };
3291
3292 SpeechHealthReport {
3293 runtime,
3294 local_models,
3295 remote_models,
3296 elevenlabs_configured: car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY")
3297 .is_some(),
3298 prefer_local: self.speech_policy.prefer_local,
3299 allow_remote_fallback: self.speech_policy.allow_remote_fallback,
3300 preferred_local_stt: self.speech_policy.preferred_local_stt.clone(),
3301 preferred_local_tts: self.speech_policy.preferred_local_tts.clone(),
3302 preferred_remote_stt: self.speech_policy.preferred_remote_stt.clone(),
3303 preferred_remote_tts: self.speech_policy.preferred_remote_tts.clone(),
3304 local_stt_default,
3305 local_tts_default,
3306 remote_stt_default,
3307 remote_tts_default,
3308 }
3309 }
3310
3311 pub async fn model_health(&self) -> ModelHealthReport {
3314 let schemas = self.list_schemas();
3315 let total_models = schemas.len();
3316 let available_models = schemas.iter().filter(|schema| schema.available).count();
3317 let local_models = schemas.iter().filter(|schema| schema.is_local()).count();
3318 let remote_models = total_models.saturating_sub(local_models);
3319
3320 let defaults = vec![
3321 self.model_default_health(
3322 ModelCapability::Generate,
3323 self.preferred_model_for_capability(ModelCapability::Generate)
3324 .unwrap_or(&self.config.generation_model),
3325 ),
3326 self.model_default_health(
3327 ModelCapability::Embed,
3328 self.preferred_model_for_capability(ModelCapability::Embed)
3329 .unwrap_or(&self.config.embedding_model),
3330 ),
3331 self.model_default_health(
3332 ModelCapability::Classify,
3333 self.preferred_model_for_capability(ModelCapability::Classify)
3334 .unwrap_or(&self.config.classification_model),
3335 ),
3336 ];
3337
3338 let mut providers = std::collections::BTreeMap::new();
3339 for schema in &schemas {
3340 let entry =
3341 providers
3342 .entry(schema.provider.clone())
3343 .or_insert_with(|| ProviderAccumulator {
3344 configured: false,
3345 local_models: 0,
3346 remote_models: 0,
3347 available_models: 0,
3348 capabilities: std::collections::HashSet::new(),
3349 });
3350
3351 entry.configured |= model_source_configured(schema);
3352 if schema.is_local() {
3353 entry.local_models += 1;
3354 } else {
3355 entry.remote_models += 1;
3356 }
3357 if schema.available {
3358 entry.available_models += 1;
3359 }
3360 for capability in &schema.capabilities {
3361 entry.capabilities.insert(*capability);
3362 }
3363 }
3364
3365 let providers = providers
3366 .into_iter()
3367 .map(|(provider, acc)| ModelProviderHealth {
3368 provider,
3369 configured: acc.configured,
3370 local_models: acc.local_models,
3371 remote_models: acc.remote_models,
3372 available_models: acc.available_models,
3373 capabilities: sort_capabilities(acc.capabilities.into_iter().collect()),
3374 })
3375 .collect();
3376
3377 let capabilities = all_model_capabilities()
3378 .into_iter()
3379 .map(|capability| {
3380 let relevant: Vec<&ModelSchema> = schemas
3381 .iter()
3382 .filter(|schema| schema.has_capability(capability))
3383 .collect();
3384 let available: Vec<&ModelSchema> = relevant
3385 .iter()
3386 .copied()
3387 .filter(|schema| schema.available)
3388 .collect();
3389 ModelCapabilityHealth {
3390 capability,
3391 total_models: relevant.len(),
3392 available_models: available.len(),
3393 local_available_models: available
3394 .iter()
3395 .filter(|schema| schema.is_local())
3396 .count(),
3397 remote_available_models: available
3398 .iter()
3399 .filter(|schema| !schema.is_local())
3400 .count(),
3401 }
3402 })
3403 .collect();
3404
3405 let routing = self.routing_scenarios().await;
3406 let routing_config = self.adaptive_router.config().clone();
3407 let benchmark_priors = load_benchmark_prior_health(&self.config.models_dir, &schemas);
3408
3409 ModelHealthReport {
3410 total_models,
3411 available_models,
3412 local_models,
3413 remote_models,
3414 defaults,
3415 providers,
3416 capabilities,
3417 routing_prefer_local: routing_config.prefer_local,
3418 routing_quality_first_cold_start: routing_config.quality_first_cold_start,
3419 routing_min_observations: routing_config.min_observations,
3420 routing_bootstrap_min_task_observations: routing_config.bootstrap_min_task_observations,
3421 routing_bootstrap_quality_floor: routing_config.bootstrap_quality_floor,
3422 routing_quality_weight: routing_config.quality_weight,
3423 routing_latency_weight: routing_config.latency_weight,
3424 routing_cost_weight: routing_config.cost_weight,
3425 routing_scenarios: routing,
3426 benchmark_priors,
3427 speech: self.speech_health(),
3428 }
3429 }
3430
3431 async fn routing_scenarios(&self) -> Vec<RoutingScenarioHealth> {
3432 let tracker = self.outcome_tracker.read().await;
3433 let config = self.adaptive_router.config().clone();
3434 let scenarios = [
3435 (
3436 "interactive_text",
3437 "Summarize the benefits of local-first AI routing in two sentences.",
3438 "text",
3439 RoutingWorkload::Interactive,
3440 false,
3441 false,
3442 ),
3443 (
3444 "background_code",
3445 "Write a Python function named fibonacci(n) that returns the nth Fibonacci number.",
3446 "code",
3447 RoutingWorkload::Background,
3448 false,
3449 false,
3450 ),
3451 (
3452 "interactive_tool_use",
3453 "Use the provided weather tool to get the weather for Boston.",
3454 "tool_use",
3455 RoutingWorkload::Interactive,
3456 true,
3457 false,
3458 ),
3459 (
3460 "interactive_vision",
3461 "What is in this image? Answer in one word.",
3462 "vision",
3463 RoutingWorkload::Interactive,
3464 false,
3465 true,
3466 ),
3467 ];
3468
3469 scenarios
3470 .into_iter()
3471 .map(
3472 |(name, prompt, task_family, workload, has_tools, has_vision)| {
3473 let decision = self.adaptive_router.route_context_aware(
3474 prompt,
3475 0,
3476 &self.unified_registry,
3477 &tracker,
3478 has_tools,
3479 has_vision,
3480 workload,
3481 );
3482 let quality_first_cold_start = if has_tools || has_vision {
3483 config.quality_first_cold_start
3484 } else if task_family == "code"
3485 && matches!(workload, RoutingWorkload::Background)
3486 {
3487 false
3488 } else {
3489 config.quality_first_cold_start
3490 };
3491 RoutingScenarioHealth {
3492 name: name.to_string(),
3493 task_family: task_family.to_string(),
3494 workload,
3495 has_tools,
3496 has_vision,
3497 prefer_local: if task_family == "speech" {
3498 self.speech_policy.prefer_local
3499 } else {
3500 config.prefer_local
3501 },
3502 quality_first_cold_start,
3503 bootstrap_min_task_observations: config.bootstrap_min_task_observations,
3504 bootstrap_quality_floor: config.bootstrap_quality_floor,
3505 model_id: decision.model_id,
3506 model_name: decision.model_name,
3507 reason: decision.reason,
3508 strategy: decision.strategy,
3509 }
3510 },
3511 )
3512 .collect()
3513 }
3514
3515 pub async fn smoke_test_speech(
3517 &self,
3518 local: bool,
3519 remote: bool,
3520 ) -> Result<SpeechSmokeReport, InferenceError> {
3521 let mut report = SpeechSmokeReport::default();
3522
3523 if local {
3524 let tts = self
3525 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
3526 .ok_or_else(|| {
3527 InferenceError::InferenceFailed(
3528 "no local text-to-speech model available".into(),
3529 )
3530 })?;
3531 let stt = self
3532 .preferred_speech_schema(ModelCapability::SpeechToText, true, false)
3533 .ok_or_else(|| {
3534 InferenceError::InferenceFailed(
3535 "no local speech-to-text model available".into(),
3536 )
3537 })?;
3538 report.local = Some(
3539 self.run_speech_smoke_path("local", &tts, &stt, "Testing CAR local speech path.")
3540 .await?,
3541 );
3542 } else {
3543 report.skipped.push("local".to_string());
3544 }
3545
3546 if remote {
3547 let tts = self
3548 .preferred_speech_schema(ModelCapability::TextToSpeech, false, true)
3549 .ok_or_else(|| {
3550 InferenceError::InferenceFailed(
3551 "no remote text-to-speech model available".into(),
3552 )
3553 })?;
3554 let stt = self
3555 .preferred_speech_schema(ModelCapability::SpeechToText, false, true)
3556 .ok_or_else(|| {
3557 InferenceError::InferenceFailed(
3558 "no remote speech-to-text model available".into(),
3559 )
3560 })?;
3561 report.remote = Some(
3562 self.run_speech_smoke_path("remote", &tts, &stt, "Testing CAR remote speech path.")
3563 .await?,
3564 );
3565 } else {
3566 report.skipped.push("remote".to_string());
3567 }
3568
3569 Ok(report)
3570 }
3571
3572 fn speech_candidates(
3573 &self,
3574 capability: ModelCapability,
3575 explicit: Option<&str>,
3576 ) -> Result<Vec<ModelSchema>, InferenceError> {
3577 if let Some(model) = explicit {
3578 let schema = self
3579 .unified_registry
3580 .get(model)
3581 .or_else(|| self.unified_registry.find_by_name(model))
3582 .cloned()
3583 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3584 if !schema.has_capability(capability) {
3585 return Err(InferenceError::InferenceFailed(format!(
3586 "model {} does not support {:?}",
3587 schema.name, capability
3588 )));
3589 }
3590 return Ok(vec![schema]);
3591 }
3592
3593 let mut candidates: Vec<ModelSchema> = self
3594 .unified_registry
3595 .query(&ModelFilter {
3596 capabilities: vec![capability],
3597 ..Default::default()
3598 })
3599 .into_iter()
3600 .cloned()
3601 .collect();
3602
3603 if candidates.is_empty() {
3604 return Err(InferenceError::InferenceFailed(format!(
3605 "no models registered for capability {:?}",
3606 capability
3607 )));
3608 }
3609
3610 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3611 if !self.speech_policy.allow_remote_fallback
3612 && candidates.iter().any(|model| model.is_local())
3613 {
3614 candidates.retain(|model| model.is_local());
3615 }
3616
3617 Ok(candidates)
3618 }
3619
3620 #[allow(dead_code)] fn resolve_external_hf_repo(
3626 &self,
3627 explicit: Option<&str>,
3628 capability: ModelCapability,
3629 ) -> Option<String> {
3630 let id = explicit?;
3631 let schema = self
3632 .unified_registry
3633 .get(id)
3634 .or_else(|| self.unified_registry.find_by_name(id))?;
3635 if !schema.has_capability(capability) {
3636 return Some(id.to_string());
3637 }
3638 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3639 return Some(hf_repo.clone());
3640 }
3641 Some(id.to_string())
3642 }
3643
3644 fn media_generation_candidates(
3645 &self,
3646 capability: ModelCapability,
3647 explicit: Option<&str>,
3648 ) -> Result<Vec<ModelSchema>, InferenceError> {
3649 if let Some(model) = explicit {
3650 let schema = self
3651 .unified_registry
3652 .get(model)
3653 .or_else(|| self.unified_registry.find_by_name(model))
3654 .cloned()
3655 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3656 if !schema.has_capability(capability) {
3657 return Err(InferenceError::InferenceFailed(format!(
3658 "model {} does not support {:?}",
3659 schema.name, capability
3660 )));
3661 }
3662 return Ok(vec![schema]);
3663 }
3664
3665 let mut candidates: Vec<ModelSchema> = self
3666 .unified_registry
3667 .query(&ModelFilter {
3668 capabilities: vec![capability],
3669 local_only: true,
3670 ..Default::default()
3671 })
3672 .into_iter()
3673 .cloned()
3674 .collect();
3675 candidates.sort_by_key(|schema| (!schema.available, schema.size_mb()));
3676 if candidates.is_empty() {
3677 return Err(InferenceError::InferenceFailed(format!(
3678 "no models registered for capability {:?}",
3679 capability
3680 )));
3681 }
3682 Ok(candidates)
3683 }
3684
3685 fn preferred_speech_schema(
3686 &self,
3687 capability: ModelCapability,
3688 local_only: bool,
3689 remote_only: bool,
3690 ) -> Option<ModelSchema> {
3691 let available_only = remote_only;
3692 let mut candidates: Vec<ModelSchema> = self
3693 .unified_registry
3694 .query(&ModelFilter {
3695 capabilities: vec![capability],
3696 available_only,
3697 ..Default::default()
3698 })
3699 .into_iter()
3700 .filter(|schema| {
3701 (!local_only || schema.is_local()) && (!remote_only || schema.is_remote())
3702 })
3703 .cloned()
3704 .collect();
3705 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3706 candidates.into_iter().next()
3707 }
3708
3709 fn speech_health_default_name(
3710 &self,
3711 capability: ModelCapability,
3712 local_only: bool,
3713 remote_only: bool,
3714 ) -> Option<String> {
3715 let preferred = match capability {
3716 ModelCapability::SpeechToText if local_only => {
3717 self.speech_policy.preferred_local_stt.as_ref()
3718 }
3719 ModelCapability::SpeechToText if remote_only => {
3720 self.speech_policy.preferred_remote_stt.as_ref()
3721 }
3722 ModelCapability::TextToSpeech if local_only => {
3723 self.speech_policy.preferred_local_tts.as_ref()
3724 }
3725 ModelCapability::TextToSpeech if remote_only => {
3726 self.speech_policy.preferred_remote_tts.as_ref()
3727 }
3728 _ => None,
3729 };
3730
3731 preferred
3732 .filter(|name| {
3733 self.unified_registry.list().iter().any(|schema| {
3734 schema.name == **name
3735 && schema.has_capability(capability)
3736 && (!local_only || schema.is_local())
3737 && (!remote_only || schema.is_remote())
3738 })
3739 })
3740 .cloned()
3741 .or_else(|| {
3742 self.preferred_speech_schema(capability, local_only, remote_only)
3743 .map(|schema| schema.name)
3744 })
3745 }
3746
3747 fn model_default_health(
3748 &self,
3749 capability: ModelCapability,
3750 configured_model: &str,
3751 ) -> ModelDefaultHealth {
3752 let schema = self
3753 .unified_registry
3754 .find_by_name(configured_model)
3755 .or_else(|| self.unified_registry.get(configured_model));
3756
3757 ModelDefaultHealth {
3758 capability,
3759 configured_model: configured_model.to_string(),
3760 available: schema.is_some_and(|model| model.available),
3761 is_local: schema.is_some_and(ModelSchema::is_local),
3762 provider: schema.map(|model| model.provider.clone()),
3763 }
3764 }
3765
3766 fn speech_sort_key(
3767 &self,
3768 capability: ModelCapability,
3769 model: &ModelSchema,
3770 ) -> (u8, u8, u8, u8, u64, u64) {
3771 let policy_preference = match capability {
3772 ModelCapability::SpeechToText if model.is_local() => {
3773 self.speech_policy.preferred_local_stt.as_ref()
3774 }
3775 ModelCapability::SpeechToText => self.speech_policy.preferred_remote_stt.as_ref(),
3776 ModelCapability::TextToSpeech if model.is_local() => {
3777 self.speech_policy.preferred_local_tts.as_ref()
3778 }
3779 ModelCapability::TextToSpeech => self.speech_policy.preferred_remote_tts.as_ref(),
3780 _ => None,
3781 };
3782 let local_rank = if self.speech_policy.prefer_local {
3783 if model.is_local() {
3784 0
3785 } else {
3786 1
3787 }
3788 } else if model.is_remote() {
3789 0
3790 } else {
3791 1
3792 };
3793 let availability_rank = if model.available {
3794 0
3795 } else if model.is_local() {
3796 1
3797 } else {
3798 2
3799 };
3800 let policy_rank: u8 = if policy_preference.is_some_and(|preferred| preferred == &model.name)
3801 {
3802 0
3803 } else {
3804 1
3805 };
3806 let speech_rank = match capability {
3807 ModelCapability::TextToSpeech => {
3808 if model.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
3809 0
3810 } else if model.name == "Kokoro-82M-bf16" {
3811 1
3812 } else if model.name == "Kokoro-82M-6bit" {
3813 2
3814 } else {
3815 3
3816 }
3817 }
3818 ModelCapability::SpeechToText => {
3819 if model.name == "Parakeet-TDT-0.6B-v3-MLX" {
3820 0
3821 } else {
3822 1
3823 }
3824 }
3825 _ => 0,
3826 };
3827 let latency_rank = model.performance.latency_p50_ms.unwrap_or(u64::MAX);
3828 let size_rank = model.cost.size_mb.unwrap_or(u64::MAX);
3829 (
3830 local_rank,
3831 availability_rank,
3832 policy_rank,
3833 speech_rank,
3834 latency_rank,
3835 size_rank,
3836 )
3837 }
3838
3839 async fn run_speech_smoke_path(
3840 &self,
3841 path: &str,
3842 tts: &ModelSchema,
3843 stt: &ModelSchema,
3844 text: &str,
3845 ) -> Result<SpeechSmokePathReport, InferenceError> {
3846 let work_dir = temp_work_dir(&format!("speech-smoke-{path}"))?;
3847 let audio_path = work_dir.join(format!("{path}.wav"));
3848 let synth = self
3849 .synthesize(SynthesizeRequest {
3850 text: text.to_string(),
3851 model: Some(tts.name.clone()),
3852 voice: default_speech_voice(tts),
3853 language: Some("en".to_string()),
3854 output_path: Some(audio_path.display().to_string()),
3855 ..SynthesizeRequest::default()
3856 })
3857 .await?;
3858 let transcript = self
3859 .transcribe(TranscribeRequest {
3860 audio_path: synth.audio_path.clone(),
3861 model: Some(stt.name.clone()),
3862 language: Some("en".to_string()),
3863 prompt: None,
3864 timestamps: false,
3865 })
3866 .await?;
3867
3868 Ok(SpeechSmokePathReport {
3869 path: path.to_string(),
3870 tts_model: synth.model_used.unwrap_or_else(|| tts.name.clone()),
3871 stt_model: transcript.model_used.unwrap_or_else(|| stt.name.clone()),
3872 audio_path: PathBuf::from(synth.audio_path),
3873 transcript: transcript.text,
3874 })
3875 }
3876
3877 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3878 async fn ensure_speech_runtime(&self) -> Result<SpeechRuntime, InferenceError> {
3879 let mut guard = self.speech_runtime.lock().await;
3880 if let Some(runtime) = guard.as_ref() {
3881 if runtime.is_ready() {
3882 return Ok(runtime.clone());
3883 }
3884 }
3885
3886 let runtime =
3887 SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3888 if !runtime.is_ready() {
3889 bootstrap_speech_runtime(&runtime).await?;
3890 }
3891 if !runtime.is_ready() {
3892 return Err(InferenceError::InferenceFailed(format!(
3893 "managed speech runtime is not ready at {}",
3894 runtime.root.display()
3895 )));
3896 }
3897
3898 *guard = Some(runtime.clone());
3899 Ok(runtime)
3900 }
3901
3902 async fn transcribe_local_mlx(
3903 &self,
3904 schema: &ModelSchema,
3905 req: &TranscribeRequest,
3906 ) -> Result<TranscribeResult, InferenceError> {
3907 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3909 {
3910 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3911 let parakeet = backend::mlx_parakeet::ParakeetBackend::load(&model_dir)?;
3912 let (text, words) = if req.timestamps {
3914 parakeet
3915 .transcribe_detailed(Path::new(&req.audio_path))
3916 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?
3917 } else {
3918 let t = parakeet
3919 .transcribe(Path::new(&req.audio_path))
3920 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?;
3921 (t, Vec::new())
3922 };
3923 return Ok(TranscribeResult {
3924 text,
3925 model_used: Some(schema.name.clone()),
3926 language: req.language.clone(),
3927 words,
3928 });
3929 }
3930
3931 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3933 {
3934 let runtime = self.ensure_speech_runtime().await?;
3935 let hf_repo = match &schema.source {
3936 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3937 _ => unreachable!(),
3938 };
3939 let output_dir = temp_work_dir("stt")?;
3940 let output_prefix = output_dir.join("transcript");
3941 let mut args = vec![
3942 "--model".to_string(),
3943 hf_repo,
3944 "--audio".to_string(),
3945 req.audio_path.clone(),
3946 "--output-path".to_string(),
3947 output_prefix.display().to_string(),
3948 "--format".to_string(),
3949 "json".to_string(),
3950 ];
3951 if let Some(language) = &req.language {
3952 args.push("--language".to_string());
3953 args.push(normalize_lang_code(language));
3954 }
3955 if let Some(prompt) = &req.prompt {
3956 args.push("--context".to_string());
3957 args.push(prompt.clone());
3958 }
3959 if req.timestamps {
3960 args.push("--verbose".to_string());
3961 }
3962
3963 let output = run_mlx_audio_command(&runtime, "stt.generate", &args).await?;
3964 let text = read_transcription_result(&output_prefix)?
3965 .or_else(|| extract_text_from_payload(&output.stdout))
3966 .ok_or_else(|| {
3967 InferenceError::InferenceFailed(format!(
3968 "mlx-audio transcription returned no text: {}",
3969 output.stderr
3970 ))
3971 })?;
3972
3973 Ok(TranscribeResult {
3974 text,
3975 model_used: Some(schema.name.clone()),
3976 language: req.language.clone(),
3977 words: Vec::new(),
3978 })
3979 }
3980 }
3981
3982 async fn synthesize_local_mlx(
3983 &self,
3984 schema: &ModelSchema,
3985 req: &SynthesizeRequest,
3986 ) -> Result<SynthesizeResult, InferenceError> {
3987 let requested = req.requested_advanced_controls();
3992 let repo_supports_advanced = match &schema.source {
3993 ModelSource::Mlx { hf_repo, .. } => hf_repo.to_ascii_lowercase().contains("qwen3-tts"),
3994 _ => false,
3995 };
3996 if !requested.is_empty() && !repo_supports_advanced {
3997 if req.strict_capabilities {
3998 return Err(InferenceError::InferenceFailed(format!(
3999 "model {name} does not support Qwen3-TTS advanced controls {requested:?}; \
4000 route to a Qwen3-TTS model or set strict_capabilities = false to degrade",
4001 name = schema.name,
4002 )));
4003 }
4004 tracing::warn!(
4005 model = %schema.name,
4006 fields = ?requested,
4007 "Qwen3-TTS advanced controls set on non-Qwen3-TTS backend — ignored \
4008 (set strict_capabilities=true to error instead)"
4009 );
4010 }
4011
4012 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
4014 {
4015 if repo_supports_advanced && !requested.is_empty() {
4021 if req.strict_capabilities {
4022 return Err(InferenceError::InferenceFailed(format!(
4023 "native MLX TTS backend does not yet implement Qwen3-TTS advanced \
4024 controls {requested:?}; run on non-Apple-Silicon to use the Python \
4025 mlx-audio fallback, or set strict_capabilities = false"
4026 )));
4027 }
4028 tracing::warn!(
4029 model = %schema.name,
4030 fields = ?requested,
4031 "Qwen3-TTS advanced controls are not yet implemented in the native MLX TTS \
4032 backend; synthesizing without cloning/voice-design"
4033 );
4034 }
4035 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
4036 let size = backend_cache::estimate_model_size(&model_dir);
4037 let cache = Arc::clone(&self.kokoro_cache);
4038 let key = schema.id.clone();
4039 let handle = cache.get_or_load(&key, size, || {
4040 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
4041 })?;
4042
4043 let output_path = req.output_path.clone().unwrap_or_else(|| {
4044 let dir = std::env::temp_dir().join("car_tts");
4045 let _ = std::fs::create_dir_all(&dir);
4046 dir.join("output.wav").display().to_string()
4047 });
4048 let voice = req.voice.as_deref().unwrap_or("af_heart").to_string();
4049 let text = req.text.clone();
4050 let op = tokio::task::spawn_blocking(move || -> Result<PathBuf, InferenceError> {
4051 let mut guard = handle.lock().map_err(|_| {
4052 InferenceError::InferenceFailed("kokoro backend mutex poisoned".into())
4053 })?;
4054 guard
4055 .synthesize(&text, Some(&voice), Path::new(&output_path))
4056 .map_err(|e| InferenceError::InferenceFailed(format!("native TTS: {e}")))
4057 })
4058 .await
4059 .map_err(|e| InferenceError::InferenceFailed(format!("kokoro task join: {e}")))??;
4060
4061 let final_path =
4062 materialize_audio_output(&op, req.output_path.as_deref(), &req.format)?;
4063 return Ok(SynthesizeResult {
4064 audio_path: final_path.display().to_string(),
4065 media_type: media_type_for_format(&req.format),
4066 model_used: Some(schema.name.clone()),
4067 voice_used: req.voice.clone(),
4068 });
4069 }
4070
4071 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4073 {
4074 let runtime = self.ensure_speech_runtime().await?;
4075 let primary_hf_repo = match &schema.source {
4076 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
4077 _ => unreachable!(),
4078 };
4079 let (produced, model_used) = match self
4080 .synthesize_local_mlx_repo(&runtime, &primary_hf_repo, schema.name.as_str(), req)
4081 .await
4082 {
4083 Ok(result) => result,
4084 Err(primary_err)
4085 if primary_hf_repo == "mlx-community/Kokoro-82M-6bit"
4086 && kokoro_runtime_fallback_enabled() =>
4087 {
4088 let fallback_repo = "mlx-community/Kokoro-82M-bf16";
4089 let fallback_name = "Kokoro-82M-bf16";
4090 match self
4091 .synthesize_local_mlx_repo(&runtime, fallback_repo, fallback_name, req)
4092 .await
4093 {
4094 Ok(result) => result,
4095 Err(fallback_err) => {
4096 return Err(InferenceError::InferenceFailed(format!(
4097 "{primary_err}; fallback {fallback_name} also failed: {fallback_err}"
4098 )));
4099 }
4100 }
4101 }
4102 Err(err) => return Err(err),
4103 };
4104 let final_path =
4105 materialize_audio_output(&produced, req.output_path.as_deref(), &req.format)?;
4106
4107 Ok(SynthesizeResult {
4108 audio_path: final_path.display().to_string(),
4109 media_type: media_type_for_format(&req.format),
4110 model_used: Some(model_used),
4111 voice_used: req.voice.clone(),
4112 })
4113 }
4114 }
4115
4116 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4117 async fn synthesize_local_mlx_repo(
4118 &self,
4119 runtime: &SpeechRuntime,
4120 hf_repo: &str,
4121 model_name: &str,
4122 req: &SynthesizeRequest,
4123 ) -> Result<(PathBuf, String), InferenceError> {
4124 let output_dir = temp_work_dir("tts")?;
4125 let mut args = vec![
4126 "--model".to_string(),
4127 hf_repo.to_string(),
4128 "--text".to_string(),
4129 req.text.clone(),
4130 "--output_path".to_string(),
4131 output_dir.display().to_string(),
4132 ];
4133 if let Some(voice) = &req.voice {
4134 args.push("--voice".to_string());
4135 args.push(voice.clone());
4136 }
4137 if let Some(speed) = req.speed {
4138 args.push("--speed".to_string());
4139 args.push(speed.to_string());
4140 }
4141 let repo_lower = hf_repo.to_ascii_lowercase();
4142 if repo_lower.contains("kokoro") {
4143 args.push("--lang_code".to_string());
4144 args.push(kokoro_lang_code(req.language.as_deref()).to_string());
4145 } else if let Some(language) = &req.language {
4146 args.push("--lang_code".to_string());
4147 args.push(normalize_lang_code(language));
4148 }
4149
4150 if repo_lower.contains("qwen3-tts") {
4156 if let Some(ref_audio) = &req.reference_audio_path {
4157 args.push("--ref_audio".to_string());
4158 args.push(ref_audio.clone());
4159 }
4160 if let Some(ref_text) = &req.reference_text {
4161 args.push("--ref_text".to_string());
4162 args.push(ref_text.clone());
4163 }
4164 if let Some(instruct) = &req.voice_instruction {
4165 args.push("--instruct".to_string());
4166 args.push(instruct.clone());
4167 }
4168 }
4169
4170 let output = if repo_lower.contains("kokoro") {
4171 let device = std::env::var("CAR_SPEECH_KOKORO_DEVICE")
4172 .or_else(|_| std::env::var("CAR_SPEECH_MLX_DEVICE"))
4173 .unwrap_or_else(|_| "cpu".to_string());
4174 let extra_env = vec![
4175 ("MLX_DEVICE".to_string(), device),
4177 ("PYTORCH_ENABLE_MPS_FALLBACK".to_string(), "1".to_string()),
4179 ];
4180 run_mlx_audio_command_with_env(runtime, "tts.generate", &args, &extra_env).await?
4181 } else {
4182 run_mlx_audio_command(runtime, "tts.generate", &args).await?
4183 };
4184 let produced = find_audio_file(&output_dir)?.ok_or_else(|| {
4185 let hint = if repo_lower.contains("kokoro") {
4186 ". Kokoro models may crash on GPU — try CAR_SPEECH_KOKORO_DEVICE=cpu or use the default Qwen3-TTS model"
4187 } else {
4188 ""
4189 };
4190 InferenceError::InferenceFailed(format!(
4191 "mlx-audio synthesis produced no audio file: {}{}",
4192 output.stderr, hint
4193 ))
4194 })?;
4195 Ok((produced, model_name.to_string()))
4196 }
4197
4198 async fn transcribe_elevenlabs(
4199 &self,
4200 schema: &ModelSchema,
4201 req: &TranscribeRequest,
4202 ) -> Result<TranscribeResult, InferenceError> {
4203 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4204 let file_name = Path::new(&req.audio_path)
4205 .file_name()
4206 .and_then(|f| f.to_str())
4207 .unwrap_or("audio.wav")
4208 .to_string();
4209 let audio_bytes = tokio::fs::read(&req.audio_path).await?;
4210 let file_part = Part::bytes(audio_bytes).file_name(file_name);
4211 let mut form = Form::new()
4212 .text("model_id", schema.name.clone())
4213 .part("file", file_part);
4214 if let Some(language) = &req.language {
4215 form = form.text("language_code", language.clone());
4216 }
4217
4218 let resp = self
4219 .remote_backend
4220 .client
4221 .post(format!(
4222 "{}/v1/speech-to-text",
4223 endpoint.trim_end_matches('/')
4224 ))
4225 .header("xi-api-key", api_key)
4226 .multipart(form)
4227 .send()
4228 .await
4229 .map_err(|e| {
4230 InferenceError::InferenceFailed(format!("ElevenLabs STT request failed: {e}"))
4231 })?;
4232 let status = resp.status();
4233 let body = resp.text().await.map_err(|e| {
4234 InferenceError::InferenceFailed(format!("read ElevenLabs STT body: {e}"))
4235 })?;
4236 if !status.is_success() {
4237 return Err(InferenceError::InferenceFailed(format!(
4238 "ElevenLabs STT returned {status}: {body}"
4239 )));
4240 }
4241 let payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
4242 InferenceError::InferenceFailed(format!("parse ElevenLabs STT response: {e}"))
4243 })?;
4244 let text = payload
4245 .get("text")
4246 .and_then(|v| v.as_str())
4247 .map(str::to_string)
4248 .ok_or_else(|| {
4249 InferenceError::InferenceFailed("ElevenLabs STT response missing text".into())
4250 })?;
4251
4252 Ok(TranscribeResult {
4253 text,
4254 model_used: Some(schema.name.clone()),
4255 language: payload
4256 .get("language_code")
4257 .and_then(|v| v.as_str())
4258 .map(str::to_string),
4259 words: Vec::new(),
4260 })
4261 }
4262
4263 async fn synthesize_elevenlabs(
4264 &self,
4265 schema: &ModelSchema,
4266 req: &SynthesizeRequest,
4267 ) -> Result<SynthesizeResult, InferenceError> {
4268 let requested = req.requested_advanced_controls();
4272 if !requested.is_empty() {
4273 if req.strict_capabilities {
4274 return Err(InferenceError::InferenceFailed(format!(
4275 "ElevenLabs backend does not support Qwen3-TTS advanced controls \
4276 {requested:?}; route to a Qwen3-TTS model or set strict_capabilities = false"
4277 )));
4278 }
4279 tracing::warn!(
4280 model = %schema.name,
4281 fields = ?requested,
4282 "Qwen3-TTS advanced controls ignored by ElevenLabs backend"
4283 );
4284 }
4285 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4286 let voice_id = req
4287 .voice
4288 .clone()
4289 .unwrap_or_else(|| "JBFqnCBsd6RMkjVDRZzb".to_string());
4290 let output_format = elevenlabs_output_format(&req.format);
4291 let url = format!(
4292 "{}/v1/text-to-speech/{}?output_format={}",
4293 endpoint.trim_end_matches('/'),
4294 voice_id,
4295 output_format
4296 );
4297
4298 let mut body = serde_json::json!({
4299 "text": req.text,
4300 "model_id": schema.name,
4301 });
4302 if let Some(language) = &req.language {
4303 body["language_code"] = serde_json::Value::String(language.clone());
4304 }
4305
4306 let resp = self
4307 .remote_backend
4308 .client
4309 .post(url)
4310 .header("xi-api-key", api_key)
4311 .header("Content-Type", "application/json")
4312 .json(&body)
4313 .send()
4314 .await
4315 .map_err(|e| {
4316 InferenceError::InferenceFailed(format!("ElevenLabs TTS request failed: {e}"))
4317 })?;
4318 let status = resp.status();
4319 let audio = resp.bytes().await.map_err(|e| {
4320 InferenceError::InferenceFailed(format!("read ElevenLabs TTS body: {e}"))
4321 })?;
4322 if !status.is_success() {
4323 let err_body = String::from_utf8_lossy(&audio);
4324 return Err(InferenceError::InferenceFailed(format!(
4325 "ElevenLabs TTS returned {status}: {err_body}"
4326 )));
4327 }
4328
4329 let final_path = requested_or_temp_output(req.output_path.as_deref(), &req.format)?;
4330 ensure_parent_dir(&final_path)?;
4331 tokio::fs::write(&final_path, &audio).await?;
4332
4333 Ok(SynthesizeResult {
4334 audio_path: final_path.display().to_string(),
4335 media_type: media_type_for_format(&req.format),
4336 model_used: Some(schema.name.clone()),
4337 voice_used: Some(voice_id),
4338 })
4339 }
4340}
4341
4342#[derive(Default)]
4343struct ProviderAccumulator {
4344 configured: bool,
4345 local_models: usize,
4346 remote_models: usize,
4347 available_models: usize,
4348 capabilities: std::collections::HashSet<ModelCapability>,
4349}
4350
4351#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4355struct CommandOutput {
4356 stdout: String,
4357 stderr: String,
4358}
4359
4360#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4361#[derive(Debug, Clone)]
4362struct SpeechRuntime {
4363 root: PathBuf,
4364 python: PathBuf,
4365 stt_program: PathBuf,
4366 tts_program: PathBuf,
4367}
4368
4369#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4370impl SpeechRuntime {
4371 fn new(root: PathBuf) -> Self {
4372 let bin_dir = root.join("bin");
4373 Self {
4374 root,
4375 python: bin_dir.join("python"),
4376 stt_program: bin_dir.join("mlx_audio.stt.generate"),
4377 tts_program: bin_dir.join("mlx_audio.tts.generate"),
4378 }
4379 }
4380
4381 fn is_ready(&self) -> bool {
4382 self.python.exists() && self.stt_program.exists() && self.tts_program.exists()
4383 }
4384
4385 fn command_for(&self, subcommand: &str) -> Result<&Path, InferenceError> {
4386 match subcommand {
4387 "stt.generate" => Ok(&self.stt_program),
4388 "tts.generate" => Ok(&self.tts_program),
4389 _ => Err(InferenceError::InferenceFailed(format!(
4390 "unknown speech subcommand: {subcommand}"
4391 ))),
4392 }
4393 }
4394}
4395
4396#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4397async fn run_mlx_audio_command(
4398 runtime: &SpeechRuntime,
4399 subcommand: &str,
4400 args: &[String],
4401) -> Result<CommandOutput, InferenceError> {
4402 run_mlx_audio_command_with_env(runtime, subcommand, args, &[]).await
4403}
4404
4405#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4406async fn run_mlx_audio_command_with_env(
4407 runtime: &SpeechRuntime,
4408 subcommand: &str,
4409 args: &[String],
4410 envs: &[(String, String)],
4411) -> Result<CommandOutput, InferenceError> {
4412 let program = runtime.command_for(subcommand)?;
4413 let mut command = Command::new(program);
4414 command.args(args);
4415 for (key, value) in envs {
4416 command.env(key, value);
4417 }
4418 let output = command
4419 .output()
4420 .await
4421 .map_err(|err| InferenceError::InferenceFailed(format!("{}: {err}", program.display())))?;
4422
4423 if output.status.success() {
4424 Ok(CommandOutput {
4425 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
4426 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
4427 })
4428 } else {
4429 Err(InferenceError::InferenceFailed(format!(
4430 "{} exited with {}: {}",
4431 program.display(),
4432 output.status,
4433 String::from_utf8_lossy(&output.stderr)
4434 )))
4435 }
4436}
4437
4438#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4439async fn bootstrap_speech_runtime(runtime: &SpeechRuntime) -> Result<(), InferenceError> {
4440 std::fs::create_dir_all(&runtime.root)?;
4441 let python = select_speech_python()?;
4442
4443 run_command(
4444 "uv",
4445 &[
4446 "venv".to_string(),
4447 "--python".to_string(),
4448 python,
4449 runtime.root.display().to_string(),
4450 ],
4451 )
4452 .await?;
4453
4454 run_command(
4455 "uv",
4456 &[
4457 "pip".to_string(),
4458 "install".to_string(),
4459 "--python".to_string(),
4460 runtime.python.display().to_string(),
4461 speech_runtime_mlx_audio_spec(),
4462 "misaki[en]".to_string(),
4463 speech_runtime_spacy_model_spec(),
4464 ],
4465 )
4466 .await?;
4467
4468 Ok(())
4469}
4470
4471#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4472async fn run_command(program: &str, args: &[String]) -> Result<(), InferenceError> {
4473 let output = Command::new(program)
4474 .args(args)
4475 .output()
4476 .await
4477 .map_err(|err| InferenceError::InferenceFailed(format!("{program}: {err}")))?;
4478
4479 if output.status.success() {
4480 Ok(())
4481 } else {
4482 Err(InferenceError::InferenceFailed(format!(
4483 "{} exited with {}: {}",
4484 program,
4485 output.status,
4486 String::from_utf8_lossy(&output.stderr)
4487 )))
4488 }
4489}
4490
4491#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4492fn select_speech_python() -> Result<String, InferenceError> {
4493 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4494 if !path.trim().is_empty() {
4495 return Ok(path);
4496 }
4497 }
4498
4499 for candidate in ["python3.13", "python3.12", "python3.11"] {
4500 if command_in_path(candidate) {
4501 return Ok(candidate.to_string());
4502 }
4503 }
4504
4505 Err(InferenceError::InferenceFailed(
4506 "no supported Python found for managed speech runtime (tried python3.13, python3.12, python3.11)".into(),
4507 ))
4508}
4509
4510#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4511fn detect_speech_python() -> Option<String> {
4512 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4513 if !path.trim().is_empty() {
4514 return Some(path);
4515 }
4516 }
4517
4518 ["python3.13", "python3.12", "python3.11"]
4519 .into_iter()
4520 .find(|candidate| command_in_path(candidate))
4521 .map(str::to_string)
4522}
4523
4524#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4525fn speech_runtime_root_from_models_dir(_models_dir: &Path) -> PathBuf {
4526 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
4527 if !path.trim().is_empty() {
4528 return PathBuf::from(path);
4529 }
4530 }
4531
4532 std::env::var("HOME")
4533 .map(PathBuf::from)
4534 .unwrap_or_else(|_| PathBuf::from("."))
4535 .join(".car")
4536 .join("speech-runtime")
4537}
4538
4539#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4540fn command_in_path(name: &str) -> bool {
4541 std::env::var_os("PATH")
4542 .map(|paths| {
4543 std::env::split_paths(&paths).any(|dir| {
4544 let path = dir.join(name);
4545 path.exists() && path.is_file()
4546 })
4547 })
4548 .unwrap_or(false)
4549}
4550
4551fn speech_model_cached(schema: &ModelSchema) -> bool {
4552 match &schema.source {
4553 ModelSource::Mlx { hf_repo, .. } => huggingface_repo_has_snapshot(hf_repo),
4554 ModelSource::Proprietary { auth, .. } => match auth {
4555 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4556 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4557 ProprietaryAuth::OAuth2Pkce { .. } => false,
4558 },
4559 _ => false,
4560 }
4561}
4562
4563fn remove_huggingface_repo_cache(repo_id: &str) -> Result<(), InferenceError> {
4564 let repo_dir = std::env::var("HF_HOME")
4565 .map(PathBuf::from)
4566 .unwrap_or_else(|_| {
4567 std::env::var("HOME")
4568 .map(PathBuf::from)
4569 .unwrap_or_else(|_| PathBuf::from("."))
4570 .join(".cache")
4571 .join("huggingface")
4572 })
4573 .join("hub")
4574 .join(format!("models--{}", repo_id.replace('/', "--")));
4575
4576 if repo_dir.exists() {
4577 std::fs::remove_dir_all(repo_dir)?;
4578 }
4579 Ok(())
4580}
4581
4582fn model_source_configured(schema: &ModelSchema) -> bool {
4583 match &schema.source {
4584 ModelSource::RemoteApi {
4585 api_key_env,
4586 api_key_envs,
4587 ..
4588 } => {
4589 std::env::var(api_key_env).is_ok()
4590 || api_key_envs
4591 .iter()
4592 .any(|env_var| std::env::var(env_var).is_ok())
4593 }
4594 ModelSource::Proprietary { auth, .. } => match auth {
4595 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4596 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4597 ProprietaryAuth::OAuth2Pkce { .. } => false,
4598 },
4599 ModelSource::VllmMlx { .. } => {
4600 std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available
4601 }
4602 ModelSource::Ollama { .. } => schema.available,
4603 ModelSource::Mlx { .. } | ModelSource::Local { .. } => true,
4604 ModelSource::AppleFoundationModels { .. } => schema.available,
4605 ModelSource::Delegated { .. } => true,
4610 }
4611}
4612
4613fn all_model_capabilities() -> [ModelCapability; 13] {
4614 [
4615 ModelCapability::Generate,
4616 ModelCapability::Embed,
4617 ModelCapability::Classify,
4618 ModelCapability::Code,
4619 ModelCapability::Reasoning,
4620 ModelCapability::Summarize,
4621 ModelCapability::ToolUse,
4622 ModelCapability::MultiToolCall,
4623 ModelCapability::Vision,
4624 ModelCapability::SpeechToText,
4625 ModelCapability::TextToSpeech,
4626 ModelCapability::ImageGeneration,
4627 ModelCapability::VideoGeneration,
4628 ]
4629}
4630
4631fn sort_capabilities(mut capabilities: Vec<ModelCapability>) -> Vec<ModelCapability> {
4632 capabilities.sort_by_key(|capability| {
4633 all_model_capabilities()
4634 .iter()
4635 .position(|candidate| candidate == capability)
4636 .unwrap_or(usize::MAX)
4637 });
4638 capabilities
4639}
4640
4641fn speech_model_source_label(schema: &ModelSchema) -> String {
4642 match &schema.source {
4643 ModelSource::Mlx { hf_repo, .. } => format!("mlx:{hf_repo}"),
4644 ModelSource::Proprietary {
4645 provider, endpoint, ..
4646 } => format!("proprietary:{provider}:{endpoint}"),
4647 ModelSource::RemoteApi { endpoint, .. } => format!("remote:{endpoint}"),
4648 ModelSource::Local { hf_repo, .. } => format!("local:{hf_repo}"),
4649 ModelSource::VllmMlx {
4650 endpoint,
4651 model_name,
4652 } => format!("vllm-mlx:{endpoint}:{model_name}"),
4653 ModelSource::Ollama { model_tag, host } => format!("ollama:{host}:{model_tag}"),
4654 ModelSource::AppleFoundationModels { use_case } => {
4655 format!(
4656 "apple-foundation:{}",
4657 use_case.as_deref().unwrap_or("default")
4658 )
4659 }
4660 ModelSource::Delegated { hint } => {
4661 format!("delegated:{}", hint.as_deref().unwrap_or("(none)"))
4662 }
4663 }
4664}
4665
4666fn rerank_prompt(instruction: &str, query: &str, document: &str) -> String {
4674 const SYSTEM: &str = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".";
4675 format!(
4676 "<|im_start|>system\n{SYSTEM}<|im_end|>\n\
4677 <|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n\
4678 <|im_start|>assistant\n<think>\n\n</think>\n\n"
4679 )
4680}
4681
4682fn no_backend_recovery_hint(underlying: &str) -> Option<String> {
4701 let is_no_backend = underlying.contains("no credential")
4702 || underlying.contains("model not found")
4703 || underlying.contains("no models available")
4704 || underlying.contains("no inference runner");
4705 if !is_no_backend {
4706 return None;
4707 }
4708 Some(format!(
4709 "no inference backend is available. To install a local CPU-only \
4710 model (no account required, works on Windows/Linux/macOS), run:\n \
4711 car models pull qwen/qwen3-1.7b:q8_0\n\
4712 To use Parslee's hosted models instead, run:\n \
4713 car auth login parslee\n\
4714 (underlying error: {underlying})"
4715 ))
4716}
4717
4718fn score_from_rerank_output(text: &str, model_name: &str) -> f32 {
4719 let normalized: String = text
4724 .to_ascii_lowercase()
4725 .chars()
4726 .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
4727 .collect();
4728 for tok in normalized.split_ascii_whitespace().take(5) {
4729 match tok {
4730 "yes" => return 1.0,
4731 "no" => return 0.0,
4732 _ => continue,
4733 }
4734 }
4735 tracing::warn!(
4736 model = %model_name,
4737 output = %text,
4738 "rerank: first tokens contain neither `yes` nor `no`; returning neutral 0.5"
4739 );
4740 0.5
4741}
4742
4743fn default_speech_voice(schema: &ModelSchema) -> Option<String> {
4744 if schema.provider == "elevenlabs" {
4745 Some("JBFqnCBsd6RMkjVDRZzb".to_string())
4746 } else if schema.name == "Kokoro-82M-6bit" || schema.name == "Kokoro-82M-bf16" {
4747 Some("af_heart".to_string())
4748 } else if schema.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
4749 Some("Chelsie".to_string())
4750 } else {
4751 None
4752 }
4753}
4754
4755#[allow(dead_code)] fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
4757 find_latest_huggingface_snapshot(repo_id).is_some()
4758}
4759
4760fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
4761 let cache_root = std::env::var("HF_HOME")
4762 .map(PathBuf::from)
4763 .unwrap_or_else(|_| {
4764 std::env::var("HOME")
4765 .map(PathBuf::from)
4766 .unwrap_or_else(|_| PathBuf::from("."))
4767 .join(".cache")
4768 .join("huggingface")
4769 })
4770 .join("hub");
4771 cache_root.join(format!("models--{}", repo_id.replace('/', "--")))
4772}
4773
4774fn find_latest_huggingface_snapshot(repo_id: &str) -> Option<PathBuf> {
4775 let snapshots = huggingface_repo_dir(repo_id).join("snapshots");
4776 std::fs::read_dir(snapshots)
4777 .ok()?
4778 .filter_map(Result::ok)
4779 .map(|entry| entry.path())
4780 .find(|path| path.is_dir() && snapshot_looks_ready(path))
4781}
4782
4783fn snapshot_looks_ready(path: &Path) -> bool {
4784 if path.join("config.json").exists() || path.join("model_index.json").exists() {
4785 return true;
4786 }
4787 snapshot_contains_ext(path, "safetensors")
4788}
4789
4790fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
4791 let Ok(entries) = std::fs::read_dir(root) else {
4792 return false;
4793 };
4794 entries.filter_map(Result::ok).any(|entry| {
4795 let path = entry.path();
4796 if path.is_dir() {
4797 snapshot_contains_ext(&path, ext)
4798 } else {
4799 path.extension()
4800 .and_then(|value| value.to_str())
4801 .map(|value| value.eq_ignore_ascii_case(ext))
4802 .unwrap_or(false)
4803 }
4804 })
4805}
4806
4807#[allow(dead_code)] fn count_files_recursive(root: &Path) -> usize {
4809 let Ok(entries) = std::fs::read_dir(root) else {
4810 return 0;
4811 };
4812 entries
4813 .filter_map(Result::ok)
4814 .map(|entry| entry.path())
4815 .map(|path| {
4816 if path.is_dir() {
4817 count_files_recursive(&path)
4818 } else if path.is_file() {
4819 1
4820 } else {
4821 0
4822 }
4823 })
4824 .sum()
4825}
4826
4827async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
4828 let api = hf_hub::api::tokio::ApiBuilder::from_env()
4829 .with_progress(false)
4830 .build()
4831 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
4832 let repo = api.model(repo_id.to_string());
4833 let info = repo
4834 .info()
4835 .await
4836 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
4837
4838 let snapshot_path = huggingface_repo_dir(repo_id)
4839 .join("snapshots")
4840 .join(&info.sha);
4841 let mut downloaded = 0usize;
4842 for sibling in &info.siblings {
4843 let local_path = snapshot_path.join(&sibling.rfilename);
4844 if local_path.exists() {
4845 downloaded += 1;
4846 continue;
4847 }
4848 repo.download(&sibling.rfilename).await.map_err(|e| {
4849 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
4850 })?;
4851 downloaded += 1;
4852 }
4853
4854 Ok((snapshot_path, downloaded))
4855}
4856
4857fn temp_work_dir(prefix: &str) -> Result<PathBuf, InferenceError> {
4858 let unique = SystemTime::now()
4859 .duration_since(UNIX_EPOCH)
4860 .map_err(|e| InferenceError::InferenceFailed(format!("clock error: {e}")))?
4861 .as_nanos();
4862 let dir = std::env::temp_dir().join(format!("car-inference-{prefix}-{unique}"));
4863 std::fs::create_dir_all(&dir)?;
4864 Ok(dir)
4865}
4866
4867fn ensure_parent_dir(path: &Path) -> Result<(), InferenceError> {
4868 if let Some(parent) = path.parent() {
4869 std::fs::create_dir_all(parent)?;
4870 }
4871 Ok(())
4872}
4873
4874fn requested_or_temp_output(
4875 output_path: Option<&str>,
4876 format: &str,
4877) -> Result<PathBuf, InferenceError> {
4878 if let Some(path) = output_path {
4879 return Ok(PathBuf::from(path));
4880 }
4881 let dir = temp_work_dir("audio-out")?;
4882 Ok(dir.join(format!("speech.{format}")))
4883}
4884
4885#[allow(dead_code)] fn requested_or_temp_media_output(
4887 output_path: Option<&str>,
4888 format: &str,
4889 stem: &str,
4890) -> Result<PathBuf, InferenceError> {
4891 if let Some(path) = output_path {
4892 return Ok(PathBuf::from(path));
4893 }
4894 let dir = temp_work_dir(&format!("{stem}-out"))?;
4895 Ok(dir.join(format!("{stem}.{format}")))
4896}
4897
4898fn materialize_audio_output(
4899 produced: &Path,
4900 requested: Option<&str>,
4901 format: &str,
4902) -> Result<PathBuf, InferenceError> {
4903 if let Some(path) = requested {
4904 let dest = PathBuf::from(path);
4905 ensure_parent_dir(&dest)?;
4906 std::fs::copy(produced, &dest)?;
4907 Ok(dest)
4908 } else {
4909 let dest = requested_or_temp_output(None, format)?;
4910 ensure_parent_dir(&dest)?;
4911 std::fs::copy(produced, &dest)?;
4912 Ok(dest)
4913 }
4914}
4915
4916#[allow(dead_code)] fn read_transcription_result(output_prefix: &Path) -> Result<Option<String>, InferenceError> {
4918 let candidates = [
4919 output_prefix.with_extension("json"),
4920 output_prefix.to_path_buf(),
4921 ];
4922
4923 for path in candidates {
4924 if path.exists() {
4925 let contents = std::fs::read_to_string(path)?;
4926 if let Some(text) = extract_text_from_payload(&contents) {
4927 return Ok(Some(text));
4928 }
4929 }
4930 }
4931
4932 Ok(None)
4933}
4934
4935#[allow(dead_code)] fn extract_text_from_payload(payload: &str) -> Option<String> {
4937 let value: serde_json::Value = serde_json::from_str(payload).ok()?;
4938 if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
4939 return Some(text.to_string());
4940 }
4941 if let Some(transcripts) = value.get("transcripts").and_then(|v| v.as_array()) {
4942 let joined = transcripts
4943 .iter()
4944 .filter_map(|item| item.get("text").and_then(|v| v.as_str()))
4945 .collect::<Vec<_>>()
4946 .join("\n");
4947 if !joined.is_empty() {
4948 return Some(joined);
4949 }
4950 }
4951 if let Some(items) = value.as_array() {
4952 let joined = items
4953 .iter()
4954 .filter_map(|item| {
4955 item.get("text")
4956 .or_else(|| item.get("Content"))
4957 .and_then(|v| v.as_str())
4958 })
4959 .collect::<Vec<_>>()
4960 .join(" ");
4961 if !joined.is_empty() {
4962 return Some(joined);
4963 }
4964 }
4965 None
4966}
4967
4968#[allow(dead_code)] fn find_audio_file(output_dir: &Path) -> Result<Option<PathBuf>, InferenceError> {
4970 let mut audio_files = Vec::new();
4971 collect_audio_files(output_dir, &mut audio_files)?;
4972 audio_files.sort();
4973 Ok(audio_files.into_iter().next())
4974}
4975
4976#[allow(dead_code)] fn collect_audio_files(dir: &Path, audio_files: &mut Vec<PathBuf>) -> Result<(), InferenceError> {
4978 for entry in std::fs::read_dir(dir)? {
4979 let path = entry?.path();
4980 if path.is_dir() {
4981 collect_audio_files(&path, audio_files)?;
4982 } else if matches!(
4983 path.extension().and_then(|ext| ext.to_str()),
4984 Some("wav" | "mp3" | "flac" | "pcm" | "m4a")
4985 ) {
4986 audio_files.push(path);
4987 }
4988 }
4989 Ok(())
4990}
4991
4992fn media_type_for_format(format: &str) -> String {
4993 match format.to_ascii_lowercase().as_str() {
4994 "mp3" => "audio/mpeg".to_string(),
4995 "flac" => "audio/flac".to_string(),
4996 "pcm" => "audio/L16".to_string(),
4997 "m4a" => "audio/mp4".to_string(),
4998 _ => "audio/wav".to_string(),
4999 }
5000}
5001
5002#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5003fn kokoro_lang_code(language: Option<&str>) -> &'static str {
5004 match language.unwrap_or("en").to_ascii_lowercase().as_str() {
5005 "en-gb" | "british" | "british english" => "b",
5006 "ja" | "japanese" => "j",
5007 "zh" | "zh-cn" | "mandarin" | "chinese" => "z",
5008 "es" | "spanish" => "e",
5009 "fr" | "french" => "f",
5010 _ => "a",
5011 }
5012}
5013
5014#[allow(dead_code)] fn normalize_lang_code(language: &str) -> String {
5016 match language.to_ascii_lowercase().as_str() {
5017 "english" | "en-us" | "en_us" => "en".to_string(),
5018 "spanish" => "es".to_string(),
5019 "french" => "fr".to_string(),
5020 "japanese" => "ja".to_string(),
5021 "chinese" | "mandarin" => "zh".to_string(),
5022 other => match other {
5023 "en" | "es" | "fr" | "ja" | "zh" => other.to_string(),
5024 _ => "en".to_string(),
5025 },
5026 }
5027}
5028
5029fn elevenlabs_auth(schema: &ModelSchema) -> Result<(String, String), InferenceError> {
5030 match &schema.source {
5031 ModelSource::Proprietary {
5032 endpoint,
5033 auth: schema::ProprietaryAuth::ApiKeyEnv { env_var },
5034 ..
5035 } => {
5036 let key = car_secrets::resolve_env_or_keychain(env_var).ok_or_else(|| {
5037 InferenceError::InferenceFailed(format!(
5038 "missing API key {env_var}; set the environment variable or \
5039 store it with `car secrets put {env_var}`"
5040 ))
5041 })?;
5042 Ok((endpoint.clone(), key))
5043 }
5044 _ => Err(InferenceError::InferenceFailed(format!(
5045 "model {} is not an ElevenLabs proprietary model",
5046 schema.id
5047 ))),
5048 }
5049}
5050
5051fn elevenlabs_output_format(format: &str) -> &'static str {
5052 match format.to_ascii_lowercase().as_str() {
5053 "mp3" => "mp3_44100_128",
5054 "pcm" => "pcm_16000",
5055 _ => "wav_44100",
5056 }
5057}
5058
5059fn benchmark_priors_paths(models_dir: &Path) -> Vec<PathBuf> {
5060 let mut paths = Vec::new();
5061
5062 let direct = models_dir.join("benchmark_priors.json");
5063 if !paths.contains(&direct) {
5064 paths.push(direct);
5065 }
5066
5067 if let Some(parent) = models_dir.parent() {
5068 let parent_path = parent.join("benchmark_priors.json");
5069 if !paths.contains(&parent_path) {
5070 paths.push(parent_path);
5071 }
5072 }
5073
5074 if let Some(path) = std::env::var_os("CAR_BENCHMARK_PRIORS_PATH") {
5075 let path = PathBuf::from(path);
5076 if !paths.contains(&path) {
5077 paths.push(path);
5078 }
5079 }
5080
5081 paths
5082}
5083
5084fn load_benchmark_prior_health(
5085 models_dir: &Path,
5086 schemas: &[ModelSchema],
5087) -> Vec<ModelBenchmarkPriorHealth> {
5088 let mut priors = std::collections::BTreeMap::new();
5089 for path in benchmark_priors_paths(models_dir) {
5090 let Ok(loaded) = routing_ext::load_benchmark_priors(&path) else {
5091 continue;
5092 };
5093 for (model_id, prior) in loaded {
5094 let model_name = schemas
5095 .iter()
5096 .find(|schema| schema.id == model_id)
5097 .map(|schema| schema.name.clone());
5098 priors.insert(
5099 model_id.clone(),
5100 ModelBenchmarkPriorHealth {
5101 model_id,
5102 model_name,
5103 overall_score: prior.overall_score,
5104 overall_latency_ms: prior.overall_latency_ms,
5105 task_scores: prior.task_scores,
5106 task_latency_ms: prior.task_latency_ms,
5107 source_path: path.clone(),
5108 },
5109 );
5110 }
5111 }
5112
5113 priors.into_values().collect()
5114}
5115
5116#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5117fn kokoro_runtime_fallback_enabled() -> bool {
5118 std::env::var("CAR_SPEECH_KOKORO_FALLBACK")
5119 .ok()
5120 .map(|value| {
5121 !matches!(
5122 value.trim().to_ascii_lowercase().as_str(),
5123 "0" | "false" | "off"
5124 )
5125 })
5126 .unwrap_or(true)
5127}
5128
5129#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5130fn speech_runtime_mlx_audio_spec() -> String {
5131 std::env::var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC")
5132 .ok()
5133 .filter(|value| !value.trim().is_empty())
5134 .unwrap_or_else(|| "mlx-audio==0.4.2".to_string())
5135}
5136
5137#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5138fn speech_runtime_spacy_model_spec() -> String {
5139 std::env::var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC")
5140 .ok()
5141 .filter(|value| !value.trim().is_empty())
5142 .unwrap_or_else(|| {
5143 "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl".to_string()
5144 })
5145}
5146
5147#[cfg(test)]
5148mod tests {
5149 use super::*;
5150 use tempfile::TempDir;
5151
5152 #[test]
5153 fn no_backend_hint_fires_on_missing_backend_phrases() {
5154 for phrase in [
5156 "no credential for proprietary provider 'parslee'",
5157 "model not found",
5158 "no models available",
5159 "model declares ModelSource::Delegated but no inference runner is registered",
5160 ] {
5161 let hint = no_backend_recovery_hint(phrase)
5162 .unwrap_or_else(|| panic!("expected a hint for {phrase:?}"));
5163 assert!(hint.contains("car models pull"));
5164 assert!(hint.contains("car auth login parslee"));
5165 assert!(hint.contains(phrase));
5167 }
5168 }
5169
5170 #[test]
5171 fn no_backend_hint_passes_through_transient_errors() {
5172 for phrase in [
5175 "API returned 401 Unauthorized",
5176 "API returned 429 Too Many Requests",
5177 "API returned 500 Internal Server Error",
5178 "connection refused",
5179 "request timed out",
5180 "parse response: unexpected end of input",
5181 ] {
5182 assert!(
5183 no_backend_recovery_hint(phrase).is_none(),
5184 "transient error wrongly classified as no-backend: {phrase:?}"
5185 );
5186 }
5187 }
5188
5189 static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
5192
5193 fn test_config(models_dir: PathBuf) -> InferenceConfig {
5194 InferenceConfig {
5195 models_dir,
5196 device: None,
5197 generation_model: "Qwen3-0.6B".into(),
5198 preferred_generation_model: None,
5199 embedding_model: "Qwen3-Embedding-0.6B".into(),
5200 preferred_embedding_model: None,
5201 classification_model: "Qwen3-0.6B".into(),
5202 preferred_classification_model: None,
5203 }
5204 }
5205
5206 #[tokio::test]
5207 async fn tokenize_rejects_known_remote_model_with_unsupported_mode() {
5208 let tmp = TempDir::new().unwrap();
5213 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5214 let remote_id = engine
5215 .list_schemas()
5216 .into_iter()
5217 .find(|s| !s.is_local())
5218 .map(|s| s.id)
5219 .expect("built-in catalog should include at least one remote model schema");
5220
5221 let err = engine
5222 .tokenize(&remote_id, "hello")
5223 .await
5224 .expect_err("remote tokenize must error");
5225 match err {
5226 InferenceError::UnsupportedMode { mode, backend, .. } => {
5227 assert_eq!(mode, "tokenize/detokenize");
5228 assert_eq!(backend, "remote");
5229 }
5230 other => panic!("expected UnsupportedMode, got {other:?}"),
5231 }
5232
5233 let err = engine
5234 .detokenize(&remote_id, &[1, 2, 3])
5235 .await
5236 .expect_err("remote detokenize must error");
5237 assert!(
5238 matches!(err, InferenceError::UnsupportedMode { .. }),
5239 "expected UnsupportedMode, got {err:?}"
5240 );
5241 }
5242
5243 #[test]
5244 fn engine_loads_benchmark_priors_on_startup() {
5245 let _env = ENV_MUTEX.lock().unwrap();
5246 let tmp = TempDir::new().unwrap();
5247 let priors_path = tmp.path().join("benchmark_priors.json");
5248 std::fs::write(
5249 &priors_path,
5250 serde_json::json!({
5251 "model_id": "qwen/qwen3-8b:q4_k_m",
5252 "overall_score": 0.88
5253 })
5254 .to_string(),
5255 )
5256 .unwrap();
5257
5258 unsafe {
5259 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5260 }
5261
5262 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5263 let tracker = engine.outcome_tracker.blocking_read();
5264 let profile = tracker
5265 .profile("qwen/qwen3-8b:q4_k_m")
5266 .expect("benchmark prior should create a profile");
5267 assert!((profile.ema_quality - 0.88).abs() < 0.01);
5268
5269 unsafe {
5270 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5271 }
5272 }
5273
5274 #[test]
5275 fn benchmark_priors_do_not_override_observed_profiles() {
5276 let _env = ENV_MUTEX.lock().unwrap();
5277 let tmp = TempDir::new().unwrap();
5278 let models_dir = tmp.path().join("models");
5279 std::fs::create_dir_all(&models_dir).unwrap();
5280
5281 let observed = vec![ModelProfile {
5282 model_id: "qwen/qwen3-8b:q4_k_m".into(),
5283 total_calls: 12,
5284 success_count: 3,
5285 fail_count: 9,
5286 total_latency_ms: 1200,
5287 total_input_tokens: 0,
5288 total_output_tokens: 0,
5289 task_stats: std::collections::HashMap::new(),
5290 ema_quality: 0.21,
5291 quality_per_1k_tokens: 0.0,
5292 updated_at: 1,
5293 }];
5294 std::fs::write(
5295 models_dir.join("outcome_profiles.json"),
5296 serde_json::to_string(&observed).unwrap(),
5297 )
5298 .unwrap();
5299
5300 let priors_path = tmp.path().join("benchmark_priors.json");
5301 std::fs::write(
5302 &priors_path,
5303 serde_json::json!({
5304 "model_id": "qwen/qwen3-8b:q4_k_m",
5305 "overall_score": 0.95
5306 })
5307 .to_string(),
5308 )
5309 .unwrap();
5310
5311 unsafe {
5312 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5313 }
5314
5315 let engine = InferenceEngine::new(test_config(models_dir));
5316 let tracker = engine.outcome_tracker.blocking_read();
5317 let profile = tracker
5318 .profile("qwen/qwen3-8b:q4_k_m")
5319 .expect("observed profile should remain present");
5320 assert!((profile.ema_quality - 0.21).abs() < 0.01);
5321 assert_eq!(profile.total_calls, 12);
5322
5323 unsafe {
5324 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5325 }
5326 }
5327
5328 #[test]
5329 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5330 fn speech_runtime_package_spec_defaults_and_overrides() {
5331 let _env = ENV_MUTEX.lock().unwrap();
5332 unsafe {
5333 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5334 }
5335 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.2");
5336
5337 unsafe {
5338 std::env::set_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC", "mlx-audio==0.4.1");
5339 }
5340 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.1");
5341
5342 unsafe {
5343 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5344 }
5345 }
5346
5347 #[test]
5348 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5349 fn speech_runtime_spacy_model_spec_defaults_and_overrides() {
5350 let _env = ENV_MUTEX.lock().unwrap();
5351 unsafe {
5352 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5353 }
5354 assert!(
5355 speech_runtime_spacy_model_spec().starts_with("en-core-web-sm @ https://github.com/")
5356 );
5357
5358 unsafe {
5359 std::env::set_var(
5360 "CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC",
5361 "en-core-web-sm==3.8.0",
5362 );
5363 }
5364 assert_eq!(speech_runtime_spacy_model_spec(), "en-core-web-sm==3.8.0");
5365
5366 unsafe {
5367 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5368 }
5369 }
5370
5371 #[test]
5372 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5373 fn kokoro_runtime_fallback_defaults_on() {
5374 unsafe {
5375 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5376 }
5377 assert!(kokoro_runtime_fallback_enabled());
5378
5379 unsafe {
5380 std::env::set_var("CAR_SPEECH_KOKORO_FALLBACK", "false");
5381 }
5382 assert!(!kokoro_runtime_fallback_enabled());
5383
5384 unsafe {
5385 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5386 }
5387 }
5388
5389 #[test]
5390 fn preferred_local_tts_wins_over_builtin_rank() {
5391 let tmp = TempDir::new().unwrap();
5392 let mut engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5393 engine.set_speech_policy(SpeechPolicy {
5394 prefer_local: true,
5395 allow_remote_fallback: false,
5396 preferred_local_stt: None,
5397 preferred_local_tts: Some("Kokoro-82M-6bit".into()),
5398 preferred_remote_stt: None,
5399 preferred_remote_tts: None,
5400 });
5401
5402 let schema = engine
5403 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
5404 .expect("preferred local TTS should resolve");
5405 assert_eq!(schema.name, "Kokoro-82M-6bit");
5406 }
5407
5408 #[test]
5409 fn preferred_discovered_vllm_mlx_model_wins_generate_routing() {
5410 let tmp = TempDir::new().unwrap();
5411 let mut config = test_config(tmp.path().join("models"));
5412 config.preferred_generation_model =
5413 Some("vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit".into());
5414 let mut engine = InferenceEngine::new(config);
5415 let schema = crate::vllm_mlx::to_model_schema(
5416 &crate::vllm_mlx::DiscoveredModel {
5417 id: "mlx-community/gemma-3n-E2B-it-lm-4bit".into(),
5418 owned_by: Some("mlx-community".into()),
5419 },
5420 "http://127.0.0.1:8001",
5421 );
5422 engine.register_model(schema);
5423
5424 let rt = tokio::runtime::Runtime::new().unwrap();
5425 let decision = rt.block_on(engine.route_adaptive("say hello in one sentence"));
5426 assert_eq!(
5427 decision.model_id,
5428 "vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit"
5429 );
5430 assert_eq!(decision.strategy, RoutingStrategy::Explicit);
5431 assert_eq!(decision.reason, "preferred generation model override");
5432 }
5433
5434 #[test]
5439 fn inference_result_serializes_with_full_shape() {
5440 use crate::tasks::generate::ToolCall;
5441 use std::collections::HashMap;
5442
5443 let mut args = HashMap::new();
5444 args.insert("path".to_string(), serde_json::json!("README.md"));
5445
5446 let result = InferenceResult {
5447 text: String::new(),
5448 bounding_boxes: Vec::new(),
5449 tool_calls: vec![ToolCall {
5450 id: None,
5451 name: "read_file".into(),
5452 arguments: args,
5453 }],
5454 trace_id: "trace-abc".into(),
5455 model_used: "test-model".into(),
5456 latency_ms: 1234,
5457 time_to_first_token_ms: Some(180),
5458 usage: Some(TokenUsage {
5459 prompt_tokens: 100,
5460 completion_tokens: 50,
5461 total_tokens: 150,
5462 context_window: 8192,
5463 }),
5464 provider_output_items: Vec::new(),
5465 };
5466
5467 let json = serde_json::to_value(&result).expect("serialize");
5468
5469 assert_eq!(json["text"].as_str(), Some(""));
5471 assert_eq!(json["trace_id"].as_str(), Some("trace-abc"));
5472 assert_eq!(json["model_used"].as_str(), Some("test-model"));
5473 assert_eq!(json["latency_ms"].as_u64(), Some(1234));
5474
5475 let tool_calls = json["tool_calls"].as_array().expect("tool_calls array");
5477 assert_eq!(tool_calls.len(), 1);
5478 assert_eq!(tool_calls[0]["name"].as_str(), Some("read_file"));
5479 assert_eq!(
5480 tool_calls[0]["arguments"]["path"].as_str(),
5481 Some("README.md")
5482 );
5483
5484 let usage = &json["usage"];
5486 assert_eq!(usage["prompt_tokens"].as_u64(), Some(100));
5487 assert_eq!(usage["completion_tokens"].as_u64(), Some(50));
5488 assert_eq!(usage["total_tokens"].as_u64(), Some(150));
5489 assert_eq!(usage["context_window"].as_u64(), Some(8192));
5490
5491 assert_eq!(json["time_to_first_token_ms"].as_u64(), Some(180));
5493 }
5494
5495 #[test]
5501 fn inference_result_top_level_keys_are_locked() {
5502 use std::collections::BTreeSet;
5503
5504 let result = InferenceResult {
5505 text: "anything".into(),
5506 bounding_boxes: Vec::new(),
5507 tool_calls: vec![],
5508 trace_id: "t".into(),
5509 model_used: "m".into(),
5510 latency_ms: 0,
5511 time_to_first_token_ms: None,
5512 usage: None,
5513 provider_output_items: Vec::new(),
5514 };
5515
5516 let json = serde_json::to_value(&result).expect("serialize");
5517 let keys: BTreeSet<&str> = json
5518 .as_object()
5519 .expect("top-level object")
5520 .keys()
5521 .map(String::as_str)
5522 .collect();
5523
5524 let expected: BTreeSet<&str> = [
5525 "text",
5526 "tool_calls",
5527 "trace_id",
5528 "model_used",
5529 "latency_ms",
5530 "time_to_first_token_ms",
5531 "usage",
5532 ]
5533 .into_iter()
5534 .collect();
5535
5536 assert_eq!(
5537 keys, expected,
5538 "infer response top-level keys drifted -- update both the test \
5539 and the WebSocket protocol documentation if this is intentional"
5540 );
5541
5542 for key in &keys {
5544 assert!(
5545 !key.chars().any(|c| c.is_uppercase()) && !key.contains('-'),
5546 "key '{}' is not snake_case",
5547 key
5548 );
5549 }
5550 }
5551
5552 #[test]
5556 fn inference_result_serializes_plain_text_response() {
5557 let result = InferenceResult {
5558 text: "hello world".into(),
5559 bounding_boxes: Vec::new(),
5560 tool_calls: vec![],
5561 trace_id: "trace-xyz".into(),
5562 model_used: "test-model".into(),
5563 latency_ms: 42,
5564 time_to_first_token_ms: None,
5565 usage: None,
5566 provider_output_items: Vec::new(),
5567 };
5568
5569 let json = serde_json::to_value(&result).expect("serialize");
5570 assert_eq!(json["text"], "hello world");
5571 assert!(json["tool_calls"].is_array());
5572 assert_eq!(json["tool_calls"].as_array().unwrap().len(), 0);
5573 assert_eq!(json["model_used"], "test-model");
5574 assert!(json["usage"].is_null());
5575 assert!(json["time_to_first_token_ms"].is_null());
5578 }
5579
5580 #[test]
5592 fn generate_request_deserializes_intent_field_from_json_rpc_params() {
5593 use crate::intent::{IntentHint, TaskHint};
5594 use crate::schema::ModelCapability;
5595
5596 let params = serde_json::json!({
5599 "prompt": "summarize this email",
5600 "intent": {
5601 "task": "chat",
5602 "prefer_local": true,
5603 "require": ["tool_use"],
5604 },
5605 });
5606
5607 let req: GenerateRequest =
5608 serde_json::from_value(params).expect("GenerateRequest deserialize");
5609
5610 let intent = req.intent.as_ref().expect("intent field deserialized");
5611 assert_eq!(intent.task, Some(TaskHint::Chat));
5612 assert!(intent.prefer_local);
5613 assert_eq!(intent.require, vec![ModelCapability::ToolUse]);
5614
5615 let back: serde_json::Value =
5619 serde_json::to_value(&req).expect("re-serialize GenerateRequest");
5620 assert_eq!(back["intent"]["task"], "chat");
5621 assert_eq!(back["intent"]["prefer_local"], true);
5622 assert_eq!(back["intent"]["require"][0], "tool_use");
5623
5624 let default_req: GenerateRequest = serde_json::from_value(serde_json::json!({
5629 "prompt": "x",
5630 "intent": {},
5631 }))
5632 .unwrap();
5633 let default_intent = default_req.intent.expect("present but empty");
5634 assert_eq!(default_intent.task, None);
5635 assert!(!default_intent.prefer_local);
5636 assert!(default_intent.require.is_empty());
5637
5638 let no_intent: GenerateRequest =
5641 serde_json::from_value(serde_json::json!({"prompt": "x"})).unwrap();
5642 assert!(no_intent.intent.is_none());
5643 }
5644
5645 #[test]
5646 fn rerank_prompt_matches_upstream_template_shape() {
5647 let p = rerank_prompt(
5648 "retrieve relevant passages",
5649 "who runs the treasury?",
5650 "doc x",
5651 );
5652 assert!(p.contains("<|im_start|>system"));
5653 assert!(p.contains("Note that the answer can only be \"yes\" or \"no\"."));
5654 assert!(p.contains("<|im_start|>user\n<Instruct>: retrieve relevant passages"));
5655 assert!(p.contains("<Query>: who runs the treasury?"));
5656 assert!(p.contains("<Document>: doc x<|im_end|>"));
5657 assert!(p.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
5658 }
5659
5660 #[test]
5661 fn rerank_score_yes_and_no_exactly() {
5662 assert_eq!(score_from_rerank_output("yes", "m"), 1.0);
5663 assert_eq!(score_from_rerank_output("no", "m"), 0.0);
5664 }
5665
5666 #[test]
5667 fn rerank_score_handles_case_leading_space_and_chat_sentinels() {
5668 assert_eq!(score_from_rerank_output(" Yes", "m"), 1.0);
5671 assert_eq!(score_from_rerank_output("\nno.", "m"), 0.0);
5672 assert_eq!(score_from_rerank_output("<|im_end|>yes", "m"), 1.0);
5673 }
5674
5675 #[test]
5676 fn rerank_score_scans_up_to_three_tokens() {
5677 assert_eq!(score_from_rerank_output("_bos_ yes", "m"), 1.0);
5680 }
5681
5682 #[test]
5683 fn rerank_score_unexpected_is_neutral() {
5684 assert_eq!(score_from_rerank_output("maybe", "m"), 0.5);
5687 assert_eq!(score_from_rerank_output("", "m"), 0.5);
5688 }
5689}