burn_dataset/source/huggingface/
downloader.rs1use std::fs::{self, create_dir_all};
2use std::path::{Path, PathBuf};
3use std::process::Command;
4
5use crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage};
6
7use sanitize_filename::sanitize;
8use serde::de::DeserializeOwned;
9use thiserror::Error;
10
11const PYTHON_SOURCE: &str = include_str!("importer.py");
12#[cfg(not(target_os = "windows"))]
13const VENV_BIN_PYTHON: &str = "bin/python3";
14#[cfg(target_os = "windows")]
15const VENV_BIN_PYTHON: &str = "Scripts\\python";
16
17#[derive(Error, Debug)]
19pub enum ImporterError {
20 #[error("unknown: `{0}`")]
22 Unknown(String),
23
24 #[error("fail to download python dependencies: `{0}`")]
26 FailToDownloadPythonDependencies(String),
27
28 #[error("sqlite dataset: `{0}`")]
30 SqliteDataset(#[from] SqliteDatasetError),
31
32 #[error("python3 is not installed")]
34 PythonNotInstalled,
35
36 #[error("venv environment is not initialized")]
38 VenvNotInitialized,
39}
40
41pub struct HuggingfaceDatasetLoader {
62 name: String,
63 subset: Option<String>,
64 base_dir: Option<PathBuf>,
65 huggingface_token: Option<String>,
66 huggingface_cache_dir: Option<String>,
67 huggingface_data_dir: Option<String>,
68 trust_remote_code: bool,
69 use_python_venv: bool,
70}
71
72impl HuggingfaceDatasetLoader {
73 pub fn new(name: &str) -> Self {
75 Self {
76 name: name.to_string(),
77 subset: None,
78 base_dir: None,
79 huggingface_token: None,
80 huggingface_cache_dir: None,
81 huggingface_data_dir: None,
82 trust_remote_code: false,
83 use_python_venv: true,
84 }
85 }
86
87 pub fn with_subset(mut self, subset: &str) -> Self {
93 self.subset = Some(subset.to_string());
94 self
95 }
96
97 pub fn with_base_dir(mut self, base_dir: &str) -> Self {
101 self.base_dir = Some(base_dir.into());
102 self
103 }
104
105 pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
109 self.huggingface_token = Some(huggingface_token.to_string());
110 self
111 }
112
113 pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
117 self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
118 self
119 }
120
121 pub fn with_huggingface_data_dir(mut self, huggingface_data_dir: &str) -> Self {
127 self.huggingface_data_dir = Some(huggingface_data_dir.to_string());
128 self
129 }
130
131 pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {
135 self.trust_remote_code = trust_remote_code;
136 self
137 }
138
139 pub fn with_use_python_venv(mut self, use_python_venv: bool) -> Self {
145 self.use_python_venv = use_python_venv;
146 self
147 }
148
149 pub fn dataset<I: DeserializeOwned + Clone>(
151 self,
152 split: &str,
153 ) -> Result<SqliteDataset<I>, ImporterError> {
154 let db_file = self.db_file()?;
155 let dataset = SqliteDataset::from_db_file(db_file, split)?;
156 Ok(dataset)
157 }
158
159 pub fn db_file(self) -> Result<PathBuf, ImporterError> {
163 let base_dir = SqliteDatasetStorage::base_dir(self.base_dir);
165
166 if !base_dir.exists() {
167 create_dir_all(&base_dir).expect("Failed to create base directory");
168 }
169
170 let name = sanitize(self.name.as_str());
172
173 let db_file_name = if let Some(subset) = self.subset.clone() {
175 format!("{name}-{}.db", sanitize(subset.as_str()))
176 } else {
177 format!("{name}.db")
178 };
179
180 let db_file = base_dir.join(db_file_name);
181
182 if !Path::new(&db_file).exists() {
184 import(
185 self.name,
186 self.subset,
187 db_file.clone(),
188 base_dir,
189 self.huggingface_token,
190 self.huggingface_cache_dir,
191 self.huggingface_data_dir,
192 self.trust_remote_code,
193 self.use_python_venv,
194 )?;
195 }
196
197 Ok(db_file)
198 }
199}
200
201#[allow(clippy::too_many_arguments)]
203fn import(
204 name: String,
205 subset: Option<String>,
206 base_file: PathBuf,
207 base_dir: PathBuf,
208 huggingface_token: Option<String>,
209 huggingface_cache_dir: Option<String>,
210 huggingface_data_dir: Option<String>,
211 trust_remote_code: bool,
212 use_python_venv: bool,
213) -> Result<(), ImporterError> {
214 let python_path = if use_python_venv {
215 install_python_deps(&base_dir)?
216 } else {
217 get_python_name()?.into()
218 };
219
220 let mut command = Command::new(python_path);
221
222 command.arg(importer_script_path(&base_dir));
223
224 command.arg("--name");
225 command.arg(name);
226
227 command.arg("--file");
228 command.arg(base_file);
229
230 if let Some(subset) = subset {
231 command.arg("--subset");
232 command.arg(subset);
233 }
234
235 if let Some(huggingface_token) = huggingface_token {
236 command.arg("--token");
237 command.arg(huggingface_token);
238 }
239
240 if let Some(huggingface_cache_dir) = huggingface_cache_dir {
241 command.arg("--cache_dir");
242 command.arg(huggingface_cache_dir);
243 }
244 if let Some(huggingface_data_dir) = huggingface_data_dir {
245 command.arg("--data_dir");
246 command.arg(huggingface_data_dir);
247 }
248 if trust_remote_code {
249 command.arg("--trust_remote_code");
250 command.arg("True");
251 }
252 let mut handle = command.spawn().unwrap();
253
254 let exit_status = handle
255 .wait()
256 .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;
257
258 if !exit_status.success() {
259 return Err(ImporterError::Unknown(format!("{exit_status}")));
260 }
261
262 Ok(())
263}
264
265fn check_python_version_is_3(python: &str) -> bool {
267 let output = Command::new(python).arg("--version").output();
268 match output {
269 Ok(output) => {
270 if output.status.success() {
271 let version_string = String::from_utf8_lossy(&output.stdout);
272 if let Some(index) = version_string.find(' ') {
273 let version = &version_string[index + 1..];
274 version.starts_with("3.")
275 } else {
276 false
277 }
278 } else {
279 false
280 }
281 }
282 Err(_error) => false,
283 }
284}
285
286fn get_python_name() -> Result<&'static str, ImporterError> {
288 let python_name_list = ["python3", "python", "py"];
289 for python_name in python_name_list.iter() {
290 if check_python_version_is_3(python_name) {
291 return Ok(python_name);
292 }
293 }
294 Err(ImporterError::PythonNotInstalled)
295}
296
297fn importer_script_path(base_dir: &Path) -> PathBuf {
298 let path_file = base_dir.join("importer.py");
299
300 fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader");
301 path_file
302}
303
304fn install_python_deps(base_dir: &Path) -> Result<PathBuf, ImporterError> {
305 let venv_dir = base_dir.join("venv");
306 let venv_python_path = venv_dir.join(VENV_BIN_PYTHON);
307 if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
309 let python_name = get_python_name()?;
310 let mut command = Command::new(python_name);
311 command.args([
312 "-m",
313 "venv",
314 venv_dir
315 .as_os_str()
316 .to_str()
317 .expect("Path utf8 conversion should not fail"),
318 ]);
319
320 let mut handle = command.spawn().unwrap();
322
323 handle.wait().map_err(|err| {
324 ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}"))
325 })?;
326 if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
328 return Err(ImporterError::VenvNotInitialized);
329 }
330 }
331
332 let mut command = Command::new(&venv_python_path);
333 command.args([
334 "-m",
335 "pip",
336 "--quiet",
337 "install",
338 "pyarrow",
339 "sqlalchemy",
340 "Pillow",
341 "soundfile",
342 "datasets",
343 ]);
344
345 let mut handle = command.spawn().unwrap();
347 handle
348 .wait()
349 .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")))?;
350
351 Ok(venv_python_path)
352}