burn_dataset/source/huggingface/
downloader.rs1use std::fs::{self, create_dir_all};
2use std::path::{Path, PathBuf};
3use std::process::Command;
4
5use crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage};
6
7use sanitize_filename::sanitize;
8use serde::de::DeserializeOwned;
9use thiserror::Error;
10
11const PYTHON_SOURCE: &str = include_str!("importer.py");
12#[cfg(not(target_os = "windows"))]
13const VENV_BIN_PYTHON: &str = "bin/python3";
14#[cfg(target_os = "windows")]
15const VENV_BIN_PYTHON: &str = "Scripts\\python";
16
17#[derive(Error, Debug)]
19pub enum ImporterError {
20 #[error("unknown: `{0}`")]
22 Unknown(String),
23
24 #[error("fail to download python dependencies: `{0}`")]
26 FailToDownloadPythonDependencies(String),
27
28 #[error("sqlite dataset: `{0}`")]
30 SqliteDataset(#[from] SqliteDatasetError),
31
32 #[error("python3 is not installed")]
34 PythonNotInstalled,
35
36 #[error("venv environment is not initialized")]
38 VenvNotInitialized,
39}
40
41pub struct HuggingfaceDatasetLoader {
66 name: String,
67 subset: Option<String>,
68 base_dir: Option<PathBuf>,
69 huggingface_token: Option<String>,
70 huggingface_cache_dir: Option<String>,
71 huggingface_data_dir: Option<String>,
72 trust_remote_code: bool,
73 use_python_venv: bool,
74}
75
76impl HuggingfaceDatasetLoader {
77 pub fn new(name: &str) -> Self {
79 Self {
80 name: name.to_string(),
81 subset: None,
82 base_dir: None,
83 huggingface_token: None,
84 huggingface_cache_dir: None,
85 huggingface_data_dir: None,
86 trust_remote_code: false,
87 use_python_venv: true,
88 }
89 }
90
91 pub fn with_subset(mut self, subset: &str) -> Self {
97 self.subset = Some(subset.to_string());
98 self
99 }
100
101 pub fn with_base_dir(mut self, base_dir: &str) -> Self {
105 self.base_dir = Some(base_dir.into());
106 self
107 }
108
109 pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
113 self.huggingface_token = Some(huggingface_token.to_string());
114 self
115 }
116
117 pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
121 self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
122 self
123 }
124
125 pub fn with_huggingface_data_dir(mut self, huggingface_data_dir: &str) -> Self {
131 self.huggingface_data_dir = Some(huggingface_data_dir.to_string());
132 self
133 }
134
135 pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {
139 self.trust_remote_code = trust_remote_code;
140 self
141 }
142
143 pub fn with_use_python_venv(mut self, use_python_venv: bool) -> Self {
149 self.use_python_venv = use_python_venv;
150 self
151 }
152
153 pub fn dataset<I: DeserializeOwned + Clone>(
155 self,
156 split: &str,
157 ) -> Result<SqliteDataset<I>, ImporterError> {
158 let db_file = self.db_file()?;
159 let dataset = SqliteDataset::from_db_file(db_file, split)?;
160 Ok(dataset)
161 }
162
163 pub fn db_file(self) -> Result<PathBuf, ImporterError> {
167 let base_dir = SqliteDatasetStorage::base_dir(self.base_dir);
169
170 if !base_dir.exists() {
171 create_dir_all(&base_dir).expect("Failed to create base directory");
172 }
173
174 let name = sanitize(self.name.as_str());
176
177 let db_file_name = if let Some(subset) = self.subset.clone() {
179 format!("{name}-{}.db", sanitize(subset.as_str()))
180 } else {
181 format!("{name}.db")
182 };
183
184 let db_file = base_dir.join(db_file_name);
185
186 if !Path::new(&db_file).exists() {
188 import(
189 self.name,
190 self.subset,
191 db_file.clone(),
192 base_dir,
193 self.huggingface_token,
194 self.huggingface_cache_dir,
195 self.huggingface_data_dir,
196 self.trust_remote_code,
197 self.use_python_venv,
198 )?;
199 }
200
201 Ok(db_file)
202 }
203}
204
205#[allow(clippy::too_many_arguments)]
207fn import(
208 name: String,
209 subset: Option<String>,
210 base_file: PathBuf,
211 base_dir: PathBuf,
212 huggingface_token: Option<String>,
213 huggingface_cache_dir: Option<String>,
214 huggingface_data_dir: Option<String>,
215 trust_remote_code: bool,
216 use_python_venv: bool,
217) -> Result<(), ImporterError> {
218 let python_path = if use_python_venv {
219 install_python_deps(&base_dir)?
220 } else {
221 get_python_name()?.into()
222 };
223
224 let mut command = Command::new(python_path);
225
226 command.arg(importer_script_path(&base_dir));
227
228 command.arg("--name");
229 command.arg(name);
230
231 command.arg("--file");
232 command.arg(base_file);
233
234 if let Some(subset) = subset {
235 command.arg("--subset");
236 command.arg(subset);
237 }
238
239 if let Some(huggingface_token) = huggingface_token {
240 command.arg("--token");
241 command.arg(huggingface_token);
242 }
243
244 if let Some(huggingface_cache_dir) = huggingface_cache_dir {
245 command.arg("--cache_dir");
246 command.arg(huggingface_cache_dir);
247 }
248 if let Some(huggingface_data_dir) = huggingface_data_dir {
249 command.arg("--data_dir");
250 command.arg(huggingface_data_dir);
251 }
252 if trust_remote_code {
253 command.arg("--trust_remote_code");
254 command.arg("True");
255 }
256 let mut handle = command.spawn().unwrap();
257
258 let exit_status = handle
259 .wait()
260 .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;
261
262 if !exit_status.success() {
263 return Err(ImporterError::Unknown(format!("{exit_status}")));
264 }
265
266 Ok(())
267}
268
269fn check_python_version_is_3(python: &str) -> bool {
271 let output = Command::new(python).arg("--version").output();
272 match output {
273 Ok(output) => {
274 if output.status.success() {
275 let version_string = String::from_utf8_lossy(&output.stdout);
276 if let Some(index) = version_string.find(' ') {
277 let version = &version_string[index + 1..];
278 version.starts_with("3.")
279 } else {
280 false
281 }
282 } else {
283 false
284 }
285 }
286 Err(_error) => false,
287 }
288}
289
290fn get_python_name() -> Result<&'static str, ImporterError> {
292 let python_name_list = ["python3", "python", "py"];
293 for python_name in python_name_list.iter() {
294 if check_python_version_is_3(python_name) {
295 return Ok(python_name);
296 }
297 }
298 Err(ImporterError::PythonNotInstalled)
299}
300
301fn importer_script_path(base_dir: &Path) -> PathBuf {
302 let path_file = base_dir.join("importer.py");
303
304 fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader");
305 path_file
306}
307
308fn install_python_deps(base_dir: &Path) -> Result<PathBuf, ImporterError> {
309 let venv_dir = base_dir.join("venv");
310 let venv_python_path = venv_dir.join(VENV_BIN_PYTHON);
311 if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
313 let python_name = get_python_name()?;
314 let mut command = Command::new(python_name);
315 command.args([
316 "-m",
317 "venv",
318 venv_dir
319 .as_os_str()
320 .to_str()
321 .expect("Path utf8 conversion should not fail"),
322 ]);
323
324 let mut handle = command.spawn().unwrap();
326
327 handle.wait().map_err(|err| {
328 ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}"))
329 })?;
330 if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
332 return Err(ImporterError::VenvNotInitialized);
333 }
334 }
335
336 let mut ensurepip_cmd = Command::new(&venv_python_path);
337 ensurepip_cmd.args(["-m", "ensurepip", "--upgrade"]);
338 let status = ensurepip_cmd.status().map_err(|err| {
339 ImporterError::FailToDownloadPythonDependencies(format!("failed to run ensurepip: {err}"))
340 })?;
341 if !status.success() {
342 return Err(ImporterError::FailToDownloadPythonDependencies(
343 "ensurepip failed to initialize pip".to_string(),
344 ));
345 }
346
347 let mut command = Command::new(&venv_python_path);
348 command.args([
349 "-m",
350 "pip",
351 "--quiet",
352 "install",
353 "pyarrow",
354 "sqlalchemy",
355 "Pillow",
356 "soundfile",
357 "datasets",
358 ]);
359
360 let mut handle = command.spawn().unwrap();
362 handle
363 .wait()
364 .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")))?;
365
366 Ok(venv_python_path)
367}