Skip to main content

openai_protocol/
model_card.rs

1//! Model card definitions for worker model configuration.
2//!
3//! Defines [`ModelCard`] which consolidates model-related configuration:
4//! identity, capabilities, tokenization, and classification support.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use super::{
11    model_type::{Endpoint, ModelType},
12    models::ModelObject,
13    worker::ProviderType,
14};
15
16#[expect(
17    clippy::trivially_copy_pass_by_ref,
18    reason = "serde skip_serializing_if passes &T"
19)]
20fn is_zero(n: &u32) -> bool {
21    *n == 0
22}
23
24fn default_model_type() -> ModelType {
25    ModelType::LLM
26}
27
28/// Model card containing model configuration and capabilities.
29///
30/// # Example
31///
32/// ```
33/// use openai_protocol::{model_type::ModelType, model_card::ModelCard, worker::ProviderType};
34///
35/// let card = ModelCard::new("meta-llama/Llama-3.1-8B-Instruct")
36///     .with_display_name("Llama 3.1 8B Instruct")
37///     .with_alias("llama-3.1-8b")
38///     .with_model_type(ModelType::VISION_LLM)
39///     .with_context_length(128_000)
40///     .with_tokenizer_path("meta-llama/Llama-3.1-8B-Instruct");
41///
42/// assert!(card.matches("llama-3.1-8b"));
43/// assert!(card.model_type.supports_vision());
44/// assert!(card.provider.is_none()); // Local model, no external provider
45/// ```
46#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
47pub struct ModelCard {
48    // === Identity ===
49    /// Primary model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct")
50    pub id: String,
51
52    /// Optional display name (e.g., "Llama 3.1 8B Instruct")
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub display_name: Option<String>,
55
56    /// Alternative names/aliases for this model
57    #[serde(default, skip_serializing_if = "Vec::is_empty")]
58    pub aliases: Vec<String>,
59
60    // === Capabilities ===
61    /// Supported endpoint types (bitflags)
62    #[serde(default = "default_model_type")]
63    pub model_type: ModelType,
64
65    /// HuggingFace model type string (e.g., "llama", "qwen2", "gpt-oss")
66    #[serde(default, skip_serializing_if = "Option::is_none")]
67    pub hf_model_type: Option<String>,
68
69    /// Model architectures from HuggingFace config (e.g., ["LlamaForCausalLM"])
70    #[serde(default, skip_serializing_if = "Vec::is_empty")]
71    pub architectures: Vec<String>,
72
73    /// Provider hint for API transformations.
74    /// `None` means native/passthrough (no transformation needed).
75    #[serde(default, skip_serializing_if = "Option::is_none")]
76    pub provider: Option<ProviderType>,
77
78    /// Maximum context length in tokens
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub context_length: Option<u32>,
81
82    // === Tokenization & Parsing ===
83    /// Path to tokenizer (e.g., HuggingFace model ID or local path)
84    #[serde(default, skip_serializing_if = "Option::is_none")]
85    pub tokenizer_path: Option<String>,
86
87    /// Chat template (Jinja2 template string or path)
88    #[serde(default, skip_serializing_if = "Option::is_none")]
89    pub chat_template: Option<String>,
90
91    /// Reasoning parser type (e.g., "deepseek", "qwen")
92    #[serde(default, skip_serializing_if = "Option::is_none")]
93    pub reasoning_parser: Option<String>,
94
95    /// Tool/function calling parser type (e.g., "llama", "mistral")
96    #[serde(default, skip_serializing_if = "Option::is_none")]
97    pub tool_parser: Option<String>,
98
99    /// User-defined metadata (for fields not covered above)
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    pub metadata: Option<serde_json::Value>,
102
103    // === Classification Support ===
104    /// Classification label mapping (class index -> label name).
105    /// Empty if not a classification model.
106    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
107    pub id2label: HashMap<u32, String>,
108
109    /// Number of classification labels (0 if not a classifier).
110    #[serde(default, skip_serializing_if = "is_zero")]
111    pub num_labels: u32,
112}
113
114impl ModelCard {
115    /// Create a new model card with minimal configuration.
116    ///
117    /// Defaults to `ModelType::LLM` and no provider (native/passthrough).
118    pub fn new(id: impl Into<String>) -> Self {
119        Self {
120            id: id.into(),
121            display_name: None,
122            aliases: Vec::new(),
123            model_type: ModelType::LLM,
124            hf_model_type: None,
125            architectures: Vec::new(),
126            provider: None,
127            context_length: None,
128            tokenizer_path: None,
129            chat_template: None,
130            reasoning_parser: None,
131            tool_parser: None,
132            metadata: None,
133            id2label: HashMap::new(),
134            num_labels: 0,
135        }
136    }
137
138    // === Builder-style methods ===
139
140    /// Set the display name
141    pub fn with_display_name(mut self, name: impl Into<String>) -> Self {
142        self.display_name = Some(name.into());
143        self
144    }
145
146    /// Add a single alias
147    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
148        self.aliases.push(alias.into());
149        self
150    }
151
152    /// Add multiple aliases
153    pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
154        self.aliases.extend(aliases.into_iter().map(|a| a.into()));
155        self
156    }
157
158    /// Set the model type (capabilities)
159    pub fn with_model_type(mut self, model_type: ModelType) -> Self {
160        self.model_type = model_type;
161        self
162    }
163
164    /// Set the HuggingFace model type string
165    pub fn with_hf_model_type(mut self, hf_model_type: impl Into<String>) -> Self {
166        self.hf_model_type = Some(hf_model_type.into());
167        self
168    }
169
170    /// Set the model architectures
171    pub fn with_architectures(mut self, architectures: Vec<String>) -> Self {
172        self.architectures = architectures;
173        self
174    }
175
176    /// Set the provider type (for external API transformations)
177    pub fn with_provider(mut self, provider: ProviderType) -> Self {
178        self.provider = Some(provider);
179        self
180    }
181
182    /// Set the context length
183    pub fn with_context_length(mut self, length: u32) -> Self {
184        self.context_length = Some(length);
185        self
186    }
187
188    /// Set the tokenizer path
189    pub fn with_tokenizer_path(mut self, path: impl Into<String>) -> Self {
190        self.tokenizer_path = Some(path.into());
191        self
192    }
193
194    /// Set the chat template
195    pub fn with_chat_template(mut self, template: impl Into<String>) -> Self {
196        self.chat_template = Some(template.into());
197        self
198    }
199
200    /// Set the reasoning parser type
201    pub fn with_reasoning_parser(mut self, parser: impl Into<String>) -> Self {
202        self.reasoning_parser = Some(parser.into());
203        self
204    }
205
206    /// Set the tool parser type
207    pub fn with_tool_parser(mut self, parser: impl Into<String>) -> Self {
208        self.tool_parser = Some(parser.into());
209        self
210    }
211
212    /// Set custom metadata
213    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
214        self.metadata = Some(metadata);
215        self
216    }
217
218    /// Set the id2label mapping for classification models
219    pub fn with_id2label(mut self, id2label: HashMap<u32, String>) -> Self {
220        self.num_labels = id2label.len() as u32;
221        self.id2label = id2label;
222        self
223    }
224
225    /// Set num_labels directly (alternative to with_id2label)
226    pub fn with_num_labels(mut self, num_labels: u32) -> Self {
227        self.num_labels = num_labels;
228        self
229    }
230
231    // === Query methods ===
232
233    /// Check if this model matches the given ID (including aliases)
234    pub fn matches(&self, model_id: &str) -> bool {
235        self.id == model_id || self.aliases.iter().any(|a| a == model_id)
236    }
237
238    /// Check if this model supports a given endpoint
239    pub fn supports_endpoint(&self, endpoint: Endpoint) -> bool {
240        self.model_type.supports_endpoint(endpoint)
241    }
242
243    /// Get the display name or fall back to ID
244    pub fn name(&self) -> &str {
245        self.display_name.as_deref().unwrap_or(&self.id)
246    }
247
248    /// Check if this is a native/local model (no external provider)
249    #[inline]
250    pub fn is_native(&self) -> bool {
251        self.provider.is_none()
252    }
253
254    /// Check if this model uses an external provider
255    #[inline]
256    pub fn has_external_provider(&self) -> bool {
257        self.provider.is_some()
258    }
259
260    /// Check if this is an LLM (supports chat)
261    #[inline]
262    pub fn is_llm(&self) -> bool {
263        self.model_type.is_llm()
264    }
265
266    /// Check if this is an embedding model
267    #[inline]
268    pub fn is_embedding_model(&self) -> bool {
269        self.model_type.is_embedding_model()
270    }
271
272    /// Check if this model supports vision/multimodal
273    #[inline]
274    pub fn supports_vision(&self) -> bool {
275        self.model_type.supports_vision()
276    }
277
278    /// Check if this model supports tools/function calling
279    #[inline]
280    pub fn supports_tools(&self) -> bool {
281        self.model_type.supports_tools()
282    }
283
284    /// Check if this model supports reasoning
285    #[inline]
286    pub fn supports_reasoning(&self) -> bool {
287        self.model_type.supports_reasoning()
288    }
289
290    /// Get the `owned_by` string for this model.
291    ///
292    /// Maps `None` → `"self_hosted"`, provider → `provider.as_str()`.
293    pub fn owned_by(&self) -> &str {
294        match &self.provider {
295            Some(p) => p.as_str(),
296            None => "self_hosted",
297        }
298    }
299
300    /// Convert this model card into an OpenAI-compatible [`ModelObject`],
301    /// consuming `self` to avoid cloning the model ID.
302    pub fn into_model_object(self) -> ModelObject {
303        let owned_by = self.owned_by().to_owned();
304        ModelObject {
305            id: self.id,
306            object: "model".to_owned(),
307            created: 0,
308            owned_by,
309        }
310    }
311
312    /// Check if this is a classification model
313    #[inline]
314    pub fn is_classifier(&self) -> bool {
315        self.num_labels > 0
316    }
317
318    /// Get label for a class index, with fallback to generic label (LABEL_N)
319    pub fn get_label(&self, class_idx: u32) -> String {
320        self.id2label
321            .get(&class_idx)
322            .cloned()
323            .unwrap_or_else(|| format!("LABEL_{class_idx}"))
324    }
325}
326
327impl Default for ModelCard {
328    fn default() -> Self {
329        Self::new(super::UNKNOWN_MODEL_ID)
330    }
331}
332
333impl std::fmt::Display for ModelCard {
334    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335        write!(f, "{}", self.name())
336    }
337}