hf-fetch-model 0.7.3

Fast HuggingFace model downloads for Rust — an embeddable library for downloading HuggingFace models with maximum throughput
Documentation
// SPDX-License-Identifier: MIT OR Apache-2.0

//! Model family discovery and search via the `HuggingFace` Hub API.
//!
//! Queries the HF Hub for popular models, extracts `model_type` metadata,
//! compares against locally cached families, and fetches model card metadata.

use std::collections::BTreeMap;
use std::hash::BuildHasher;

use serde::Deserialize;

use crate::error::FetchError;

/// A model found by searching the `HuggingFace` Hub.
#[derive(Debug, Clone)]
pub struct SearchResult {
    /// The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
    pub model_id: String,
    /// Total download count.
    pub downloads: u64,
}

/// A model family discovered from the `HuggingFace` Hub.
#[derive(Debug, Clone)]
pub struct DiscoveredFamily {
    /// The `model_type` identifier (e.g., `"gpt_neox"`, `"llama"`).
    pub model_type: String,
    /// The most-downloaded representative model for this family.
    pub top_model: String,
    /// Download count of the representative model.
    pub downloads: u64,
}

/// JSON response structure for an individual model from the HF API.
#[derive(Debug, Deserialize)]
struct ApiModelEntry {
    #[serde(rename = "modelId")]
    model_id: String,
    #[serde(default)]
    downloads: u64,
    #[serde(default)]
    config: Option<ApiConfig>,
}

/// The `config` object embedded in a model API response.
#[derive(Debug, Deserialize)]
struct ApiConfig {
    model_type: Option<String>,
}

/// Access control status of a model on the `HuggingFace` Hub.
///
/// Some models require users to accept license terms before downloading.
/// The gating mode determines whether approval is automatic or manual.
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum GateStatus {
    /// No gate — anyone can download without restrictions.
    Open,
    /// Automatic approval after the user accepts terms on the Hub.
    Auto,
    /// Manual approval by the model author after the user requests access.
    Manual,
}

impl GateStatus {
    /// Returns `true` if the model requires accepting terms before download.
    #[must_use]
    pub const fn is_gated(&self) -> bool {
        matches!(self, Self::Auto | Self::Manual)
    }
}

impl std::fmt::Display for GateStatus {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Open => write!(f, "open"),
            Self::Auto => write!(f, "auto"),
            Self::Manual => write!(f, "manual"),
        }
    }
}

/// Metadata from a `HuggingFace` model card.
///
/// Extracted from the single-model API endpoint
/// (`GET /api/models/{owner}/{model}`). All fields are optional
/// because model cards may omit any of them.
#[derive(Debug, Clone)]
pub struct ModelCardMetadata {
    /// SPDX license identifier (e.g., `"apache-2.0"`).
    pub license: Option<String>,
    /// Pipeline tag (e.g., `"text-generation"`).
    pub pipeline_tag: Option<String>,
    /// Tags associated with the model (e.g., `["pytorch", "safetensors"]`).
    pub tags: Vec<String>,
    /// Library name (e.g., `"transformers"`, `"vllm"`).
    pub library_name: Option<String>,
    /// Languages the model supports (e.g., `["en", "fr"]`).
    pub languages: Vec<String>,
    /// Access control status (open, auto-gated, or manually gated).
    pub gated: GateStatus,
}

/// JSON response for a single model from `GET /api/models/{model_id}`.
#[derive(Debug, Deserialize)]
struct ApiModelDetail {
    #[serde(default)]
    pipeline_tag: Option<String>,
    #[serde(default)]
    tags: Vec<String>,
    #[serde(default)]
    library_name: Option<String>,
    #[serde(default)]
    gated: ApiGated,
    #[serde(default, rename = "cardData")]
    card_data: Option<ApiCardData>,
}

/// The `cardData` sub-object (parsed YAML front matter from the model README).
#[derive(Debug, Deserialize)]
struct ApiCardData {
    #[serde(default)]
    license: Option<String>,
    #[serde(default)]
    language: Option<ApiLanguage>,
}

/// Languages in `cardData` can be a single string or a list of strings.
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiLanguage {
    Single(String),
    Multiple(Vec<String>),
}

/// The `gated` field can be `false` (boolean) or a string like `"auto"` / `"manual"`.
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiGated {
    Bool(bool),
    Mode(String),
}

impl Default for ApiGated {
    fn default() -> Self {
        Self::Bool(false)
    }
}

const PAGE_SIZE: usize = 100;
const HF_API_BASE: &str = "https://huggingface.co/api/models";

