Skip to main content

alimentar/hf_hub/
download.rs

1//! HuggingFace Hub dataset download functionality.
2
3use std::path::{Path, PathBuf};
4
5use crate::{
6    backend::{HttpBackend, StorageBackend},
7    dataset::ArrowDataset,
8    error::{Error, Result},
9};
10
11/// Base URL for HuggingFace Hub datasets API.
12const HF_HUB_URL: &str = "https://huggingface.co";
13
14/// HuggingFace Hub dataset configuration and loader.
15///
16/// This struct provides a builder pattern for configuring and downloading
17/// datasets from the HuggingFace Hub.
18#[derive(Debug, Clone)]
19pub struct HfDataset {
20    /// Dataset repository name (e.g., "squad", "glue", "openai/gsm8k")
21    repo_id: String,
22    /// Git revision (branch, tag, or commit hash)
23    revision: String,
24    /// Dataset subset/config (optional)
25    subset: Option<String>,
26    /// Data split (train, validation, test)
27    split: Option<String>,
28    /// Local cache directory
29    cache_dir: PathBuf,
30}
31
32impl HfDataset {
33    /// Creates a new builder for a HuggingFace dataset.
34    ///
35    /// # Arguments
36    ///
37    /// * `repo_id` - The dataset repository ID (e.g., "squad", "openai/gsm8k")
38    pub fn builder(repo_id: impl Into<String>) -> HfDatasetBuilder {
39        HfDatasetBuilder::new(repo_id)
40    }
41
42    /// Returns the repository ID.
43    pub fn repo_id(&self) -> &str {
44        &self.repo_id
45    }
46
47    /// Returns the revision being used.
48    pub fn revision(&self) -> &str {
49        &self.revision
50    }
51
52    /// Returns the subset/config if set.
53    pub fn subset(&self) -> Option<&str> {
54        self.subset.as_deref()
55    }
56
57    /// Returns the split if set.
58    pub fn split(&self) -> Option<&str> {
59        self.split.as_deref()
60    }
61
62    /// Returns the cache directory.
63    pub fn cache_dir(&self) -> &Path {
64        &self.cache_dir
65    }
66
67    /// Downloads the dataset and returns an ArrowDataset.
68    ///
69    /// This method:
70    /// 1. Checks the local cache for existing data
71    /// 2. Downloads parquet files from HuggingFace Hub if not cached
72    /// 3. Loads the parquet files into an ArrowDataset
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if:
77    /// - The dataset cannot be found on HuggingFace Hub
78    /// - The download fails
79    /// - The parquet files cannot be parsed
80    pub fn download(&self) -> Result<ArrowDataset> {
81        // Build the URL path for the parquet file
82        let parquet_path = self.build_parquet_path();
83        let cache_file = self.cache_path_for(&parquet_path);
84
85        // Check cache first
86        if cache_file.exists() {
87            return ArrowDataset::from_parquet(&cache_file);
88        }
89
90        // Download from HF Hub
91        let url = self.build_download_url(&parquet_path);
92        let http = HttpBackend::with_timeout(&url, 300)?;
93
94        // The key is empty since we've already built the full URL
95        let data = http.get("")?;
96
97        // Ensure cache directory exists
98        if let Some(parent) = cache_file.parent() {
99            std::fs::create_dir_all(parent).map_err(|e| Error::io(e, parent))?;
100        }
101
102        // Write to cache
103        std::fs::write(&cache_file, &data).map_err(|e| Error::io(e, &cache_file))?;
104
105        // Load from cache
106        ArrowDataset::from_parquet(&cache_file)
107    }
108
109    /// Downloads the dataset to a specific output path.
110    ///
111    /// # Arguments
112    ///
113    /// * `output` - Path where the dataset should be saved
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if the download or save fails.
118    pub fn download_to(&self, output: impl AsRef<Path>) -> Result<ArrowDataset> {
119        let output = output.as_ref();
120        let parquet_path = self.build_parquet_path();
121        let url = self.build_download_url(&parquet_path);
122
123        let http = HttpBackend::with_timeout(&url, 300)?;
124        let data = http.get("")?;
125
126        // Ensure parent directory exists
127        if let Some(parent) = output.parent() {
128            std::fs::create_dir_all(parent).map_err(|e| Error::io(e, parent))?;
129        }
130
131        // Write to output
132        std::fs::write(output, &data).map_err(|e| Error::io(e, output))?;
133
134        // Load and return
135        ArrowDataset::from_parquet(output)
136    }
137
138    /// Builds the parquet file path within the repository.
139    pub(crate) fn build_parquet_path(&self) -> String {
140        let mut path_parts = Vec::new();
141
142        // Add subset/config if present
143        if let Some(ref subset) = self.subset {
144            path_parts.push(subset.clone());
145        } else {
146            path_parts.push("default".to_string());
147        }
148
149        // Add split
150        let split = self.split.as_deref().unwrap_or("train");
151        path_parts.push(format!("{split}.parquet"));
152
153        path_parts.join("/")
154    }
155
156    /// Builds the download URL for a parquet file.
157    pub(crate) fn build_download_url(&self, parquet_path: &str) -> String {
158        format!(
159            "{}/datasets/{}/resolve/{}/data/{}",
160            HF_HUB_URL, self.repo_id, self.revision, parquet_path
161        )
162    }
163
164    /// Returns the cache path for a given parquet path.
165    pub(crate) fn cache_path_for(&self, parquet_path: &str) -> PathBuf {
166        self.cache_dir
167            .join("huggingface")
168            .join("datasets")
169            .join(&self.repo_id)
170            .join(&self.revision)
171            .join(parquet_path)
172    }
173
174    /// Clears the local cache for this dataset.
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if the cache cannot be deleted.
179    pub fn clear_cache(&self) -> Result<()> {
180        let cache_path = self
181            .cache_dir
182            .join("huggingface")
183            .join("datasets")
184            .join(&self.repo_id);
185
186        if cache_path.exists() {
187            std::fs::remove_dir_all(&cache_path).map_err(|e| Error::io(e, &cache_path))?;
188        }
189
190        Ok(())
191    }
192}
193
194/// Builder for configuring HuggingFace dataset downloads.
195#[derive(Debug, Clone)]
196pub struct HfDatasetBuilder {
197    repo_id: String,
198    revision: String,
199    subset: Option<String>,
200    split: Option<String>,
201    cache_dir: Option<PathBuf>,
202}
203
204impl HfDatasetBuilder {
205    /// Creates a new builder with the given repository ID.
206    pub fn new(repo_id: impl Into<String>) -> Self {
207        Self {
208            repo_id: repo_id.into(),
209            revision: "main".to_string(),
210            subset: None,
211            split: None,
212            cache_dir: None,
213        }
214    }
215
216    /// Sets the Git revision (branch, tag, or commit hash).
217    ///
218    /// Default is "main".
219    #[must_use]
220    pub fn revision(mut self, revision: impl Into<String>) -> Self {
221        self.revision = revision.into();
222        self
223    }
224
225    /// Sets the dataset subset/configuration.
226    ///
227    /// Some datasets have multiple configurations (e.g., "glue" has "cola",
228    /// "sst2", etc.)
229    #[must_use]
230    pub fn subset(mut self, subset: impl Into<String>) -> Self {
231        self.subset = Some(subset.into());
232        self
233    }
234
235    /// Sets the data split to download.
236    ///
237    /// Common values: "train", "validation", "test"
238    #[must_use]
239    pub fn split(mut self, split: impl Into<String>) -> Self {
240        self.split = Some(split.into());
241        self
242    }
243
244    /// Sets the local cache directory.
245    ///
246    /// Default is `~/.cache/alimentar` on Unix or `%LOCALAPPDATA%/alimentar` on
247    /// Windows.
248    #[must_use]
249    pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
250        self.cache_dir = Some(path.into());
251        self
252    }
253
254    /// Builds the HfDataset configuration.
255    ///
256    /// # Errors
257    ///
258    /// Returns an error if required fields are missing or invalid.
259    pub fn build(self) -> Result<HfDataset> {
260        if self.repo_id.is_empty() {
261            return Err(Error::invalid_config("Repository ID cannot be empty"));
262        }
263
264        let cache_dir = self.cache_dir.unwrap_or_else(default_cache_dir);
265
266        Ok(HfDataset {
267            repo_id: self.repo_id,
268            revision: self.revision,
269            subset: self.subset,
270            split: self.split,
271            cache_dir,
272        })
273    }
274}
275
276/// Returns the default cache directory for the current platform.
277pub(crate) fn default_cache_dir() -> PathBuf {
278    #[cfg(target_os = "windows")]
279    {
280        if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") {
281            return PathBuf::from(local_app_data)
282                .join("alimentar")
283                .join("cache");
284        }
285    }
286
287    #[cfg(not(target_os = "windows"))]
288    {
289        if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
290            return PathBuf::from(xdg_cache).join("alimentar");
291        }
292        if let Ok(home) = std::env::var("HOME") {
293            return PathBuf::from(home).join(".cache").join("alimentar");
294        }
295    }
296
297    // System scratch directory as last-resort cache location
298    std::env::temp_dir().join("alimentar").join("cache")
299}
300
301/// Lists available parquet files for a dataset on HuggingFace Hub.
302///
303/// This function queries the HuggingFace Hub API to list available
304/// parquet files for a given dataset.
305///
306/// # Arguments
307///
308/// * `repo_id` - The dataset repository ID
309/// * `revision` - Git revision (default "main")
310///
311/// # Errors
312///
313/// Returns an error if the HTTP request fails or the response cannot be parsed.
314///
315/// # Note
316///
317/// This function requires the HF Hub API and may be rate-limited.
318pub fn list_dataset_files(repo_id: &str, revision: Option<&str>) -> Result<Vec<String>> {
319    let revision = revision.unwrap_or("main");
320    let url = format!("{}/api/datasets/{}/tree/{}", HF_HUB_URL, repo_id, revision);
321
322    let http = HttpBackend::with_timeout(&url, 30)?;
323    let data = http.get("")?;
324
325    // Parse JSON response
326    let json: serde_json::Value = serde_json::from_slice(&data)
327        .map_err(|e| Error::storage(format!("Failed to parse HF Hub response: {e}")))?;
328
329    let mut parquet_files = Vec::new();
330
331    if let Some(items) = json.as_array() {
332        for item in items {
333            if let Some(path) = item.get("path").and_then(|p| p.as_str()) {
334                if path.ends_with(".parquet") {
335                    parquet_files.push(path.to_string());
336                }
337            }
338        }
339    }
340
341    Ok(parquet_files)
342}
343
344/// Information about a HuggingFace dataset.
345#[derive(Debug, Clone)]
346pub struct DatasetInfo {
347    /// Dataset repository ID
348    pub repo_id: String,
349    /// Available splits
350    pub splits: Vec<String>,
351    /// Available subsets/configs
352    pub subsets: Vec<String>,
353    /// Total download size in bytes (if known)
354    pub download_size: Option<u64>,
355    /// Description
356    pub description: Option<String>,
357}