entrenar/hf_pipeline/leaderboard/
client.rs1use super::types::{HfLeaderboard, LeaderboardEntry, LeaderboardKind};
7use crate::hf_pipeline::error::FetchError;
8use crate::hf_pipeline::HfModelFetcher;
9
10pub struct LeaderboardClient {
12 token: Option<String>,
13 client: reqwest::blocking::Client,
14}
15
16impl LeaderboardClient {
17 pub fn new() -> Result<Self, FetchError> {
19 let token = HfModelFetcher::resolve_token();
20 let client =
21 reqwest::blocking::Client::builder().user_agent("entrenar/0.5").build().map_err(
22 |e| FetchError::HttpError { message: format!("Failed to create HTTP client: {e}") },
23 )?;
24
25 Ok(Self { token, client })
26 }
27
28 pub fn with_token(token: impl Into<String>) -> Result<Self, FetchError> {
30 let client =
31 reqwest::blocking::Client::builder().user_agent("entrenar/0.5").build().map_err(
32 |e| FetchError::HttpError { message: format!("Failed to create HTTP client: {e}") },
33 )?;
34
35 Ok(Self { token: Some(token.into()), client })
36 }
37
38 pub fn fetch(&self, kind: LeaderboardKind) -> Result<HfLeaderboard, FetchError> {
40 self.fetch_paginated(kind, 0, 100)
41 }
42
43 pub fn fetch_paginated(
45 &self,
46 kind: LeaderboardKind,
47 offset: usize,
48 limit: usize,
49 ) -> Result<HfLeaderboard, FetchError> {
50 let repo_id = kind.dataset_repo_id();
51 let url = format!(
52 "https://datasets-server.huggingface.co/rows?dataset={repo_id}&config=default&split=train&offset={offset}&length={limit}"
53 );
54
55 let mut request = self.client.get(&url);
56 if let Some(token) = &self.token {
57 request = request.bearer_auth(token);
58 }
59
60 let response = request.send().map_err(|e| FetchError::HttpError {
61 message: format!("Leaderboard request failed: {e}"),
62 })?;
63
64 if !response.status().is_success() {
65 let status = response.status();
66 if status.as_u16() == 404 {
67 return Err(FetchError::LeaderboardNotFound { kind: kind.to_string() });
68 }
69 return Err(FetchError::HttpError {
70 message: format!("Leaderboard API returned {status} for {repo_id}"),
71 });
72 }
73
74 let body: serde_json::Value = response.json().map_err(|e| FetchError::HttpError {
75 message: format!("Failed to parse leaderboard JSON: {e}"),
76 })?;
77 parse_response(kind, &body)
78 }
79
80 pub fn find_model(
82 &self,
83 kind: LeaderboardKind,
84 model_repo_id: &str,
85 ) -> Result<Option<LeaderboardEntry>, FetchError> {
86 let leaderboard = self.fetch(kind)?;
89 Ok(leaderboard.find_model(model_repo_id).cloned())
90 }
91}
92
93fn parse_response(
95 kind: LeaderboardKind,
96 body: &serde_json::Value,
97) -> Result<HfLeaderboard, FetchError> {
98 let mut leaderboard = HfLeaderboard::new(kind);
99
100 leaderboard.total_count =
102 body.get("num_rows_total").and_then(serde_json::Value::as_u64).unwrap_or(0) as usize;
103
104 let rows = body.get("rows").and_then(|v| v.as_array()).ok_or_else(|| {
106 FetchError::LeaderboardParseError {
107 message: "Missing 'rows' array in response".to_string(),
108 }
109 })?;
110
111 for row in rows {
112 let row_data = row.get("row").unwrap_or(row);
113
114 let model_id = row_data
116 .get("model")
117 .or_else(|| row_data.get("model_id"))
118 .or_else(|| row_data.get("model_name"))
119 .and_then(|v| v.as_str())
120 .unwrap_or("unknown")
121 .to_string();
122
123 let mut entry = LeaderboardEntry::new(model_id);
124
125 if let Some(obj) = row_data.as_object() {
127 for (key, value) in obj {
128 if let Some(num) = value.as_f64() {
129 entry.scores.insert(key.clone(), num);
130 } else if let Some(s) = value.as_str() {
131 entry.metadata.insert(key.clone(), s.to_string());
132 }
133 }
134 }
135
136 leaderboard.entries.push(entry);
137 }
138
139 Ok(leaderboard)
140}
141
142impl std::fmt::Debug for LeaderboardClient {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 f.debug_struct("LeaderboardClient")
145 .field("has_token", &self.token.is_some())
146 .finish_non_exhaustive()
147 }
148}