Skip to main content

entrenar/hf_pipeline/leaderboard/
client.rs

1//! HuggingFace leaderboard HTTP client
2//!
3//! Fetches leaderboard data from the HuggingFace datasets-server JSON API.
4//! Uses the rows endpoint to avoid Parquet parsing.
5
6use super::types::{HfLeaderboard, LeaderboardEntry, LeaderboardKind};
7use crate::hf_pipeline::error::FetchError;
8use crate::hf_pipeline::HfModelFetcher;
9
10/// HTTP client for fetching HuggingFace leaderboard data
11pub struct LeaderboardClient {
12    token: Option<String>,
13    client: reqwest::blocking::Client,
14}
15
16impl LeaderboardClient {
17    /// Create a new leaderboard client with automatic token resolution
18    pub fn new() -> Result<Self, FetchError> {
19        let token = HfModelFetcher::resolve_token();
20        let client =
21            reqwest::blocking::Client::builder().user_agent("entrenar/0.5").build().map_err(
22                |e| FetchError::HttpError { message: format!("Failed to create HTTP client: {e}") },
23            )?;
24
25        Ok(Self { token, client })
26    }
27
28    /// Create a client with an explicit token
29    pub fn with_token(token: impl Into<String>) -> Result<Self, FetchError> {
30        let client =
31            reqwest::blocking::Client::builder().user_agent("entrenar/0.5").build().map_err(
32                |e| FetchError::HttpError { message: format!("Failed to create HTTP client: {e}") },
33            )?;
34
35        Ok(Self { token: Some(token.into()), client })
36    }
37
38    /// Fetch leaderboard data (first page)
39    pub fn fetch(&self, kind: LeaderboardKind) -> Result<HfLeaderboard, FetchError> {
40        self.fetch_paginated(kind, 0, 100)
41    }
42
43    /// Fetch leaderboard data with pagination
44    pub fn fetch_paginated(
45        &self,
46        kind: LeaderboardKind,
47        offset: usize,
48        limit: usize,
49    ) -> Result<HfLeaderboard, FetchError> {
50        let repo_id = kind.dataset_repo_id();
51        let url = format!(
52            "https://datasets-server.huggingface.co/rows?dataset={repo_id}&config=default&split=train&offset={offset}&length={limit}"
53        );
54
55        let mut request = self.client.get(&url);
56        if let Some(token) = &self.token {
57            request = request.bearer_auth(token);
58        }
59
60        let response = request.send().map_err(|e| FetchError::HttpError {
61            message: format!("Leaderboard request failed: {e}"),
62        })?;
63
64        if !response.status().is_success() {
65            let status = response.status();
66            if status.as_u16() == 404 {
67                return Err(FetchError::LeaderboardNotFound { kind: kind.to_string() });
68            }
69            return Err(FetchError::HttpError {
70                message: format!("Leaderboard API returned {status} for {repo_id}"),
71            });
72        }
73
74        let body: serde_json::Value = response.json().map_err(|e| FetchError::HttpError {
75            message: format!("Failed to parse leaderboard JSON: {e}"),
76        })?;
77        parse_response(kind, &body)
78    }
79
80    /// Find a specific model in a leaderboard
81    pub fn find_model(
82        &self,
83        kind: LeaderboardKind,
84        model_repo_id: &str,
85    ) -> Result<Option<LeaderboardEntry>, FetchError> {
86        // Fetch the full leaderboard and search locally
87        // (HF datasets-server doesn't support server-side filtering by row content)
88        let leaderboard = self.fetch(kind)?;
89        Ok(leaderboard.find_model(model_repo_id).cloned())
90    }
91}
92
93/// Parse HF datasets-server JSON response into our types
94fn parse_response(
95    kind: LeaderboardKind,
96    body: &serde_json::Value,
97) -> Result<HfLeaderboard, FetchError> {
98    let mut leaderboard = HfLeaderboard::new(kind);
99
100    // Extract total count from "num_rows_total"
101    leaderboard.total_count =
102        body.get("num_rows_total").and_then(serde_json::Value::as_u64).unwrap_or(0) as usize;
103
104    // Extract rows
105    let rows = body.get("rows").and_then(|v| v.as_array()).ok_or_else(|| {
106        FetchError::LeaderboardParseError {
107            message: "Missing 'rows' array in response".to_string(),
108        }
109    })?;
110
111    for row in rows {
112        let row_data = row.get("row").unwrap_or(row);
113
114        // Try to extract model ID from common column names
115        let model_id = row_data
116            .get("model")
117            .or_else(|| row_data.get("model_id"))
118            .or_else(|| row_data.get("model_name"))
119            .and_then(|v| v.as_str())
120            .unwrap_or("unknown")
121            .to_string();
122
123        let mut entry = LeaderboardEntry::new(model_id);
124
125        // Extract all numeric values as scores
126        if let Some(obj) = row_data.as_object() {
127            for (key, value) in obj {
128                if let Some(num) = value.as_f64() {
129                    entry.scores.insert(key.clone(), num);
130                } else if let Some(s) = value.as_str() {
131                    entry.metadata.insert(key.clone(), s.to_string());
132                }
133            }
134        }
135
136        leaderboard.entries.push(entry);
137    }
138
139    Ok(leaderboard)
140}
141
142impl std::fmt::Debug for LeaderboardClient {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        f.debug_struct("LeaderboardClient")
145            .field("has_token", &self.token.is_some())
146            .finish_non_exhaustive()
147    }
148}