Skip to main content

hf_fetch_model/
discover.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Model family discovery and search via the `HuggingFace` Hub API.
4//!
5//! Queries the HF Hub for popular models, extracts `model_type` metadata,
6//! compares against locally cached families, and fetches model card metadata.
7
8use std::collections::BTreeMap;
9use std::hash::BuildHasher;
10
11use serde::Deserialize;
12
13use crate::error::FetchError;
14
15/// A model found by searching the `HuggingFace` Hub.
16#[derive(Debug, Clone)]
17pub struct SearchResult {
18    /// The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
19    pub model_id: String,
20    /// Total download count.
21    pub downloads: u64,
22    /// Library framework (e.g., `"transformers"`, `"peft"`, `"diffusers"`), if reported.
23    pub library_name: Option<String>,
24    /// Pipeline task tag (e.g., `"text-generation"`), if reported.
25    pub pipeline_tag: Option<String>,
26}
27
28/// A model family discovered from the `HuggingFace` Hub.
29#[derive(Debug, Clone)]
30pub struct DiscoveredFamily {
31    /// The `model_type` identifier (e.g., `"gpt_neox"`, `"llama"`).
32    pub model_type: String,
33    /// The most-downloaded representative model for this family.
34    pub top_model: String,
35    /// Download count of the representative model.
36    pub downloads: u64,
37}
38
39/// JSON response structure for an individual model from the HF API.
40#[derive(Debug, Deserialize)]
41struct ApiModelEntry {
42    #[serde(rename = "modelId")]
43    model_id: String,
44    #[serde(default)]
45    downloads: u64,
46    #[serde(default)]
47    config: Option<ApiConfig>,
48    #[serde(default)]
49    library_name: Option<String>,
50    #[serde(default)]
51    pipeline_tag: Option<String>,
52}
53
54/// The `config` object embedded in a model API response.
55#[derive(Debug, Deserialize)]
56struct ApiConfig {
57    model_type: Option<String>,
58}
59
60/// Access control status of a model on the `HuggingFace` Hub.
61///
62/// Some models require users to accept license terms before downloading.
63/// The gating mode determines whether approval is automatic or manual.
64#[derive(Debug, Clone, PartialEq, Eq)]
65#[non_exhaustive]
66pub enum GateStatus {
67    /// No gate — anyone can download without restrictions.
68    Open,
69    /// Automatic approval after the user accepts terms on the Hub.
70    Auto,
71    /// Manual approval by the model author after the user requests access.
72    Manual,
73}
74
75impl GateStatus {
76    /// Returns `true` if the model requires accepting terms before download.
77    #[must_use]
78    pub const fn is_gated(&self) -> bool {
79        matches!(self, Self::Auto | Self::Manual)
80    }
81}
82
83impl std::fmt::Display for GateStatus {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match self {
86            Self::Open => write!(f, "open"),
87            Self::Auto => write!(f, "auto"),
88            Self::Manual => write!(f, "manual"),
89        }
90    }
91}
92
93/// Metadata from a `HuggingFace` model card.
94///
95/// Extracted from the single-model API endpoint
96/// (`GET /api/models/{owner}/{model}`). All fields are optional
97/// because model cards may omit any of them.
98#[derive(Debug, Clone)]
99pub struct ModelCardMetadata {
100    /// SPDX license identifier (e.g., `"apache-2.0"`).
101    pub license: Option<String>,
102    /// Pipeline tag (e.g., `"text-generation"`).
103    pub pipeline_tag: Option<String>,
104    /// Tags associated with the model (e.g., `["pytorch", "safetensors"]`).
105    pub tags: Vec<String>,
106    /// Library name (e.g., `"transformers"`, `"vllm"`).
107    pub library_name: Option<String>,
108    /// Languages the model supports (e.g., `["en", "fr"]`).
109    pub languages: Vec<String>,
110    /// Access control status (open, auto-gated, or manually gated).
111    pub gated: GateStatus,
112}
113
114/// JSON response for a single model from `GET /api/models/{model_id}`.
115#[derive(Debug, Deserialize)]
116struct ApiModelDetail {
117    #[serde(default)]
118    pipeline_tag: Option<String>,
119    #[serde(default)]
120    tags: Vec<String>,
121    #[serde(default)]
122    library_name: Option<String>,
123    #[serde(default)]
124    gated: ApiGated,
125    #[serde(default, rename = "cardData")]
126    card_data: Option<ApiCardData>,
127}
128
129/// The `cardData` sub-object (parsed YAML front matter from the model README).
130#[derive(Debug, Deserialize)]
131struct ApiCardData {
132    #[serde(default)]
133    license: Option<String>,
134    #[serde(default)]
135    language: Option<ApiLanguage>,
136}
137
138/// Languages in `cardData` can be a single string or a list of strings.
139#[derive(Debug, Deserialize)]
140#[serde(untagged)]
141enum ApiLanguage {
142    Single(String),
143    Multiple(Vec<String>),
144}
145
146/// The `gated` field can be `false` (boolean) or a string like `"auto"` / `"manual"`.
147#[derive(Debug, Deserialize)]
148#[serde(untagged)]
149enum ApiGated {
150    Bool(bool),
151    Mode(String),
152}
153
154impl Default for ApiGated {
155    fn default() -> Self {
156        Self::Bool(false)
157    }
158}
159
160const PAGE_SIZE: usize = 100;
161const HF_API_BASE: &str = "https://huggingface.co/api/models";
162
163/// Queries the `HuggingFace` Hub API for top models by downloads
164/// and returns families not present in the local cache.
165///
166/// # Arguments
167///
168/// * `local_families` — Set of `model_type` values already cached locally.
169/// * `max_models` — Maximum number of models to scan (paginated in batches of 100).
170///
171/// # Errors
172///
173/// Returns [`FetchError::Http`] if any API request fails.
174pub async fn discover_new_families<S: BuildHasher>(
175    local_families: &std::collections::HashSet<String, S>,
176    max_models: usize,
177) -> Result<Vec<DiscoveredFamily>, FetchError> {
178    let client = reqwest::Client::new();
179    let mut remote_families: BTreeMap<String, (String, u64)> = BTreeMap::new();
180    let mut offset: usize = 0;
181
182    while offset < max_models {
183        let limit = PAGE_SIZE.min(max_models.saturating_sub(offset));
184        let url = format!(
185            "{HF_API_BASE}?config=true&sort=downloads&direction=-1&limit={limit}&offset={offset}"
186        );
187
188        let response = client
189            .get(url.as_str()) // BORROW: explicit .as_str()
190            .send()
191            .await
192            .map_err(|e| FetchError::Http(e.to_string()))?;
193
194        if !response.status().is_success() {
195            return Err(FetchError::Http(format!(
196                "HF API returned status {}",
197                response.status()
198            )));
199        }
200
201        let models: Vec<ApiModelEntry> = response
202            .json()
203            .await
204            .map_err(|e| FetchError::Http(e.to_string()))?;
205
206        if models.is_empty() {
207            break;
208        }
209
210        for model in &models {
211            // BORROW: explicit .as_ref() and .as_str() for Option<String>
212            let model_type = model.config.as_ref().and_then(|c| c.model_type.as_deref());
213
214            if let Some(mt) = model_type {
215                remote_families
216                    .entry(mt.to_owned())
217                    .or_insert_with(|| (model.model_id.clone(), model.downloads));
218            }
219        }
220
221        offset = offset.saturating_add(models.len());
222    }
223
224    // Filter to families not already cached locally
225    // BORROW: explicit .as_str() instead of Deref coercion
226    let discovered: Vec<DiscoveredFamily> = remote_families
227        .into_iter()
228        .filter(|(mt, _)| !local_families.contains(mt.as_str()))
229        .map(|(model_type, (top_model, downloads))| DiscoveredFamily {
230            model_type,
231            top_model,
232            downloads,
233        })
234        .collect();
235
236    Ok(discovered)
237}
238
239/// Normalizes common quantization synonyms in a search query so that
240/// variant spellings (e.g., `"8bit"`, `"8-bit"`, `"int8"`) produce
241/// consistent results.
242#[must_use]
243fn normalize_quantization_terms(query: &str) -> String {
244    /// Synonym groups: all variants map to the first (canonical) form.
245    const SYNONYMS: &[(&[&str], &str)] = &[
246        (&["8bit", "8-bit", "int8"], "8-bit"),
247        (&["4bit", "4-bit", "int4"], "4-bit"),
248        (&["fp8", "float8"], "fp8"),
249    ];
250
251    query
252        .split_whitespace()
253        .map(|token| {
254            // BORROW: explicit .to_lowercase() for case-insensitive comparison
255            let lower = token.to_lowercase();
256            for &(variants, canonical) in SYNONYMS {
257                // BORROW: explicit .as_str() instead of Deref coercion
258                if variants.contains(&lower.as_str()) {
259                    // BORROW: explicit .to_owned() for &str → owned String
260                    return (*canonical).to_owned();
261                }
262            }
263            // BORROW: explicit .to_owned() for &str → owned String
264            token.to_owned()
265        })
266        .collect::<Vec<_>>()
267        .join(" ")
268}
269
270/// Searches the `HuggingFace` Hub for models matching a query string.
271///
272/// Optionally filters by `library` framework (e.g., `"transformers"`, `"peft"`)
273/// and/or `pipeline` task tag (e.g., `"text-generation"`). These filters are
274/// applied server-side by the `HuggingFace` API.
275///
276/// Common quantization synonyms (`"8bit"` / `"8-bit"` / `"int8"`,
277/// `"4bit"` / `"4-bit"` / `"int4"`, `"fp8"` / `"float8"`) are normalized
278/// before querying the API so that variant spellings return consistent results.
279///
280/// Results are sorted by download count (most popular first).
281///
282/// # Arguments
283///
284/// * `query` — Free-text search string (e.g., `"RWKV-7"`, `"llama 3"`).
285/// * `limit` — Maximum number of results to return.
286/// * `library` — Optional library filter (e.g., `"peft"`, `"transformers"`).
287/// * `pipeline` — Optional pipeline tag filter (e.g., `"text-generation"`).
288///
289/// # Errors
290///
291/// Returns [`FetchError::Http`] if the API request fails.
292pub async fn search_models(
293    query: &str,
294    limit: usize,
295    library: Option<&str>,
296    pipeline: Option<&str>,
297) -> Result<Vec<SearchResult>, FetchError> {
298    let normalized = normalize_quantization_terms(query);
299    let client = reqwest::Client::new();
300
301    // BORROW: explicit .as_str() instead of Deref coercion
302    let mut query_params: Vec<(&str, &str)> = vec![
303        ("search", normalized.as_str()),
304        ("sort", "downloads"),
305        ("direction", "-1"),
306    ];
307    if let Some(lib) = library {
308        query_params.push(("library", lib));
309    }
310    if let Some(pipe) = pipeline {
311        query_params.push(("pipeline_tag", pipe));
312    }
313
314    let response = client
315        .get(HF_API_BASE)
316        .query(&query_params)
317        .query(&[("limit", limit)])
318        .send()
319        .await
320        .map_err(|e| FetchError::Http(e.to_string()))?;
321
322    if !response.status().is_success() {
323        return Err(FetchError::Http(format!(
324            "HF API returned status {}",
325            response.status()
326        )));
327    }
328
329    let models: Vec<ApiModelEntry> = response
330        .json()
331        .await
332        .map_err(|e| FetchError::Http(e.to_string()))?;
333
334    // Client-side filtering: the HF search API may ignore library/pipeline_tag
335    // query parameters when combined with the `search` parameter, so we filter
336    // the results ourselves to guarantee correctness.
337    let results = models
338        .into_iter()
339        .filter(|m| {
340            if let Some(lib) = library {
341                match m.library_name {
342                    // BORROW: explicit .as_str() instead of Deref coercion
343                    Some(ref name) if name.as_str().eq_ignore_ascii_case(lib) => {}
344                    _ => return false,
345                }
346            }
347            if let Some(pipe) = pipeline {
348                match m.pipeline_tag {
349                    // BORROW: explicit .as_str() instead of Deref coercion
350                    Some(ref tag) if tag.as_str().eq_ignore_ascii_case(pipe) => {}
351                    _ => return false,
352                }
353            }
354            true
355        })
356        .map(|m| SearchResult {
357            model_id: m.model_id,
358            downloads: m.downloads,
359            library_name: m.library_name,
360            pipeline_tag: m.pipeline_tag,
361        })
362        .collect();
363
364    Ok(results)
365}
366
367/// Fetches model card metadata for a specific model from the `HuggingFace` Hub.
368///
369/// Queries `GET https://huggingface.co/api/models/{model_id}` and extracts
370/// license, pipeline tag, tags, library name, and languages from the response.
371///
372/// # Arguments
373///
374/// * `model_id` — The full model identifier (e.g., `"mistralai/Ministral-3-3B-Instruct-2512"`).
375///
376/// # Errors
377///
378/// Returns [`FetchError::Http`] if the API request fails or the model is not found.
379pub async fn fetch_model_card(model_id: &str) -> Result<ModelCardMetadata, FetchError> {
380    let client = reqwest::Client::new();
381    let url = format!("{HF_API_BASE}/{model_id}");
382
383    let response = client
384        .get(url.as_str()) // BORROW: explicit .as_str()
385        .send()
386        .await
387        .map_err(|e| FetchError::Http(e.to_string()))?;
388
389    if !response.status().is_success() {
390        return Err(FetchError::Http(format!(
391            "HF API returned status {} for model {model_id}",
392            response.status()
393        )));
394    }
395
396    let detail: ApiModelDetail = response
397        .json()
398        .await
399        .map_err(|e| FetchError::Http(e.to_string()))?;
400
401    let (license, languages) = if let Some(card) = detail.card_data {
402        let langs = match card.language {
403            Some(ApiLanguage::Single(s)) => vec![s],
404            Some(ApiLanguage::Multiple(v)) => v,
405            None => Vec::new(),
406        };
407        (card.license, langs)
408    } else {
409        (None, Vec::new())
410    };
411
412    let gated = match detail.gated {
413        ApiGated::Bool(false) => GateStatus::Open,
414        ApiGated::Mode(ref mode) if mode.eq_ignore_ascii_case("manual") => GateStatus::Manual,
415        ApiGated::Bool(true) | ApiGated::Mode(_) => GateStatus::Auto,
416    };
417
418    Ok(ModelCardMetadata {
419        license,
420        pipeline_tag: detail.pipeline_tag,
421        tags: detail.tags,
422        library_name: detail.library_name,
423        languages,
424        gated,
425    })
426}
427
428/// Fetches the raw README text for a `HuggingFace` model repository.
429///
430/// Downloads `README.md` from the repository at the given revision.
431/// Returns `Ok(None)` if the file does not exist (HTTP 404).
432///
433/// # Arguments
434///
435/// * `model_id` — The full model identifier (e.g., `"mistralai/Ministral-3-3B-Instruct-2512"`).
436/// * `revision` — Git revision to fetch (defaults to `"main"` when `None`).
437/// * `token` — Optional authentication token.
438///
439/// # Errors
440///
441/// Returns [`FetchError::Http`] if the request fails (other than 404).
442pub async fn fetch_readme(
443    model_id: &str,
444    revision: Option<&str>,
445    token: Option<&str>,
446) -> Result<Option<String>, FetchError> {
447    let rev = revision.unwrap_or("main");
448    let url = crate::chunked::build_download_url(model_id, rev, "README.md");
449    let client = crate::chunked::build_client(token)?;
450
451    let response = client
452        .get(url.as_str()) // BORROW: explicit .as_str() instead of Deref coercion
453        .send()
454        .await
455        .map_err(|e| FetchError::Http(format!("failed to fetch README for {model_id}: {e}")))?;
456
457    if response.status() == reqwest::StatusCode::NOT_FOUND {
458        return Ok(None);
459    }
460
461    if !response.status().is_success() {
462        return Err(FetchError::Http(format!(
463            "README request for {model_id} returned status {}",
464            response.status()
465        )));
466    }
467
468    let text = response
469        .text()
470        .await
471        .map_err(|e| FetchError::Http(format!("failed to read README for {model_id}: {e}")))?;
472
473    Ok(Some(text))
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn normalize_8bit_variants() {
482        assert_eq!(normalize_quantization_terms("AWQ 8bit"), "AWQ 8-bit");
483        assert_eq!(normalize_quantization_terms("AWQ 8-bit"), "AWQ 8-bit");
484        assert_eq!(normalize_quantization_terms("AWQ int8"), "AWQ 8-bit");
485        assert_eq!(normalize_quantization_terms("AWQ INT8"), "AWQ 8-bit");
486    }
487
488    #[test]
489    fn normalize_4bit_variants() {
490        assert_eq!(normalize_quantization_terms("GPTQ 4bit"), "GPTQ 4-bit");
491        assert_eq!(normalize_quantization_terms("GPTQ INT4"), "GPTQ 4-bit");
492        assert_eq!(normalize_quantization_terms("GPTQ 4-bit"), "GPTQ 4-bit");
493    }
494
495    #[test]
496    fn normalize_fp8_variants() {
497        assert_eq!(normalize_quantization_terms("FP8"), "fp8");
498        assert_eq!(normalize_quantization_terms("float8"), "fp8");
499        assert_eq!(normalize_quantization_terms("fp8"), "fp8");
500    }
501
502    #[test]
503    fn normalize_passthrough() {
504        assert_eq!(normalize_quantization_terms("llama 3"), "llama 3");
505        assert_eq!(normalize_quantization_terms("RWKV-7"), "RWKV-7");
506    }
507}