1use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum ModelCapability {
13 Generate,
15 Embed,
17 Classify,
19 Code,
21 Reasoning,
23 Summarize,
25 ToolUse,
27 Vision,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(tag = "type", rename_all = "snake_case")]
34pub enum ModelSource {
35 Local {
37 hf_repo: String,
38 hf_filename: String,
39 tokenizer_repo: String,
40 },
41 RemoteApi {
43 endpoint: String,
44 api_key_env: String,
46 #[serde(default)]
47 api_version: Option<String>,
48 protocol: ApiProtocol,
49 },
50 Ollama {
52 model_tag: String,
53 #[serde(default = "default_ollama_host")]
54 host: String,
55 },
56 Proprietary {
62 provider: String,
64 endpoint: String,
66 auth: ProprietaryAuth,
68 protocol: ProprietaryProtocol,
70 },
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(tag = "type", rename_all = "snake_case")]
76pub enum ProprietaryAuth {
77 OAuth2Pkce {
79 authority: String,
80 client_id: String,
81 scopes: Vec<String>,
82 },
83 ApiKeyEnv { env_var: String },
85 BearerTokenEnv { env_var: String },
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ProprietaryProtocol {
92 #[serde(default = "default_chat_path")]
94 pub chat_path: String,
95 #[serde(default = "default_content_type")]
97 pub content_type: String,
98 #[serde(default)]
100 pub streaming: bool,
101 #[serde(default)]
103 pub extra_headers: std::collections::HashMap<String, String>,
104}
105
106impl Default for ProprietaryProtocol {
107 fn default() -> Self {
108 Self {
109 chat_path: default_chat_path(),
110 content_type: default_content_type(),
111 streaming: false,
112 extra_headers: std::collections::HashMap::new(),
113 }
114 }
115}
116
117fn default_chat_path() -> String {
118 "/chat".to_string()
119}
120
121fn default_content_type() -> String {
122 "application/json".to_string()
123}
124
125fn default_ollama_host() -> String {
126 "http://localhost:11434".to_string()
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
130#[serde(rename_all = "snake_case")]
131pub enum ApiProtocol {
132 OpenAiCompat,
133 Anthropic,
134 Google,
135}
136
137#[derive(Debug, Clone, Default, Serialize, Deserialize)]
139pub struct PerformanceEnvelope {
140 #[serde(default)]
142 pub latency_p50_ms: Option<u64>,
143 #[serde(default)]
145 pub latency_p99_ms: Option<u64>,
146 #[serde(default)]
148 pub tokens_per_second: Option<f64>,
149}
150
151#[derive(Debug, Clone, Default, Serialize, Deserialize)]
153pub struct CostModel {
154 #[serde(default)]
156 pub input_per_mtok: Option<f64>,
157 #[serde(default)]
159 pub output_per_mtok: Option<f64>,
160 #[serde(default)]
162 pub size_mb: Option<u64>,
163 #[serde(default)]
165 pub ram_mb: Option<u64>,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct ModelSchema {
175 pub id: String,
177 pub name: String,
179 pub provider: String,
181 pub family: String,
183 #[serde(default)]
185 pub version: String,
186 pub capabilities: Vec<ModelCapability>,
188 pub context_length: usize,
190 #[serde(default)]
192 pub param_count: String,
193 #[serde(default)]
195 pub quantization: Option<String>,
196 #[serde(default)]
198 pub performance: PerformanceEnvelope,
199 #[serde(default)]
201 pub cost: CostModel,
202 pub source: ModelSource,
204 #[serde(default)]
206 pub tags: Vec<String>,
207 #[serde(skip)]
210 pub available: bool,
211}
212
213impl ModelSchema {
214 pub fn has_capability(&self, cap: ModelCapability) -> bool {
216 self.capabilities.contains(&cap)
217 }
218
219 pub fn is_local(&self) -> bool {
221 matches!(self.source, ModelSource::Local { .. })
222 }
223
224 pub fn is_remote(&self) -> bool {
226 matches!(self.source, ModelSource::RemoteApi { .. })
227 }
228
229 pub fn size_mb(&self) -> u64 {
231 self.cost.size_mb.unwrap_or(0)
232 }
233
234 pub fn ram_mb(&self) -> u64 {
236 self.cost.ram_mb.unwrap_or_else(|| self.size_mb())
237 }
238
239 pub fn cost_per_1k_output(&self) -> f64 {
241 self.cost.output_per_mtok.map(|c| c / 1000.0).unwrap_or(0.0)
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 fn sample_local() -> ModelSchema {
250 ModelSchema {
251 id: "qwen/qwen3-4b:q4_k_m".into(),
252 name: "Qwen3-4B".into(),
253 provider: "qwen".into(),
254 family: "qwen3".into(),
255 version: "1.0".into(),
256 capabilities: vec![ModelCapability::Generate, ModelCapability::Code],
257 context_length: 32768,
258 param_count: "4B".into(),
259 quantization: Some("Q4_K_M".into()),
260 performance: PerformanceEnvelope {
261 tokens_per_second: Some(45.0),
262 ..Default::default()
263 },
264 cost: CostModel {
265 size_mb: Some(2500),
266 ram_mb: Some(2500),
267 ..Default::default()
268 },
269 source: ModelSource::Local {
270 hf_repo: "Qwen/Qwen3-4B-GGUF".into(),
271 hf_filename: "Qwen3-4B-Q4_K_M.gguf".into(),
272 tokenizer_repo: "Qwen/Qwen3-4B".into(),
273 },
274 tags: vec!["code".into(), "fast".into()],
275 available: false,
276 }
277 }
278
279 fn sample_remote() -> ModelSchema {
280 ModelSchema {
281 id: "anthropic/claude-sonnet-4-6:latest".into(),
282 name: "Claude Sonnet 4.6".into(),
283 provider: "anthropic".into(),
284 family: "claude-4".into(),
285 version: "latest".into(),
286 capabilities: vec![
287 ModelCapability::Generate,
288 ModelCapability::Code,
289 ModelCapability::Reasoning,
290 ModelCapability::ToolUse,
291 ModelCapability::Vision,
292 ],
293 context_length: 200000,
294 param_count: String::new(),
295 quantization: None,
296 performance: PerformanceEnvelope {
297 latency_p50_ms: Some(2000),
298 latency_p99_ms: Some(8000),
299 tokens_per_second: Some(80.0),
300 },
301 cost: CostModel {
302 input_per_mtok: Some(3.0),
303 output_per_mtok: Some(15.0),
304 ..Default::default()
305 },
306 source: ModelSource::RemoteApi {
307 endpoint: "https://api.anthropic.com/v1/messages".into(),
308 api_key_env: "ANTHROPIC_API_KEY".into(),
309 api_version: Some("2023-06-01".into()),
310 protocol: ApiProtocol::Anthropic,
311 },
312 tags: vec!["reasoning".into(), "tool_use".into()],
313 available: false,
314 }
315 }
316
317 #[test]
318 fn capabilities() {
319 let m = sample_local();
320 assert!(m.has_capability(ModelCapability::Code));
321 assert!(!m.has_capability(ModelCapability::Vision));
322 }
323
324 #[test]
325 fn local_vs_remote() {
326 assert!(sample_local().is_local());
327 assert!(!sample_local().is_remote());
328 assert!(sample_remote().is_remote());
329 assert!(!sample_remote().is_local());
330 }
331
332 #[test]
333 fn cost() {
334 let local = sample_local();
335 assert_eq!(local.cost_per_1k_output(), 0.0);
336
337 let remote = sample_remote();
338 assert!(remote.cost_per_1k_output() > 0.0);
339 }
340
341 #[test]
342 fn serde_roundtrip() {
343 let local = sample_local();
344 let json = serde_json::to_string(&local).unwrap();
345 let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
346 assert_eq!(parsed.id, local.id);
347 assert_eq!(parsed.capabilities, local.capabilities);
348
349 let remote = sample_remote();
350 let json = serde_json::to_string(&remote).unwrap();
351 let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
352 assert_eq!(parsed.id, remote.id);
353 assert!(!parsed.available);
355 }
356}