1pub mod adaptive_router;
26pub mod backend;
27pub mod backend_cache;
28pub mod hardware;
29pub mod intent;
30pub mod key_pool;
31pub mod models;
32pub mod outcome;
33pub mod protocol;
34pub mod registry;
35pub mod remote;
36pub mod router;
37pub mod routing_ext;
38pub mod schema;
39pub mod service;
40pub mod stream;
41pub mod tasks;
42pub mod vllm_mlx;
43
44use std::path::{Path, PathBuf};
45use std::sync::Arc;
46use std::time::Instant;
47use std::time::{SystemTime, UNIX_EPOCH};
48
49use serde::Serialize;
50use reqwest::multipart::{Form, Part};
51use thiserror::Error;
52#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
53use tokio::process::Command;
54#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
55use tokio::sync::Mutex;
56use tokio::sync::RwLock;
57use tracing::{debug, instrument};
58
59pub use adaptive_router::{
61 AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy,
62};
63pub use intent::{IntentHint, TaskHint};
64pub use key_pool::{KeyPool, KeyStats};
65pub use outcome::{
66 CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
67};
68pub use registry::{ModelFilter, ModelInfo, ModelRuntimeRequirement, ModelUpgrade, UnifiedRegistry};
69pub use remote::RemoteBackend;
70pub use routing_ext::{
71 CircuitBreaker, CircuitBreakerRegistry, CircuitState, ImplicitSignal, ImplicitSignalType,
72 RoutingMode, SpendControl, SpendLimitExceeded, SpendLimits, SpendStatus,
73};
74pub use schema::{
75 ApiProtocol, BenchmarkScore, CostModel, ModelCapability, ModelSchema, ModelSource,
76 PerformanceEnvelope, ProprietaryAuth,
77};
78
79pub use adaptive_router::TaskComplexity;
81#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
82pub use backend::CandleBackend;
83#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
84pub use backend::EmbeddingBackend;
85#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
86pub use backend::MlxBackend;
87pub use hardware::HardwareInfo;
88pub use models::{ModelRegistry, ModelRole};
89pub use router::{ModelRouter, RoutingDecision};
90pub use stream::{StreamAccumulator, StreamEvent};
91pub use tasks::{
92 parse_boxes, BoundingBox, ClassifyRequest, ClassifyResult, ContentBlock, EmbedRequest,
93 GenerateImageRequest, GenerateImageResult, GenerateParams, GenerateRequest,
94 GenerateVideoRequest, GenerateVideoResult, GroundRequest, GroundResult, Message, RerankRequest,
95 RerankResult, RerankedDocument, RoutingWorkload, SynthesizeRequest, SynthesizeResult, ToolCall,
96 TranscribeRequest, TranscribeResult, VideoMode,
97};
98#[derive(Error, Debug)]
101pub enum InferenceError {
102 #[error("model not found: {0}")]
103 ModelNotFound(String),
104
105 #[error("model download failed: {0}")]
106 DownloadFailed(String),
107
108 #[error("inference failed: {0}")]
109 InferenceFailed(String),
110
111 #[error("mode {mode} not implemented on backend {backend}: {reason}")]
116 UnsupportedMode {
117 mode: &'static str,
118 backend: &'static str,
119 reason: &'static str,
120 },
121
122 #[error("tokenization error: {0}")]
123 TokenizationError(String),
124
125 #[error("device error: {0}")]
126 DeviceError(String),
127
128 #[error("io error: {0}")]
129 Io(#[from] std::io::Error),
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum Device {
135 Cpu,
136 Metal,
137 Cuda(usize), }
139
140impl Device {
141 pub fn auto() -> Self {
143 #[cfg(feature = "metal")]
144 {
145 return Device::Metal;
146 }
147 #[cfg(feature = "cuda")]
148 {
149 return Device::Cuda(0);
150 }
151 #[cfg(not(any(feature = "metal", feature = "cuda")))]
152 {
153 Device::Cpu
154 }
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct InferenceConfig {
161 pub models_dir: std::path::PathBuf,
163 pub device: Option<Device>,
165 pub generation_model: String,
167 pub preferred_generation_model: Option<String>,
169 pub embedding_model: String,
171 pub preferred_embedding_model: Option<String>,
173 pub classification_model: String,
175 pub preferred_classification_model: Option<String>,
177}
178
179impl Default for InferenceConfig {
180 fn default() -> Self {
181 let models_dir = dirs_next()
182 .unwrap_or_else(|| std::path::PathBuf::from("."))
183 .join(".car")
184 .join("models");
185
186 let hw = HardwareInfo::detect();
187
188 Self {
189 models_dir,
190 device: None,
191 generation_model: hw.recommended_model,
192 preferred_generation_model: None,
193 embedding_model: "Qwen3-Embedding-0.6B".to_string(),
194 preferred_embedding_model: None,
195 classification_model: "Qwen3-0.6B".to_string(),
196 preferred_classification_model: None,
197 }
198 }
199}
200
201fn dirs_next() -> Option<std::path::PathBuf> {
202 std::env::var("HOME").ok().map(std::path::PathBuf::from)
203}
204
205#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
207pub struct TokenUsage {
208 pub prompt_tokens: u64,
210 pub completion_tokens: u64,
212 pub total_tokens: u64,
214 pub context_window: u64,
216}
217
218#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
220pub struct InferenceResult {
221 pub text: String,
223 pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
225 #[serde(default, skip_serializing_if = "Vec::is_empty")]
232 pub bounding_boxes: Vec<crate::tasks::grounding::BoundingBox>,
233 pub trace_id: String,
235 pub model_used: String,
237 pub latency_ms: u64,
239 #[serde(default)]
251 pub time_to_first_token_ms: Option<u64>,
252 pub usage: Option<TokenUsage>,
254}
255
256impl InferenceResult {
257 pub fn has_tool_calls(&self) -> bool {
259 !self.tool_calls.is_empty()
260 }
261}
262
263#[derive(Debug, Clone, Serialize)]
264pub struct SpeechRuntimeHealth {
265 pub root: PathBuf,
266 pub installed: bool,
267 pub python: PathBuf,
268 pub stt_command: PathBuf,
269 pub tts_command: PathBuf,
270 pub configured_python: Option<String>,
271 pub detected_python: Option<String>,
272}
273
274#[derive(Debug, Clone, Serialize)]
275pub struct SpeechModelHealth {
276 pub id: String,
277 pub name: String,
278 pub provider: String,
279 pub capability: ModelCapability,
280 pub is_local: bool,
281 pub available: bool,
282 pub cached: bool,
283 pub selected_by_default: bool,
284 pub source: String,
285}
286
287#[derive(Debug, Clone, Serialize)]
288pub struct SpeechHealthReport {
289 pub runtime: SpeechRuntimeHealth,
290 pub local_models: Vec<SpeechModelHealth>,
291 pub remote_models: Vec<SpeechModelHealth>,
292 pub elevenlabs_configured: bool,
293 pub prefer_local: bool,
294 pub allow_remote_fallback: bool,
295 pub preferred_local_stt: Option<String>,
296 pub preferred_local_tts: Option<String>,
297 pub preferred_remote_stt: Option<String>,
298 pub preferred_remote_tts: Option<String>,
299 pub local_stt_default: Option<String>,
300 pub local_tts_default: Option<String>,
301 pub remote_stt_default: Option<String>,
302 pub remote_tts_default: Option<String>,
303}
304
305#[derive(Debug, Clone, Serialize)]
306pub struct ModelDefaultHealth {
307 pub capability: ModelCapability,
308 pub configured_model: String,
309 pub available: bool,
310 pub is_local: bool,
311 pub provider: Option<String>,
312}
313
314#[derive(Debug, Clone, Serialize)]
315pub struct ModelProviderHealth {
316 pub provider: String,
317 pub configured: bool,
318 pub local_models: usize,
319 pub remote_models: usize,
320 pub available_models: usize,
321 pub capabilities: Vec<ModelCapability>,
322}
323
324#[derive(Debug, Clone, Serialize)]
325pub struct ModelCapabilityHealth {
326 pub capability: ModelCapability,
327 pub total_models: usize,
328 pub available_models: usize,
329 pub local_available_models: usize,
330 pub remote_available_models: usize,
331}
332
333#[derive(Debug, Clone, Serialize)]
334pub struct RoutingScenarioHealth {
335 pub name: String,
336 pub workload: RoutingWorkload,
337 pub task_family: String,
338 pub has_tools: bool,
339 pub has_vision: bool,
340 pub prefer_local: bool,
341 pub quality_first_cold_start: bool,
342 pub bootstrap_min_task_observations: u64,
343 pub bootstrap_quality_floor: f64,
344 pub model_id: String,
345 pub model_name: String,
346 pub reason: String,
347 pub strategy: RoutingStrategy,
348}
349
350#[derive(Debug, Clone, Serialize)]
351pub struct ModelBenchmarkPriorHealth {
352 pub model_id: String,
353 pub model_name: Option<String>,
354 pub overall_score: f64,
355 pub overall_latency_ms: Option<f64>,
356 pub task_scores: std::collections::HashMap<String, f64>,
357 pub task_latency_ms: std::collections::HashMap<String, f64>,
358 pub source_path: PathBuf,
359}
360
361#[derive(Debug, Clone, Serialize)]
362pub struct ModelHealthReport {
363 pub total_models: usize,
364 pub available_models: usize,
365 pub local_models: usize,
366 pub remote_models: usize,
367 pub defaults: Vec<ModelDefaultHealth>,
368 pub providers: Vec<ModelProviderHealth>,
369 pub capabilities: Vec<ModelCapabilityHealth>,
370 pub routing_prefer_local: bool,
371 pub routing_quality_first_cold_start: bool,
372 pub routing_min_observations: u64,
373 pub routing_bootstrap_min_task_observations: u64,
374 pub routing_bootstrap_quality_floor: f64,
375 pub routing_quality_weight: f64,
376 pub routing_latency_weight: f64,
377 pub routing_cost_weight: f64,
378 pub routing_scenarios: Vec<RoutingScenarioHealth>,
379 pub benchmark_priors: Vec<ModelBenchmarkPriorHealth>,
380 pub speech: SpeechHealthReport,
381}
382
383#[derive(Debug, Clone, Serialize)]
384pub struct SpeechInstallReport {
385 pub name: String,
386 pub hf_repo: String,
387 pub snapshot_path: PathBuf,
388 pub files_downloaded: usize,
389}
390
391#[derive(Debug, Clone, Serialize)]
392pub struct SpeechSmokePathReport {
393 pub path: String,
394 pub tts_model: String,
395 pub stt_model: String,
396 pub audio_path: PathBuf,
397 pub transcript: String,
398}
399
400#[derive(Debug, Clone, Serialize, Default)]
401pub struct SpeechSmokeReport {
402 pub local: Option<SpeechSmokePathReport>,
403 pub remote: Option<SpeechSmokePathReport>,
404 pub skipped: Vec<String>,
405}
406
407#[derive(Debug, Clone, Serialize, Default)]
408pub struct SpeechPolicy {
409 pub prefer_local: bool,
410 pub allow_remote_fallback: bool,
411 pub preferred_local_stt: Option<String>,
412 pub preferred_local_tts: Option<String>,
413 pub preferred_remote_stt: Option<String>,
414 pub preferred_remote_tts: Option<String>,
415}
416
417pub struct InferenceEngine {
422 pub config: InferenceConfig,
423 pub unified_registry: UnifiedRegistry,
425 pub adaptive_router: AdaptiveRouter,
427 pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
429 remote_backend: RemoteBackend,
431 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
437 mlx_backends: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
438 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
443 flux_cache: Arc<backend_cache::BackendCache<backend::mlx_flux::FluxBackend>>,
444 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
446 ltx_cache: Arc<backend_cache::BackendCache<backend::mlx_ltx::LtxBackend>>,
447 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
450 kokoro_cache: Arc<backend_cache::BackendCache<backend::mlx_kokoro::KokoroBackend>>,
451 pub registry: models::ModelRegistry,
453 pub router: ModelRouter,
454 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
455 backend: Arc<RwLock<Option<CandleBackend>>>,
456 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
457 embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
458 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
459 speech_runtime: Arc<Mutex<Option<SpeechRuntime>>>,
460 speech_policy: SpeechPolicy,
461}
462
463impl InferenceEngine {
464 fn preferred_model_for_capability(&self, capability: ModelCapability) -> Option<&str> {
465 match capability {
466 ModelCapability::Generate => self.config.preferred_generation_model.as_deref(),
467 ModelCapability::Embed => self.config.preferred_embedding_model.as_deref(),
468 ModelCapability::Classify => self.config.preferred_classification_model.as_deref(),
469 _ => None,
470 }
471 }
472
473 fn request_needs_vision(req: &GenerateRequest) -> bool {
474 req.images.as_ref().is_some_and(|images| !images.is_empty())
475 || req.messages.as_ref().is_some_and(|messages| {
476 messages
477 .iter()
478 .any(|msg| matches!(msg, Message::UserMultimodal { .. }))
479 })
480 }
481
482 fn request_has_video(req: &GenerateRequest) -> bool {
487 let images_have_video = req
488 .images
489 .as_ref()
490 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
491 let messages_have_video = req.messages.as_ref().is_some_and(|messages| {
492 messages.iter().any(|msg| match msg {
493 Message::UserMultimodal { content } => {
494 content.iter().any(ContentBlock::is_video)
495 }
496 _ => false,
497 })
498 });
499 images_have_video || messages_have_video
500 }
501
502 fn request_has_audio(req: &GenerateRequest) -> bool {
506 let images_have_audio = req
507 .images
508 .as_ref()
509 .is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
510 let messages_have_audio = req.messages.as_ref().is_some_and(|messages| {
511 messages.iter().any(|msg| match msg {
512 Message::UserMultimodal { content } => {
513 content.iter().any(ContentBlock::is_audio)
514 }
515 _ => false,
516 })
517 });
518 images_have_audio || messages_have_audio
519 }
520
521 pub fn new(config: InferenceConfig) -> Self {
522 let registry = models::ModelRegistry::new(config.models_dir.clone());
523 let hw = HardwareInfo::detect();
524 let router = ModelRouter::new(hw.clone());
525 let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
526 let adaptive_router = AdaptiveRouter::with_default_config(hw);
527 let mut tracker = OutcomeTracker::new();
528 let profiles_path = config.models_dir.join("outcome_profiles.json");
530 if let Ok(n) = tracker.load_from_file(&profiles_path) {
531 if n > 0 {
532 tracing::info!(loaded = n, "loaded persisted model profiles");
533 }
534 }
535 let mut benchmark_models_loaded = 0usize;
536 for path in benchmark_priors_paths(&config.models_dir) {
537 match routing_ext::load_benchmark_priors(&path) {
538 Ok(priors) if !priors.is_empty() => {
539 benchmark_models_loaded += priors.len();
540 routing_ext::apply_benchmark_priors(&mut tracker, &priors);
541 tracing::info!(
542 path = %path.display(),
543 loaded = priors.len(),
544 "loaded benchmark quality priors"
545 );
546 }
547 Ok(_) => {}
548 Err(error) => {
549 tracing::warn!(path = %path.display(), %error, "failed to load benchmark priors");
550 }
551 }
552 }
553 if benchmark_models_loaded > 0 {
554 tracing::info!(
555 loaded = benchmark_models_loaded,
556 "applied benchmark priors to cold-start routing"
557 );
558 }
559 let outcome_tracker = Arc::new(RwLock::new(tracker));
560
561 let remote_backend = RemoteBackend::new();
562
563 Self {
564 config,
565 unified_registry,
566 adaptive_router,
567 outcome_tracker,
568 remote_backend,
569 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
570 mlx_backends: Arc::new(backend_cache::BackendCache::from_env()),
571 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
572 flux_cache: Arc::new(backend_cache::BackendCache::from_env()),
573 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
574 ltx_cache: Arc::new(backend_cache::BackendCache::from_env()),
575 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
576 kokoro_cache: Arc::new(backend_cache::BackendCache::from_env()),
577 registry,
578 router,
579 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
580 backend: Arc::new(RwLock::new(None)),
581 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
582 embedding_backend: Arc::new(RwLock::new(None)),
583 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
584 speech_runtime: Arc::new(Mutex::new(None)),
585 speech_policy: SpeechPolicy {
586 prefer_local: cfg!(all(target_os = "macos", target_arch = "aarch64")),
587 allow_remote_fallback: true,
588 preferred_local_stt: None,
589 preferred_local_tts: None,
590 preferred_remote_stt: None,
591 preferred_remote_tts: None,
592 },
593 }
594 }
595
596 pub async fn init_key_pool(&self) {
599 for schema in self.unified_registry.list() {
601 if schema.is_remote() {
602 self.remote_backend.register_model_keys(schema).await;
603 }
604 }
605
606 let stats_path = self.config.models_dir.join("key_pool_stats.json");
608 if let Ok(n) = self.remote_backend.key_pool.load_stats(&stats_path).await {
609 if n > 0 {
610 tracing::info!(loaded = n, "loaded persisted key pool stats");
611 }
612 }
613
614 let total = self.remote_backend.key_pool.total_keys().await;
615 if total > 0 {
616 tracing::info!(keys = total, "key pool initialized");
617 }
618 }
619
620 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
623 async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
624 let read = self.backend.read().await;
625 if read.is_some() {
626 return Ok(());
627 }
628 drop(read);
629
630 let mut write = self.backend.write().await;
631 if write.is_some() {
632 return Ok(());
633 }
634
635 let model_path = self.registry.ensure_model(model_name).await?;
636 let device = self.config.device.unwrap_or_else(Device::auto);
637 let backend = CandleBackend::load(&model_path, device)?;
638 *write = Some(backend);
639 Ok(())
640 }
641
642 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
645 async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
646 let read = self.embedding_backend.read().await;
647 if read.is_some() {
648 return Ok(());
649 }
650 drop(read);
651
652 let mut write = self.embedding_backend.write().await;
653 if write.is_some() {
654 return Ok(());
655 }
656
657 let embedding_model = self
658 .preferred_model_for_capability(ModelCapability::Embed)
659 .unwrap_or(&self.config.embedding_model);
660 let model_path = self
661 .registry
662 .ensure_model(embedding_model)
663 .await?;
664 let device = self.config.device.unwrap_or_else(Device::auto);
665 let backend = EmbeddingBackend::load(&model_path, device)?;
666 *write = Some(backend);
667 Ok(())
668 }
669
670 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
673 async fn ensure_mlx_embedding_backend(&self) -> Result<String, InferenceError> {
674 let embedding_model_name = self
675 .preferred_model_for_capability(ModelCapability::Embed)
676 .unwrap_or(&self.config.embedding_model)
677 .to_string();
678 let schema = self
679 .unified_registry
680 .get(&embedding_model_name)
681 .or_else(|| self.unified_registry.find_by_name(&embedding_model_name))
682 .ok_or_else(|| InferenceError::ModelNotFound(embedding_model_name.clone()))?
683 .clone();
684 self.ensure_mlx_backend(&schema).await?;
685 Ok(schema.id)
686 }
687
688 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
692 async fn ensure_mlx_backend(
693 &self,
694 schema: &ModelSchema,
695 ) -> Result<backend_cache::CachedBackend<backend::MlxBackend>, InferenceError> {
696 if !Self::supports_native_mlx(schema) {
697 return Err(InferenceError::InferenceFailed(format!(
698 "native MLX backend does not support {} ({}) yet; use vLLM-MLX or add a family-specific MLX backend",
699 schema.name, schema.family
700 )));
701 }
702
703 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
704 let size = backend_cache::estimate_model_size(&model_dir);
705 let cache = Arc::clone(&self.mlx_backends);
706 let key = schema.id.clone();
707 cache.get_or_load(&key, size, move || {
711 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
712 backend::MlxBackend::load(&model_dir)
713 }))
714 .map_err(|e| {
715 InferenceError::InferenceFailed(format!(
716 "MLX backend loading panicked (possible Metal/accelerate exception): {:?}",
717 e
718 ))
719 })?
720 })
721 }
722
723 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
730 pub async fn warm_up<S: AsRef<str>>(&self, schema_ids: &[S]) -> Vec<Result<(), InferenceError>> {
731 let mut results = Vec::with_capacity(schema_ids.len());
732 for id in schema_ids {
733 let id = id.as_ref();
734 let outcome: Result<(), InferenceError> = async {
735 let schema = self
736 .unified_registry
737 .get(id)
738 .cloned()
739 .ok_or_else(|| InferenceError::InferenceFailed(
740 format!("warm_up: unknown schema id {id}"),
741 ))?;
742 match schema.capabilities.first().copied() {
743 Some(ModelCapability::ImageGeneration) => {
744 let model_dir =
745 self.unified_registry.ensure_local(&schema.id).await?;
746 let size = backend_cache::estimate_model_size(&model_dir);
747 let _ = self.flux_cache.get_or_load(&schema.id, size, || {
748 backend::mlx_flux::FluxBackend::load(&model_dir)
749 })?;
750 }
751 Some(ModelCapability::VideoGeneration) => {
752 let model_dir =
753 self.unified_registry.ensure_local(&schema.id).await?;
754 let size = backend_cache::estimate_model_size(&model_dir);
755 let _ = self.ltx_cache.get_or_load(&schema.id, size, || {
756 backend::mlx_ltx::LtxBackend::load(&model_dir)
757 })?;
758 }
759 Some(ModelCapability::TextToSpeech) => {
760 let model_dir =
761 self.unified_registry.ensure_local(&schema.id).await?;
762 let size = backend_cache::estimate_model_size(&model_dir);
763 let _ = self.kokoro_cache.get_or_load(&schema.id, size, || {
764 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
765 })?;
766 }
767 _ => {
768 let _ = self.ensure_mlx_backend(&schema).await?;
769 }
770 }
771 Ok(())
772 }
773 .await;
774 results.push(outcome);
775 }
776 results
777 }
778
779 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
781 pub async fn warm_up<S: AsRef<str>>(&self, _schema_ids: &[S]) -> Vec<Result<(), InferenceError>> {
782 Vec::new()
783 }
784
785 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
786 fn supports_native_mlx(schema: &ModelSchema) -> bool {
787 matches!(schema.family.as_str(), "qwen3" | "qwen2.5-vl" | "qwen2-vl")
788 }
789
790 pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
792 if let Some(model) = self.preferred_model_for_capability(ModelCapability::Generate) {
793 let ctx_len = self
794 .unified_registry
795 .get(model)
796 .or_else(|| self.unified_registry.find_by_name(model))
797 .map(|s| s.context_length)
798 .unwrap_or(0);
799 return AdaptiveRoutingDecision {
800 model_id: model.to_string(),
801 model_name: model.to_string(),
802 task: InferenceTask::Generate,
803 complexity: TaskComplexity::assess(prompt),
804 reason: "preferred generation model override".into(),
805 strategy: RoutingStrategy::Explicit,
806 predicted_quality: 0.5,
807 fallbacks: vec![],
808 context_length: ctx_len,
809 needs_compaction: false,
810 };
811 }
812 let tracker = self.outcome_tracker.read().await;
813 self.adaptive_router
814 .route(prompt, &self.unified_registry, &tracker)
815 }
816
817 pub fn route(&self, prompt: &str) -> RoutingDecision {
819 self.router.route_generate(prompt, &self.registry)
820 }
821
822 pub fn estimated_tokens(
825 &self,
826 req: &GenerateRequest,
827 model_id: Option<&str>,
828 ) -> (usize, usize, bool) {
829 let prompt_tokens = remote::estimate_tokens(&req.prompt);
830 let context_tokens = req
831 .context
832 .as_ref()
833 .map(|c| remote::estimate_tokens(c))
834 .unwrap_or(0);
835 let tools_tokens = req
836 .tools
837 .as_ref()
838 .map(|t| remote::estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
839 .unwrap_or(0);
840 let total_input = prompt_tokens + context_tokens + tools_tokens;
841
842 let context_window = model_id
843 .and_then(|id| {
844 self.unified_registry
845 .get(id)
846 .or_else(|| self.unified_registry.find_by_name(id))
847 })
848 .map(|s| s.context_length)
849 .unwrap_or(0);
850
851 let fits = context_window == 0 || (total_input + req.params.max_tokens) <= context_window;
852 (total_input, context_window, fits)
853 }
854
855 #[instrument(
858 name = "inference.generate",
859 skip_all,
860 fields(
861 model = tracing::field::Empty,
862 max_tokens = req.params.max_tokens,
863 prompt_tokens = tracing::field::Empty,
864 completion_tokens = tracing::field::Empty,
865 latency_ms = tracing::field::Empty,
866 )
867 )]
868 pub async fn generate_tracked(
869 &self,
870 req: GenerateRequest,
871 ) -> Result<InferenceResult, InferenceError> {
872 let start = Instant::now();
873
874 let (estimated_input, _, _) = self.estimated_tokens(&req, None);
876 let tracker_read = self.outcome_tracker.read().await;
877 let has_tools = req.tools.is_some();
878 let has_vision = Self::request_needs_vision(&req);
879 let preferred_model = self
880 .preferred_model_for_capability(ModelCapability::Generate)
881 .map(str::to_string);
882 let decision = match req.model.clone().or(preferred_model) {
883 Some(m) => {
884 let ctx_len = self
885 .unified_registry
886 .get(&m)
887 .or_else(|| self.unified_registry.find_by_name(&m))
888 .map(|s| s.context_length)
889 .unwrap_or(0);
890 AdaptiveRoutingDecision {
891 model_id: m.clone(),
892 model_name: m.clone(),
893 task: InferenceTask::Generate,
894 complexity: TaskComplexity::assess(&req.prompt),
895 reason: "explicit model".into(),
896 strategy: RoutingStrategy::Explicit,
897 predicted_quality: 0.5,
898 fallbacks: vec![],
899 context_length: ctx_len,
900 needs_compaction: ctx_len > 0 && estimated_input > ctx_len,
901 }
902 }
903 None => match &req.intent {
904 Some(hint) => self.adaptive_router.route_context_aware_with_intent(
905 &req.prompt,
906 estimated_input,
907 &self.unified_registry,
908 &tracker_read,
909 has_tools,
910 has_vision,
911 req.params.workload,
912 hint,
913 ),
914 None => self.adaptive_router.route_context_aware(
915 &req.prompt,
916 estimated_input,
917 &self.unified_registry,
918 &tracker_read,
919 has_tools,
920 has_vision,
921 req.params.workload,
922 ),
923 },
924 };
925 drop(tracker_read);
926
927 if decision.needs_compaction {
928 tracing::info!(
929 model = %decision.model_name,
930 prompt_tokens = estimated_input,
931 context_window = decision.context_length,
932 "prompt exceeds model context window — compaction or truncation needed"
933 );
934 }
935
936 let trace_id = {
938 let mut tracker = self.outcome_tracker.write().await;
939 tracker.record_start(&decision.model_id, decision.task, &decision.reason)
940 };
941
942 debug!(
943 model = %decision.model_name,
944 strategy = ?decision.strategy,
945 reason = %decision.reason,
946 trace = %trace_id,
947 "adaptive-routed generate request"
948 );
949
950 let mut req = req;
953 if req.params.budget_tokens == 0 && matches!(decision.complexity, TaskComplexity::Complex) {
954 let supports_thinking = self
955 .unified_registry
956 .get(&decision.model_id)
957 .map(|s| {
958 s.supported_params
959 .contains(&schema::GenerateParam::ExtendedThinking)
960 })
961 .unwrap_or(false);
962 if supports_thinking {
963 req.params.budget_tokens = 8000;
964 tracing::info!(model = %decision.model_name, budget = 8000, "auto-enabled extended thinking for complex task");
965 }
966 }
967
968 let mut models_to_try = vec![decision.model_id.clone()];
970 models_to_try.extend(decision.fallbacks.iter().cloned());
971
972 let mut last_error = None;
973
974 for candidate_id in &models_to_try {
975 #[allow(unused_mut)]
978 let mut schema = self
979 .unified_registry
980 .get(candidate_id)
981 .or_else(|| self.unified_registry.find_by_name(candidate_id))
982 .cloned();
983
984 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
986 if let Some(ref s) = schema {
987 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
988 tracing::info!(
989 from = %s.id, to = %mlx_equiv.id,
990 "redirecting GGUF model to MLX equivalent on Apple Silicon"
991 );
992 schema = Some(mlx_equiv.clone());
993 }
994 }
995
996 let candidate_name = schema
997 .as_ref()
998 .map(|s| s.name.clone())
999 .unwrap_or_else(|| candidate_id.clone());
1000
1001 let is_remote = schema
1002 .as_ref()
1003 .map(|s| s.is_remote() || s.is_vllm_mlx())
1004 .unwrap_or(false);
1005
1006 let has_tools = req.tools.is_some();
1007
1008 let context = if has_tools
1010 && req.tools.as_ref().map_or(false, |t| {
1011 t.iter().any(|tool| {
1012 tool.get("function")
1013 .and_then(|f| f.get("name"))
1014 .and_then(|n| n.as_str())
1015 == Some("done")
1016 })
1017 }) {
1018 let base = req.context.as_deref().unwrap_or("");
1019 Some(format!(
1020 "{base}\n\nIMPORTANT: When calling the `done` tool, the `result` field MUST contain a DETAILED summary of everything you found and did. This is the ONLY output the user sees. Do NOT just say 'completed' — include specific findings, data, and conclusions."
1021 ))
1022 } else {
1023 req.context.clone()
1024 };
1025
1026 let result = if is_remote {
1027 let schema_val = schema.unwrap();
1028 let _ctx_len = schema_val.context_length;
1029 let temperature = if !schema_val.supported_params.is_empty()
1032 && !schema_val
1033 .supported_params
1034 .contains(&crate::schema::GenerateParam::Temperature)
1035 {
1036 -1.0
1037 } else {
1038 req.params.temperature
1039 };
1040
1041 self.remote_backend
1047 .generate_with_tools_multi(
1048 &schema_val,
1049 &req.prompt,
1050 context.as_deref(),
1051 temperature,
1052 req.params.max_tokens,
1053 req.tools.as_deref(),
1054 req.images.as_deref(),
1055 req.messages.as_deref(),
1056 req.params.tool_choice.as_deref(),
1057 req.params.parallel_tool_calls,
1058 req.params.budget_tokens,
1059 req.cache_control,
1060 req.response_format.as_ref(),
1061 )
1062 .await
1063 .map(|(t, c, u)| (t, c, u, None::<u64>))
1068 } else {
1069 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1070 {
1071 let schema_ref = schema
1075 .as_ref()
1076 .ok_or_else(|| InferenceError::ModelNotFound(candidate_id.clone()))?;
1077
1078 if schema_ref.is_foundation_models() {
1083 if has_tools {
1084 Err(InferenceError::UnsupportedMode {
1085 mode: "tool-use",
1086 backend: "foundation-models",
1087 reason:
1088 "tool calling not yet wired through the FoundationModels \
1089 bridge — route to a remote model with ToolUse capability",
1090 })
1091 } else if Self::request_has_video(&req)
1092 || Self::request_has_audio(&req)
1093 || req
1094 .images
1095 .as_ref()
1096 .is_some_and(|imgs| !imgs.is_empty())
1097 {
1098 Err(InferenceError::UnsupportedMode {
1099 mode: "multimodal-content",
1100 backend: "foundation-models",
1101 reason:
1102 "the FoundationModels bridge currently exposes text-only \
1103 generation — route image/audio/video to a remote VL model",
1104 })
1105 } else {
1106 let prompt = req.prompt.clone();
1107 let instructions = context.clone();
1108 let max_tokens = req.params.max_tokens as u32;
1109 let temperature = req.params.temperature;
1110 tokio::task::spawn_blocking(move || {
1111 crate::backend::foundation_models::generate(
1112 &prompt,
1113 instructions.as_deref(),
1114 max_tokens,
1115 temperature as f32,
1116 )
1117 })
1118 .await
1119 .map_err(|e| {
1120 InferenceError::InferenceFailed(format!(
1121 "FoundationModels task panicked: {e}"
1122 ))
1123 })
1124 .and_then(|r| r)
1125 .map(|text| (text, vec![], None, None))
1126 }
1127 } else if !schema_ref.is_mlx() {
1128 Err(InferenceError::InferenceFailed(format!(
1129 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1130 schema_ref.id
1131 )))
1132 } else if schema_ref.tags.iter().any(|t| t == "mlx-vlm-cli") {
1133 let has_images = req
1145 .images
1146 .as_ref()
1147 .is_some_and(|imgs| !imgs.is_empty());
1148 if !has_images {
1149 return Err(InferenceError::UnsupportedMode {
1150 mode: "text-only-on-mlx-vlm-id",
1151 backend: "mlx-vlm-cli",
1152 reason:
1153 "the `mlx-vlm/...` model IDs route exclusively \
1154 through the mlx-vlm CLI for image inference. \
1155 For text-only generation, route to a Qwen3 \
1156 text model (`mlx/qwen3-4b:4bit` etc.) — the \
1157 CLI shell-out has higher latency than the \
1158 in-process MLX text tower.",
1159 });
1160 }
1161 if !crate::backend::mlx_vlm_cli::is_available() {
1162 return Err(InferenceError::UnsupportedMode {
1163 mode: "image-content-block",
1164 backend: "mlx-vlm-cli",
1165 reason:
1166 "mlx-vlm CLI not found on PATH. Install with \
1167 `uv tool install mlx-vlm` (or `pip install \
1168 mlx-vlm`) and CAR will route through \
1169 `mlx_vlm.generate` automatically. \
1170 Alternatives: a local vLLM-MLX VLM server, \
1171 or a remote VL model. (#115)",
1172 });
1173 }
1174 let repo = match &schema_ref.source {
1175 crate::schema::ModelSource::Mlx { hf_repo, .. } => {
1176 hf_repo.clone()
1177 }
1178 _ => {
1179 return Err(InferenceError::InferenceFailed(format!(
1180 "model '{}' is tagged mlx-vlm-cli but its \
1181 source isn't ModelSource::Mlx — registry bug",
1182 schema_ref.id
1183 )));
1184 }
1185 };
1186 let imgs = req.images.clone().unwrap_or_default();
1187 let temp = req.params.temperature;
1188 let max_t = req.params.max_tokens;
1189 let prompt = req.prompt.clone();
1190 let text = tokio::task::spawn_blocking(move || {
1191 crate::backend::mlx_vlm_cli::generate(
1192 &repo, &prompt, &imgs, temp, max_t,
1193 )
1194 })
1195 .await
1196 .map_err(|e| {
1197 InferenceError::InferenceFailed(format!(
1198 "mlx_vlm CLI task panicked: {e}"
1199 ))
1200 })??;
1201 let bounding_boxes = parse_boxes(&text);
1202 let latency_ms = start.elapsed().as_millis() as u64;
1203 {
1204 let mut tracker = self.outcome_tracker.write().await;
1205 tracker.record_complete(&trace_id, latency_ms, 0, 0);
1206 }
1207 return Ok(InferenceResult {
1208 text,
1209 tool_calls: vec![],
1210 bounding_boxes,
1211 trace_id: trace_id.clone(),
1212 model_used: schema_ref.id.clone(),
1213 latency_ms,
1214 time_to_first_token_ms: None,
1215 usage: None,
1216 });
1217 } else {
1218 let handle = self.ensure_mlx_backend(schema_ref).await?;
1228 if Self::request_has_video(&req) {
1234 return Err(InferenceError::UnsupportedMode {
1235 mode: "video-content-block",
1236 backend: "native-mlx-qwen25vl",
1237 reason:
1238 "Qwen2.5-VL video understanding is on the request surface \
1239 but the video-tokenization path (frame sampling + merger) \
1240 is not yet wired; route to a remote VL provider for now",
1241 });
1242 }
1243 if Self::request_has_audio(&req) {
1244 return Err(InferenceError::UnsupportedMode {
1245 mode: "audio-content-block",
1246 backend: "native-mlx-qwen25vl",
1247 reason:
1248 "audio understanding is on the request surface (Gemma 4 \
1249 E2B/E4B and Gemini accept it) but the native MLX path \
1250 for this model does not — route to Gemini or Gemma-4",
1251 });
1252 }
1253 let has_images = req
1254 .images
1255 .as_ref()
1256 .is_some_and(|imgs| !imgs.is_empty());
1257 if has_images {
1258 let can_do_vision = {
1259 let guard = handle.lock().map_err(|_| {
1260 InferenceError::InferenceFailed(
1261 "MLX backend mutex poisoned".into(),
1262 )
1263 })?;
1264 guard.supports_capability(
1265 crate::schema::ModelCapability::Vision,
1266 )
1267 };
1268 if !can_do_vision {
1269 return Err(InferenceError::UnsupportedMode {
1279 mode: "image-content-block",
1280 backend: "native-mlx-text",
1281 reason:
1282 "this MLX backend is a plain Qwen3 text tower. \
1283 For local image inference, route to \
1284 `mlx-vlm/qwen3-vl-2b:bf16` or another `mlx-vlm/...` \
1285 catalog ID so CAR shells out to `mlx_vlm.generate`. \
1286 Alternatives: a local vLLM-MLX VLM server, or a \
1287 remote VL model. (#115)",
1288 });
1289 }
1290 }
1291 self.generate_mlx(req.clone(), &schema_ref.id)
1292 .await
1293 .map(|(text, ttft)| (text, vec![], None, ttft))
1294 }
1295 }
1296
1297 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1298 {
1299 match self.ensure_backend(&candidate_name).await {
1300 Ok(()) => {
1301 let mut write = self.backend.write().await;
1302 let backend = write.as_mut().unwrap();
1303 tasks::generate::generate(backend, req.clone())
1304 .await
1305 .map(|(text, ttft)| (text, vec![], None, ttft))
1306 }
1307 Err(e) => Err(e),
1308 }
1309 }
1310 };
1311
1312 match result {
1313 Ok((text, tool_calls, usage, time_to_first_token_ms)) => {
1314 let latency_ms = start.elapsed().as_millis() as u64;
1315 let estimated_tokens = usage
1316 .as_ref()
1317 .map(|u| u.completion_tokens as usize)
1318 .unwrap_or_else(|| text.split_whitespace().count());
1319 {
1320 let mut tracker = self.outcome_tracker.write().await;
1321 tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
1322 }
1323 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1325 cb.record_success(candidate_id);
1326 }
1327 self.auto_save_outcomes().await;
1329
1330 let span = tracing::Span::current();
1332 span.record("model", candidate_name.as_str());
1333 span.record("latency_ms", latency_ms);
1334 if let Some(ttft) = time_to_first_token_ms {
1335 span.record("ttft_ms", ttft);
1336 }
1337 if let Some(ref u) = usage {
1338 span.record("prompt_tokens", u.prompt_tokens);
1339 span.record("completion_tokens", u.completion_tokens);
1340 }
1341
1342 let bounding_boxes = tasks::grounding::parse_boxes(&text);
1345 return Ok(InferenceResult {
1346 text,
1347 tool_calls,
1348 bounding_boxes,
1349 trace_id,
1350 model_used: candidate_name,
1351 latency_ms,
1352 time_to_first_token_ms,
1353 usage,
1354 });
1355 }
1356 Err(e) => {
1357 tracing::warn!(
1358 model = %candidate_name,
1359 error = %e,
1360 remaining = models_to_try.len().saturating_sub(
1361 models_to_try.iter().position(|m| m == candidate_id).unwrap_or(0) + 1
1362 ),
1363 "model failed, trying next fallback immediately"
1364 );
1365 {
1367 let mut tracker = self.outcome_tracker.write().await;
1368 let fail_trace =
1369 tracker.record_start(candidate_id, decision.task, "fallback");
1370 tracker.record_failure(&fail_trace, &e.to_string());
1371 }
1372 {
1376 let err_str = e.to_string();
1377 let is_client_error =
1378 err_str.contains("API returned 4") && !err_str.contains("429");
1379 if let Ok(mut cb) = self.adaptive_router.circuit_breakers.lock() {
1380 cb.record_failure(candidate_id);
1381 if is_client_error {
1384 cb.record_failure(candidate_id);
1385 }
1386 }
1387 }
1388 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1390 {
1391 let mut write = self.backend.write().await;
1392 *write = None;
1393 }
1394 last_error = Some(e);
1395 }
1396 }
1397 }
1398
1399 let e = last_error.unwrap_or(InferenceError::InferenceFailed(
1401 "no models available".into(),
1402 ));
1403 {
1404 let mut tracker = self.outcome_tracker.write().await;
1405 tracker.record_failure(&trace_id, &e.to_string());
1406 }
1407 self.auto_save_outcomes().await;
1408 Err(e)
1409 }
1410
1411 pub async fn generate_tracked_stream(
1448 &self,
1449 req: GenerateRequest,
1450 ) -> Result<tokio::sync::mpsc::Receiver<stream::StreamEvent>, InferenceError> {
1451 let has_tools = req.tools.is_some();
1452 let has_vision = Self::request_needs_vision(&req);
1453 let preferred_model = self
1454 .preferred_model_for_capability(ModelCapability::Generate)
1455 .map(str::to_string);
1456 let decision = match req.model.clone().or(preferred_model) {
1457 Some(m) => {
1458 let ctx_len = self
1459 .unified_registry
1460 .get(&m)
1461 .or_else(|| self.unified_registry.find_by_name(&m))
1462 .map(|s| s.context_length)
1463 .unwrap_or(0);
1464 AdaptiveRoutingDecision {
1465 model_id: m.clone(),
1466 model_name: m,
1467 task: InferenceTask::Generate,
1468 complexity: TaskComplexity::assess(&req.prompt),
1469 reason: "explicit model".into(),
1470 strategy: RoutingStrategy::Explicit,
1471 predicted_quality: 0.5,
1472 fallbacks: vec![],
1473 context_length: ctx_len,
1474 needs_compaction: false,
1475 }
1476 }
1477 None => {
1478 let tracker_read = self.outcome_tracker.read().await;
1479 if has_vision {
1480 self.adaptive_router.route_with_vision(
1481 &req.prompt,
1482 &self.unified_registry,
1483 &tracker_read,
1484 has_tools,
1485 )
1486 } else if has_tools {
1487 self.adaptive_router.route_with_tools(
1488 &req.prompt,
1489 &self.unified_registry,
1490 &tracker_read,
1491 )
1492 } else {
1493 self.adaptive_router
1494 .route(&req.prompt, &self.unified_registry, &tracker_read)
1495 }
1496 }
1497 };
1498
1499 #[allow(unused_mut)]
1502 let mut schema = self
1503 .unified_registry
1504 .get(&decision.model_id)
1505 .or_else(|| self.unified_registry.find_by_name(&decision.model_id))
1506 .cloned();
1507
1508 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1510 if let Some(ref s) = schema {
1511 if let Some(mlx_equiv) = self.unified_registry.resolve_mlx_equivalent(s) {
1512 tracing::info!(
1513 from = %s.id, to = %mlx_equiv.id,
1514 "redirecting GGUF model to MLX equivalent on Apple Silicon (stream)"
1515 );
1516 schema = Some(mlx_equiv.clone());
1517 }
1518 }
1519
1520 let is_remote = schema
1521 .as_ref()
1522 .map(|s| s.is_remote() || s.is_vllm_mlx())
1523 .unwrap_or(false);
1524
1525 if is_remote {
1526 let schema = schema.unwrap();
1527 self.remote_backend.register_model_keys(&schema).await;
1529
1530 self.remote_backend
1531 .generate_stream(
1532 &schema,
1533 &req.prompt,
1534 req.context.as_deref(),
1535 req.params.temperature,
1536 req.params.max_tokens,
1537 req.tools.as_deref(),
1538 req.images.as_deref(),
1539 req.params.tool_choice.as_deref(),
1540 req.params.parallel_tool_calls,
1541 req.response_format.as_ref(),
1542 )
1543 .await
1544 } else {
1545 let schema = schema.ok_or_else(|| {
1546 InferenceError::ModelNotFound(decision.model_id.clone())
1547 })?;
1548 let (tx, rx) = tokio::sync::mpsc::channel(64);
1549
1550 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1551 {
1552 if schema.is_foundation_models() {
1554 let prompt = req.prompt.clone();
1555 let instructions = req.context.clone();
1556 let max_tokens = req.params.max_tokens as u32;
1557 let temperature = req.params.temperature;
1558 let tx_clone = tx.clone();
1559 tokio::task::spawn_blocking(move || {
1560 let accum = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
1567 let accum_cb = accum.clone();
1568 let cb = crate::backend::foundation_models::StreamCallback::new(
1569 move |delta: &str| {
1570 if let Ok(mut g) = accum_cb.lock() {
1571 g.push_str(delta);
1572 }
1573 tx_clone
1574 .blocking_send(stream::StreamEvent::TextDelta(
1575 delta.to_string(),
1576 ))
1577 .is_ok()
1578 },
1579 );
1580 let result = crate::backend::foundation_models::stream(
1581 &prompt,
1582 instructions.as_deref(),
1583 max_tokens,
1584 temperature as f32,
1585 cb,
1586 );
1587 let final_text = accum
1588 .lock()
1589 .map(|g| g.clone())
1590 .unwrap_or_default();
1591 let _ = tx.blocking_send(stream::StreamEvent::Done {
1592 text: final_text,
1593 tool_calls: vec![],
1594 });
1595 result
1596 });
1597 return Ok(rx);
1598 }
1599 if !schema.is_mlx() {
1601 return Err(InferenceError::InferenceFailed(format!(
1602 "model '{}' has no MLX equivalent; Candle backend disabled on Apple Silicon",
1603 schema.id
1604 )));
1605 }
1606 let backend = self.ensure_mlx_backend(&schema).await?;
1607 let model_id = schema.id.clone();
1608 let cache = Arc::clone(&self.mlx_backends);
1609 tokio::task::spawn_blocking(move || {
1615 let _ = Self::stream_local_mlx(backend, cache, model_id, req, tx);
1616 });
1617 return Ok(rx);
1618 }
1619
1620 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1621 {
1622 self.ensure_backend(&schema.name).await?;
1623 let backend = self.backend.clone();
1624 tokio::spawn(async move {
1625 let _ = Self::stream_local_candle(backend, req, tx).await;
1626 });
1627 Ok(rx)
1628 }
1629 }
1630 }
1631
1632 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1633 async fn stream_local_candle(
1634 backend_lock: Arc<RwLock<Option<CandleBackend>>>,
1635 req: GenerateRequest,
1636 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1637 ) -> Result<(), InferenceError> {
1638 let mut write = backend_lock.write().await;
1639 let backend = write
1640 .as_mut()
1641 .ok_or_else(|| InferenceError::InferenceFailed("backend not initialized".into()))?;
1642 backend.clear_kv_cache();
1643
1644 let formatted = tasks::generate::apply_chat_template(
1645 &req.prompt,
1646 req.context.as_deref(),
1647 req.params.thinking,
1648 );
1649 let tokens = backend.encode(&formatted)?;
1650 let eos = backend.eos_token_id();
1651 let eos_alt = backend.token_id("<|im_end|>");
1652 let params = &req.params;
1653
1654 if tokens.is_empty() {
1655 let _ = tx
1656 .send(stream::StreamEvent::Done {
1657 text: String::new(),
1658 tool_calls: vec![],
1659 })
1660 .await;
1661 return Ok(());
1662 }
1663
1664 let max_ctx = backend.context_length().unwrap_or(32768);
1665 let headroom = params.max_tokens.min(max_ctx / 4);
1666 let max_prompt = max_ctx.saturating_sub(headroom);
1667 let tokens = if tokens.len() > max_prompt {
1668 tokens[tokens.len() - max_prompt..].to_vec()
1669 } else {
1670 tokens
1671 };
1672
1673 let mut generated = Vec::new();
1674 let logits = backend.forward(&tokens, 0)?;
1675 let mut next_token = tasks::generate::sample_token(&logits, params)?;
1676
1677 for _ in 0..params.max_tokens {
1678 if eos.map_or(false, |id| next_token == id)
1679 || eos_alt.map_or(false, |id| next_token == id)
1680 {
1681 break;
1682 }
1683
1684 generated.push(next_token);
1685 let delta = backend.decode(&[next_token])?;
1686 if !delta.is_empty()
1687 && tx.send(stream::StreamEvent::TextDelta(delta)).await.is_err()
1688 {
1689 return Ok(());
1690 }
1691
1692 if !params.stop.is_empty() {
1693 let text_so_far = backend.decode(&generated)?;
1694 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1695 break;
1696 }
1697 }
1698
1699 let pos = tokens.len() + generated.len() - 1;
1700 let logits = backend.forward(&[next_token], pos)?;
1701 next_token = tasks::generate::sample_token(&logits, params)?;
1702 }
1703
1704 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1705 let _ = tx
1706 .send(stream::StreamEvent::Done {
1707 text,
1708 tool_calls: vec![],
1709 })
1710 .await;
1711 Ok(())
1712 }
1713
1714 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1718 fn stream_local_mlx(
1719 handle: backend_cache::CachedBackend<backend::MlxBackend>,
1720 cache: Arc<backend_cache::BackendCache<backend::MlxBackend>>,
1721 model_id: String,
1722 req: GenerateRequest,
1723 tx: tokio::sync::mpsc::Sender<stream::StreamEvent>,
1724 ) -> Result<(), InferenceError> {
1725 let mut guard = handle.lock().map_err(|_| {
1726 InferenceError::InferenceFailed(format!(
1727 "MLX backend mutex poisoned for {model_id}"
1728 ))
1729 })?;
1730 let backend: &mut backend::MlxBackend = &mut *guard;
1731 backend.clear_kv_cache();
1732
1733 let formatted = tasks::generate::apply_chat_template(
1734 &req.prompt,
1735 req.context.as_deref(),
1736 req.params.thinking,
1737 );
1738 let tokens = backend.encode(&formatted)?;
1739 let eos = backend.eos_token_id();
1740 let eos_alt = backend.token_id("<|im_end|>");
1741 let params = &req.params;
1742
1743 if tokens.is_empty() {
1744 let _ = tx.blocking_send(stream::StreamEvent::Done {
1745 text: String::new(),
1746 tool_calls: vec![],
1747 });
1748 return Ok(());
1749 }
1750
1751 let max_ctx = backend.context_length();
1752 let headroom = params.max_tokens.min(max_ctx / 4);
1753 let max_prompt = max_ctx.saturating_sub(headroom);
1754 let tokens = if tokens.len() > max_prompt {
1755 tokens[tokens.len() - max_prompt..].to_vec()
1756 } else {
1757 tokens
1758 };
1759
1760 let mut generated = Vec::new();
1761 let logits = match Self::catch_mlx("stream prefill", || backend.forward(&tokens, 0)) {
1766 Ok(v) => v,
1767 Err(e) => { cache.invalidate(&model_id); return Err(e); }
1768 };
1769 let mut next_token = Self::sample_from_logits(&logits, params)?;
1770
1771 for _ in 0..params.max_tokens {
1772 if eos.map_or(false, |id| next_token == id)
1773 || eos_alt.map_or(false, |id| next_token == id)
1774 {
1775 break;
1776 }
1777
1778 generated.push(next_token);
1779 let delta = backend.decode(&[next_token])?;
1780 if !delta.is_empty()
1781 && tx.blocking_send(stream::StreamEvent::TextDelta(delta)).is_err()
1782 {
1783 return Ok(());
1784 }
1785
1786 if !params.stop.is_empty() {
1787 let text_so_far = backend.decode(&generated)?;
1788 if params.stop.iter().any(|s| text_so_far.contains(s)) {
1789 break;
1790 }
1791 }
1792
1793 let pos = tokens.len() + generated.len() - 1;
1794 let logits = match Self::catch_mlx("stream forward", || backend.forward(&[next_token], pos)) {
1795 Ok(v) => v,
1796 Err(e) => { cache.invalidate(&model_id); return Err(e); }
1797 };
1798 next_token = Self::sample_from_logits(&logits, params)?;
1799 }
1800
1801 let text = tasks::generate::strip_thinking(&backend.decode(&generated)?, params.thinking);
1802 let _ = tx.blocking_send(stream::StreamEvent::Done {
1803 text,
1804 tool_calls: vec![],
1805 });
1806 Ok(())
1807 }
1808
1809 pub async fn route_context_snapshot(
1811 &self,
1812 prompt: &str,
1813 workload: RoutingWorkload,
1814 has_tools: bool,
1815 has_vision: bool,
1816 ) -> AdaptiveRoutingDecision {
1817 let tracker = self.outcome_tracker.read().await;
1818 self.adaptive_router.route_context_aware(
1819 prompt,
1820 0,
1821 &self.unified_registry,
1822 &tracker,
1823 has_tools,
1824 has_vision,
1825 workload,
1826 )
1827 }
1828
1829 pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
1832 Ok(self.generate_tracked(req).await?.text)
1833 }
1834
1835 pub async fn tokenize(
1847 &self,
1848 model: &str,
1849 text: &str,
1850 ) -> Result<Vec<u32>, InferenceError> {
1851 self.assert_local_for_tokenize(model)?;
1852
1853 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1854 {
1855 let schema = self
1856 .unified_registry
1857 .get(model)
1858 .or_else(|| self.unified_registry.find_by_name(model))
1859 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1860 .clone();
1861 let handle = self.ensure_mlx_backend(&schema).await?;
1862 let guard = handle.lock().map_err(|_| {
1863 InferenceError::InferenceFailed(format!(
1864 "MLX backend mutex poisoned for {}", schema.id
1865 ))
1866 })?;
1867 return guard.tokenize_raw(text);
1868 }
1869
1870 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1871 {
1872 self.ensure_backend(model).await?;
1873 let read = self.backend.read().await;
1874 let backend = read.as_ref().ok_or_else(|| {
1875 InferenceError::InferenceFailed(
1876 "candle backend missing after ensure_backend".to_string(),
1877 )
1878 })?;
1879 backend.tokenize_raw(text)
1880 }
1881 }
1882
1883 pub async fn detokenize(
1885 &self,
1886 model: &str,
1887 tokens: &[u32],
1888 ) -> Result<String, InferenceError> {
1889 self.assert_local_for_tokenize(model)?;
1890
1891 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1892 {
1893 let schema = self
1894 .unified_registry
1895 .get(model)
1896 .or_else(|| self.unified_registry.find_by_name(model))
1897 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?
1898 .clone();
1899 let handle = self.ensure_mlx_backend(&schema).await?;
1900 let guard = handle.lock().map_err(|_| {
1901 InferenceError::InferenceFailed(format!(
1902 "MLX backend mutex poisoned for {}", schema.id
1903 ))
1904 })?;
1905 return guard.detokenize_raw(tokens);
1906 }
1907
1908 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1909 {
1910 self.ensure_backend(model).await?;
1911 let read = self.backend.read().await;
1912 let backend = read.as_ref().ok_or_else(|| {
1913 InferenceError::InferenceFailed(
1914 "candle backend missing after ensure_backend".to_string(),
1915 )
1916 })?;
1917 backend.detokenize_raw(tokens)
1918 }
1919 }
1920
1921 fn assert_local_for_tokenize(&self, model: &str) -> Result<(), InferenceError> {
1925 if let Some(schema) = self
1926 .unified_registry
1927 .get(model)
1928 .or_else(|| self.unified_registry.find_by_name(model))
1929 {
1930 if !schema.is_local() {
1931 return Err(InferenceError::UnsupportedMode {
1932 mode: "tokenize/detokenize",
1933 backend: "remote",
1934 reason:
1935 "remote provider tokenizer is not exposed by the runtime; \
1936 use a local model (Qwen3 GGUF / MLX) for tokenizer-correctness checks",
1937 });
1938 }
1939 }
1940 Ok(())
1942 }
1943
1944 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1950 fn catch_mlx<F, T>(context: &str, f: F) -> Result<T, InferenceError>
1951 where
1952 F: FnOnce() -> Result<T, InferenceError>,
1953 {
1954 std::panic::catch_unwind(std::panic::AssertUnwindSafe(f))
1955 .map_err(|e| InferenceError::InferenceFailed(
1956 format!("MLX panicked during {context}: {e:?}")
1957 ))?
1958 }
1959
1960 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1963 async fn generate_mlx(
1964 &self,
1965 req: GenerateRequest,
1966 model_id: &str,
1967 ) -> Result<(String, Option<u64>), InferenceError> {
1968 let start = std::time::Instant::now();
1969
1970 let schema = self
1971 .unified_registry
1972 .get(model_id)
1973 .cloned()
1974 .ok_or_else(|| InferenceError::InferenceFailed(format!(
1975 "generate_mlx: unknown schema id {model_id}"
1976 )))?;
1977 let handle = self.ensure_mlx_backend(&schema).await?;
1978 let mut guard = handle.lock().map_err(|_| {
1979 InferenceError::InferenceFailed(format!(
1980 "MLX backend mutex poisoned for {model_id}"
1981 ))
1982 })?;
1983 let backend: &mut backend::MlxBackend = &mut *guard;
1984 backend.clear_kv_cache();
1985
1986 let formatted = tasks::generate::apply_chat_template(
1987 &req.prompt,
1988 req.context.as_deref(),
1989 req.params.thinking,
1990 );
1991 let tokens = backend.encode(&formatted)?;
1992 let eos = backend.eos_token_id();
1993 let eos_alt = backend.token_id("<|im_end|>");
1994 let params = &req.params;
1995
1996 if tokens.is_empty() {
1997 return Ok((String::new(), None));
1998 }
1999
2000 let max_ctx = backend.context_length();
2002 let headroom = params.max_tokens.min(max_ctx / 4);
2003 let max_prompt = max_ctx.saturating_sub(headroom);
2004 let tokens = if tokens.len() > max_prompt {
2005 tokens[tokens.len() - max_prompt..].to_vec()
2006 } else {
2007 tokens
2008 };
2009
2010 let mut generated = Vec::new();
2011
2012 let logits = match Self::catch_mlx("prefill", || backend.forward(&tokens, 0)) {
2018 Ok(v) => v,
2019 Err(e) => {
2020 drop(guard);
2021 self.mlx_backends.invalidate(model_id);
2022 return Err(e);
2023 }
2024 };
2025 let mut next_token = Self::sample_from_logits(&logits, params)?;
2026 let ttft_ms = Some(start.elapsed().as_millis() as u64);
2027
2028 for _ in 0..params.max_tokens {
2029 if eos.map_or(false, |id| next_token == id)
2030 || eos_alt.map_or(false, |id| next_token == id)
2031 {
2032 break;
2033 }
2034
2035 generated.push(next_token);
2036
2037 if !params.stop.is_empty() {
2038 let text_so_far = backend.decode(&generated)?;
2039 if params.stop.iter().any(|s| text_so_far.contains(s)) {
2040 break;
2041 }
2042 }
2043
2044 let pos = tokens.len() + generated.len() - 1;
2045 let logits = match Self::catch_mlx("forward", || backend.forward(&[next_token], pos)) {
2046 Ok(v) => v,
2047 Err(e) => {
2048 drop(guard);
2049 self.mlx_backends.invalidate(model_id);
2050 return Err(e);
2051 }
2052 };
2053 next_token = Self::sample_from_logits(&logits, params)?;
2054 }
2055
2056 let text = backend.decode(&generated)?;
2057 Ok((tasks::generate::strip_thinking(&text, params.thinking), ttft_ms))
2058 }
2059
2060 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2062 fn sample_from_logits(logits: &[f32], params: &GenerateParams) -> Result<u32, InferenceError> {
2063 if params.temperature <= 0.0 {
2064 let (idx, _) = logits
2066 .iter()
2067 .enumerate()
2068 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2069 .ok_or_else(|| InferenceError::InferenceFailed("empty logits".into()))?;
2070 return Ok(idx as u32);
2071 }
2072
2073 let temp = params.temperature as f32;
2075 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
2076 let mut probs: Vec<f32> = logits
2077 .iter()
2078 .map(|&l| ((l - max_logit) / temp).exp())
2079 .collect();
2080 let sum: f32 = probs.iter().sum();
2081 for p in &mut probs {
2082 *p /= sum;
2083 }
2084
2085 if params.top_p < 1.0 {
2087 let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
2088 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2089 let mut cumsum = 0.0;
2090 let mut cutoff_idx = indexed.len();
2091 for (i, &(_, p)) in indexed.iter().enumerate() {
2092 cumsum += p;
2093 if cumsum > params.top_p as f32 {
2094 cutoff_idx = i + 1;
2095 break;
2096 }
2097 }
2098 let allowed: std::collections::HashSet<usize> =
2100 indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
2101 for (i, p) in probs.iter_mut().enumerate() {
2102 if !allowed.contains(&i) {
2103 *p = 0.0;
2104 }
2105 }
2106 let sum: f32 = probs.iter().sum();
2107 if sum > 0.0 {
2108 for p in &mut probs {
2109 *p /= sum;
2110 }
2111 }
2112 }
2113
2114 use rand::Rng;
2116 let mut rng = rand::rng();
2117 let r: f32 = rng.random();
2118 let mut cumsum = 0.0;
2119 for (i, &p) in probs.iter().enumerate() {
2120 cumsum += p;
2121 if cumsum >= r {
2122 return Ok(i as u32);
2123 }
2124 }
2125 Ok((probs.len() - 1) as u32)
2126 }
2127
2128 pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
2131 let instruction = req
2132 .instruction
2133 .as_deref()
2134 .unwrap_or("Retrieve relevant memory facts");
2135
2136 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2137 {
2138 let model_id = self.ensure_mlx_embedding_backend().await?;
2139 let schema = self
2140 .unified_registry
2141 .get(&model_id)
2142 .cloned()
2143 .ok_or_else(|| InferenceError::InferenceFailed(format!(
2144 "embed: unknown schema id {model_id}"
2145 )))?;
2146 let handle = self.ensure_mlx_backend(&schema).await?;
2147 let mut guard = handle.lock().map_err(|_| {
2148 InferenceError::InferenceFailed(format!(
2149 "MLX embedding backend mutex poisoned for {model_id}"
2150 ))
2151 })?;
2152 let backend: &mut backend::MlxBackend = &mut *guard;
2153
2154 let mut results = Vec::with_capacity(req.texts.len());
2155 for text in &req.texts {
2156 let embedding = if req.is_query {
2157 backend.embed_query(text, instruction)?
2158 } else {
2159 backend.embed_one(text)?
2160 };
2161 results.push(embedding);
2162 }
2163 return Ok(results);
2164 }
2165
2166 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
2167 {
2168 self.ensure_embedding_backend().await?;
2169 let mut write = self.embedding_backend.write().await;
2170 let backend = write.as_mut().unwrap();
2171
2172 let mut results = Vec::with_capacity(req.texts.len());
2173 for text in &req.texts {
2174 let embedding = if req.is_query {
2175 backend.embed_query(text, instruction)?
2176 } else {
2177 backend.embed_one(text)?
2178 };
2179 results.push(embedding);
2180 }
2181 Ok(results)
2182 }
2183 }
2184
2185 pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResult, InferenceError> {
2215 if req.documents.is_empty() {
2216 return Ok(RerankResult { ranked: Vec::new(), model_used: None });
2217 }
2218
2219 let model_name = match req.model.clone() {
2220 Some(m) => m,
2221 None => self
2222 .preferred_model_for_capability(ModelCapability::Rerank)
2223 .map(str::to_string)
2224 .ok_or_else(|| {
2225 InferenceError::InferenceFailed(
2226 "no reranker model available — pull a Qwen3-Reranker model first".into(),
2227 )
2228 })?,
2229 };
2230
2231 let schema = self
2232 .unified_registry
2233 .find_by_name(&model_name)
2234 .or_else(|| self.unified_registry.get(&model_name))
2235 .cloned()
2236 .ok_or_else(|| {
2237 InferenceError::InferenceFailed(format!(
2238 "rerank: unknown reranker model {model_name}"
2239 ))
2240 })?;
2241 if !schema.has_capability(ModelCapability::Rerank) {
2242 return Err(InferenceError::InferenceFailed(format!(
2243 "model {} does not declare the Rerank capability",
2244 schema.name
2245 )));
2246 }
2247
2248 let instruction = req.instruction.as_deref().unwrap_or(
2249 "Given a web search query, retrieve relevant passages that answer the query",
2250 );
2251
2252 let mut scored: Vec<RerankedDocument> = Vec::with_capacity(req.documents.len());
2253 for (idx, doc) in req.documents.iter().enumerate() {
2254 let prompt = rerank_prompt(instruction, &req.query, doc);
2255 let gen_req = GenerateRequest {
2256 prompt,
2257 model: Some(schema.id.clone()),
2258 params: tasks::generate::GenerateParams {
2259 temperature: 0.0,
2260 max_tokens: 3,
2264 thinking: tasks::generate::ThinkingMode::Off,
2265 ..Default::default()
2266 },
2267 context: None,
2268 tools: None,
2269 images: None,
2270 messages: None,
2271 cache_control: false,
2272 response_format: None,
2273 intent: None,
2274 };
2275 let out = self.generate(gen_req).await?;
2276 let score = score_from_rerank_output(&out, &schema.name);
2277 scored.push(RerankedDocument {
2278 index: idx,
2279 score,
2280 document: doc.clone(),
2281 });
2282 }
2283
2284 scored.sort_by(|a, b| {
2287 b.score
2288 .partial_cmp(&a.score)
2289 .unwrap_or(std::cmp::Ordering::Equal)
2290 .then_with(|| a.index.cmp(&b.index))
2291 });
2292 if let Some(n) = req.top_n {
2293 scored.truncate(n);
2294 }
2295
2296 Ok(RerankResult {
2297 ranked: scored,
2298 model_used: Some(schema.name),
2299 })
2300 }
2301
2302 pub async fn ground(&self, req: GroundRequest) -> Result<GroundResult, InferenceError> {
2312 let model_name = match req.model.clone() {
2313 Some(m) => m,
2314 None => self
2315 .preferred_model_for_capability(ModelCapability::Grounding)
2316 .map(str::to_string)
2317 .ok_or_else(|| {
2318 InferenceError::InferenceFailed(
2319 "no grounding-capable model available — pull a Qwen2.5-VL model first"
2320 .into(),
2321 )
2322 })?,
2323 };
2324
2325 let gen_req = GenerateRequest {
2326 prompt: req.prompt.clone(),
2327 model: Some(model_name),
2328 params: GenerateParams::default(),
2329 context: None,
2330 tools: None,
2331 images: Some(vec![req.image.clone()]),
2332 messages: None,
2333 cache_control: false,
2334 response_format: None,
2335 intent: None,
2336 };
2337 let result = self.generate_tracked(gen_req).await?;
2338 Ok(GroundResult {
2339 boxes: result.bounding_boxes,
2340 raw_text: result.text,
2341 model_used: Some(result.model_used),
2342 })
2343 }
2344
2345 pub async fn classify(
2348 &self,
2349 req: ClassifyRequest,
2350 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2351 let model = match req
2352 .model
2353 .clone()
2354 .or_else(|| self.preferred_model_for_capability(ModelCapability::Classify).map(str::to_string))
2355 {
2356 Some(m) => m,
2357 None => {
2358 let m = self.router.route_small(&self.registry);
2359 debug!(model = %m, "auto-routed classify request");
2360 m
2361 }
2362 };
2363
2364 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2367 {
2368 return self.classify_via_generate(req, &model).await;
2369 }
2370
2371 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
2372 {
2373 self.ensure_backend(&model).await?;
2374 let mut write = self.backend.write().await;
2375 let backend = write.as_mut().unwrap();
2376 tasks::classify::classify(backend, req).await
2377 }
2378 }
2379
2380 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2382 async fn classify_via_generate(
2383 &self,
2384 req: ClassifyRequest,
2385 model: &str,
2386 ) -> Result<Vec<ClassifyResult>, InferenceError> {
2387 let labels_str = req
2388 .labels
2389 .iter()
2390 .enumerate()
2391 .map(|(i, l)| format!("{}. {}", i + 1, l))
2392 .collect::<Vec<_>>()
2393 .join("\n");
2394
2395 let prompt = format!(
2396 "Classify the following text into one of these categories:\n\
2397 {labels_str}\n\n\
2398 Text: {}\n\n\
2399 Respond with ONLY the category name, nothing else.",
2400 req.text
2401 );
2402
2403 let gen_req = GenerateRequest {
2404 prompt,
2405 model: Some(model.to_string()),
2406 params: tasks::generate::GenerateParams {
2407 temperature: 0.0,
2408 max_tokens: 32,
2409 thinking: tasks::generate::ThinkingMode::Off,
2412 ..Default::default()
2413 },
2414 context: None,
2415 tools: None,
2416 images: None,
2417 messages: None,
2418 cache_control: false,
2419 response_format: None,
2420 intent: None,
2421 };
2422
2423 let response = self.generate(gen_req).await?;
2424 let response_lower = response.trim().to_lowercase();
2425
2426 let mut results: Vec<ClassifyResult> = req
2427 .labels
2428 .iter()
2429 .map(|label| {
2430 let label_lower = label.to_lowercase();
2431 let score = if response_lower == label_lower {
2432 1.0
2433 } else if response_lower.contains(&label_lower) {
2434 0.8
2435 } else {
2436 let label_words: Vec<&str> = label_lower.split_whitespace().collect();
2437 let matches = label_words
2438 .iter()
2439 .filter(|w| response_lower.contains(**w))
2440 .count();
2441 if label_words.is_empty() {
2442 0.0
2443 } else {
2444 0.5 * (matches as f64 / label_words.len() as f64)
2445 }
2446 };
2447 ClassifyResult {
2448 label: label.clone(),
2449 score,
2450 }
2451 })
2452 .collect();
2453
2454 results.sort_by(|a, b| {
2455 b.score
2456 .partial_cmp(&a.score)
2457 .unwrap_or(std::cmp::Ordering::Equal)
2458 });
2459
2460 let total: f64 = results.iter().map(|r| r.score).sum();
2461 if total > 0.0 {
2462 for r in &mut results {
2463 r.score /= total;
2464 }
2465 }
2466
2467 Ok(results)
2468 }
2469
2470 pub async fn transcribe(
2472 &self,
2473 req: TranscribeRequest,
2474 ) -> Result<TranscribeResult, InferenceError> {
2475 let candidates =
2476 self.speech_candidates(ModelCapability::SpeechToText, req.model.as_deref())?;
2477 let mut last_error = None;
2478
2479 for schema in candidates {
2480 let result = match &schema.source {
2481 ModelSource::Mlx { .. } => self.transcribe_local_mlx(&schema, &req).await,
2482 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2483 self.transcribe_elevenlabs(&schema, &req).await
2484 }
2485 _ => Err(InferenceError::InferenceFailed(format!(
2486 "speech-to-text not implemented for model source: {}",
2487 schema.id
2488 ))),
2489 };
2490
2491 match result {
2492 Ok(result) => return Ok(result),
2493 Err(err) => last_error = Some(err),
2494 }
2495 }
2496
2497 Err(last_error.unwrap_or_else(|| {
2498 InferenceError::InferenceFailed("no speech-to-text models available".into())
2499 }))
2500 }
2501
2502 pub async fn synthesize(
2504 &self,
2505 req: SynthesizeRequest,
2506 ) -> Result<SynthesizeResult, InferenceError> {
2507 let candidates =
2508 self.speech_candidates(ModelCapability::TextToSpeech, req.model.as_deref())?;
2509 let mut last_error = None;
2510
2511 for schema in candidates {
2512 let result = match &schema.source {
2513 ModelSource::Mlx { .. } => self.synthesize_local_mlx(&schema, &req).await,
2514 ModelSource::Proprietary { provider, .. } if provider == "elevenlabs" => {
2515 self.synthesize_elevenlabs(&schema, &req).await
2516 }
2517 _ => Err(InferenceError::InferenceFailed(format!(
2518 "text-to-speech not implemented for model source: {}",
2519 schema.id
2520 ))),
2521 };
2522
2523 match result {
2524 Ok(result) => return Ok(result),
2525 Err(err) => last_error = Some(err),
2526 }
2527 }
2528
2529 Err(last_error.unwrap_or_else(|| {
2530 InferenceError::InferenceFailed("no text-to-speech models available".into())
2531 }))
2532 }
2533
2534 pub async fn generate_image(
2536 &self,
2537 req: GenerateImageRequest,
2538 ) -> Result<GenerateImageResult, InferenceError> {
2539 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2546 {
2547 use crate::backend::external_flux;
2548 let backend = std::env::var("CAR_IMAGE_BACKEND")
2549 .unwrap_or_else(|_| "native".to_string());
2550 let use_external = match backend.as_str() {
2551 "external" => true,
2552 "native" => false,
2553 _ => external_flux::is_available() && backend == "auto-external",
2555 };
2556 if use_external {
2557 tracing::info!(
2558 "routing image generation to external mflux \
2559 (set CAR_IMAGE_BACKEND=native to use the Rust port)"
2560 );
2561 let mut req = req;
2562 req.model = self.resolve_external_hf_repo(
2563 req.model.as_deref(),
2564 ModelCapability::ImageGeneration,
2565 );
2566 return external_flux::generate_image(&req);
2567 }
2568 tracing::info!("using native Rust MLX Flux backend");
2569 }
2570
2571 let candidates =
2572 self.media_generation_candidates(ModelCapability::ImageGeneration, req.model.as_deref())?;
2573 let mut last_error = None;
2574
2575 for schema in candidates {
2576 let result = match &schema.source {
2577 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2578 ModelSource::Mlx { .. } => self.generate_image_native_mlx(&schema, &req).await,
2579 _ => Err(InferenceError::InferenceFailed(format!(
2580 "image generation not implemented for model source: {}",
2581 schema.id
2582 ))),
2583 };
2584
2585 match result {
2586 Ok(result) => return Ok(result),
2587 Err(err) => last_error = Some(err),
2588 }
2589 }
2590
2591 Err(last_error.unwrap_or_else(|| {
2592 InferenceError::InferenceFailed("no image generation models available".into())
2593 }))
2594 }
2595
2596 pub async fn generate_image_batch(
2612 &self,
2613 req: GenerateImageRequest,
2614 ) -> Result<Vec<GenerateImageResult>, InferenceError> {
2615 let count = req.variant_count.unwrap_or(1).max(1);
2616 if count == 1 {
2617 return self.generate_image(req).await.map(|r| vec![r]);
2618 }
2619 let base_seed = req.seed.unwrap_or(0);
2620 let mut results = Vec::with_capacity(count as usize);
2621 for i in 0..count {
2622 let mut variant_req = req.clone();
2627 variant_req.seed = Some(base_seed.wrapping_add(i as u64));
2628 variant_req.variant_count = Some(1);
2632 results.push(self.generate_image(variant_req).await?);
2633 }
2634 Ok(results)
2635 }
2636
2637 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2639 async fn generate_image_native_mlx(
2640 &self,
2641 schema: &ModelSchema,
2642 req: &GenerateImageRequest,
2643 ) -> Result<GenerateImageResult, InferenceError> {
2644 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2645 let size = backend_cache::estimate_model_size(&model_dir);
2646 let cache = Arc::clone(&self.flux_cache);
2647 let key = schema.id.clone();
2648 let handle = cache.get_or_load(&key, size, || {
2649 backend::mlx_flux::FluxBackend::load(&model_dir)
2650 })?;
2651 let req = req.clone();
2655 tokio::task::spawn_blocking(move || -> Result<GenerateImageResult, InferenceError> {
2656 let mut guard = handle.lock().map_err(|_| {
2657 InferenceError::InferenceFailed("flux backend mutex poisoned".into())
2658 })?;
2659 guard.generate(&req)
2660 })
2661 .await
2662 .map_err(|e| InferenceError::InferenceFailed(format!("flux task join: {e}")))?
2663 }
2664
2665 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2667 async fn generate_video_native_mlx(
2668 &self,
2669 schema: &ModelSchema,
2670 req: &GenerateVideoRequest,
2671 ) -> Result<GenerateVideoResult, InferenceError> {
2672 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
2673 let size = backend_cache::estimate_model_size(&model_dir);
2674 let cache = Arc::clone(&self.ltx_cache);
2675 let key = schema.id.clone();
2676 let handle = cache.get_or_load(&key, size, || {
2677 backend::mlx_ltx::LtxBackend::load(&model_dir)
2678 })?;
2679 let req = req.clone();
2680 tokio::task::spawn_blocking(move || -> Result<GenerateVideoResult, InferenceError> {
2681 let mut guard = handle.lock().map_err(|_| {
2682 InferenceError::InferenceFailed("ltx backend mutex poisoned".into())
2683 })?;
2684 guard.generate(&req)
2685 })
2686 .await
2687 .map_err(|e| InferenceError::InferenceFailed(format!("ltx task join: {e}")))?
2688 }
2689
2690 pub async fn generate_video(
2692 &self,
2693 req: GenerateVideoRequest,
2694 ) -> Result<GenerateVideoResult, InferenceError> {
2695 if let Err(msg) = req.validate() {
2698 return Err(InferenceError::InferenceFailed(format!(
2699 "invalid GenerateVideoRequest: {}",
2700 msg
2701 )));
2702 }
2703 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2710 {
2711 use crate::backend::external_ltx;
2712 let backend = std::env::var("CAR_VIDEO_BACKEND")
2717 .unwrap_or_else(|_| "native".to_string());
2718 let use_external = match backend.as_str() {
2719 "external" => true,
2720 "native" => false,
2721 "auto-external" => external_ltx::is_available(),
2725 _ => false,
2726 };
2727 if use_external {
2728 tracing::info!(
2729 "routing video generation to external ltx-2-mlx \
2730 (set CAR_VIDEO_BACKEND=native to opt into the in-progress Rust port)"
2731 );
2732 let mut req = req;
2733 req.model = self.resolve_external_hf_repo(
2734 req.model.as_deref(),
2735 ModelCapability::VideoGeneration,
2736 );
2737 return external_ltx::generate_video(&req);
2738 }
2739 tracing::info!(
2740 "using native Rust MLX video backend — Gemma left-pad bug was \
2741 just fixed; verify output quality before making this the default"
2742 );
2743 }
2744
2745 let candidates =
2746 self.media_generation_candidates(ModelCapability::VideoGeneration, req.model.as_deref())?;
2747 let mut last_error = None;
2748
2749 for schema in candidates {
2750 let result = match &schema.source {
2751 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2752 ModelSource::Mlx { .. } => self.generate_video_native_mlx(&schema, &req).await,
2753 _ => Err(InferenceError::InferenceFailed(format!(
2754 "video generation not implemented for model source: {}",
2755 schema.id
2756 ))),
2757 };
2758
2759 match result {
2760 Ok(result) => return Ok(result),
2761 Err(err) => last_error = Some(err),
2762 }
2763 }
2764
2765 Err(last_error.unwrap_or_else(|| {
2766 InferenceError::InferenceFailed("no video generation models available".into())
2767 }))
2768 }
2769
2770 pub fn list_models_unified(&self) -> Vec<ModelInfo> {
2772 self.unified_registry
2773 .list()
2774 .iter()
2775 .map(|m| ModelInfo::from(*m))
2776 .collect()
2777 }
2778
2779 pub fn list_schemas(&self) -> Vec<ModelSchema> {
2782 self.unified_registry.list().into_iter().cloned().collect()
2783 }
2784
2785 pub fn list_models(&self) -> Vec<models::ModelInfo> {
2786 self.registry.list_models()
2787 }
2788
2789 pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
2791 let schema = self
2792 .unified_registry
2793 .find_by_name(name)
2794 .or_else(|| self.unified_registry.get(name))
2795 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2796 self.unified_registry.ensure_local(&schema.id).await
2797 }
2798
2799 pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
2801 let schema = self
2802 .unified_registry
2803 .get(name)
2804 .or_else(|| {
2805 self.unified_registry
2806 .list()
2807 .into_iter()
2808 .find(|schema| schema.name.eq_ignore_ascii_case(name))
2809 })
2810 .or_else(|| self.unified_registry.find_by_name(name))
2811 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
2812 let model_dir = self.unified_registry.models_dir().join(&schema.name);
2813 if model_dir.exists() {
2814 std::fs::remove_dir_all(&model_dir)?;
2815 }
2816 match &schema.source {
2817 ModelSource::Mlx { hf_repo, .. } => {
2818 remove_huggingface_repo_cache(hf_repo)?;
2819 }
2820 ModelSource::Local {
2821 hf_repo,
2822 tokenizer_repo,
2823 ..
2824 } => {
2825 remove_huggingface_repo_cache(hf_repo)?;
2826 remove_huggingface_repo_cache(tokenizer_repo)?;
2827 }
2828 _ => {}
2829 }
2830 Ok(())
2831 }
2832
2833 pub fn register_model(&mut self, schema: ModelSchema) {
2835 self.unified_registry.register(schema);
2836 }
2837
2838 pub async fn discover_vllm_mlx_models(&mut self) -> usize {
2841 let config = vllm_mlx::VllmMlxConfig::default();
2842 if !config.auto_discover {
2843 return 0;
2844 }
2845 vllm_mlx::discover_and_register(&config, &mut self.unified_registry).await
2846 }
2847
2848 pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
2850 self.outcome_tracker.clone()
2851 }
2852
2853 async fn auto_save_outcomes(&self) {
2855 if let Err(e) = self.save_outcomes().await {
2856 tracing::debug!("auto-save outcomes failed: {}", e);
2857 }
2858 if let Err(e) = self.save_key_pool_stats().await {
2859 tracing::debug!("auto-save key pool stats failed: {}", e);
2860 }
2861 }
2862
2863 pub async fn save_outcomes(&self) -> Result<(), std::io::Error> {
2865 let tracker = self.outcome_tracker.read().await;
2866 let path = self.config.models_dir.join("outcome_profiles.json");
2867 tracker.save_to_file(&path)
2868 }
2869
2870 pub async fn save_key_pool_stats(&self) -> Result<(), std::io::Error> {
2872 let path = self.config.models_dir.join("key_pool_stats.json");
2873 self.remote_backend.key_pool.save_stats(&path).await
2874 }
2875
2876 pub async fn key_pool_stats(
2878 &self,
2879 ) -> std::collections::HashMap<String, Vec<key_pool::KeyStats>> {
2880 self.remote_backend.key_pool.all_stats().await
2881 }
2882
2883 pub async fn export_profiles(&self) -> Vec<ModelProfile> {
2885 let tracker = self.outcome_tracker.read().await;
2886 tracker.export_profiles()
2887 }
2888
2889 pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
2891 let mut tracker = self.outcome_tracker.write().await;
2892 tracker.import_profiles(profiles);
2893 }
2894
2895 pub async fn prepare_speech_runtime(&self) -> Result<PathBuf, InferenceError> {
2898 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
2899 {
2900 Ok(self.config.models_dir.clone())
2902 }
2903 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
2904 {
2905 Ok(self.ensure_speech_runtime().await?.root)
2906 }
2907 }
2908
2909 pub fn set_speech_policy(&mut self, policy: SpeechPolicy) {
2911 self.speech_policy = policy;
2912 }
2913
2914 pub fn set_routing_config(&mut self, config: RoutingConfig) {
2915 self.adaptive_router.set_config(config);
2916 }
2917
2918 pub async fn install_curated_speech(
2920 &mut self,
2921 ) -> Result<Vec<SpeechInstallReport>, InferenceError> {
2922 let _runtime_root = self.prepare_speech_runtime().await?;
2923 let schemas = self.list_schemas();
2924 let mut repos = Vec::new();
2925 for schema in &schemas {
2926 if !schema.is_mlx() || !schema.tags.iter().any(|tag| tag == "speech") {
2927 continue;
2928 }
2929 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
2930 if !repos.iter().any(|existing: &String| existing == hf_repo) {
2931 repos.push(hf_repo.clone());
2932 }
2933 }
2934 }
2935
2936 let mut installed = Vec::new();
2937 for repo in repos {
2938 let (snapshot_path, files_downloaded) = download_hf_repo_snapshot(&repo).await?;
2939 let name = schemas
2940 .iter()
2941 .find(|schema| {
2942 matches!(&schema.source, ModelSource::Mlx { hf_repo, .. } if hf_repo == &repo)
2943 })
2944 .map(|schema| schema.name.clone())
2945 .unwrap_or_else(|| repo.clone());
2946 installed.push(SpeechInstallReport {
2947 name,
2948 hf_repo: repo,
2949 snapshot_path,
2950 files_downloaded,
2951 });
2952 }
2953
2954 self.unified_registry.refresh_availability();
2955 Ok(installed)
2956 }
2957
2958
2959 pub fn speech_health(&self) -> SpeechHealthReport {
2961 let local_stt_default = self
2962 .speech_health_default_name(ModelCapability::SpeechToText, true, false);
2963 let local_tts_default = self
2964 .speech_health_default_name(ModelCapability::TextToSpeech, true, false);
2965 let remote_stt_default = self
2966 .speech_health_default_name(ModelCapability::SpeechToText, false, true);
2967 let remote_tts_default = self
2968 .speech_health_default_name(ModelCapability::TextToSpeech, false, true);
2969
2970 let mut local_models = Vec::new();
2971 let mut remote_models = Vec::new();
2972 for schema in self.list_schemas() {
2973 let capability = if schema.has_capability(ModelCapability::SpeechToText) {
2974 Some(ModelCapability::SpeechToText)
2975 } else if schema.has_capability(ModelCapability::TextToSpeech) {
2976 Some(ModelCapability::TextToSpeech)
2977 } else {
2978 None
2979 };
2980 let Some(capability) = capability else {
2981 continue;
2982 };
2983
2984 let selected_by_default = local_stt_default
2985 .as_ref()
2986 .is_some_and(|name| name == &schema.name)
2987 || local_tts_default
2988 .as_ref()
2989 .is_some_and(|name| name == &schema.name)
2990 || remote_stt_default
2991 .as_ref()
2992 .is_some_and(|name| name == &schema.name)
2993 || remote_tts_default
2994 .as_ref()
2995 .is_some_and(|name| name == &schema.name);
2996
2997 let health = SpeechModelHealth {
2998 id: schema.id.clone(),
2999 name: schema.name.clone(),
3000 provider: schema.provider.clone(),
3001 capability,
3002 is_local: schema.is_local(),
3003 available: schema.available,
3004 cached: speech_model_cached(&schema),
3005 selected_by_default,
3006 source: speech_model_source_label(&schema),
3007 };
3008 if schema.is_local() {
3009 local_models.push(health);
3010 } else {
3011 remote_models.push(health);
3012 }
3013 }
3014
3015 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
3017 let runtime = SpeechRuntimeHealth {
3018 root: self.config.models_dir.clone(),
3019 installed: true,
3020 python: PathBuf::new(),
3021 stt_command: PathBuf::new(),
3022 tts_command: PathBuf::new(),
3023 configured_python: None,
3024 detected_python: None,
3025 };
3026
3027 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3028 let runtime = {
3029 let rt = SpeechRuntime::new(speech_runtime_root_from_models_dir(&self.config.models_dir));
3030 SpeechRuntimeHealth {
3031 root: rt.root.clone(),
3032 installed: rt.is_ready(),
3033 python: rt.python.clone(),
3034 stt_command: rt.stt_program.clone(),
3035 tts_command: rt.tts_program.clone(),
3036 configured_python: std::env::var("CAR_SPEECH_PYTHON")
3037 .ok()
3038 .filter(|value| !value.trim().is_empty()),
3039 detected_python: detect_speech_python(),
3040 }
3041 };
3042
3043 SpeechHealthReport {
3044 runtime,
3045 local_models,
3046 remote_models,
3047 elevenlabs_configured: car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY").is_some(),
3048 prefer_local: self.speech_policy.prefer_local,
3049 allow_remote_fallback: self.speech_policy.allow_remote_fallback,
3050 preferred_local_stt: self.speech_policy.preferred_local_stt.clone(),
3051 preferred_local_tts: self.speech_policy.preferred_local_tts.clone(),
3052 preferred_remote_stt: self.speech_policy.preferred_remote_stt.clone(),
3053 preferred_remote_tts: self.speech_policy.preferred_remote_tts.clone(),
3054 local_stt_default,
3055 local_tts_default,
3056 remote_stt_default,
3057 remote_tts_default,
3058 }
3059 }
3060
3061 pub async fn model_health(&self) -> ModelHealthReport {
3064 let schemas = self.list_schemas();
3065 let total_models = schemas.len();
3066 let available_models = schemas.iter().filter(|schema| schema.available).count();
3067 let local_models = schemas.iter().filter(|schema| schema.is_local()).count();
3068 let remote_models = total_models.saturating_sub(local_models);
3069
3070 let defaults = vec![
3071 self.model_default_health(
3072 ModelCapability::Generate,
3073 self.preferred_model_for_capability(ModelCapability::Generate)
3074 .unwrap_or(&self.config.generation_model),
3075 ),
3076 self.model_default_health(
3077 ModelCapability::Embed,
3078 self.preferred_model_for_capability(ModelCapability::Embed)
3079 .unwrap_or(&self.config.embedding_model),
3080 ),
3081 self.model_default_health(
3082 ModelCapability::Classify,
3083 self.preferred_model_for_capability(ModelCapability::Classify)
3084 .unwrap_or(&self.config.classification_model),
3085 ),
3086 ];
3087
3088 let mut providers = std::collections::BTreeMap::new();
3089 for schema in &schemas {
3090 let entry = providers
3091 .entry(schema.provider.clone())
3092 .or_insert_with(|| ProviderAccumulator {
3093 configured: false,
3094 local_models: 0,
3095 remote_models: 0,
3096 available_models: 0,
3097 capabilities: std::collections::HashSet::new(),
3098 });
3099
3100 entry.configured |= model_source_configured(schema);
3101 if schema.is_local() {
3102 entry.local_models += 1;
3103 } else {
3104 entry.remote_models += 1;
3105 }
3106 if schema.available {
3107 entry.available_models += 1;
3108 }
3109 for capability in &schema.capabilities {
3110 entry.capabilities.insert(*capability);
3111 }
3112 }
3113
3114 let providers = providers
3115 .into_iter()
3116 .map(|(provider, acc)| ModelProviderHealth {
3117 provider,
3118 configured: acc.configured,
3119 local_models: acc.local_models,
3120 remote_models: acc.remote_models,
3121 available_models: acc.available_models,
3122 capabilities: sort_capabilities(acc.capabilities.into_iter().collect()),
3123 })
3124 .collect();
3125
3126 let capabilities = all_model_capabilities()
3127 .into_iter()
3128 .map(|capability| {
3129 let relevant: Vec<&ModelSchema> = schemas
3130 .iter()
3131 .filter(|schema| schema.has_capability(capability))
3132 .collect();
3133 let available: Vec<&ModelSchema> = relevant
3134 .iter()
3135 .copied()
3136 .filter(|schema| schema.available)
3137 .collect();
3138 ModelCapabilityHealth {
3139 capability,
3140 total_models: relevant.len(),
3141 available_models: available.len(),
3142 local_available_models: available
3143 .iter()
3144 .filter(|schema| schema.is_local())
3145 .count(),
3146 remote_available_models: available
3147 .iter()
3148 .filter(|schema| !schema.is_local())
3149 .count(),
3150 }
3151 })
3152 .collect();
3153
3154 let routing = self.routing_scenarios().await;
3155 let routing_config = self.adaptive_router.config().clone();
3156 let benchmark_priors = load_benchmark_prior_health(&self.config.models_dir, &schemas);
3157
3158 ModelHealthReport {
3159 total_models,
3160 available_models,
3161 local_models,
3162 remote_models,
3163 defaults,
3164 providers,
3165 capabilities,
3166 routing_prefer_local: routing_config.prefer_local,
3167 routing_quality_first_cold_start: routing_config.quality_first_cold_start,
3168 routing_min_observations: routing_config.min_observations,
3169 routing_bootstrap_min_task_observations: routing_config.bootstrap_min_task_observations,
3170 routing_bootstrap_quality_floor: routing_config.bootstrap_quality_floor,
3171 routing_quality_weight: routing_config.quality_weight,
3172 routing_latency_weight: routing_config.latency_weight,
3173 routing_cost_weight: routing_config.cost_weight,
3174 routing_scenarios: routing,
3175 benchmark_priors,
3176 speech: self.speech_health(),
3177 }
3178 }
3179
3180 async fn routing_scenarios(&self) -> Vec<RoutingScenarioHealth> {
3181 let tracker = self.outcome_tracker.read().await;
3182 let config = self.adaptive_router.config().clone();
3183 let scenarios = [
3184 (
3185 "interactive_text",
3186 "Summarize the benefits of local-first AI routing in two sentences.",
3187 "text",
3188 RoutingWorkload::Interactive,
3189 false,
3190 false,
3191 ),
3192 (
3193 "background_code",
3194 "Write a Python function named fibonacci(n) that returns the nth Fibonacci number.",
3195 "code",
3196 RoutingWorkload::Background,
3197 false,
3198 false,
3199 ),
3200 (
3201 "interactive_tool_use",
3202 "Use the provided weather tool to get the weather for Boston.",
3203 "tool_use",
3204 RoutingWorkload::Interactive,
3205 true,
3206 false,
3207 ),
3208 (
3209 "interactive_vision",
3210 "What is in this image? Answer in one word.",
3211 "vision",
3212 RoutingWorkload::Interactive,
3213 false,
3214 true,
3215 ),
3216 ];
3217
3218 scenarios
3219 .into_iter()
3220 .map(|(name, prompt, task_family, workload, has_tools, has_vision)| {
3221 let decision = self.adaptive_router.route_context_aware(
3222 prompt,
3223 0,
3224 &self.unified_registry,
3225 &tracker,
3226 has_tools,
3227 has_vision,
3228 workload,
3229 );
3230 let quality_first_cold_start = if has_tools || has_vision {
3231 config.quality_first_cold_start
3232 } else if task_family == "code" && matches!(workload, RoutingWorkload::Background) {
3233 false
3234 } else {
3235 config.quality_first_cold_start
3236 };
3237 RoutingScenarioHealth {
3238 name: name.to_string(),
3239 task_family: task_family.to_string(),
3240 workload,
3241 has_tools,
3242 has_vision,
3243 prefer_local: if task_family == "speech" {
3244 self.speech_policy.prefer_local
3245 } else {
3246 config.prefer_local
3247 },
3248 quality_first_cold_start,
3249 bootstrap_min_task_observations: config.bootstrap_min_task_observations,
3250 bootstrap_quality_floor: config.bootstrap_quality_floor,
3251 model_id: decision.model_id,
3252 model_name: decision.model_name,
3253 reason: decision.reason,
3254 strategy: decision.strategy,
3255 }
3256 })
3257 .collect()
3258 }
3259
3260 pub async fn smoke_test_speech(
3262 &self,
3263 local: bool,
3264 remote: bool,
3265 ) -> Result<SpeechSmokeReport, InferenceError> {
3266 let mut report = SpeechSmokeReport::default();
3267
3268 if local {
3269 let tts = self
3270 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
3271 .ok_or_else(|| {
3272 InferenceError::InferenceFailed(
3273 "no local text-to-speech model available".into(),
3274 )
3275 })?;
3276 let stt = self
3277 .preferred_speech_schema(ModelCapability::SpeechToText, true, false)
3278 .ok_or_else(|| {
3279 InferenceError::InferenceFailed(
3280 "no local speech-to-text model available".into(),
3281 )
3282 })?;
3283 report.local = Some(
3284 self.run_speech_smoke_path("local", &tts, &stt, "Testing CAR local speech path.")
3285 .await?,
3286 );
3287 } else {
3288 report.skipped.push("local".to_string());
3289 }
3290
3291 if remote {
3292 let tts = self
3293 .preferred_speech_schema(ModelCapability::TextToSpeech, false, true)
3294 .ok_or_else(|| {
3295 InferenceError::InferenceFailed(
3296 "no remote text-to-speech model available".into(),
3297 )
3298 })?;
3299 let stt = self
3300 .preferred_speech_schema(ModelCapability::SpeechToText, false, true)
3301 .ok_or_else(|| {
3302 InferenceError::InferenceFailed(
3303 "no remote speech-to-text model available".into(),
3304 )
3305 })?;
3306 report.remote = Some(
3307 self.run_speech_smoke_path("remote", &tts, &stt, "Testing CAR remote speech path.")
3308 .await?,
3309 );
3310 } else {
3311 report.skipped.push("remote".to_string());
3312 }
3313
3314 Ok(report)
3315 }
3316
3317 fn speech_candidates(
3318 &self,
3319 capability: ModelCapability,
3320 explicit: Option<&str>,
3321 ) -> Result<Vec<ModelSchema>, InferenceError> {
3322 if let Some(model) = explicit {
3323 let schema = self
3324 .unified_registry
3325 .get(model)
3326 .or_else(|| self.unified_registry.find_by_name(model))
3327 .cloned()
3328 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3329 if !schema.has_capability(capability) {
3330 return Err(InferenceError::InferenceFailed(format!(
3331 "model {} does not support {:?}",
3332 schema.name, capability
3333 )));
3334 }
3335 return Ok(vec![schema]);
3336 }
3337
3338 let mut candidates: Vec<ModelSchema> = self
3339 .unified_registry
3340 .query(&ModelFilter {
3341 capabilities: vec![capability],
3342 ..Default::default()
3343 })
3344 .into_iter()
3345 .cloned()
3346 .collect();
3347
3348 if candidates.is_empty() {
3349 return Err(InferenceError::InferenceFailed(format!(
3350 "no models registered for capability {:?}",
3351 capability
3352 )));
3353 }
3354
3355 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3356 if !self.speech_policy.allow_remote_fallback && candidates.iter().any(|model| model.is_local()) {
3357 candidates.retain(|model| model.is_local());
3358 }
3359
3360 Ok(candidates)
3361 }
3362
3363 fn resolve_external_hf_repo(
3368 &self,
3369 explicit: Option<&str>,
3370 capability: ModelCapability,
3371 ) -> Option<String> {
3372 let id = explicit?;
3373 let schema = self
3374 .unified_registry
3375 .get(id)
3376 .or_else(|| self.unified_registry.find_by_name(id))?;
3377 if !schema.has_capability(capability) {
3378 return Some(id.to_string());
3379 }
3380 if let ModelSource::Mlx { hf_repo, .. } = &schema.source {
3381 return Some(hf_repo.clone());
3382 }
3383 Some(id.to_string())
3384 }
3385
3386 fn media_generation_candidates(
3387 &self,
3388 capability: ModelCapability,
3389 explicit: Option<&str>,
3390 ) -> Result<Vec<ModelSchema>, InferenceError> {
3391 if let Some(model) = explicit {
3392 let schema = self
3393 .unified_registry
3394 .get(model)
3395 .or_else(|| self.unified_registry.find_by_name(model))
3396 .cloned()
3397 .ok_or_else(|| InferenceError::ModelNotFound(model.to_string()))?;
3398 if !schema.has_capability(capability) {
3399 return Err(InferenceError::InferenceFailed(format!(
3400 "model {} does not support {:?}",
3401 schema.name, capability
3402 )));
3403 }
3404 return Ok(vec![schema]);
3405 }
3406
3407 let mut candidates: Vec<ModelSchema> = self
3408 .unified_registry
3409 .query(&ModelFilter {
3410 capabilities: vec![capability],
3411 local_only: true,
3412 ..Default::default()
3413 })
3414 .into_iter()
3415 .cloned()
3416 .collect();
3417 candidates.sort_by_key(|schema| (!schema.available, schema.size_mb()));
3418 if candidates.is_empty() {
3419 return Err(InferenceError::InferenceFailed(format!(
3420 "no models registered for capability {:?}",
3421 capability
3422 )));
3423 }
3424 Ok(candidates)
3425 }
3426
3427 fn preferred_speech_schema(
3428 &self,
3429 capability: ModelCapability,
3430 local_only: bool,
3431 remote_only: bool,
3432 ) -> Option<ModelSchema> {
3433 let available_only = remote_only;
3434 let mut candidates: Vec<ModelSchema> = self
3435 .unified_registry
3436 .query(&ModelFilter {
3437 capabilities: vec![capability],
3438 available_only,
3439 ..Default::default()
3440 })
3441 .into_iter()
3442 .filter(|schema| (!local_only || schema.is_local()) && (!remote_only || schema.is_remote()))
3443 .cloned()
3444 .collect();
3445 candidates.sort_by_key(|model| self.speech_sort_key(capability, model));
3446 candidates.into_iter().next()
3447 }
3448
3449 fn speech_health_default_name(
3450 &self,
3451 capability: ModelCapability,
3452 local_only: bool,
3453 remote_only: bool,
3454 ) -> Option<String> {
3455 let preferred = match capability {
3456 ModelCapability::SpeechToText if local_only => {
3457 self.speech_policy.preferred_local_stt.as_ref()
3458 }
3459 ModelCapability::SpeechToText if remote_only => {
3460 self.speech_policy.preferred_remote_stt.as_ref()
3461 }
3462 ModelCapability::TextToSpeech if local_only => {
3463 self.speech_policy.preferred_local_tts.as_ref()
3464 }
3465 ModelCapability::TextToSpeech if remote_only => {
3466 self.speech_policy.preferred_remote_tts.as_ref()
3467 }
3468 _ => None,
3469 };
3470
3471 preferred
3472 .filter(|name| {
3473 self.unified_registry.list().iter().any(|schema| {
3474 schema.name == **name
3475 && schema.has_capability(capability)
3476 && (!local_only || schema.is_local())
3477 && (!remote_only || schema.is_remote())
3478 })
3479 })
3480 .cloned()
3481 .or_else(|| {
3482 self.preferred_speech_schema(capability, local_only, remote_only)
3483 .map(|schema| schema.name)
3484 })
3485 }
3486
3487 fn model_default_health(
3488 &self,
3489 capability: ModelCapability,
3490 configured_model: &str,
3491 ) -> ModelDefaultHealth {
3492 let schema = self
3493 .unified_registry
3494 .find_by_name(configured_model)
3495 .or_else(|| self.unified_registry.get(configured_model));
3496
3497 ModelDefaultHealth {
3498 capability,
3499 configured_model: configured_model.to_string(),
3500 available: schema.is_some_and(|model| model.available),
3501 is_local: schema.is_some_and(ModelSchema::is_local),
3502 provider: schema.map(|model| model.provider.clone()),
3503 }
3504 }
3505
3506 fn speech_sort_key(
3507 &self,
3508 capability: ModelCapability,
3509 model: &ModelSchema,
3510 ) -> (u8, u8, u8, u8, u64, u64) {
3511 let policy_preference = match capability {
3512 ModelCapability::SpeechToText if model.is_local() => self
3513 .speech_policy
3514 .preferred_local_stt
3515 .as_ref(),
3516 ModelCapability::SpeechToText => self.speech_policy.preferred_remote_stt.as_ref(),
3517 ModelCapability::TextToSpeech if model.is_local() => self
3518 .speech_policy
3519 .preferred_local_tts
3520 .as_ref(),
3521 ModelCapability::TextToSpeech => self.speech_policy.preferred_remote_tts.as_ref(),
3522 _ => None,
3523 };
3524 let local_rank = if self.speech_policy.prefer_local {
3525 if model.is_local() {
3526 0
3527 } else {
3528 1
3529 }
3530 } else if model.is_remote() {
3531 0
3532 } else {
3533 1
3534 };
3535 let availability_rank = if model.available {
3536 0
3537 } else if model.is_local() {
3538 1
3539 } else {
3540 2
3541 };
3542 let policy_rank: u8 = if policy_preference.is_some_and(|preferred| preferred == &model.name) {
3543 0
3544 } else {
3545 1
3546 };
3547 let speech_rank = match capability {
3548 ModelCapability::TextToSpeech => {
3549 if model.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
3550 0
3551 } else if model.name == "Kokoro-82M-bf16" {
3552 1
3553 } else if model.name == "Kokoro-82M-6bit" {
3554 2
3555 } else {
3556 3
3557 }
3558 }
3559 ModelCapability::SpeechToText => {
3560 if model.name == "Parakeet-TDT-0.6B-v3-MLX" {
3561 0
3562 } else {
3563 1
3564 }
3565 }
3566 _ => 0,
3567 };
3568 let latency_rank = model.performance.latency_p50_ms.unwrap_or(u64::MAX);
3569 let size_rank = model.cost.size_mb.unwrap_or(u64::MAX);
3570 (
3571 local_rank,
3572 availability_rank,
3573 policy_rank,
3574 speech_rank,
3575 latency_rank,
3576 size_rank,
3577 )
3578 }
3579
3580 async fn run_speech_smoke_path(
3581 &self,
3582 path: &str,
3583 tts: &ModelSchema,
3584 stt: &ModelSchema,
3585 text: &str,
3586 ) -> Result<SpeechSmokePathReport, InferenceError> {
3587 let work_dir = temp_work_dir(&format!("speech-smoke-{path}"))?;
3588 let audio_path = work_dir.join(format!("{path}.wav"));
3589 let synth = self
3590 .synthesize(SynthesizeRequest {
3591 text: text.to_string(),
3592 model: Some(tts.name.clone()),
3593 voice: default_speech_voice(tts),
3594 language: Some("en".to_string()),
3595 output_path: Some(audio_path.display().to_string()),
3596 ..SynthesizeRequest::default()
3597 })
3598 .await?;
3599 let transcript = self
3600 .transcribe(TranscribeRequest {
3601 audio_path: synth.audio_path.clone(),
3602 model: Some(stt.name.clone()),
3603 language: Some("en".to_string()),
3604 prompt: None,
3605 timestamps: false,
3606 })
3607 .await?;
3608
3609 Ok(SpeechSmokePathReport {
3610 path: path.to_string(),
3611 tts_model: synth.model_used.unwrap_or_else(|| tts.name.clone()),
3612 stt_model: transcript.model_used.unwrap_or_else(|| stt.name.clone()),
3613 audio_path: PathBuf::from(synth.audio_path),
3614 transcript: transcript.text,
3615 })
3616 }
3617
3618 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3619 async fn ensure_speech_runtime(&self) -> Result<SpeechRuntime, InferenceError> {
3620 let mut guard = self.speech_runtime.lock().await;
3621 if let Some(runtime) = guard.as_ref() {
3622 if runtime.is_ready() {
3623 return Ok(runtime.clone());
3624 }
3625 }
3626
3627 let runtime = SpeechRuntime::new(speech_runtime_root_from_models_dir(
3628 &self.config.models_dir,
3629 ));
3630 if !runtime.is_ready() {
3631 bootstrap_speech_runtime(&runtime).await?;
3632 }
3633 if !runtime.is_ready() {
3634 return Err(InferenceError::InferenceFailed(format!(
3635 "managed speech runtime is not ready at {}",
3636 runtime.root.display()
3637 )));
3638 }
3639
3640 *guard = Some(runtime.clone());
3641 Ok(runtime)
3642 }
3643
3644 async fn transcribe_local_mlx(
3645 &self,
3646 schema: &ModelSchema,
3647 req: &TranscribeRequest,
3648 ) -> Result<TranscribeResult, InferenceError> {
3649 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
3651 {
3652 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3653 let parakeet = backend::mlx_parakeet::ParakeetBackend::load(&model_dir)?;
3654 let (text, words) = if req.timestamps {
3656 parakeet
3657 .transcribe_detailed(Path::new(&req.audio_path))
3658 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?
3659 } else {
3660 let t = parakeet
3661 .transcribe(Path::new(&req.audio_path))
3662 .map_err(|e| InferenceError::InferenceFailed(format!("native STT: {e}")))?;
3663 (t, Vec::new())
3664 };
3665 return Ok(TranscribeResult {
3666 text,
3667 model_used: Some(schema.name.clone()),
3668 language: req.language.clone(),
3669 words,
3670 });
3671 }
3672
3673 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3675 {
3676 let runtime = self.ensure_speech_runtime().await?;
3677 let hf_repo = match &schema.source {
3678 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3679 _ => unreachable!(),
3680 };
3681 let output_dir = temp_work_dir("stt")?;
3682 let output_prefix = output_dir.join("transcript");
3683 let mut args = vec![
3684 "--model".to_string(),
3685 hf_repo,
3686 "--audio".to_string(),
3687 req.audio_path.clone(),
3688 "--output-path".to_string(),
3689 output_prefix.display().to_string(),
3690 "--format".to_string(),
3691 "json".to_string(),
3692 ];
3693 if let Some(language) = &req.language {
3694 args.push("--language".to_string());
3695 args.push(normalize_lang_code(language));
3696 }
3697 if let Some(prompt) = &req.prompt {
3698 args.push("--context".to_string());
3699 args.push(prompt.clone());
3700 }
3701 if req.timestamps {
3702 args.push("--verbose".to_string());
3703 }
3704
3705 let output = run_mlx_audio_command(&runtime, "stt.generate", &args).await?;
3706 let text = read_transcription_result(&output_prefix)?
3707 .or_else(|| extract_text_from_payload(&output.stdout))
3708 .ok_or_else(|| {
3709 InferenceError::InferenceFailed(format!(
3710 "mlx-audio transcription returned no text: {}",
3711 output.stderr
3712 ))
3713 })?;
3714
3715 Ok(TranscribeResult {
3716 text,
3717 model_used: Some(schema.name.clone()),
3718 language: req.language.clone(),
3719 words: Vec::new(),
3720 })
3721 }
3722 }
3723
3724 async fn synthesize_local_mlx(
3725 &self,
3726 schema: &ModelSchema,
3727 req: &SynthesizeRequest,
3728 ) -> Result<SynthesizeResult, InferenceError> {
3729 let requested = req.requested_advanced_controls();
3734 let repo_supports_advanced = match &schema.source {
3735 ModelSource::Mlx { hf_repo, .. } => {
3736 hf_repo.to_ascii_lowercase().contains("qwen3-tts")
3737 }
3738 _ => false,
3739 };
3740 if !requested.is_empty() && !repo_supports_advanced {
3741 if req.strict_capabilities {
3742 return Err(InferenceError::InferenceFailed(format!(
3743 "model {name} does not support Qwen3-TTS advanced controls {requested:?}; \
3744 route to a Qwen3-TTS model or set strict_capabilities = false to degrade",
3745 name = schema.name,
3746 )));
3747 }
3748 tracing::warn!(
3749 model = %schema.name,
3750 fields = ?requested,
3751 "Qwen3-TTS advanced controls set on non-Qwen3-TTS backend — ignored \
3752 (set strict_capabilities=true to error instead)"
3753 );
3754 }
3755
3756 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
3758 {
3759 if repo_supports_advanced && !requested.is_empty() {
3765 if req.strict_capabilities {
3766 return Err(InferenceError::InferenceFailed(format!(
3767 "native MLX TTS backend does not yet implement Qwen3-TTS advanced \
3768 controls {requested:?}; run on non-Apple-Silicon to use the Python \
3769 mlx-audio fallback, or set strict_capabilities = false"
3770 )));
3771 }
3772 tracing::warn!(
3773 model = %schema.name,
3774 fields = ?requested,
3775 "Qwen3-TTS advanced controls are not yet implemented in the native MLX TTS \
3776 backend; synthesizing without cloning/voice-design"
3777 );
3778 }
3779 let model_dir = self.unified_registry.ensure_local(&schema.id).await?;
3780 let size = backend_cache::estimate_model_size(&model_dir);
3781 let cache = Arc::clone(&self.kokoro_cache);
3782 let key = schema.id.clone();
3783 let handle = cache.get_or_load(&key, size, || {
3784 backend::mlx_kokoro::KokoroBackend::load(&model_dir)
3785 })?;
3786
3787 let output_path = req.output_path.clone().unwrap_or_else(|| {
3788 let dir = std::env::temp_dir().join("car_tts");
3789 let _ = std::fs::create_dir_all(&dir);
3790 dir.join("output.wav").display().to_string()
3791 });
3792 let voice = req.voice.as_deref().unwrap_or("af_heart").to_string();
3793 let text = req.text.clone();
3794 let op = tokio::task::spawn_blocking(move || -> Result<PathBuf, InferenceError> {
3795 let mut guard = handle.lock().map_err(|_| {
3796 InferenceError::InferenceFailed("kokoro backend mutex poisoned".into())
3797 })?;
3798 guard
3799 .synthesize(&text, Some(&voice), Path::new(&output_path))
3800 .map_err(|e| InferenceError::InferenceFailed(format!("native TTS: {e}")))
3801 })
3802 .await
3803 .map_err(|e| InferenceError::InferenceFailed(format!("kokoro task join: {e}")))??;
3804
3805 let final_path =
3806 materialize_audio_output(&op, req.output_path.as_deref(), &req.format)?;
3807 return Ok(SynthesizeResult {
3808 audio_path: final_path.display().to_string(),
3809 media_type: media_type_for_format(&req.format),
3810 model_used: Some(schema.name.clone()),
3811 voice_used: req.voice.clone(),
3812 });
3813 }
3814
3815 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3817 {
3818 let runtime = self.ensure_speech_runtime().await?;
3819 let primary_hf_repo = match &schema.source {
3820 ModelSource::Mlx { hf_repo, .. } => hf_repo.clone(),
3821 _ => unreachable!(),
3822 };
3823 let (produced, model_used) = match self
3824 .synthesize_local_mlx_repo(&runtime, &primary_hf_repo, schema.name.as_str(), req)
3825 .await
3826 {
3827 Ok(result) => result,
3828 Err(primary_err)
3829 if primary_hf_repo == "mlx-community/Kokoro-82M-6bit"
3830 && kokoro_runtime_fallback_enabled() =>
3831 {
3832 let fallback_repo = "mlx-community/Kokoro-82M-bf16";
3833 let fallback_name = "Kokoro-82M-bf16";
3834 match self
3835 .synthesize_local_mlx_repo(&runtime, fallback_repo, fallback_name, req)
3836 .await
3837 {
3838 Ok(result) => result,
3839 Err(fallback_err) => {
3840 return Err(InferenceError::InferenceFailed(format!(
3841 "{primary_err}; fallback {fallback_name} also failed: {fallback_err}"
3842 )));
3843 }
3844 }
3845 }
3846 Err(err) => return Err(err),
3847 };
3848 let final_path =
3849 materialize_audio_output(&produced, req.output_path.as_deref(), &req.format)?;
3850
3851 Ok(SynthesizeResult {
3852 audio_path: final_path.display().to_string(),
3853 media_type: media_type_for_format(&req.format),
3854 model_used: Some(model_used),
3855 voice_used: req.voice.clone(),
3856 })
3857 }
3858 }
3859
3860 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
3861 async fn synthesize_local_mlx_repo(
3862 &self,
3863 runtime: &SpeechRuntime,
3864 hf_repo: &str,
3865 model_name: &str,
3866 req: &SynthesizeRequest,
3867 ) -> Result<(PathBuf, String), InferenceError> {
3868 let output_dir = temp_work_dir("tts")?;
3869 let mut args = vec![
3870 "--model".to_string(),
3871 hf_repo.to_string(),
3872 "--text".to_string(),
3873 req.text.clone(),
3874 "--output_path".to_string(),
3875 output_dir.display().to_string(),
3876 ];
3877 if let Some(voice) = &req.voice {
3878 args.push("--voice".to_string());
3879 args.push(voice.clone());
3880 }
3881 if let Some(speed) = req.speed {
3882 args.push("--speed".to_string());
3883 args.push(speed.to_string());
3884 }
3885 let repo_lower = hf_repo.to_ascii_lowercase();
3886 if repo_lower.contains("kokoro") {
3887 args.push("--lang_code".to_string());
3888 args.push(kokoro_lang_code(req.language.as_deref()).to_string());
3889 } else if let Some(language) = &req.language {
3890 args.push("--lang_code".to_string());
3891 args.push(normalize_lang_code(language));
3892 }
3893
3894 if repo_lower.contains("qwen3-tts") {
3900 if let Some(ref_audio) = &req.reference_audio_path {
3901 args.push("--ref_audio".to_string());
3902 args.push(ref_audio.clone());
3903 }
3904 if let Some(ref_text) = &req.reference_text {
3905 args.push("--ref_text".to_string());
3906 args.push(ref_text.clone());
3907 }
3908 if let Some(instruct) = &req.voice_instruction {
3909 args.push("--instruct".to_string());
3910 args.push(instruct.clone());
3911 }
3912 }
3913
3914 let output = if repo_lower.contains("kokoro") {
3915 let device = std::env::var("CAR_SPEECH_KOKORO_DEVICE")
3916 .or_else(|_| std::env::var("CAR_SPEECH_MLX_DEVICE"))
3917 .unwrap_or_else(|_| "cpu".to_string());
3918 let extra_env = vec![
3919 ("MLX_DEVICE".to_string(), device),
3921 ("PYTORCH_ENABLE_MPS_FALLBACK".to_string(), "1".to_string()),
3923 ];
3924 run_mlx_audio_command_with_env(runtime, "tts.generate", &args, &extra_env).await?
3925 } else {
3926 run_mlx_audio_command(runtime, "tts.generate", &args).await?
3927 };
3928 let produced = find_audio_file(&output_dir)?.ok_or_else(|| {
3929 let hint = if repo_lower.contains("kokoro") {
3930 ". Kokoro models may crash on GPU — try CAR_SPEECH_KOKORO_DEVICE=cpu or use the default Qwen3-TTS model"
3931 } else {
3932 ""
3933 };
3934 InferenceError::InferenceFailed(format!(
3935 "mlx-audio synthesis produced no audio file: {}{}",
3936 output.stderr, hint
3937 ))
3938 })?;
3939 Ok((produced, model_name.to_string()))
3940 }
3941
3942
3943 async fn transcribe_elevenlabs(
3944 &self,
3945 schema: &ModelSchema,
3946 req: &TranscribeRequest,
3947 ) -> Result<TranscribeResult, InferenceError> {
3948 let (endpoint, api_key) = elevenlabs_auth(schema)?;
3949 let file_name = Path::new(&req.audio_path)
3950 .file_name()
3951 .and_then(|f| f.to_str())
3952 .unwrap_or("audio.wav")
3953 .to_string();
3954 let audio_bytes = tokio::fs::read(&req.audio_path).await?;
3955 let file_part = Part::bytes(audio_bytes).file_name(file_name);
3956 let mut form = Form::new()
3957 .text("model_id", schema.name.clone())
3958 .part("file", file_part);
3959 if let Some(language) = &req.language {
3960 form = form.text("language_code", language.clone());
3961 }
3962
3963 let resp = self
3964 .remote_backend
3965 .client
3966 .post(format!(
3967 "{}/v1/speech-to-text",
3968 endpoint.trim_end_matches('/')
3969 ))
3970 .header("xi-api-key", api_key)
3971 .multipart(form)
3972 .send()
3973 .await
3974 .map_err(|e| {
3975 InferenceError::InferenceFailed(format!("ElevenLabs STT request failed: {e}"))
3976 })?;
3977 let status = resp.status();
3978 let body = resp.text().await.map_err(|e| {
3979 InferenceError::InferenceFailed(format!("read ElevenLabs STT body: {e}"))
3980 })?;
3981 if !status.is_success() {
3982 return Err(InferenceError::InferenceFailed(format!(
3983 "ElevenLabs STT returned {status}: {body}"
3984 )));
3985 }
3986 let payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
3987 InferenceError::InferenceFailed(format!("parse ElevenLabs STT response: {e}"))
3988 })?;
3989 let text = payload
3990 .get("text")
3991 .and_then(|v| v.as_str())
3992 .map(str::to_string)
3993 .ok_or_else(|| {
3994 InferenceError::InferenceFailed("ElevenLabs STT response missing text".into())
3995 })?;
3996
3997 Ok(TranscribeResult {
3998 text,
3999 model_used: Some(schema.name.clone()),
4000 language: payload
4001 .get("language_code")
4002 .and_then(|v| v.as_str())
4003 .map(str::to_string),
4004 words: Vec::new(),
4005 })
4006 }
4007
4008 async fn synthesize_elevenlabs(
4009 &self,
4010 schema: &ModelSchema,
4011 req: &SynthesizeRequest,
4012 ) -> Result<SynthesizeResult, InferenceError> {
4013 let requested = req.requested_advanced_controls();
4017 if !requested.is_empty() {
4018 if req.strict_capabilities {
4019 return Err(InferenceError::InferenceFailed(format!(
4020 "ElevenLabs backend does not support Qwen3-TTS advanced controls \
4021 {requested:?}; route to a Qwen3-TTS model or set strict_capabilities = false"
4022 )));
4023 }
4024 tracing::warn!(
4025 model = %schema.name,
4026 fields = ?requested,
4027 "Qwen3-TTS advanced controls ignored by ElevenLabs backend"
4028 );
4029 }
4030 let (endpoint, api_key) = elevenlabs_auth(schema)?;
4031 let voice_id = req
4032 .voice
4033 .clone()
4034 .unwrap_or_else(|| "JBFqnCBsd6RMkjVDRZzb".to_string());
4035 let output_format = elevenlabs_output_format(&req.format);
4036 let url = format!(
4037 "{}/v1/text-to-speech/{}?output_format={}",
4038 endpoint.trim_end_matches('/'),
4039 voice_id,
4040 output_format
4041 );
4042
4043 let mut body = serde_json::json!({
4044 "text": req.text,
4045 "model_id": schema.name,
4046 });
4047 if let Some(language) = &req.language {
4048 body["language_code"] = serde_json::Value::String(language.clone());
4049 }
4050
4051 let resp = self
4052 .remote_backend
4053 .client
4054 .post(url)
4055 .header("xi-api-key", api_key)
4056 .header("Content-Type", "application/json")
4057 .json(&body)
4058 .send()
4059 .await
4060 .map_err(|e| {
4061 InferenceError::InferenceFailed(format!("ElevenLabs TTS request failed: {e}"))
4062 })?;
4063 let status = resp.status();
4064 let audio = resp.bytes().await.map_err(|e| {
4065 InferenceError::InferenceFailed(format!("read ElevenLabs TTS body: {e}"))
4066 })?;
4067 if !status.is_success() {
4068 let err_body = String::from_utf8_lossy(&audio);
4069 return Err(InferenceError::InferenceFailed(format!(
4070 "ElevenLabs TTS returned {status}: {err_body}"
4071 )));
4072 }
4073
4074 let final_path = requested_or_temp_output(req.output_path.as_deref(), &req.format)?;
4075 ensure_parent_dir(&final_path)?;
4076 tokio::fs::write(&final_path, &audio).await?;
4077
4078 Ok(SynthesizeResult {
4079 audio_path: final_path.display().to_string(),
4080 media_type: media_type_for_format(&req.format),
4081 model_used: Some(schema.name.clone()),
4082 voice_used: Some(voice_id),
4083 })
4084 }
4085}
4086
4087#[derive(Default)]
4088struct ProviderAccumulator {
4089 configured: bool,
4090 local_models: usize,
4091 remote_models: usize,
4092 available_models: usize,
4093 capabilities: std::collections::HashSet<ModelCapability>,
4094}
4095
4096#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4100struct CommandOutput {
4101 stdout: String,
4102 stderr: String,
4103}
4104
4105#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4106#[derive(Debug, Clone)]
4107struct SpeechRuntime {
4108 root: PathBuf,
4109 python: PathBuf,
4110 stt_program: PathBuf,
4111 tts_program: PathBuf,
4112}
4113
4114#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4115impl SpeechRuntime {
4116 fn new(root: PathBuf) -> Self {
4117 let bin_dir = root.join("bin");
4118 Self {
4119 root,
4120 python: bin_dir.join("python"),
4121 stt_program: bin_dir.join("mlx_audio.stt.generate"),
4122 tts_program: bin_dir.join("mlx_audio.tts.generate"),
4123 }
4124 }
4125
4126 fn is_ready(&self) -> bool {
4127 self.python.exists() && self.stt_program.exists() && self.tts_program.exists()
4128 }
4129
4130 fn command_for(&self, subcommand: &str) -> Result<&Path, InferenceError> {
4131 match subcommand {
4132 "stt.generate" => Ok(&self.stt_program),
4133 "tts.generate" => Ok(&self.tts_program),
4134 _ => Err(InferenceError::InferenceFailed(format!(
4135 "unknown speech subcommand: {subcommand}"
4136 ))),
4137 }
4138 }
4139}
4140
4141
4142#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4143async fn run_mlx_audio_command(
4144 runtime: &SpeechRuntime,
4145 subcommand: &str,
4146 args: &[String],
4147) -> Result<CommandOutput, InferenceError> {
4148 run_mlx_audio_command_with_env(runtime, subcommand, args, &[]).await
4149}
4150
4151#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4152async fn run_mlx_audio_command_with_env(
4153 runtime: &SpeechRuntime,
4154 subcommand: &str,
4155 args: &[String],
4156 envs: &[(String, String)],
4157) -> Result<CommandOutput, InferenceError> {
4158 let program = runtime.command_for(subcommand)?;
4159 let mut command = Command::new(program);
4160 command.args(args);
4161 for (key, value) in envs {
4162 command.env(key, value);
4163 }
4164 let output = command.output().await.map_err(|err| {
4165 InferenceError::InferenceFailed(format!("{}: {err}", program.display()))
4166 })?;
4167
4168 if output.status.success() {
4169 Ok(CommandOutput {
4170 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
4171 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
4172 })
4173 } else {
4174 Err(InferenceError::InferenceFailed(format!(
4175 "{} exited with {}: {}",
4176 program.display(),
4177 output.status,
4178 String::from_utf8_lossy(&output.stderr)
4179 )))
4180 }
4181}
4182
4183
4184#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4185async fn bootstrap_speech_runtime(runtime: &SpeechRuntime) -> Result<(), InferenceError> {
4186 std::fs::create_dir_all(&runtime.root)?;
4187 let python = select_speech_python()?;
4188
4189 run_command(
4190 "uv",
4191 &[
4192 "venv".to_string(),
4193 "--python".to_string(),
4194 python,
4195 runtime.root.display().to_string(),
4196 ],
4197 )
4198 .await?;
4199
4200 run_command(
4201 "uv",
4202 &[
4203 "pip".to_string(),
4204 "install".to_string(),
4205 "--python".to_string(),
4206 runtime.python.display().to_string(),
4207 speech_runtime_mlx_audio_spec(),
4208 "misaki[en]".to_string(),
4209 speech_runtime_spacy_model_spec(),
4210 ],
4211 )
4212 .await?;
4213
4214 Ok(())
4215}
4216
4217
4218#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4219async fn run_command(program: &str, args: &[String]) -> Result<(), InferenceError> {
4220 let output = Command::new(program)
4221 .args(args)
4222 .output()
4223 .await
4224 .map_err(|err| InferenceError::InferenceFailed(format!("{program}: {err}")))?;
4225
4226 if output.status.success() {
4227 Ok(())
4228 } else {
4229 Err(InferenceError::InferenceFailed(format!(
4230 "{} exited with {}: {}",
4231 program,
4232 output.status,
4233 String::from_utf8_lossy(&output.stderr)
4234 )))
4235 }
4236}
4237
4238#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4239fn select_speech_python() -> Result<String, InferenceError> {
4240 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4241 if !path.trim().is_empty() {
4242 return Ok(path);
4243 }
4244 }
4245
4246 for candidate in ["python3.13", "python3.12", "python3.11"] {
4247 if command_in_path(candidate) {
4248 return Ok(candidate.to_string());
4249 }
4250 }
4251
4252 Err(InferenceError::InferenceFailed(
4253 "no supported Python found for managed speech runtime (tried python3.13, python3.12, python3.11)".into(),
4254 ))
4255}
4256
4257#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4258fn detect_speech_python() -> Option<String> {
4259 if let Ok(path) = std::env::var("CAR_SPEECH_PYTHON") {
4260 if !path.trim().is_empty() {
4261 return Some(path);
4262 }
4263 }
4264
4265 ["python3.13", "python3.12", "python3.11"]
4266 .into_iter()
4267 .find(|candidate| command_in_path(candidate))
4268 .map(str::to_string)
4269}
4270
4271#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4272fn speech_runtime_root_from_models_dir(_models_dir: &Path) -> PathBuf {
4273 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
4274 if !path.trim().is_empty() {
4275 return PathBuf::from(path);
4276 }
4277 }
4278
4279 std::env::var("HOME")
4280 .map(PathBuf::from)
4281 .unwrap_or_else(|_| PathBuf::from("."))
4282 .join(".car")
4283 .join("speech-runtime")
4284}
4285
4286#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4287fn command_in_path(name: &str) -> bool {
4288 std::env::var_os("PATH")
4289 .map(|paths| {
4290 std::env::split_paths(&paths).any(|dir| {
4291 let path = dir.join(name);
4292 path.exists() && path.is_file()
4293 })
4294 })
4295 .unwrap_or(false)
4296}
4297
4298fn speech_model_cached(schema: &ModelSchema) -> bool {
4299 match &schema.source {
4300 ModelSource::Mlx { hf_repo, .. } => {
4301 huggingface_repo_has_snapshot(hf_repo)
4302 }
4303 ModelSource::Proprietary { auth, .. } => match auth {
4304 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4305 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4306 ProprietaryAuth::OAuth2Pkce { .. } => false,
4307 },
4308 _ => false,
4309 }
4310}
4311
4312fn remove_huggingface_repo_cache(repo_id: &str) -> Result<(), InferenceError> {
4313 let repo_dir = std::env::var("HF_HOME")
4314 .map(PathBuf::from)
4315 .unwrap_or_else(|_| {
4316 std::env::var("HOME")
4317 .map(PathBuf::from)
4318 .unwrap_or_else(|_| PathBuf::from("."))
4319 .join(".cache")
4320 .join("huggingface")
4321 })
4322 .join("hub")
4323 .join(format!("models--{}", repo_id.replace('/', "--")));
4324
4325 if repo_dir.exists() {
4326 std::fs::remove_dir_all(repo_dir)?;
4327 }
4328 Ok(())
4329}
4330
4331fn model_source_configured(schema: &ModelSchema) -> bool {
4332 match &schema.source {
4333 ModelSource::RemoteApi {
4334 api_key_env,
4335 api_key_envs,
4336 ..
4337 } => {
4338 std::env::var(api_key_env).is_ok()
4339 || api_key_envs
4340 .iter()
4341 .any(|env_var| std::env::var(env_var).is_ok())
4342 }
4343 ModelSource::Proprietary { auth, .. } => match auth {
4344 ProprietaryAuth::ApiKeyEnv { env_var } => std::env::var(env_var).is_ok(),
4345 ProprietaryAuth::BearerTokenEnv { env_var } => std::env::var(env_var).is_ok(),
4346 ProprietaryAuth::OAuth2Pkce { .. } => false,
4347 },
4348 ModelSource::VllmMlx { .. } => std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available,
4349 ModelSource::Ollama { .. } => schema.available,
4350 ModelSource::Mlx { .. } | ModelSource::Local { .. } => {
4351 true
4352 }
4353 ModelSource::AppleFoundationModels { .. } => schema.available,
4354 }
4355}
4356
4357fn all_model_capabilities() -> [ModelCapability; 13] {
4358 [
4359 ModelCapability::Generate,
4360 ModelCapability::Embed,
4361 ModelCapability::Classify,
4362 ModelCapability::Code,
4363 ModelCapability::Reasoning,
4364 ModelCapability::Summarize,
4365 ModelCapability::ToolUse,
4366 ModelCapability::MultiToolCall,
4367 ModelCapability::Vision,
4368 ModelCapability::SpeechToText,
4369 ModelCapability::TextToSpeech,
4370 ModelCapability::ImageGeneration,
4371 ModelCapability::VideoGeneration,
4372 ]
4373}
4374
4375fn sort_capabilities(mut capabilities: Vec<ModelCapability>) -> Vec<ModelCapability> {
4376 capabilities.sort_by_key(|capability| {
4377 all_model_capabilities()
4378 .iter()
4379 .position(|candidate| candidate == capability)
4380 .unwrap_or(usize::MAX)
4381 });
4382 capabilities
4383}
4384
4385fn speech_model_source_label(schema: &ModelSchema) -> String {
4386 match &schema.source {
4387 ModelSource::Mlx { hf_repo, .. } => format!("mlx:{hf_repo}"),
4388 ModelSource::Proprietary {
4389 provider, endpoint, ..
4390 } => format!("proprietary:{provider}:{endpoint}"),
4391 ModelSource::RemoteApi { endpoint, .. } => format!("remote:{endpoint}"),
4392 ModelSource::Local { hf_repo, .. } => format!("local:{hf_repo}"),
4393 ModelSource::VllmMlx {
4394 endpoint,
4395 model_name,
4396 } => format!("vllm-mlx:{endpoint}:{model_name}"),
4397 ModelSource::Ollama { model_tag, host } => format!("ollama:{host}:{model_tag}"),
4398 ModelSource::AppleFoundationModels { use_case } => {
4399 format!("apple-foundation:{}", use_case.as_deref().unwrap_or("default"))
4400 }
4401 }
4402}
4403
4404fn rerank_prompt(instruction: &str, query: &str, document: &str) -> String {
4412 const SYSTEM: &str = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".";
4413 format!(
4414 "<|im_start|>system\n{SYSTEM}<|im_end|>\n\
4415 <|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n\
4416 <|im_start|>assistant\n<think>\n\n</think>\n\n"
4417 )
4418}
4419
4420fn score_from_rerank_output(text: &str, model_name: &str) -> f32 {
4426 let normalized: String = text
4431 .to_ascii_lowercase()
4432 .chars()
4433 .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
4434 .collect();
4435 for tok in normalized.split_ascii_whitespace().take(5) {
4436 match tok {
4437 "yes" => return 1.0,
4438 "no" => return 0.0,
4439 _ => continue,
4440 }
4441 }
4442 tracing::warn!(
4443 model = %model_name,
4444 output = %text,
4445 "rerank: first tokens contain neither `yes` nor `no`; returning neutral 0.5"
4446 );
4447 0.5
4448}
4449
4450fn default_speech_voice(schema: &ModelSchema) -> Option<String> {
4451 if schema.provider == "elevenlabs" {
4452 Some("JBFqnCBsd6RMkjVDRZzb".to_string())
4453 } else if schema.name == "Kokoro-82M-6bit" || schema.name == "Kokoro-82M-bf16" {
4454 Some("af_heart".to_string())
4455 } else if schema.name == "Qwen3-TTS-12Hz-1.7B-Base-5bit" {
4456 Some("Chelsie".to_string())
4457 } else {
4458 None
4459 }
4460}
4461
4462fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
4463 find_latest_huggingface_snapshot(repo_id).is_some()
4464}
4465
4466fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
4467 let cache_root = std::env::var("HF_HOME")
4468 .map(PathBuf::from)
4469 .unwrap_or_else(|_| {
4470 std::env::var("HOME")
4471 .map(PathBuf::from)
4472 .unwrap_or_else(|_| PathBuf::from("."))
4473 .join(".cache")
4474 .join("huggingface")
4475 })
4476 .join("hub");
4477 cache_root.join(format!("models--{}", repo_id.replace('/', "--")))
4478}
4479
4480fn find_latest_huggingface_snapshot(repo_id: &str) -> Option<PathBuf> {
4481 let snapshots = huggingface_repo_dir(repo_id).join("snapshots");
4482 std::fs::read_dir(snapshots)
4483 .ok()?
4484 .filter_map(Result::ok)
4485 .map(|entry| entry.path())
4486 .find(|path| path.is_dir() && snapshot_looks_ready(path))
4487}
4488
4489fn snapshot_looks_ready(path: &Path) -> bool {
4490 if path.join("config.json").exists() || path.join("model_index.json").exists() {
4491 return true;
4492 }
4493 snapshot_contains_ext(path, "safetensors")
4494}
4495
4496fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
4497 let Ok(entries) = std::fs::read_dir(root) else {
4498 return false;
4499 };
4500 entries.filter_map(Result::ok).any(|entry| {
4501 let path = entry.path();
4502 if path.is_dir() {
4503 snapshot_contains_ext(&path, ext)
4504 } else {
4505 path.extension()
4506 .and_then(|value| value.to_str())
4507 .map(|value| value.eq_ignore_ascii_case(ext))
4508 .unwrap_or(false)
4509 }
4510 })
4511}
4512
4513fn count_files_recursive(root: &Path) -> usize {
4514 let Ok(entries) = std::fs::read_dir(root) else {
4515 return 0;
4516 };
4517 entries
4518 .filter_map(Result::ok)
4519 .map(|entry| entry.path())
4520 .map(|path| {
4521 if path.is_dir() {
4522 count_files_recursive(&path)
4523 } else if path.is_file() {
4524 1
4525 } else {
4526 0
4527 }
4528 })
4529 .sum()
4530}
4531
4532async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
4533 let api = hf_hub::api::tokio::ApiBuilder::from_env()
4534 .with_progress(false)
4535 .build()
4536 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
4537 let repo = api.model(repo_id.to_string());
4538 let info = repo
4539 .info()
4540 .await
4541 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
4542
4543 let snapshot_path = huggingface_repo_dir(repo_id)
4544 .join("snapshots")
4545 .join(&info.sha);
4546 let mut downloaded = 0usize;
4547 for sibling in &info.siblings {
4548 let local_path = snapshot_path.join(&sibling.rfilename);
4549 if local_path.exists() {
4550 downloaded += 1;
4551 continue;
4552 }
4553 repo.download(&sibling.rfilename).await.map_err(|e| {
4554 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
4555 })?;
4556 downloaded += 1;
4557 }
4558
4559 Ok((snapshot_path, downloaded))
4560}
4561
4562fn temp_work_dir(prefix: &str) -> Result<PathBuf, InferenceError> {
4563 let unique = SystemTime::now()
4564 .duration_since(UNIX_EPOCH)
4565 .map_err(|e| InferenceError::InferenceFailed(format!("clock error: {e}")))?
4566 .as_nanos();
4567 let dir = std::env::temp_dir().join(format!("car-inference-{prefix}-{unique}"));
4568 std::fs::create_dir_all(&dir)?;
4569 Ok(dir)
4570}
4571
4572fn ensure_parent_dir(path: &Path) -> Result<(), InferenceError> {
4573 if let Some(parent) = path.parent() {
4574 std::fs::create_dir_all(parent)?;
4575 }
4576 Ok(())
4577}
4578
4579fn requested_or_temp_output(
4580 output_path: Option<&str>,
4581 format: &str,
4582) -> Result<PathBuf, InferenceError> {
4583 if let Some(path) = output_path {
4584 return Ok(PathBuf::from(path));
4585 }
4586 let dir = temp_work_dir("audio-out")?;
4587 Ok(dir.join(format!("speech.{format}")))
4588}
4589
4590fn requested_or_temp_media_output(
4591 output_path: Option<&str>,
4592 format: &str,
4593 stem: &str,
4594) -> Result<PathBuf, InferenceError> {
4595 if let Some(path) = output_path {
4596 return Ok(PathBuf::from(path));
4597 }
4598 let dir = temp_work_dir(&format!("{stem}-out"))?;
4599 Ok(dir.join(format!("{stem}.{format}")))
4600}
4601
4602fn materialize_audio_output(
4603 produced: &Path,
4604 requested: Option<&str>,
4605 format: &str,
4606) -> Result<PathBuf, InferenceError> {
4607 if let Some(path) = requested {
4608 let dest = PathBuf::from(path);
4609 ensure_parent_dir(&dest)?;
4610 std::fs::copy(produced, &dest)?;
4611 Ok(dest)
4612 } else {
4613 let dest = requested_or_temp_output(None, format)?;
4614 ensure_parent_dir(&dest)?;
4615 std::fs::copy(produced, &dest)?;
4616 Ok(dest)
4617 }
4618}
4619
4620fn materialize_binary_output(
4621 produced: &Path,
4622 requested: Option<&str>,
4623 format: &str,
4624 stem: &str,
4625) -> Result<PathBuf, InferenceError> {
4626 let dest = requested_or_temp_media_output(requested, format, stem)?;
4627 ensure_parent_dir(&dest)?;
4628 std::fs::copy(produced, &dest)?;
4629 Ok(dest)
4630}
4631
4632fn find_generated_file(
4633 root: &Path,
4634 extensions: &[&str],
4635) -> Result<Option<PathBuf>, InferenceError> {
4636 let entries = std::fs::read_dir(root)?;
4637 let mut candidates: Vec<PathBuf> = entries
4638 .filter_map(Result::ok)
4639 .map(|entry| entry.path())
4640 .filter(|path| {
4641 path.is_file()
4642 && path
4643 .extension()
4644 .and_then(|ext| ext.to_str())
4645 .map(|ext| extensions.iter().any(|candidate| candidate.eq_ignore_ascii_case(ext)))
4646 .unwrap_or(false)
4647 })
4648 .collect();
4649 candidates.sort();
4650 Ok(candidates.pop())
4651}
4652
4653fn media_type_for_image_format(format: &str) -> String {
4654 match format.to_ascii_lowercase().as_str() {
4655 "jpg" | "jpeg" => "image/jpeg".to_string(),
4656 "webp" => "image/webp".to_string(),
4657 _ => "image/png".to_string(),
4658 }
4659}
4660
4661fn media_type_for_video_format(format: &str) -> String {
4662 match format.to_ascii_lowercase().as_str() {
4663 "mov" => "video/quicktime".to_string(),
4664 "gif" => "image/gif".to_string(),
4665 _ => "video/mp4".to_string(),
4666 }
4667}
4668
4669fn read_transcription_result(output_prefix: &Path) -> Result<Option<String>, InferenceError> {
4670 let candidates = [
4671 output_prefix.with_extension("json"),
4672 output_prefix.to_path_buf(),
4673 ];
4674
4675 for path in candidates {
4676 if path.exists() {
4677 let contents = std::fs::read_to_string(path)?;
4678 if let Some(text) = extract_text_from_payload(&contents) {
4679 return Ok(Some(text));
4680 }
4681 }
4682 }
4683
4684 Ok(None)
4685}
4686
4687fn extract_text_from_payload(payload: &str) -> Option<String> {
4688 let value: serde_json::Value = serde_json::from_str(payload).ok()?;
4689 if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
4690 return Some(text.to_string());
4691 }
4692 if let Some(transcripts) = value.get("transcripts").and_then(|v| v.as_array()) {
4693 let joined = transcripts
4694 .iter()
4695 .filter_map(|item| item.get("text").and_then(|v| v.as_str()))
4696 .collect::<Vec<_>>()
4697 .join("\n");
4698 if !joined.is_empty() {
4699 return Some(joined);
4700 }
4701 }
4702 if let Some(items) = value.as_array() {
4703 let joined = items
4704 .iter()
4705 .filter_map(|item| {
4706 item.get("text")
4707 .or_else(|| item.get("Content"))
4708 .and_then(|v| v.as_str())
4709 })
4710 .collect::<Vec<_>>()
4711 .join(" ");
4712 if !joined.is_empty() {
4713 return Some(joined);
4714 }
4715 }
4716 None
4717}
4718
4719fn find_audio_file(output_dir: &Path) -> Result<Option<PathBuf>, InferenceError> {
4720 let mut audio_files = Vec::new();
4721 collect_audio_files(output_dir, &mut audio_files)?;
4722 audio_files.sort();
4723 Ok(audio_files.into_iter().next())
4724}
4725
4726fn collect_audio_files(dir: &Path, audio_files: &mut Vec<PathBuf>) -> Result<(), InferenceError> {
4727 for entry in std::fs::read_dir(dir)? {
4728 let path = entry?.path();
4729 if path.is_dir() {
4730 collect_audio_files(&path, audio_files)?;
4731 } else if matches!(
4732 path.extension().and_then(|ext| ext.to_str()),
4733 Some("wav" | "mp3" | "flac" | "pcm" | "m4a")
4734 ) {
4735 audio_files.push(path);
4736 }
4737 }
4738 Ok(())
4739}
4740
4741fn media_type_for_format(format: &str) -> String {
4742 match format.to_ascii_lowercase().as_str() {
4743 "mp3" => "audio/mpeg".to_string(),
4744 "flac" => "audio/flac".to_string(),
4745 "pcm" => "audio/L16".to_string(),
4746 "m4a" => "audio/mp4".to_string(),
4747 _ => "audio/wav".to_string(),
4748 }
4749}
4750
4751#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4752fn kokoro_lang_code(language: Option<&str>) -> &'static str {
4753 match language.unwrap_or("en").to_ascii_lowercase().as_str() {
4754 "en-gb" | "british" | "british english" => "b",
4755 "ja" | "japanese" => "j",
4756 "zh" | "zh-cn" | "mandarin" | "chinese" => "z",
4757 "es" | "spanish" => "e",
4758 "fr" | "french" => "f",
4759 _ => "a",
4760 }
4761}
4762
4763fn normalize_lang_code(language: &str) -> String {
4764 match language.to_ascii_lowercase().as_str() {
4765 "english" | "en-us" | "en_us" => "en".to_string(),
4766 "spanish" => "es".to_string(),
4767 "french" => "fr".to_string(),
4768 "japanese" => "ja".to_string(),
4769 "chinese" | "mandarin" => "zh".to_string(),
4770 other => match other {
4771 "en" | "es" | "fr" | "ja" | "zh" => other.to_string(),
4772 _ => "en".to_string(),
4773 },
4774 }
4775}
4776
4777fn elevenlabs_auth(schema: &ModelSchema) -> Result<(String, String), InferenceError> {
4778 match &schema.source {
4779 ModelSource::Proprietary {
4780 endpoint,
4781 auth: schema::ProprietaryAuth::ApiKeyEnv { env_var },
4782 ..
4783 } => {
4784 let key = car_secrets::resolve_env_or_keychain(env_var).ok_or_else(|| {
4785 InferenceError::InferenceFailed(format!(
4786 "missing API key {env_var}; set the environment variable or \
4787 store it with `car secrets put {env_var}`"
4788 ))
4789 })?;
4790 Ok((endpoint.clone(), key))
4791 }
4792 _ => Err(InferenceError::InferenceFailed(format!(
4793 "model {} is not an ElevenLabs proprietary model",
4794 schema.id
4795 ))),
4796 }
4797}
4798
4799fn elevenlabs_output_format(format: &str) -> &'static str {
4800 match format.to_ascii_lowercase().as_str() {
4801 "mp3" => "mp3_44100_128",
4802 "pcm" => "pcm_16000",
4803 _ => "wav_44100",
4804 }
4805}
4806
4807fn benchmark_priors_paths(models_dir: &Path) -> Vec<PathBuf> {
4808 let mut paths = Vec::new();
4809
4810 let direct = models_dir.join("benchmark_priors.json");
4811 if !paths.contains(&direct) {
4812 paths.push(direct);
4813 }
4814
4815 if let Some(parent) = models_dir.parent() {
4816 let parent_path = parent.join("benchmark_priors.json");
4817 if !paths.contains(&parent_path) {
4818 paths.push(parent_path);
4819 }
4820 }
4821
4822 if let Some(path) = std::env::var_os("CAR_BENCHMARK_PRIORS_PATH") {
4823 let path = PathBuf::from(path);
4824 if !paths.contains(&path) {
4825 paths.push(path);
4826 }
4827 }
4828
4829 paths
4830}
4831
4832fn load_benchmark_prior_health(
4833 models_dir: &Path,
4834 schemas: &[ModelSchema],
4835) -> Vec<ModelBenchmarkPriorHealth> {
4836 let mut priors = std::collections::BTreeMap::new();
4837 for path in benchmark_priors_paths(models_dir) {
4838 let Ok(loaded) = routing_ext::load_benchmark_priors(&path) else {
4839 continue;
4840 };
4841 for (model_id, prior) in loaded {
4842 let model_name = schemas
4843 .iter()
4844 .find(|schema| schema.id == model_id)
4845 .map(|schema| schema.name.clone());
4846 priors.insert(
4847 model_id.clone(),
4848 ModelBenchmarkPriorHealth {
4849 model_id,
4850 model_name,
4851 overall_score: prior.overall_score,
4852 overall_latency_ms: prior.overall_latency_ms,
4853 task_scores: prior.task_scores,
4854 task_latency_ms: prior.task_latency_ms,
4855 source_path: path.clone(),
4856 },
4857 );
4858 }
4859 }
4860
4861 priors.into_values().collect()
4862}
4863
4864#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4865fn kokoro_runtime_fallback_enabled() -> bool {
4866 std::env::var("CAR_SPEECH_KOKORO_FALLBACK")
4867 .ok()
4868 .map(|value| !matches!(value.trim().to_ascii_lowercase().as_str(), "0" | "false" | "off"))
4869 .unwrap_or(true)
4870}
4871
4872#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4873fn speech_runtime_mlx_audio_spec() -> String {
4874 std::env::var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC")
4875 .ok()
4876 .filter(|value| !value.trim().is_empty())
4877 .unwrap_or_else(|| "mlx-audio==0.4.2".to_string())
4878}
4879
4880#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4881fn speech_runtime_spacy_model_spec() -> String {
4882 std::env::var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC")
4883 .ok()
4884 .filter(|value| !value.trim().is_empty())
4885 .unwrap_or_else(|| {
4886 "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl".to_string()
4887 })
4888}
4889
4890#[cfg(test)]
4891mod tests {
4892 use super::*;
4893 use tempfile::TempDir;
4894
4895 static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
4898
4899 fn test_config(models_dir: PathBuf) -> InferenceConfig {
4900 InferenceConfig {
4901 models_dir,
4902 device: None,
4903 generation_model: "Qwen3-0.6B".into(),
4904 preferred_generation_model: None,
4905 embedding_model: "Qwen3-Embedding-0.6B".into(),
4906 preferred_embedding_model: None,
4907 classification_model: "Qwen3-0.6B".into(),
4908 preferred_classification_model: None,
4909 }
4910 }
4911
4912 #[tokio::test]
4913 async fn tokenize_rejects_known_remote_model_with_unsupported_mode() {
4914 let tmp = TempDir::new().unwrap();
4919 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
4920 let remote_id = engine
4921 .list_schemas()
4922 .into_iter()
4923 .find(|s| !s.is_local())
4924 .map(|s| s.id)
4925 .expect(
4926 "built-in catalog should include at least one remote model schema",
4927 );
4928
4929 let err = engine
4930 .tokenize(&remote_id, "hello")
4931 .await
4932 .expect_err("remote tokenize must error");
4933 match err {
4934 InferenceError::UnsupportedMode { mode, backend, .. } => {
4935 assert_eq!(mode, "tokenize/detokenize");
4936 assert_eq!(backend, "remote");
4937 }
4938 other => panic!("expected UnsupportedMode, got {other:?}"),
4939 }
4940
4941 let err = engine
4942 .detokenize(&remote_id, &[1, 2, 3])
4943 .await
4944 .expect_err("remote detokenize must error");
4945 assert!(
4946 matches!(err, InferenceError::UnsupportedMode { .. }),
4947 "expected UnsupportedMode, got {err:?}"
4948 );
4949 }
4950
4951 #[test]
4952 fn engine_loads_benchmark_priors_on_startup() {
4953 let _env = ENV_MUTEX.lock().unwrap();
4954 let tmp = TempDir::new().unwrap();
4955 let priors_path = tmp.path().join("benchmark_priors.json");
4956 std::fs::write(
4957 &priors_path,
4958 serde_json::json!({
4959 "model_id": "qwen/qwen3-8b:q4_k_m",
4960 "overall_score": 0.88
4961 })
4962 .to_string(),
4963 )
4964 .unwrap();
4965
4966 unsafe {
4967 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
4968 }
4969
4970 let engine = InferenceEngine::new(test_config(tmp.path().join("models")));
4971 let tracker = engine.outcome_tracker.blocking_read();
4972 let profile = tracker
4973 .profile("qwen/qwen3-8b:q4_k_m")
4974 .expect("benchmark prior should create a profile");
4975 assert!((profile.ema_quality - 0.88).abs() < 0.01);
4976
4977 unsafe {
4978 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
4979 }
4980 }
4981
4982 #[test]
4983 fn benchmark_priors_do_not_override_observed_profiles() {
4984 let _env = ENV_MUTEX.lock().unwrap();
4985 let tmp = TempDir::new().unwrap();
4986 let models_dir = tmp.path().join("models");
4987 std::fs::create_dir_all(&models_dir).unwrap();
4988
4989 let observed = vec![ModelProfile {
4990 model_id: "qwen/qwen3-8b:q4_k_m".into(),
4991 total_calls: 12,
4992 success_count: 3,
4993 fail_count: 9,
4994 total_latency_ms: 1200,
4995 total_input_tokens: 0,
4996 total_output_tokens: 0,
4997 task_stats: std::collections::HashMap::new(),
4998 ema_quality: 0.21,
4999 quality_per_1k_tokens: 0.0,
5000 updated_at: 1,
5001 }];
5002 std::fs::write(
5003 models_dir.join("outcome_profiles.json"),
5004 serde_json::to_string(&observed).unwrap(),
5005 )
5006 .unwrap();
5007
5008 let priors_path = tmp.path().join("benchmark_priors.json");
5009 std::fs::write(
5010 &priors_path,
5011 serde_json::json!({
5012 "model_id": "qwen/qwen3-8b:q4_k_m",
5013 "overall_score": 0.95
5014 })
5015 .to_string(),
5016 )
5017 .unwrap();
5018
5019 unsafe {
5020 std::env::set_var("CAR_BENCHMARK_PRIORS_PATH", &priors_path);
5021 }
5022
5023 let engine = InferenceEngine::new(test_config(models_dir));
5024 let tracker = engine.outcome_tracker.blocking_read();
5025 let profile = tracker
5026 .profile("qwen/qwen3-8b:q4_k_m")
5027 .expect("observed profile should remain present");
5028 assert!((profile.ema_quality - 0.21).abs() < 0.01);
5029 assert_eq!(profile.total_calls, 12);
5030
5031 unsafe {
5032 std::env::remove_var("CAR_BENCHMARK_PRIORS_PATH");
5033 }
5034 }
5035
5036 #[test]
5037 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
5038 fn speech_runtime_package_spec_defaults_and_overrides() {
5039 let _env = ENV_MUTEX.lock().unwrap();
5040 unsafe {
5041 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5042 }
5043 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.2");
5044
5045 unsafe {
5046 std::env::set_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC", "mlx-audio==0.4.1");
5047 }
5048 assert_eq!(speech_runtime_mlx_audio_spec(), "mlx-audio==0.4.1");
5049
5050 unsafe {
5051 std::env::remove_var("CAR_SPEECH_RUNTIME_MLX_AUDIO_SPEC");
5052 }
5053 }
5054
5055 #[test]
5056 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
5057 fn speech_runtime_spacy_model_spec_defaults_and_overrides() {
5058 let _env = ENV_MUTEX.lock().unwrap();
5059 unsafe {
5060 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5061 }
5062 assert!(
5063 speech_runtime_spacy_model_spec().starts_with("en-core-web-sm @ https://github.com/")
5064 );
5065
5066 unsafe {
5067 std::env::set_var(
5068 "CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC",
5069 "en-core-web-sm==3.8.0",
5070 );
5071 }
5072 assert_eq!(
5073 speech_runtime_spacy_model_spec(),
5074 "en-core-web-sm==3.8.0"
5075 );
5076
5077 unsafe {
5078 std::env::remove_var("CAR_SPEECH_RUNTIME_SPACY_MODEL_SPEC");
5079 }
5080 }
5081
5082 #[test]
5083 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
5084 fn kokoro_runtime_fallback_defaults_on() {
5085 unsafe {
5086 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5087 }
5088 assert!(kokoro_runtime_fallback_enabled());
5089
5090 unsafe {
5091 std::env::set_var("CAR_SPEECH_KOKORO_FALLBACK", "false");
5092 }
5093 assert!(!kokoro_runtime_fallback_enabled());
5094
5095 unsafe {
5096 std::env::remove_var("CAR_SPEECH_KOKORO_FALLBACK");
5097 }
5098 }
5099
5100 #[test]
5101 fn preferred_local_tts_wins_over_builtin_rank() {
5102 let tmp = TempDir::new().unwrap();
5103 let mut engine = InferenceEngine::new(test_config(tmp.path().join("models")));
5104 engine.set_speech_policy(SpeechPolicy {
5105 prefer_local: true,
5106 allow_remote_fallback: false,
5107 preferred_local_stt: None,
5108 preferred_local_tts: Some("Kokoro-82M-6bit".into()),
5109 preferred_remote_stt: None,
5110 preferred_remote_tts: None,
5111 });
5112
5113 let schema = engine
5114 .preferred_speech_schema(ModelCapability::TextToSpeech, true, false)
5115 .expect("preferred local TTS should resolve");
5116 assert_eq!(schema.name, "Kokoro-82M-6bit");
5117 }
5118
5119 #[test]
5120 fn preferred_discovered_vllm_mlx_model_wins_generate_routing() {
5121 let tmp = TempDir::new().unwrap();
5122 let mut config = test_config(tmp.path().join("models"));
5123 config.preferred_generation_model =
5124 Some("vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit".into());
5125 let mut engine = InferenceEngine::new(config);
5126 let schema = crate::vllm_mlx::to_model_schema(
5127 &crate::vllm_mlx::DiscoveredModel {
5128 id: "mlx-community/gemma-3n-E2B-it-lm-4bit".into(),
5129 owned_by: Some("mlx-community".into()),
5130 },
5131 "http://127.0.0.1:8001",
5132 );
5133 engine.register_model(schema);
5134
5135 let rt = tokio::runtime::Runtime::new().unwrap();
5136 let decision = rt.block_on(engine.route_adaptive("say hello in one sentence"));
5137 assert_eq!(decision.model_id, "vllm-mlx/mlx-community_gemma-3n-E2B-it-lm-4bit");
5138 assert_eq!(decision.strategy, RoutingStrategy::Explicit);
5139 assert_eq!(decision.reason, "preferred generation model override");
5140 }
5141
5142 #[test]
5147 fn inference_result_serializes_with_full_shape() {
5148 use crate::tasks::generate::ToolCall;
5149 use std::collections::HashMap;
5150
5151 let mut args = HashMap::new();
5152 args.insert("path".to_string(), serde_json::json!("README.md"));
5153
5154 let result = InferenceResult {
5155 text: String::new(),
5156 bounding_boxes: Vec::new(),
5157 tool_calls: vec![ToolCall {
5158 id: None,
5159 name: "read_file".into(),
5160 arguments: args,
5161 }],
5162 trace_id: "trace-abc".into(),
5163 model_used: "test-model".into(),
5164 latency_ms: 1234,
5165 time_to_first_token_ms: Some(180),
5166 usage: Some(TokenUsage {
5167 prompt_tokens: 100,
5168 completion_tokens: 50,
5169 total_tokens: 150,
5170 context_window: 8192,
5171 }),
5172 };
5173
5174 let json = serde_json::to_value(&result).expect("serialize");
5175
5176 assert_eq!(json["text"].as_str(), Some(""));
5178 assert_eq!(json["trace_id"].as_str(), Some("trace-abc"));
5179 assert_eq!(json["model_used"].as_str(), Some("test-model"));
5180 assert_eq!(json["latency_ms"].as_u64(), Some(1234));
5181
5182 let tool_calls = json["tool_calls"].as_array().expect("tool_calls array");
5184 assert_eq!(tool_calls.len(), 1);
5185 assert_eq!(tool_calls[0]["name"].as_str(), Some("read_file"));
5186 assert_eq!(tool_calls[0]["arguments"]["path"].as_str(), Some("README.md"));
5187
5188 let usage = &json["usage"];
5190 assert_eq!(usage["prompt_tokens"].as_u64(), Some(100));
5191 assert_eq!(usage["completion_tokens"].as_u64(), Some(50));
5192 assert_eq!(usage["total_tokens"].as_u64(), Some(150));
5193 assert_eq!(usage["context_window"].as_u64(), Some(8192));
5194
5195 assert_eq!(json["time_to_first_token_ms"].as_u64(), Some(180));
5197 }
5198
5199 #[test]
5205 fn inference_result_top_level_keys_are_locked() {
5206 use std::collections::BTreeSet;
5207
5208 let result = InferenceResult {
5209 text: "anything".into(),
5210 bounding_boxes: Vec::new(),
5211 tool_calls: vec![],
5212 trace_id: "t".into(),
5213 model_used: "m".into(),
5214 latency_ms: 0,
5215 time_to_first_token_ms: None,
5216 usage: None,
5217 };
5218
5219 let json = serde_json::to_value(&result).expect("serialize");
5220 let keys: BTreeSet<&str> = json
5221 .as_object()
5222 .expect("top-level object")
5223 .keys()
5224 .map(String::as_str)
5225 .collect();
5226
5227 let expected: BTreeSet<&str> = [
5228 "text",
5229 "tool_calls",
5230 "trace_id",
5231 "model_used",
5232 "latency_ms",
5233 "time_to_first_token_ms",
5234 "usage",
5235 ]
5236 .into_iter()
5237 .collect();
5238
5239 assert_eq!(
5240 keys, expected,
5241 "infer response top-level keys drifted -- update both the test \
5242 and the WebSocket protocol documentation if this is intentional"
5243 );
5244
5245 for key in &keys {
5247 assert!(
5248 !key.chars().any(|c| c.is_uppercase()) && !key.contains('-'),
5249 "key '{}' is not snake_case",
5250 key
5251 );
5252 }
5253 }
5254
5255 #[test]
5259 fn inference_result_serializes_plain_text_response() {
5260 let result = InferenceResult {
5261 text: "hello world".into(),
5262 bounding_boxes: Vec::new(),
5263 tool_calls: vec![],
5264 trace_id: "trace-xyz".into(),
5265 model_used: "test-model".into(),
5266 latency_ms: 42,
5267 time_to_first_token_ms: None,
5268 usage: None,
5269 };
5270
5271 let json = serde_json::to_value(&result).expect("serialize");
5272 assert_eq!(json["text"], "hello world");
5273 assert!(json["tool_calls"].is_array());
5274 assert_eq!(json["tool_calls"].as_array().unwrap().len(), 0);
5275 assert_eq!(json["model_used"], "test-model");
5276 assert!(json["usage"].is_null());
5277 assert!(json["time_to_first_token_ms"].is_null());
5280 }
5281
5282 #[test]
5283 fn rerank_prompt_matches_upstream_template_shape() {
5284 let p = rerank_prompt("retrieve relevant passages", "who runs the treasury?", "doc x");
5285 assert!(p.contains("<|im_start|>system"));
5286 assert!(p.contains("Note that the answer can only be \"yes\" or \"no\"."));
5287 assert!(p.contains("<|im_start|>user\n<Instruct>: retrieve relevant passages"));
5288 assert!(p.contains("<Query>: who runs the treasury?"));
5289 assert!(p.contains("<Document>: doc x<|im_end|>"));
5290 assert!(p.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
5291 }
5292
5293 #[test]
5294 fn rerank_score_yes_and_no_exactly() {
5295 assert_eq!(score_from_rerank_output("yes", "m"), 1.0);
5296 assert_eq!(score_from_rerank_output("no", "m"), 0.0);
5297 }
5298
5299 #[test]
5300 fn rerank_score_handles_case_leading_space_and_chat_sentinels() {
5301 assert_eq!(score_from_rerank_output(" Yes", "m"), 1.0);
5304 assert_eq!(score_from_rerank_output("\nno.", "m"), 0.0);
5305 assert_eq!(score_from_rerank_output("<|im_end|>yes", "m"), 1.0);
5306 }
5307
5308 #[test]
5309 fn rerank_score_scans_up_to_three_tokens() {
5310 assert_eq!(score_from_rerank_output("_bos_ yes", "m"), 1.0);
5313 }
5314
5315 #[test]
5316 fn rerank_score_unexpected_is_neutral() {
5317 assert_eq!(score_from_rerank_output("maybe", "m"), 0.5);
5320 assert_eq!(score_from_rerank_output("", "m"), 0.5);
5321 }
5322}