1pub mod adaptive_router;
26pub mod backend;
27pub mod hardware;
28pub mod models;
29pub mod outcome;
30pub mod registry;
31pub mod remote;
32pub mod router;
33pub mod schema;
34pub mod service;
35pub mod tasks;
36
37use std::sync::Arc;
38use std::time::Instant;
39
40use thiserror::Error;
41use tokio::sync::RwLock;
42use tracing::debug;
43
44pub use adaptive_router::{AdaptiveRouter, AdaptiveRoutingDecision, RoutingConfig, RoutingStrategy};
46pub use outcome::{
47 CodeOutcome, InferenceOutcome, InferenceTask, InferredOutcome, ModelProfile, OutcomeTracker,
48};
49pub use registry::{ModelFilter, ModelInfo, UnifiedRegistry};
50pub use remote::RemoteBackend;
51pub use schema::{ApiProtocol, CostModel, ModelCapability, ModelSchema, ModelSource, PerformanceEnvelope};
52
53pub use adaptive_router::TaskComplexity;
55pub use backend::CandleBackend;
56pub use backend::EmbeddingBackend;
57pub use hardware::HardwareInfo;
58pub use models::{ModelRegistry, ModelRole};
59pub use router::{ModelRouter, RoutingDecision};
60pub use tasks::{ClassifyRequest, ClassifyResult, EmbedRequest, GenerateParams, GenerateRequest};
61
62#[derive(Error, Debug)]
63pub enum InferenceError {
64 #[error("model not found: {0}")]
65 ModelNotFound(String),
66
67 #[error("model download failed: {0}")]
68 DownloadFailed(String),
69
70 #[error("inference failed: {0}")]
71 InferenceFailed(String),
72
73 #[error("tokenization error: {0}")]
74 TokenizationError(String),
75
76 #[error("device error: {0}")]
77 DeviceError(String),
78
79 #[error("io error: {0}")]
80 Io(#[from] std::io::Error),
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum Device {
86 Cpu,
87 Metal,
88 Cuda(usize), }
90
91impl Device {
92 pub fn auto() -> Self {
94 #[cfg(feature = "metal")]
95 {
96 return Device::Metal;
97 }
98 #[cfg(feature = "cuda")]
99 {
100 return Device::Cuda(0);
101 }
102 #[cfg(not(any(feature = "metal", feature = "cuda")))]
103 {
104 Device::Cpu
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct InferenceConfig {
112 pub models_dir: std::path::PathBuf,
114 pub device: Option<Device>,
116 pub generation_model: String,
118 pub embedding_model: String,
120 pub classification_model: String,
122}
123
124impl Default for InferenceConfig {
125 fn default() -> Self {
126 let models_dir = dirs_next()
127 .unwrap_or_else(|| std::path::PathBuf::from("."))
128 .join(".car")
129 .join("models");
130
131 let hw = HardwareInfo::detect();
132
133 Self {
134 models_dir,
135 device: None,
136 generation_model: hw.recommended_model,
137 embedding_model: "Qwen3-Embedding-0.6B".to_string(),
138 classification_model: "Qwen3-0.6B".to_string(),
139 }
140 }
141}
142
143fn dirs_next() -> Option<std::path::PathBuf> {
144 std::env::var("HOME")
145 .ok()
146 .map(std::path::PathBuf::from)
147}
148
149#[derive(Debug, Clone)]
151pub struct InferenceResult {
152 pub text: String,
154 pub trace_id: String,
156 pub model_used: String,
158 pub latency_ms: u64,
160}
161
162pub struct InferenceEngine {
167 pub config: InferenceConfig,
168 pub unified_registry: UnifiedRegistry,
170 pub adaptive_router: AdaptiveRouter,
172 pub outcome_tracker: Arc<RwLock<OutcomeTracker>>,
174 remote_backend: RemoteBackend,
176 pub registry: models::ModelRegistry,
178 pub router: ModelRouter,
179 backend: Arc<RwLock<Option<CandleBackend>>>,
180 embedding_backend: Arc<RwLock<Option<EmbeddingBackend>>>,
181}
182
183impl InferenceEngine {
184 pub fn new(config: InferenceConfig) -> Self {
185 let registry = models::ModelRegistry::new(config.models_dir.clone());
186 let hw = HardwareInfo::detect();
187 let router = ModelRouter::new(hw.clone());
188 let unified_registry = UnifiedRegistry::new(config.models_dir.clone());
189 let adaptive_router = AdaptiveRouter::with_default_config(hw);
190 let outcome_tracker = Arc::new(RwLock::new(OutcomeTracker::new()));
191
192 Self {
193 config,
194 unified_registry,
195 adaptive_router,
196 outcome_tracker,
197 remote_backend: RemoteBackend::new(),
198 registry,
199 router,
200 backend: Arc::new(RwLock::new(None)),
201 embedding_backend: Arc::new(RwLock::new(None)),
202 }
203 }
204
205 async fn ensure_backend(&self, model_name: &str) -> Result<(), InferenceError> {
207 let read = self.backend.read().await;
208 if read.is_some() {
209 return Ok(());
210 }
211 drop(read);
212
213 let mut write = self.backend.write().await;
214 if write.is_some() {
215 return Ok(());
216 }
217
218 let model_path = self.registry.ensure_model(model_name).await?;
219 let device = self.config.device.unwrap_or_else(Device::auto);
220 let backend = CandleBackend::load(&model_path, device)?;
221 *write = Some(backend);
222 Ok(())
223 }
224
225 async fn ensure_embedding_backend(&self) -> Result<(), InferenceError> {
227 let read = self.embedding_backend.read().await;
228 if read.is_some() {
229 return Ok(());
230 }
231 drop(read);
232
233 let mut write = self.embedding_backend.write().await;
234 if write.is_some() {
235 return Ok(());
236 }
237
238 let model_path = self.registry.ensure_model(&self.config.embedding_model).await?;
239 let device = self.config.device.unwrap_or_else(Device::auto);
240 let backend = EmbeddingBackend::load(&model_path, device)?;
241 *write = Some(backend);
242 Ok(())
243 }
244
245 pub async fn route_adaptive(&self, prompt: &str) -> AdaptiveRoutingDecision {
247 let tracker = self.outcome_tracker.read().await;
248 self.adaptive_router.route(prompt, &self.unified_registry, &tracker)
249 }
250
251 pub fn route(&self, prompt: &str) -> RoutingDecision {
253 self.router.route_generate(prompt, &self.registry)
254 }
255
256 pub async fn generate_tracked(&self, req: GenerateRequest) -> Result<InferenceResult, InferenceError> {
259 let start = Instant::now();
260
261 let tracker_read = self.outcome_tracker.read().await;
263 let decision = match req.model.clone() {
264 Some(m) => AdaptiveRoutingDecision {
265 model_id: m.clone(),
266 model_name: m.clone(),
267 task: InferenceTask::Generate,
268 complexity: TaskComplexity::assess(&req.prompt),
269 reason: "explicit model".into(),
270 strategy: RoutingStrategy::Explicit,
271 predicted_quality: 0.5,
272 fallbacks: vec![],
273 },
274 None => self.adaptive_router.route(&req.prompt, &self.unified_registry, &tracker_read),
275 };
276 drop(tracker_read);
277
278 let trace_id = {
280 let mut tracker = self.outcome_tracker.write().await;
281 tracker.record_start(&decision.model_id, decision.task, &decision.reason)
282 };
283
284 debug!(
285 model = %decision.model_name,
286 strategy = ?decision.strategy,
287 reason = %decision.reason,
288 trace = %trace_id,
289 "adaptive-routed generate request"
290 );
291
292 let mut models_to_try = vec![decision.model_id.clone()];
294 models_to_try.extend(decision.fallbacks.iter().cloned());
295
296 let mut last_error = None;
297 let mut used_model_name = decision.model_name.clone();
298
299 for candidate_id in &models_to_try {
300 let schema = self.unified_registry.get(candidate_id)
301 .or_else(|| self.unified_registry.find_by_name(candidate_id))
302 .cloned();
303
304 let candidate_name = schema.as_ref()
305 .map(|s| s.name.clone())
306 .unwrap_or_else(|| candidate_id.clone());
307
308 let is_remote = schema.as_ref().map(|s| s.is_remote()).unwrap_or(false);
309
310 let result = if is_remote {
311 let schema = schema.unwrap();
312 self.remote_backend.generate(
313 &schema,
314 &req.prompt,
315 req.context.as_deref(),
316 req.params.temperature,
317 req.params.max_tokens,
318 ).await
319 } else {
320 match self.ensure_backend(&candidate_name).await {
321 Ok(()) => {
322 let mut write = self.backend.write().await;
323 let backend = write.as_mut().unwrap();
324 tasks::generate::generate(backend, req.clone()).await
325 }
326 Err(e) => Err(e),
327 }
328 };
329
330 match result {
331 Ok(text) => {
332 let latency_ms = start.elapsed().as_millis() as u64;
333 let estimated_tokens = text.split_whitespace().count();
334 let mut tracker = self.outcome_tracker.write().await;
335 tracker.record_complete(&trace_id, latency_ms, 0, estimated_tokens);
336 used_model_name = candidate_name;
337 return Ok(InferenceResult {
338 text,
339 trace_id,
340 model_used: used_model_name,
341 latency_ms,
342 });
343 }
344 Err(e) => {
345 debug!(model = %candidate_name, error = %e, "model failed, trying fallback");
346 {
348 let mut tracker = self.outcome_tracker.write().await;
349 let fail_trace = tracker.record_start(candidate_id, decision.task, "fallback");
350 tracker.record_failure(&fail_trace, &e.to_string());
351 }
352 {
354 let mut write = self.backend.write().await;
355 *write = None;
356 }
357 last_error = Some(e);
358 }
359 }
360 }
361
362 let e = last_error.unwrap_or(InferenceError::InferenceFailed("no models available".into()));
364 let mut tracker = self.outcome_tracker.write().await;
365 tracker.record_failure(&trace_id, &e.to_string());
366 Err(e)
367 }
368
369 pub async fn generate(&self, req: GenerateRequest) -> Result<String, InferenceError> {
372 let model = match req.model.clone() {
373 Some(m) => m,
374 None => {
375 let decision = self.router.route_generate(&req.prompt, &self.registry);
376 debug!(
377 model = %decision.model,
378 reason = %decision.reason,
379 "auto-routed generate request"
380 );
381 decision.model
382 }
383 };
384 self.ensure_backend(&model).await?;
385
386 let mut write = self.backend.write().await;
387 let backend = write.as_mut().unwrap();
388 tasks::generate::generate(backend, req).await
389 }
390
391 pub async fn embed(&self, req: EmbedRequest) -> Result<Vec<Vec<f32>>, InferenceError> {
394 self.ensure_embedding_backend().await?;
395
396 let mut write = self.embedding_backend.write().await;
397 let backend = write.as_mut().unwrap();
398
399 let instruction = req.instruction.as_deref()
400 .unwrap_or("Retrieve relevant memory facts");
401
402 let mut results = Vec::with_capacity(req.texts.len());
403 for text in &req.texts {
404 let embedding = if req.is_query {
405 backend.embed_query(text, instruction)?
406 } else {
407 backend.embed_one(text)?
408 };
409 results.push(embedding);
410 }
411 Ok(results)
412 }
413
414 pub async fn classify(&self, req: ClassifyRequest) -> Result<Vec<ClassifyResult>, InferenceError> {
417 let model = match req.model.clone() {
418 Some(m) => m,
419 None => {
420 let m = self.router.route_small(&self.registry);
421 debug!(model = %m, "auto-routed classify request");
422 m
423 }
424 };
425 self.ensure_backend(&model).await?;
426
427 let mut write = self.backend.write().await;
428 let backend = write.as_mut().unwrap();
429 tasks::classify::classify(backend, req).await
430 }
431
432 pub fn list_models_unified(&self) -> Vec<ModelInfo> {
434 self.unified_registry.list().iter().map(|m| ModelInfo::from(*m)).collect()
435 }
436
437 pub fn list_models(&self) -> Vec<models::ModelInfo> {
439 self.registry.list_models()
440 }
441
442 pub async fn pull_model(&self, name: &str) -> Result<std::path::PathBuf, InferenceError> {
444 self.registry.ensure_model(name).await
445 }
446
447 pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
449 self.registry.remove_model(name)
450 }
451
452 pub fn register_model(&mut self, schema: ModelSchema) {
454 self.unified_registry.register(schema);
455 }
456
457 pub fn outcome_tracker(&self) -> Arc<RwLock<OutcomeTracker>> {
459 self.outcome_tracker.clone()
460 }
461
462 pub async fn export_profiles(&self) -> Vec<ModelProfile> {
464 let tracker = self.outcome_tracker.read().await;
465 tracker.export_profiles()
466 }
467
468 pub async fn import_profiles(&self, profiles: Vec<ModelProfile>) {
470 let mut tracker = self.outcome_tracker.write().await;
471 tracker.import_profiles(profiles);
472 }
473}