Skip to main content

hf_fetch_model/
discover.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Model family discovery and search via the `HuggingFace` Hub API.
4//!
5//! Queries the HF Hub for popular models, extracts `model_type` metadata,
6//! compares against locally cached families, and fetches model card metadata.
7
8use std::collections::BTreeMap;
9use std::hash::BuildHasher;
10
11use serde::Deserialize;
12
13use crate::error::FetchError;
14
15/// A model found by searching the `HuggingFace` Hub.
16#[derive(Debug, Clone)]
17pub struct SearchResult {
18    /// The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
19    pub model_id: String,
20    /// Total download count.
21    pub downloads: u64,
22}
23
24/// A model family discovered from the `HuggingFace` Hub.
25#[derive(Debug, Clone)]
26pub struct DiscoveredFamily {
27    /// The `model_type` identifier (e.g., `"gpt_neox"`, `"llama"`).
28    pub model_type: String,
29    /// The most-downloaded representative model for this family.
30    pub top_model: String,
31    /// Download count of the representative model.
32    pub downloads: u64,
33}
34
35/// JSON response structure for an individual model from the HF API.
36#[derive(Debug, Deserialize)]
37struct ApiModelEntry {
38    #[serde(rename = "modelId")]
39    model_id: String,
40    #[serde(default)]
41    downloads: u64,
42    #[serde(default)]
43    config: Option<ApiConfig>,
44}
45
46/// The `config` object embedded in a model API response.
47#[derive(Debug, Deserialize)]
48struct ApiConfig {
49    model_type: Option<String>,
50}
51
52/// Access control status of a model on the `HuggingFace` Hub.
53///
54/// Some models require users to accept license terms before downloading.
55/// The gating mode determines whether approval is automatic or manual.
56#[derive(Debug, Clone, PartialEq, Eq)]
57#[non_exhaustive]
58pub enum GateStatus {
59    /// No gate — anyone can download without restrictions.
60    Open,
61    /// Automatic approval after the user accepts terms on the Hub.
62    Auto,
63    /// Manual approval by the model author after the user requests access.
64    Manual,
65}
66
67impl GateStatus {
68    /// Returns `true` if the model requires accepting terms before download.
69    #[must_use]
70    pub const fn is_gated(&self) -> bool {
71        matches!(self, Self::Auto | Self::Manual)
72    }
73}
74
75impl std::fmt::Display for GateStatus {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            Self::Open => write!(f, "open"),
79            Self::Auto => write!(f, "auto"),
80            Self::Manual => write!(f, "manual"),
81        }
82    }
83}
84
85/// Metadata from a `HuggingFace` model card.
86///
87/// Extracted from the single-model API endpoint
88/// (`GET /api/models/{owner}/{model}`). All fields are optional
89/// because model cards may omit any of them.
90#[derive(Debug, Clone)]
91pub struct ModelCardMetadata {
92    /// SPDX license identifier (e.g., `"apache-2.0"`).
93    pub license: Option<String>,
94    /// Pipeline tag (e.g., `"text-generation"`).
95    pub pipeline_tag: Option<String>,
96    /// Tags associated with the model (e.g., `["pytorch", "safetensors"]`).
97    pub tags: Vec<String>,
98    /// Library name (e.g., `"transformers"`, `"vllm"`).
99    pub library_name: Option<String>,
100    /// Languages the model supports (e.g., `["en", "fr"]`).
101    pub languages: Vec<String>,
102    /// Access control status (open, auto-gated, or manually gated).
103    pub gated: GateStatus,
104}
105
106/// JSON response for a single model from `GET /api/models/{model_id}`.
107#[derive(Debug, Deserialize)]
108struct ApiModelDetail {
109    #[serde(default)]
110    pipeline_tag: Option<String>,
111    #[serde(default)]
112    tags: Vec<String>,
113    #[serde(default)]
114    library_name: Option<String>,
115    #[serde(default)]
116    gated: ApiGated,
117    #[serde(default, rename = "cardData")]
118    card_data: Option<ApiCardData>,
119}
120
121/// The `cardData` sub-object (parsed YAML front matter from the model README).
122#[derive(Debug, Deserialize)]
123struct ApiCardData {
124    #[serde(default)]
125    license: Option<String>,
126    #[serde(default)]
127    language: Option<ApiLanguage>,
128}
129
130/// Languages in `cardData` can be a single string or a list of strings.
131#[derive(Debug, Deserialize)]
132#[serde(untagged)]
133enum ApiLanguage {
134    Single(String),
135    Multiple(Vec<String>),
136}
137
138/// The `gated` field can be `false` (boolean) or a string like `"auto"` / `"manual"`.
139#[derive(Debug, Deserialize)]
140#[serde(untagged)]
141enum ApiGated {
142    Bool(bool),
143    Mode(String),
144}
145
146impl Default for ApiGated {
147    fn default() -> Self {
148        Self::Bool(false)
149    }
150}
151
152const PAGE_SIZE: usize = 100;
153const HF_API_BASE: &str = "https://huggingface.co/api/models";
154
155/// Queries the `HuggingFace` Hub API for top models by downloads
156/// and returns families not present in the local cache.
157///
158/// # Arguments
159///
160/// * `local_families` — Set of `model_type` values already cached locally.
161/// * `max_models` — Maximum number of models to scan (paginated in batches of 100).
162///
163/// # Errors
164///
165/// Returns [`FetchError::Http`] if any API request fails.
166pub async fn discover_new_families<S: BuildHasher>(
167    local_families: &std::collections::HashSet<String, S>,
168    max_models: usize,
169) -> Result<Vec<DiscoveredFamily>, FetchError> {
170    let client = reqwest::Client::new();
171    let mut remote_families: BTreeMap<String, (String, u64)> = BTreeMap::new();
172    let mut offset: usize = 0;
173
174    while offset < max_models {
175        let limit = PAGE_SIZE.min(max_models.saturating_sub(offset));
176        let url = format!(
177            "{HF_API_BASE}?config=true&sort=downloads&direction=-1&limit={limit}&offset={offset}"
178        );
179
180        let response = client
181            .get(url.as_str()) // BORROW: explicit .as_str()
182            .send()
183            .await
184            .map_err(|e| FetchError::Http(e.to_string()))?;
185
186        if !response.status().is_success() {
187            return Err(FetchError::Http(format!(
188                "HF API returned status {}",
189                response.status()
190            )));
191        }
192
193        let models: Vec<ApiModelEntry> = response
194            .json()
195            .await
196            .map_err(|e| FetchError::Http(e.to_string()))?;
197
198        if models.is_empty() {
199            break;
200        }
201
202        for model in &models {
203            // BORROW: explicit .as_ref() and .as_str() for Option<String>
204            let model_type = model.config.as_ref().and_then(|c| c.model_type.as_deref());
205
206            if let Some(mt) = model_type {
207                remote_families
208                    .entry(mt.to_owned())
209                    .or_insert_with(|| (model.model_id.clone(), model.downloads));
210            }
211        }
212
213        offset = offset.saturating_add(models.len());
214    }
215
216    // Filter to families not already cached locally
217    // BORROW: explicit .as_str() instead of Deref coercion
218    let discovered: Vec<DiscoveredFamily> = remote_families
219        .into_iter()
220        .filter(|(mt, _)| !local_families.contains(mt.as_str()))
221        .map(|(model_type, (top_model, downloads))| DiscoveredFamily {
222            model_type,
223            top_model,
224            downloads,
225        })
226        .collect();
227
228    Ok(discovered)
229}
230
231/// Searches the `HuggingFace` Hub for models matching a query string.
232///
233/// Results are sorted by download count (most popular first).
234///
235/// # Arguments
236///
237/// * `query` — Free-text search string (e.g., `"RWKV-7"`, `"llama 3"`).
238/// * `limit` — Maximum number of results to return.
239///
240/// # Errors
241///
242/// Returns [`FetchError::Http`] if the API request fails.
243pub async fn search_models(query: &str, limit: usize) -> Result<Vec<SearchResult>, FetchError> {
244    let client = reqwest::Client::new();
245
246    let response = client
247        .get(HF_API_BASE)
248        .query(&[
249            ("search", query),
250            ("sort", "downloads"),
251            ("direction", "-1"),
252        ])
253        .query(&[("limit", limit)])
254        .send()
255        .await
256        .map_err(|e| FetchError::Http(e.to_string()))?;
257
258    if !response.status().is_success() {
259        return Err(FetchError::Http(format!(
260            "HF API returned status {}",
261            response.status()
262        )));
263    }
264
265    let models: Vec<ApiModelEntry> = response
266        .json()
267        .await
268        .map_err(|e| FetchError::Http(e.to_string()))?;
269
270    let results = models
271        .into_iter()
272        .map(|m| SearchResult {
273            model_id: m.model_id,
274            downloads: m.downloads,
275        })
276        .collect();
277
278    Ok(results)
279}
280
281/// Fetches model card metadata for a specific model from the `HuggingFace` Hub.
282///
283/// Queries `GET https://huggingface.co/api/models/{model_id}` and extracts
284/// license, pipeline tag, tags, library name, and languages from the response.
285///
286/// # Arguments
287///
288/// * `model_id` — The full model identifier (e.g., `"mistralai/Ministral-3-3B-Instruct-2512"`).
289///
290/// # Errors
291///
292/// Returns [`FetchError::Http`] if the API request fails or the model is not found.
293pub async fn fetch_model_card(model_id: &str) -> Result<ModelCardMetadata, FetchError> {
294    let client = reqwest::Client::new();
295    let url = format!("{HF_API_BASE}/{model_id}");
296
297    let response = client
298        .get(url.as_str()) // BORROW: explicit .as_str()
299        .send()
300        .await
301        .map_err(|e| FetchError::Http(e.to_string()))?;
302
303    if !response.status().is_success() {
304        return Err(FetchError::Http(format!(
305            "HF API returned status {} for model {model_id}",
306            response.status()
307        )));
308    }
309
310    let detail: ApiModelDetail = response
311        .json()
312        .await
313        .map_err(|e| FetchError::Http(e.to_string()))?;
314
315    let (license, languages) = if let Some(card) = detail.card_data {
316        let langs = match card.language {
317            Some(ApiLanguage::Single(s)) => vec![s],
318            Some(ApiLanguage::Multiple(v)) => v,
319            None => Vec::new(),
320        };
321        (card.license, langs)
322    } else {
323        (None, Vec::new())
324    };
325
326    let gated = match detail.gated {
327        ApiGated::Bool(false) => GateStatus::Open,
328        ApiGated::Mode(ref mode) if mode.eq_ignore_ascii_case("manual") => GateStatus::Manual,
329        ApiGated::Bool(true) | ApiGated::Mode(_) => GateStatus::Auto,
330    };
331
332    Ok(ModelCardMetadata {
333        license,
334        pipeline_tag: detail.pipeline_tag,
335        tags: detail.tags,
336        library_name: detail.library_name,
337        languages,
338        gated,
339    })
340}