Skip to main content

hf_fetch_model/
discover.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Model family discovery and search via the `HuggingFace` Hub API.
4//!
5//! Queries the HF Hub for popular models, extracts `model_type` metadata,
6//! compares against locally cached families, and fetches model card metadata.
7
8use std::collections::BTreeMap;
9use std::hash::BuildHasher;
10
11use serde::Deserialize;
12
13use crate::error::FetchError;
14
15/// A model found by searching the `HuggingFace` Hub.
16#[derive(Debug, Clone)]
17pub struct SearchResult {
18    /// The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
19    pub model_id: String,
20    /// Total download count.
21    pub downloads: u64,
22    /// Library framework (e.g., `"transformers"`, `"peft"`, `"diffusers"`), if reported.
23    pub library_name: Option<String>,
24    /// Pipeline task tag (e.g., `"text-generation"`), if reported.
25    pub pipeline_tag: Option<String>,
26    /// Tags from the model's metadata (e.g., `["gguf", "conversational"]`).
27    pub tags: Vec<String>,
28}
29
30/// A model family discovered from the `HuggingFace` Hub.
31#[derive(Debug, Clone)]
32pub struct DiscoveredFamily {
33    /// The `model_type` identifier (e.g., `"gpt_neox"`, `"llama"`).
34    pub model_type: String,
35    /// The most-downloaded representative model for this family.
36    pub top_model: String,
37    /// Download count of the representative model.
38    pub downloads: u64,
39}
40
41/// JSON response structure for an individual model from the HF API.
42#[derive(Debug, Deserialize)]
43struct ApiModelEntry {
44    #[serde(rename = "modelId")]
45    model_id: String,
46    #[serde(default)]
47    downloads: u64,
48    #[serde(default)]
49    config: Option<ApiConfig>,
50    #[serde(default)]
51    library_name: Option<String>,
52    #[serde(default)]
53    pipeline_tag: Option<String>,
54    #[serde(default)]
55    tags: Vec<String>,
56}
57
58/// The `config` object embedded in a model API response.
59#[derive(Debug, Deserialize)]
60struct ApiConfig {
61    model_type: Option<String>,
62}
63
64/// Access control status of a model on the `HuggingFace` Hub.
65///
66/// Some models require users to accept license terms before downloading.
67/// The gating mode determines whether approval is automatic or manual.
68#[derive(Debug, Clone, PartialEq, Eq)]
69#[non_exhaustive]
70pub enum GateStatus {
71    /// No gate — anyone can download without restrictions.
72    Open,
73    /// Automatic approval after the user accepts terms on the Hub.
74    Auto,
75    /// Manual approval by the model author after the user requests access.
76    Manual,
77}
78
79impl GateStatus {
80    /// Returns `true` if the model requires accepting terms before download.
81    #[must_use]
82    pub const fn is_gated(&self) -> bool {
83        matches!(self, Self::Auto | Self::Manual)
84    }
85}
86
87impl std::fmt::Display for GateStatus {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        match self {
90            Self::Open => write!(f, "open"),
91            Self::Auto => write!(f, "auto"),
92            Self::Manual => write!(f, "manual"),
93        }
94    }
95}
96
97/// Metadata from a `HuggingFace` model card.
98///
99/// Extracted from the single-model API endpoint
100/// (`GET /api/models/{owner}/{model}`). All fields are optional
101/// because model cards may omit any of them.
102#[derive(Debug, Clone)]
103pub struct ModelCardMetadata {
104    /// SPDX license identifier (e.g., `"apache-2.0"`).
105    pub license: Option<String>,
106    /// Pipeline tag (e.g., `"text-generation"`).
107    pub pipeline_tag: Option<String>,
108    /// Tags associated with the model (e.g., `["pytorch", "safetensors"]`).
109    pub tags: Vec<String>,
110    /// Library name (e.g., `"transformers"`, `"vllm"`).
111    pub library_name: Option<String>,
112    /// Languages the model supports (e.g., `["en", "fr"]`).
113    pub languages: Vec<String>,
114    /// Access control status (open, auto-gated, or manually gated).
115    pub gated: GateStatus,
116}
117
118/// JSON response for a single model from `GET /api/models/{model_id}`.
119#[derive(Debug, Deserialize)]
120struct ApiModelDetail {
121    #[serde(default)]
122    pipeline_tag: Option<String>,
123    #[serde(default)]
124    tags: Vec<String>,
125    #[serde(default)]
126    library_name: Option<String>,
127    #[serde(default)]
128    gated: ApiGated,
129    #[serde(default, rename = "cardData")]
130    card_data: Option<ApiCardData>,
131}
132
133/// The `cardData` sub-object (parsed YAML front matter from the model README).
134#[derive(Debug, Deserialize)]
135struct ApiCardData {
136    #[serde(default)]
137    license: Option<String>,
138    #[serde(default)]
139    language: Option<ApiLanguage>,
140}
141
142/// Languages in `cardData` can be a single string or a list of strings.
143#[derive(Debug, Deserialize)]
144#[serde(untagged)]
145enum ApiLanguage {
146    Single(String),
147    Multiple(Vec<String>),
148}
149
150/// The `gated` field can be `false` (boolean) or a string like `"auto"` / `"manual"`.
151#[derive(Debug, Deserialize)]
152#[serde(untagged)]
153enum ApiGated {
154    Bool(bool),
155    Mode(String),
156}
157
158impl Default for ApiGated {
159    fn default() -> Self {
160        Self::Bool(false)
161    }
162}
163
164const PAGE_SIZE: usize = 100;
165const HF_API_BASE: &str = "https://huggingface.co/api/models";
166
167/// Queries the `HuggingFace` Hub API for top models by downloads
168/// and returns families not present in the local cache.
169///
170/// # Arguments
171///
172/// * `local_families` — Set of `model_type` values already cached locally.
173/// * `max_models` — Maximum number of models to scan (paginated in batches of 100).
174///
175/// # Errors
176///
177/// Returns [`FetchError::Http`] if any API request fails.
178pub async fn discover_new_families<S: BuildHasher>(
179    local_families: &std::collections::HashSet<String, S>,
180    max_models: usize,
181) -> Result<Vec<DiscoveredFamily>, FetchError> {
182    let client = reqwest::Client::new();
183    let mut remote_families: BTreeMap<String, (String, u64)> = BTreeMap::new();
184    let mut offset: usize = 0;
185
186    while offset < max_models {
187        let limit = PAGE_SIZE.min(max_models.saturating_sub(offset));
188        let url = format!(
189            "{HF_API_BASE}?config=true&sort=downloads&direction=-1&limit={limit}&offset={offset}"
190        );
191
192        let response = client
193            .get(url.as_str()) // BORROW: explicit .as_str()
194            .send()
195            .await
196            .map_err(|e| FetchError::Http(e.to_string()))?;
197
198        if !response.status().is_success() {
199            return Err(FetchError::Http(format!(
200                "HF API returned status {}",
201                response.status()
202            )));
203        }
204
205        let models: Vec<ApiModelEntry> = response
206            .json()
207            .await
208            .map_err(|e| FetchError::Http(e.to_string()))?;
209
210        if models.is_empty() {
211            break;
212        }
213
214        for model in &models {
215            // BORROW: explicit .as_ref() and .as_str() for Option<String>
216            let model_type = model.config.as_ref().and_then(|c| c.model_type.as_deref());
217
218            if let Some(mt) = model_type {
219                remote_families
220                    .entry(mt.to_owned())
221                    .or_insert_with(|| (model.model_id.clone(), model.downloads));
222            }
223        }
224
225        offset = offset.saturating_add(models.len());
226    }
227
228    // Filter to families not already cached locally
229    // BORROW: explicit .as_str() instead of Deref coercion
230    let discovered: Vec<DiscoveredFamily> = remote_families
231        .into_iter()
232        .filter(|(mt, _)| !local_families.contains(mt.as_str()))
233        .map(|(model_type, (top_model, downloads))| DiscoveredFamily {
234            model_type,
235            top_model,
236            downloads,
237        })
238        .collect();
239
240    Ok(discovered)
241}
242
243/// Normalizes common quantization synonyms in a search query so that
244/// variant spellings (e.g., `"8bit"`, `"8-bit"`, `"int8"`) produce
245/// consistent results.
246#[must_use]
247fn normalize_quantization_terms(query: &str) -> String {
248    /// Synonym groups: all variants map to the first (canonical) form.
249    const SYNONYMS: &[(&[&str], &str)] = &[
250        (&["8bit", "8-bit", "int8"], "8-bit"),
251        (&["4bit", "4-bit", "int4"], "4-bit"),
252        (&["fp8", "float8"], "fp8"),
253    ];
254
255    query
256        .split_whitespace()
257        .map(|token| {
258            // BORROW: explicit .to_lowercase() for case-insensitive comparison
259            let lower = token.to_lowercase();
260            for &(variants, canonical) in SYNONYMS {
261                // BORROW: explicit .as_str() instead of Deref coercion
262                if variants.contains(&lower.as_str()) {
263                    // BORROW: explicit .to_owned() for &str → owned String
264                    return (*canonical).to_owned();
265                }
266            }
267            // BORROW: explicit .to_owned() for &str → owned String
268            token.to_owned()
269        })
270        .collect::<Vec<_>>()
271        .join(" ")
272}
273
274/// Searches the `HuggingFace` Hub for models matching a query string.
275///
276/// Optionally filters by `library` framework (e.g., `"transformers"`, `"peft"`),
277/// `pipeline` task tag (e.g., `"text-generation"`), and/or `tag` (e.g., `"gguf"`).
278/// Library and pipeline filters are sent as query parameters; tag is sent via the
279/// `filter` parameter. All three are also applied client-side for correctness.
280///
281/// Common quantization synonyms (`"8bit"` / `"8-bit"` / `"int8"`,
282/// `"4bit"` / `"4-bit"` / `"int4"`, `"fp8"` / `"float8"`) are normalized
283/// before querying the API so that variant spellings return consistent results.
284///
285/// Results are sorted by download count (most popular first).
286///
287/// # Arguments
288///
289/// * `query` — Free-text search string (e.g., `"RWKV-7"`, `"llama 3"`).
290/// * `limit` — Maximum number of results to return.
291/// * `library` — Optional library filter (e.g., `"peft"`, `"transformers"`).
292/// * `pipeline` — Optional pipeline tag filter (e.g., `"text-generation"`).
293/// * `tag` — Optional tag filter (e.g., `"gguf"`, `"conversational"`).
294///
295/// # Errors
296///
297/// Returns [`FetchError::Http`] if the API request fails.
298pub async fn search_models(
299    query: &str,
300    limit: usize,
301    library: Option<&str>,
302    pipeline: Option<&str>,
303    tag: Option<&str>,
304) -> Result<Vec<SearchResult>, FetchError> {
305    let normalized = normalize_quantization_terms(query);
306    let client = reqwest::Client::new();
307
308    // BORROW: explicit .as_str() instead of Deref coercion
309    let mut query_params: Vec<(&str, &str)> = vec![
310        ("search", normalized.as_str()),
311        ("sort", "downloads"),
312        ("direction", "-1"),
313    ];
314    if let Some(lib) = library {
315        query_params.push(("library", lib));
316    }
317    if let Some(pipe) = pipeline {
318        query_params.push(("pipeline_tag", pipe));
319    }
320    if let Some(t) = tag {
321        query_params.push(("filter", t));
322    }
323
324    let response = client
325        .get(HF_API_BASE)
326        .query(&query_params)
327        .query(&[("limit", limit)])
328        .send()
329        .await
330        .map_err(|e| FetchError::Http(e.to_string()))?;
331
332    if !response.status().is_success() {
333        return Err(FetchError::Http(format!(
334            "HF API returned status {}",
335            response.status()
336        )));
337    }
338
339    let models: Vec<ApiModelEntry> = response
340        .json()
341        .await
342        .map_err(|e| FetchError::Http(e.to_string()))?;
343
344    // Client-side filtering: the HF search API may ignore library/pipeline_tag/filter
345    // query parameters when combined with the `search` parameter, so we filter
346    // the results ourselves to guarantee correctness.
347    let results = models
348        .into_iter()
349        .filter(|m| {
350            if let Some(lib) = library {
351                match m.library_name {
352                    // BORROW: explicit .as_str() instead of Deref coercion
353                    Some(ref name) if name.as_str().eq_ignore_ascii_case(lib) => {}
354                    _ => return false,
355                }
356            }
357            if let Some(pipe) = pipeline {
358                match m.pipeline_tag {
359                    // BORROW: explicit .as_str() instead of Deref coercion
360                    Some(ref t) if t.as_str().eq_ignore_ascii_case(pipe) => {}
361                    _ => return false,
362                }
363            }
364            if let Some(t) = tag {
365                if !m.tags.iter().any(|model_tag| {
366                    // BORROW: explicit .as_str() instead of Deref coercion
367                    model_tag.as_str().eq_ignore_ascii_case(t)
368                }) {
369                    return false;
370                }
371            }
372            true
373        })
374        .map(|m| SearchResult {
375            model_id: m.model_id,
376            downloads: m.downloads,
377            library_name: m.library_name,
378            pipeline_tag: m.pipeline_tag,
379            tags: m.tags,
380        })
381        .collect();
382
383    Ok(results)
384}
385
386/// Fetches model card metadata for a specific model from the `HuggingFace` Hub.
387///
388/// Queries `GET https://huggingface.co/api/models/{model_id}` and extracts
389/// license, pipeline tag, tags, library name, and languages from the response.
390///
391/// # Arguments
392///
393/// * `model_id` — The full model identifier (e.g., `"mistralai/Ministral-3-3B-Instruct-2512"`).
394///
395/// # Errors
396///
397/// Returns [`FetchError::Http`] if the API request fails or the model is not found.
398pub async fn fetch_model_card(model_id: &str) -> Result<ModelCardMetadata, FetchError> {
399    let client = reqwest::Client::new();
400    let url = format!("{HF_API_BASE}/{model_id}");
401
402    let response = client
403        .get(url.as_str()) // BORROW: explicit .as_str()
404        .send()
405        .await
406        .map_err(|e| FetchError::Http(e.to_string()))?;
407
408    if !response.status().is_success() {
409        return Err(FetchError::Http(format!(
410            "HF API returned status {} for model {model_id}",
411            response.status()
412        )));
413    }
414
415    let detail: ApiModelDetail = response
416        .json()
417        .await
418        .map_err(|e| FetchError::Http(e.to_string()))?;
419
420    let (license, languages) = if let Some(card) = detail.card_data {
421        let langs = match card.language {
422            Some(ApiLanguage::Single(s)) => vec![s],
423            Some(ApiLanguage::Multiple(v)) => v,
424            None => Vec::new(),
425        };
426        (card.license, langs)
427    } else {
428        (None, Vec::new())
429    };
430
431    let gated = match detail.gated {
432        ApiGated::Bool(false) => GateStatus::Open,
433        ApiGated::Mode(ref mode) if mode.eq_ignore_ascii_case("manual") => GateStatus::Manual,
434        ApiGated::Bool(true) | ApiGated::Mode(_) => GateStatus::Auto,
435    };
436
437    Ok(ModelCardMetadata {
438        license,
439        pipeline_tag: detail.pipeline_tag,
440        tags: detail.tags,
441        library_name: detail.library_name,
442        languages,
443        gated,
444    })
445}
446
447/// Fetches the raw README text for a `HuggingFace` model repository.
448///
449/// Downloads `README.md` from the repository at the given revision.
450/// Returns `Ok(None)` if the file does not exist (HTTP 404).
451///
452/// # Arguments
453///
454/// * `model_id` — The full model identifier (e.g., `"mistralai/Ministral-3-3B-Instruct-2512"`).
455/// * `revision` — Git revision to fetch (defaults to `"main"` when `None`).
456/// * `token` — Optional authentication token.
457///
458/// # Errors
459///
460/// Returns [`FetchError::Http`] if the request fails (other than 404).
461pub async fn fetch_readme(
462    model_id: &str,
463    revision: Option<&str>,
464    token: Option<&str>,
465) -> Result<Option<String>, FetchError> {
466    let rev = revision.unwrap_or("main");
467    let url = crate::chunked::build_download_url(model_id, rev, "README.md");
468    let client = crate::chunked::build_client(token)?;
469
470    let response = client
471        .get(url.as_str()) // BORROW: explicit .as_str() instead of Deref coercion
472        .send()
473        .await
474        .map_err(|e| FetchError::Http(format!("failed to fetch README for {model_id}: {e}")))?;
475
476    if response.status() == reqwest::StatusCode::NOT_FOUND {
477        return Ok(None);
478    }
479
480    if !response.status().is_success() {
481        return Err(FetchError::Http(format!(
482            "README request for {model_id} returned status {}",
483            response.status()
484        )));
485    }
486
487    let text = response
488        .text()
489        .await
490        .map_err(|e| FetchError::Http(format!("failed to read README for {model_id}: {e}")))?;
491
492    Ok(Some(text))
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn normalize_8bit_variants() {
501        assert_eq!(normalize_quantization_terms("AWQ 8bit"), "AWQ 8-bit");
502        assert_eq!(normalize_quantization_terms("AWQ 8-bit"), "AWQ 8-bit");
503        assert_eq!(normalize_quantization_terms("AWQ int8"), "AWQ 8-bit");
504        assert_eq!(normalize_quantization_terms("AWQ INT8"), "AWQ 8-bit");
505    }
506
507    #[test]
508    fn normalize_4bit_variants() {
509        assert_eq!(normalize_quantization_terms("GPTQ 4bit"), "GPTQ 4-bit");
510        assert_eq!(normalize_quantization_terms("GPTQ INT4"), "GPTQ 4-bit");
511        assert_eq!(normalize_quantization_terms("GPTQ 4-bit"), "GPTQ 4-bit");
512    }
513
514    #[test]
515    fn normalize_fp8_variants() {
516        assert_eq!(normalize_quantization_terms("FP8"), "fp8");
517        assert_eq!(normalize_quantization_terms("float8"), "fp8");
518        assert_eq!(normalize_quantization_terms("fp8"), "fp8");
519    }
520
521    #[test]
522    fn normalize_passthrough() {
523        assert_eq!(normalize_quantization_terms("llama 3"), "llama 3");
524        assert_eq!(normalize_quantization_terms("RWKV-7"), "RWKV-7");
525    }
526}