1use serde::{Deserialize, Serialize};
14use std::collections::{HashMap, HashSet};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
21#[serde(rename_all = "lowercase")]
22pub enum Provider {
23 OpenAI,
25 Meta,
27 Mistral,
29 DeepSeek,
31 Qwen,
33 Nvidia,
35 Google,
37 AllenAI,
39 #[default]
41 Other,
42}
43
44impl Provider {
45 pub fn prefix(&self) -> &'static str {
57 match self {
58 Provider::OpenAI => "o",
59 Provider::Meta => "m",
60 Provider::Mistral => "mi",
61 Provider::DeepSeek => "d",
62 Provider::Qwen => "q",
63 Provider::Nvidia => "n",
64 Provider::Google => "g",
65 Provider::AllenAI => "a",
66 Provider::Other => "_",
67 }
68 }
69
70 pub fn from_model_id(id: &str) -> Self {
81 let prefix = id.split('/').next().unwrap_or(id);
82 let model_name = id.split('/').nth(1).unwrap_or("");
83 match prefix {
84 "openai" => Provider::OpenAI,
85 "meta-llama" => Provider::Meta,
86 "mistralai" => Provider::Mistral,
87 "deepseek" => Provider::DeepSeek,
88 "qwen" => Provider::Qwen,
89 "nvidia" => Provider::Nvidia,
90 "google" if model_name.starts_with("gemma") => Provider::Google,
92 "allenai" => Provider::AllenAI,
93 _ => Provider::Other,
94 }
95 }
96
97 pub fn name(&self) -> &'static str {
99 match self {
100 Provider::OpenAI => "OpenAI",
101 Provider::Meta => "Meta",
102 Provider::Mistral => "Mistral",
103 Provider::DeepSeek => "DeepSeek",
104 Provider::Qwen => "Qwen",
105 Provider::Nvidia => "Nvidia",
106 Provider::Google => "Google",
107 Provider::AllenAI => "Allen AI",
108 Provider::Other => "Other",
109 }
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
118pub enum Encoding {
119 #[default]
121 Cl100kBase,
122 O200kBase,
124 LlamaBpe,
126 Heuristic,
128}
129
130impl Encoding {
131 pub fn infer_from_id(id: &str) -> Self {
142 let id_lower = id.to_lowercase();
143
144 if id_lower.contains("gpt-4o")
146 || id_lower.contains("o1-")
147 || id_lower.contains("o3-")
148 || id_lower.contains("/o1")
149 || id_lower.contains("/o3")
150 {
151 return Encoding::O200kBase;
152 }
153
154 if id_lower.contains("gpt-3") || id_lower.contains("gpt-4") {
156 return Encoding::Cl100kBase;
157 }
158
159 if id_lower.contains("llama")
161 || id_lower.contains("mistral")
162 || id_lower.contains("mixtral")
163 || id_lower.contains("nemotron")
164 {
165 return Encoding::LlamaBpe;
166 }
167
168 Encoding::Heuristic
170 }
171
172 pub fn name(&self) -> &'static str {
174 match self {
175 Encoding::Cl100kBase => "cl100k_base",
176 Encoding::O200kBase => "o200k_base",
177 Encoding::LlamaBpe => "llama_bpe",
178 Encoding::Heuristic => "heuristic",
179 }
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct Pricing {
186 pub prompt: f64,
188 pub completion: f64,
190}
191
192impl Pricing {
193 pub fn new(prompt: f64, completion: f64) -> Self {
195 Self { prompt, completion }
196 }
197
198 pub fn from_per_million(prompt_per_m: f64, completion_per_m: f64) -> Self {
200 Self {
201 prompt: prompt_per_m / 1_000_000.0,
202 completion: completion_per_m / 1_000_000.0,
203 }
204 }
205
206 pub fn calculate(&self, prompt_tokens: u64, completion_tokens: u64) -> f64 {
208 self.prompt * prompt_tokens as f64 + self.completion * completion_tokens as f64
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ModelCard {
217 pub id: String,
219
220 pub abbrev: String,
222
223 pub provider: Provider,
225
226 pub encoding: Encoding,
228
229 pub context_length: u32,
231
232 #[serde(default)]
234 pub defaults: HashMap<String, serde_json::Value>,
235
236 #[serde(default)]
238 pub supported_params: HashSet<String>,
239
240 #[serde(skip_serializing_if = "Option::is_none")]
242 pub pricing: Option<Pricing>,
243
244 #[serde(default = "default_true")]
246 pub supports_streaming: bool,
247
248 #[serde(default)]
250 pub supports_tools: bool,
251
252 #[serde(default)]
254 pub supports_vision: bool,
255}
256
257fn default_true() -> bool {
258 true
259}
260
261impl ModelCard {
262 pub fn new(id: impl Into<String>) -> Self {
264 let id = id.into();
265 let provider = Provider::from_model_id(&id);
266 let encoding = Encoding::infer_from_id(&id);
267 let abbrev = Self::generate_abbrev(&id, provider);
268
269 Self {
270 id,
271 abbrev,
272 provider,
273 encoding,
274 context_length: 128000, defaults: default_params(),
276 supported_params: common_params(),
277 pricing: None,
278 supports_streaming: true,
279 supports_tools: false,
280 supports_vision: false,
281 }
282 }
283
284 pub fn with_abbrev(id: impl Into<String>, abbrev: impl Into<String>) -> Self {
286 let id = id.into();
287 let provider = Provider::from_model_id(&id);
288 let encoding = Encoding::infer_from_id(&id);
289
290 Self {
291 id,
292 abbrev: abbrev.into(),
293 provider,
294 encoding,
295 context_length: 128000,
296 defaults: default_params(),
297 supported_params: common_params(),
298 pricing: None,
299 supports_streaming: true,
300 supports_tools: false,
301 supports_vision: false,
302 }
303 }
304
305 pub fn encoding(mut self, encoding: Encoding) -> Self {
307 self.encoding = encoding;
308 self
309 }
310
311 pub fn context_length(mut self, context_length: u32) -> Self {
313 self.context_length = context_length;
314 self
315 }
316
317 pub fn pricing(mut self, pricing: Pricing) -> Self {
319 self.pricing = Some(pricing);
320 self
321 }
322
323 pub fn with_tools(mut self) -> Self {
325 self.supports_tools = true;
326 self
327 }
328
329 pub fn with_vision(mut self) -> Self {
331 self.supports_vision = true;
332 self
333 }
334
335 pub fn generate_abbrev(id: &str, provider: Provider) -> String {
346 let prefix = provider.prefix();
347
348 let name = id.split('/').next_back().unwrap_or(id);
350
351 let short = name
353 .replace("gpt-", "g")
355 .replace("llama-", "l")
356 .replace("mistral-", "m")
357 .replace("mixtral-", "mx")
358 .replace("deepseek-", "")
359 .replace("qwen-", "q")
360 .replace("nemotron-", "n")
361 .replace("codestral-", "cod")
362 .replace("-turbo", "t")
364 .replace("-preview", "p")
365 .replace("-mini", "m")
366 .replace("-latest", "l")
367 .replace("-instruct", "i")
368 .replace("-chat", "")
369 .replace("-coder", "c")
370 .replace("-lite", "l")
371 .replace(['.', '-'], "");
373
374 format!("{prefix}{short}")
375 }
376}
377
378pub fn default_params() -> HashMap<String, serde_json::Value> {
382 let mut map = HashMap::new();
383 map.insert("temperature".into(), serde_json::json!(1.0));
384 map.insert("top_p".into(), serde_json::json!(1.0));
385 map.insert("n".into(), serde_json::json!(1));
386 map.insert("stream".into(), serde_json::json!(false));
387 map.insert("frequency_penalty".into(), serde_json::json!(0));
388 map.insert("presence_penalty".into(), serde_json::json!(0));
389 map
390}
391
392pub fn common_params() -> HashSet<String> {
394 [
395 "model",
396 "messages",
397 "temperature",
398 "top_p",
399 "n",
400 "stream",
401 "stop",
402 "max_tokens",
403 "frequency_penalty",
404 "presence_penalty",
405 "logit_bias",
406 "tools",
407 "tool_choice",
408 "response_format",
409 "seed",
410 "user",
411 ]
412 .into_iter()
413 .map(String::from)
414 .collect()
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_provider_from_model_id() {
423 assert_eq!(Provider::from_model_id("openai/gpt-4o"), Provider::OpenAI);
424 assert_eq!(
425 Provider::from_model_id("meta-llama/llama-3.1-70b"),
426 Provider::Meta
427 );
428 assert_eq!(
429 Provider::from_model_id("mistralai/mistral-large"),
430 Provider::Mistral
431 );
432 assert_eq!(
433 Provider::from_model_id("deepseek/deepseek-v3"),
434 Provider::DeepSeek
435 );
436 assert_eq!(
438 Provider::from_model_id("anthropic/claude-3.5-sonnet"),
439 Provider::Other
440 );
441 assert_eq!(
442 Provider::from_model_id("google/gemini-2.0-flash"),
443 Provider::Other
444 );
445 assert_eq!(Provider::from_model_id("unknown/model"), Provider::Other);
446 assert_eq!(Provider::from_model_id("gpt-4"), Provider::Other);
447 }
448
449 #[test]
450 fn test_provider_prefix() {
451 assert_eq!(Provider::OpenAI.prefix(), "o");
452 assert_eq!(Provider::Meta.prefix(), "m");
453 assert_eq!(Provider::Mistral.prefix(), "mi");
454 assert_eq!(Provider::DeepSeek.prefix(), "d");
455 assert_eq!(Provider::Qwen.prefix(), "q");
456 }
457
458 #[test]
459 fn test_encoding_inference() {
460 assert_eq!(
462 Encoding::infer_from_id("openai/gpt-4o"),
463 Encoding::O200kBase
464 );
465 assert_eq!(
466 Encoding::infer_from_id("openai/gpt-4o-mini"),
467 Encoding::O200kBase
468 );
469 assert_eq!(Encoding::infer_from_id("openai/o1"), Encoding::O200kBase);
470 assert_eq!(
471 Encoding::infer_from_id("openai/gpt-4-turbo"),
472 Encoding::Cl100kBase
473 );
474 assert_eq!(
475 Encoding::infer_from_id("openai/gpt-3.5-turbo"),
476 Encoding::Cl100kBase
477 );
478 assert_eq!(
480 Encoding::infer_from_id("meta-llama/llama-3.1-70b"),
481 Encoding::LlamaBpe
482 );
483 assert_eq!(
484 Encoding::infer_from_id("mistralai/mistral-large"),
485 Encoding::LlamaBpe
486 );
487 assert_eq!(
489 Encoding::infer_from_id("qwen/qwen-2.5-72b"),
490 Encoding::Heuristic
491 );
492 }
493
494 #[test]
495 fn test_abbreviation_generation() {
496 assert_eq!(
498 ModelCard::generate_abbrev("openai/gpt-4o", Provider::OpenAI),
499 "og4o"
500 );
501 assert_eq!(
502 ModelCard::generate_abbrev("openai/gpt-4o-mini", Provider::OpenAI),
503 "og4om"
504 );
505 assert_eq!(
506 ModelCard::generate_abbrev("openai/gpt-4-turbo", Provider::OpenAI),
507 "og4t"
508 );
509 assert_eq!(
510 ModelCard::generate_abbrev("openai/o1", Provider::OpenAI),
511 "oo1"
512 );
513
514 assert_eq!(
516 ModelCard::generate_abbrev("meta-llama/llama-3.1-405b", Provider::Meta),
517 "ml31405b"
518 );
519
520 assert_eq!(
522 ModelCard::generate_abbrev("deepseek/deepseek-v3", Provider::DeepSeek),
523 "dv3"
524 );
525 }
526
527 #[test]
528 fn test_model_card_creation() {
529 let card = ModelCard::new("openai/gpt-4o");
530 assert_eq!(card.id, "openai/gpt-4o");
531 assert_eq!(card.abbrev, "og4o");
532 assert_eq!(card.provider, Provider::OpenAI);
533 assert_eq!(card.encoding, Encoding::O200kBase);
534 }
535
536 #[test]
537 fn test_pricing_calculation() {
538 let pricing = Pricing::from_per_million(2.50, 10.0);
540 let cost = pricing.calculate(1000, 500);
541 assert!((cost - 0.0075).abs() < 0.0001); }
543}