alimentar/hf_hub/
download.rs1use std::path::{Path, PathBuf};
4
5use crate::{
6 backend::{HttpBackend, StorageBackend},
7 dataset::ArrowDataset,
8 error::{Error, Result},
9};
10
11const HF_HUB_URL: &str = "https://huggingface.co";
13
14#[derive(Debug, Clone)]
19pub struct HfDataset {
20 repo_id: String,
22 revision: String,
24 subset: Option<String>,
26 split: Option<String>,
28 cache_dir: PathBuf,
30}
31
32impl HfDataset {
33 pub fn builder(repo_id: impl Into<String>) -> HfDatasetBuilder {
39 HfDatasetBuilder::new(repo_id)
40 }
41
42 pub fn repo_id(&self) -> &str {
44 &self.repo_id
45 }
46
47 pub fn revision(&self) -> &str {
49 &self.revision
50 }
51
52 pub fn subset(&self) -> Option<&str> {
54 self.subset.as_deref()
55 }
56
57 pub fn split(&self) -> Option<&str> {
59 self.split.as_deref()
60 }
61
62 pub fn cache_dir(&self) -> &Path {
64 &self.cache_dir
65 }
66
67 pub fn download(&self) -> Result<ArrowDataset> {
81 let parquet_path = self.build_parquet_path();
83 let cache_file = self.cache_path_for(&parquet_path);
84
85 if cache_file.exists() {
87 return ArrowDataset::from_parquet(&cache_file);
88 }
89
90 let url = self.build_download_url(&parquet_path);
92 let http = HttpBackend::with_timeout(&url, 300)?;
93
94 let data = http.get("")?;
96
97 if let Some(parent) = cache_file.parent() {
99 std::fs::create_dir_all(parent).map_err(|e| Error::io(e, parent))?;
100 }
101
102 std::fs::write(&cache_file, &data).map_err(|e| Error::io(e, &cache_file))?;
104
105 ArrowDataset::from_parquet(&cache_file)
107 }
108
109 pub fn download_to(&self, output: impl AsRef<Path>) -> Result<ArrowDataset> {
119 let output = output.as_ref();
120 let parquet_path = self.build_parquet_path();
121 let url = self.build_download_url(&parquet_path);
122
123 let http = HttpBackend::with_timeout(&url, 300)?;
124 let data = http.get("")?;
125
126 if let Some(parent) = output.parent() {
128 std::fs::create_dir_all(parent).map_err(|e| Error::io(e, parent))?;
129 }
130
131 std::fs::write(output, &data).map_err(|e| Error::io(e, output))?;
133
134 ArrowDataset::from_parquet(output)
136 }
137
138 pub(crate) fn build_parquet_path(&self) -> String {
140 let mut path_parts = Vec::new();
141
142 if let Some(ref subset) = self.subset {
144 path_parts.push(subset.clone());
145 } else {
146 path_parts.push("default".to_string());
147 }
148
149 let split = self.split.as_deref().unwrap_or("train");
151 path_parts.push(format!("{split}.parquet"));
152
153 path_parts.join("/")
154 }
155
156 pub(crate) fn build_download_url(&self, parquet_path: &str) -> String {
158 format!(
159 "{}/datasets/{}/resolve/{}/data/{}",
160 HF_HUB_URL, self.repo_id, self.revision, parquet_path
161 )
162 }
163
164 pub(crate) fn cache_path_for(&self, parquet_path: &str) -> PathBuf {
166 self.cache_dir
167 .join("huggingface")
168 .join("datasets")
169 .join(&self.repo_id)
170 .join(&self.revision)
171 .join(parquet_path)
172 }
173
174 pub fn clear_cache(&self) -> Result<()> {
180 let cache_path = self
181 .cache_dir
182 .join("huggingface")
183 .join("datasets")
184 .join(&self.repo_id);
185
186 if cache_path.exists() {
187 std::fs::remove_dir_all(&cache_path).map_err(|e| Error::io(e, &cache_path))?;
188 }
189
190 Ok(())
191 }
192}
193
194#[derive(Debug, Clone)]
196pub struct HfDatasetBuilder {
197 repo_id: String,
198 revision: String,
199 subset: Option<String>,
200 split: Option<String>,
201 cache_dir: Option<PathBuf>,
202}
203
204impl HfDatasetBuilder {
205 pub fn new(repo_id: impl Into<String>) -> Self {
207 Self {
208 repo_id: repo_id.into(),
209 revision: "main".to_string(),
210 subset: None,
211 split: None,
212 cache_dir: None,
213 }
214 }
215
216 #[must_use]
220 pub fn revision(mut self, revision: impl Into<String>) -> Self {
221 self.revision = revision.into();
222 self
223 }
224
225 #[must_use]
230 pub fn subset(mut self, subset: impl Into<String>) -> Self {
231 self.subset = Some(subset.into());
232 self
233 }
234
235 #[must_use]
239 pub fn split(mut self, split: impl Into<String>) -> Self {
240 self.split = Some(split.into());
241 self
242 }
243
244 #[must_use]
249 pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
250 self.cache_dir = Some(path.into());
251 self
252 }
253
254 pub fn build(self) -> Result<HfDataset> {
260 if self.repo_id.is_empty() {
261 return Err(Error::invalid_config("Repository ID cannot be empty"));
262 }
263
264 let cache_dir = self.cache_dir.unwrap_or_else(default_cache_dir);
265
266 Ok(HfDataset {
267 repo_id: self.repo_id,
268 revision: self.revision,
269 subset: self.subset,
270 split: self.split,
271 cache_dir,
272 })
273 }
274}
275
276pub(crate) fn default_cache_dir() -> PathBuf {
278 #[cfg(target_os = "windows")]
279 {
280 if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") {
281 return PathBuf::from(local_app_data)
282 .join("alimentar")
283 .join("cache");
284 }
285 }
286
287 #[cfg(not(target_os = "windows"))]
288 {
289 if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
290 return PathBuf::from(xdg_cache).join("alimentar");
291 }
292 if let Ok(home) = std::env::var("HOME") {
293 return PathBuf::from(home).join(".cache").join("alimentar");
294 }
295 }
296
297 std::env::temp_dir().join("alimentar").join("cache")
299}
300
301pub fn list_dataset_files(repo_id: &str, revision: Option<&str>) -> Result<Vec<String>> {
319 let revision = revision.unwrap_or("main");
320 let url = format!("{}/api/datasets/{}/tree/{}", HF_HUB_URL, repo_id, revision);
321
322 let http = HttpBackend::with_timeout(&url, 30)?;
323 let data = http.get("")?;
324
325 let json: serde_json::Value = serde_json::from_slice(&data)
327 .map_err(|e| Error::storage(format!("Failed to parse HF Hub response: {e}")))?;
328
329 let mut parquet_files = Vec::new();
330
331 if let Some(items) = json.as_array() {
332 for item in items {
333 if let Some(path) = item.get("path").and_then(|p| p.as_str()) {
334 if path.ends_with(".parquet") {
335 parquet_files.push(path.to_string());
336 }
337 }
338 }
339 }
340
341 Ok(parquet_files)
342}
343
344#[derive(Debug, Clone)]
346pub struct DatasetInfo {
347 pub repo_id: String,
349 pub splits: Vec<String>,
351 pub subsets: Vec<String>,
353 pub download_size: Option<u64>,
355 pub description: Option<String>,
357}