use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::{
model_type::{Endpoint, ModelType},
models::ModelObject,
worker::ProviderType,
};
#[expect(
clippy::trivially_copy_pass_by_ref,
reason = "serde skip_serializing_if passes &T"
)]
fn is_zero(n: &u32) -> bool {
*n == 0
}
fn default_model_type() -> ModelType {
ModelType::LLM
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct ModelCard {
pub id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub aliases: Vec<String>,
#[serde(default = "default_model_type")]
pub model_type: ModelType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub hf_model_type: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub architectures: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provider: Option<ProviderType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub context_length: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tokenizer_path: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning_parser: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_parser: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub id2label: HashMap<u32, String>,
#[serde(default, skip_serializing_if = "is_zero")]
pub num_labels: u32,
}
impl ModelCard {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
display_name: None,
aliases: Vec::new(),
model_type: ModelType::LLM,
hf_model_type: None,
architectures: Vec::new(),
provider: None,
context_length: None,
tokenizer_path: None,
chat_template: None,
reasoning_parser: None,
tool_parser: None,
metadata: None,
id2label: HashMap::new(),
num_labels: 0,
}
}
pub fn with_display_name(mut self, name: impl Into<String>) -> Self {
self.display_name = Some(name.into());
self
}
pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
self.aliases.push(alias.into());
self
}
pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.aliases.extend(aliases.into_iter().map(|a| a.into()));
self
}
pub fn with_model_type(mut self, model_type: ModelType) -> Self {
self.model_type = model_type;
self
}
pub fn with_hf_model_type(mut self, hf_model_type: impl Into<String>) -> Self {
self.hf_model_type = Some(hf_model_type.into());
self
}
pub fn with_architectures(mut self, architectures: Vec<String>) -> Self {
self.architectures = architectures;
self
}
pub fn with_provider(mut self, provider: ProviderType) -> Self {
self.provider = Some(provider);
self
}
pub fn with_context_length(mut self, length: u32) -> Self {
self.context_length = Some(length);
self
}
pub fn with_tokenizer_path(mut self, path: impl Into<String>) -> Self {
self.tokenizer_path = Some(path.into());
self
}
pub fn with_chat_template(mut self, template: impl Into<String>) -> Self {
self.chat_template = Some(template.into());
self
}
pub fn with_reasoning_parser(mut self, parser: impl Into<String>) -> Self {
self.reasoning_parser = Some(parser.into());
self
}
pub fn with_tool_parser(mut self, parser: impl Into<String>) -> Self {
self.tool_parser = Some(parser.into());
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn with_id2label(mut self, id2label: HashMap<u32, String>) -> Self {
self.num_labels = id2label.len() as u32;
self.id2label = id2label;
self
}
pub fn with_num_labels(mut self, num_labels: u32) -> Self {
self.num_labels = num_labels;
self
}
pub fn matches(&self, model_id: &str) -> bool {
self.id == model_id || self.aliases.iter().any(|a| a == model_id)
}
pub fn supports_endpoint(&self, endpoint: Endpoint) -> bool {
self.model_type.supports_endpoint(endpoint)
}
pub fn name(&self) -> &str {
self.display_name.as_deref().unwrap_or(&self.id)
}
#[inline]
pub fn is_native(&self) -> bool {
self.provider.is_none()
}
#[inline]
pub fn has_external_provider(&self) -> bool {
self.provider.is_some()
}
#[inline]
pub fn is_llm(&self) -> bool {
self.model_type.is_llm()
}
#[inline]
pub fn is_embedding_model(&self) -> bool {
self.model_type.is_embedding_model()
}
#[inline]
pub fn supports_vision(&self) -> bool {
self.model_type.supports_vision()
}
#[inline]
pub fn supports_tools(&self) -> bool {
self.model_type.supports_tools()
}
#[inline]
pub fn supports_reasoning(&self) -> bool {
self.model_type.supports_reasoning()
}
pub fn owned_by(&self) -> &str {
match &self.provider {
Some(p) => p.as_str(),
None => "self_hosted",
}
}
pub fn into_model_object(self) -> ModelObject {
let owned_by = self.owned_by().to_owned();
ModelObject {
id: self.id,
object: "model".to_owned(),
created: 0,
owned_by,
}
}
#[inline]
pub fn is_classifier(&self) -> bool {
self.num_labels > 0
}
pub fn get_label(&self, class_idx: u32) -> String {
self.id2label
.get(&class_idx)
.cloned()
.unwrap_or_else(|| format!("LABEL_{class_idx}"))
}
}
impl Default for ModelCard {
fn default() -> Self {
Self::new(super::UNKNOWN_MODEL_ID)
}
}
impl std::fmt::Display for ModelCard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}