m2m/models/
card.rs

1//! Model card data structures.
2//!
3//! This module defines the core types for model metadata:
4//! - `ModelCard`: Complete model metadata including abbreviation, encoding, etc.
5//! - `Provider`: LLM provider enum for models with accessible tokenizers
6//! - `Encoding`: Tokenizer encoding type (cl100k_base, o200k_base, llama_bpe, etc.)
7//! - `Pricing`: Token pricing information
8//!
9//! Note: Only models with publicly accessible tokenizers are supported.
10//! This includes OpenAI (via tiktoken) and open source models (Llama, Mistral, etc.).
11//! Closed tokenizer models (Anthropic, Google, X.AI, Cohere) are excluded.
12
13use serde::{Deserialize, Serialize};
14use std::collections::{HashMap, HashSet};
15
16/// LLM provider categorization
17///
18/// Only providers with publicly available tokenizers are supported.
19/// This ensures accurate token counting for M2M compression optimization.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
21#[serde(rename_all = "lowercase")]
22pub enum Provider {
23    /// OpenAI models (GPT-5.x, GPT-4.x, o1-o4) - tokenizer available via tiktoken
24    OpenAI,
25    /// Meta Llama models (Llama 4, 3.x)
26    Meta,
27    /// Mistral AI models (Large, Ministral, Devstral, Codestral)
28    Mistral,
29    /// DeepSeek models (v3.2, R1)
30    DeepSeek,
31    /// Qwen models (3, 2.5, coder)
32    Qwen,
33    /// Nvidia models (Nemotron 3 - Llama-based)
34    Nvidia,
35    /// Google Gemma models (open source, NOT Gemini)
36    Google,
37    /// Allen AI OLMo models (fully open source)
38    AllenAI,
39    /// Other/unknown provider
40    #[default]
41    Other,
42}
43
44impl Provider {
45    /// Get the abbreviation prefix for this provider
46    ///
47    /// These prefixes are used in model abbreviations:
48    /// - `o` = OpenAI (e.g., `og4o` for gpt-4o)
49    /// - `m` = Meta (e.g., `ml3170i` for llama-3.1-70b-instruct)
50    /// - `mi` = Mistral (e.g., `mim-l` for mistral-large)
51    /// - `d` = DeepSeek (e.g., `dv3` for deepseek-v3)
52    /// - `q` = Qwen (e.g., `qq2572` for qwen-2.5-72b)
53    /// - `n` = Nvidia (e.g., `nn70` for nemotron-70b)
54    /// - `g` = Google Gemma (e.g., `gg327` for gemma-3-27b)
55    /// - `a` = Allen AI (e.g., `aolmo` for olmo)
56    pub fn prefix(&self) -> &'static str {
57        match self {
58            Provider::OpenAI => "o",
59            Provider::Meta => "m",
60            Provider::Mistral => "mi",
61            Provider::DeepSeek => "d",
62            Provider::Qwen => "q",
63            Provider::Nvidia => "n",
64            Provider::Google => "g",
65            Provider::AllenAI => "a",
66            Provider::Other => "_",
67        }
68    }
69
70    /// Parse provider from model ID prefix
71    ///
72    /// # Examples
73    /// ```
74    /// use m2m::models::Provider;
75    ///
76    /// assert_eq!(Provider::from_model_id("openai/gpt-4o"), Provider::OpenAI);
77    /// assert_eq!(Provider::from_model_id("meta-llama/llama-3.1-70b"), Provider::Meta);
78    /// assert_eq!(Provider::from_model_id("mistralai/mistral-large"), Provider::Mistral);
79    /// ```
80    pub fn from_model_id(id: &str) -> Self {
81        let prefix = id.split('/').next().unwrap_or(id);
82        let model_name = id.split('/').nth(1).unwrap_or("");
83        match prefix {
84            "openai" => Provider::OpenAI,
85            "meta-llama" => Provider::Meta,
86            "mistralai" => Provider::Mistral,
87            "deepseek" => Provider::DeepSeek,
88            "qwen" => Provider::Qwen,
89            "nvidia" => Provider::Nvidia,
90            // Google Gemma is open source, Gemini is not
91            "google" if model_name.starts_with("gemma") => Provider::Google,
92            "allenai" => Provider::AllenAI,
93            _ => Provider::Other,
94        }
95    }
96
97    /// Get provider display name
98    pub fn name(&self) -> &'static str {
99        match self {
100            Provider::OpenAI => "OpenAI",
101            Provider::Meta => "Meta",
102            Provider::Mistral => "Mistral",
103            Provider::DeepSeek => "DeepSeek",
104            Provider::Qwen => "Qwen",
105            Provider::Nvidia => "Nvidia",
106            Provider::Google => "Google",
107            Provider::AllenAI => "Allen AI",
108            Provider::Other => "Other",
109        }
110    }
111}
112
113/// Tokenizer encoding type
114///
115/// Different models use different tokenizers. The encoding type determines
116/// which tokenizer to use for accurate token counting.
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
118pub enum Encoding {
119    /// OpenAI cl100k_base encoding (GPT-3.5, GPT-4) - via tiktoken
120    #[default]
121    Cl100kBase,
122    /// OpenAI o200k_base encoding (GPT-4o, o1, o3) - via tiktoken
123    O200kBase,
124    /// Llama-style BPE tokenizer (Llama, Mistral, and derivatives)
125    LlamaBpe,
126    /// Heuristic fallback (~4 characters per token)
127    Heuristic,
128}
129
130impl Encoding {
131    /// Infer encoding from model ID
132    ///
133    /// # Examples
134    /// ```
135    /// use m2m::models::Encoding;
136    ///
137    /// assert_eq!(Encoding::infer_from_id("openai/gpt-4o"), Encoding::O200kBase);
138    /// assert_eq!(Encoding::infer_from_id("openai/gpt-4"), Encoding::Cl100kBase);
139    /// assert_eq!(Encoding::infer_from_id("meta-llama/llama-3.1-70b"), Encoding::LlamaBpe);
140    /// ```
141    pub fn infer_from_id(id: &str) -> Self {
142        let id_lower = id.to_lowercase();
143
144        // O200k models: GPT-4o family, o1, o3
145        if id_lower.contains("gpt-4o")
146            || id_lower.contains("o1-")
147            || id_lower.contains("o3-")
148            || id_lower.contains("/o1")
149            || id_lower.contains("/o3")
150        {
151            return Encoding::O200kBase;
152        }
153
154        // Cl100k models: GPT-3.5, GPT-4 (non-o)
155        if id_lower.contains("gpt-3") || id_lower.contains("gpt-4") {
156            return Encoding::Cl100kBase;
157        }
158
159        // Llama-based models
160        if id_lower.contains("llama")
161            || id_lower.contains("mistral")
162            || id_lower.contains("mixtral")
163            || id_lower.contains("nemotron")
164        {
165            return Encoding::LlamaBpe;
166        }
167
168        // Everything else uses heuristic
169        Encoding::Heuristic
170    }
171
172    /// Get encoding name as string
173    pub fn name(&self) -> &'static str {
174        match self {
175            Encoding::Cl100kBase => "cl100k_base",
176            Encoding::O200kBase => "o200k_base",
177            Encoding::LlamaBpe => "llama_bpe",
178            Encoding::Heuristic => "heuristic",
179        }
180    }
181}
182
183/// Token pricing information (USD per token)
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct Pricing {
186    /// Cost per prompt/input token (USD)
187    pub prompt: f64,
188    /// Cost per completion/output token (USD)
189    pub completion: f64,
190}
191
192impl Pricing {
193    /// Create new pricing
194    pub fn new(prompt: f64, completion: f64) -> Self {
195        Self { prompt, completion }
196    }
197
198    /// Create pricing from per-million token rates (common format)
199    pub fn from_per_million(prompt_per_m: f64, completion_per_m: f64) -> Self {
200        Self {
201            prompt: prompt_per_m / 1_000_000.0,
202            completion: completion_per_m / 1_000_000.0,
203        }
204    }
205
206    /// Calculate cost for given token counts
207    pub fn calculate(&self, prompt_tokens: u64, completion_tokens: u64) -> f64 {
208        self.prompt * prompt_tokens as f64 + self.completion * completion_tokens as f64
209    }
210}
211
212/// Model metadata card
213///
214/// Contains all metadata needed for compression, token counting, and optimization.
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ModelCard {
217    /// Full model ID (e.g., "openai/gpt-4o")
218    pub id: String,
219
220    /// Short abbreviation for compression (e.g., "og4o")
221    pub abbrev: String,
222
223    /// Provider
224    pub provider: Provider,
225
226    /// Tokenizer encoding
227    pub encoding: Encoding,
228
229    /// Context window size (max input + output tokens)
230    pub context_length: u32,
231
232    /// Default parameter values (for removal during compression)
233    #[serde(default)]
234    pub defaults: HashMap<String, serde_json::Value>,
235
236    /// Supported parameters
237    #[serde(default)]
238    pub supported_params: HashSet<String>,
239
240    /// Pricing information (optional)
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub pricing: Option<Pricing>,
243
244    /// Whether this model supports streaming
245    #[serde(default = "default_true")]
246    pub supports_streaming: bool,
247
248    /// Whether this model supports function/tool calling
249    #[serde(default)]
250    pub supports_tools: bool,
251
252    /// Whether this model supports vision/images
253    #[serde(default)]
254    pub supports_vision: bool,
255}
256
257fn default_true() -> bool {
258    true
259}
260
261impl ModelCard {
262    /// Create a new model card with auto-generated abbreviation
263    pub fn new(id: impl Into<String>) -> Self {
264        let id = id.into();
265        let provider = Provider::from_model_id(&id);
266        let encoding = Encoding::infer_from_id(&id);
267        let abbrev = Self::generate_abbrev(&id, provider);
268
269        Self {
270            id,
271            abbrev,
272            provider,
273            encoding,
274            context_length: 128000, // Safe default
275            defaults: default_params(),
276            supported_params: common_params(),
277            pricing: None,
278            supports_streaming: true,
279            supports_tools: false,
280            supports_vision: false,
281        }
282    }
283
284    /// Create model card with explicit abbreviation
285    pub fn with_abbrev(id: impl Into<String>, abbrev: impl Into<String>) -> Self {
286        let id = id.into();
287        let provider = Provider::from_model_id(&id);
288        let encoding = Encoding::infer_from_id(&id);
289
290        Self {
291            id,
292            abbrev: abbrev.into(),
293            provider,
294            encoding,
295            context_length: 128000,
296            defaults: default_params(),
297            supported_params: common_params(),
298            pricing: None,
299            supports_streaming: true,
300            supports_tools: false,
301            supports_vision: false,
302        }
303    }
304
305    /// Builder: set encoding
306    pub fn encoding(mut self, encoding: Encoding) -> Self {
307        self.encoding = encoding;
308        self
309    }
310
311    /// Builder: set context length
312    pub fn context_length(mut self, context_length: u32) -> Self {
313        self.context_length = context_length;
314        self
315    }
316
317    /// Builder: set pricing
318    pub fn pricing(mut self, pricing: Pricing) -> Self {
319        self.pricing = Some(pricing);
320        self
321    }
322
323    /// Builder: enable tools support
324    pub fn with_tools(mut self) -> Self {
325        self.supports_tools = true;
326        self
327    }
328
329    /// Builder: enable vision support
330    pub fn with_vision(mut self) -> Self {
331        self.supports_vision = true;
332        self
333    }
334
335    /// Generate abbreviation from model ID
336    ///
337    /// The abbreviation scheme:
338    /// 1. Provider prefix (1-2 chars): o=OpenAI, m=Meta, mi=Mistral, etc.
339    /// 2. Compressed model name: remove common prefixes, compress version numbers
340    ///
341    /// Examples:
342    /// - `openai/gpt-4o` -> `og4o`
343    /// - `meta-llama/llama-3.1-405b` -> `ml31405b`
344    /// - `deepseek/deepseek-v3` -> `dv3`
345    pub fn generate_abbrev(id: &str, provider: Provider) -> String {
346        let prefix = provider.prefix();
347
348        // Extract model name part (after provider/)
349        let name = id.split('/').next_back().unwrap_or(id);
350
351        // Generate short form through a series of replacements
352        let short = name
353            // Remove common model prefixes
354            .replace("gpt-", "g")
355            .replace("llama-", "l")
356            .replace("mistral-", "m")
357            .replace("mixtral-", "mx")
358            .replace("deepseek-", "")
359            .replace("qwen-", "q")
360            .replace("nemotron-", "n")
361            .replace("codestral-", "cod")
362            // Compress version qualifiers
363            .replace("-turbo", "t")
364            .replace("-preview", "p")
365            .replace("-mini", "m")
366            .replace("-latest", "l")
367            .replace("-instruct", "i")
368            .replace("-chat", "")
369            .replace("-coder", "c")
370            .replace("-lite", "l")
371            // Remove punctuation
372            .replace(['.', '-'], "");
373
374        format!("{prefix}{short}")
375    }
376}
377
378/// Get common default parameter values
379///
380/// These are the OpenAI API defaults that can be safely removed during compression.
381pub fn default_params() -> HashMap<String, serde_json::Value> {
382    let mut map = HashMap::new();
383    map.insert("temperature".into(), serde_json::json!(1.0));
384    map.insert("top_p".into(), serde_json::json!(1.0));
385    map.insert("n".into(), serde_json::json!(1));
386    map.insert("stream".into(), serde_json::json!(false));
387    map.insert("frequency_penalty".into(), serde_json::json!(0));
388    map.insert("presence_penalty".into(), serde_json::json!(0));
389    map
390}
391
392/// Get common supported parameters
393pub fn common_params() -> HashSet<String> {
394    [
395        "model",
396        "messages",
397        "temperature",
398        "top_p",
399        "n",
400        "stream",
401        "stop",
402        "max_tokens",
403        "frequency_penalty",
404        "presence_penalty",
405        "logit_bias",
406        "tools",
407        "tool_choice",
408        "response_format",
409        "seed",
410        "user",
411    ]
412    .into_iter()
413    .map(String::from)
414    .collect()
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_provider_from_model_id() {
423        assert_eq!(Provider::from_model_id("openai/gpt-4o"), Provider::OpenAI);
424        assert_eq!(
425            Provider::from_model_id("meta-llama/llama-3.1-70b"),
426            Provider::Meta
427        );
428        assert_eq!(
429            Provider::from_model_id("mistralai/mistral-large"),
430            Provider::Mistral
431        );
432        assert_eq!(
433            Provider::from_model_id("deepseek/deepseek-v3"),
434            Provider::DeepSeek
435        );
436        // Closed tokenizer providers map to Other
437        assert_eq!(
438            Provider::from_model_id("anthropic/claude-3.5-sonnet"),
439            Provider::Other
440        );
441        assert_eq!(
442            Provider::from_model_id("google/gemini-2.0-flash"),
443            Provider::Other
444        );
445        assert_eq!(Provider::from_model_id("unknown/model"), Provider::Other);
446        assert_eq!(Provider::from_model_id("gpt-4"), Provider::Other);
447    }
448
449    #[test]
450    fn test_provider_prefix() {
451        assert_eq!(Provider::OpenAI.prefix(), "o");
452        assert_eq!(Provider::Meta.prefix(), "m");
453        assert_eq!(Provider::Mistral.prefix(), "mi");
454        assert_eq!(Provider::DeepSeek.prefix(), "d");
455        assert_eq!(Provider::Qwen.prefix(), "q");
456    }
457
458    #[test]
459    fn test_encoding_inference() {
460        // OpenAI models
461        assert_eq!(
462            Encoding::infer_from_id("openai/gpt-4o"),
463            Encoding::O200kBase
464        );
465        assert_eq!(
466            Encoding::infer_from_id("openai/gpt-4o-mini"),
467            Encoding::O200kBase
468        );
469        assert_eq!(Encoding::infer_from_id("openai/o1"), Encoding::O200kBase);
470        assert_eq!(
471            Encoding::infer_from_id("openai/gpt-4-turbo"),
472            Encoding::Cl100kBase
473        );
474        assert_eq!(
475            Encoding::infer_from_id("openai/gpt-3.5-turbo"),
476            Encoding::Cl100kBase
477        );
478        // Llama-based models
479        assert_eq!(
480            Encoding::infer_from_id("meta-llama/llama-3.1-70b"),
481            Encoding::LlamaBpe
482        );
483        assert_eq!(
484            Encoding::infer_from_id("mistralai/mistral-large"),
485            Encoding::LlamaBpe
486        );
487        // Other models use heuristic
488        assert_eq!(
489            Encoding::infer_from_id("qwen/qwen-2.5-72b"),
490            Encoding::Heuristic
491        );
492    }
493
494    #[test]
495    fn test_abbreviation_generation() {
496        // OpenAI models
497        assert_eq!(
498            ModelCard::generate_abbrev("openai/gpt-4o", Provider::OpenAI),
499            "og4o"
500        );
501        assert_eq!(
502            ModelCard::generate_abbrev("openai/gpt-4o-mini", Provider::OpenAI),
503            "og4om"
504        );
505        assert_eq!(
506            ModelCard::generate_abbrev("openai/gpt-4-turbo", Provider::OpenAI),
507            "og4t"
508        );
509        assert_eq!(
510            ModelCard::generate_abbrev("openai/o1", Provider::OpenAI),
511            "oo1"
512        );
513
514        // Meta models
515        assert_eq!(
516            ModelCard::generate_abbrev("meta-llama/llama-3.1-405b", Provider::Meta),
517            "ml31405b"
518        );
519
520        // DeepSeek models
521        assert_eq!(
522            ModelCard::generate_abbrev("deepseek/deepseek-v3", Provider::DeepSeek),
523            "dv3"
524        );
525    }
526
527    #[test]
528    fn test_model_card_creation() {
529        let card = ModelCard::new("openai/gpt-4o");
530        assert_eq!(card.id, "openai/gpt-4o");
531        assert_eq!(card.abbrev, "og4o");
532        assert_eq!(card.provider, Provider::OpenAI);
533        assert_eq!(card.encoding, Encoding::O200kBase);
534    }
535
536    #[test]
537    fn test_pricing_calculation() {
538        // GPT-4o pricing: $2.50/M input, $10/M output
539        let pricing = Pricing::from_per_million(2.50, 10.0);
540        let cost = pricing.calculate(1000, 500);
541        assert!((cost - 0.0075).abs() < 0.0001); // 0.0025 + 0.005 = 0.0075
542    }
543}