Skip to main content

openai_protocol/
model_card.rs

1//! Model card definitions for worker model configuration.
2//!
3//! Defines [`ModelCard`] which consolidates model-related configuration:
4//! identity, capabilities, tokenization, and classification support.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use super::{
11    model_type::{Endpoint, ModelType},
12    worker::ProviderType,
13};
14
15fn is_zero(n: &u32) -> bool {
16    *n == 0
17}
18
19fn default_model_type() -> ModelType {
20    ModelType::LLM
21}
22
23/// Model card containing model configuration and capabilities.
24///
25/// # Example
26///
27/// ```
28/// use openai_protocol::{model_type::ModelType, model_card::ModelCard, worker::ProviderType};
29///
30/// let card = ModelCard::new("meta-llama/Llama-3.1-8B-Instruct")
31///     .with_display_name("Llama 3.1 8B Instruct")
32///     .with_alias("llama-3.1-8b")
33///     .with_model_type(ModelType::VISION_LLM)
34///     .with_context_length(128_000)
35///     .with_tokenizer_path("meta-llama/Llama-3.1-8B-Instruct");
36///
37/// assert!(card.matches("llama-3.1-8b"));
38/// assert!(card.model_type.supports_vision());
39/// assert!(card.provider.is_none()); // Local model, no external provider
40/// ```
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelCard {
43    // === Identity ===
44    /// Primary model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct")
45    pub id: String,
46
47    /// Optional display name (e.g., "Llama 3.1 8B Instruct")
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub display_name: Option<String>,
50
51    /// Alternative names/aliases for this model
52    #[serde(default, skip_serializing_if = "Vec::is_empty")]
53    pub aliases: Vec<String>,
54
55    // === Capabilities ===
56    /// Supported endpoint types (bitflags)
57    #[serde(default = "default_model_type")]
58    pub model_type: ModelType,
59
60    /// HuggingFace model type string (e.g., "llama", "qwen2", "gpt-oss")
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    pub hf_model_type: Option<String>,
63
64    /// Model architectures from HuggingFace config (e.g., ["LlamaForCausalLM"])
65    #[serde(default, skip_serializing_if = "Vec::is_empty")]
66    pub architectures: Vec<String>,
67
68    /// Provider hint for API transformations.
69    /// `None` means native/passthrough (no transformation needed).
70    #[serde(default, skip_serializing_if = "Option::is_none")]
71    pub provider: Option<ProviderType>,
72
73    /// Maximum context length in tokens
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    pub context_length: Option<u32>,
76
77    // === Tokenization & Parsing ===
78    /// Path to tokenizer (e.g., HuggingFace model ID or local path)
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub tokenizer_path: Option<String>,
81
82    /// Chat template (Jinja2 template string or path)
83    #[serde(default, skip_serializing_if = "Option::is_none")]
84    pub chat_template: Option<String>,
85
86    /// Reasoning parser type (e.g., "deepseek", "qwen")
87    #[serde(default, skip_serializing_if = "Option::is_none")]
88    pub reasoning_parser: Option<String>,
89
90    /// Tool/function calling parser type (e.g., "llama", "mistral")
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    pub tool_parser: Option<String>,
93
94    /// User-defined metadata (for fields not covered above)
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub metadata: Option<serde_json::Value>,
97
98    // === Classification Support ===
99    /// Classification label mapping (class index -> label name).
100    /// Empty if not a classification model.
101    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
102    pub id2label: HashMap<u32, String>,
103
104    /// Number of classification labels (0 if not a classifier).
105    #[serde(default, skip_serializing_if = "is_zero")]
106    pub num_labels: u32,
107}
108
109impl ModelCard {
110    /// Create a new model card with minimal configuration.
111    ///
112    /// Defaults to `ModelType::LLM` and no provider (native/passthrough).
113    pub fn new(id: impl Into<String>) -> Self {
114        Self {
115            id: id.into(),
116            display_name: None,
117            aliases: Vec::new(),
118            model_type: ModelType::LLM,
119            hf_model_type: None,
120            architectures: Vec::new(),
121            provider: None,
122            context_length: None,
123            tokenizer_path: None,
124            chat_template: None,
125            reasoning_parser: None,
126            tool_parser: None,
127            metadata: None,
128            id2label: HashMap::new(),
129            num_labels: 0,
130        }
131    }
132
133    // === Builder-style methods ===
134
135    /// Set the display name
136    pub fn with_display_name(mut self, name: impl Into<String>) -> Self {
137        self.display_name = Some(name.into());
138        self
139    }
140
141    /// Add a single alias
142    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
143        self.aliases.push(alias.into());
144        self
145    }
146
147    /// Add multiple aliases
148    pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
149        self.aliases.extend(aliases.into_iter().map(|a| a.into()));
150        self
151    }
152
153    /// Set the model type (capabilities)
154    pub fn with_model_type(mut self, model_type: ModelType) -> Self {
155        self.model_type = model_type;
156        self
157    }
158
159    /// Set the HuggingFace model type string
160    pub fn with_hf_model_type(mut self, hf_model_type: impl Into<String>) -> Self {
161        self.hf_model_type = Some(hf_model_type.into());
162        self
163    }
164
165    /// Set the model architectures
166    pub fn with_architectures(mut self, architectures: Vec<String>) -> Self {
167        self.architectures = architectures;
168        self
169    }
170
171    /// Set the provider type (for external API transformations)
172    pub fn with_provider(mut self, provider: ProviderType) -> Self {
173        self.provider = Some(provider);
174        self
175    }
176
177    /// Set the context length
178    pub fn with_context_length(mut self, length: u32) -> Self {
179        self.context_length = Some(length);
180        self
181    }
182
183    /// Set the tokenizer path
184    pub fn with_tokenizer_path(mut self, path: impl Into<String>) -> Self {
185        self.tokenizer_path = Some(path.into());
186        self
187    }
188
189    /// Set the chat template
190    pub fn with_chat_template(mut self, template: impl Into<String>) -> Self {
191        self.chat_template = Some(template.into());
192        self
193    }
194
195    /// Set the reasoning parser type
196    pub fn with_reasoning_parser(mut self, parser: impl Into<String>) -> Self {
197        self.reasoning_parser = Some(parser.into());
198        self
199    }
200
201    /// Set the tool parser type
202    pub fn with_tool_parser(mut self, parser: impl Into<String>) -> Self {
203        self.tool_parser = Some(parser.into());
204        self
205    }
206
207    /// Set custom metadata
208    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
209        self.metadata = Some(metadata);
210        self
211    }
212
213    /// Set the id2label mapping for classification models
214    pub fn with_id2label(mut self, id2label: HashMap<u32, String>) -> Self {
215        self.num_labels = id2label.len() as u32;
216        self.id2label = id2label;
217        self
218    }
219
220    /// Set num_labels directly (alternative to with_id2label)
221    pub fn with_num_labels(mut self, num_labels: u32) -> Self {
222        self.num_labels = num_labels;
223        self
224    }
225
226    // === Query methods ===
227
228    /// Check if this model matches the given ID (including aliases)
229    pub fn matches(&self, model_id: &str) -> bool {
230        self.id == model_id || self.aliases.iter().any(|a| a == model_id)
231    }
232
233    /// Check if this model supports a given endpoint
234    pub fn supports_endpoint(&self, endpoint: Endpoint) -> bool {
235        self.model_type.supports_endpoint(endpoint)
236    }
237
238    /// Get the display name or fall back to ID
239    pub fn name(&self) -> &str {
240        self.display_name.as_deref().unwrap_or(&self.id)
241    }
242
243    /// Check if this is a native/local model (no external provider)
244    #[inline]
245    pub fn is_native(&self) -> bool {
246        self.provider.is_none()
247    }
248
249    /// Check if this model uses an external provider
250    #[inline]
251    pub fn has_external_provider(&self) -> bool {
252        self.provider.is_some()
253    }
254
255    /// Check if this is an LLM (supports chat)
256    #[inline]
257    pub fn is_llm(&self) -> bool {
258        self.model_type.is_llm()
259    }
260
261    /// Check if this is an embedding model
262    #[inline]
263    pub fn is_embedding_model(&self) -> bool {
264        self.model_type.is_embedding_model()
265    }
266
267    /// Check if this model supports vision/multimodal
268    #[inline]
269    pub fn supports_vision(&self) -> bool {
270        self.model_type.supports_vision()
271    }
272
273    /// Check if this model supports tools/function calling
274    #[inline]
275    pub fn supports_tools(&self) -> bool {
276        self.model_type.supports_tools()
277    }
278
279    /// Check if this model supports reasoning
280    #[inline]
281    pub fn supports_reasoning(&self) -> bool {
282        self.model_type.supports_reasoning()
283    }
284
285    /// Check if this is a classification model
286    #[inline]
287    pub fn is_classifier(&self) -> bool {
288        self.num_labels > 0
289    }
290
291    /// Get label for a class index, with fallback to generic label (LABEL_N)
292    pub fn get_label(&self, class_idx: u32) -> String {
293        self.id2label
294            .get(&class_idx)
295            .cloned()
296            .unwrap_or_else(|| format!("LABEL_{}", class_idx))
297    }
298}
299
300impl Default for ModelCard {
301    fn default() -> Self {
302        Self::new(super::UNKNOWN_MODEL_ID)
303    }
304}
305
306impl std::fmt::Display for ModelCard {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        write!(f, "{}", self.name())
309    }
310}