use std::collections::BTreeMap;
use std::hash::BuildHasher;
use serde::Deserialize;
use crate::error::FetchError;
#[derive(Debug, Clone)]
pub struct SearchResult {
pub model_id: String,
pub downloads: u64,
pub library_name: Option<String>,
pub pipeline_tag: Option<String>,
pub tags: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct DiscoveredFamily {
pub model_type: String,
pub top_model: String,
pub downloads: u64,
}
#[derive(Debug, Deserialize)]
struct ApiModelEntry {
#[serde(rename = "modelId")]
model_id: String,
#[serde(default)]
downloads: u64,
#[serde(default)]
config: Option<ApiConfig>,
#[serde(default)]
library_name: Option<String>,
#[serde(default)]
pipeline_tag: Option<String>,
#[serde(default)]
tags: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct ApiConfig {
model_type: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum GateStatus {
Open,
Auto,
Manual,
}
impl GateStatus {
#[must_use]
pub const fn is_gated(&self) -> bool {
matches!(self, Self::Auto | Self::Manual)
}
}
impl std::fmt::Display for GateStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Open => write!(f, "open"),
Self::Auto => write!(f, "auto"),
Self::Manual => write!(f, "manual"),
}
}
}
#[derive(Debug, Clone)]
pub struct ModelCardMetadata {
pub license: Option<String>,
pub pipeline_tag: Option<String>,
pub tags: Vec<String>,
pub library_name: Option<String>,
pub languages: Vec<String>,
pub gated: GateStatus,
}
#[derive(Debug, Deserialize)]
struct ApiModelDetail {
#[serde(default)]
pipeline_tag: Option<String>,
#[serde(default)]
tags: Vec<String>,
#[serde(default)]
library_name: Option<String>,
#[serde(default)]
gated: ApiGated,
#[serde(default, rename = "cardData")]
card_data: Option<ApiCardData>,
}
#[derive(Debug, Deserialize)]
struct ApiCardData {
#[serde(default)]
license: Option<String>,
#[serde(default)]
language: Option<ApiLanguage>,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiLanguage {
Single(String),
Multiple(Vec<String>),
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiGated {
Bool(bool),
Mode(String),
}
impl Default for ApiGated {
fn default() -> Self {
Self::Bool(false)
}
}
const PAGE_SIZE: usize = 100;
const HF_API_BASE: &str = "https://huggingface.co/api/models";
pub async fn discover_new_families<S: BuildHasher>(
local_families: &std::collections::HashSet<String, S>,
max_models: usize,
) -> Result<Vec<DiscoveredFamily>, FetchError> {
let client = reqwest::Client::new();
let mut remote_families: BTreeMap<String, (String, u64)> = BTreeMap::new();
let mut offset: usize = 0;
while offset < max_models {
let limit = PAGE_SIZE.min(max_models.saturating_sub(offset));
let url = format!(
"{HF_API_BASE}?config=true&sort=downloads&direction=-1&limit={limit}&offset={offset}"
);
let response = client
.get(url.as_str()) .send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"HF API returned status {}",
response.status()
)));
}
let models: Vec<ApiModelEntry> = response
.json()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
if models.is_empty() {
break;
}
for model in &models {
let model_type = model.config.as_ref().and_then(|c| c.model_type.as_deref());
if let Some(mt) = model_type {
remote_families
.entry(mt.to_owned())
.or_insert_with(|| (model.model_id.clone(), model.downloads));
}
}
offset = offset.saturating_add(models.len());
}
let discovered: Vec<DiscoveredFamily> = remote_families
.into_iter()
.filter(|(mt, _)| !local_families.contains(mt.as_str()))
.map(|(model_type, (top_model, downloads))| DiscoveredFamily {
model_type,
top_model,
downloads,
})
.collect();
Ok(discovered)
}
#[must_use]
fn normalize_quantization_terms(query: &str) -> String {
const SYNONYMS: &[(&[&str], &str)] = &[
(&["8bit", "8-bit", "int8"], "8-bit"),
(&["4bit", "4-bit", "int4"], "4-bit"),
(&["fp8", "float8"], "fp8"),
];
query
.split_whitespace()
.map(|token| {
let lower = token.to_lowercase();
for &(variants, canonical) in SYNONYMS {
if variants.contains(&lower.as_str()) {
return (*canonical).to_owned();
}
}
token.to_owned()
})
.collect::<Vec<_>>()
.join(" ")
}
pub async fn search_models(
query: &str,
limit: usize,
library: Option<&str>,
pipeline: Option<&str>,
tag: Option<&str>,
) -> Result<Vec<SearchResult>, FetchError> {
let normalized = normalize_quantization_terms(query);
let client = reqwest::Client::new();
let mut query_params: Vec<(&str, &str)> = vec![
("search", normalized.as_str()),
("sort", "downloads"),
("direction", "-1"),
];
if let Some(lib) = library {
query_params.push(("library", lib));
}
if let Some(pipe) = pipeline {
query_params.push(("pipeline_tag", pipe));
}
if let Some(t) = tag {
query_params.push(("filter", t));
}
let response = client
.get(HF_API_BASE)
.query(&query_params)
.query(&[("limit", limit)])
.send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"HF API returned status {}",
response.status()
)));
}
let models: Vec<ApiModelEntry> = response
.json()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
let results = models
.into_iter()
.filter(|m| {
if let Some(lib) = library {
match m.library_name {
Some(ref name) if name.as_str().eq_ignore_ascii_case(lib) => {}
_ => return false,
}
}
if let Some(pipe) = pipeline {
match m.pipeline_tag {
Some(ref t) if t.as_str().eq_ignore_ascii_case(pipe) => {}
_ => return false,
}
}
if let Some(t) = tag {
if !m.tags.iter().any(|model_tag| {
model_tag.as_str().eq_ignore_ascii_case(t)
}) {
return false;
}
}
true
})
.map(|m| SearchResult {
model_id: m.model_id,
downloads: m.downloads,
library_name: m.library_name,
pipeline_tag: m.pipeline_tag,
tags: m.tags,
})
.collect();
Ok(results)
}
pub async fn fetch_model_card(model_id: &str) -> Result<ModelCardMetadata, FetchError> {
let client = reqwest::Client::new();
let url = format!("{HF_API_BASE}/{model_id}");
let response = client
.get(url.as_str()) .send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"HF API returned status {} for model {model_id}",
response.status()
)));
}
let detail: ApiModelDetail = response
.json()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
let (license, languages) = if let Some(card) = detail.card_data {
let langs = match card.language {
Some(ApiLanguage::Single(s)) => vec![s],
Some(ApiLanguage::Multiple(v)) => v,
None => Vec::new(),
};
(card.license, langs)
} else {
(None, Vec::new())
};
let gated = match detail.gated {
ApiGated::Bool(false) => GateStatus::Open,
ApiGated::Mode(ref mode) if mode.eq_ignore_ascii_case("manual") => GateStatus::Manual,
ApiGated::Bool(true) | ApiGated::Mode(_) => GateStatus::Auto,
};
Ok(ModelCardMetadata {
license,
pipeline_tag: detail.pipeline_tag,
tags: detail.tags,
library_name: detail.library_name,
languages,
gated,
})
}
pub async fn fetch_readme(
model_id: &str,
revision: Option<&str>,
token: Option<&str>,
) -> Result<Option<String>, FetchError> {
let rev = revision.unwrap_or("main");
let url = crate::chunked::build_download_url(model_id, rev, "README.md");
let client = crate::chunked::build_client(token)?;
let response = client
.get(url.as_str()) .send()
.await
.map_err(|e| FetchError::Http(format!("failed to fetch README for {model_id}: {e}")))?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"README request for {model_id} returned status {}",
response.status()
)));
}
let text = response
.text()
.await
.map_err(|e| FetchError::Http(format!("failed to read README for {model_id}: {e}")))?;
Ok(Some(text))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_8bit_variants() {
assert_eq!(normalize_quantization_terms("AWQ 8bit"), "AWQ 8-bit");
assert_eq!(normalize_quantization_terms("AWQ 8-bit"), "AWQ 8-bit");
assert_eq!(normalize_quantization_terms("AWQ int8"), "AWQ 8-bit");
assert_eq!(normalize_quantization_terms("AWQ INT8"), "AWQ 8-bit");
}
#[test]
fn normalize_4bit_variants() {
assert_eq!(normalize_quantization_terms("GPTQ 4bit"), "GPTQ 4-bit");
assert_eq!(normalize_quantization_terms("GPTQ INT4"), "GPTQ 4-bit");
assert_eq!(normalize_quantization_terms("GPTQ 4-bit"), "GPTQ 4-bit");
}
#[test]
fn normalize_fp8_variants() {
assert_eq!(normalize_quantization_terms("FP8"), "fp8");
assert_eq!(normalize_quantization_terms("float8"), "fp8");
assert_eq!(normalize_quantization_terms("fp8"), "fp8");
}
#[test]
fn normalize_passthrough() {
assert_eq!(normalize_quantization_terms("llama 3"), "llama 3");
assert_eq!(normalize_quantization_terms("RWKV-7"), "RWKV-7");
}
}