Skip to main content

car_inference/
lib.rs

1//! # car-inference
2//!
3//! Local model inference for the Common Agent Runtime.
4//!
5//! Provides on-device inference using Candle with automatic hardware detection:
6//! - **macOS**: Metal (Apple Silicon GPU)
7//! - **Linux**: CUDA (NVIDIA GPU) or CPU fallback
8//!
9//! Ships with Qwen3 models downloaded on first use from HuggingFace.
10//! Supports remote API models (OpenAI, Anthropic, Google) via the same schema.
11//!
12//! ## Architecture
13//!
14//! Models are first-class typed resources described by `ModelSchema` (analogous
15//! to `ToolSchema`). The `UnifiedRegistry` holds local and remote models.
16//! The `AdaptiveRouter` selects the best model using a three-phase strategy:
17//! filter → score → explore. The `OutcomeTracker` learns from results to
18//! improve routing over time.
19//!
20//! ## Dual purpose
21//!
22//! 1. **Internal** — powers skill learning/repair, semantic memory, policy evaluation
23//! 2. **Service** — exposes `infer`, `embed`, `classify` as built-in CAR tools
24
25pub mod adaptive_router;
26pub mod backend;
27pub mod hardware;
28pub mod models;
29pub mod outcome;
30pub mod registry;
31pub mod remote;
32pub mod router;
33pub mod schema;
34pub mod service;
35pub mod tasks;
36
37use std::sync::Arc;
38use std::time::Instant;
39
40use thiserror::Error;
41use tokio::sync::RwLock;
42use tracing::debug;
43
44// --- New types ---
45pub use adaptive_router::{AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy};
46pub use outcome::{
47    CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
48};
49pub use registry::{ModelFilter, ModelInfo, UnifiedRegistry};
50pub use remote::RemoteBackend;
51pub use schema::{ApiProtocol, CostModel, ModelCapability, ModelSchema, ModelSource, PerformanceEnvelope};
52
53// --- Legacy re-exports (kept for backward compatibility) ---
54pub use adaptive_router::TaskComplexity;
55pub use backend::CandleBackend;
56pub use backend::EmbeddingBackend;
57pub use hardware::HardwareInfo;
58pub use models::{ModelRegistry, ModelRole};
59pub use router::{ModelRouter, RoutingDecision};
60pub use tasks::{ClassifyRequest, ClassifyResult, EmbedRequest, GenerateParams, GenerateRequest};
61
62#[derive(Error, Debug)]
63pub enum InferenceError {
64    #[error("model not found: {0}")]
65    ModelNotFound(String),
66
67    #[error("model download failed: {0}")]
68    DownloadFailed(String),
69
70    #[error("inference failed: {0}")]
71    InferenceFailed(String),
72
73    #[error("tokenization error: {0}")]
74    TokenizationError(String),
75
76    #[error("device error: {0}")]
77    DeviceError(String),
78
79    #[error("io error: {0}")]
80    Io(#[from] std::io::Error),
81}
82
83/// Which device to run inference on.
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum Device {
86    Cpu,
87    Metal,
88    Cuda(usize), // device ordinal
89}
90
91impl Device {
92    /// Auto-detect the best available device for this platform.
93    pub fn auto() -> Self {
94        #[cfg(feature = "metal")]
95        {
96            return Device::Metal;
97        }
98        #[cfg(feature = "cuda")]
99        {
100            return Device::Cuda(0);
101        }
102        #[cfg(not(any(feature = "metal", feature = "cuda")))]
103        {
104            Device::Cpu
105        }
106    }
107}
108
109/// Configuration for the inference engine.
110#[derive(Debug, Clone)]
111pub struct InferenceConfig {
112    /// Where to store downloaded models. Defaults to ~/.car/models/
113    pub models_dir: std::path::PathBuf,
114    /// Device override. None = auto-detect.
115    pub device: Option<Device>,
116    /// Default model for generation tasks.
117    pub generation_model: String,
118    /// Default model for embedding tasks.
119    pub embedding_model: String,
120    /// Default model for classification tasks.
121    pub classification_model: String,
122}
123
124impl Default for InferenceConfig {
125    fn default() -> Self {
126        let models_dir = dirs_next()
127            .unwrap_or_else(|| std::path::PathBuf::from("."))
128            .join(".car")
129            .join("models");
130
131        let hw = HardwareInfo::detect();
132
133        Self {
134            models_dir,
135            device: None,
136            generation_model: hw.recommended_model,
137            embedding_model: "Qwen3-Embedding-0.6B".to_string(),
138            classification_model: "Qwen3-0.6B".to_string(),
139        }
140    }
141}
142
143fn dirs_next() -> Option<std::path::PathBuf> {
144    std::env::var("HOME")
145        .ok()
146        .map(std::path::PathBuf::from)
147}
148
149/// Result of an inference call, including trace ID for outcome tracking.
150#[derive(Debug, Clone)]
151pub struct InferenceResult {
152    /// The generated text.
153    pub text: String,
154    /// Trace ID for reporting outcomes back to the tracker.
155    pub trace_id: String,
156    /// Which model was used.
157    pub model_used: String,
158    /// Wall-clock latency in ms.
159    pub latency_ms: u64,
160}
161
162/// The main inference engine. Thread-safe, lazily loads models.
163///
164/// Now includes the unified registry, adaptive router, and outcome tracker
165/// for schema-driven model selection with learned performance profiles.
166pub struct InferenceEngine {
167    pub config: InferenceConfig,
168    /// Unified model registry (local + remote).
169    pub unified_registry: UnifiedRegistry,
170    /// Adaptive router with three-phase selection.
171    pub adaptive_router: AdaptiveRouter,
172    /// Outcome tracker for learning from results.
173    pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
174    /// HTTP client for remote API models.
175    remote_backend: RemoteBackend,
176    // Legacy fields kept for backward compatibility
177    pub registry: models::ModelRegistry,
178    pub router: ModelRouter,
179    backend: Arc<RwLock<Option<CandleBackend>>>,
180    embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
181}
182
183impl InferenceEngine {
184    pub fn new(config: InferenceConfig) -> Self {
185        let registry = models::ModelRegistry::new(config.models_dir.clone());
186        let hw = HardwareInfo::detect();
187        let router = ModelRouter::new(hw.clone());
188        let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
189        let adaptive_router = AdaptiveRouter::with_default_config(hw);
190        let outcome_tracker = Arc::new(RwLock::new(OutcomeTracker::new()));
191
192        Self {
193            config,
194            unified_registry,
195            adaptive_router,
196            outcome_tracker,
197            remote_backend: RemoteBackend::new(),
198            registry,
199            router,
200            backend: Arc::new(RwLock::new(None)),
201            embedding_backend: Arc::new(RwLock::new(None)),
202        }
203    }
204
205    /// Get or initialize the generative backend, loading the specified model.
206    async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
207        let read = self.backend.read().await;
208        if read.is_some() {
209            return Ok(());
210        }
211        drop(read);
212
213        let mut write = self.backend.write().await;
214        if write.is_some() {
215            return Ok(());
216        }
217
218        let model_path = self.registry.ensure_model(model_name).await?;
219        let device = self.config.device.unwrap_or_else(Device::auto);
220        let backend = CandleBackend::load(&model_path, device)?;
221        *write = Some(backend);
222        Ok(())
223    }
224
225    /// Get or initialize the embedding backend.
226    async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
227        let read = self.embedding_backend.read().await;
228        if read.is_some() {
229            return Ok(());
230        }
231        drop(read);
232
233        let mut write = self.embedding_backend.write().await;
234        if write.is_some() {
235            return Ok(());
236        }
237
238        let model_path = self.registry.ensure_model(&self.config.embedding_model).await?;
239        let device = self.config.device.unwrap_or_else(Device::auto);
240        let backend = EmbeddingBackend::load(&model_path, device)?;
241        *write = Some(backend);
242        Ok(())
243    }
244
245    /// Route a prompt using the adaptive router (new). Returns full decision context.
246    pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
247        let tracker = self.outcome_tracker.read().await;
248        self.adaptive_router.route(prompt, &self.unified_registry, &tracker)
249    }
250
251    /// Route a prompt to the best model without executing (legacy compat).
252    pub fn route(&self, prompt: &str) -> RoutingDecision {
253        self.router.route_generate(prompt, &self.registry)
254    }
255
256    /// Generate text from a prompt with outcome tracking.
257    /// Returns `InferenceResult` with trace_id for reporting outcomes.
258    pub async fn generate_tracked(&self, req: GenerateRequest) -> Result<InferenceResult, InferenceError> {
259        let start = Instant::now();
260
261        // Route using adaptive router
262        let tracker_read = self.outcome_tracker.read().await;
263        let decision = match req.model.clone() {
264            Some(m) => AdaptiveRoutingDecision {
265                model_id: m.clone(),
266                model_name: m.clone(),
267                task: InferenceTask::Generate,
268                complexity: TaskComplexity::assess(&req.prompt),
269                reason: "explicit model".into(),
270                strategy: RoutingStrategy::Explicit,
271                predicted_quality: 0.5,
272                fallbacks: vec![],
273            },
274            None => self.adaptive_router.route(&req.prompt, &self.unified_registry, &tracker_read),
275        };
276        drop(tracker_read);
277
278        // Record start
279        let trace_id = {
280            let mut tracker = self.outcome_tracker.write().await;
281            tracker.record_start(&decision.model_id, decision.task, &decision.reason)
282        };
283
284        debug!(
285            model = %decision.model_name,
286            strategy = ?decision.strategy,
287            reason = %decision.reason,
288            trace = %trace_id,
289            "adaptive-routed generate request"
290        );
291
292        // Execute — dispatch to local or remote backend, with fallback on failure
293        let mut models_to_try = vec![decision.model_id.clone()];
294        models_to_try.extend(decision.fallbacks.iter().cloned());
295
296        let mut last_error = None;
297        let mut used_model_name = decision.model_name.clone();
298
299        for candidate_id in &models_to_try {
300            let schema = self.unified_registry.get(candidate_id)
301                .or_else(|| self.unified_registry.find_by_name(candidate_id))
302                .cloned();
303
304            let candidate_name = schema.as_ref()
305                .map(|s| s.name.clone())
306                .unwrap_or_else(|| candidate_id.clone());
307
308            let is_remote = schema.as_ref().map(|s| s.is_remote()).unwrap_or(false);
309
310            let result = if is_remote {
311                let schema = schema.unwrap();
312                self.remote_backend.generate(
313                    &schema,
314                    &req.prompt,
315                    req.context.as_deref(),
316                    req.params.temperature,
317                    req.params.max_tokens,
318                ).await
319            } else {
320                match self.ensure_backend(&candidate_name).await {
321                    Ok(()) => {
322                        let mut write = self.backend.write().await;
323                        let backend = write.as_mut().unwrap();
324                        tasks::generate::generate(backend, req.clone()).await
325                    }
326                    Err(e) => Err(e),
327                }
328            };
329
330            match result {
331                Ok(text) => {
332                    let latency_ms = start.elapsed().as_millis() as u64;
333                    let estimated_tokens = text.split_whitespace().count();
334                    let mut tracker = self.outcome_tracker.write().await;
335                    tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
336                    used_model_name = candidate_name;
337                    return Ok(InferenceResult {
338                        text,
339                        trace_id,
340                        model_used: used_model_name,
341                        latency_ms,
342                    });
343                }
344                Err(e) => {
345                    debug!(model = %candidate_name, error = %e, "model failed, trying fallback");
346                    // Record failure for this model so the router learns
347                    {
348                        let mut tracker = self.outcome_tracker.write().await;
349                        let fail_trace = tracker.record_start(candidate_id, decision.task, "fallback");
350                        tracker.record_failure(&fail_trace, &e.to_string());
351                    }
352                    // Reset backend so next model can load
353                    {
354                        let mut write = self.backend.write().await;
355                        *write = None;
356                    }
357                    last_error = Some(e);
358                }
359            }
360        }
361
362        // All models failed
363        let e = last_error.unwrap_or(InferenceError::InferenceFailed("no models available".into()));
364        let mut tracker = self.outcome_tracker.write().await;
365        tracker.record_failure(&trace_id, &e.to_string());
366        Err(e)
367    }
368
369    /// Generate text from a prompt (legacy API, no outcome tracking).
370    /// When `req.model` is None, uses intelligent routing based on prompt complexity.
371    pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
372        let model = match req.model.clone() {
373            Some(m) => m,
374            None => {
375                let decision = self.router.route_generate(&req.prompt, &self.registry);
376                debug!(
377                    model = %decision.model,
378                    reason = %decision.reason,
379                    "auto-routed generate request"
380                );
381                decision.model
382            }
383        };
384        self.ensure_backend(&model).await?;
385
386        let mut write = self.backend.write().await;
387        let backend = write.as_mut().unwrap();
388        tasks::generate::generate(backend, req).await
389    }
390
391    /// Generate embeddings for text using the dedicated embedding model.
392    /// Uses Qwen3-Embedding with proper last-token hidden state extraction.
393    pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
394        self.ensure_embedding_backend().await?;
395
396        let mut write = self.embedding_backend.write().await;
397        let backend = write.as_mut().unwrap();
398
399        let instruction = req.instruction.as_deref()
400            .unwrap_or("Retrieve relevant memory facts");
401
402        let mut results = Vec::with_capacity(req.texts.len());
403        for text in &req.texts {
404            let embedding = if req.is_query {
405                backend.embed_query(text, instruction)?
406            } else {
407                backend.embed_one(text)?
408            };
409            results.push(embedding);
410        }
411        Ok(results)
412    }
413
414    /// Classify text against candidate labels.
415    /// When `req.model` is None, routes to the smallest available model.
416    pub async fn classify(&self, req: ClassifyRequest) -> Result<Vec<ClassifyResult>, InferenceError> {
417        let model = match req.model.clone() {
418            Some(m) => m,
419            None => {
420                let m = self.router.route_small(&self.registry);
421                debug!(model = %m, "auto-routed classify request");
422                m
423            }
424        };
425        self.ensure_backend(&model).await?;
426
427        let mut write = self.backend.write().await;
428        let backend = write.as_mut().unwrap();
429        tasks::classify::classify(backend, req).await
430    }
431
432    /// List all known models and their status (new registry).
433    pub fn list_models_unified(&self) -> Vec<ModelInfo> {
434        self.unified_registry.list().iter().map(|m| ModelInfo::from(*m)).collect()
435    }
436
437    /// List all known models and their download status (legacy).
438    pub fn list_models(&self) -> Vec<models::ModelInfo> {
439        self.registry.list_models()
440    }
441
442    /// Download a model if not already present.
443    pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
444        self.registry.ensure_model(name).await
445    }
446
447    /// Remove a downloaded model.
448    pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
449        self.registry.remove_model(name)
450    }
451
452    /// Register a model in the unified registry.
453    pub fn register_model(&mut self, schema: ModelSchema) {
454        self.unified_registry.register(schema);
455    }
456
457    /// Get outcome tracker for external use (e.g., memgine integration).
458    pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
459        self.outcome_tracker.clone()
460    }
461
462    /// Export model performance profiles for persistence.
463    pub async fn export_profiles(&self) -> Vec<ModelProfile> {
464        let tracker = self.outcome_tracker.read().await;
465        tracker.export_profiles()
466    }
467
468    /// Import model performance profiles (from persistence).
469    pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
470        let mut tracker = self.outcome_tracker.write().await;
471        tracker.import_profiles(profiles);
472    }
473}