1use polars::prelude::*;
2use reqwest::blocking::Client;
3use serde::Deserialize;
4use std::io::Cursor;
5
6
7use super::{DatasetInfo, DataSource, Record};
10use crate::{Error, Result};
11
12const HF_DATASETS_SERVER: &str = "https://datasets-server.huggingface.co";
13
14pub struct HuggingFaceSource {
15 client: Client,
16 dataset: String,
17 subset: Option<String>,
18 split: String,
19 columns: Option<Vec<String>>,
20 info: DatasetInfo,
21 parquet_urls: Vec<String>,
22}
23
24#[derive(Debug, Clone)]
25pub struct Split {
26 pub name: String,
27 pub num_rows: usize,
28}
29
30#[derive(Deserialize)]
31struct InfoResponse {
32 dataset_info: std::collections::HashMap<String, ConfigInfo>,
33}
34
35#[derive(Deserialize)]
36struct ConfigInfo {
37 description: Option<String>,
38 features: std::collections::HashMap<String, serde_json::Value>,
39 splits: std::collections::HashMap<String, SplitInfo>,
40}
41
42#[derive(Deserialize)]
43struct SplitInfo {
44 name: String,
45 num_examples: usize,
46}
47
48#[derive(Deserialize)]
49struct ParquetResponse {
50 parquet_files: Vec<ParquetFile>,
51}
52
53#[derive(Deserialize)]
54struct ParquetFile {
55 config: String,
56 split: String,
57 url: String,
58}
59
60impl HuggingFaceSource {
61 pub fn new(
62 dataset: String,
63 subset: Option<String>,
64 split: String,
65 columns: Option<Vec<String>>,
66 ) -> Result<Self> {
67 let client = Client::new();
68 let mut source = Self {
69 client,
70 dataset,
71 subset,
72 split,
73 columns,
74 info: DatasetInfo::default(),
75 parquet_urls: Vec::new(),
76 };
77 source.fetch_info()?;
78 source.fetch_parquet_urls()?;
79 Ok(source)
80 }
81
82 fn fetch_info(&mut self) -> Result<()> {
83 let url = format!("{}/info?dataset={}", HF_DATASETS_SERVER, self.dataset);
84 let response: InfoResponse = self
85 .client
86 .get(&url)
87 .send()
88 .map_err(|e| Error::Dataset(format!("Failed to fetch dataset info: {}", e)))?
89 .json()
90 .map_err(|e| Error::Dataset(format!("Failed to parse dataset info: {}", e)))?;
91
92 let config_name = self.subset.as_deref().unwrap_or("default");
93 let config = response.dataset_info.get(config_name).ok_or_else(|| {
94 Error::Dataset(format!("Config '{}' not found in dataset", config_name))
95 })?;
96
97 let splits: Vec<Split> = config
98 .splits
99 .values()
100 .map(|s| Split {
101 name: s.name.clone(),
102 num_rows: s.num_examples,
103 })
104 .collect();
105
106 let split_info = splits
107 .iter()
108 .find(|s| s.name == self.split)
109 .ok_or_else(|| Error::Dataset(format!("Split '{}' not found", self.split)))?;
110
111 self.info = DatasetInfo {
112 name: self.dataset.clone(),
113 description: config.description.clone(),
114 num_rows: split_info.num_rows,
115 columns: config.features.keys().cloned().collect(),
116 splits,
117 };
118
119 Ok(())
120 }
121
122 fn fetch_parquet_urls(&mut self) -> Result<()> {
123 let url = format!("{}/parquet?dataset={}", HF_DATASETS_SERVER, self.dataset);
124 let response: ParquetResponse = self
125 .client
126 .get(&url)
127 .send()
128 .map_err(|e| Error::Dataset(format!("Failed to fetch parquet URLs: {}", e)))?
129 .json()
130 .map_err(|e| Error::Dataset(format!("Failed to parse parquet response: {}", e)))?;
131
132 let config_name = self.subset.as_deref().unwrap_or("default");
133 self.parquet_urls = response
134 .parquet_files
135 .into_iter()
136 .filter(|f| f.config == config_name && f.split == self.split)
137 .map(|f| f.url)
138 .collect();
139
140 if self.parquet_urls.is_empty() {
141 return Err(Error::Dataset(format!(
142 "No parquet files found for config '{}' split '{}'",
143 config_name, self.split
144 )));
145 }
146
147 Ok(())
148 }
149
150 fn download_and_read_parquet(&self) -> Result<DataFrame> {
151 let mut dfs = Vec::new();
152
153 for url in &self.parquet_urls {
154 let bytes = self
155 .client
156 .get(url)
157 .send()
158 .map_err(|e| Error::Dataset(format!("Failed to download parquet: {}", e)))?
159 .bytes()
160 .map_err(|e| Error::Dataset(format!("Failed to read parquet bytes: {}", e)))?;
161
162 let cursor = Cursor::new(bytes.to_vec());
163 let df = ParquetReader::new(cursor)
164 .finish()
165 .map_err(|e| Error::Dataset(format!("Failed to parse parquet: {}", e)))?;
166 dfs.push(df);
167 }
168
169 if dfs.len() == 1 {
170 Ok(dfs.remove(0))
171 } else {
172 let lazy_frames: Vec<_> = dfs.into_iter().map(|df| df.lazy()).collect();
173 concat(&lazy_frames, UnionArgs::default())
174 .map_err(|e| Error::Dataset(format!("Failed to concat dataframes: {}", e)))?
175 .collect()
176 .map_err(|e| Error::Dataset(format!("Failed to collect dataframe: {}", e)))
177 }
178 }
179}
180
181impl DataSource for HuggingFaceSource {
182 fn info(&self) -> &DatasetInfo {
183 &self.info
184 }
185
186 fn load(&mut self, sample: Option<usize>) -> Result<Vec<Record>> {
187 let mut df = self.download_and_read_parquet()?;
188
189 if let Some(cols) = &self.columns {
190 let col_exprs: Vec<_> = cols.iter().map(col).collect();
191 df = df
192 .lazy()
193 .select(col_exprs)
194 .collect()
195 .map_err(|e| Error::Dataset(format!("Failed to select columns: {}", e)))?;
196 }
197
198 if let Some(n) = sample {
199 df = df.head(Some(n));
200 }
201
202 let mut records = Vec::with_capacity(df.height());
203 for i in 0..df.height() {
204 let row = df.get(i).ok_or_else(|| Error::Dataset("Row not found".into()))?;
205 let mut map = serde_json::Map::new();
206
207 for (col_name, value) in df.get_column_names().iter().zip(row.iter()) {
208 let json_value = anyvalue_to_json(value);
209 map.insert(col_name.to_string(), json_value);
210 }
211
212 records.push(Record {
213 data: serde_json::Value::Object(map),
214 index: i,
215 });
216 }
217
218 Ok(records)
219 }
220}
221
222fn anyvalue_to_json(value: &AnyValue) -> serde_json::Value {
223 match value {
224 AnyValue::Null => serde_json::Value::Null,
225 AnyValue::Boolean(b) => serde_json::Value::Bool(*b),
226 AnyValue::String(s) => serde_json::Value::String(s.to_string()),
227 AnyValue::StringOwned(s) => serde_json::Value::String(s.to_string()),
228 AnyValue::Float32(n) => serde_json::Number::from_f64(*n as f64)
229 .map(serde_json::Value::Number)
230 .unwrap_or(serde_json::Value::Null),
231 AnyValue::Float64(n) => serde_json::Number::from_f64(*n)
232 .map(serde_json::Value::Number)
233 .unwrap_or(serde_json::Value::Null),
234 other => serde_json::Value::String(format!("{}", other)),
235 }
236}
237
238pub async fn search_datasets(query: &str, limit: usize) -> Result<Vec<DatasetSearchResult>> {
239 let client = reqwest::Client::new();
240 let url = format!(
241 "https://huggingface.co/api/datasets?search={}&limit={}",
242 query, limit
243 );
244
245 let results: Vec<DatasetSearchResult> = client
246 .get(&url)
247 .send()
248 .await
249 .map_err(|e| Error::Dataset(format!("Search failed: {}", e)))?
250 .json()
251 .await
252 .map_err(|e| Error::Dataset(format!("Failed to parse search results: {}", e)))?;
253
254 Ok(results)
255}
256
257#[derive(Debug, Deserialize)]
258pub struct DatasetSearchResult {
259 pub id: String,
260 #[serde(default)]
261 pub likes: u32,
262 #[serde(default)]
263 pub downloads: u64,
264}
265
266pub async fn get_dataset_info(dataset: &str) -> Result<DatasetInfo> {
267 let client = reqwest::Client::new();
268 let url = format!("{}/info?dataset={}", HF_DATASETS_SERVER, dataset);
269
270 let response: InfoResponse = client
271 .get(&url)
272 .send()
273 .await
274 .map_err(|e| Error::Dataset(format!("Failed to fetch info: {}", e)))?
275 .json()
276 .await
277 .map_err(|e| Error::Dataset(format!("Failed to parse info: {}", e)))?;
278
279 let (_config_name, config) = response
280 .dataset_info
281 .into_iter()
282 .next()
283 .ok_or_else(|| Error::Dataset("No config found".into()))?;
284
285 let splits: Vec<Split> = config
286 .splits
287 .values()
288 .map(|s| Split {
289 name: s.name.clone(),
290 num_rows: s.num_examples,
291 })
292 .collect();
293
294 let total_rows: usize = splits.iter().map(|s| s.num_rows).sum();
295
296 Ok(DatasetInfo {
297 name: dataset.to_string(),
298 description: config.description,
299 num_rows: total_rows,
300 columns: config.features.keys().cloned().collect(),
301 splits,
302 })
303}
304
305pub async fn preview_dataset(
306 dataset: &str,
307 config: Option<&str>,
308 split: &str,
309 rows: usize,
310) -> Result<Vec<serde_json::Value>> {
311 let client = reqwest::Client::new();
312 let config = config.unwrap_or("default");
313 let url = format!(
314 "{}/rows?dataset={}&config={}&split={}&offset=0&length={}",
315 HF_DATASETS_SERVER, dataset, config, split, rows.min(100)
316 );
317
318 #[derive(Deserialize)]
319 struct RowsResponse {
320 rows: Vec<RowEntry>,
321 }
322
323 #[derive(Deserialize)]
324 struct RowEntry {
325 row: serde_json::Value,
326 }
327
328 let response: RowsResponse = client
329 .get(&url)
330 .send()
331 .await
332 .map_err(|e| Error::Dataset(format!("Failed to fetch rows: {}", e)))?
333 .json()
334 .await
335 .map_err(|e| Error::Dataset(format!("Failed to parse rows: {}", e)))?;
336
337 Ok(response.rows.into_iter().map(|r| r.row).collect())
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[tokio::test]
345 async fn test_search_datasets() {
346 let results = search_datasets("sentiment", 5).await.unwrap();
347 assert!(!results.is_empty());
348 assert!(results.len() <= 5);
349 }
350
351 #[tokio::test]
352 async fn test_get_dataset_info() {
353 let info = get_dataset_info("cornell-movie-review-data/rotten_tomatoes")
354 .await
355 .unwrap();
356 assert_eq!(info.name, "cornell-movie-review-data/rotten_tomatoes");
357 assert!(info.columns.contains(&"text".to_string()));
358 assert!(info.columns.contains(&"label".to_string()));
359 }
360
361 #[tokio::test]
362 async fn test_preview_dataset() {
363 let rows = preview_dataset(
364 "cornell-movie-review-data/rotten_tomatoes",
365 None,
366 "train",
367 3,
368 )
369 .await
370 .unwrap();
371 assert_eq!(rows.len(), 3);
372 assert!(rows[0].get("text").is_some());
373 }
374
375 #[test]
376 fn test_huggingface_source_load() {
377 let mut source = HuggingFaceSource::new(
378 "cornell-movie-review-data/rotten_tomatoes".to_string(),
379 None,
380 "train".to_string(),
381 None,
382 )
383 .unwrap();
384
385 let records = source.load(Some(10)).unwrap();
386 assert_eq!(records.len(), 10);
387 assert!(records[0].data.get("text").is_some());
388 }
389}