Skip to main content

openai_protocol/
model_card.rs

1//! Model card definitions for worker model configuration.
2//!
3//! Defines [`ModelCard`] which consolidates model-related configuration:
4//! identity, capabilities, tokenization, and classification support.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use super::{
11    model_type::{Endpoint, ModelType},
12    worker::ProviderType,
13};
14
15#[expect(
16    clippy::trivially_copy_pass_by_ref,
17    reason = "serde skip_serializing_if passes &T"
18)]
19fn is_zero(n: &u32) -> bool {
20    *n == 0
21}
22
23fn default_model_type() -> ModelType {
24    ModelType::LLM
25}
26
27/// Model card containing model configuration and capabilities.
28///
29/// # Example
30///
31/// ```
32/// use openai_protocol::{model_type::ModelType, model_card::ModelCard, worker::ProviderType};
33///
34/// let card = ModelCard::new("meta-llama/Llama-3.1-8B-Instruct")
35///     .with_display_name("Llama 3.1 8B Instruct")
36///     .with_alias("llama-3.1-8b")
37///     .with_model_type(ModelType::VISION_LLM)
38///     .with_context_length(128_000)
39///     .with_tokenizer_path("meta-llama/Llama-3.1-8B-Instruct");
40///
41/// assert!(card.matches("llama-3.1-8b"));
42/// assert!(card.model_type.supports_vision());
43/// assert!(card.provider.is_none()); // Local model, no external provider
44/// ```
45#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
46pub struct ModelCard {
47    // === Identity ===
48    /// Primary model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct")
49    pub id: String,
50
51    /// Optional display name (e.g., "Llama 3.1 8B Instruct")
52    #[serde(default, skip_serializing_if = "Option::is_none")]
53    pub display_name: Option<String>,
54
55    /// Alternative names/aliases for this model
56    #[serde(default, skip_serializing_if = "Vec::is_empty")]
57    pub aliases: Vec<String>,
58
59    // === Capabilities ===
60    /// Supported endpoint types (bitflags)
61    #[serde(default = "default_model_type")]
62    pub model_type: ModelType,
63
64    /// HuggingFace model type string (e.g., "llama", "qwen2", "gpt-oss")
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub hf_model_type: Option<String>,
67
68    /// Model architectures from HuggingFace config (e.g., ["LlamaForCausalLM"])
69    #[serde(default, skip_serializing_if = "Vec::is_empty")]
70    pub architectures: Vec<String>,
71
72    /// Provider hint for API transformations.
73    /// `None` means native/passthrough (no transformation needed).
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    pub provider: Option<ProviderType>,
76
77    /// Maximum context length in tokens
78    #[serde(default, skip_serializing_if = "Option::is_none")]
79    pub context_length: Option<u32>,
80
81    // === Tokenization & Parsing ===
82    /// Path to tokenizer (e.g., HuggingFace model ID or local path)
83    #[serde(default, skip_serializing_if = "Option::is_none")]
84    pub tokenizer_path: Option<String>,
85
86    /// Chat template (Jinja2 template string or path)
87    #[serde(default, skip_serializing_if = "Option::is_none")]
88    pub chat_template: Option<String>,
89
90    /// Reasoning parser type (e.g., "deepseek", "qwen")
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    pub reasoning_parser: Option<String>,
93
94    /// Tool/function calling parser type (e.g., "llama", "mistral")
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub tool_parser: Option<String>,
97
98    /// User-defined metadata (for fields not covered above)
99    #[serde(default, skip_serializing_if = "Option::is_none")]
100    pub metadata: Option<serde_json::Value>,
101
102    // === Classification Support ===
103    /// Classification label mapping (class index -> label name).
104    /// Empty if not a classification model.
105    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
106    pub id2label: HashMap<u32, String>,
107
108    /// Number of classification labels (0 if not a classifier).
109    #[serde(default, skip_serializing_if = "is_zero")]
110    pub num_labels: u32,
111}
112
113impl ModelCard {
114    /// Create a new model card with minimal configuration.
115    ///
116    /// Defaults to `ModelType::LLM` and no provider (native/passthrough).
117    pub fn new(id: impl Into<String>) -> Self {
118        Self {
119            id: id.into(),
120            display_name: None,
121            aliases: Vec::new(),
122            model_type: ModelType::LLM,
123            hf_model_type: None,
124            architectures: Vec::new(),
125            provider: None,
126            context_length: None,
127            tokenizer_path: None,
128            chat_template: None,
129            reasoning_parser: None,
130            tool_parser: None,
131            metadata: None,
132            id2label: HashMap::new(),
133            num_labels: 0,
134        }
135    }
136
137    // === Builder-style methods ===
138
139    /// Set the display name
140    pub fn with_display_name(mut self, name: impl Into<String>) -> Self {
141        self.display_name = Some(name.into());
142        self
143    }
144
145    /// Add a single alias
146    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
147        self.aliases.push(alias.into());
148        self
149    }
150
151    /// Add multiple aliases
152    pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
153        self.aliases.extend(aliases.into_iter().map(|a| a.into()));
154        self
155    }
156
157    /// Set the model type (capabilities)
158    pub fn with_model_type(mut self, model_type: ModelType) -> Self {
159        self.model_type = model_type;
160        self
161    }
162
163    /// Set the HuggingFace model type string
164    pub fn with_hf_model_type(mut self, hf_model_type: impl Into<String>) -> Self {
165        self.hf_model_type = Some(hf_model_type.into());
166        self
167    }
168
169    /// Set the model architectures
170    pub fn with_architectures(mut self, architectures: Vec<String>) -> Self {
171        self.architectures = architectures;
172        self
173    }
174
175    /// Set the provider type (for external API transformations)
176    pub fn with_provider(mut self, provider: ProviderType) -> Self {
177        self.provider = Some(provider);
178        self
179    }
180
181    /// Set the context length
182    pub fn with_context_length(mut self, length: u32) -> Self {
183        self.context_length = Some(length);
184        self
185    }
186
187    /// Set the tokenizer path
188    pub fn with_tokenizer_path(mut self, path: impl Into<String>) -> Self {
189        self.tokenizer_path = Some(path.into());
190        self
191    }
192
193    /// Set the chat template
194    pub fn with_chat_template(mut self, template: impl Into<String>) -> Self {
195        self.chat_template = Some(template.into());
196        self
197    }
198
199    /// Set the reasoning parser type
200    pub fn with_reasoning_parser(mut self, parser: impl Into<String>) -> Self {
201        self.reasoning_parser = Some(parser.into());
202        self
203    }
204
205    /// Set the tool parser type
206    pub fn with_tool_parser(mut self, parser: impl Into<String>) -> Self {
207        self.tool_parser = Some(parser.into());
208        self
209    }
210
211    /// Set custom metadata
212    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
213        self.metadata = Some(metadata);
214        self
215    }
216
217    /// Set the id2label mapping for classification models
218    pub fn with_id2label(mut self, id2label: HashMap<u32, String>) -> Self {
219        self.num_labels = id2label.len() as u32;
220        self.id2label = id2label;
221        self
222    }
223
224    /// Set num_labels directly (alternative to with_id2label)
225    pub fn with_num_labels(mut self, num_labels: u32) -> Self {
226        self.num_labels = num_labels;
227        self
228    }
229
230    // === Query methods ===
231
232    /// Check if this model matches the given ID (including aliases)
233    pub fn matches(&self, model_id: &str) -> bool {
234        self.id == model_id || self.aliases.iter().any(|a| a == model_id)
235    }
236
237    /// Check if this model supports a given endpoint
238    pub fn supports_endpoint(&self, endpoint: Endpoint) -> bool {
239        self.model_type.supports_endpoint(endpoint)
240    }
241
242    /// Get the display name or fall back to ID
243    pub fn name(&self) -> &str {
244        self.display_name.as_deref().unwrap_or(&self.id)
245    }
246
247    /// Check if this is a native/local model (no external provider)
248    #[inline]
249    pub fn is_native(&self) -> bool {
250        self.provider.is_none()
251    }
252
253    /// Check if this model uses an external provider
254    #[inline]
255    pub fn has_external_provider(&self) -> bool {
256        self.provider.is_some()
257    }
258
259    /// Check if this is an LLM (supports chat)
260    #[inline]
261    pub fn is_llm(&self) -> bool {
262        self.model_type.is_llm()
263    }
264
265    /// Check if this is an embedding model
266    #[inline]
267    pub fn is_embedding_model(&self) -> bool {
268        self.model_type.is_embedding_model()
269    }
270
271    /// Check if this model supports vision/multimodal
272    #[inline]
273    pub fn supports_vision(&self) -> bool {
274        self.model_type.supports_vision()
275    }
276
277    /// Check if this model supports tools/function calling
278    #[inline]
279    pub fn supports_tools(&self) -> bool {
280        self.model_type.supports_tools()
281    }
282
283    /// Check if this model supports reasoning
284    #[inline]
285    pub fn supports_reasoning(&self) -> bool {
286        self.model_type.supports_reasoning()
287    }
288
289    /// Check if this is a classification model
290    #[inline]
291    pub fn is_classifier(&self) -> bool {
292        self.num_labels > 0
293    }
294
295    /// Get label for a class index, with fallback to generic label (LABEL_N)
296    pub fn get_label(&self, class_idx: u32) -> String {
297        self.id2label
298            .get(&class_idx)
299            .cloned()
300            .unwrap_or_else(|| format!("LABEL_{class_idx}"))
301    }
302}
303
304impl Default for ModelCard {
305    fn default() -> Self {
306        Self::new(super::UNKNOWN_MODEL_ID)
307    }
308}
309
310impl std::fmt::Display for ModelCard {
311    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312        write!(f, "{}", self.name())
313    }
314}