Skip to main content

car_inference/
schema.rs

1//! Model schema — declarative metadata for models, analogous to ToolSchema for tools.
2//!
3//! Every model (local GGUF, remote API, Ollama) is described by a `ModelSchema`
4//! that declares identity, capabilities, constraints, cost, and source.
5//! The router uses this schema for initial routing; observed outcomes refine it.
6
7use serde::{Deserialize, Serialize};
8
9/// What a model can do.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum ModelCapability {
13    /// Text completion / chat generation
14    Generate,
15    /// Vector embeddings
16    Embed,
17    /// Cross-encoder relevance scoring (query + document → relevance
18    /// score). Qwen3-Reranker is the canonical local implementation.
19    Rerank,
20    /// Label assignment / classification
21    Classify,
22    /// Code generation, repair, refactoring
23    Code,
24    /// Chain-of-thought, planning, analysis
25    Reasoning,
26    /// Text condensation
27    Summarize,
28    /// Function/tool calling
29    ToolUse,
30    /// Multiple tool calls in a single response (parallel tool execution)
31    MultiToolCall,
32    /// Vision / image understanding
33    Vision,
34    /// Video understanding (multi-frame sampling + temporal tokens).
35    /// Distinct from `Vision` so routing can prefer video-trained
36    /// models when the caller attaches a video content block.
37    VideoUnderstanding,
38    /// Audio understanding (speech + non-speech audio as an input to
39    /// a chat/reasoning model). Distinct from `SpeechToText` which is
40    /// the transcription-only task. Gemma 4 E2B/E4B and Gemini do
41    /// this; Qwen2.5-VL does not.
42    AudioUnderstanding,
43    /// Visual grounding — structured object-localization output
44    /// (bounding boxes keyed to object labels) in addition to text.
45    Grounding,
46    /// Speech recognition / transcription
47    SpeechToText,
48    /// Speech synthesis / text-to-speech
49    TextToSpeech,
50    /// Image generation
51    ImageGeneration,
52    /// Video generation
53    VideoGeneration,
54}
55
56/// How to access the model.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58#[serde(tag = "type", rename_all = "snake_case")]
59pub enum ModelSource {
60    /// Local GGUF file via Candle backend.
61    Local {
62        hf_repo: String,
63        hf_filename: String,
64        tokenizer_repo: String,
65    },
66    /// Remote API endpoint (OpenAI-compatible, Anthropic, etc.)
67    RemoteApi {
68        endpoint: String,
69        /// Environment variable name containing the API key (never the key itself).
70        /// The env var value may contain comma-separated keys for load balancing.
71        api_key_env: String,
72        /// Additional environment variable names for load balancing across multiple keys.
73        /// Each env var may also contain comma-separated keys.
74        #[serde(default)]
75        api_key_envs: Vec<String>,
76        #[serde(default)]
77        api_version: Option<String>,
78        protocol: ApiProtocol,
79    },
80    /// Ollama local server.
81    Ollama {
82        model_tag: String,
83        #[serde(default = "default_ollama_host")]
84        host: String,
85    },
86    /// Local MLX model via mlx-rs backend (Apple Silicon, safetensors format).
87    /// Models from mlx-community on HuggingFace.
88    Mlx {
89        /// HuggingFace repo (e.g., "mlx-community/Qwen3-4B-4bit").
90        hf_repo: String,
91        /// Optional specific weight filename. If None, auto-discovers safetensors files.
92        #[serde(default)]
93        hf_weight_file: Option<String>,
94    },
95    /// Local vLLM-MLX server (Apple Silicon, OpenAI-compatible API).
96    /// Routes through RemoteBackend with OpenAI protocol handler.
97    VllmMlx {
98        /// Server endpoint (e.g., "http://localhost:8000").
99        endpoint: String,
100        /// The model name as known to vLLM-MLX (e.g., "mlx-community/Qwen3-4B-4bit").
101        model_name: String,
102    },
103    /// Apple's on-device system model via the FoundationModels framework
104    /// (macOS 26+, Apple Silicon). Inference happens in-process through a
105    /// Swift shim — there is no HTTP, no API key, and no model file: the
106    /// OS owns the weights. Availability is checked at runtime via
107    /// `@available(macOS 26.0, *)`; on older macOS or non-Apple-Silicon
108    /// hosts the backend reports `UnsupportedMode` and the router falls
109    /// through to the next candidate.
110    AppleFoundationModels {
111        /// Optional Apple use-case hint passed through to
112        /// `LanguageModelSession`. Apple's framework tunes its prompt and
113        /// safety scaffolding per use case (e.g. "general", "summarize").
114        /// `None` uses the default.
115        #[serde(default)]
116        use_case: Option<String>,
117    },
118    /// Proprietary provider with custom auth and protocol.
119    ///
120    /// For vendor-specific APIs that aren't generic OpenAI-compatible endpoints.
121    /// Parslee is the first proprietary provider — custom auth (OAuth2),
122    /// custom response format, multi-provider routing built into the API.
123    Proprietary {
124        /// Provider identifier (e.g., "parslee").
125        provider: String,
126        /// Base URL for the API.
127        endpoint: String,
128        /// Auth configuration.
129        auth: ProprietaryAuth,
130        /// Custom protocol details.
131        protocol: ProprietaryProtocol,
132    },
133}
134
135/// Authentication method for proprietary providers.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137#[serde(tag = "type", rename_all = "snake_case")]
138pub enum ProprietaryAuth {
139    /// OAuth2 PKCE flow (e.g., Azure AD for Parslee).
140    OAuth2Pkce {
141        authority: String,
142        client_id: String,
143        scopes: Vec<String>,
144    },
145    /// Static API key from environment variable.
146    ApiKeyEnv { env_var: String },
147    /// Bearer token from environment variable.
148    BearerTokenEnv { env_var: String },
149}
150
151/// Protocol configuration for proprietary providers.
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ProprietaryProtocol {
154    /// Chat/completion endpoint path (appended to base URL).
155    #[serde(default = "default_chat_path")]
156    pub chat_path: String,
157    /// Content type for requests.
158    #[serde(default = "default_content_type")]
159    pub content_type: String,
160    /// Whether the API streams responses via SSE.
161    #[serde(default)]
162    pub streaming: bool,
163    /// Custom headers to include in every request.
164    #[serde(default)]
165    pub extra_headers: std::collections::HashMap<String, String>,
166}
167
168impl Default for ProprietaryProtocol {
169    fn default() -> Self {
170        Self {
171            chat_path: default_chat_path(),
172            content_type: default_content_type(),
173            streaming: false,
174            extra_headers: std::collections::HashMap::new(),
175        }
176    }
177}
178
179fn default_chat_path() -> String {
180    "/chat".to_string()
181}
182
183fn default_content_type() -> String {
184    "application/json".to_string()
185}
186
187fn default_ollama_host() -> String {
188    "http://localhost:11434".to_string()
189}
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
192#[serde(rename_all = "snake_case")]
193pub enum ApiProtocol {
194    OpenAiCompat,
195    /// OpenAI Responses API (/v1/responses) — works with all OpenAI models including codex.
196    OpenAiResponses,
197    Anthropic,
198    Google,
199    /// Azure OpenAI — uses api-key header and deployment-based URLs.
200    /// Endpoint format: {base}/openai/deployments/{model}/chat/completions?api-version={version}
201    AzureOpenAi,
202}
203
204/// Declared performance expectations. Overridden by observed data once available.
205#[derive(Debug, Clone, Default, Serialize, Deserialize)]
206pub struct PerformanceEnvelope {
207    /// Median latency in milliseconds (declared/estimated).
208    #[serde(default)]
209    pub latency_p50_ms: Option<u64>,
210    /// 99th percentile latency in milliseconds.
211    #[serde(default)]
212    pub latency_p99_ms: Option<u64>,
213    /// Tokens per second throughput.
214    #[serde(default)]
215    pub tokens_per_second: Option<f64>,
216}
217
218/// Cost model for routing optimization.
219/// Generation parameters that a model may or may not support.
220/// Models declare which params they accept. The inference layer
221/// strips unsupported params before sending to the API.
222#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
223#[serde(rename_all = "snake_case")]
224pub enum GenerateParam {
225    Temperature,
226    TopP,
227    TopK,
228    MaxTokens,
229    StopSequences,
230    FrequencyPenalty,
231    PresencePenalty,
232    Seed,
233    ResponseFormat,
234    /// Extended thinking / internal reasoning before responding.
235    ExtendedThinking,
236}
237
238/// Standard parameter set for most models.
239pub fn standard_params() -> Vec<GenerateParam> {
240    vec![
241        GenerateParam::Temperature,
242        GenerateParam::TopP,
243        GenerateParam::MaxTokens,
244        GenerateParam::StopSequences,
245        GenerateParam::FrequencyPenalty,
246        GenerateParam::PresencePenalty,
247        GenerateParam::Seed,
248    ]
249}
250
251/// Parameter set for reasoning models (no temperature, no top_p).
252pub fn reasoning_params() -> Vec<GenerateParam> {
253    vec![GenerateParam::MaxTokens, GenerateParam::StopSequences]
254}
255
256#[derive(Debug, Clone, Default, Serialize, Deserialize)]
257pub struct CostModel {
258    /// USD per 1M input tokens (remote models).
259    #[serde(default)]
260    pub input_per_mtok: Option<f64>,
261    /// USD per 1M output tokens (remote models).
262    #[serde(default)]
263    pub output_per_mtok: Option<f64>,
264    /// On-disk size in MB (local models).
265    #[serde(default)]
266    pub size_mb: Option<u64>,
267    /// RAM required during inference in MB.
268    #[serde(default)]
269    pub ram_mb: Option<u64>,
270}
271
272/// A score on a public benchmark from a published source (model card,
273/// paper, leaderboard). The schema is deliberately permissive — no enum
274/// of benchmark names — so the catalog can carry whichever benchmarks
275/// the upstream provider chose to publish, and new ones can be added
276/// without a code change. Scores are stored on a 0.0–1.0 scale (e.g.
277/// 73.5% accuracy → 0.735) so they compare cleanly across benchmarks
278/// and so `routing_ext::apply_benchmark_priors` can consume them
279/// directly when wired in later.
280#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct BenchmarkScore {
282    /// Benchmark name as published (e.g., "MMLU-Pro", "GPQA-Diamond",
283    /// "SWE-bench-Verified", "HumanEval", "MATH").
284    pub name: String,
285    /// Score on a 0.0–1.0 scale.
286    pub score: f64,
287    /// Evaluation harness or setup label (e.g., "5-shot", "0-shot CoT",
288    /// "agentic", "pass@1"). Optional but strongly recommended — the
289    /// same benchmark name can mean different things under different
290    /// harnesses.
291    #[serde(default)]
292    pub harness: Option<String>,
293    /// Where the score came from (model card URL, paper, leaderboard
294    /// snapshot). Empty when the source is the upstream provider's
295    /// announcement and a stable URL is not yet known.
296    #[serde(default)]
297    pub source_url: Option<String>,
298    /// ISO 8601 date of the score snapshot (e.g., "2025-08-12"). Lets
299    /// downstream code judge how stale a number is.
300    #[serde(default)]
301    pub measured_at: Option<String>,
302}
303
304/// The full declarative schema for a model.
305///
306/// Analogous to `ToolSchema` — describes what a model is, what it can do,
307/// and how to access it. The router uses this for constraint-based filtering
308/// and cold-start scoring before observed performance data is available.
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct ModelSchema {
311    /// Unique identifier: "provider/model-name:variant" (e.g., "qwen/qwen3-4b:q4_k_m").
312    pub id: String,
313    /// Human-readable display name.
314    pub name: String,
315    /// Provider (qwen, openai, anthropic, google, meta, ollama, custom).
316    pub provider: String,
317    /// Model family for grouping (qwen3, gpt-4, claude-4, llama-3).
318    pub family: String,
319    /// Semantic version or checkpoint label.
320    #[serde(default)]
321    pub version: String,
322    /// What this model can do — ordered by primary capability first.
323    pub capabilities: Vec<ModelCapability>,
324    /// Context window in tokens.
325    pub context_length: usize,
326    /// Parameter count as human-readable string (e.g., "4B", "30B (3B active)").
327    #[serde(default)]
328    pub param_count: String,
329    /// Quantization (Q4_K_M, Q8_0, F16, none).
330    #[serde(default)]
331    pub quantization: Option<String>,
332    /// Declared performance envelope (initial estimate, overridden by observed data).
333    #[serde(default)]
334    pub performance: PerformanceEnvelope,
335    /// Cost structure.
336    #[serde(default)]
337    pub cost: CostModel,
338    /// How to access this model.
339    pub source: ModelSource,
340    /// Free-form tags for filtering (e.g., "fast", "multilingual", "moe").
341    #[serde(default)]
342    pub tags: Vec<String>,
343    /// Supported generation parameters. The inference layer strips any parameter
344    /// not in this set before sending to the API. Empty = all supported.
345    #[serde(default)]
346    pub supported_params: Vec<GenerateParam>,
347    /// Public benchmark scores as published by the model provider or
348    /// reproduced on a public leaderboard (MMLU-Pro, GPQA-Diamond,
349    /// SWE-bench, HumanEval, etc.). The built-in catalog ships this
350    /// empty — population is a curation step, not a code change. See
351    /// `BenchmarkScore` for the field shape and the 0.0–1.0 scoring
352    /// convention.
353    #[serde(default)]
354    pub public_benchmarks: Vec<BenchmarkScore>,
355    /// Whether this model is currently available (downloaded / reachable).
356    /// Not serialized — computed at runtime.
357    #[serde(skip)]
358    pub available: bool,
359}
360
361impl ModelSchema {
362    /// Check if this model has a given capability.
363    pub fn has_capability(&self, cap: ModelCapability) -> bool {
364        self.capabilities.contains(&cap)
365    }
366
367    /// Check if this model is local (runs on-device).
368    pub fn is_local(&self) -> bool {
369        matches!(
370            self.source,
371            ModelSource::Local { .. }
372                | ModelSource::Mlx { .. }
373                | ModelSource::VllmMlx { .. }
374                | ModelSource::AppleFoundationModels { .. }
375        )
376    }
377
378    /// Check if this model uses the MLX backend.
379    pub fn is_mlx(&self) -> bool {
380        matches!(self.source, ModelSource::Mlx { .. })
381    }
382
383    /// Check if this model routes to Apple's on-device FoundationModels
384    /// framework. True only for `ModelSource::AppleFoundationModels`;
385    /// callers must still verify runtime availability before dispatch
386    /// (the schema can describe the model on any host, but execution
387    /// requires macOS 26+ on Apple Silicon).
388    pub fn is_foundation_models(&self) -> bool {
389        matches!(self.source, ModelSource::AppleFoundationModels { .. })
390    }
391
392    /// Check if this model uses vLLM-MLX backend.
393    pub fn is_vllm_mlx(&self) -> bool {
394        matches!(self.source, ModelSource::VllmMlx { .. })
395    }
396
397    /// Check if this model is remote (requires API call).
398    pub fn is_remote(&self) -> bool {
399        matches!(
400            self.source,
401            ModelSource::RemoteApi { .. } | ModelSource::Proprietary { .. }
402        )
403    }
404
405    /// Collect all API key env var names for this model (primary + extras).
406    /// Returns empty vec for non-remote models.
407    pub fn all_api_key_envs(&self) -> Vec<String> {
408        match &self.source {
409            ModelSource::RemoteApi {
410                api_key_env,
411                api_key_envs,
412                ..
413            } => {
414                let mut all = vec![api_key_env.clone()];
415                all.extend(api_key_envs.iter().cloned());
416                all
417            }
418            ModelSource::Proprietary {
419                auth: ProprietaryAuth::ApiKeyEnv { env_var },
420                ..
421            }
422            | ModelSource::Proprietary {
423                auth: ProprietaryAuth::BearerTokenEnv { env_var },
424                ..
425            } => vec![env_var.clone()],
426            _ => vec![],
427        }
428    }
429
430    /// Get the size in MB (from cost model or 0 if unknown).
431    pub fn size_mb(&self) -> u64 {
432        self.cost.size_mb.unwrap_or(0)
433    }
434
435    /// Get the RAM requirement in MB (from cost model, falls back to size_mb).
436    pub fn ram_mb(&self) -> u64 {
437        self.cost.ram_mb.unwrap_or_else(|| self.size_mb())
438    }
439
440    /// Estimated cost per 1K output tokens in USD. Returns 0.0 for local models.
441    pub fn cost_per_1k_output(&self) -> f64 {
442        self.cost.output_per_mtok.map(|c| c / 1000.0).unwrap_or(0.0)
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    fn sample_local() -> ModelSchema {
451        ModelSchema {
452            id: "qwen/qwen3-4b:q4_k_m".into(),
453            name: "Qwen3-4B".into(),
454            provider: "qwen".into(),
455            family: "qwen3".into(),
456            version: "1.0".into(),
457            capabilities: vec![ModelCapability::Generate, ModelCapability::Code],
458            context_length: 32768,
459            param_count: "4B".into(),
460            quantization: Some("Q4_K_M".into()),
461            performance: PerformanceEnvelope {
462                tokens_per_second: Some(45.0),
463                ..Default::default()
464            },
465            cost: CostModel {
466                size_mb: Some(2500),
467                ram_mb: Some(2500),
468                ..Default::default()
469            },
470            source: ModelSource::Local {
471                hf_repo: "Qwen/Qwen3-4B-GGUF".into(),
472                hf_filename: "Qwen3-4B-Q4_K_M.gguf".into(),
473                tokenizer_repo: "Qwen/Qwen3-4B".into(),
474            },
475            tags: vec!["code".into(), "fast".into()],
476            supported_params: vec![],
477            public_benchmarks: vec![],
478            available: false,
479        }
480    }
481
482    fn sample_remote() -> ModelSchema {
483        ModelSchema {
484            id: "anthropic/claude-sonnet-4-6:latest".into(),
485            name: "Claude Sonnet 4.6".into(),
486            provider: "anthropic".into(),
487            family: "claude-4".into(),
488            version: "latest".into(),
489            capabilities: vec![
490                ModelCapability::Generate,
491                ModelCapability::Code,
492                ModelCapability::Reasoning,
493                ModelCapability::ToolUse,
494                ModelCapability::Vision,
495            ],
496            context_length: 200000,
497            param_count: String::new(),
498            quantization: None,
499            performance: PerformanceEnvelope {
500                latency_p50_ms: Some(2000),
501                latency_p99_ms: Some(8000),
502                tokens_per_second: Some(80.0),
503            },
504            cost: CostModel {
505                input_per_mtok: Some(3.0),
506                output_per_mtok: Some(15.0),
507                ..Default::default()
508            },
509            source: ModelSource::RemoteApi {
510                endpoint: "https://api.anthropic.com/v1/messages".into(),
511                api_key_env: "ANTHROPIC_API_KEY".into(),
512                api_key_envs: vec![],
513                api_version: Some("2023-06-01".into()),
514                protocol: ApiProtocol::Anthropic,
515            },
516            tags: vec!["reasoning".into(), "tool_use".into()],
517            supported_params: vec![],
518            public_benchmarks: vec![],
519            available: false,
520        }
521    }
522
523    #[test]
524    fn capabilities() {
525        let m = sample_local();
526        assert!(m.has_capability(ModelCapability::Code));
527        assert!(!m.has_capability(ModelCapability::Vision));
528    }
529
530    #[test]
531    fn local_vs_remote() {
532        assert!(sample_local().is_local());
533        assert!(!sample_local().is_remote());
534        assert!(sample_remote().is_remote());
535        assert!(!sample_remote().is_local());
536    }
537
538    #[test]
539    fn cost() {
540        let local = sample_local();
541        assert_eq!(local.cost_per_1k_output(), 0.0);
542
543        let remote = sample_remote();
544        assert!(remote.cost_per_1k_output() > 0.0);
545    }
546
547    #[test]
548    fn serde_roundtrip() {
549        let local = sample_local();
550        let json = serde_json::to_string(&local).unwrap();
551        let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
552        assert_eq!(parsed.id, local.id);
553        assert_eq!(parsed.capabilities, local.capabilities);
554
555        let remote = sample_remote();
556        let json = serde_json::to_string(&remote).unwrap();
557        let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
558        assert_eq!(parsed.id, remote.id);
559        // available is skip-serialized, defaults to false
560        assert!(!parsed.available);
561    }
562}