car_inference/schema.rs
1//! Model schema — declarative metadata for models, analogous to ToolSchema for tools.
2//!
3//! Every model (local GGUF, remote API, Ollama) is described by a `ModelSchema`
4//! that declares identity, capabilities, constraints, cost, and source.
5//! The router uses this schema for initial routing; observed outcomes refine it.
6
7use serde::{Deserialize, Serialize};
8
9/// What a model can do.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum ModelCapability {
13 /// Text completion / chat generation
14 Generate,
15 /// Vector embeddings
16 Embed,
17 /// Cross-encoder relevance scoring (query + document → relevance
18 /// score). Qwen3-Reranker is the canonical local implementation.
19 Rerank,
20 /// Label assignment / classification
21 Classify,
22 /// Code generation, repair, refactoring
23 Code,
24 /// Chain-of-thought, planning, analysis
25 Reasoning,
26 /// Text condensation
27 Summarize,
28 /// Function/tool calling
29 ToolUse,
30 /// Multiple tool calls in a single response (parallel tool execution)
31 MultiToolCall,
32 /// Vision / image understanding
33 Vision,
34 /// Video understanding (multi-frame sampling + temporal tokens).
35 /// Distinct from `Vision` so routing can prefer video-trained
36 /// models when the caller attaches a video content block.
37 VideoUnderstanding,
38 /// Audio understanding (speech + non-speech audio as an input to
39 /// a chat/reasoning model). Distinct from `SpeechToText` which is
40 /// the transcription-only task. Gemma 4 E2B/E4B and Gemini do
41 /// this; Qwen2.5-VL does not.
42 AudioUnderstanding,
43 /// Visual grounding — structured object-localization output
44 /// (bounding boxes keyed to object labels) in addition to text.
45 Grounding,
46 /// Speech recognition / transcription
47 SpeechToText,
48 /// Speech synthesis / text-to-speech
49 TextToSpeech,
50 /// Image generation
51 ImageGeneration,
52 /// Video generation
53 VideoGeneration,
54}
55
56/// How to access the model.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58#[serde(tag = "type", rename_all = "snake_case")]
59pub enum ModelSource {
60 /// Local GGUF file via Candle backend.
61 Local {
62 hf_repo: String,
63 hf_filename: String,
64 tokenizer_repo: String,
65 },
66 /// Remote API endpoint (OpenAI-compatible, Anthropic, etc.)
67 RemoteApi {
68 endpoint: String,
69 /// Environment variable name containing the API key (never the key itself).
70 /// The env var value may contain comma-separated keys for load balancing.
71 api_key_env: String,
72 /// Additional environment variable names for load balancing across multiple keys.
73 /// Each env var may also contain comma-separated keys.
74 #[serde(default)]
75 api_key_envs: Vec<String>,
76 #[serde(default)]
77 api_version: Option<String>,
78 protocol: ApiProtocol,
79 },
80 /// Ollama local server.
81 Ollama {
82 model_tag: String,
83 #[serde(default = "default_ollama_host")]
84 host: String,
85 },
86 /// Local MLX model via mlx-rs backend (Apple Silicon, safetensors format).
87 /// Models from mlx-community on HuggingFace.
88 Mlx {
89 /// HuggingFace repo (e.g., "mlx-community/Qwen3-4B-4bit").
90 hf_repo: String,
91 /// Optional specific weight filename. If None, auto-discovers safetensors files.
92 #[serde(default)]
93 hf_weight_file: Option<String>,
94 },
95 /// Local vLLM-MLX server (Apple Silicon, OpenAI-compatible API).
96 /// Routes through RemoteBackend with OpenAI protocol handler.
97 VllmMlx {
98 /// Server endpoint (e.g., "http://localhost:8000").
99 endpoint: String,
100 /// The model name as known to vLLM-MLX (e.g., "mlx-community/Qwen3-4B-4bit").
101 model_name: String,
102 },
103 /// Apple's on-device system model via the FoundationModels framework
104 /// (macOS 26+, Apple Silicon). Inference happens in-process through a
105 /// Swift shim — there is no HTTP, no API key, and no model file: the
106 /// OS owns the weights. Availability is checked at runtime via
107 /// `@available(macOS 26.0, *)`; on older macOS or non-Apple-Silicon
108 /// hosts the backend reports `UnsupportedMode` and the router falls
109 /// through to the next candidate.
110 AppleFoundationModels {
111 /// Optional Apple use-case hint passed through to
112 /// `LanguageModelSession`. Apple's framework tunes its prompt and
113 /// safety scaffolding per use case (e.g. "general", "summarize").
114 /// `None` uses the default.
115 #[serde(default)]
116 use_case: Option<String>,
117 },
118 /// Proprietary provider with custom auth and protocol.
119 ///
120 /// For vendor-specific APIs that aren't generic OpenAI-compatible endpoints.
121 /// Parslee is the first proprietary provider — custom auth (OAuth2),
122 /// custom response format, multi-provider routing built into the API.
123 Proprietary {
124 /// Provider identifier (e.g., "parslee").
125 provider: String,
126 /// Base URL for the API.
127 endpoint: String,
128 /// Auth configuration.
129 auth: ProprietaryAuth,
130 /// Custom protocol details.
131 protocol: ProprietaryProtocol,
132 },
133}
134
135/// Authentication method for proprietary providers.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137#[serde(tag = "type", rename_all = "snake_case")]
138pub enum ProprietaryAuth {
139 /// OAuth2 PKCE flow (e.g., Azure AD for Parslee).
140 OAuth2Pkce {
141 authority: String,
142 client_id: String,
143 scopes: Vec<String>,
144 },
145 /// Static API key from environment variable.
146 ApiKeyEnv { env_var: String },
147 /// Bearer token from environment variable.
148 BearerTokenEnv { env_var: String },
149}
150
151/// Protocol configuration for proprietary providers.
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ProprietaryProtocol {
154 /// Chat/completion endpoint path (appended to base URL).
155 #[serde(default = "default_chat_path")]
156 pub chat_path: String,
157 /// Content type for requests.
158 #[serde(default = "default_content_type")]
159 pub content_type: String,
160 /// Whether the API streams responses via SSE.
161 #[serde(default)]
162 pub streaming: bool,
163 /// Custom headers to include in every request.
164 #[serde(default)]
165 pub extra_headers: std::collections::HashMap<String, String>,
166}
167
168impl Default for ProprietaryProtocol {
169 fn default() -> Self {
170 Self {
171 chat_path: default_chat_path(),
172 content_type: default_content_type(),
173 streaming: false,
174 extra_headers: std::collections::HashMap::new(),
175 }
176 }
177}
178
179fn default_chat_path() -> String {
180 "/chat".to_string()
181}
182
183fn default_content_type() -> String {
184 "application/json".to_string()
185}
186
187fn default_ollama_host() -> String {
188 "http://localhost:11434".to_string()
189}
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
192#[serde(rename_all = "snake_case")]
193pub enum ApiProtocol {
194 OpenAiCompat,
195 /// OpenAI Responses API (/v1/responses) — works with all OpenAI models including codex.
196 OpenAiResponses,
197 Anthropic,
198 Google,
199 /// Azure OpenAI — uses api-key header and deployment-based URLs.
200 /// Endpoint format: {base}/openai/deployments/{model}/chat/completions?api-version={version}
201 AzureOpenAi,
202}
203
204/// Declared performance expectations. Overridden by observed data once available.
205#[derive(Debug, Clone, Default, Serialize, Deserialize)]
206pub struct PerformanceEnvelope {
207 /// Median latency in milliseconds (declared/estimated).
208 #[serde(default)]
209 pub latency_p50_ms: Option<u64>,
210 /// 99th percentile latency in milliseconds.
211 #[serde(default)]
212 pub latency_p99_ms: Option<u64>,
213 /// Tokens per second throughput.
214 #[serde(default)]
215 pub tokens_per_second: Option<f64>,
216}
217
218/// Cost model for routing optimization.
219/// Generation parameters that a model may or may not support.
220/// Models declare which params they accept. The inference layer
221/// strips unsupported params before sending to the API.
222#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
223#[serde(rename_all = "snake_case")]
224pub enum GenerateParam {
225 Temperature,
226 TopP,
227 TopK,
228 MaxTokens,
229 StopSequences,
230 FrequencyPenalty,
231 PresencePenalty,
232 Seed,
233 ResponseFormat,
234 /// Extended thinking / internal reasoning before responding.
235 ExtendedThinking,
236}
237
238/// Standard parameter set for most models.
239pub fn standard_params() -> Vec<GenerateParam> {
240 vec![
241 GenerateParam::Temperature,
242 GenerateParam::TopP,
243 GenerateParam::MaxTokens,
244 GenerateParam::StopSequences,
245 GenerateParam::FrequencyPenalty,
246 GenerateParam::PresencePenalty,
247 GenerateParam::Seed,
248 ]
249}
250
251/// Parameter set for reasoning models (no temperature, no top_p).
252pub fn reasoning_params() -> Vec<GenerateParam> {
253 vec![GenerateParam::MaxTokens, GenerateParam::StopSequences]
254}
255
256#[derive(Debug, Clone, Default, Serialize, Deserialize)]
257pub struct CostModel {
258 /// USD per 1M input tokens (remote models).
259 #[serde(default)]
260 pub input_per_mtok: Option<f64>,
261 /// USD per 1M output tokens (remote models).
262 #[serde(default)]
263 pub output_per_mtok: Option<f64>,
264 /// On-disk size in MB (local models).
265 #[serde(default)]
266 pub size_mb: Option<u64>,
267 /// RAM required during inference in MB.
268 #[serde(default)]
269 pub ram_mb: Option<u64>,
270}
271
272/// A score on a public benchmark from a published source (model card,
273/// paper, leaderboard). The schema is deliberately permissive — no enum
274/// of benchmark names — so the catalog can carry whichever benchmarks
275/// the upstream provider chose to publish, and new ones can be added
276/// without a code change. Scores are stored on a 0.0–1.0 scale (e.g.
277/// 73.5% accuracy → 0.735) so they compare cleanly across benchmarks
278/// and so `routing_ext::apply_benchmark_priors` can consume them
279/// directly when wired in later.
280#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct BenchmarkScore {
282 /// Benchmark name as published (e.g., "MMLU-Pro", "GPQA-Diamond",
283 /// "SWE-bench-Verified", "HumanEval", "MATH").
284 pub name: String,
285 /// Score on a 0.0–1.0 scale.
286 pub score: f64,
287 /// Evaluation harness or setup label (e.g., "5-shot", "0-shot CoT",
288 /// "agentic", "pass@1"). Optional but strongly recommended — the
289 /// same benchmark name can mean different things under different
290 /// harnesses.
291 #[serde(default)]
292 pub harness: Option<String>,
293 /// Where the score came from (model card URL, paper, leaderboard
294 /// snapshot). Empty when the source is the upstream provider's
295 /// announcement and a stable URL is not yet known.
296 #[serde(default)]
297 pub source_url: Option<String>,
298 /// ISO 8601 date of the score snapshot (e.g., "2025-08-12"). Lets
299 /// downstream code judge how stale a number is.
300 #[serde(default)]
301 pub measured_at: Option<String>,
302}
303
304/// The full declarative schema for a model.
305///
306/// Analogous to `ToolSchema` — describes what a model is, what it can do,
307/// and how to access it. The router uses this for constraint-based filtering
308/// and cold-start scoring before observed performance data is available.
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct ModelSchema {
311 /// Unique identifier: "provider/model-name:variant" (e.g., "qwen/qwen3-4b:q4_k_m").
312 pub id: String,
313 /// Human-readable display name.
314 pub name: String,
315 /// Provider (qwen, openai, anthropic, google, meta, ollama, custom).
316 pub provider: String,
317 /// Model family for grouping (qwen3, gpt-4, claude-4, llama-3).
318 pub family: String,
319 /// Semantic version or checkpoint label.
320 #[serde(default)]
321 pub version: String,
322 /// What this model can do — ordered by primary capability first.
323 pub capabilities: Vec<ModelCapability>,
324 /// Context window in tokens.
325 pub context_length: usize,
326 /// Parameter count as human-readable string (e.g., "4B", "30B (3B active)").
327 #[serde(default)]
328 pub param_count: String,
329 /// Quantization (Q4_K_M, Q8_0, F16, none).
330 #[serde(default)]
331 pub quantization: Option<String>,
332 /// Declared performance envelope (initial estimate, overridden by observed data).
333 #[serde(default)]
334 pub performance: PerformanceEnvelope,
335 /// Cost structure.
336 #[serde(default)]
337 pub cost: CostModel,
338 /// How to access this model.
339 pub source: ModelSource,
340 /// Free-form tags for filtering (e.g., "fast", "multilingual", "moe").
341 #[serde(default)]
342 pub tags: Vec<String>,
343 /// Supported generation parameters. The inference layer strips any parameter
344 /// not in this set before sending to the API. Empty = all supported.
345 #[serde(default)]
346 pub supported_params: Vec<GenerateParam>,
347 /// Public benchmark scores as published by the model provider or
348 /// reproduced on a public leaderboard (MMLU-Pro, GPQA-Diamond,
349 /// SWE-bench, HumanEval, etc.). The built-in catalog ships this
350 /// empty — population is a curation step, not a code change. See
351 /// `BenchmarkScore` for the field shape and the 0.0–1.0 scoring
352 /// convention.
353 #[serde(default)]
354 pub public_benchmarks: Vec<BenchmarkScore>,
355 /// Whether this model is currently available (downloaded / reachable).
356 /// Not serialized — computed at runtime.
357 #[serde(skip)]
358 pub available: bool,
359}
360
361impl ModelSchema {
362 /// Check if this model has a given capability.
363 pub fn has_capability(&self, cap: ModelCapability) -> bool {
364 self.capabilities.contains(&cap)
365 }
366
367 /// Check if this model is local (runs on-device).
368 pub fn is_local(&self) -> bool {
369 matches!(
370 self.source,
371 ModelSource::Local { .. }
372 | ModelSource::Mlx { .. }
373 | ModelSource::VllmMlx { .. }
374 | ModelSource::AppleFoundationModels { .. }
375 )
376 }
377
378 /// Check if this model uses the MLX backend.
379 pub fn is_mlx(&self) -> bool {
380 matches!(self.source, ModelSource::Mlx { .. })
381 }
382
383 /// Check if this model routes to Apple's on-device FoundationModels
384 /// framework. True only for `ModelSource::AppleFoundationModels`;
385 /// callers must still verify runtime availability before dispatch
386 /// (the schema can describe the model on any host, but execution
387 /// requires macOS 26+ on Apple Silicon).
388 pub fn is_foundation_models(&self) -> bool {
389 matches!(self.source, ModelSource::AppleFoundationModels { .. })
390 }
391
392 /// Check if this model uses vLLM-MLX backend.
393 pub fn is_vllm_mlx(&self) -> bool {
394 matches!(self.source, ModelSource::VllmMlx { .. })
395 }
396
397 /// Check if this model is remote (requires API call).
398 pub fn is_remote(&self) -> bool {
399 matches!(
400 self.source,
401 ModelSource::RemoteApi { .. } | ModelSource::Proprietary { .. }
402 )
403 }
404
405 /// Collect all API key env var names for this model (primary + extras).
406 /// Returns empty vec for non-remote models.
407 pub fn all_api_key_envs(&self) -> Vec<String> {
408 match &self.source {
409 ModelSource::RemoteApi {
410 api_key_env,
411 api_key_envs,
412 ..
413 } => {
414 let mut all = vec![api_key_env.clone()];
415 all.extend(api_key_envs.iter().cloned());
416 all
417 }
418 ModelSource::Proprietary {
419 auth: ProprietaryAuth::ApiKeyEnv { env_var },
420 ..
421 }
422 | ModelSource::Proprietary {
423 auth: ProprietaryAuth::BearerTokenEnv { env_var },
424 ..
425 } => vec![env_var.clone()],
426 _ => vec![],
427 }
428 }
429
430 /// Get the size in MB (from cost model or 0 if unknown).
431 pub fn size_mb(&self) -> u64 {
432 self.cost.size_mb.unwrap_or(0)
433 }
434
435 /// Get the RAM requirement in MB (from cost model, falls back to size_mb).
436 pub fn ram_mb(&self) -> u64 {
437 self.cost.ram_mb.unwrap_or_else(|| self.size_mb())
438 }
439
440 /// Estimated cost per 1K output tokens in USD. Returns 0.0 for local models.
441 pub fn cost_per_1k_output(&self) -> f64 {
442 self.cost.output_per_mtok.map(|c| c / 1000.0).unwrap_or(0.0)
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 fn sample_local() -> ModelSchema {
451 ModelSchema {
452 id: "qwen/qwen3-4b:q4_k_m".into(),
453 name: "Qwen3-4B".into(),
454 provider: "qwen".into(),
455 family: "qwen3".into(),
456 version: "1.0".into(),
457 capabilities: vec![ModelCapability::Generate, ModelCapability::Code],
458 context_length: 32768,
459 param_count: "4B".into(),
460 quantization: Some("Q4_K_M".into()),
461 performance: PerformanceEnvelope {
462 tokens_per_second: Some(45.0),
463 ..Default::default()
464 },
465 cost: CostModel {
466 size_mb: Some(2500),
467 ram_mb: Some(2500),
468 ..Default::default()
469 },
470 source: ModelSource::Local {
471 hf_repo: "Qwen/Qwen3-4B-GGUF".into(),
472 hf_filename: "Qwen3-4B-Q4_K_M.gguf".into(),
473 tokenizer_repo: "Qwen/Qwen3-4B".into(),
474 },
475 tags: vec!["code".into(), "fast".into()],
476 supported_params: vec![],
477 public_benchmarks: vec![],
478 available: false,
479 }
480 }
481
482 fn sample_remote() -> ModelSchema {
483 ModelSchema {
484 id: "anthropic/claude-sonnet-4-6:latest".into(),
485 name: "Claude Sonnet 4.6".into(),
486 provider: "anthropic".into(),
487 family: "claude-4".into(),
488 version: "latest".into(),
489 capabilities: vec![
490 ModelCapability::Generate,
491 ModelCapability::Code,
492 ModelCapability::Reasoning,
493 ModelCapability::ToolUse,
494 ModelCapability::Vision,
495 ],
496 context_length: 200000,
497 param_count: String::new(),
498 quantization: None,
499 performance: PerformanceEnvelope {
500 latency_p50_ms: Some(2000),
501 latency_p99_ms: Some(8000),
502 tokens_per_second: Some(80.0),
503 },
504 cost: CostModel {
505 input_per_mtok: Some(3.0),
506 output_per_mtok: Some(15.0),
507 ..Default::default()
508 },
509 source: ModelSource::RemoteApi {
510 endpoint: "https://api.anthropic.com/v1/messages".into(),
511 api_key_env: "ANTHROPIC_API_KEY".into(),
512 api_key_envs: vec![],
513 api_version: Some("2023-06-01".into()),
514 protocol: ApiProtocol::Anthropic,
515 },
516 tags: vec!["reasoning".into(), "tool_use".into()],
517 supported_params: vec![],
518 public_benchmarks: vec![],
519 available: false,
520 }
521 }
522
523 #[test]
524 fn capabilities() {
525 let m = sample_local();
526 assert!(m.has_capability(ModelCapability::Code));
527 assert!(!m.has_capability(ModelCapability::Vision));
528 }
529
530 #[test]
531 fn local_vs_remote() {
532 assert!(sample_local().is_local());
533 assert!(!sample_local().is_remote());
534 assert!(sample_remote().is_remote());
535 assert!(!sample_remote().is_local());
536 }
537
538 #[test]
539 fn cost() {
540 let local = sample_local();
541 assert_eq!(local.cost_per_1k_output(), 0.0);
542
543 let remote = sample_remote();
544 assert!(remote.cost_per_1k_output() > 0.0);
545 }
546
547 #[test]
548 fn serde_roundtrip() {
549 let local = sample_local();
550 let json = serde_json::to_string(&local).unwrap();
551 let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
552 assert_eq!(parsed.id, local.id);
553 assert_eq!(parsed.capabilities, local.capabilities);
554
555 let remote = sample_remote();
556 let json = serde_json::to_string(&remote).unwrap();
557 let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
558 assert_eq!(parsed.id, remote.id);
559 // available is skip-serialized, defaults to false
560 assert!(!parsed.available);
561 }
562}