Skip to main content

car_inference/
schema.rs

1//! Model schema — declarative metadata for models, analogous to ToolSchema for tools.
2//!
3//! Every model (local GGUF, remote API, Ollama) is described by a `ModelSchema`
4//! that declares identity, capabilities, constraints, cost, and source.
5//! The router uses this schema for initial routing; observed outcomes refine it.
6
7use serde::{Deserialize, Serialize};
8
9/// What a model can do.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum ModelCapability {
13    /// Text completion / chat generation
14    Generate,
15    /// Vector embeddings
16    Embed,
17    /// Label assignment / classification
18    Classify,
19    /// Code generation, repair, refactoring
20    Code,
21    /// Chain-of-thought, planning, analysis
22    Reasoning,
23    /// Text condensation
24    Summarize,
25    /// Function/tool calling
26    ToolUse,
27    /// Vision / image understanding
28    Vision,
29}
30
31/// How to access the model.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(tag = "type", rename_all = "snake_case")]
34pub enum ModelSource {
35    /// Local GGUF file via Candle backend.
36    Local {
37        hf_repo: String,
38        hf_filename: String,
39        tokenizer_repo: String,
40    },
41    /// Remote API endpoint (OpenAI-compatible, Anthropic, etc.)
42    RemoteApi {
43        endpoint: String,
44        /// Environment variable name containing the API key (never the key itself).
45        api_key_env: String,
46        #[serde(default)]
47        api_version: Option<String>,
48        protocol: ApiProtocol,
49    },
50    /// Ollama local server.
51    Ollama {
52        model_tag: String,
53        #[serde(default = "default_ollama_host")]
54        host: String,
55    },
56    /// Proprietary provider with custom auth and protocol.
57    ///
58    /// For vendor-specific APIs that aren't generic OpenAI-compatible endpoints.
59    /// Parslee is the first proprietary provider — custom auth (OAuth2),
60    /// custom response format, multi-provider routing built into the API.
61    Proprietary {
62        /// Provider identifier (e.g., "parslee").
63        provider: String,
64        /// Base URL for the API.
65        endpoint: String,
66        /// Auth configuration.
67        auth: ProprietaryAuth,
68        /// Custom protocol details.
69        protocol: ProprietaryProtocol,
70    },
71}
72
73/// Authentication method for proprietary providers.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(tag = "type", rename_all = "snake_case")]
76pub enum ProprietaryAuth {
77    /// OAuth2 PKCE flow (e.g., Azure AD for Parslee).
78    OAuth2Pkce {
79        authority: String,
80        client_id: String,
81        scopes: Vec<String>,
82    },
83    /// Static API key from environment variable.
84    ApiKeyEnv { env_var: String },
85    /// Bearer token from environment variable.
86    BearerTokenEnv { env_var: String },
87}
88
89/// Protocol configuration for proprietary providers.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ProprietaryProtocol {
92    /// Chat/completion endpoint path (appended to base URL).
93    #[serde(default = "default_chat_path")]
94    pub chat_path: String,
95    /// Content type for requests.
96    #[serde(default = "default_content_type")]
97    pub content_type: String,
98    /// Whether the API streams responses via SSE.
99    #[serde(default)]
100    pub streaming: bool,
101    /// Custom headers to include in every request.
102    #[serde(default)]
103    pub extra_headers: std::collections::HashMap<String, String>,
104}
105
106impl Default for ProprietaryProtocol {
107    fn default() -> Self {
108        Self {
109            chat_path: default_chat_path(),
110            content_type: default_content_type(),
111            streaming: false,
112            extra_headers: std::collections::HashMap::new(),
113        }
114    }
115}
116
117fn default_chat_path() -> String {
118    "/chat".to_string()
119}
120
121fn default_content_type() -> String {
122    "application/json".to_string()
123}
124
125fn default_ollama_host() -> String {
126    "http://localhost:11434".to_string()
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
130#[serde(rename_all = "snake_case")]
131pub enum ApiProtocol {
132    OpenAiCompat,
133    Anthropic,
134    Google,
135}
136
137/// Declared performance expectations. Overridden by observed data once available.
138#[derive(Debug, Clone, Default, Serialize, Deserialize)]
139pub struct PerformanceEnvelope {
140    /// Median latency in milliseconds (declared/estimated).
141    #[serde(default)]
142    pub latency_p50_ms: Option<u64>,
143    /// 99th percentile latency in milliseconds.
144    #[serde(default)]
145    pub latency_p99_ms: Option<u64>,
146    /// Tokens per second throughput.
147    #[serde(default)]
148    pub tokens_per_second: Option<f64>,
149}
150
151/// Cost model for routing optimization.
152#[derive(Debug, Clone, Default, Serialize, Deserialize)]
153pub struct CostModel {
154    /// USD per 1M input tokens (remote models).
155    #[serde(default)]
156    pub input_per_mtok: Option<f64>,
157    /// USD per 1M output tokens (remote models).
158    #[serde(default)]
159    pub output_per_mtok: Option<f64>,
160    /// On-disk size in MB (local models).
161    #[serde(default)]
162    pub size_mb: Option<u64>,
163    /// RAM required during inference in MB.
164    #[serde(default)]
165    pub ram_mb: Option<u64>,
166}
167
168/// The full declarative schema for a model.
169///
170/// Analogous to `ToolSchema` — describes what a model is, what it can do,
171/// and how to access it. The router uses this for constraint-based filtering
172/// and cold-start scoring before observed performance data is available.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct ModelSchema {
175    /// Unique identifier: "provider/model-name:variant" (e.g., "qwen/qwen3-4b:q4_k_m").
176    pub id: String,
177    /// Human-readable display name.
178    pub name: String,
179    /// Provider (qwen, openai, anthropic, google, meta, ollama, custom).
180    pub provider: String,
181    /// Model family for grouping (qwen3, gpt-4, claude-4, llama-3).
182    pub family: String,
183    /// Semantic version or checkpoint label.
184    #[serde(default)]
185    pub version: String,
186    /// What this model can do — ordered by primary capability first.
187    pub capabilities: Vec<ModelCapability>,
188    /// Context window in tokens.
189    pub context_length: usize,
190    /// Parameter count as human-readable string (e.g., "4B", "30B (3B active)").
191    #[serde(default)]
192    pub param_count: String,
193    /// Quantization (Q4_K_M, Q8_0, F16, none).
194    #[serde(default)]
195    pub quantization: Option<String>,
196    /// Declared performance envelope (initial estimate, overridden by observed data).
197    #[serde(default)]
198    pub performance: PerformanceEnvelope,
199    /// Cost structure.
200    #[serde(default)]
201    pub cost: CostModel,
202    /// How to access this model.
203    pub source: ModelSource,
204    /// Free-form tags for filtering (e.g., "fast", "multilingual", "moe").
205    #[serde(default)]
206    pub tags: Vec<String>,
207    /// Whether this model is currently available (downloaded / reachable).
208    /// Not serialized — computed at runtime.
209    #[serde(skip)]
210    pub available: bool,
211}
212
213impl ModelSchema {
214    /// Check if this model has a given capability.
215    pub fn has_capability(&self, cap: ModelCapability) -> bool {
216        self.capabilities.contains(&cap)
217    }
218
219    /// Check if this model is local (runs on-device).
220    pub fn is_local(&self) -> bool {
221        matches!(self.source, ModelSource::Local { .. })
222    }
223
224    /// Check if this model is remote (requires API call).
225    pub fn is_remote(&self) -> bool {
226        matches!(self.source, ModelSource::RemoteApi { .. })
227    }
228
229    /// Get the size in MB (from cost model or 0 if unknown).
230    pub fn size_mb(&self) -> u64 {
231        self.cost.size_mb.unwrap_or(0)
232    }
233
234    /// Get the RAM requirement in MB (from cost model, falls back to size_mb).
235    pub fn ram_mb(&self) -> u64 {
236        self.cost.ram_mb.unwrap_or_else(|| self.size_mb())
237    }
238
239    /// Estimated cost per 1K output tokens in USD. Returns 0.0 for local models.
240    pub fn cost_per_1k_output(&self) -> f64 {
241        self.cost.output_per_mtok.map(|c| c / 1000.0).unwrap_or(0.0)
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    fn sample_local() -> ModelSchema {
250        ModelSchema {
251            id: "qwen/qwen3-4b:q4_k_m".into(),
252            name: "Qwen3-4B".into(),
253            provider: "qwen".into(),
254            family: "qwen3".into(),
255            version: "1.0".into(),
256            capabilities: vec![ModelCapability::Generate, ModelCapability::Code],
257            context_length: 32768,
258            param_count: "4B".into(),
259            quantization: Some("Q4_K_M".into()),
260            performance: PerformanceEnvelope {
261                tokens_per_second: Some(45.0),
262                ..Default::default()
263            },
264            cost: CostModel {
265                size_mb: Some(2500),
266                ram_mb: Some(2500),
267                ..Default::default()
268            },
269            source: ModelSource::Local {
270                hf_repo: "Qwen/Qwen3-4B-GGUF".into(),
271                hf_filename: "Qwen3-4B-Q4_K_M.gguf".into(),
272                tokenizer_repo: "Qwen/Qwen3-4B".into(),
273            },
274            tags: vec!["code".into(), "fast".into()],
275            available: false,
276        }
277    }
278
279    fn sample_remote() -> ModelSchema {
280        ModelSchema {
281            id: "anthropic/claude-sonnet-4-6:latest".into(),
282            name: "Claude Sonnet 4.6".into(),
283            provider: "anthropic".into(),
284            family: "claude-4".into(),
285            version: "latest".into(),
286            capabilities: vec![
287                ModelCapability::Generate,
288                ModelCapability::Code,
289                ModelCapability::Reasoning,
290                ModelCapability::ToolUse,
291                ModelCapability::Vision,
292            ],
293            context_length: 200000,
294            param_count: String::new(),
295            quantization: None,
296            performance: PerformanceEnvelope {
297                latency_p50_ms: Some(2000),
298                latency_p99_ms: Some(8000),
299                tokens_per_second: Some(80.0),
300            },
301            cost: CostModel {
302                input_per_mtok: Some(3.0),
303                output_per_mtok: Some(15.0),
304                ..Default::default()
305            },
306            source: ModelSource::RemoteApi {
307                endpoint: "https://api.anthropic.com/v1/messages".into(),
308                api_key_env: "ANTHROPIC_API_KEY".into(),
309                api_version: Some("2023-06-01".into()),
310                protocol: ApiProtocol::Anthropic,
311            },
312            tags: vec!["reasoning".into(), "tool_use".into()],
313            available: false,
314        }
315    }
316
317    #[test]
318    fn capabilities() {
319        let m = sample_local();
320        assert!(m.has_capability(ModelCapability::Code));
321        assert!(!m.has_capability(ModelCapability::Vision));
322    }
323
324    #[test]
325    fn local_vs_remote() {
326        assert!(sample_local().is_local());
327        assert!(!sample_local().is_remote());
328        assert!(sample_remote().is_remote());
329        assert!(!sample_remote().is_local());
330    }
331
332    #[test]
333    fn cost() {
334        let local = sample_local();
335        assert_eq!(local.cost_per_1k_output(), 0.0);
336
337        let remote = sample_remote();
338        assert!(remote.cost_per_1k_output() > 0.0);
339    }
340
341    #[test]
342    fn serde_roundtrip() {
343        let local = sample_local();
344        let json = serde_json::to_string(&local).unwrap();
345        let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
346        assert_eq!(parsed.id, local.id);
347        assert_eq!(parsed.capabilities, local.capabilities);
348
349        let remote = sample_remote();
350        let json = serde_json::to_string(&remote).unwrap();
351        let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
352        assert_eq!(parsed.id, remote.id);
353        // available is skip-serialized, defaults to false
354        assert!(!parsed.available);
355    }
356}