1pub mod adaptive_router;
26pub mod backend;
27pub mod backend_cache;
28pub mod hardware;
29pub mod intent;
30pub mod key_pool;
31pub mod models;
32pub mod outcome;
33pub mod protocol;
34pub mod registry;
35pub mod remote;
36pub mod router;
37pub mod runner;
38pub mod routing_ext;
39pub mod schema;
40pub mod service;
41pub mod stream;
42pub mod tasks;
43pub mod vllm_mlx;
44
45use std::path::{Path, PathBuf};
46use std::sync::Arc;
47use std::time::Instant;
48use std::time::{SystemTime, UNIX_EPOCH};
49
50use serde::Serialize;
51use reqwest::multipart::{Form, Part};
52use thiserror::Error;
53#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
54use tokio::process::Command;
55#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
56use tokio::sync::Mutex;
57use tokio::sync::RwLock;
58use tracing::{debug, instrument};
59
60pub use adaptive_router::{
62 AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy,
63};
64pub use intent::{IntentHint, TaskHint};
65pub use key_pool::{KeyPool, KeyStats};
66pub use outcome::{
67 CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
68};
69pub use registry::{ModelFilter, ModelInfo, ModelRuntimeRequirement, ModelUpgrade, UnifiedRegistry};
70pub use remote::RemoteBackend;
71pub use runner::{
72 current_inference_runner, set_inference_runner, EventEmitter, InferenceRunner, RunnerError,
73 RunnerResult,
74};
75pub use routing_ext::{
76 CircuitBreaker, CircuitBreakerRegistry, CircuitState, ImplicitSignal, ImplicitSignalType,
77 RoutingMode, SpendControl, SpendLimitExceeded, SpendLimits, SpendStatus,
78};
79pub use schema::{
80 ApiProtocol, BenchmarkScore, CostModel, ModelCapability, ModelSchema, ModelSource,
81 PerformanceEnvelope, ProprietaryAuth,
82};
83
84pub use adaptive_router::TaskComplexity;
86#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
87pub use backend::CandleBackend;
88#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
89pub use backend::EmbeddingBackend;
90#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
91pub use backend::MlxBackend;
92pub use hardware::HardwareInfo;
93pub use models::{ModelRegistry, ModelRole};
94pub use router::{ModelRouter, RoutingDecision};
95pub use stream::{StreamAccumulator, StreamEvent};
96pub use tasks::{
97 parse_boxes, BoundingBox, ClassifyRequest, ClassifyResult, ContentBlock, EmbedRequest,
98 GenerateImageRequest, GenerateImageResult, GenerateParams, GenerateRequest,
99 GenerateVideoRequest, GenerateVideoResult, GroundRequest, GroundResult, Message, RerankRequest,
100 RerankResult, RerankedDocument, RoutingWorkload, SynthesizeRequest, SynthesizeResult,
101 ThinkingMode, ToolCall, TranscribeRequest, TranscribeResult, VideoMode,
102};
103#[derive(Error, Debug)]
106pub enum InferenceError {
107 #[error("model not found: {0}")]
108 ModelNotFound(String),
109
110 #[error("model download failed: {0}")]
111 DownloadFailed(String),
112
113 #[error("inference failed: {0}")]
114 InferenceFailed(String),
115
116 #[error("mode {mode} not implemented on backend {backend}: {reason}")]
121 UnsupportedMode {
122 mode: &'static str,
123 backend: &'static str,
124 reason: &'static str,
125 },
126
127 #[error("tokenization error: {0}")]
128 TokenizationError(String),
129
130 #[error("device error: {0}")]
131 DeviceError(String),
132
133 #[error("io error: {0}")]
134 Io(#[from] std::io::Error),
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub enum Device {
140 Cpu,
141 Metal,
142 Cuda(usize), }
144
145impl Device {
146 pub fn auto() -> Self {
148 #[cfg(feature = "metal")]
149 {
150 return Device::Metal;
151 }
152 #[cfg(feature = "cuda")]
153 {
154 return Device::Cuda(0);
155 }
156 #[cfg(not(any(feature = "metal", feature = "cuda")))]
157 {
158 Device::Cpu
159 }
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct InferenceConfig {
166 pub models_dir: std::path::PathBuf,
168 pub device: Option<Device>,
170 pub generation_model: String,
172 pub preferred_generation_model: Option<String>,
174 pub embedding_model: String,
176 pub preferred_embedding_model: Option<String>,
178 pub classification_model: String,
180 pub preferred_classification_model: Option<String>,
182}
183
184impl Default for InferenceConfig {
185 fn default() -> Self {
186 let models_dir = dirs_next()
187 .unwrap_or_else(|| std::path::PathBuf::from("."))
188 .join(".car")
189 .join("models");
190
191 let hw = HardwareInfo::detect();
192
193 Self {
194 models_dir,
195 device: None,
196 generation_model: hw.recommended_model,
197 preferred_generation_model: None,
198 embedding_model: "Qwen3-Embedding-0.6B".to_string(),
199 preferred_embedding_model: None,
200 classification_model: "Qwen3-0.6B".to_string(),
201 preferred_classification_model: None,
202 }
203 }
204}
205
206fn dirs_next() -> Option<std::path::PathBuf> {
207 std::env::var("HOME").ok().map(std::path::PathBuf::from)
208}
209
210#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
212pub struct TokenUsage {
213 pub prompt_tokens: u64,
215 pub completion_tokens: u64,
217 pub total_tokens: u64,
219 pub context_window: u64,
221}
222
223#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
225pub struct InferenceResult {
226 pub text: String,
228 pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
230 #[serde(default, skip_serializing_if = "Vec::is_empty")]
237 pub bounding_boxes: Vec<crate::tasks::grounding::BoundingBox>,
238 pub trace_id: String,
240 pub model_used: String,
242 pub latency_ms: u64,
244 #[serde(default)]
256 pub time_to_first_token_ms: Option<u64>,
257 pub usage: Option<TokenUsage>,
259 #[serde(default, skip_serializing_if = "Vec::is_empty")]
272 pub provider_output_items: Vec<serde_json::Value>,
273}
274
275impl InferenceResult {
276 pub fn has_tool_calls(&self) -> bool {
278 !self.tool_calls.is_empty()
279 }
280}
281
282#[derive(Debug, Clone, Serialize)]
283pub struct SpeechRuntimeHealth {
284 pub root: PathBuf,
285 pub installed: bool,
286 pub python: PathBuf,
287 pub stt_command: PathBuf,
288 pub tts_command: PathBuf,
289 pub configured_python: Option<String>,
290 pub detected_python: Option<String>,
291}
292
293#[derive(Debug, Clone, Serialize)]
294pub struct SpeechModelHealth {
295 pub id: String,
296 pub name: String,
297 pub provider: String,
298 pub capability: ModelCapability,
299 pub is_local: bool,
300 pub available: bool,
301 pub cached: bool,
302 pub selected_by_default: bool,
303 pub source: String,
304}
305
306#[derive(Debug, Clone, Serialize)]
307pub struct SpeechHealthReport {
308 pub runtime: SpeechRuntimeHealth,
309 pub local_models: Vec<SpeechModelHealth>,
310 pub remote_models: Vec<SpeechModelHealth>,
311 pub elevenlabs_configured: bool,
312 pub prefer_local: bool,
313 pub allow_remote_fallback: bool,
314 pub preferred_local_stt: Option<String>,
315 pub preferred_local_tts: Option<String>,
316 pub preferred_remote_stt: Option<String>,
317 pub preferred_remote_tts: Option<String>,
318 pub local_stt_default: Option<String>,
319 pub local_tts_default: Option<String>,
320 pub remote_stt_default: Option<String>,
321 pub remote_tts_default: Option<String>,
322}
323
324#[derive(Debug, Clone, Serialize)]
325pub struct ModelDefaultHealth {
326 pub capability: ModelCapability,
327 pub configured_model: String,
328 pub available: bool,
329 pub is_local: bool,
330 pub provider: Option<String>,
331}
332
333#[derive(Debug, Clone, Serialize)]
334pub struct ModelProviderHealth {
335 pub provider: String,
336 pub configured: bool,
337 pub local_models: usize,
338 pub remote_models: usize,
339 pub available_models: usize,
340 pub capabilities: Vec<ModelCapability>,
341}
342
343#[derive(Debug, Clone, Serialize)]
344pub struct ModelCapabilityHealth {
345 pub capability: ModelCapability,
346 pub total_models: usize,
347 pub available_models: usize,
348 pub local_available_models: usize,
349 pub remote_available_models: usize,
350}
351
352#[derive(Debug, Clone, Serialize)]
353pub struct RoutingScenarioHealth {
354 pub name: String,
355 pub workload: RoutingWorkload,
356 pub task_family: String,
357 pub has_tools: bool,
358 pub has_vision: bool,
359 pub prefer_local: bool,
360 pub quality_first_cold_start: bool,
361 pub bootstrap_min_task_observations: u64,
362 pub bootstrap_quality_floor: f64,
363 pub model_id: String,
364 pub model_name: String,
365 pub reason: String,
366 pub strategy: RoutingStrategy,
367}
368
369#[derive(Debug, Clone, Serialize)]
370pub struct ModelBenchmarkPriorHealth {
371 pub model_id: String,
372 pub model_name: Option<String>,
373 pub overall_score: f64,
374 pub overall_latency_ms: Option<f64>,
375 pub task_scores: std::collections::HashMap<String, f64>,
376 pub task_latency_ms: std::collections::HashMap<String, f64>,
377 pub source_path: PathBuf,
378}
379
380#[derive(Debug, Clone, Serialize)]
381pub struct ModelHealthReport {
382 pub total_models: usize,
383 pub available_models: usize,
384 pub local_models: usize,
385 pub remote_models: usize,
386 pub defaults: Vec<ModelDefaultHealth>,
387 pub providers: Vec<ModelProviderHealth>,
388 pub capabilities: Vec<ModelCapabilityHealth>,
389 pub routing_prefer_local: bool,
390 pub routing_quality_first_cold_start: bool,
391 pub routing_min_observations: u64,
392 pub routing_bootstrap_min_task_observations: u64,
393 pub routing_bootstrap_quality_floor: f64,
394 pub routing_quality_weight: f64,
395 pub routing_latency_weight: f64,
396 pub routing_cost_weight: f64,
397 pub routing_scenarios: Vec<RoutingScenarioHealth>,
398 pub benchmark_priors: Vec<ModelBenchmarkPriorHealth>,
399 pub speech: SpeechHealthReport,
400}
401
402#[derive(Debug, Clone, Serialize)]
403pub struct SpeechInstallReport {
404 pub name: String,
405 pub hf_repo: String,
406 pub snapshot_path: PathBuf,
407 pub files_downloaded: usize,
408}
409
410#[derive(Debug, Clone, Serialize)]
411pub struct SpeechSmokePathReport {
412 pub path: String,
413 pub tts_model: String,
414 pub stt_model: String,
415 pub audio_path: PathBuf,
416 pub transcript: String,
417}
418
419#[derive(Debug, Clone, Serialize, Default)]
420pub struct SpeechSmokeReport {
421 pub local: Option<SpeechSmokePathReport>,
422 pub remote: Option<SpeechSmokePathReport>,
423 pub skipped: Vec<String>,
424}
425
426#[derive(Debug, Clone, Serialize, Default)]
427pub struct SpeechPolicy {
428 pub prefer_local: bool,
429 pub allow_remote_fallback: bool,
430 pub preferred_local_stt: Option<String>,
431 pub preferred_local_tts: Option<String>,
432 pub preferred_remote_stt: Option<String>,
433 pub preferred_remote_tts: Option<String>,
434}
435
436pub struct InferenceEngine {
441 pub config: InferenceConfig,
442 pub unified_registry: UnifiedRegistry,
444 pub adaptive_router: AdaptiveRouter,
446 pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
448 remote_backend: RemoteBackend,
450 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
456 mlx_backends: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
457 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
462 flux_cache: Arc<backend_cache::BackendCache<backend::mlx_flux::FluxBackend>>,
463 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
465 ltx_cache: Arc<backend_cache::BackendCache<backend::mlx_ltx::LtxBackend>>,
466 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
469 kokoro_cache: Arc<backend_cache::BackendCache<backend::mlx_kokoro::KokoroBackend>>,
470 pub registry: models::ModelRegistry,
472 pub router: ModelRouter,
473 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
474 backend: Arc<RwLock<Option<CandleBackend>>>,
475 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
476 embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
477 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
478 speech_runtime: Arc<Mutex<Option<SpeechRuntime>>>,
479 speech_policy: SpeechPolicy,
480}
481
482impl InferenceEngine {
483 fn preferred_model_for_capability(&self, capability: ModelCapability) -> Option<&str> {
484 match capability {
485 ModelCapability::Generate => self.config.preferred_generation_model.as_deref(),
486 ModelCapability::Embed => self.config.preferred_embedding_model.as_deref(),
487 ModelCapability::Classify => self.config.preferred_classification_model.as_deref(),
488 _ => None,
489 }
490 }
491
492 fn request_needs_vision(req: &GenerateRequest) -> bool {
493 req.images.as_ref().is_some_and(|images| !images.is_empty())
494 || req.messages.as_ref().is_some_and(|messages| {
495 messages
496 .iter()
497 .any(|msg| matches!(msg, Message::UserMultimodal { .. }))
498 })
499 }
500
501 fn request_has_video(req: &GenerateRequest) -> bool {
506 let images_have_video = req
507 .images
508 .as_ref()
509 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
510 let messages_have_video = req.messages.as_ref().is_some_and(|messages| {
511 messages.iter().any(|msg| match msg {
512 Message::UserMultimodal { content } => {
513 content.iter().any(ContentBlock::is_video)
514 }
515 _ => false,
516 })
517 });
518 images_have_video || messages_have_video
519 }
520
521 fn request_has_audio(req: &GenerateRequest) -> bool {
525 let images_have_audio = req
526 .images
527 .as_ref()
528 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
529 let messages_have_audio = req.messages.as_ref().is_some_and(|messages| {
530 messages.iter().any(|msg| match msg {
531 Message::UserMultimodal { content } => {
532 content.iter().any(ContentBlock::is_audio)
533 }
534 _ => false,
535 })
536 });
537 images_have_audio || messages_have_audio
538 }
539
540 pub fn new(config: InferenceConfig) -> Self {
541 let registry = models::ModelRegistry::new(config.models_dir.clone());
542 let hw = HardwareInfo::detect();
543 let router = ModelRouter::new(hw.clone());
544 let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
545 let adaptive_router = AdaptiveRouter::with_default_config(hw);
546 let mut tracker = OutcomeTracker::new();
547 let profiles_path = config.models_dir.join("outcome_profiles.json");
549 if let Ok(n) = tracker.load_from_file(&profiles_path) {
550 if n > 0 {
551 tracing::info!(loaded = n, "loaded persisted model profiles");
552 }
553 }
554 let mut benchmark_models_loaded = 0usize;
555 for path in benchmark_priors_paths(&config.models_dir) {
556 match routing_ext::load_benchmark_priors(&path) {
557 Ok(priors) if !priors.is_empty() => {
558 benchmark_models_loaded += priors.len();
559 routing_ext::apply_benchmark_priors(&mut tracker, &priors);
560 tracing::info!(
561 path = %path.display(),
562 loaded = priors.len(),
563 "loaded benchmark quality priors"
564 );
565 }
566 Ok(_) => {}
567 Err(error) => {
568 tracing::warn!(path = %path.display(), %error, "failed to load benchmark priors");
569 }
570 }
571 }
572 if benchmark_models_loaded > 0 {
573 tracing::info!(
574 loaded = benchmark_models_loaded,
575 "applied benchmark priors to cold-start routing"
576 );
577 }
578 let outcome_tracker = Arc::new(RwLock::new(tracker));
579
580 let remote_backend = RemoteBackend::new();
581
582 Self {
583 config,
584 unified_registry,
585 adaptive_router,
586 outcome_tracker,
587 remote_backend,
588 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
589 mlx_backends: Arc::new(backend_cache::BackendCache::from_env()),
590 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
591 flux_cache: Arc::new(backend_cache::BackendCache::from_env()),
592 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
593 ltx_cache: Arc::new(backend_cache::BackendCache::from_env()),
594 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
595 kokoro_cache: Arc::new(backend_cache::BackendCache::from_env()),
596 registry,
597 router,
598 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
599 backend: Arc::new(RwLock::new(None)),
600 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
601 embedding_backend: Arc::new(RwLock::new(None)),
602 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
603 speech_runtime: Arc::new(Mutex::new(None)),
604 speech_policy: SpeechPolicy {
605 prefer_local: cfg!(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))),
606 allow_remote_fallback: true,
607 preferred_local_stt: None,
608 preferred_local_tts: None,
609 preferred_remote_stt: None,
610 preferred_remote_tts: None,
611 },
612 }
613 }
614
615 pub async fn init_key_pool(&self) {
618 for schema in self.unified_registry.list() {
620 if schema.is_remote() {
621 self.remote_backend.register_model_keys(schema).await;
622 }
623 }
624
625 let stats_path = self.config.models_dir.join("key_pool_stats.json");
627 if let Ok(n) = self.remote_backend.key_pool.load_stats(&stats_path).await {
628 if n > 0 {
629 tracing::info!(loaded = n, "loaded persisted key pool stats");
630 }
631 }
632
633 let total = self.remote_backend.key_pool.total_keys().await;
634 if total > 0 {
635 tracing::info!(keys = total, "key pool initialized");
636 }
637 }
638
639 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
642 async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
643 let read = self.backend.read().await;
644 if read.is_some() {
645 return Ok(());
646 }
647 drop(read);
648
649 let mut write = self.backend.write().await;
650 if write.is_some() {
651 return Ok(());
652 }
653
654 let model_path = self.registry.ensure_model(model_name).await?;
655 let device = self.config.device.unwrap_or_else(Device::auto);
656 let backend = CandleBackend::load(&model_path, device)?;
657 *write = Some(backend);
658 Ok(())
659 }
660
661 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
664 async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
665 let read = self.embedding_backend.read().await;
666 if read.is_some() {
667 return Ok(());
668 }
669 drop(read);
670
671 let mut write = self.embedding_backend.write().await;
672 if write.is_some() {
673 return Ok(());
674 }
675
676 let embedding_model = self
677 .preferred_model_for_capability(ModelCapability::Embed)
678 .unwrap_or(&self.config.embedding_model);
679 let model_path = self
680 .registry
681 .ensure_model(embedding_model)
682 .await?;
683 let device = self.config.device.unwrap_or_else(Device::auto);
684 let backend = EmbeddingBackend::load(&model_path, device)?;
685 *write = Some(backend);
686 Ok(())
687 }
688
689 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
692 async fn ensure_mlx_embedding_backend(&self) -> Result<String, InferenceError> {
693 let embedding_model_name = self
694 .preferred_model_for_capability(ModelCapability::Embed)
695 .unwrap_or(&self.config.embedding_model)
696 .to_string();
697 let schema = self
698 .unified_registry
699 .get(&embedding_model_name)
700 .or_else(|| self.unified_registry.find_by_name(&embedding_model_name))
701 .ok_or_else(|| InferenceError::ModelNotFound(embedding_model_name.clone()))?
702 .clone();
703 self.ensure_mlx_backend(&schema).await?;
704 Ok(schema.id)
705 }
706
707 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
711 async fn ensure_mlx_backend(
712 &self,
713 schema: &ModelSchema,
714 ) -> Result<backend_cache::CachedBackend<backend::MlxBackend>, InferenceError> {
715 if !Self::supports_native_mlx(schema) {
716 return Err(InferenceError::InferenceFailed(format!(
717 "native MLX backend does not support {} ({}) yet; use vLLM-MLX or add a family-specific MLX backend",
718 schema.name, schema.family
719 )));
720 }
721
722 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
723 let size = backend_cache::estimate_model_size(&model_dir);
724 let cache = Arc::clone(&self.mlx_backends);
725 let key = schema.id.clone();
726 cache.get_or_load(&key, size, move || {
730 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
731 backend::MlxBackend::load(&model_dir)
732 }))
733 .map_err(|e| {
734 InferenceError::InferenceFailed(format!(
735 "MLX backend loading panicked (possible Metal/accelerate exception): {:?}",
736 e
737 ))
738 })?
739 })
740 }
741
742 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
749 pub async fn warm_up<S: AsRef<str>>(&self, schema_ids: &[S]) -> Vec<Result<(), InferenceError>> {
750 let mut results = Vec::with_capacity(schema_ids.len());
751 for id in schema_ids {
752 let id = id.as_ref();
753 let outcome: Result<(), InferenceError> = async {
754 let schema = self
755 .unified_registry
756 .get(id)
757 .cloned()
758 .ok_or_else(|| InferenceError::InferenceFailed(
759 format!("warm_up: unknown schema id {id}"),
760 ))?;
761 match schema.capabilities.first().copied() {
762 Some(ModelCapability::ImageGeneration) => {
763 let model_dir =
764 self.unified_registry.ensure_local(&schema.id).await?;
765 let size = backend_cache::estimate_model_size(&model_dir);
766 let _ = self.flux_cache.get_or_load(&schema.id, size, || {
767 backend::mlx_flux::FluxBackend::load(&model_dir)
768 })?;
769 }
770 Some(ModelCapability::VideoGeneration) => {
771 let model_dir =
772 self.unified_registry.ensure_local(&schema.id).await?;
773 let size = backend_cache::estimate_model_size(&model_dir);
774 let _ = self.ltx_cache.get_or_load(&schema.id, size, || {
775 backend::mlx_ltx::LtxBackend::load(&model_dir)
776 })?;
777 }
778 Some(ModelCapability::TextToSpeech) => {
779 let model_dir =
780 self.unified_registry.ensure_local(&schema.id).await?;
781 let size = backend_cache::estimate_model_size(&model_dir);
782 let _ = self.kokoro_cache.get_or_load(&schema.id, size, || {
783 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
784 })?;
785 }
786 _ => {
787 let _ = self.ensure_mlx_backend(&schema).await?;
788 }
789 }
790 Ok(())
791 }
792 .await;
793 results.push(outcome);
794 }
795 results
796 }
797
798 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
800 pub async fn warm_up<S: AsRef<str>>(&self, _schema_ids: &[S]) -> Vec<Result<(), InferenceError>> {
801 Vec::new()
802 }
803
804 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
805 fn supports_native_mlx(schema: &ModelSchema) -> bool {
806 matches!(schema.family.as_str(), "qwen3" | "qwen2.5-vl" | "qwen2-vl")
807 }
808
809 pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
811 if let Some(model) = self.preferred_model_for_capability(ModelCapability::Generate) {
812 let ctx_len = self
813 .unified_registry
814 .get(model)
815 .or_else(|| self.unified_registry.find_by_name(model))
816 .map(|s| s.context_length)
817 .unwrap_or(0);
818 return AdaptiveRoutingDecision {
819 model_id: model.to_string(),
820 model_name: model.to_string(),
821 task: InferenceTask::Generate,
822 complexity: TaskComplexity::assess(prompt),
823 reason: "preferred generation model override".into(),
824 strategy: RoutingStrategy::Explicit,
825 predicted_quality: 0.5,
826 fallbacks: vec![],
827 context_length: ctx_len,
828 needs_compaction: false,
829 };
830 }
831 let tracker = self.outcome_tracker.read().await;
832 self.adaptive_router
833 .route(prompt, &self.unified_registry, &tracker)
834 }
835
836 pub fn route(&self, prompt: &str) -> RoutingDecision {
838 self.router.route_generate(prompt, &self.registry)
839 }
840
841 pub fn estimated_tokens(
844 &self,
845 req: &GenerateRequest,
846 model_id: Option<&str>,
847 ) -> (usize, usize, bool) {
848 let prompt_tokens = remote::estimate_tokens(&req.prompt);
849 let context_tokens = req
850 .context
851 .as_ref()
852 .map(|c| remote::estimate_tokens(c))
853 .unwrap_or(0);
854 let tools_tokens = req
855 .tools
856 .as_ref()
857 .map(|t| remote::estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
858 .unwrap_or(0);
859 let total_input = prompt_tokens + context_tokens + tools_tokens;
860
861 let context_window = model_id
862 .and_then(|id| {
863 self.unified_registry
864 .get(id)
865 .or_else(|| self.unified_registry.find_by_name(id))
866 })
867 .map(|s| s.context_length)
868 .unwrap_or(0);
869
870 let fits = context_window == 0 || (total_input + req.params.max_tokens) <= context_window;
871 (total_input, context_window, fits)
872 }
873
874 #[instrument(
877 name = "inference.generate",
878 skip_all,
879 fields(
880 model = tracing::field::Empty,
881 max_tokens = req.params.max_tokens,
882 prompt_tokens = tracing::field::Empty,
883 completion_tokens = tracing::field::Empty,
884 latency_ms = tracing::field::Empty,
885 )
886 )]
887 pub async fn generate_tracked(
888 &self,
889 req: GenerateRequest,
890 ) -> Result<InferenceResult, InferenceError> {
891 let start = Instant::now();
892
893 let (estimated_input, _, _) = self.estimated_tokens(&req, None);
895 let tracker_read = self.outcome_tracker.read().await;
896 let has_tools = req.tools.is_some();
897 let has_vision = Self::request_needs_vision(&req);
898 let preferred_model = self
899 .preferred_model_for_capability(ModelCapability::Generate)
900 .map(str::to_string);
901 let decision = match req.model.clone().or(preferred_model) {
902 Some(m) => {
903 let ctx_len = self
904 .unified_registry
905 .get(&m)
906 .or_else(|| self.unified_registry.find_by_name(&m))
907 .map(|s| s.context_length)
908 .unwrap_or(0);
909 AdaptiveRoutingDecision {
910 model_id: m.clone(),
911 model_name: m.clone(),
912 task: InferenceTask::Generate,
913 complexity: TaskComplexity::assess(&req.prompt),
914 reason: "explicit model".into(),
915 strategy: RoutingStrategy::Explicit,
916 predicted_quality: 0.5,
917 fallbacks: vec![],
918 context_length: ctx_len,
919 needs_compaction: ctx_len > 0 && estimated_input > ctx_len,
920 }
921 }
922 None => match &req.intent {
923 Some(hint) => self.adaptive_router.route_context_aware_with_intent(
924 &req.prompt,
925 estimated_input,
926 &self.unified_registry,
927 &tracker_read,
928 has_tools,
929 has_vision,
930 req.params.workload,
931 hint,
932 ),
933 None => self.adaptive_router.route_context_aware(
934 &req.prompt,
935 estimated_input,
936 &self.unified_registry,
937 &tracker_read,
938 has_tools,
939 has_vision,
940 req.params.workload,
941 ),
942 },
943 };
944 drop(tracker_read);
945
946 if decision.needs_compaction {
947 tracing::info!(
948 model = %decision.model_name,
949 prompt_tokens = estimated_input,
950 context_window = decision.context_length,
951 "prompt exceeds model context window — compaction or truncation needed"
952 );
953 }
954
955 let trace_id = {
957 let mut tracker = self.outcome_tracker.write().await;
958 tracker.record_start(&decision.model_id, decision.task, &decision.reason)
959 };
960
961 debug!(
962 model = %decision.model_name,
963 strategy = ?decision.strategy,
964 reason = %decision.reason,
965 trace = %trace_id,
966 "adaptive-routed generate request"
967 );
968
969 let mut req = req;
972 if req.params.budget_tokens == 0 && matches!(decision.complexity, TaskComplexity::Complex) {
973 let supports_thinking = self
974 .unified_registry
975 .get(&decision.model_id)
976 .map(|s| {
977 s.supported_params
978 .contains(&schema::GenerateParam::ExtendedThinking)
979 })
980 .unwrap_or(false);
981 if supports_thinking {
982 req.params.budget_tokens = 8000;
983 tracing::info!(model = %decision.model_name, budget = 8000, "auto-enabled extended thinking for complex task");
984 }
985 }
986
987 let mut models_to_try = vec![decision.model_id.clone()];
989 models_to_try.extend(decision.fallbacks.iter().cloned());
990
991 let mut last_error = None;
992
993 for candidate_id in &models_to_try {
994 #[allow(unused_mut)]
997 let mut schema = self
998 .unified_registry
999 .get(candidate_id)
1000 .or_else(|| self.unified_registry.find_by_name(candidate_id))
1001 .cloned();
1002
1003 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1005 if let Some(ref s) = schema {
1006 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1007 tracing::info!(
1008 from = %s.id, to = %mlx_equiv.id,
1009 "redirecting GGUF model to MLX equivalent on Apple Silicon"
1010 );
1011 schema = Some(mlx_equiv.clone());
1012 }
1013 }
1014
1015 let candidate_name = schema
1016 .as_ref()
1017 .map(|s| s.name.clone())
1018 .unwrap_or_else(|| candidate_id.clone());
1019
1020 let is_remote = schema
1021 .as_ref()
1022 .map(|s| s.is_remote() || s.is_vllm_mlx())
1023 .unwrap_or(false);
1024 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1025
1026 if is_delegated {
1032 let runner = match runner::current_inference_runner() {
1033 Some(r) => r,
1034 None => {
1035 last_error = Some(InferenceError::InferenceFailed(
1036 "model declares ModelSource::Delegated but no inference runner is registered"
1037 .into(),
1038 ));
1039 continue;
1040 }
1041 };
1042 let (tx, mut rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1043 let emitter = runner::EventEmitter::new(tx);
1044 let runner_req = req.clone();
1045 let runner_handle = tokio::spawn(async move { runner.run(runner_req, emitter).await });
1046 let mut accumulator = stream::StreamAccumulator::default();
1047 while let Some(evt) = rx.recv().await {
1048 accumulator.push(&evt);
1049 }
1050 let (acc_text, acc_tool_calls) = accumulator.finish();
1054 match runner_handle.await {
1055 Ok(Ok(_runner_result)) => {
1056 let elapsed = start.elapsed().as_millis() as u64;
1057 let mut tracker = self.outcome_tracker.write().await;
1058 tracker.record_complete(&trace_id, elapsed, 0, 0);
1059 return Ok(InferenceResult {
1060 text: acc_text,
1061 tool_calls: acc_tool_calls,
1062 bounding_boxes: vec![],
1063 trace_id,
1064 model_used: candidate_name,
1065 latency_ms: elapsed,
1066 time_to_first_token_ms: None,
1067 usage: None,
1068 provider_output_items: vec![],
1069 });
1070 }
1071 Ok(Err(e)) => {
1072 last_error = Some(InferenceError::InferenceFailed(e.to_string()));
1073 continue;
1074 }
1075 Err(join_err) => {
1076 last_error = Some(InferenceError::InferenceFailed(format!(
1077 "runner task panicked: {join_err}"
1078 )));
1079 continue;
1080 }
1081 }
1082 }
1083
1084 let has_tools = req.tools.is_some();
1085
1086 let context = if has_tools
1088 && req.tools.as_ref().map_or(false, |t| {
1089 t.iter().any(|tool| {
1090 tool.get("function")
1091 .and_then(|f| f.get("name"))
1092 .and_then(|n| n.as_str())
1093 == Some("done")
1094 })
1095 }) {
1096 let base = req.context.as_deref().unwrap_or("");
1097 Some(format!(
1098 "{base}\n\nIMPORTANT: When calling the `done` tool, the `result` field MUST contain a DETAILED summary of everything you found and did. This is the ONLY output the user sees. Do NOT just say 'completed' — include specific findings, data, and conclusions."
1099 ))
1100 } else {
1101 req.context.clone()
1102 };
1103
1104 let result = if is_remote {
1105 let schema_val = schema.unwrap();
1106 let _ctx_len = schema_val.context_length;
1107 let temperature = if !schema_val.supported_params.is_empty()
1110 && !schema_val
1111 .supported_params
1112 .contains(&crate::schema::GenerateParam::Temperature)
1113 {
1114 -1.0
1115 } else {
1116 req.params.temperature
1117 };
1118
1119 self.remote_backend
1125 .generate_with_tools_multi(
1126 &schema_val,
1127 &req.prompt,
1128 context.as_deref(),
1129 temperature,
1130 req.params.max_tokens,
1131 req.tools.as_deref(),
1132 req.images.as_deref(),
1133 req.messages.as_deref(),
1134 req.params.tool_choice.as_deref(),
1135 req.params.parallel_tool_calls,
1136 req.params.budget_tokens,
1137 req.cache_control,
1138 req.response_format.as_ref(),
1139 )
1140 .await
1141 .map(|(t, c, u)| (t, c, u, None::<u64>))
1146 } else {
1147 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1148 {
1149 let schema_ref = schema
1153 .as_ref()
1154 .ok_or_else(|| InferenceError::ModelNotFound(candidate_id.clone()))?;
1155
1156 if schema_ref.is_foundation_models() {
1161 if has_tools {
1162 Err(InferenceError::UnsupportedMode {
1163 mode: "tool-use",
1164 backend: "foundation-models",
1165 reason:
1166 "tool calling not yet wired through the FoundationModels \
1167 bridge — route to a remote model with ToolUse capability",
1168 })
1169 } else if Self::request_has_video(&req)
1170 || Self::request_has_audio(&req)
1171 || req
1172 .images
1173 .as_ref()
1174 .is_some_and(|imgs| !imgs.is_empty())
1175 {
1176 Err(InferenceError::UnsupportedMode {
1177 mode: "multimodal-content",
1178 backend: "foundation-models",
1179 reason:
1180 "the FoundationModels bridge currently exposes text-only \
1181 generation — route image/audio/video to a remote VL model",
1182 })
1183 } else {
1184 let prompt = req.prompt.clone();
1185 let instructions = context.clone();
1186 let max_tokens = req.params.max_tokens as u32;
1187 let temperature = req.params.temperature;
1188 tokio::task::spawn_blocking(move || {
1189 crate::backend::foundation_models::generate(
1190 &prompt,
1191 instructions.as_deref(),
1192 max_tokens,
1193 temperature as f32,
1194 )
1195 })
1196 .await
1197 .map_err(|e| {
1198 InferenceError::InferenceFailed(format!(
1199 "FoundationModels task panicked: {e}"
1200 ))
1201 })
1202 .and_then(|r| r)
1203 .map(|text| (text, vec![], None, None))
1204 }
1205 } else if !schema_ref.is_mlx() {
1206 Err(InferenceError::InferenceFailed(format!(
1207 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1208 schema_ref.id
1209 )))
1210 } else if schema_ref.tags.iter().any(|t| t == "mlx-vlm-cli") {
1211 let has_images = req
1223 .images
1224 .as_ref()
1225 .is_some_and(|imgs| !imgs.is_empty());
1226 if !has_images {
1227 return Err(InferenceError::UnsupportedMode {
1228 mode: "text-only-on-mlx-vlm-id",
1229 backend: "mlx-vlm-cli",
1230 reason:
1231 "the `mlx-vlm/...` model IDs route exclusively \
1232 through the mlx-vlm CLI for image inference. \
1233 For text-only generation, route to a Qwen3 \
1234 text model (`mlx/qwen3-4b:4bit` etc.) — the \
1235 CLI shell-out has higher latency than the \
1236 in-process MLX text tower.",
1237 });
1238 }
1239 let vlm_status = crate::backend::mlx_vlm_cli::runtime_status();
1240 if !vlm_status.is_available() {
1241 return Err(InferenceError::InferenceFailed(vlm_status.user_message()));
1242 }
1243 let repo = match &schema_ref.source {
1244 crate::schema::ModelSource::Mlx { hf_repo, .. } => {
1245 hf_repo.clone()
1246 }
1247 _ => {
1248 return Err(InferenceError::InferenceFailed(format!(
1249 "model '{}' is tagged mlx-vlm-cli but its \
1250 source isn't ModelSource::Mlx — registry bug",
1251 schema_ref.id
1252 )));
1253 }
1254 };
1255 let imgs = req.images.clone().unwrap_or_default();
1256 let temp = req.params.temperature;
1257 let max_t = req.params.max_tokens;
1258 let prompt = req.prompt.clone();
1259 let text = tokio::task::spawn_blocking(move || {
1260 crate::backend::mlx_vlm_cli::generate(
1261 &repo, &prompt, &imgs, temp, max_t,
1262 )
1263 })
1264 .await
1265 .map_err(|e| {
1266 InferenceError::InferenceFailed(format!(
1267 "mlx_vlm CLI task panicked: {e}"
1268 ))
1269 })??;
1270 let bounding_boxes = parse_boxes(&text);
1271 let latency_ms = start.elapsed().as_millis() as u64;
1272 {
1273 let mut tracker = self.outcome_tracker.write().await;
1274 tracker.record_complete(&trace_id, latency_ms, 0, 0);
1275 }
1276 return Ok(InferenceResult {
1277 text,
1278 tool_calls: vec![],
1279 bounding_boxes,
1280 trace_id: trace_id.clone(),
1281 model_used: schema_ref.id.clone(),
1282 latency_ms,
1283 time_to_first_token_ms: None,
1284 usage: None,
1285 provider_output_items: Vec::new(),
1286 });
1287 } else {
1288 let handle = self.ensure_mlx_backend(schema_ref).await?;
1298 if Self::request_has_video(&req) {
1304 return Err(InferenceError::UnsupportedMode {
1305 mode: "video-content-block",
1306 backend: "native-mlx-qwen25vl",
1307 reason:
1308 "Qwen2.5-VL video understanding is on the request surface \
1309 but the video-tokenization path (frame sampling + merger) \
1310 is not yet wired; route to a remote VL provider for now",
1311 });
1312 }
1313 if Self::request_has_audio(&req) {
1314 return Err(InferenceError::UnsupportedMode {
1315 mode: "audio-content-block",
1316 backend: "native-mlx-qwen25vl",
1317 reason:
1318 "audio understanding is on the request surface (Gemma 4 \
1319 E2B/E4B and Gemini accept it) but the native MLX path \
1320 for this model does not — route to Gemini or Gemma-4",
1321 });
1322 }
1323 let has_images = req
1324 .images
1325 .as_ref()
1326 .is_some_and(|imgs| !imgs.is_empty());
1327 if has_images {
1328 let can_do_vision = {
1329 let guard = handle.lock().map_err(|_| {
1330 InferenceError::InferenceFailed(
1331 "MLX backend mutex poisoned".into(),
1332 )
1333 })?;
1334 guard.supports_capability(
1335 crate::schema::ModelCapability::Vision,
1336 )
1337 };
1338 if !can_do_vision {
1339 return Err(InferenceError::UnsupportedMode {
1349 mode: "image-content-block",
1350 backend: "native-mlx-text",
1351 reason:
1352 "this MLX backend is a plain Qwen3 text tower. \
1353 For local image inference, route to \
1354 `mlx-vlm/qwen3-vl-2b:bf16` or another `mlx-vlm/...` \
1355 catalog ID so CAR shells out to `mlx_vlm.generate`. \
1356 Alternatives: a local vLLM-MLX VLM server, or a \
1357 remote VL model. (#115)",
1358 });
1359 }
1360 }
1361 self.generate_mlx(req.clone(), &schema_ref.id)
1362 .await
1363 .map(|(text, ttft)| (text, vec![], None, ttft))
1364 }
1365 }
1366
1367 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1368 {
1369 match self.ensure_backend(&candidate_name).await {
1370 Ok(()) => {
1371 let mut write = self.backend.write().await;
1372 let backend = write.as_mut().unwrap();
1373 tasks::generate::generate(backend, req.clone())
1374 .await
1375 .map(|(text, ttft)| (text, vec![], None, ttft))
1376 }
1377 Err(e) => Err(e),
1378 }
1379 }
1380 };
1381
1382 match result {
1383 Ok((text, tool_calls, usage, time_to_first_token_ms)) => {
1384 let latency_ms = start.elapsed().as_millis() as u64;
1385 let estimated_tokens = usage
1386 .as_ref()
1387 .map(|u| u.completion_tokens as usize)
1388 .unwrap_or_else(|| text.split_whitespace().count());
1389 {
1390 let mut tracker = self.outcome_tracker.write().await;
1391 tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
1392 }
1393 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1395 cb.record_success(candidate_id);
1396 }
1397 self.auto_save_outcomes().await;
1399
1400 let span = tracing::Span::current();
1402 span.record("model", candidate_name.as_str());
1403 span.record("latency_ms", latency_ms);
1404 if let Some(ttft) = time_to_first_token_ms {
1405 span.record("ttft_ms", ttft);
1406 }
1407 if let Some(ref u) = usage {
1408 span.record("prompt_tokens", u.prompt_tokens);
1409 span.record("completion_tokens", u.completion_tokens);
1410 }
1411
1412 let bounding_boxes = tasks::grounding::parse_boxes(&text);
1415 return Ok(InferenceResult {
1416 text,
1417 tool_calls,
1418 bounding_boxes,
1419 trace_id,
1420 model_used: candidate_name,
1421 latency_ms,
1422 time_to_first_token_ms,
1423 usage,
1424 provider_output_items: Vec::new(),
1425 });
1426 }
1427 Err(e) => {
1428 tracing::warn!(
1429 model = %candidate_name,
1430 error = %e,
1431 remaining = models_to_try.len().saturating_sub(
1432 models_to_try.iter().position(|m| m == candidate_id).unwrap_or(0) + 1
1433 ),
1434 "model failed, trying next fallback immediately"
1435 );
1436 {
1438 let mut tracker = self.outcome_tracker.write().await;
1439 let fail_trace =
1440 tracker.record_start(candidate_id, decision.task, "fallback");
1441 tracker.record_failure(&fail_trace, &e.to_string());
1442 }
1443 {
1447 let err_str = e.to_string();
1448 let is_client_error =
1449 err_str.contains("API returned 4") && !err_str.contains("429");
1450 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1451 cb.record_failure(candidate_id);
1452 if is_client_error {
1455 cb.record_failure(candidate_id);
1456 }
1457 }
1458 }
1459 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1461 {
1462 let mut write = self.backend.write().await;
1463 *write = None;
1464 }
1465 last_error = Some(e);
1466 }
1467 }
1468 }
1469
1470 let e = last_error.unwrap_or(InferenceError::InferenceFailed(
1472 "no models available".into(),
1473 ));
1474 {
1475 let mut tracker = self.outcome_tracker.write().await;
1476 tracker.record_failure(&trace_id, &e.to_string());
1477 }
1478 self.auto_save_outcomes().await;
1479 Err(e)
1480 }
1481
1482 pub async fn generate_tracked_stream(
1519 &self,
1520 req: GenerateRequest,
1521 ) -> Result<tokio::sync::mpsc::Receiver<stream::StreamEvent>, InferenceError> {
1522 let has_tools = req.tools.is_some();
1523 let has_vision = Self::request_needs_vision(&req);
1524 let preferred_model = self
1525 .preferred_model_for_capability(ModelCapability::Generate)
1526 .map(str::to_string);
1527 let decision = match req.model.clone().or(preferred_model) {
1528 Some(m) => {
1529 let ctx_len = self
1530 .unified_registry
1531 .get(&m)
1532 .or_else(|| self.unified_registry.find_by_name(&m))
1533 .map(|s| s.context_length)
1534 .unwrap_or(0);
1535 AdaptiveRoutingDecision {
1536 model_id: m.clone(),
1537 model_name: m,
1538 task: InferenceTask::Generate,
1539 complexity: TaskComplexity::assess(&req.prompt),
1540 reason: "explicit model".into(),
1541 strategy: RoutingStrategy::Explicit,
1542 predicted_quality: 0.5,
1543 fallbacks: vec![],
1544 context_length: ctx_len,
1545 needs_compaction: false,
1546 }
1547 }
1548 None => {
1549 let tracker_read = self.outcome_tracker.read().await;
1550 if has_vision {
1551 self.adaptive_router.route_with_vision(
1552 &req.prompt,
1553 &self.unified_registry,
1554 &tracker_read,
1555 has_tools,
1556 )
1557 } else if has_tools {
1558 self.adaptive_router.route_with_tools(
1559 &req.prompt,
1560 &self.unified_registry,
1561 &tracker_read,
1562 )
1563 } else {
1564 self.adaptive_router
1565 .route(&req.prompt, &self.unified_registry, &tracker_read)
1566 }
1567 }
1568 };
1569
1570 #[allow(unused_mut)]
1573 let mut schema = self
1574 .unified_registry
1575 .get(&decision.model_id)
1576 .or_else(|| self.unified_registry.find_by_name(&decision.model_id))
1577 .cloned();
1578
1579 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1581 if let Some(ref s) = schema {
1582 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1583 tracing::info!(
1584 from = %s.id, to = %mlx_equiv.id,
1585 "redirecting GGUF model to MLX equivalent on Apple Silicon (stream)"
1586 );
1587 schema = Some(mlx_equiv.clone());
1588 }
1589 }
1590
1591 let is_remote = schema
1592 .as_ref()
1593 .map(|s| s.is_remote() || s.is_vllm_mlx())
1594 .unwrap_or(false);
1595
1596 let is_delegated = schema.as_ref().map(|s| s.is_delegated()).unwrap_or(false);
1597
1598 if is_delegated {
1599 let runner = runner::current_inference_runner().ok_or_else(|| {
1603 InferenceError::InferenceFailed(
1604 "model declares ModelSource::Delegated but no inference runner is registered \
1605 (call set_inference_runner / registerInferenceRunner / register_inference_runner)"
1606 .into(),
1607 )
1608 })?;
1609 let (tx, rx) = tokio::sync::mpsc::channel::<stream::StreamEvent>(64);
1610 let emitter = runner::EventEmitter::new(tx);
1611 let request = req.clone();
1612 tokio::spawn(async move {
1613 if let Err(e) = runner.run(request, emitter).await {
1614 tracing::warn!(error = %e, "delegated inference runner failed");
1615 }
1616 });
1617 return Ok(rx);
1618 }
1619
1620 if is_remote {
1621 let schema = schema.unwrap();
1622 self.remote_backend.register_model_keys(&schema).await;
1624
1625 self.remote_backend
1626 .generate_stream(
1627 &schema,
1628 &req.prompt,
1629 req.context.as_deref(),
1630 req.params.temperature,
1631 req.params.max_tokens,
1632 req.tools.as_deref(),
1633 req.images.as_deref(),
1634 req.params.tool_choice.as_deref(),
1635 req.params.parallel_tool_calls,
1636 req.response_format.as_ref(),
1637 )
1638 .await
1639 } else {
1640 let schema = schema.ok_or_else(|| {
1641 InferenceError::ModelNotFound(decision.model_id.clone())
1642 })?;
1643 let (tx, rx) = tokio::sync::mpsc::channel(64);
1644
1645 #[cfg(any(
1651 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
1652 all(target_os = "ios", target_arch = "aarch64")
1653 ))]
1654 {
1655 if schema.is_foundation_models() {
1656 let prompt = req.prompt.clone();
1657 let instructions = req.context.clone();
1658 let max_tokens = req.params.max_tokens as u32;
1659 let temperature = req.params.temperature;
1660 let tx_clone = tx.clone();
1661 tokio::task::spawn_blocking(move || {
1662 let accum = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
1669 let accum_cb = accum.clone();
1670 let cb = crate::backend::foundation_models::StreamCallback::new(
1671 move |delta: &str| {
1672 if let Ok(mut g) = accum_cb.lock() {
1673 g.push_str(delta);
1674 }
1675 tx_clone
1676 .blocking_send(stream::StreamEvent::TextDelta(
1677 delta.to_string(),
1678 ))
1679 .is_ok()
1680 },
1681 );
1682 let result = crate::backend::foundation_models::stream(
1683 &prompt,
1684 instructions.as_deref(),
1685 max_tokens,
1686 temperature as f32,
1687 cb,
1688 );
1689 let final_text = accum
1690 .lock()
1691 .map(|g| g.clone())
1692 .unwrap_or_default();
1693 let _ = tx.blocking_send(stream::StreamEvent::Done {
1694 text: final_text,
1695 tool_calls: vec![],
1696 });
1697 result
1698 });
1699 return Ok(rx);
1700 }
1701 }
1702
1703 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1705 {
1706 if !schema.is_mlx() {
1709 return Err(InferenceError::InferenceFailed(format!(
1710 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1711 schema.id
1712 )));
1713 }
1714 let backend = self.ensure_mlx_backend(&schema).await?;
1715 let model_id = schema.id.clone();
1716 let cache = Arc::clone(&self.mlx_backends);
1717 tokio::task::spawn_blocking(move || {
1723 let _ = Self::stream_local_mlx(backend, cache, model_id, req, tx);
1724 });
1725 return Ok(rx);
1726 }
1727
1728 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1729 {
1730 self.ensure_backend(&schema.name).await?;
1731 let backend = self.backend.clone();
1732 tokio::spawn(async move {
1733 let _ = Self::stream_local_candle(backend, req, tx).await;
1734 });
1735 Ok(rx)
1736 }
1737 }
1738 }
1739
1740 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1741 async fn stream_local_candle(
1742 backend_lock: Arc<RwLock<Option<CandleBackend>>>,
1743 req: GenerateRequest,
1744 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1745 ) -> Result<(), InferenceError> {
1746 let mut write = backend_lock.write().await;
1747 let backend = write
1748 .as_mut()
1749 .ok_or_else(|| InferenceError::InferenceFailed("backend not initialized".into()))?;
1750 backend.clear_kv_cache();
1751
1752 let formatted = tasks::generate::apply_chat_template(
1753 &req.prompt,
1754 req.context.as_deref(),
1755 req.params.thinking,
1756 );
1757 let tokens = backend.encode(&formatted)?;
1758 let eos = backend.eos_token_id();
1759 let eos_alt = backend.token_id("<|im_end|>");
1760 let params = &req.params;
1761
1762 if tokens.is_empty() {
1763 let _ = tx
1764 .send(stream::StreamEvent::Done {
1765 text: String::new(),
1766 tool_calls: vec![],
1767 })
1768 .await;
1769 return Ok(());
1770 }
1771
1772 let max_ctx = backend.context_length().unwrap_or(32768);
1773 let headroom = params.max_tokens.min(max_ctx / 4);
1774 let max_prompt = max_ctx.saturating_sub(headroom);
1775 let tokens = if tokens.len() > max_prompt {
1776 tokens[tokens.len() - max_prompt..].to_vec()
1777 } else {
1778 tokens
1779 };
1780
1781 let mut generated = Vec::new();
1782 let logits = backend.forward(&tokens, 0)?;
1783 let mut next_token = tasks::generate::sample_token(&logits, params)?;
1784
1785 for _ in 0..params.max_tokens {
1786 if eos.map_or(false, |id| next_token == id)
1787 || eos_alt.map_or(false, |id| next_token == id)
1788 {
1789 break;
1790 }
1791
1792 generated.push(next_token);
1793 let delta = backend.decode(&[next_token])?;
1794 if !delta.is_empty()
1795 && tx.send(stream::StreamEvent::TextDelta(delta)).await.is_err()
1796 {
1797 return Ok(());
1798 }
1799
1800 if !params.stop.is_empty() {
1801 let text_so_far = backend.decode(&generated)?;
1802 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1803 break;
1804 }
1805 }
1806
1807 let pos = tokens.len() + generated.len() - 1;
1808 let logits = backend.forward(&[next_token], pos)?;
1809 next_token = tasks::generate::sample_token(&logits, params)?;
1810 }
1811
1812 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1813 let _ = tx
1814 .send(stream::StreamEvent::Done {
1815 text,
1816 tool_calls: vec![],
1817 })
1818 .await;
1819 Ok(())
1820 }
1821
1822 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1826 fn stream_local_mlx(
1827 handle: backend_cache::CachedBackend<backend::MlxBackend>,
1828 cache: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
1829 model_id: String,
1830 req: GenerateRequest,
1831 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1832 ) -> Result<(), InferenceError> {
1833 let mut guard = handle.lock().map_err(|_| {
1834 InferenceError::InferenceFailed(format!(
1835 "MLX backend mutex poisoned for {model_id}"
1836 ))
1837 })?;
1838 let backend: &mut backend::MlxBackend = &mut *guard;
1839 backend.clear_kv_cache();
1840
1841 let formatted = tasks::generate::apply_chat_template(
1842 &req.prompt,
1843 req.context.as_deref(),
1844 req.params.thinking,
1845 );
1846 let tokens = backend.encode(&formatted)?;
1847 let eos = backend.eos_token_id();
1848 let eos_alt = backend.token_id("<|im_end|>");
1849 let params = &req.params;
1850
1851 if tokens.is_empty() {
1852 let _ = tx.blocking_send(stream::StreamEvent::Done {
1853 text: String::new(),
1854 tool_calls: vec![],
1855 });
1856 return Ok(());
1857 }
1858
1859 let max_ctx = backend.context_length();
1860 let headroom = params.max_tokens.min(max_ctx / 4);
1861 let max_prompt = max_ctx.saturating_sub(headroom);
1862 let tokens = if tokens.len() > max_prompt {
1863 tokens[tokens.len() - max_prompt..].to_vec()
1864 } else {
1865 tokens
1866 };
1867
1868 let mut generated = Vec::new();
1869 let logits = match Self::catch_mlx("stream prefill", || backend.forward(&tokens, 0)) {
1874 Ok(v) => v,
1875 Err(e) => { cache.invalidate(&model_id); return Err(e); }
1876 };
1877 let mut next_token = Self::sample_from_logits(&logits, params)?;
1878
1879 for _ in 0..params.max_tokens {
1880 if eos.map_or(false, |id| next_token == id)
1881 || eos_alt.map_or(false, |id| next_token == id)
1882 {
1883 break;
1884 }
1885
1886 generated.push(next_token);
1887 let delta = backend.decode(&[next_token])?;
1888 if !delta.is_empty()
1889 && tx.blocking_send(stream::StreamEvent::TextDelta(delta)).is_err()
1890 {
1891 return Ok(());
1892 }
1893
1894 if !params.stop.is_empty() {
1895 let text_so_far = backend.decode(&generated)?;
1896 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1897 break;
1898 }
1899 }
1900
1901 let pos = tokens.len() + generated.len() - 1;
1902 let logits = match Self::catch_mlx("stream forward", || backend.forward(&[next_token], pos)) {
1903 Ok(v) => v,
1904 Err(e) => { cache.invalidate(&model_id); return Err(e); }
1905 };
1906 next_token = Self::sample_from_logits(&logits, params)?;
1907 }
1908
1909 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1910 let _ = tx.blocking_send(stream::StreamEvent::Done {
1911 text,
1912 tool_calls: vec![],
1913 });
1914 Ok(())
1915 }
1916
1917 pub async fn route_context_snapshot(
1919 &self,
1920 prompt: &str,
1921 workload: RoutingWorkload,
1922 has_tools: bool,
1923 has_vision: bool,
1924 ) -> AdaptiveRoutingDecision {
1925 let tracker = self.outcome_tracker.read().await;
1926 self.adaptive_router.route_context_aware(
1927 prompt,
1928 0,
1929 &self.unified_registry,
1930 &tracker,
1931 has_tools,
1932 has_vision,
1933 workload,
1934 )
1935 }
1936
1937 pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
1940 Ok(self.generate_tracked(req).await?.text)
1941 }
1942
1943 pub async fn tokenize(
1955 &self,
1956 model: &str,
1957 text: &str,
1958 ) -> Result<Vec<u32>, InferenceError> {
1959 self.assert_local_for_tokenize(model)?;
1960
1961 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1962 {
1963 let schema = self
1964 .unified_registry
1965 .get(model)
1966 .or_else(|| self.unified_registry.find_by_name(model))
1967 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1968 .clone();
1969 let handle = self.ensure_mlx_backend(&schema).await?;
1970 let guard = handle.lock().map_err(|_| {
1971 InferenceError::InferenceFailed(format!(
1972 "MLX backend mutex poisoned for {}", schema.id
1973 ))
1974 })?;
1975 return guard.tokenize_raw(text);
1976 }
1977
1978 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1979 {
1980 self.ensure_backend(model).await?;
1981 let read = self.backend.read().await;
1982 let backend = read.as_ref().ok_or_else(|| {
1983 InferenceError::InferenceFailed(
1984 "candle backend missing after ensure_backend".to_string(),
1985 )
1986 })?;
1987 backend.tokenize_raw(text)
1988 }
1989 }
1990
1991 pub async fn detokenize(
1993 &self,
1994 model: &str,
1995 tokens: &[u32],
1996 ) -> Result<String, InferenceError> {
1997 self.assert_local_for_tokenize(model)?;
1998
1999 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2000 {
2001 let schema = self
2002 .unified_registry
2003 .get(model)
2004 .or_else(|| self.unified_registry.find_by_name(model))
2005 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
2006 .clone();
2007 let handle = self.ensure_mlx_backend(&schema).await?;
2008 let guard = handle.lock().map_err(|_| {
2009 InferenceError::InferenceFailed(format!(
2010 "MLX backend mutex poisoned for {}", schema.id
2011 ))
2012 })?;
2013 return guard.detokenize_raw(tokens);
2014 }
2015
2016 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2017 {
2018 self.ensure_backend(model).await?;
2019 let read = self.backend.read().await;
2020 let backend = read.as_ref().ok_or_else(|| {
2021 InferenceError::InferenceFailed(
2022 "candle backend missing after ensure_backend".to_string(),
2023 )
2024 })?;
2025 backend.detokenize_raw(tokens)
2026 }
2027 }
2028
2029 fn assert_local_for_tokenize(&self, model: &str) -> Result<(), InferenceError> {
2033 if let Some(schema) = self
2034 .unified_registry
2035 .get(model)
2036 .or_else(|| self.unified_registry.find_by_name(model))
2037 {
2038 if !schema.is_local() {
2039 return Err(InferenceError::UnsupportedMode {
2040 mode: "tokenize/detokenize",
2041 backend: "remote",
2042 reason:
2043 "remote provider tokenizer is not exposed by the runtime; \
2044 use a local model (Qwen3 GGUF / MLX) for tokenizer-correctness checks",
2045 });
2046 }
2047 }
2048 Ok(())
2050 }
2051
2052 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2058 fn catch_mlx<F, T>(context: &str, f: F) -> Result<T, InferenceError>
2059 where
2060 F: FnOnce() -> Result<T, InferenceError>,
2061 {
2062 std::panic::catch_unwind(std::panic::AssertUnwindSafe(f))
2063 .map_err(|e| InferenceError::InferenceFailed(
2064 format!("MLX panicked during {context}: {e:?}")
2065 ))?
2066 }
2067
2068 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2071 async fn generate_mlx(
2072 &self,
2073 req: GenerateRequest,
2074 model_id: &str,
2075 ) -> Result<(String, Option<u64>), InferenceError> {
2076 let start = std::time::Instant::now();
2077
2078 let schema = self
2079 .unified_registry
2080 .get(model_id)
2081 .cloned()
2082 .ok_or_else(|| InferenceError::InferenceFailed(format!(
2083 "generate_mlx: unknown schema id {model_id}"
2084 )))?;
2085 let handle = self.ensure_mlx_backend(&schema).await?;
2086 let mut guard = handle.lock().map_err(|_| {
2087 InferenceError::InferenceFailed(format!(
2088 "MLX backend mutex poisoned for {model_id}"
2089 ))
2090 })?;
2091 let backend: &mut backend::MlxBackend = &mut *guard;
2092 backend.clear_kv_cache();
2093
2094 let formatted = tasks::generate::apply_chat_template(
2095 &req.prompt,
2096 req.context.as_deref(),
2097 req.params.thinking,
2098 );
2099 let tokens = backend.encode(&formatted)?;
2100 let eos = backend.eos_token_id();
2101 let eos_alt = backend.token_id("<|im_end|>");
2102 let params = &req.params;
2103
2104 if tokens.is_empty() {
2105 return Ok((String::new(), None));
2106 }
2107
2108 let max_ctx = backend.context_length();
2110 let headroom = params.max_tokens.min(max_ctx / 4);
2111 let max_prompt = max_ctx.saturating_sub(headroom);
2112 let tokens = if tokens.len() > max_prompt {
2113 tokens[tokens.len() - max_prompt..].to_vec()
2114 } else {
2115 tokens
2116 };
2117
2118 let mut generated = Vec::new();
2119
2120 let logits = match Self::catch_mlx("prefill", || backend.forward(&tokens, 0)) {
2126 Ok(v) => v,
2127 Err(e) => {
2128 drop(guard);
2129 self.mlx_backends.invalidate(model_id);
2130 return Err(e);
2131 }
2132 };
2133 let mut next_token = Self::sample_from_logits(&logits, params)?;
2134 let ttft_ms = Some(start.elapsed().as_millis() as u64);
2135
2136 for _ in 0..params.max_tokens {
2137 if eos.map_or(false, |id| next_token == id)
2138 || eos_alt.map_or(false, |id| next_token == id)
2139 {
2140 break;
2141 }
2142
2143 generated.push(next_token);
2144
2145 if !params.stop.is_empty() {
2146 let text_so_far = backend.decode(&generated)?;
2147 if params.stop.iter().any(|s| text_so_far.contains(s)) {
2148 break;
2149 }
2150 }
2151
2152 let pos = tokens.len() + generated.len() - 1;
2153 let logits = match Self::catch_mlx("forward", || backend.forward(&[next_token], pos)) {
2154 Ok(v) => v,
2155 Err(e) => {
2156 drop(guard);
2157 self.mlx_backends.invalidate(model_id);
2158 return Err(e);
2159 }
2160 };
2161 next_token = Self::sample_from_logits(&logits, params)?;
2162 }
2163
2164 let text = backend.decode(&generated)?;
2165 Ok((tasks::generate::strip_thinking(&text, params.thinking), ttft_ms))
2166 }
2167
2168 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2170 fn sample_from_logits(logits: &[f32], params: &GenerateParams) -> Result<u32, InferenceError> {
2171 if params.temperature <= 0.0 {
2172 let (idx, _) = logits
2174 .iter()
2175 .enumerate()
2176 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2177 .ok_or_else(|| InferenceError::InferenceFailed("empty logits".into()))?;
2178 return Ok(idx as u32);
2179 }
2180
2181 let temp = params.temperature as f32;
2183 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
2184 let mut probs: Vec<f32> = logits
2185 .iter()
2186 .map(|&l| ((l - max_logit) / temp).exp())
2187 .collect();
2188 let sum: f32 = probs.iter().sum();
2189 for p in &mut probs {
2190 *p /= sum;
2191 }
2192
2193 if params.top_p < 1.0 {
2195 let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
2196 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2197 let mut cumsum = 0.0;
2198 let mut cutoff_idx = indexed.len();
2199 for (i, &(_, p)) in indexed.iter().enumerate() {
2200 cumsum += p;
2201 if cumsum > params.top_p as f32 {
2202 cutoff_idx = i + 1;
2203 break;
2204 }
2205 }
2206 let allowed: std::collections::HashSet<usize> =
2208 indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
2209 for (i, p) in probs.iter_mut().enumerate() {
2210 if !allowed.contains(&i) {
2211 *p = 0.0;
2212 }
2213 }
2214 let sum: f32 = probs.iter().sum();
2215 if sum > 0.0 {
2216 for p in &mut probs {
2217 *p /= sum;
2218 }
2219 }
2220 }
2221
2222 use rand::Rng;
2224 let mut rng = rand::rng();
2225 let r: f32 = rng.random();
2226 let mut cumsum = 0.0;
2227 for (i, &p) in probs.iter().enumerate() {
2228 cumsum += p;
2229 if cumsum >= r {
2230 return Ok(i as u32);
2231 }
2232 }
2233 Ok((probs.len() - 1) as u32)
2234 }
2235
2236 pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
2239 let instruction = req
2240 .instruction
2241 .as_deref()
2242 .unwrap_or("Retrieve relevant memory facts");
2243
2244 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2245 {
2246 let model_id = self.ensure_mlx_embedding_backend().await?;
2247 let schema = self
2248 .unified_registry
2249 .get(&model_id)
2250 .cloned()
2251 .ok_or_else(|| InferenceError::InferenceFailed(format!(
2252 "embed: unknown schema id {model_id}"
2253 )))?;
2254 let handle = self.ensure_mlx_backend(&schema).await?;
2255 let mut guard = handle.lock().map_err(|_| {
2256 InferenceError::InferenceFailed(format!(
2257 "MLX embedding backend mutex poisoned for {model_id}"
2258 ))
2259 })?;
2260 let backend: &mut backend::MlxBackend = &mut *guard;
2261
2262 let mut results = Vec::with_capacity(req.texts.len());
2263 for text in &req.texts {
2264 let embedding = if req.is_query {
2265 backend.embed_query(text, instruction)?
2266 } else {
2267 backend.embed_one(text)?
2268 };
2269 results.push(embedding);
2270 }
2271 return Ok(results);
2272 }
2273
2274 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2275 {
2276 self.ensure_embedding_backend().await?;
2277 let mut write = self.embedding_backend.write().await;
2278 let backend = write.as_mut().unwrap();
2279
2280 let mut results = Vec::with_capacity(req.texts.len());
2281 for text in &req.texts {
2282 let embedding = if req.is_query {
2283 backend.embed_query(text, instruction)?
2284 } else {
2285 backend.embed_one(text)?
2286 };
2287 results.push(embedding);
2288 }
2289 Ok(results)
2290 }
2291 }
2292
2293 pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResult, InferenceError> {
2323 if req.documents.is_empty() {
2324 return Ok(RerankResult { ranked: Vec::new(), model_used: None });
2325 }
2326
2327 let model_name = match req.model.clone() {
2328 Some(m) => m,
2329 None => self
2330 .preferred_model_for_capability(ModelCapability::Rerank)
2331 .map(str::to_string)
2332 .ok_or_else(|| {
2333 InferenceError::InferenceFailed(
2334 "no reranker model available — pull a Qwen3-Reranker model first".into(),
2335 )
2336 })?,
2337 };
2338
2339 let schema = self
2340 .unified_registry
2341 .find_by_name(&model_name)
2342 .or_else(|| self.unified_registry.get(&model_name))
2343 .cloned()
2344 .ok_or_else(|| {
2345 InferenceError::InferenceFailed(format!(
2346 "rerank: unknown reranker model {model_name}"
2347 ))
2348 })?;
2349 if !schema.has_capability(ModelCapability::Rerank) {
2350 return Err(InferenceError::InferenceFailed(format!(
2351 "model {} does not declare the Rerank capability",
2352 schema.name
2353 )));
2354 }
2355
2356 let instruction = req.instruction.as_deref().unwrap_or(
2357 "Given a web search query, retrieve relevant passages that answer the query",
2358 );
2359
2360 let mut scored: Vec<RerankedDocument> = Vec::with_capacity(req.documents.len());
2361 for (idx, doc) in req.documents.iter().enumerate() {
2362 let prompt = rerank_prompt(instruction, &req.query, doc);
2363 let gen_req = GenerateRequest {
2364 prompt,
2365 model: Some(schema.id.clone()),
2366 params: tasks::generate::GenerateParams {
2367 temperature: 0.0,
2368 max_tokens: 3,
2372 thinking: tasks::generate::ThinkingMode::Off,
2373 ..Default::default()
2374 },
2375 context: None,
2376 tools: None,
2377 images: None,
2378 messages: None,
2379 cache_control: false,
2380 response_format: None,
2381 intent: None,
2382 };
2383 let out = self.generate(gen_req).await?;
2384 let score = score_from_rerank_output(&out, &schema.name);
2385 scored.push(RerankedDocument {
2386 index: idx,
2387 score,
2388 document: doc.clone(),
2389 });
2390 }
2391
2392 scored.sort_by(|a, b| {
2395 b.score
2396 .partial_cmp(&a.score)
2397 .unwrap_or(std::cmp::Ordering::Equal)
2398 .then_with(|| a.index.cmp(&b.index))
2399 });
2400 if let Some(n) = req.top_n {
2401 scored.truncate(n);
2402 }
2403
2404 Ok(RerankResult {
2405 ranked: scored,
2406 model_used: Some(schema.name),
2407 })
2408 }
2409
2410 pub async fn ground(&self, req: GroundRequest) -> Result<GroundResult, InferenceError> {
2420 let model_name = match req.model.clone() {
2421 Some(m) => m,
2422 None => self
2423 .preferred_model_for_capability(ModelCapability::Grounding)
2424 .map(str::to_string)
2425 .ok_or_else(|| {
2426 InferenceError::InferenceFailed(
2427 "no grounding-capable model available — pull a Qwen2.5-VL model first"
2428 .into(),
2429 )
2430 })?,
2431 };
2432
2433 let gen_req = GenerateRequest {
2434 prompt: req.prompt.clone(),
2435 model: Some(model_name),
2436 params: GenerateParams::default(),
2437 context: None,
2438 tools: None,
2439 images: Some(vec![req.image.clone()]),
2440 messages: None,
2441 cache_control: false,
2442 response_format: None,
2443 intent: None,
2444 };
2445 let result = self.generate_tracked(gen_req).await?;
2446 Ok(GroundResult {
2447 boxes: result.bounding_boxes,
2448 raw_text: result.text,
2449 model_used: Some(result.model_used),
2450 })
2451 }
2452
2453 pub async fn classify(
2456 &self,
2457 req: ClassifyRequest,
2458 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2459 let model = match req
2460 .model
2461 .clone()
2462 .or_else(|| self.preferred_model_for_capability(ModelCapability::Classify).map(str::to_string))
2463 {
2464 Some(m) => m,
2465 None => {
2466 let m = self.router.route_small(&self.registry);
2467 debug!(model = %m, "auto-routed classify request");
2468 m
2469 }
2470 };
2471
2472 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2475 {
2476 return self.classify_via_generate(req, &model).await;
2477 }
2478
2479 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
2480 {
2481 self.ensure_backend(&model).await?;
2482 let mut write = self.backend.write().await;
2483 let backend = write.as_mut().unwrap();
2484 tasks::classify::classify(backend, req).await
2485 }
2486 }
2487
2488 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2490 async fn classify_via_generate(
2491 &self,
2492 req: ClassifyRequest,
2493 model: &str,
2494 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2495 let labels_str = req
2496 .labels
2497 .iter()
2498 .enumerate()
2499 .map(|(i, l)| format!("{}. {}", i + 1, l))
2500 .collect::<Vec<_>>()
2501 .join("\n");
2502
2503 let prompt = format!(
2504 "Classify the following text into one of these categories:\n\
2505 {labels_str}\n\n\
2506 Text: {}\n\n\
2507 Respond with ONLY the category name, nothing else.",
2508 req.text
2509 );
2510
2511 let gen_req = GenerateRequest {
2512 prompt,
2513 model: Some(model.to_string()),
2514 params: tasks::generate::GenerateParams {
2515 temperature: 0.0,
2516 max_tokens: 32,
2517 thinking: tasks::generate::ThinkingMode::Off,
2520 ..Default::default()
2521 },
2522 context: None,
2523 tools: None,
2524 images: None,
2525 messages: None,
2526 cache_control: false,
2527 response_format: None,
2528 intent: None,
2529 };
2530
2531 let response = self.generate(gen_req).await?;
2532 let response_lower = response.trim().to_lowercase();
2533
2534 let mut results: Vec<ClassifyResult> = req
2535 .labels
2536 .iter()
2537 .map(|label| {
2538 let label_lower = label.to_lowercase();
2539 let score = if response_lower == label_lower {
2540 1.0
2541 } else if response_lower.contains(&label_lower) {
2542 0.8
2543 } else {
2544 let label_words: Vec<&str> = label_lower.split_whitespace().collect();
2545 let matches = label_words
2546 .iter()
2547 .filter(|w| response_lower.contains(**w))
2548 .count();
2549 if label_words.is_empty() {
2550 0.0
2551 } else {
2552 0.5 * (matches as f64 / label_words.len() as f64)
2553 }
2554 };
2555 ClassifyResult {
2556 label: label.clone(),
2557 score,
2558 }
2559 })
2560 .collect();
2561
2562 results.sort_by(|a, b| {
2563 b.score
2564 .partial_cmp(&a.score)
2565 .unwrap_or(std::cmp::Ordering::Equal)
2566 });
2567
2568 let total: f64 = results.iter().map(|r| r.score).sum();
2569 if total > 0.0 {
2570 for r in &mut results {
2571 r.score /= total;
2572 }
2573 }
2574
2575 Ok(results)
2576 }
2577
2578 pub async fn transcribe(
2580 &self,
2581 req: TranscribeRequest,
2582 ) -> Result<TranscribeResult, InferenceError> {
2583 let candidates =
2584 self.speech_candidates(ModelCapability::SpeechToText, req.model.as_deref())?;
2585 let mut last_error = None;
2586
2587 for schema in candidates {
2588 let result = match &schema.source {
2589 ModelSource::Mlx { .. } => self.transcribe_local_mlx(&schema, &req).await,
2590 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2591 self.transcribe_elevenlabs(&schema, &req).await
2592 }
2593 _ => Err(InferenceError::InferenceFailed(format!(
2594 "speech-to-text not implemented for model source: {}",
2595 schema.id
2596 ))),
2597 };
2598
2599 match result {
2600 Ok(result) => return Ok(result),
2601 Err(err) => last_error = Some(err),
2602 }
2603 }
2604
2605 Err(last_error.unwrap_or_else(|| {
2606 InferenceError::InferenceFailed("no speech-to-text models available".into())
2607 }))
2608 }
2609
2610 pub async fn synthesize(
2612 &self,
2613 req: SynthesizeRequest,
2614 ) -> Result<SynthesizeResult, InferenceError> {
2615 let candidates =
2616 self.speech_candidates(ModelCapability::TextToSpeech, req.model.as_deref())?;
2617 let mut last_error = None;
2618
2619 for schema in candidates {
2620 let result = match &schema.source {
2621 ModelSource::Mlx { .. } => self.synthesize_local_mlx(&schema, &req).await,
2622 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2623 self.synthesize_elevenlabs(&schema, &req).await
2624 }
2625 _ => Err(InferenceError::InferenceFailed(format!(
2626 "text-to-speech not implemented for model source: {}",
2627 schema.id
2628 ))),
2629 };
2630
2631 match result {
2632 Ok(result) => return Ok(result),
2633 Err(err) => last_error = Some(err),
2634 }
2635 }
2636
2637 Err(last_error.unwrap_or_else(|| {
2638 InferenceError::InferenceFailed("no text-to-speech models available".into())
2639 }))
2640 }
2641
2642 pub async fn generate_image(
2644 &self,
2645 req: GenerateImageRequest,
2646 ) -> Result<GenerateImageResult, InferenceError> {
2647 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2654 {
2655 use crate::backend::external_flux;
2656 let backend = std::env::var("CAR_IMAGE_BACKEND")
2657 .unwrap_or_else(|_| "native".to_string());
2658 let use_external = match backend.as_str() {
2659 "external" => true,
2660 "native" => false,
2661 _ => external_flux::is_available() && backend == "auto-external",
2663 };
2664 if use_external {
2665 tracing::info!(
2666 "routing image generation to external mflux \
2667 (set CAR_IMAGE_BACKEND=native to use the Rust port)"
2668 );
2669 let mut req = req;
2670 req.model = self.resolve_external_hf_repo(
2671 req.model.as_deref(),
2672 ModelCapability::ImageGeneration,
2673 );
2674 return external_flux::generate_image(&req);
2675 }
2676 tracing::info!("using native Rust MLX Flux backend");
2677 }
2678
2679 let candidates =
2680 self.media_generation_candidates(ModelCapability::ImageGeneration, req.model.as_deref())?;
2681 let mut last_error = None;
2682
2683 for schema in candidates {
2684 let result = match &schema.source {
2685 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2686 ModelSource::Mlx { .. } => self.generate_image_native_mlx(&schema, &req).await,
2687 _ => Err(InferenceError::InferenceFailed(format!(
2688 "image generation not implemented for model source: {}",
2689 schema.id
2690 ))),
2691 };
2692
2693 match result {
2694 Ok(result) => return Ok(result),
2695 Err(err) => last_error = Some(err),
2696 }
2697 }
2698
2699 Err(last_error.unwrap_or_else(|| {
2700 InferenceError::InferenceFailed("no image generation models available".into())
2701 }))
2702 }
2703
2704 pub async fn generate_image_batch(
2720 &self,
2721 req: GenerateImageRequest,
2722 ) -> Result<Vec<GenerateImageResult>, InferenceError> {
2723 let count = req.variant_count.unwrap_or(1).max(1);
2724 if count == 1 {
2725 return self.generate_image(req).await.map(|r| vec![r]);
2726 }
2727 let base_seed = req.seed.unwrap_or(0);
2728 let mut results = Vec::with_capacity(count as usize);
2729 for i in 0..count {
2730 let mut variant_req = req.clone();
2735 variant_req.seed = Some(base_seed.wrapping_add(i as u64));
2736 variant_req.variant_count = Some(1);
2740 results.push(self.generate_image(variant_req).await?);
2741 }
2742 Ok(results)
2743 }
2744
2745 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2747 async fn generate_image_native_mlx(
2748 &self,
2749 schema: &ModelSchema,
2750 req: &GenerateImageRequest,
2751 ) -> Result<GenerateImageResult, InferenceError> {
2752 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2753 let size = backend_cache::estimate_model_size(&model_dir);
2754 let cache = Arc::clone(&self.flux_cache);
2755 let key = schema.id.clone();
2756 let handle = cache.get_or_load(&key, size, || {
2757 backend::mlx_flux::FluxBackend::load(&model_dir)
2758 })?;
2759 let req = req.clone();
2763 tokio::task::spawn_blocking(move || -> Result<GenerateImageResult, InferenceError> {
2764 let mut guard = handle.lock().map_err(|_| {
2765 InferenceError::InferenceFailed("flux backend mutex poisoned".into())
2766 })?;
2767 guard.generate(&req)
2768 })
2769 .await
2770 .map_err(|e| InferenceError::InferenceFailed(format!("flux task join: {e}")))?
2771 }
2772
2773 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2775 async fn generate_video_native_mlx(
2776 &self,
2777 schema: &ModelSchema,
2778 req: &GenerateVideoRequest,
2779 ) -> Result<GenerateVideoResult, InferenceError> {
2780 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2781 let size = backend_cache::estimate_model_size(&model_dir);
2782 let cache = Arc::clone(&self.ltx_cache);
2783 let key = schema.id.clone();
2784 let handle = cache.get_or_load(&key, size, || {
2785 backend::mlx_ltx::LtxBackend::load(&model_dir)
2786 })?;
2787 let req = req.clone();
2788 tokio::task::spawn_blocking(move || -> Result<GenerateVideoResult, InferenceError> {
2789 let mut guard = handle.lock().map_err(|_| {
2790 InferenceError::InferenceFailed("ltx backend mutex poisoned".into())
2791 })?;
2792 guard.generate(&req)
2793 })
2794 .await
2795 .map_err(|e| InferenceError::InferenceFailed(format!("ltx task join: {e}")))?
2796 }
2797
2798 pub async fn generate_video(
2800 &self,
2801 req: GenerateVideoRequest,
2802 ) -> Result<GenerateVideoResult, InferenceError> {
2803 if let Err(msg) = req.validate() {
2806 return Err(InferenceError::InferenceFailed(format!(
2807 "invalid GenerateVideoRequest: {}",
2808 msg
2809 )));
2810 }
2811 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2816 {
2817 use crate::backend::external_ltx;
2818 let backend = std::env::var("CAR_VIDEO_BACKEND")
2823 .unwrap_or_else(|_| "native".to_string());
2824 let use_external = match backend.as_str() {
2825 "external" => true,
2826 "native" => false,
2827 "auto-external" => external_ltx::is_available(),
2831 _ => false,
2832 };
2833 if use_external {
2834 tracing::info!("CAR_VIDEO_BACKEND requested external LTX routing for LTX-family models");
2835 } else {
2836 tracing::info!("using family-aware MLX video routing");
2837 }
2838 }
2839
2840 let candidates =
2841 self.media_generation_candidates(ModelCapability::VideoGeneration, req.model.as_deref())?;
2842 let mut last_error = None;
2843
2844 for schema in candidates {
2845 let result = match &schema.source {
2846 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
2847 ModelSource::Mlx { hf_repo, .. } => {
2848 if crate::backend::external_mlx_video::is_wan_family(&schema) {
2849 match self.unified_registry.ensure_local(&schema.id).await {
2850 Ok(model_dir) => crate::backend::external_mlx_video::generate_wan_video(&schema, &model_dir, &req),
2851 Err(err) => Err(err),
2852 }
2853 } else {
2854 let backend = std::env::var("CAR_VIDEO_BACKEND")
2855 .unwrap_or_else(|_| "native".to_string());
2856 let use_external_ltx = match backend.as_str() {
2857 "external" => true,
2858 "native" => false,
2859 "auto-external" => crate::backend::external_ltx::is_available(),
2860 _ => false,
2861 };
2862 if use_external_ltx {
2863 let mut req = req.clone();
2864 req.model = Some(hf_repo.clone());
2865 crate::backend::external_ltx::generate_video(&req)
2866 } else {
2867 self.generate_video_native_mlx(&schema, &req).await
2868 }
2869 }
2870 }
2871 _ => Err(InferenceError::InferenceFailed(format!(
2872 "video generation not implemented for model source: {}",
2873 schema.id
2874 ))),
2875 };
2876
2877 match result {
2878 Ok(result) => return Ok(result),
2879 Err(err) => last_error = Some(err),
2880 }
2881 }
2882
2883 Err(last_error.unwrap_or_else(|| {
2884 InferenceError::InferenceFailed("no video generation models available".into())
2885 }))
2886 }
2887
2888 pub fn list_models_unified(&self) -> Vec<ModelInfo> {
2890 self.unified_registry
2891 .list()
2892 .iter()
2893 .map(|m| ModelInfo::from(*m))
2894 .collect()
2895 }
2896
2897 pub fn available_model_upgrades(&self) -> Vec<ModelUpgrade> {
2899 self.unified_registry.available_upgrades()
2900 }
2901
2902 pub fn list_schemas(&self) -> Vec<ModelSchema> {
2905 self.unified_registry.list().into_iter().cloned().collect()
2906 }
2907
2908 pub fn list_models(&self) -> Vec<models::ModelInfo> {
2909 self.registry.list_models()
2910 }
2911
2912 pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
2914 let schema = self
2915 .unified_registry
2916 .find_by_name(name)
2917 .or_else(|| self.unified_registry.get(name))
2918 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2919 self.unified_registry.ensure_local(&schema.id).await
2920 }
2921
2922 pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
2924 let schema = self
2925 .unified_registry
2926 .get(name)
2927 .or_else(|| {
2928 self.unified_registry
2929 .list()
2930 .into_iter()
2931 .find(|schema| schema.name.eq_ignore_ascii_case(name))
2932 })
2933 .or_else(|| self.unified_registry.find_by_name(name))
2934 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2935 let model_dir = self.unified_registry.models_dir().join(&schema.name);
2936 if model_dir.exists() {
2937 std::fs::remove_dir_all(&model_dir)?;
2938 }
2939 match &schema.source {
2940 ModelSource::Mlx { hf_repo, .. } => {
2941 remove_huggingface_repo_cache(hf_repo)?;
2942 }
2943 ModelSource::Local {
2944 hf_repo,
2945 tokenizer_repo,
2946 ..
2947 } => {
2948 remove_huggingface_repo_cache(hf_repo)?;
2949 remove_huggingface_repo_cache(tokenizer_repo)?;
2950 }
2951 _ => {}
2952 }
2953 Ok(())
2954 }
2955
2956 pub fn register_model(&mut self, schema: ModelSchema) {
2958 self.unified_registry.register(schema);
2959 }
2960
2961 pub async fn discover_vllm_mlx_models(&mut self) -> usize {
2964 let config = vllm_mlx::VllmMlxConfig::default();
2965 if !config.auto_discover {
2966 return 0;
2967 }
2968 vllm_mlx::discover_and_register(&config, &mut self.unified_registry).await
2969 }
2970
2971 pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
2973 self.outcome_tracker.clone()
2974 }
2975
2976 async fn auto_save_outcomes(&self) {
2978 if let Err(e) = self.save_outcomes().await {
2979 tracing::debug!("auto-save outcomes failed: {}", e);
2980 }
2981 if let Err(e) = self.save_key_pool_stats().await {
2982 tracing::debug!("auto-save key pool stats failed: {}", e);
2983 }
2984 }
2985
2986 pub async fn save_outcomes(&self) -> Result<(), std::io::Error> {
2988 let tracker = self.outcome_tracker.read().await;
2989 let path = self.config.models_dir.join("outcome_profiles.json");
2990 tracker.save_to_file(&path)
2991 }
2992
2993 pub async fn save_key_pool_stats(&self) -> Result<(), std::io::Error> {
2995 let path = self.config.models_dir.join("key_pool_stats.json");
2996 self.remote_backend.key_pool.save_stats(&path).await
2997 }
2998
2999 pub async fn key_pool_stats(
3001 &self,
3002 ) -> std::collections::HashMap<String, Vec<key_pool::KeyStats>> {
3003 self.remote_backend.key_pool.all_stats().await
3004 }
3005
3006 pub async fn export_profiles(&self) -> Vec<ModelProfile> {
3008 let tracker = self.outcome_tracker.read().await;
3009 tracker.export_profiles()
3010 }
3011
3012 pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
3014 let mut tracker = self.outcome_tracker.write().await;
3015 tracker.import_profiles(profiles);
3016 }
3017
3018 pub async fn prepare_speech_runtime(&self) -> Result<PathBuf, InferenceError> {
3021 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3022 {
3023 Ok(self.config.models_dir.clone())
3025 }
3026 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3027 {
3028 Ok(self.ensure_speech_runtime().await?.root)
3029 }
3030 }
3031
3032 pub fn set_speech_policy(&mut self, policy: SpeechPolicy) {
3034 self.speech_policy = policy;
3035 }
3036
3037 pub fn set_routing_config(&mut self, config: RoutingConfig) {
3038 self.adaptive_router.set_config(config);
3039 }
3040
3041 pub async fn install_curated_speech(
3043 &mut self,
3044 ) -> Result<Vec<SpeechInstallReport>, InferenceError> {
3045 let _runtime_root = self.prepare_speech_runtime().await?;
3046 let schemas = self.list_schemas();
3047 let mut repos = Vec::new();
3048 for schema in &schemas {
3049 if !schema.is_mlx() || !schema.tags.iter().any(|tag| tag == "speech") {
3050 continue;
3051 }
3052 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3053 if !repos.iter().any(|existing: &String| existing == hf_repo) {
3054 repos.push(hf_repo.clone());
3055 }
3056 }
3057 }
3058
3059 let mut installed = Vec::new();
3060 for repo in repos {
3061 let (snapshot_path, files_downloaded) = download_hf_repo_snapshot(&repo).await?;
3062 let name = schemas
3063 .iter()
3064 .find(|schema| {
3065 matches!(&schema.source, ModelSource::Mlx { hf_repo, .. } if hf_repo == &repo)
3066 })
3067 .map(|schema| schema.name.clone())
3068 .unwrap_or_else(|| repo.clone());
3069 installed.push(SpeechInstallReport {
3070 name,
3071 hf_repo: repo,
3072 snapshot_path,
3073 files_downloaded,
3074 });
3075 }
3076
3077 self.unified_registry.refresh_availability();
3078 Ok(installed)
3079 }
3080
3081
3082 pub fn speech_health(&self) -> SpeechHealthReport {
3084 let local_stt_default = self
3085 .speech_health_default_name(ModelCapability::SpeechToText, true, false);
3086 let local_tts_default = self
3087 .speech_health_default_name(ModelCapability::TextToSpeech, true, false);
3088 let remote_stt_default = self
3089 .speech_health_default_name(ModelCapability::SpeechToText, false, true);
3090 let remote_tts_default = self
3091 .speech_health_default_name(ModelCapability::TextToSpeech, false, true);
3092
3093 let mut local_models = Vec::new();
3094 let mut remote_models = Vec::new();
3095 for schema in self.list_schemas() {
3096 let capability = if schema.has_capability(ModelCapability::SpeechToText) {
3097 Some(ModelCapability::SpeechToText)
3098 } else if schema.has_capability(ModelCapability::TextToSpeech) {
3099 Some(ModelCapability::TextToSpeech)
3100 } else {
3101 None
3102 };
3103 let Some(capability) = capability else {
3104 continue;
3105 };
3106
3107 let selected_by_default = local_stt_default
3108 .as_ref()
3109 .is_some_and(|name| name == &schema.name)
3110 || local_tts_default
3111 .as_ref()
3112 .is_some_and(|name| name == &schema.name)
3113 || remote_stt_default
3114 .as_ref()
3115 .is_some_and(|name| name == &schema.name)
3116 || remote_tts_default
3117 .as_ref()
3118 .is_some_and(|name| name == &schema.name);
3119
3120 let health = SpeechModelHealth {
3121 id: schema.id.clone(),
3122 name: schema.name.clone(),
3123 provider: schema.provider.clone(),
3124 capability,
3125 is_local: schema.is_local(),
3126 available: schema.available,
3127 cached: speech_model_cached(&schema),
3128 selected_by_default,
3129 source: speech_model_source_label(&schema),
3130 };
3131 if schema.is_local() {
3132 local_models.push(health);
3133 } else {
3134 remote_models.push(health);
3135 }
3136 }
3137
3138 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3140 let runtime = SpeechRuntimeHealth {
3141 root: self.config.models_dir.clone(),
3142 installed: true,
3143 python: PathBuf::new(),
3144 stt_command: PathBuf::new(),
3145 tts_command: PathBuf::new(),
3146 configured_python: None,
3147 detected_python: None,
3148 };
3149
3150 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3151 let runtime = {
3152 let rt = SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3153 SpeechRuntimeHealth {
3154 root: rt.root.clone(),
3155 installed: rt.is_ready(),
3156 python: rt.python.clone(),
3157 stt_command: rt.stt_program.clone(),
3158 tts_command: rt.tts_program.clone(),
3159 configured_python: std::env::var("CAR_SPEECH_PYTHON")
3160 .ok()
3161 .filter(|value| !value.trim().is_empty()),
3162 detected_python: detect_speech_python(),
3163 }
3164 };
3165
3166 SpeechHealthReport {
3167 runtime,
3168 local_models,
3169 remote_models,
3170 elevenlabs_configured: car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY").is_some(),
3171 prefer_local: self.speech_policy.prefer_local,
3172 allow_remote_fallback: self.speech_policy.allow_remote_fallback,
3173 preferred_local_stt: self.speech_policy.preferred_local_stt.clone(),
3174 preferred_local_tts: self.speech_policy.preferred_local_tts.clone(),
3175 preferred_remote_stt: self.speech_policy.preferred_remote_stt.clone(),
3176 preferred_remote_tts: self.speech_policy.preferred_remote_tts.clone(),
3177 local_stt_default,
3178 local_tts_default,
3179 remote_stt_default,
3180 remote_tts_default,
3181 }
3182 }
3183
3184 pub async fn model_health(&self) -> ModelHealthReport {
3187 let schemas = self.list_schemas();
3188 let total_models = schemas.len();
3189 let available_models = schemas.iter().filter(|schema| schema.available).count();
3190 let local_models = schemas.iter().filter(|schema| schema.is_local()).count();
3191 let remote_models = total_models.saturating_sub(local_models);
3192
3193 let defaults = vec![
3194 self.model_default_health(
3195 ModelCapability::Generate,
3196 self.preferred_model_for_capability(ModelCapability::Generate)
3197 .unwrap_or(&self.config.generation_model),
3198 ),
3199 self.model_default_health(
3200 ModelCapability::Embed,
3201 self.preferred_model_for_capability(ModelCapability::Embed)
3202 .unwrap_or(&self.config.embedding_model),
3203 ),
3204 self.model_default_health(
3205 ModelCapability::Classify,
3206 self.preferred_model_for_capability(ModelCapability::Classify)
3207 .unwrap_or(&self.config.classification_model),
3208 ),
3209 ];
3210
3211 let mut providers = std::collections::BTreeMap::new();
3212 for schema in &schemas {
3213 let entry = providers
3214 .entry(schema.provider.clone())
3215 .or_insert_with(|| ProviderAccumulator {
3216 configured: false,
3217 local_models: 0,
3218 remote_models: 0,
3219 available_models: 0,
3220 capabilities: std::collections::HashSet::new(),
3221 });
3222
3223 entry.configured |= model_source_configured(schema);
3224 if schema.is_local() {
3225 entry.local_models += 1;
3226 } else {
3227 entry.remote_models += 1;
3228 }
3229 if schema.available {
3230 entry.available_models += 1;
3231 }
3232 for capability in &schema.capabilities {
3233 entry.capabilities.insert(*capability);
3234 }
3235 }
3236
3237 let providers = providers
3238 .into_iter()
3239 .map(|(provider, acc)| ModelProviderHealth {
3240 provider,
3241 configured: acc.configured,
3242 local_models: acc.local_models,
3243 remote_models: acc.remote_models,
3244 available_models: acc.available_models,
3245 capabilities: sort_capabilities(acc.capabilities.into_iter().collect()),
3246 })
3247 .collect();
3248
3249 let capabilities = all_model_capabilities()
3250 .into_iter()
3251 .map(|capability| {
3252 let relevant: Vec<&ModelSchema> = schemas
3253 .iter()
3254 .filter(|schema| schema.has_capability(capability))
3255 .collect();
3256 let available: Vec<&ModelSchema> = relevant
3257 .iter()
3258 .copied()
3259 .filter(|schema| schema.available)
3260 .collect();
3261 ModelCapabilityHealth {
3262 capability,
3263 total_models: relevant.len(),
3264 available_models: available.len(),
3265 local_available_models: available
3266 .iter()
3267 .filter(|schema| schema.is_local())
3268 .count(),
3269 remote_available_models: available
3270 .iter()
3271 .filter(|schema| !schema.is_local())
3272 .count(),
3273 }
3274 })
3275 .collect();
3276
3277 let routing = self.routing_scenarios().await;
3278 let routing_config = self.adaptive_router.config().clone();
3279 let benchmark_priors = load_benchmark_prior_health(&self.config.models_dir, &schemas);
3280
3281 ModelHealthReport {
3282 total_models,
3283 available_models,
3284 local_models,
3285 remote_models,
3286 defaults,
3287 providers,
3288 capabilities,
3289 routing_prefer_local: routing_config.prefer_local,
3290 routing_quality_first_cold_start: routing_config.quality_first_cold_start,
3291 routing_min_observations: routing_config.min_observations,
3292 routing_bootstrap_min_task_observations: routing_config.bootstrap_min_task_observations,
3293 routing_bootstrap_quality_floor: routing_config.bootstrap_quality_floor,
3294 routing_quality_weight: routing_config.quality_weight,
3295 routing_latency_weight: routing_config.latency_weight,
3296 routing_cost_weight: routing_config.cost_weight,
3297 routing_scenarios: routing,
3298 benchmark_priors,
3299 speech: self.speech_health(),
3300 }
3301 }
3302
3303 async fn routing_scenarios(&self) -> Vec<RoutingScenarioHealth> {
3304 let tracker = self.outcome_tracker.read().await;
3305 let config = self.adaptive_router.config().clone();
3306 let scenarios = [
3307 (
3308 "interactive_text",
3309 "Summarize the benefits of local-first AI routing in two sentences.",
3310 "text",
3311 RoutingWorkload::Interactive,
3312 false,
3313 false,
3314 ),
3315 (
3316 "background_code",
3317 "Write a Python function named fibonacci(n) that returns the nth Fibonacci number.",
3318 "code",
3319 RoutingWorkload::Background,
3320 false,
3321 false,
3322 ),
3323 (
3324 "interactive_tool_use",
3325 "Use the provided weather tool to get the weather for Boston.",
3326 "tool_use",
3327 RoutingWorkload::Interactive,
3328 true,
3329 false,
3330 ),
3331 (
3332 "interactive_vision",
3333 "What is in this image? Answer in one word.",
3334 "vision",
3335 RoutingWorkload::Interactive,
3336 false,
3337 true,
3338 ),
3339 ];
3340
3341 scenarios
3342 .into_iter()
3343 .map(|(name, prompt, task_family, workload, has_tools, has_vision)| {
3344 let decision = self.adaptive_router.route_context_aware(
3345 prompt,
3346 0,
3347 &self.unified_registry,
3348 &tracker,
3349 has_tools,
3350 has_vision,
3351 workload,
3352 );
3353 let quality_first_cold_start = if has_tools || has_vision {
3354 config.quality_first_cold_start
3355 } else if task_family == "code" && matches!(workload, RoutingWorkload::Background) {
3356 false
3357 } else {
3358 config.quality_first_cold_start
3359 };
3360 RoutingScenarioHealth {
3361 name: name.to_string(),
3362 task_family: task_family.to_string(),
3363 workload,
3364 has_tools,
3365 has_vision,
3366 prefer_local: if task_family == "speech" {
3367 self.speech_policy.prefer_local
3368 } else {
3369 config.prefer_local
3370 },
3371 quality_first_cold_start,
3372 bootstrap_min_task_observations: config.bootstrap_min_task_observations,
3373 bootstrap_quality_floor: config.bootstrap_quality_floor,
3374 model_id: decision.model_id,
3375 model_name: decision.model_name,
3376 reason: decision.reason,
3377 strategy: decision.strategy,
3378 }
3379 })
3380 .collect()
3381 }
3382
3383 pub async fn smoke_test_speech(
3385 &self,
3386 local: bool,
3387 remote: bool,
3388 ) -> Result<SpeechSmokeReport, InferenceError> {
3389 let mut report = SpeechSmokeReport::default();
3390
3391 if local {
3392 let tts = self
3393 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
3394 .ok_or_else(|| {
3395 InferenceError::InferenceFailed(
3396 "no local text-to-speech model available".into(),
3397 )
3398 })?;
3399 let stt = self
3400 .preferred_speech_schema(ModelCapability::SpeechToText, true, false)
3401 .ok_or_else(|| {
3402 InferenceError::InferenceFailed(
3403 "no local speech-to-text model available".into(),
3404 )
3405 })?;
3406 report.local = Some(
3407 self.run_speech_smoke_path("local", &tts, &stt, "Testing CAR local speech path.")
3408 .await?,
3409 );
3410 } else {
3411 report.skipped.push("local".to_string());
3412 }
3413
3414 if remote {
3415 let tts = self
3416 .preferred_speech_schema(ModelCapability::TextToSpeech, false, true)
3417 .ok_or_else(|| {
3418 InferenceError::InferenceFailed(
3419 "no remote text-to-speech model available".into(),
3420 )
3421 })?;
3422 let stt = self
3423 .preferred_speech_schema(ModelCapability::SpeechToText, false, true)
3424 .ok_or_else(|| {
3425 InferenceError::InferenceFailed(
3426 "no remote speech-to-text model available".into(),
3427 )
3428 })?;
3429 report.remote = Some(
3430 self.run_speech_smoke_path("remote", &tts, &stt, "Testing CAR remote speech path.")
3431 .await?,
3432 );
3433 } else {
3434 report.skipped.push("remote".to_string());
3435 }
3436
3437 Ok(report)
3438 }
3439
3440 fn speech_candidates(
3441 &self,
3442 capability: ModelCapability,
3443 explicit: Option<&str>,
3444 ) -> Result<Vec<ModelSchema>, InferenceError> {
3445 if let Some(model) = explicit {
3446 let schema = self
3447 .unified_registry
3448 .get(model)
3449 .or_else(|| self.unified_registry.find_by_name(model))
3450 .cloned()
3451 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3452 if !schema.has_capability(capability) {
3453 return Err(InferenceError::InferenceFailed(format!(
3454 "model {} does not support {:?}",
3455 schema.name, capability
3456 )));
3457 }
3458 return Ok(vec![schema]);
3459 }
3460
3461 let mut candidates: Vec<ModelSchema> = self
3462 .unified_registry
3463 .query(&ModelFilter {
3464 capabilities: vec![capability],
3465 ..Default::default()
3466 })
3467 .into_iter()
3468 .cloned()
3469 .collect();
3470
3471 if candidates.is_empty() {
3472 return Err(InferenceError::InferenceFailed(format!(
3473 "no models registered for capability {:?}",
3474 capability
3475 )));
3476 }
3477
3478 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3479 if !self.speech_policy.allow_remote_fallback && candidates.iter().any(|model| model.is_local()) {
3480 candidates.retain(|model| model.is_local());
3481 }
3482
3483 Ok(candidates)
3484 }
3485
3486 fn resolve_external_hf_repo(
3491 &self,
3492 explicit: Option<&str>,
3493 capability: ModelCapability,
3494 ) -> Option<String> {
3495 let id = explicit?;
3496 let schema = self
3497 .unified_registry
3498 .get(id)
3499 .or_else(|| self.unified_registry.find_by_name(id))?;
3500 if !schema.has_capability(capability) {
3501 return Some(id.to_string());
3502 }
3503 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3504 return Some(hf_repo.clone());
3505 }
3506 Some(id.to_string())
3507 }
3508
3509 fn media_generation_candidates(
3510 &self,
3511 capability: ModelCapability,
3512 explicit: Option<&str>,
3513 ) -> Result<Vec<ModelSchema>, InferenceError> {
3514 if let Some(model) = explicit {
3515 let schema = self
3516 .unified_registry
3517 .get(model)
3518 .or_else(|| self.unified_registry.find_by_name(model))
3519 .cloned()
3520 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3521 if !schema.has_capability(capability) {
3522 return Err(InferenceError::InferenceFailed(format!(
3523 "model {} does not support {:?}",
3524 schema.name, capability
3525 )));
3526 }
3527 return Ok(vec![schema]);
3528 }
3529
3530 let mut candidates: Vec<ModelSchema> = self
3531 .unified_registry
3532 .query(&ModelFilter {
3533 capabilities: vec![capability],
3534 local_only: true,
3535 ..Default::default()
3536 })
3537 .into_iter()
3538 .cloned()
3539 .collect();
3540 candidates.sort_by_key(|schema| (!schema.available, schema.size_mb()));
3541 if candidates.is_empty() {
3542 return Err(InferenceError::InferenceFailed(format!(
3543 "no models registered for capability {:?}",
3544 capability
3545 )));
3546 }
3547 Ok(candidates)
3548 }
3549
3550 fn preferred_speech_schema(
3551 &self,
3552 capability: ModelCapability,
3553 local_only: bool,
3554 remote_only: bool,
3555 ) -> Option<ModelSchema> {
3556 let available_only = remote_only;
3557 let mut candidates: Vec<ModelSchema> = self
3558 .unified_registry
3559 .query(&ModelFilter {
3560 capabilities: vec![capability],
3561 available_only,
3562 ..Default::default()
3563 })
3564 .into_iter()
3565 .filter(|schema| (!local_only || schema.is_local()) && (!remote_only || schema.is_remote()))
3566 .cloned()
3567 .collect();
3568 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3569 candidates.into_iter().next()
3570 }
3571
3572 fn speech_health_default_name(
3573 &self,
3574 capability: ModelCapability,
3575 local_only: bool,
3576 remote_only: bool,
3577 ) -> Option<String> {
3578 let preferred = match capability {
3579 ModelCapability::SpeechToText if local_only => {
3580 self.speech_policy.preferred_local_stt.as_ref()
3581 }
3582 ModelCapability::SpeechToText if remote_only => {
3583 self.speech_policy.preferred_remote_stt.as_ref()
3584 }
3585 ModelCapability::TextToSpeech if local_only => {
3586 self.speech_policy.preferred_local_tts.as_ref()
3587 }
3588 ModelCapability::TextToSpeech if remote_only => {
3589 self.speech_policy.preferred_remote_tts.as_ref()
3590 }
3591 _ => None,
3592 };
3593
3594 preferred
3595 .filter(|name| {
3596 self.unified_registry.list().iter().any(|schema| {
3597 schema.name == **name
3598 && schema.has_capability(capability)
3599 && (!local_only || schema.is_local())
3600 && (!remote_only || schema.is_remote())
3601 })
3602 })
3603 .cloned()
3604 .or_else(|| {
3605 self.preferred_speech_schema(capability, local_only, remote_only)
3606 .map(|schema| schema.name)
3607 })
3608 }
3609
3610 fn model_default_health(
3611 &self,
3612 capability: ModelCapability,
3613 configured_model: &str,
3614 ) -> ModelDefaultHealth {
3615 let schema = self
3616 .unified_registry
3617 .find_by_name(configured_model)
3618 .or_else(|| self.unified_registry.get(configured_model));
3619
3620 ModelDefaultHealth {
3621 capability,
3622 configured_model: configured_model.to_string(),
3623 available: schema.is_some_and(|model| model.available),
3624 is_local: schema.is_some_and(ModelSchema::is_local),
3625 provider: schema.map(|model| model.provider.clone()),
3626 }
3627 }
3628
3629 fn speech_sort_key(
3630 &self,
3631 capability: ModelCapability,
3632 model: &ModelSchema,
3633 ) -> (u8, u8, u8, u8, u64, u64) {
3634 let policy_preference = match capability {
3635 ModelCapability::SpeechToText if model.is_local() => self
3636 .speech_policy
3637 .preferred_local_stt
3638 .as_ref(),
3639 ModelCapability::SpeechToText => self.speech_policy.preferred_remote_stt.as_ref(),
3640 ModelCapability::TextToSpeech if model.is_local() => self
3641 .speech_policy
3642 .preferred_local_tts
3643 .as_ref(),
3644 ModelCapability::TextToSpeech => self.speech_policy.preferred_remote_tts.as_ref(),
3645 _ => None,
3646 };
3647 let local_rank = if self.speech_policy.prefer_local {
3648 if model.is_local() {
3649 0
3650 } else {
3651 1
3652 }
3653 } else if model.is_remote() {
3654 0
3655 } else {
3656 1
3657 };
3658 let availability_rank = if model.available {
3659 0
3660 } else if model.is_local() {
3661 1
3662 } else {
3663 2
3664 };
3665 let policy_rank: u8 = if policy_preference.is_some_and(|preferred| preferred == &model.name) {
3666 0
3667 } else {
3668 1
3669 };
3670 let speech_rank = match capability {
3671 ModelCapability::TextToSpeech => {
3672 if model.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
3673 0
3674 } else if model.name == "Kokoro-82M-bf16" {
3675 1
3676 } else if model.name == "Kokoro-82M-6bit" {
3677 2
3678 } else {
3679 3
3680 }
3681 }
3682 ModelCapability::SpeechToText => {
3683 if model.name == "Parakeet-TDT-0.6B-v3-MLX" {
3684 0
3685 } else {
3686 1
3687 }
3688 }
3689 _ => 0,
3690 };
3691 let latency_rank = model.performance.latency_p50_ms.unwrap_or(u64::MAX);
3692 let size_rank = model.cost.size_mb.unwrap_or(u64::MAX);
3693 (
3694 local_rank,
3695 availability_rank,
3696 policy_rank,
3697 speech_rank,
3698 latency_rank,
3699 size_rank,
3700 )
3701 }
3702
3703 async fn run_speech_smoke_path(
3704 &self,
3705 path: &str,
3706 tts: &ModelSchema,
3707 stt: &ModelSchema,
3708 text: &str,
3709 ) -> Result<SpeechSmokePathReport, InferenceError> {
3710 let work_dir = temp_work_dir(&format!("speech-smoke-{path}"))?;
3711 let audio_path = work_dir.join(format!("{path}.wav"));
3712 let synth = self
3713 .synthesize(SynthesizeRequest {
3714 text: text.to_string(),
3715 model: Some(tts.name.clone()),
3716 voice: default_speech_voice(tts),
3717 language: Some("en".to_string()),
3718 output_path: Some(audio_path.display().to_string()),
3719 ..SynthesizeRequest::default()
3720 })
3721 .await?;
3722 let transcript = self
3723 .transcribe(TranscribeRequest {
3724 audio_path: synth.audio_path.clone(),
3725 model: Some(stt.name.clone()),
3726 language: Some("en".to_string()),
3727 prompt: None,
3728 timestamps: false,
3729 })
3730 .await?;
3731
3732 Ok(SpeechSmokePathReport {
3733 path: path.to_string(),
3734 tts_model: synth.model_used.unwrap_or_else(|| tts.name.clone()),
3735 stt_model: transcript.model_used.unwrap_or_else(|| stt.name.clone()),
3736 audio_path: PathBuf::from(synth.audio_path),
3737 transcript: transcript.text,
3738 })
3739 }
3740
3741 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3742 async fn ensure_speech_runtime(&self) -> Result<SpeechRuntime, InferenceError> {
3743 let mut guard = self.speech_runtime.lock().await;
3744 if let Some(runtime) = guard.as_ref() {
3745 if runtime.is_ready() {
3746 return Ok(runtime.clone());
3747 }
3748 }
3749
3750 let runtime = SpeechRuntime::new(speech_runtime_root_from_models_dir(
3751 &self.config.models_dir,
3752 ));
3753 if !runtime.is_ready() {
3754 bootstrap_speech_runtime(&runtime).await?;
3755 }
3756 if !runtime.is_ready() {
3757 return Err(InferenceError::InferenceFailed(format!(
3758 "managed speech runtime is not ready at {}",
3759 runtime.root.display()
3760 )));
3761 }
3762
3763 *guard = Some(runtime.clone());
3764 Ok(runtime)
3765 }
3766
3767 async fn transcribe_local_mlx(
3768 &self,
3769 schema: &ModelSchema,
3770 req: &TranscribeRequest,
3771 ) -> Result<TranscribeResult, InferenceError> {
3772 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3774 {
3775 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3776 let parakeet = backend::mlx_parakeet::ParakeetBackend::load(&model_dir)?;
3777 let (text, words) = if req.timestamps {
3779 parakeet
3780 .transcribe_detailed(Path::new(&req.audio_path))
3781 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?
3782 } else {
3783 let t = parakeet
3784 .transcribe(Path::new(&req.audio_path))
3785 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?;
3786 (t, Vec::new())
3787 };
3788 return Ok(TranscribeResult {
3789 text,
3790 model_used: Some(schema.name.clone()),
3791 language: req.language.clone(),
3792 words,
3793 });
3794 }
3795
3796 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3798 {
3799 let runtime = self.ensure_speech_runtime().await?;
3800 let hf_repo = match &schema.source {
3801 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3802 _ => unreachable!(),
3803 };
3804 let output_dir = temp_work_dir("stt")?;
3805 let output_prefix = output_dir.join("transcript");
3806 let mut args = vec![
3807 "--model".to_string(),
3808 hf_repo,
3809 "--audio".to_string(),
3810 req.audio_path.clone(),
3811 "--output-path".to_string(),
3812 output_prefix.display().to_string(),
3813 "--format".to_string(),
3814 "json".to_string(),
3815 ];
3816 if let Some(language) = &req.language {
3817 args.push("--language".to_string());
3818 args.push(normalize_lang_code(language));
3819 }
3820 if let Some(prompt) = &req.prompt {
3821 args.push("--context".to_string());
3822 args.push(prompt.clone());
3823 }
3824 if req.timestamps {
3825 args.push("--verbose".to_string());
3826 }
3827
3828 let output = run_mlx_audio_command(&runtime, "stt.generate", &args).await?;
3829 let text = read_transcription_result(&output_prefix)?
3830 .or_else(|| extract_text_from_payload(&output.stdout))
3831 .ok_or_else(|| {
3832 InferenceError::InferenceFailed(format!(
3833 "mlx-audio transcription returned no text: {}",
3834 output.stderr
3835 ))
3836 })?;
3837
3838 Ok(TranscribeResult {
3839 text,
3840 model_used: Some(schema.name.clone()),
3841 language: req.language.clone(),
3842 words: Vec::new(),
3843 })
3844 }
3845 }
3846
3847 async fn synthesize_local_mlx(
3848 &self,
3849 schema: &ModelSchema,
3850 req: &SynthesizeRequest,
3851 ) -> Result<SynthesizeResult, InferenceError> {
3852 let requested = req.requested_advanced_controls();
3857 let repo_supports_advanced = match &schema.source {
3858 ModelSource::Mlx { hf_repo, .. } => {
3859 hf_repo.to_ascii_lowercase().contains("qwen3-tts")
3860 }
3861 _ => false,
3862 };
3863 if !requested.is_empty() && !repo_supports_advanced {
3864 if req.strict_capabilities {
3865 return Err(InferenceError::InferenceFailed(format!(
3866 "model {name} does not support Qwen3-TTS advanced controls {requested:?}; \
3867 route to a Qwen3-TTS model or set strict_capabilities = false to degrade",
3868 name = schema.name,
3869 )));
3870 }
3871 tracing::warn!(
3872 model = %schema.name,
3873 fields = ?requested,
3874 "Qwen3-TTS advanced controls set on non-Qwen3-TTS backend — ignored \
3875 (set strict_capabilities=true to error instead)"
3876 );
3877 }
3878
3879 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
3881 {
3882 if repo_supports_advanced && !requested.is_empty() {
3888 if req.strict_capabilities {
3889 return Err(InferenceError::InferenceFailed(format!(
3890 "native MLX TTS backend does not yet implement Qwen3-TTS advanced \
3891 controls {requested:?}; run on non-Apple-Silicon to use the Python \
3892 mlx-audio fallback, or set strict_capabilities = false"
3893 )));
3894 }
3895 tracing::warn!(
3896 model = %schema.name,
3897 fields = ?requested,
3898 "Qwen3-TTS advanced controls are not yet implemented in the native MLX TTS \
3899 backend; synthesizing without cloning/voice-design"
3900 );
3901 }
3902 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3903 let size = backend_cache::estimate_model_size(&model_dir);
3904 let cache = Arc::clone(&self.kokoro_cache);
3905 let key = schema.id.clone();
3906 let handle = cache.get_or_load(&key, size, || {
3907 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
3908 })?;
3909
3910 let output_path = req.output_path.clone().unwrap_or_else(|| {
3911 let dir = std::env::temp_dir().join("car_tts");
3912 let _ = std::fs::create_dir_all(&dir);
3913 dir.join("output.wav").display().to_string()
3914 });
3915 let voice = req.voice.as_deref().unwrap_or("af_heart").to_string();
3916 let text = req.text.clone();
3917 let op = tokio::task::spawn_blocking(move || -> Result<PathBuf, InferenceError> {
3918 let mut guard = handle.lock().map_err(|_| {
3919 InferenceError::InferenceFailed("kokoro backend mutex poisoned".into())
3920 })?;
3921 guard
3922 .synthesize(&text, Some(&voice), Path::new(&output_path))
3923 .map_err(|e| InferenceError::InferenceFailed(format!("native TTS: {e}")))
3924 })
3925 .await
3926 .map_err(|e| InferenceError::InferenceFailed(format!("kokoro task join: {e}")))??;
3927
3928 let final_path =
3929 materialize_audio_output(&op, req.output_path.as_deref(), &req.format)?;
3930 return Ok(SynthesizeResult {
3931 audio_path: final_path.display().to_string(),
3932 media_type: media_type_for_format(&req.format),
3933 model_used: Some(schema.name.clone()),
3934 voice_used: req.voice.clone(),
3935 });
3936 }
3937
3938 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3940 {
3941 let runtime = self.ensure_speech_runtime().await?;
3942 let primary_hf_repo = match &schema.source {
3943 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3944 _ => unreachable!(),
3945 };
3946 let (produced, model_used) = match self
3947 .synthesize_local_mlx_repo(&runtime, &primary_hf_repo, schema.name.as_str(), req)
3948 .await
3949 {
3950 Ok(result) => result,
3951 Err(primary_err)
3952 if primary_hf_repo == "mlx-community/Kokoro-82M-6bit"
3953 && kokoro_runtime_fallback_enabled() =>
3954 {
3955 let fallback_repo = "mlx-community/Kokoro-82M-bf16";
3956 let fallback_name = "Kokoro-82M-bf16";
3957 match self
3958 .synthesize_local_mlx_repo(&runtime, fallback_repo, fallback_name, req)
3959 .await
3960 {
3961 Ok(result) => result,
3962 Err(fallback_err) => {
3963 return Err(InferenceError::InferenceFailed(format!(
3964 "{primary_err}; fallback {fallback_name} also failed: {fallback_err}"
3965 )));
3966 }
3967 }
3968 }
3969 Err(err) => return Err(err),
3970 };
3971 let final_path =
3972 materialize_audio_output(&produced, req.output_path.as_deref(), &req.format)?;
3973
3974 Ok(SynthesizeResult {
3975 audio_path: final_path.display().to_string(),
3976 media_type: media_type_for_format(&req.format),
3977 model_used: Some(model_used),
3978 voice_used: req.voice.clone(),
3979 })
3980 }
3981 }
3982
3983 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
3984 async fn synthesize_local_mlx_repo(
3985 &self,
3986 runtime: &SpeechRuntime,
3987 hf_repo: &str,
3988 model_name: &str,
3989 req: &SynthesizeRequest,
3990 ) -> Result<(PathBuf, String), InferenceError> {
3991 let output_dir = temp_work_dir("tts")?;
3992 let mut args = vec![
3993 "--model".to_string(),
3994 hf_repo.to_string(),
3995 "--text".to_string(),
3996 req.text.clone(),
3997 "--output_path".to_string(),
3998 output_dir.display().to_string(),
3999 ];
4000 if let Some(voice) = &req.voice {
4001 args.push("--voice".to_string());
4002 args.push(voice.clone());
4003 }
4004 if let Some(speed) = req.speed {
4005 args.push("--speed".to_string());
4006 args.push(speed.to_string());
4007 }
4008 let repo_lower = hf_repo.to_ascii_lowercase();
4009 if repo_lower.contains("kokoro") {
4010 args.push("--lang_code".to_string());
4011 args.push(kokoro_lang_code(req.language.as_deref()).to_string());
4012 } else if let Some(language) = &req.language {
4013 args.push("--lang_code".to_string());
4014 args.push(normalize_lang_code(language));
4015 }
4016
4017 if repo_lower.contains("qwen3-tts") {
4023 if let Some(ref_audio) = &req.reference_audio_path {
4024 args.push("--ref_audio".to_string());
4025 args.push(ref_audio.clone());
4026 }
4027 if let Some(ref_text) = &req.reference_text {
4028 args.push("--ref_text".to_string());
4029 args.push(ref_text.clone());
4030 }
4031 if let Some(instruct) = &req.voice_instruction {
4032 args.push("--instruct".to_string());
4033 args.push(instruct.clone());
4034 }
4035 }
4036
4037 let output = if repo_lower.contains("kokoro") {
4038 let device = std::env::var("CAR_SPEECH_KOKORO_DEVICE")
4039 .or_else(|_| std::env::var("CAR_SPEECH_MLX_DEVICE"))
4040 .unwrap_or_else(|_| "cpu".to_string());
4041 let extra_env = vec![
4042 ("MLX_DEVICE".to_string(), device),
4044 ("PYTORCH_ENABLE_MPS_FALLBACK".to_string(), "1".to_string()),
4046 ];
4047 run_mlx_audio_command_with_env(runtime, "tts.generate", &args, &extra_env).await?
4048 } else {
4049 run_mlx_audio_command(runtime, "tts.generate", &args).await?
4050 };
4051 let produced = find_audio_file(&output_dir)?.ok_or_else(|| {
4052 let hint = if repo_lower.contains("kokoro") {
4053 ". Kokoro models may crash on GPU — try CAR_SPEECH_KOKORO_DEVICE=cpu or use the default Qwen3-TTS model"
4054 } else {
4055 ""
4056 };
4057 InferenceError::InferenceFailed(format!(
4058 "mlx-audio synthesis produced no audio file: {}{}",
4059 output.stderr, hint
4060 ))
4061 })?;
4062 Ok((produced, model_name.to_string()))
4063 }
4064
4065
4066 async fn transcribe_elevenlabs(
4067 &self,
4068 schema: &ModelSchema,
4069 req: &TranscribeRequest,
4070 ) -> Result<TranscribeResult, InferenceError> {
4071 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4072 let file_name = Path::new(&req.audio_path)
4073 .file_name()
4074 .and_then(|f| f.to_str())
4075 .unwrap_or("audio.wav")
4076 .to_string();
4077 let audio_bytes = tokio::fs::read(&req.audio_path).await?;
4078 let file_part = Part::bytes(audio_bytes).file_name(file_name);
4079 let mut form = Form::new()
4080 .text("model_id", schema.name.clone())
4081 .part("file", file_part);
4082 if let Some(language) = &req.language {
4083 form = form.text("language_code", language.clone());
4084 }
4085
4086 let resp = self
4087 .remote_backend
4088 .client
4089 .post(format!(
4090 "{}/v1/speech-to-text",
4091 endpoint.trim_end_matches('/')
4092 ))
4093 .header("xi-api-key", api_key)
4094 .multipart(form)
4095 .send()
4096 .await
4097 .map_err(|e| {
4098 InferenceError::InferenceFailed(format!("ElevenLabs STT request failed: {e}"))
4099 })?;
4100 let status = resp.status();
4101 let body = resp.text().await.map_err(|e| {
4102 InferenceError::InferenceFailed(format!("read ElevenLabs STT body: {e}"))
4103 })?;
4104 if !status.is_success() {
4105 return Err(InferenceError::InferenceFailed(format!(
4106 "ElevenLabs STT returned {status}: {body}"
4107 )));
4108 }
4109 let payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
4110 InferenceError::InferenceFailed(format!("parse ElevenLabs STT response: {e}"))
4111 })?;
4112 let text = payload
4113 .get("text")
4114 .and_then(|v| v.as_str())
4115 .map(str::to_string)
4116 .ok_or_else(|| {
4117 InferenceError::InferenceFailed("ElevenLabs STT response missing text".into())
4118 })?;
4119
4120 Ok(TranscribeResult {
4121 text,
4122 model_used: Some(schema.name.clone()),
4123 language: payload
4124 .get("language_code")
4125 .and_then(|v| v.as_str())
4126 .map(str::to_string),
4127 words: Vec::new(),
4128 })
4129 }
4130
4131 async fn synthesize_elevenlabs(
4132 &self,
4133 schema: &ModelSchema,
4134 req: &SynthesizeRequest,
4135 ) -> Result<SynthesizeResult, InferenceError> {
4136 let requested = req.requested_advanced_controls();
4140 if !requested.is_empty() {
4141 if req.strict_capabilities {
4142 return Err(InferenceError::InferenceFailed(format!(
4143 "ElevenLabs backend does not support Qwen3-TTS advanced controls \
4144 {requested:?}; route to a Qwen3-TTS model or set strict_capabilities = false"
4145 )));
4146 }
4147 tracing::warn!(
4148 model = %schema.name,
4149 fields = ?requested,
4150 "Qwen3-TTS advanced controls ignored by ElevenLabs backend"
4151 );
4152 }
4153 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4154 let voice_id = req
4155 .voice
4156 .clone()
4157 .unwrap_or_else(|| "JBFqnCBsd6RMkjVDRZzb".to_string());
4158 let output_format = elevenlabs_output_format(&req.format);
4159 let url = format!(
4160 "{}/v1/text-to-speech/{}?output_format={}",
4161 endpoint.trim_end_matches('/'),
4162 voice_id,
4163 output_format
4164 );
4165
4166 let mut body = serde_json::json!({
4167 "text": req.text,
4168 "model_id": schema.name,
4169 });
4170 if let Some(language) = &req.language {
4171 body["language_code"] = serde_json::Value::String(language.clone());
4172 }
4173
4174 let resp = self
4175 .remote_backend
4176 .client
4177 .post(url)
4178 .header("xi-api-key", api_key)
4179 .header("Content-Type", "application/json")
4180 .json(&body)
4181 .send()
4182 .await
4183 .map_err(|e| {
4184 InferenceError::InferenceFailed(format!("ElevenLabs TTS request failed: {e}"))
4185 })?;
4186 let status = resp.status();
4187 let audio = resp.bytes().await.map_err(|e| {
4188 InferenceError::InferenceFailed(format!("read ElevenLabs TTS body: {e}"))
4189 })?;
4190 if !status.is_success() {
4191 let err_body = String::from_utf8_lossy(&audio);
4192 return Err(InferenceError::InferenceFailed(format!(
4193 "ElevenLabs TTS returned {status}: {err_body}"
4194 )));
4195 }
4196
4197 let final_path = requested_or_temp_output(req.output_path.as_deref(), &req.format)?;
4198 ensure_parent_dir(&final_path)?;
4199 tokio::fs::write(&final_path, &audio).await?;
4200
4201 Ok(SynthesizeResult {
4202 audio_path: final_path.display().to_string(),
4203 media_type: media_type_for_format(&req.format),
4204 model_used: Some(schema.name.clone()),
4205 voice_used: Some(voice_id),
4206 })
4207 }
4208}
4209
4210#[derive(Default)]
4211struct ProviderAccumulator {
4212 configured: bool,
4213 local_models: usize,
4214 remote_models: usize,
4215 available_models: usize,
4216 capabilities: std::collections::HashSet<ModelCapability>,
4217}
4218
4219#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4223struct CommandOutput {
4224 stdout: String,
4225 stderr: String,
4226}
4227
4228#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4229#[derive(Debug, Clone)]
4230struct SpeechRuntime {
4231 root: PathBuf,
4232 python: PathBuf,
4233 stt_program: PathBuf,
4234 tts_program: PathBuf,
4235}
4236
4237#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4238impl SpeechRuntime {
4239 fn new(root: PathBuf) -> Self {
4240 let bin_dir = root.join("bin");
4241 Self {
4242 root,
4243 python: bin_dir.join("python"),
4244 stt_program: bin_dir.join("mlx_audio.stt.generate"),
4245 tts_program: bin_dir.join("mlx_audio.tts.generate"),
4246 }
4247 }
4248
4249 fn is_ready(&self) -> bool {
4250 self.python.exists() && self.stt_program.exists() && self.tts_program.exists()
4251 }
4252
4253 fn command_for(&self, subcommand: &str) -> Result<&Path, InferenceError> {
4254 match subcommand {
4255 "stt.generate" => Ok(&self.stt_program),
4256 "tts.generate" => Ok(&self.tts_program),
4257 _ => Err(InferenceError::InferenceFailed(format!(
4258 "unknown speech subcommand: {subcommand}"
4259 ))),
4260 }
4261 }
4262}
4263
4264
4265#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4266async fn run_mlx_audio_command(
4267 runtime: &SpeechRuntime,
4268 subcommand: &str,
4269 args: &[String],
4270) -> Result<CommandOutput, InferenceError> {
4271 run_mlx_audio_command_with_env(runtime, subcommand, args, &[]).await
4272}
4273
4274#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4275async fn run_mlx_audio_command_with_env(
4276 runtime: &SpeechRuntime,
4277 subcommand: &str,
4278 args: &[String],
4279 envs: &[(String, String)],
4280) -> Result<CommandOutput, InferenceError> {
4281 let program = runtime.command_for(subcommand)?;
4282 let mut command = Command::new(program);
4283 command.args(args);
4284 for (key, value) in envs {
4285 command.env(key, value);
4286 }
4287 let output = command.output().await.map_err(|err| {
4288 InferenceError::InferenceFailed(format!("{}: {err}", program.display()))
4289 })?;
4290
4291 if output.status.success() {
4292 Ok(CommandOutput {
4293 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
4294 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
4295 })
4296 } else {
4297 Err(InferenceError::InferenceFailed(format!(
4298 "{} exited with {}: {}",
4299 program.display(),
4300 output.status,
4301 String::from_utf8_lossy(&output.stderr)
4302 )))
4303 }
4304}
4305
4306
4307#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4308async fn bootstrap_speech_runtime(runtime: &SpeechRuntime) -> Result<(), InferenceError> {
4309 std::fs::create_dir_all(&runtime.root)?;
4310 let python = select_speech_python()?;
4311
4312 run_command(
4313 "uv",
4314 &[
4315 "venv".to_string(),
4316 "--python".to_string(),
4317 python,
4318 runtime.root.display().to_string(),
4319 ],
4320 )
4321 .await?;
4322
4323 run_command(
4324 "uv",
4325 &[
4326 "pip".to_string(),
4327 "install".to_string(),
4328 "--python".to_string(),
4329 runtime.python.display().to_string(),
4330 speech_runtime_mlx_audio_spec(),
4331 "misaki[en]".to_string(),
4332 speech_runtime_spacy_model_spec(),
4333 ],
4334 )
4335 .await?;
4336
4337 Ok(())
4338}
4339
4340
4341#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4342async fn run_command(program: &str, args: &[String]) -> Result<(), InferenceError> {
4343 let output = Command::new(program)
4344 .args(args)
4345 .output()
4346 .await
4347 .map_err(|err| InferenceError::InferenceFailed(format!("{program}: {err}")))?;
4348
4349 if output.status.success() {
4350 Ok(())
4351 } else {
4352 Err(InferenceError::InferenceFailed(format!(
4353 "{} exited with {}: {}",
4354 program,
4355 output.status,
4356 String::from_utf8_lossy(&output.stderr)
4357 )))
4358 }
4359}
4360
4361#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4362fn select_speech_python() -> Result<String, InferenceError> {
4363 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4364 if !path.trim().is_empty() {
4365 return Ok(path);
4366 }
4367 }
4368
4369 for candidate in ["python3.13", "python3.12", "python3.11"] {
4370 if command_in_path(candidate) {
4371 return Ok(candidate.to_string());
4372 }
4373 }
4374
4375 Err(InferenceError::InferenceFailed(
4376 "no supported Python found for managed speech runtime (tried python3.13, python3.12, python3.11)".into(),
4377 ))
4378}
4379
4380#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4381fn detect_speech_python() -> Option<String> {
4382 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4383 if !path.trim().is_empty() {
4384 return Some(path);
4385 }
4386 }
4387
4388 ["python3.13", "python3.12", "python3.11"]
4389 .into_iter()
4390 .find(|candidate| command_in_path(candidate))
4391 .map(str::to_string)
4392}
4393
4394#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4395fn speech_runtime_root_from_models_dir(_models_dir: &Path) -> PathBuf {
4396 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
4397 if !path.trim().is_empty() {
4398 return PathBuf::from(path);
4399 }
4400 }
4401
4402 std::env::var("HOME")
4403 .map(PathBuf::from)
4404 .unwrap_or_else(|_| PathBuf::from("."))
4405 .join(".car")
4406 .join("speech-runtime")
4407}
4408
4409#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4410fn command_in_path(name: &str) -> bool {
4411 std::env::var_os("PATH")
4412 .map(|paths| {
4413 std::env::split_paths(&paths).any(|dir| {
4414 let path = dir.join(name);
4415 path.exists() && path.is_file()
4416 })
4417 })
4418 .unwrap_or(false)
4419}
4420
4421fn speech_model_cached(schema: &ModelSchema) -> bool {
4422 match &schema.source {
4423 ModelSource::Mlx { hf_repo, .. } => {
4424 huggingface_repo_has_snapshot(hf_repo)
4425 }
4426 ModelSource::Proprietary { auth, .. } => match auth {
4427 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4428 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4429 ProprietaryAuth::OAuth2Pkce { .. } => false,
4430 },
4431 _ => false,
4432 }
4433}
4434
4435fn remove_huggingface_repo_cache(repo_id: &str) -> Result<(), InferenceError> {
4436 let repo_dir = std::env::var("HF_HOME")
4437 .map(PathBuf::from)
4438 .unwrap_or_else(|_| {
4439 std::env::var("HOME")
4440 .map(PathBuf::from)
4441 .unwrap_or_else(|_| PathBuf::from("."))
4442 .join(".cache")
4443 .join("huggingface")
4444 })
4445 .join("hub")
4446 .join(format!("models--{}", repo_id.replace('/', "--")));
4447
4448 if repo_dir.exists() {
4449 std::fs::remove_dir_all(repo_dir)?;
4450 }
4451 Ok(())
4452}
4453
4454fn model_source_configured(schema: &ModelSchema) -> bool {
4455 match &schema.source {
4456 ModelSource::RemoteApi {
4457 api_key_env,
4458 api_key_envs,
4459 ..
4460 } => {
4461 std::env::var(api_key_env).is_ok()
4462 || api_key_envs
4463 .iter()
4464 .any(|env_var| std::env::var(env_var).is_ok())
4465 }
4466 ModelSource::Proprietary { auth, .. } => match auth {
4467 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4468 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4469 ProprietaryAuth::OAuth2Pkce { .. } => false,
4470 },
4471 ModelSource::VllmMlx { .. } => std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available,
4472 ModelSource::Ollama { .. } => schema.available,
4473 ModelSource::Mlx { .. } | ModelSource::Local { .. } => {
4474 true
4475 }
4476 ModelSource::AppleFoundationModels { .. } => schema.available,
4477 ModelSource::Delegated { .. } => true,
4482 }
4483}
4484
4485fn all_model_capabilities() -> [ModelCapability; 13] {
4486 [
4487 ModelCapability::Generate,
4488 ModelCapability::Embed,
4489 ModelCapability::Classify,
4490 ModelCapability::Code,
4491 ModelCapability::Reasoning,
4492 ModelCapability::Summarize,
4493 ModelCapability::ToolUse,
4494 ModelCapability::MultiToolCall,
4495 ModelCapability::Vision,
4496 ModelCapability::SpeechToText,
4497 ModelCapability::TextToSpeech,
4498 ModelCapability::ImageGeneration,
4499 ModelCapability::VideoGeneration,
4500 ]
4501}
4502
4503fn sort_capabilities(mut capabilities: Vec<ModelCapability>) -> Vec<ModelCapability> {
4504 capabilities.sort_by_key(|capability| {
4505 all_model_capabilities()
4506 .iter()
4507 .position(|candidate| candidate == capability)
4508 .unwrap_or(usize::MAX)
4509 });
4510 capabilities
4511}
4512
4513fn speech_model_source_label(schema: &ModelSchema) -> String {
4514 match &schema.source {
4515 ModelSource::Mlx { hf_repo, .. } => format!("mlx:{hf_repo}"),
4516 ModelSource::Proprietary {
4517 provider, endpoint, ..
4518 } => format!("proprietary:{provider}:{endpoint}"),
4519 ModelSource::RemoteApi { endpoint, .. } => format!("remote:{endpoint}"),
4520 ModelSource::Local { hf_repo, .. } => format!("local:{hf_repo}"),
4521 ModelSource::VllmMlx {
4522 endpoint,
4523 model_name,
4524 } => format!("vllm-mlx:{endpoint}:{model_name}"),
4525 ModelSource::Ollama { model_tag, host } => format!("ollama:{host}:{model_tag}"),
4526 ModelSource::AppleFoundationModels { use_case } => {
4527 format!("apple-foundation:{}", use_case.as_deref().unwrap_or("default"))
4528 }
4529 ModelSource::Delegated { hint } => {
4530 format!("delegated:{}", hint.as_deref().unwrap_or("(none)"))
4531 }
4532 }
4533}
4534
4535fn rerank_prompt(instruction: &str, query: &str, document: &str) -> String {
4543 const SYSTEM: &str = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".";
4544 format!(
4545 "<|im_start|>system\n{SYSTEM}<|im_end|>\n\
4546 <|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n\
4547 <|im_start|>assistant\n<think>\n\n</think>\n\n"
4548 )
4549}
4550
4551fn score_from_rerank_output(text: &str, model_name: &str) -> f32 {
4557 let normalized: String = text
4562 .to_ascii_lowercase()
4563 .chars()
4564 .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
4565 .collect();
4566 for tok in normalized.split_ascii_whitespace().take(5) {
4567 match tok {
4568 "yes" => return 1.0,
4569 "no" => return 0.0,
4570 _ => continue,
4571 }
4572 }
4573 tracing::warn!(
4574 model = %model_name,
4575 output = %text,
4576 "rerank: first tokens contain neither `yes` nor `no`; returning neutral 0.5"
4577 );
4578 0.5
4579}
4580
4581fn default_speech_voice(schema: &ModelSchema) -> Option<String> {
4582 if schema.provider == "elevenlabs" {
4583 Some("JBFqnCBsd6RMkjVDRZzb".to_string())
4584 } else if schema.name == "Kokoro-82M-6bit" || schema.name == "Kokoro-82M-bf16" {
4585 Some("af_heart".to_string())
4586 } else if schema.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
4587 Some("Chelsie".to_string())
4588 } else {
4589 None
4590 }
4591}
4592
4593fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
4594 find_latest_huggingface_snapshot(repo_id).is_some()
4595}
4596
4597fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
4598 let cache_root = std::env::var("HF_HOME")
4599 .map(PathBuf::from)
4600 .unwrap_or_else(|_| {
4601 std::env::var("HOME")
4602 .map(PathBuf::from)
4603 .unwrap_or_else(|_| PathBuf::from("."))
4604 .join(".cache")
4605 .join("huggingface")
4606 })
4607 .join("hub");
4608 cache_root.join(format!("models--{}", repo_id.replace('/', "--")))
4609}
4610
4611fn find_latest_huggingface_snapshot(repo_id: &str) -> Option<PathBuf> {
4612 let snapshots = huggingface_repo_dir(repo_id).join("snapshots");
4613 std::fs::read_dir(snapshots)
4614 .ok()?
4615 .filter_map(Result::ok)
4616 .map(|entry| entry.path())
4617 .find(|path| path.is_dir() && snapshot_looks_ready(path))
4618}
4619
4620fn snapshot_looks_ready(path: &Path) -> bool {
4621 if path.join("config.json").exists() || path.join("model_index.json").exists() {
4622 return true;
4623 }
4624 snapshot_contains_ext(path, "safetensors")
4625}
4626
4627fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
4628 let Ok(entries) = std::fs::read_dir(root) else {
4629 return false;
4630 };
4631 entries.filter_map(Result::ok).any(|entry| {
4632 let path = entry.path();
4633 if path.is_dir() {
4634 snapshot_contains_ext(&path, ext)
4635 } else {
4636 path.extension()
4637 .and_then(|value| value.to_str())
4638 .map(|value| value.eq_ignore_ascii_case(ext))
4639 .unwrap_or(false)
4640 }
4641 })
4642}
4643
4644fn count_files_recursive(root: &Path) -> usize {
4645 let Ok(entries) = std::fs::read_dir(root) else {
4646 return 0;
4647 };
4648 entries
4649 .filter_map(Result::ok)
4650 .map(|entry| entry.path())
4651 .map(|path| {
4652 if path.is_dir() {
4653 count_files_recursive(&path)
4654 } else if path.is_file() {
4655 1
4656 } else {
4657 0
4658 }
4659 })
4660 .sum()
4661}
4662
4663async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
4664 let api = hf_hub::api::tokio::ApiBuilder::from_env()
4665 .with_progress(false)
4666 .build()
4667 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
4668 let repo = api.model(repo_id.to_string());
4669 let info = repo
4670 .info()
4671 .await
4672 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
4673
4674 let snapshot_path = huggingface_repo_dir(repo_id)
4675 .join("snapshots")
4676 .join(&info.sha);
4677 let mut downloaded = 0usize;
4678 for sibling in &info.siblings {
4679 let local_path = snapshot_path.join(&sibling.rfilename);
4680 if local_path.exists() {
4681 downloaded += 1;
4682 continue;
4683 }
4684 repo.download(&sibling.rfilename).await.map_err(|e| {
4685 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
4686 })?;
4687 downloaded += 1;
4688 }
4689
4690 Ok((snapshot_path, downloaded))
4691}
4692
4693fn temp_work_dir(prefix: &str) -> Result<PathBuf, InferenceError> {
4694 let unique = SystemTime::now()
4695 .duration_since(UNIX_EPOCH)
4696 .map_err(|e| InferenceError::InferenceFailed(format!("clock error: {e}")))?
4697 .as_nanos();
4698 let dir = std::env::temp_dir().join(format!("car-inference-{prefix}-{unique}"));
4699 std::fs::create_dir_all(&dir)?;
4700 Ok(dir)
4701}
4702
4703fn ensure_parent_dir(path: &Path) -> Result<(), InferenceError> {
4704 if let Some(parent) = path.parent() {
4705 std::fs::create_dir_all(parent)?;
4706 }
4707 Ok(())
4708}
4709
4710fn requested_or_temp_output(
4711 output_path: Option<&str>,
4712 format: &str,
4713) -> Result<PathBuf, InferenceError> {
4714 if let Some(path) = output_path {
4715 return Ok(PathBuf::from(path));
4716 }
4717 let dir = temp_work_dir("audio-out")?;
4718 Ok(dir.join(format!("speech.{format}")))
4719}
4720
4721fn requested_or_temp_media_output(
4722 output_path: Option<&str>,
4723 format: &str,
4724 stem: &str,
4725) -> Result<PathBuf, InferenceError> {
4726 if let Some(path) = output_path {
4727 return Ok(PathBuf::from(path));
4728 }
4729 let dir = temp_work_dir(&format!("{stem}-out"))?;
4730 Ok(dir.join(format!("{stem}.{format}")))
4731}
4732
4733fn materialize_audio_output(
4734 produced: &Path,
4735 requested: Option<&str>,
4736 format: &str,
4737) -> Result<PathBuf, InferenceError> {
4738 if let Some(path) = requested {
4739 let dest = PathBuf::from(path);
4740 ensure_parent_dir(&dest)?;
4741 std::fs::copy(produced, &dest)?;
4742 Ok(dest)
4743 } else {
4744 let dest = requested_or_temp_output(None, format)?;
4745 ensure_parent_dir(&dest)?;
4746 std::fs::copy(produced, &dest)?;
4747 Ok(dest)
4748 }
4749}
4750
4751fn materialize_binary_output(
4752 produced: &Path,
4753 requested: Option<&str>,
4754 format: &str,
4755 stem: &str,
4756) -> Result<PathBuf, InferenceError> {
4757 let dest = requested_or_temp_media_output(requested, format, stem)?;
4758 ensure_parent_dir(&dest)?;
4759 std::fs::copy(produced, &dest)?;
4760 Ok(dest)
4761}
4762
4763fn find_generated_file(
4764 root: &Path,
4765 extensions: &[&str],
4766) -> Result<Option<PathBuf>, InferenceError> {
4767 let entries = std::fs::read_dir(root)?;
4768 let mut candidates: Vec<PathBuf> = entries
4769 .filter_map(Result::ok)
4770 .map(|entry| entry.path())
4771 .filter(|path| {
4772 path.is_file()
4773 && path
4774 .extension()
4775 .and_then(|ext| ext.to_str())
4776 .map(|ext| extensions.iter().any(|candidate| candidate.eq_ignore_ascii_case(ext)))
4777 .unwrap_or(false)
4778 })
4779 .collect();
4780 candidates.sort();
4781 Ok(candidates.pop())
4782}
4783
4784fn media_type_for_image_format(format: &str) -> String {
4785 match format.to_ascii_lowercase().as_str() {
4786 "jpg" | "jpeg" => "image/jpeg".to_string(),
4787 "webp" => "image/webp".to_string(),
4788 _ => "image/png".to_string(),
4789 }
4790}
4791
4792fn media_type_for_video_format(format: &str) -> String {
4793 match format.to_ascii_lowercase().as_str() {
4794 "mov" => "video/quicktime".to_string(),
4795 "gif" => "image/gif".to_string(),
4796 _ => "video/mp4".to_string(),
4797 }
4798}
4799
4800fn read_transcription_result(output_prefix: &Path) -> Result<Option<String>, InferenceError> {
4801 let candidates = [
4802 output_prefix.with_extension("json"),
4803 output_prefix.to_path_buf(),
4804 ];
4805
4806 for path in candidates {
4807 if path.exists() {
4808 let contents = std::fs::read_to_string(path)?;
4809 if let Some(text) = extract_text_from_payload(&contents) {
4810 return Ok(Some(text));
4811 }
4812 }
4813 }
4814
4815 Ok(None)
4816}
4817
4818fn extract_text_from_payload(payload: &str) -> Option<String> {
4819 let value: serde_json::Value = serde_json::from_str(payload).ok()?;
4820 if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
4821 return Some(text.to_string());
4822 }
4823 if let Some(transcripts) = value.get("transcripts").and_then(|v| v.as_array()) {
4824 let joined = transcripts
4825 .iter()
4826 .filter_map(|item| item.get("text").and_then(|v| v.as_str()))
4827 .collect::<Vec<_>>()
4828 .join("\n");
4829 if !joined.is_empty() {
4830 return Some(joined);
4831 }
4832 }
4833 if let Some(items) = value.as_array() {
4834 let joined = items
4835 .iter()
4836 .filter_map(|item| {
4837 item.get("text")
4838 .or_else(|| item.get("Content"))
4839 .and_then(|v| v.as_str())
4840 })
4841 .collect::<Vec<_>>()
4842 .join(" ");
4843 if !joined.is_empty() {
4844 return Some(joined);
4845 }
4846 }
4847 None
4848}
4849
4850fn find_audio_file(output_dir: &Path) -> Result<Option<PathBuf>, InferenceError> {
4851 let mut audio_files = Vec::new();
4852 collect_audio_files(output_dir, &mut audio_files)?;
4853 audio_files.sort();
4854 Ok(audio_files.into_iter().next())
4855}
4856
4857fn collect_audio_files(dir: &Path, audio_files: &mut Vec<PathBuf>) -> Result<(), InferenceError> {
4858 for entry in std::fs::read_dir(dir)? {
4859 let path = entry?.path();
4860 if path.is_dir() {
4861 collect_audio_files(&path, audio_files)?;
4862 } else if matches!(
4863 path.extension().and_then(|ext| ext.to_str()),
4864 Some("wav" | "mp3" | "flac" | "pcm" | "m4a")
4865 ) {
4866 audio_files.push(path);
4867 }
4868 }
4869 Ok(())
4870}
4871
4872fn media_type_for_format(format: &str) -> String {
4873 match format.to_ascii_lowercase().as_str() {
4874 "mp3" => "audio/mpeg".to_string(),
4875 "flac" => "audio/flac".to_string(),
4876 "pcm" => "audio/L16".to_string(),
4877 "m4a" => "audio/mp4".to_string(),
4878 _ => "audio/wav".to_string(),
4879 }
4880}
4881
4882#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4883fn kokoro_lang_code(language: Option<&str>) -> &'static str {
4884 match language.unwrap_or("en").to_ascii_lowercase().as_str() {
4885 "en-gb" | "british" | "british english" => "b",
4886 "ja" | "japanese" => "j",
4887 "zh" | "zh-cn" | "mandarin" | "chinese" => "z",
4888 "es" | "spanish" => "e",
4889 "fr" | "french" => "f",
4890 _ => "a",
4891 }
4892}
4893
4894fn normalize_lang_code(language: &str) -> String {
4895 match language.to_ascii_lowercase().as_str() {
4896 "english" | "en-us" | "en_us" => "en".to_string(),
4897 "spanish" => "es".to_string(),
4898 "french" => "fr".to_string(),
4899 "japanese" => "ja".to_string(),
4900 "chinese" | "mandarin" => "zh".to_string(),
4901 other => match other {
4902 "en" | "es" | "fr" | "ja" | "zh" => other.to_string(),
4903 _ => "en".to_string(),
4904 },
4905 }
4906}
4907
4908fn elevenlabs_auth(schema: &ModelSchema) -> Result<(String, String), InferenceError> {
4909 match &schema.source {
4910 ModelSource::Proprietary {
4911 endpoint,
4912 auth: schema::ProprietaryAuth::ApiKeyEnv { env_var },
4913 ..
4914 } => {
4915 let key = car_secrets::resolve_env_or_keychain(env_var).ok_or_else(|| {
4916 InferenceError::InferenceFailed(format!(
4917 "missing API key {env_var}; set the environment variable or \
4918 store it with `car secrets put {env_var}`"
4919 ))
4920 })?;
4921 Ok((endpoint.clone(), key))
4922 }
4923 _ => Err(InferenceError::InferenceFailed(format!(
4924 "model {} is not an ElevenLabs proprietary model",
4925 schema.id
4926 ))),
4927 }
4928}
4929
4930fn elevenlabs_output_format(format: &str) -> &'static str {
4931 match format.to_ascii_lowercase().as_str() {
4932 "mp3" => "mp3_44100_128",
4933 "pcm" => "pcm_16000",
4934 _ => "wav_44100",
4935 }
4936}
4937
4938fn benchmark_priors_paths(models_dir: &Path) -> Vec<PathBuf> {
4939 let mut paths = Vec::new();
4940
4941 let direct = models_dir.join("benchmark_priors.json");
4942 if !paths.contains(&direct) {
4943 paths.push(direct);
4944 }
4945
4946 if let Some(parent) = models_dir.parent() {
4947 let parent_path = parent.join("benchmark_priors.json");
4948 if !paths.contains(&parent_path) {
4949 paths.push(parent_path);
4950 }
4951 }
4952
4953 if let Some(path) = std::env::var_os("CAR_BENCHMARK_PRIORS_PATH") {
4954 let path = PathBuf::from(path);
4955 if !paths.contains(&path) {
4956 paths.push(path);
4957 }
4958 }
4959
4960 paths
4961}
4962
4963fn load_benchmark_prior_health(
4964 models_dir: &Path,
4965 schemas: &[ModelSchema],
4966) -> Vec<ModelBenchmarkPriorHealth> {
4967 let mut priors = std::collections::BTreeMap::new();
4968 for path in benchmark_priors_paths(models_dir) {
4969 let Ok(loaded) = routing_ext::load_benchmark_priors(&path) else {
4970 continue;
4971 };
4972 for (model_id, prior) in loaded {
4973 let model_name = schemas
4974 .iter()
4975 .find(|schema| schema.id == model_id)
4976 .map(|schema| schema.name.clone());
4977 priors.insert(
4978 model_id.clone(),
4979 ModelBenchmarkPriorHealth {
4980 model_id,
4981 model_name,
4982 overall_score: prior.overall_score,
4983 overall_latency_ms: prior.overall_latency_ms,
4984 task_scores: prior.task_scores,
4985 task_latency_ms: prior.task_latency_ms,
4986 source_path: path.clone(),
4987 },
4988 );
4989 }
4990 }
4991
4992 priors.into_values().collect()
4993}
4994
4995#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4996fn kokoro_runtime_fallback_enabled() -> bool {
4997 std::env::var("CAR_SPEECH_KOKORO_FALLBACK")
4998 .ok()
4999 .map(|value| !matches!(value.trim().to_ascii_lowercase().as_str(), "0" | "false" | "off"))
5000 .unwrap_or(true)
5001}
5002
5003#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5004fn speech_runtime_mlx_audio_spec() -> String {
5005 std::env::var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC")
5006 .ok()
5007 .filter(|value| !value.trim().is_empty())
5008 .unwrap_or_else(|| "mlx-audio==0.4.2".to_string())
5009}
5010
5011#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5012fn speech_runtime_spacy_model_spec() -> String {
5013 std::env::var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC")
5014 .ok()
5015 .filter(|value| !value.trim().is_empty())
5016 .unwrap_or_else(|| {
5017 "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl".to_string()
5018 })
5019}
5020
5021#[cfg(test)]
5022mod tests {
5023 use super::*;
5024 use tempfile::TempDir;
5025
5026 static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
5029
5030 fn test_config(models_dir: PathBuf) -> InferenceConfig {
5031 InferenceConfig {
5032 models_dir,
5033 device: None,
5034 generation_model: "Qwen3-0.6B".into(),
5035 preferred_generation_model: None,
5036 embedding_model: "Qwen3-Embedding-0.6B".into(),
5037 preferred_embedding_model: None,
5038 classification_model: "Qwen3-0.6B".into(),
5039 preferred_classification_model: None,
5040 }
5041 }
5042
5043 #[tokio::test]
5044 async fn tokenize_rejects_known_remote_model_with_unsupported_mode() {
5045 let tmp = TempDir::new().unwrap();
5050 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5051 let remote_id = engine
5052 .list_schemas()
5053 .into_iter()
5054 .find(|s| !s.is_local())
5055 .map(|s| s.id)
5056 .expect(
5057 "built-in catalog should include at least one remote model schema",
5058 );
5059
5060 let err = engine
5061 .tokenize(&remote_id, "hello")
5062 .await
5063 .expect_err("remote tokenize must error");
5064 match err {
5065 InferenceError::UnsupportedMode { mode, backend, .. } => {
5066 assert_eq!(mode, "tokenize/detokenize");
5067 assert_eq!(backend, "remote");
5068 }
5069 other => panic!("expected UnsupportedMode, got {other:?}"),
5070 }
5071
5072 let err = engine
5073 .detokenize(&remote_id, &[1, 2, 3])
5074 .await
5075 .expect_err("remote detokenize must error");
5076 assert!(
5077 matches!(err, InferenceError::UnsupportedMode { .. }),
5078 "expected UnsupportedMode, got {err:?}"
5079 );
5080 }
5081
5082 #[test]
5083 fn engine_loads_benchmark_priors_on_startup() {
5084 let _env = ENV_MUTEX.lock().unwrap();
5085 let tmp = TempDir::new().unwrap();
5086 let priors_path = tmp.path().join("benchmark_priors.json");
5087 std::fs::write(
5088 &priors_path,
5089 serde_json::json!({
5090 "model_id": "qwen/qwen3-8b:q4_k_m",
5091 "overall_score": 0.88
5092 })
5093 .to_string(),
5094 )
5095 .unwrap();
5096
5097 unsafe {
5098 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5099 }
5100
5101 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5102 let tracker = engine.outcome_tracker.blocking_read();
5103 let profile = tracker
5104 .profile("qwen/qwen3-8b:q4_k_m")
5105 .expect("benchmark prior should create a profile");
5106 assert!((profile.ema_quality - 0.88).abs() < 0.01);
5107
5108 unsafe {
5109 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5110 }
5111 }
5112
5113 #[test]
5114 fn benchmark_priors_do_not_override_observed_profiles() {
5115 let _env = ENV_MUTEX.lock().unwrap();
5116 let tmp = TempDir::new().unwrap();
5117 let models_dir = tmp.path().join("models");
5118 std::fs::create_dir_all(&models_dir).unwrap();
5119
5120 let observed = vec![ModelProfile {
5121 model_id: "qwen/qwen3-8b:q4_k_m".into(),
5122 total_calls: 12,
5123 success_count: 3,
5124 fail_count: 9,
5125 total_latency_ms: 1200,
5126 total_input_tokens: 0,
5127 total_output_tokens: 0,
5128 task_stats: std::collections::HashMap::new(),
5129 ema_quality: 0.21,
5130 quality_per_1k_tokens: 0.0,
5131 updated_at: 1,
5132 }];
5133 std::fs::write(
5134 models_dir.join("outcome_profiles.json"),
5135 serde_json::to_string(&observed).unwrap(),
5136 )
5137 .unwrap();
5138
5139 let priors_path = tmp.path().join("benchmark_priors.json");
5140 std::fs::write(
5141 &priors_path,
5142 serde_json::json!({
5143 "model_id": "qwen/qwen3-8b:q4_k_m",
5144 "overall_score": 0.95
5145 })
5146 .to_string(),
5147 )
5148 .unwrap();
5149
5150 unsafe {
5151 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5152 }
5153
5154 let engine = InferenceEngine::new(test_config(models_dir));
5155 let tracker = engine.outcome_tracker.blocking_read();
5156 let profile = tracker
5157 .profile("qwen/qwen3-8b:q4_k_m")
5158 .expect("observed profile should remain present");
5159 assert!((profile.ema_quality - 0.21).abs() < 0.01);
5160 assert_eq!(profile.total_calls, 12);
5161
5162 unsafe {
5163 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5164 }
5165 }
5166
5167 #[test]
5168 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5169 fn speech_runtime_package_spec_defaults_and_overrides() {
5170 let _env = ENV_MUTEX.lock().unwrap();
5171 unsafe {
5172 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5173 }
5174 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.2");
5175
5176 unsafe {
5177 std::env::set_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC", "mlx-audio==0.4.1");
5178 }
5179 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.1");
5180
5181 unsafe {
5182 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5183 }
5184 }
5185
5186 #[test]
5187 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5188 fn speech_runtime_spacy_model_spec_defaults_and_overrides() {
5189 let _env = ENV_MUTEX.lock().unwrap();
5190 unsafe {
5191 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5192 }
5193 assert!(
5194 speech_runtime_spacy_model_spec().starts_with("en-core-web-sm @ https://github.com/")
5195 );
5196
5197 unsafe {
5198 std::env::set_var(
5199 "CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC",
5200 "en-core-web-sm==3.8.0",
5201 );
5202 }
5203 assert_eq!(
5204 speech_runtime_spacy_model_spec(),
5205 "en-core-web-sm==3.8.0"
5206 );
5207
5208 unsafe {
5209 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5210 }
5211 }
5212
5213 #[test]
5214 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
5215 fn kokoro_runtime_fallback_defaults_on() {
5216 unsafe {
5217 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5218 }
5219 assert!(kokoro_runtime_fallback_enabled());
5220
5221 unsafe {
5222 std::env::set_var("CAR_SPEECH_KOKORO_FALLBACK", "false");
5223 }
5224 assert!(!kokoro_runtime_fallback_enabled());
5225
5226 unsafe {
5227 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5228 }
5229 }
5230
5231 #[test]
5232 fn preferred_local_tts_wins_over_builtin_rank() {
5233 let tmp = TempDir::new().unwrap();
5234 let mut engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5235 engine.set_speech_policy(SpeechPolicy {
5236 prefer_local: true,
5237 allow_remote_fallback: false,
5238 preferred_local_stt: None,
5239 preferred_local_tts: Some("Kokoro-82M-6bit".into()),
5240 preferred_remote_stt: None,
5241 preferred_remote_tts: None,
5242 });
5243
5244 let schema = engine
5245 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
5246 .expect("preferred local TTS should resolve");
5247 assert_eq!(schema.name, "Kokoro-82M-6bit");
5248 }
5249
5250 #[test]
5251 fn preferred_discovered_vllm_mlx_model_wins_generate_routing() {
5252 let tmp = TempDir::new().unwrap();
5253 let mut config = test_config(tmp.path().join("models"));
5254 config.preferred_generation_model =
5255 Some("vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit".into());
5256 let mut engine = InferenceEngine::new(config);
5257 let schema = crate::vllm_mlx::to_model_schema(
5258 &crate::vllm_mlx::DiscoveredModel {
5259 id: "mlx-community/gemma-3n-E2B-it-lm-4bit".into(),
5260 owned_by: Some("mlx-community".into()),
5261 },
5262 "http://127.0.0.1:8001",
5263 );
5264 engine.register_model(schema);
5265
5266 let rt = tokio::runtime::Runtime::new().unwrap();
5267 let decision = rt.block_on(engine.route_adaptive("say hello in one sentence"));
5268 assert_eq!(decision.model_id, "vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit");
5269 assert_eq!(decision.strategy, RoutingStrategy::Explicit);
5270 assert_eq!(decision.reason, "preferred generation model override");
5271 }
5272
5273 #[test]
5278 fn inference_result_serializes_with_full_shape() {
5279 use crate::tasks::generate::ToolCall;
5280 use std::collections::HashMap;
5281
5282 let mut args = HashMap::new();
5283 args.insert("path".to_string(), serde_json::json!("README.md"));
5284
5285 let result = InferenceResult {
5286 text: String::new(),
5287 bounding_boxes: Vec::new(),
5288 tool_calls: vec![ToolCall {
5289 id: None,
5290 name: "read_file".into(),
5291 arguments: args,
5292 }],
5293 trace_id: "trace-abc".into(),
5294 model_used: "test-model".into(),
5295 latency_ms: 1234,
5296 time_to_first_token_ms: Some(180),
5297 usage: Some(TokenUsage {
5298 prompt_tokens: 100,
5299 completion_tokens: 50,
5300 total_tokens: 150,
5301 context_window: 8192,
5302 }),
5303 provider_output_items: Vec::new(),
5304 };
5305
5306 let json = serde_json::to_value(&result).expect("serialize");
5307
5308 assert_eq!(json["text"].as_str(), Some(""));
5310 assert_eq!(json["trace_id"].as_str(), Some("trace-abc"));
5311 assert_eq!(json["model_used"].as_str(), Some("test-model"));
5312 assert_eq!(json["latency_ms"].as_u64(), Some(1234));
5313
5314 let tool_calls = json["tool_calls"].as_array().expect("tool_calls array");
5316 assert_eq!(tool_calls.len(), 1);
5317 assert_eq!(tool_calls[0]["name"].as_str(), Some("read_file"));
5318 assert_eq!(tool_calls[0]["arguments"]["path"].as_str(), Some("README.md"));
5319
5320 let usage = &json["usage"];
5322 assert_eq!(usage["prompt_tokens"].as_u64(), Some(100));
5323 assert_eq!(usage["completion_tokens"].as_u64(), Some(50));
5324 assert_eq!(usage["total_tokens"].as_u64(), Some(150));
5325 assert_eq!(usage["context_window"].as_u64(), Some(8192));
5326
5327 assert_eq!(json["time_to_first_token_ms"].as_u64(), Some(180));
5329 }
5330
5331 #[test]
5337 fn inference_result_top_level_keys_are_locked() {
5338 use std::collections::BTreeSet;
5339
5340 let result = InferenceResult {
5341 text: "anything".into(),
5342 bounding_boxes: Vec::new(),
5343 tool_calls: vec![],
5344 trace_id: "t".into(),
5345 model_used: "m".into(),
5346 latency_ms: 0,
5347 time_to_first_token_ms: None,
5348 usage: None,
5349 provider_output_items: Vec::new(),
5350 };
5351
5352 let json = serde_json::to_value(&result).expect("serialize");
5353 let keys: BTreeSet<&str> = json
5354 .as_object()
5355 .expect("top-level object")
5356 .keys()
5357 .map(String::as_str)
5358 .collect();
5359
5360 let expected: BTreeSet<&str> = [
5361 "text",
5362 "tool_calls",
5363 "trace_id",
5364 "model_used",
5365 "latency_ms",
5366 "time_to_first_token_ms",
5367 "usage",
5368 ]
5369 .into_iter()
5370 .collect();
5371
5372 assert_eq!(
5373 keys, expected,
5374 "infer response top-level keys drifted -- update both the test \
5375 and the WebSocket protocol documentation if this is intentional"
5376 );
5377
5378 for key in &keys {
5380 assert!(
5381 !key.chars().any(|c| c.is_uppercase()) && !key.contains('-'),
5382 "key '{}' is not snake_case",
5383 key
5384 );
5385 }
5386 }
5387
5388 #[test]
5392 fn inference_result_serializes_plain_text_response() {
5393 let result = InferenceResult {
5394 text: "hello world".into(),
5395 bounding_boxes: Vec::new(),
5396 tool_calls: vec![],
5397 trace_id: "trace-xyz".into(),
5398 model_used: "test-model".into(),
5399 latency_ms: 42,
5400 time_to_first_token_ms: None,
5401 usage: None,
5402 provider_output_items: Vec::new(),
5403 };
5404
5405 let json = serde_json::to_value(&result).expect("serialize");
5406 assert_eq!(json["text"], "hello world");
5407 assert!(json["tool_calls"].is_array());
5408 assert_eq!(json["tool_calls"].as_array().unwrap().len(), 0);
5409 assert_eq!(json["model_used"], "test-model");
5410 assert!(json["usage"].is_null());
5411 assert!(json["time_to_first_token_ms"].is_null());
5414 }
5415
5416 #[test]
5428 fn generate_request_deserializes_intent_field_from_json_rpc_params() {
5429 use crate::intent::{IntentHint, TaskHint};
5430 use crate::schema::ModelCapability;
5431
5432 let params = serde_json::json!({
5435 "prompt": "summarize this email",
5436 "intent": {
5437 "task": "chat",
5438 "prefer_local": true,
5439 "require": ["tool_use"],
5440 },
5441 });
5442
5443 let req: GenerateRequest =
5444 serde_json::from_value(params).expect("GenerateRequest deserialize");
5445
5446 let intent = req.intent.as_ref().expect("intent field deserialized");
5447 assert_eq!(intent.task, Some(TaskHint::Chat));
5448 assert!(intent.prefer_local);
5449 assert_eq!(intent.require, vec![ModelCapability::ToolUse]);
5450
5451 let back: serde_json::Value =
5455 serde_json::to_value(&req).expect("re-serialize GenerateRequest");
5456 assert_eq!(back["intent"]["task"], "chat");
5457 assert_eq!(back["intent"]["prefer_local"], true);
5458 assert_eq!(back["intent"]["require"][0], "tool_use");
5459
5460 let default_req: GenerateRequest = serde_json::from_value(serde_json::json!({
5465 "prompt": "x",
5466 "intent": {},
5467 }))
5468 .unwrap();
5469 let default_intent = default_req.intent.expect("present but empty");
5470 assert_eq!(default_intent.task, None);
5471 assert!(!default_intent.prefer_local);
5472 assert!(default_intent.require.is_empty());
5473
5474 let no_intent: GenerateRequest =
5477 serde_json::from_value(serde_json::json!({"prompt": "x"})).unwrap();
5478 assert!(no_intent.intent.is_none());
5479 }
5480
5481 #[test]
5482 fn rerank_prompt_matches_upstream_template_shape() {
5483 let p = rerank_prompt("retrieve relevant passages", "who runs the treasury?", "doc x");
5484 assert!(p.contains("<|im_start|>system"));
5485 assert!(p.contains("Note that the answer can only be \"yes\" or \"no\"."));
5486 assert!(p.contains("<|im_start|>user\n<Instruct>: retrieve relevant passages"));
5487 assert!(p.contains("<Query>: who runs the treasury?"));
5488 assert!(p.contains("<Document>: doc x<|im_end|>"));
5489 assert!(p.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
5490 }
5491
5492 #[test]
5493 fn rerank_score_yes_and_no_exactly() {
5494 assert_eq!(score_from_rerank_output("yes", "m"), 1.0);
5495 assert_eq!(score_from_rerank_output("no", "m"), 0.0);
5496 }
5497
5498 #[test]
5499 fn rerank_score_handles_case_leading_space_and_chat_sentinels() {
5500 assert_eq!(score_from_rerank_output(" Yes", "m"), 1.0);
5503 assert_eq!(score_from_rerank_output("\nno.", "m"), 0.0);
5504 assert_eq!(score_from_rerank_output("<|im_end|>yes", "m"), 1.0);
5505 }
5506
5507 #[test]
5508 fn rerank_score_scans_up_to_three_tokens() {
5509 assert_eq!(score_from_rerank_output("_bos_ yes", "m"), 1.0);
5512 }
5513
5514 #[test]
5515 fn rerank_score_unexpected_is_neutral() {
5516 assert_eq!(score_from_rerank_output("maybe", "m"), 0.5);
5519 assert_eq!(score_from_rerank_output("", "m"), 0.5);
5520 }
5521}