burn_dataset/source/huggingface/
downloader.rs

1use std::fs::{self, create_dir_all};
2use std::path::{Path, PathBuf};
3use std::process::Command;
4
5use crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage};
6
7use sanitize_filename::sanitize;
8use serde::de::DeserializeOwned;
9use thiserror::Error;
10
11const PYTHON_SOURCE: &str = include_str!("importer.py");
12#[cfg(not(target_os = "windows"))]
13const VENV_BIN_PYTHON: &str = "bin/python3";
14#[cfg(target_os = "windows")]
15const VENV_BIN_PYTHON: &str = "Scripts\\python";
16
17/// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader).
18#[derive(Error, Debug)]
19pub enum ImporterError {
20    /// Unknown error.
21    #[error("unknown: `{0}`")]
22    Unknown(String),
23
24    /// Fail to download python dependencies.
25    #[error("fail to download python dependencies: `{0}`")]
26    FailToDownloadPythonDependencies(String),
27
28    /// Fail to create sqlite dataset.
29    #[error("sqlite dataset: `{0}`")]
30    SqliteDataset(#[from] SqliteDatasetError),
31
32    /// python3 is not installed.
33    #[error("python3 is not installed")]
34    PythonNotInstalled,
35
36    /// venv environment is not initialized.
37    #[error("venv environment is not initialized")]
38    VenvNotInitialized,
39}
40
41/// Load a dataset from [huggingface datasets](https://huggingface.co/datasets).
42///
43/// The dataset with all splits is stored in a single sqlite database (see [SqliteDataset](SqliteDataset)).
44///
45/// # Example
46/// ```no_run
47///  use burn_dataset::HuggingfaceDatasetLoader;
48///  use burn_dataset::SqliteDataset;
49///  use serde::{Deserialize, Serialize};
50///
51/// #[derive(Deserialize, Debug, Clone)]
52/// struct MnistItemRaw {
53///     pub image_bytes: Vec<u8>,
54///     pub label: usize,
55/// }
56///
57///  let train_ds:SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
58///       .dataset("train")
59///       .unwrap();
60/// ```
61///
62/// # Note
63/// This loader relies on the [`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index)
64/// to download datasets. This is a Python library, so you must have an existing Python installation.
65pub struct HuggingfaceDatasetLoader {
66    name: String,
67    subset: Option<String>,
68    base_dir: Option<PathBuf>,
69    huggingface_token: Option<String>,
70    huggingface_cache_dir: Option<String>,
71    huggingface_data_dir: Option<String>,
72    trust_remote_code: bool,
73    use_python_venv: bool,
74}
75
76impl HuggingfaceDatasetLoader {
77    /// Create a huggingface dataset loader.
78    pub fn new(name: &str) -> Self {
79        Self {
80            name: name.to_string(),
81            subset: None,
82            base_dir: None,
83            huggingface_token: None,
84            huggingface_cache_dir: None,
85            huggingface_data_dir: None,
86            trust_remote_code: false,
87            use_python_venv: true,
88        }
89    }
90
91    /// Create a huggingface dataset loader for a subset of the dataset.
92    ///
93    /// The subset name must be one of the subsets listed in the dataset page.
94    ///
95    /// If no subset names are listed, then do not use this method.
96    pub fn with_subset(mut self, subset: &str) -> Self {
97        self.subset = Some(subset.to_string());
98        self
99    }
100
101    /// Specify a base directory to store the dataset.
102    ///
103    /// If not specified, the dataset will be stored in `~/.cache/burn-dataset`.
104    pub fn with_base_dir(mut self, base_dir: &str) -> Self {
105        self.base_dir = Some(base_dir.into());
106        self
107    }
108
109    /// Specify a huggingface token to download datasets behind authentication.
110    ///
111    /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens)
112    pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
113        self.huggingface_token = Some(huggingface_token.to_string());
114        self
115    }
116
117    /// Specify a huggingface cache directory to store the downloaded datasets.
118    ///
119    /// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`.
120    pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
121        self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
122        self
123    }
124
125    /// Specify a relative path to a subset of a dataset. This is used in some datasets for the
126    /// manual steps of dataset download process.
127    ///
128    /// Unless you've encountered a ManualDownloadError
129    /// when loading your dataset you probably don't have to worry about this setting.
130    pub fn with_huggingface_data_dir(mut self, huggingface_data_dir: &str) -> Self {
131        self.huggingface_data_dir = Some(huggingface_data_dir.to_string());
132        self
133    }
134
135    /// Specify whether or not to trust remote code.
136    ///
137    /// If not specified, trust remote code is set to true.
138    pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {
139        self.trust_remote_code = trust_remote_code;
140        self
141    }
142
143    /// Specify whether or not to use the burn-dataset Python
144    /// virtualenv for running the importer script. If false, local
145    /// `python3`'s environment is used.
146    ///
147    /// If not specified, the virtualenv is used.
148    pub fn with_use_python_venv(mut self, use_python_venv: bool) -> Self {
149        self.use_python_venv = use_python_venv;
150        self
151    }
152
153    /// Load the dataset.
154    pub fn dataset<I: DeserializeOwned + Clone>(
155        self,
156        split: &str,
157    ) -> Result<SqliteDataset<I>, ImporterError> {
158        let db_file = self.db_file()?;
159        let dataset = SqliteDataset::from_db_file(db_file, split)?;
160        Ok(dataset)
161    }
162
163    /// Get the path to the sqlite database file.
164    ///
165    /// If the database file does not exist, it will be downloaded and imported.
166    pub fn db_file(self) -> Result<PathBuf, ImporterError> {
167        // determine (and create if needed) the base directory
168        let base_dir = SqliteDatasetStorage::base_dir(self.base_dir);
169
170        if !base_dir.exists() {
171            create_dir_all(&base_dir).expect("Failed to create base directory");
172        }
173
174        //sanitize the name and subset
175        let name = sanitize(self.name.as_str());
176
177        // create the db file path
178        let db_file_name = if let Some(subset) = self.subset.clone() {
179            format!("{name}-{}.db", sanitize(subset.as_str()))
180        } else {
181            format!("{name}.db")
182        };
183
184        let db_file = base_dir.join(db_file_name);
185
186        // import the dataset if needed
187        if !Path::new(&db_file).exists() {
188            import(
189                self.name,
190                self.subset,
191                db_file.clone(),
192                base_dir,
193                self.huggingface_token,
194                self.huggingface_cache_dir,
195                self.huggingface_data_dir,
196                self.trust_remote_code,
197                self.use_python_venv,
198            )?;
199        }
200
201        Ok(db_file)
202    }
203}
204
205/// Import a dataset from huggingface. The transformed dataset is stored as sqlite database.
206#[allow(clippy::too_many_arguments)]
207fn import(
208    name: String,
209    subset: Option<String>,
210    base_file: PathBuf,
211    base_dir: PathBuf,
212    huggingface_token: Option<String>,
213    huggingface_cache_dir: Option<String>,
214    huggingface_data_dir: Option<String>,
215    trust_remote_code: bool,
216    use_python_venv: bool,
217) -> Result<(), ImporterError> {
218    let python_path = if use_python_venv {
219        install_python_deps(&base_dir)?
220    } else {
221        get_python_name()?.into()
222    };
223
224    let mut command = Command::new(python_path);
225
226    command.arg(importer_script_path(&base_dir));
227
228    command.arg("--name");
229    command.arg(name);
230
231    command.arg("--file");
232    command.arg(base_file);
233
234    if let Some(subset) = subset {
235        command.arg("--subset");
236        command.arg(subset);
237    }
238
239    if let Some(huggingface_token) = huggingface_token {
240        command.arg("--token");
241        command.arg(huggingface_token);
242    }
243
244    if let Some(huggingface_cache_dir) = huggingface_cache_dir {
245        command.arg("--cache_dir");
246        command.arg(huggingface_cache_dir);
247    }
248    if let Some(huggingface_data_dir) = huggingface_data_dir {
249        command.arg("--data_dir");
250        command.arg(huggingface_data_dir);
251    }
252    if trust_remote_code {
253        command.arg("--trust_remote_code");
254        command.arg("True");
255    }
256    let mut handle = command.spawn().unwrap();
257
258    let exit_status = handle
259        .wait()
260        .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;
261
262    if !exit_status.success() {
263        return Err(ImporterError::Unknown(format!("{exit_status}")));
264    }
265
266    Ok(())
267}
268
269/// check python --version output is `Python 3.x.x`
270fn check_python_version_is_3(python: &str) -> bool {
271    let output = Command::new(python).arg("--version").output();
272    match output {
273        Ok(output) => {
274            if output.status.success() {
275                let version_string = String::from_utf8_lossy(&output.stdout);
276                if let Some(index) = version_string.find(' ') {
277                    let version = &version_string[index + 1..];
278                    version.starts_with("3.")
279                } else {
280                    false
281                }
282            } else {
283                false
284            }
285        }
286        Err(_error) => false,
287    }
288}
289
290/// get python3 name `python` `python3` or `py`
291fn get_python_name() -> Result<&'static str, ImporterError> {
292    let python_name_list = ["python3", "python", "py"];
293    for python_name in python_name_list.iter() {
294        if check_python_version_is_3(python_name) {
295            return Ok(python_name);
296        }
297    }
298    Err(ImporterError::PythonNotInstalled)
299}
300
301fn importer_script_path(base_dir: &Path) -> PathBuf {
302    let path_file = base_dir.join("importer.py");
303
304    fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader");
305    path_file
306}
307
308fn install_python_deps(base_dir: &Path) -> Result<PathBuf, ImporterError> {
309    let venv_dir = base_dir.join("venv");
310    let venv_python_path = venv_dir.join(VENV_BIN_PYTHON);
311    // If the venv environment is already initialized, skip the initialization.
312    if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
313        let python_name = get_python_name()?;
314        let mut command = Command::new(python_name);
315        command.args([
316            "-m",
317            "venv",
318            venv_dir
319                .as_os_str()
320                .to_str()
321                .expect("Path utf8 conversion should not fail"),
322        ]);
323
324        // Spawn the venv creation process and wait for it to complete.
325        let mut handle = command.spawn().unwrap();
326
327        handle.wait().map_err(|err| {
328            ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}"))
329        })?;
330        // Check if the venv environment can be used successfully."
331        if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
332            return Err(ImporterError::VenvNotInitialized);
333        }
334    }
335
336    let mut ensurepip_cmd = Command::new(&venv_python_path);
337    ensurepip_cmd.args(["-m", "ensurepip", "--upgrade"]);
338    let status = ensurepip_cmd.status().map_err(|err| {
339        ImporterError::FailToDownloadPythonDependencies(format!("failed to run ensurepip: {err}"))
340    })?;
341    if !status.success() {
342        return Err(ImporterError::FailToDownloadPythonDependencies(
343            "ensurepip failed to initialize pip".to_string(),
344        ));
345    }
346
347    let mut command = Command::new(&venv_python_path);
348    command.args([
349        "-m",
350        "pip",
351        "--quiet",
352        "install",
353        "pyarrow",
354        "sqlalchemy",
355        "Pillow",
356        "soundfile",
357        "datasets",
358    ]);
359
360    // Spawn the pip install process and wait for it to complete.
361    let mut handle = command.spawn().unwrap();
362    handle
363        .wait()
364        .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")))?;
365
366    Ok(venv_python_path)
367}