/// Queries the `HuggingFace` Hub API for top models by downloads
/// and returns families not present in the local cache.
///
/// # Arguments
///
/// * `local_families` — Set of `model_type` values already cached locally.
/// * `max_models` — Maximum number of models to scan (paginated in batches of 100).
///
/// # Errors
///
/// Returns [`FetchError::Http`] if any API request fails.
pub async fn discover_new_families<S: BuildHasher>(
    local_families: &std::collections::HashSet<String, S>,
    max_models: usize,
) -> Result<Vec<DiscoveredFamily>, FetchError> {
    let client = reqwest::Client::new();
    let mut remote_families: BTreeMap<String, (String, u64)> = BTreeMap::new();
    let mut offset: usize = 0;

    while offset < max_models {
        let limit = PAGE_SIZE.min(max_models.saturating_sub(offset));
        let url = format!(
            "{HF_API_BASE}?config=true&sort=downloads&direction=-1&limit={limit}&offset={offset}"
        );

        let response = client
            .get(url.as_str()) // BORROW: explicit .as_str()
            .send()
            .await
            .map_err(|e| FetchError::Http(e.to_string()))?;

        if !response.status().is_success() {
            return Err(FetchError::Http(format!(
                "HF API returned status {}",
                response.status()
            )));
        }

        let models: Vec<ApiModelEntry> = response
            .json()
            .await
            .map_err(|e| FetchError::Http(e.to_string()))?;

        if models.is_empty() {
            break;
        }

        for model in &models {
            // BORROW: explicit .as_ref() and .as_str() for Option<String>
            let model_type = model.config.as_ref().and_then(|c| c.model_type.as_deref());

            if let Some(mt) = model_type {
                remote_families
                    .entry(mt.to_owned())
                    .or_insert_with(|| (model.model_id.clone(), model.downloads));
            }
        }

        offset = offset.saturating_add(models.len());
    }

    // Filter to families not already cached locally
    // BORROW: explicit .as_str() instead of Deref coercion
    let discovered: Vec<DiscoveredFamily> = remote_families
        .into_iter()
        .filter(|(mt, _)| !local_families.contains(mt.as_str()))
        .map(|(model_type, (top_model, downloads))| DiscoveredFamily {
            model_type,
            top_model,
            downloads,
        })
        .collect();

    Ok(discovered)
}

/// Searches the `HuggingFace` Hub for models matching a query string.
///
/// Results are sorted by download count (most popular first).
///
/// # Arguments
///
/// * `query` — Free-text search string (e.g., `"RWKV-7"`, `"llama 3"`).
/// * `limit` — Maximum number of results to return.
///
/// # Errors
///
/// Returns [`FetchError::Http`] if the API request fails.
pub async fn search_models(query: &str, limit: usize) -> Result<Vec<SearchResult>, FetchError> {
    let client = reqwest::Client::new();

    let response = client
        .get(HF_API_BASE)
        .query(&[
            ("search", query),
            ("sort", "downloads"),
            ("direction", "-1"),
        ])
        .query(&[("limit", limit)])
        .send()
        .await
        .map_err(|e| FetchError::Http(e.to_string()))?;

    if !response.status().is_success() {
        return Err(FetchError::Http(format!(
            "HF API returned status {}",
            response.status()
        )));
    }

    let models: Vec<ApiModelEntry> = response
        .json()
        .await
        .map_err(|e| FetchError::Http(e.to_string()))?;

    let results = models
        .into_iter()
        .map(|m| SearchResult {
            model_id: m.model_id,
            downloads: m.downloads,
        })
        .collect();

    Ok(results)
}

/// Fetches model card metadata for a specific model from the `HuggingFace` Hub.
///
/// Queries `GET https://huggingface.co/api/models/{model_id}` and extracts
/// license, pipeline tag, tags, library name, and languages from the response.
///
/// # Arguments
///
/// * `model_id` — The full model identifier (e.g., `"mistralai/Ministral-3-3B-Instruct-2512"`).
///
/// # Errors
///
/// Returns [`FetchError::Http`] if the API request fails or the model is not found.
pub async fn fetch_model_card(model_id: &str) -> Result<ModelCardMetadata, FetchError> {
    let client = reqwest::Client::new();
    let url = format!("{HF_API_BASE}/{model_id}");

    let response = client
        .get(url.as_str()) // BORROW: explicit .as_str()
        .send()
        .await
        .map_err(|e| FetchError::Http(e.to_string()))?;

    if !response.status().is_success() {
        return Err(FetchError::Http(format!(
            "HF API returned status {} for model {model_id}",
            response.status()
        )));
    }

    let detail: ApiModelDetail = response
        .json()
        .await
        .map_err(|e| FetchError::Http(e.to_string()))?;

    let (license, languages) = if let Some(card) = detail.card_data {
        let langs = match card.language {
            Some(ApiLanguage::Single(s)) => vec![s],
            Some(ApiLanguage::Multiple(v)) => v,
            None => Vec::new(),
        };
        (card.license, langs)
    } else {
        (None, Vec::new())
    };

    let gated = match detail.gated {
        ApiGated::Bool(false) => GateStatus::Open,
        ApiGated::Mode(ref mode) if mode.eq_ignore_ascii_case("manual") => GateStatus::Manual,
        ApiGated::Bool(true) | ApiGated::Mode(_) => GateStatus::Auto,
    };

    Ok(ModelCardMetadata {
        license,
        pipeline_tag: detail.pipeline_tag,
        tags: detail.tags,
        library_name: detail.library_name,
        languages,
        gated,
    })
}