Skip to main content

synth_claw/datasets/
hf.rs

1use polars::prelude::*;
2use reqwest::blocking::Client;
3use serde::Deserialize;
4use std::io::Cursor;
5
6
7// TODO: create skill file for the cli/lib so that agents can use it.
8
9use super::{DatasetInfo, DataSource, Record};
10use crate::{Error, Result};
11
12const HF_DATASETS_SERVER: &str = "https://datasets-server.huggingface.co";
13
14pub struct HuggingFaceSource {
15    client: Client,
16    dataset: String,
17    subset: Option<String>,
18    split: String,
19    columns: Option<Vec<String>>,
20    info: DatasetInfo,
21    parquet_urls: Vec<String>,
22}
23
24#[derive(Debug, Clone)]
25pub struct Split {
26    pub name: String,
27    pub num_rows: usize,
28}
29
30#[derive(Deserialize)]
31struct InfoResponse {
32    dataset_info: std::collections::HashMap<String, ConfigInfo>,
33}
34
35#[derive(Deserialize)]
36struct ConfigInfo {
37    description: Option<String>,
38    features: std::collections::HashMap<String, serde_json::Value>,
39    splits: std::collections::HashMap<String, SplitInfo>,
40}
41
42#[derive(Deserialize)]
43struct SplitInfo {
44    name: String,
45    num_examples: usize,
46}
47
48#[derive(Deserialize)]
49struct ParquetResponse {
50    parquet_files: Vec<ParquetFile>,
51}
52
53#[derive(Deserialize)]
54struct ParquetFile {
55    config: String,
56    split: String,
57    url: String,
58}
59
60impl HuggingFaceSource {
61    pub fn new(
62        dataset: String,
63        subset: Option<String>,
64        split: String,
65        columns: Option<Vec<String>>,
66    ) -> Result<Self> {
67        let client = Client::new();
68        let mut source = Self {
69            client,
70            dataset,
71            subset,
72            split,
73            columns,
74            info: DatasetInfo::default(),
75            parquet_urls: Vec::new(),
76        };
77        source.fetch_info()?;
78        source.fetch_parquet_urls()?;
79        Ok(source)
80    }
81
82    fn fetch_info(&mut self) -> Result<()> {
83        let url = format!("{}/info?dataset={}", HF_DATASETS_SERVER, self.dataset);
84        let response: InfoResponse = self
85            .client
86            .get(&url)
87            .send()
88            .map_err(|e| Error::Dataset(format!("Failed to fetch dataset info: {}", e)))?
89            .json()
90            .map_err(|e| Error::Dataset(format!("Failed to parse dataset info: {}", e)))?;
91
92        let config_name = self.subset.as_deref().unwrap_or("default");
93        let config = response.dataset_info.get(config_name).ok_or_else(|| {
94            Error::Dataset(format!("Config '{}' not found in dataset", config_name))
95        })?;
96
97        let splits: Vec<Split> = config
98            .splits
99            .values()
100            .map(|s| Split {
101                name: s.name.clone(),
102                num_rows: s.num_examples,
103            })
104            .collect();
105
106        let split_info = splits
107            .iter()
108            .find(|s| s.name == self.split)
109            .ok_or_else(|| Error::Dataset(format!("Split '{}' not found", self.split)))?;
110
111        self.info = DatasetInfo {
112            name: self.dataset.clone(),
113            description: config.description.clone(),
114            num_rows: split_info.num_rows,
115            columns: config.features.keys().cloned().collect(),
116            splits,
117        };
118
119        Ok(())
120    }
121
122    fn fetch_parquet_urls(&mut self) -> Result<()> {
123        let url = format!("{}/parquet?dataset={}", HF_DATASETS_SERVER, self.dataset);
124        let response: ParquetResponse = self
125            .client
126            .get(&url)
127            .send()
128            .map_err(|e| Error::Dataset(format!("Failed to fetch parquet URLs: {}", e)))?
129            .json()
130            .map_err(|e| Error::Dataset(format!("Failed to parse parquet response: {}", e)))?;
131
132        let config_name = self.subset.as_deref().unwrap_or("default");
133        self.parquet_urls = response
134            .parquet_files
135            .into_iter()
136            .filter(|f| f.config == config_name && f.split == self.split)
137            .map(|f| f.url)
138            .collect();
139
140        if self.parquet_urls.is_empty() {
141            return Err(Error::Dataset(format!(
142                "No parquet files found for config '{}' split '{}'",
143                config_name, self.split
144            )));
145        }
146
147        Ok(())
148    }
149
150    fn download_and_read_parquet(&self) -> Result<DataFrame> {
151        let mut dfs = Vec::new();
152
153        for url in &self.parquet_urls {
154            let bytes = self
155                .client
156                .get(url)
157                .send()
158                .map_err(|e| Error::Dataset(format!("Failed to download parquet: {}", e)))?
159                .bytes()
160                .map_err(|e| Error::Dataset(format!("Failed to read parquet bytes: {}", e)))?;
161
162            let cursor = Cursor::new(bytes.to_vec());
163            let df = ParquetReader::new(cursor)
164                .finish()
165                .map_err(|e| Error::Dataset(format!("Failed to parse parquet: {}", e)))?;
166            dfs.push(df);
167        }
168
169        if dfs.len() == 1 {
170            Ok(dfs.remove(0))
171        } else {
172            let lazy_frames: Vec<_> = dfs.into_iter().map(|df| df.lazy()).collect();
173            concat(&lazy_frames, UnionArgs::default())
174                .map_err(|e| Error::Dataset(format!("Failed to concat dataframes: {}", e)))?
175                .collect()
176                .map_err(|e| Error::Dataset(format!("Failed to collect dataframe: {}", e)))
177        }
178    }
179}
180
181impl DataSource for HuggingFaceSource {
182    fn info(&self) -> &DatasetInfo {
183        &self.info
184    }
185
186    fn load(&mut self, sample: Option<usize>) -> Result<Vec<Record>> {
187        let mut df = self.download_and_read_parquet()?;
188
189        if let Some(cols) = &self.columns {
190            let col_exprs: Vec<_> = cols.iter().map(col).collect();
191            df = df
192                .lazy()
193                .select(col_exprs)
194                .collect()
195                .map_err(|e| Error::Dataset(format!("Failed to select columns: {}", e)))?;
196        }
197
198        if let Some(n) = sample {
199            df = df.head(Some(n));
200        }
201
202        let mut records = Vec::with_capacity(df.height());
203        for i in 0..df.height() {
204            let row = df.get(i).ok_or_else(|| Error::Dataset("Row not found".into()))?;
205            let mut map = serde_json::Map::new();
206
207            for (col_name, value) in df.get_column_names().iter().zip(row.iter()) {
208                let json_value = anyvalue_to_json(value);
209                map.insert(col_name.to_string(), json_value);
210            }
211
212            records.push(Record {
213                data: serde_json::Value::Object(map),
214                index: i,
215            });
216        }
217
218        Ok(records)
219    }
220}
221
222fn anyvalue_to_json(value: &AnyValue) -> serde_json::Value {
223    match value {
224        AnyValue::Null => serde_json::Value::Null,
225        AnyValue::Boolean(b) => serde_json::Value::Bool(*b),
226        AnyValue::String(s) => serde_json::Value::String(s.to_string()),
227        AnyValue::StringOwned(s) => serde_json::Value::String(s.to_string()),
228        AnyValue::Float32(n) => serde_json::Number::from_f64(*n as f64)
229            .map(serde_json::Value::Number)
230            .unwrap_or(serde_json::Value::Null),
231        AnyValue::Float64(n) => serde_json::Number::from_f64(*n)
232            .map(serde_json::Value::Number)
233            .unwrap_or(serde_json::Value::Null),
234        other => serde_json::Value::String(format!("{}", other)),
235    }
236}
237
238pub async fn search_datasets(query: &str, limit: usize) -> Result<Vec<DatasetSearchResult>> {
239    let client = reqwest::Client::new();
240    let url = format!(
241        "https://huggingface.co/api/datasets?search={}&limit={}",
242        query, limit
243    );
244
245    let results: Vec<DatasetSearchResult> = client
246        .get(&url)
247        .send()
248        .await
249        .map_err(|e| Error::Dataset(format!("Search failed: {}", e)))?
250        .json()
251        .await
252        .map_err(|e| Error::Dataset(format!("Failed to parse search results: {}", e)))?;
253
254    Ok(results)
255}
256
257#[derive(Debug, Deserialize)]
258pub struct DatasetSearchResult {
259    pub id: String,
260    #[serde(default)]
261    pub likes: u32,
262    #[serde(default)]
263    pub downloads: u64,
264}
265
266pub async fn get_dataset_info(dataset: &str) -> Result<DatasetInfo> {
267    let client = reqwest::Client::new();
268    let url = format!("{}/info?dataset={}", HF_DATASETS_SERVER, dataset);
269
270    let response: InfoResponse = client
271        .get(&url)
272        .send()
273        .await
274        .map_err(|e| Error::Dataset(format!("Failed to fetch info: {}", e)))?
275        .json()
276        .await
277        .map_err(|e| Error::Dataset(format!("Failed to parse info: {}", e)))?;
278
279    let (_config_name, config) = response
280        .dataset_info
281        .into_iter()
282        .next()
283        .ok_or_else(|| Error::Dataset("No config found".into()))?;
284
285    let splits: Vec<Split> = config
286        .splits
287        .values()
288        .map(|s| Split {
289            name: s.name.clone(),
290            num_rows: s.num_examples,
291        })
292        .collect();
293
294    let total_rows: usize = splits.iter().map(|s| s.num_rows).sum();
295
296    Ok(DatasetInfo {
297        name: dataset.to_string(),
298        description: config.description,
299        num_rows: total_rows,
300        columns: config.features.keys().cloned().collect(),
301        splits,
302    })
303}
304
305pub async fn preview_dataset(
306    dataset: &str,
307    config: Option<&str>,
308    split: &str,
309    rows: usize,
310) -> Result<Vec<serde_json::Value>> {
311    let client = reqwest::Client::new();
312    let config = config.unwrap_or("default");
313    let url = format!(
314        "{}/rows?dataset={}&config={}&split={}&offset=0&length={}",
315        HF_DATASETS_SERVER, dataset, config, split, rows.min(100)
316    );
317
318    #[derive(Deserialize)]
319    struct RowsResponse {
320        rows: Vec<RowEntry>,
321    }
322
323    #[derive(Deserialize)]
324    struct RowEntry {
325        row: serde_json::Value,
326    }
327
328    let response: RowsResponse = client
329        .get(&url)
330        .send()
331        .await
332        .map_err(|e| Error::Dataset(format!("Failed to fetch rows: {}", e)))?
333        .json()
334        .await
335        .map_err(|e| Error::Dataset(format!("Failed to parse rows: {}", e)))?;
336
337    Ok(response.rows.into_iter().map(|r| r.row).collect())
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[tokio::test]
345    async fn test_search_datasets() {
346        let results = search_datasets("sentiment", 5).await.unwrap();
347        assert!(!results.is_empty());
348        assert!(results.len() <= 5);
349    }
350
351    #[tokio::test]
352    async fn test_get_dataset_info() {
353        let info = get_dataset_info("cornell-movie-review-data/rotten_tomatoes")
354            .await
355            .unwrap();
356        assert_eq!(info.name, "cornell-movie-review-data/rotten_tomatoes");
357        assert!(info.columns.contains(&"text".to_string()));
358        assert!(info.columns.contains(&"label".to_string()));
359    }
360
361    #[tokio::test]
362    async fn test_preview_dataset() {
363        let rows = preview_dataset(
364            "cornell-movie-review-data/rotten_tomatoes",
365            None,
366            "train",
367            3,
368        )
369        .await
370        .unwrap();
371        assert_eq!(rows.len(), 3);
372        assert!(rows[0].get("text").is_some());
373    }
374
375    #[test]
376    fn test_huggingface_source_load() {
377        let mut source = HuggingFaceSource::new(
378            "cornell-movie-review-data/rotten_tomatoes".to_string(),
379            None,
380            "train".to_string(),
381            None,
382        )
383        .unwrap();
384
385        let records = source.load(Some(10)).unwrap();
386        assert_eq!(records.len(), 10);
387        assert!(records[0].data.get("text").is_some());
388    }
389}