burn_dataset/source/huggingface/
downloader.rs

1use std::fs::{self, create_dir_all};
2use std::path::{Path, PathBuf};
3use std::process::Command;
4
5use crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage};
6
7use sanitize_filename::sanitize;
8use serde::de::DeserializeOwned;
9use thiserror::Error;
10
11const PYTHON_SOURCE: &str = include_str!("importer.py");
12#[cfg(not(target_os = "windows"))]
13const VENV_BIN_PYTHON: &str = "bin/python3";
14#[cfg(target_os = "windows")]
15const VENV_BIN_PYTHON: &str = "Scripts\\python";
16
17/// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader).
18#[derive(Error, Debug)]
19pub enum ImporterError {
20    /// Unknown error.
21    #[error("unknown: `{0}`")]
22    Unknown(String),
23
24    /// Fail to download python dependencies.
25    #[error("fail to download python dependencies: `{0}`")]
26    FailToDownloadPythonDependencies(String),
27
28    /// Fail to create sqlite dataset.
29    #[error("sqlite dataset: `{0}`")]
30    SqliteDataset(#[from] SqliteDatasetError),
31
32    /// python3 is not installed.
33    #[error("python3 is not installed")]
34    PythonNotInstalled,
35
36    /// venv environment is not initialized.
37    #[error("venv environment is not initialized")]
38    VenvNotInitialized,
39}
40
41/// Load a dataset from [huggingface datasets](https://huggingface.co/datasets).
42///
43/// The dataset with all splits is stored in a single sqlite database (see [SqliteDataset](SqliteDataset)).
44///
45/// # Example
46/// ```no_run
47///  use burn_dataset::HuggingfaceDatasetLoader;
48///  use burn_dataset::SqliteDataset;
49///  use serde::{Deserialize, Serialize};
50///
51/// #[derive(Deserialize, Debug, Clone)]
52/// struct MnistItemRaw {
53///     pub image_bytes: Vec<u8>,
54///     pub label: usize,
55/// }
56///
57///  let train_ds:SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
58///       .dataset("train")
59///       .unwrap();
60/// ```
61pub struct HuggingfaceDatasetLoader {
62    name: String,
63    subset: Option<String>,
64    base_dir: Option<PathBuf>,
65    huggingface_token: Option<String>,
66    huggingface_cache_dir: Option<String>,
67    huggingface_data_dir: Option<String>,
68    trust_remote_code: bool,
69    use_python_venv: bool,
70}
71
72impl HuggingfaceDatasetLoader {
73    /// Create a huggingface dataset loader.
74    pub fn new(name: &str) -> Self {
75        Self {
76            name: name.to_string(),
77            subset: None,
78            base_dir: None,
79            huggingface_token: None,
80            huggingface_cache_dir: None,
81            huggingface_data_dir: None,
82            trust_remote_code: false,
83            use_python_venv: true,
84        }
85    }
86
87    /// Create a huggingface dataset loader for a subset of the dataset.
88    ///
89    /// The subset name must be one of the subsets listed in the dataset page.
90    ///
91    /// If no subset names are listed, then do not use this method.
92    pub fn with_subset(mut self, subset: &str) -> Self {
93        self.subset = Some(subset.to_string());
94        self
95    }
96
97    /// Specify a base directory to store the dataset.
98    ///
99    /// If not specified, the dataset will be stored in `~/.cache/burn-dataset`.
100    pub fn with_base_dir(mut self, base_dir: &str) -> Self {
101        self.base_dir = Some(base_dir.into());
102        self
103    }
104
105    /// Specify a huggingface token to download datasets behind authentication.
106    ///
107    /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens)
108    pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
109        self.huggingface_token = Some(huggingface_token.to_string());
110        self
111    }
112
113    /// Specify a huggingface cache directory to store the downloaded datasets.
114    ///
115    /// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`.
116    pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
117        self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
118        self
119    }
120
121    /// Specify a relative path to a subset of a dataset. This is used in some datasets for the
122    /// manual steps of dataset download process.
123    ///
124    /// Unless you've encountered a ManualDownloadError
125    /// when loading your dataset you probably don't have to worry about this setting.
126    pub fn with_huggingface_data_dir(mut self, huggingface_data_dir: &str) -> Self {
127        self.huggingface_data_dir = Some(huggingface_data_dir.to_string());
128        self
129    }
130
131    /// Specify whether or not to trust remote code.
132    ///
133    /// If not specified, trust remote code is set to true.
134    pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {
135        self.trust_remote_code = trust_remote_code;
136        self
137    }
138
139    /// Specify whether or not to use the burn-dataset Python
140    /// virtualenv for running the importer script. If false, local
141    /// `python3`'s environment is used.
142    ///
143    /// If not specified, the virtualenv is used.
144    pub fn with_use_python_venv(mut self, use_python_venv: bool) -> Self {
145        self.use_python_venv = use_python_venv;
146        self
147    }
148
149    /// Load the dataset.
150    pub fn dataset<I: DeserializeOwned + Clone>(
151        self,
152        split: &str,
153    ) -> Result<SqliteDataset<I>, ImporterError> {
154        let db_file = self.db_file()?;
155        let dataset = SqliteDataset::from_db_file(db_file, split)?;
156        Ok(dataset)
157    }
158
159    /// Get the path to the sqlite database file.
160    ///
161    /// If the database file does not exist, it will be downloaded and imported.
162    pub fn db_file(self) -> Result<PathBuf, ImporterError> {
163        // determine (and create if needed) the base directory
164        let base_dir = SqliteDatasetStorage::base_dir(self.base_dir);
165
166        if !base_dir.exists() {
167            create_dir_all(&base_dir).expect("Failed to create base directory");
168        }
169
170        //sanitize the name and subset
171        let name = sanitize(self.name.as_str());
172
173        // create the db file path
174        let db_file_name = if let Some(subset) = self.subset.clone() {
175            format!("{name}-{}.db", sanitize(subset.as_str()))
176        } else {
177            format!("{name}.db")
178        };
179
180        let db_file = base_dir.join(db_file_name);
181
182        // import the dataset if needed
183        if !Path::new(&db_file).exists() {
184            import(
185                self.name,
186                self.subset,
187                db_file.clone(),
188                base_dir,
189                self.huggingface_token,
190                self.huggingface_cache_dir,
191                self.huggingface_data_dir,
192                self.trust_remote_code,
193                self.use_python_venv,
194            )?;
195        }
196
197        Ok(db_file)
198    }
199}
200
201/// Import a dataset from huggingface. The transformed dataset is stored as sqlite database.
202#[allow(clippy::too_many_arguments)]
203fn import(
204    name: String,
205    subset: Option<String>,
206    base_file: PathBuf,
207    base_dir: PathBuf,
208    huggingface_token: Option<String>,
209    huggingface_cache_dir: Option<String>,
210    huggingface_data_dir: Option<String>,
211    trust_remote_code: bool,
212    use_python_venv: bool,
213) -> Result<(), ImporterError> {
214    let python_path = if use_python_venv {
215        install_python_deps(&base_dir)?
216    } else {
217        get_python_name()?.into()
218    };
219
220    let mut command = Command::new(python_path);
221
222    command.arg(importer_script_path(&base_dir));
223
224    command.arg("--name");
225    command.arg(name);
226
227    command.arg("--file");
228    command.arg(base_file);
229
230    if let Some(subset) = subset {
231        command.arg("--subset");
232        command.arg(subset);
233    }
234
235    if let Some(huggingface_token) = huggingface_token {
236        command.arg("--token");
237        command.arg(huggingface_token);
238    }
239
240    if let Some(huggingface_cache_dir) = huggingface_cache_dir {
241        command.arg("--cache_dir");
242        command.arg(huggingface_cache_dir);
243    }
244    if let Some(huggingface_data_dir) = huggingface_data_dir {
245        command.arg("--data_dir");
246        command.arg(huggingface_data_dir);
247    }
248    if trust_remote_code {
249        command.arg("--trust_remote_code");
250        command.arg("True");
251    }
252    let mut handle = command.spawn().unwrap();
253
254    let exit_status = handle
255        .wait()
256        .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;
257
258    if !exit_status.success() {
259        return Err(ImporterError::Unknown(format!("{exit_status}")));
260    }
261
262    Ok(())
263}
264
265/// check python --version output is `Python 3.x.x`
266fn check_python_version_is_3(python: &str) -> bool {
267    let output = Command::new(python).arg("--version").output();
268    match output {
269        Ok(output) => {
270            if output.status.success() {
271                let version_string = String::from_utf8_lossy(&output.stdout);
272                if let Some(index) = version_string.find(' ') {
273                    let version = &version_string[index + 1..];
274                    version.starts_with("3.")
275                } else {
276                    false
277                }
278            } else {
279                false
280            }
281        }
282        Err(_error) => false,
283    }
284}
285
286/// get python3 name `python` `python3` or `py`
287fn get_python_name() -> Result<&'static str, ImporterError> {
288    let python_name_list = ["python3", "python", "py"];
289    for python_name in python_name_list.iter() {
290        if check_python_version_is_3(python_name) {
291            return Ok(python_name);
292        }
293    }
294    Err(ImporterError::PythonNotInstalled)
295}
296
297fn importer_script_path(base_dir: &Path) -> PathBuf {
298    let path_file = base_dir.join("importer.py");
299
300    fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader");
301    path_file
302}
303
304fn install_python_deps(base_dir: &Path) -> Result<PathBuf, ImporterError> {
305    let venv_dir = base_dir.join("venv");
306    let venv_python_path = venv_dir.join(VENV_BIN_PYTHON);
307    // If the venv environment is already initialized, skip the initialization.
308    if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
309        let python_name = get_python_name()?;
310        let mut command = Command::new(python_name);
311        command.args([
312            "-m",
313            "venv",
314            venv_dir
315                .as_os_str()
316                .to_str()
317                .expect("Path utf8 conversion should not fail"),
318        ]);
319
320        // Spawn the venv creation process and wait for it to complete.
321        let mut handle = command.spawn().unwrap();
322
323        handle.wait().map_err(|err| {
324            ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}"))
325        })?;
326        // Check if the venv environment can be used successfully."
327        if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
328            return Err(ImporterError::VenvNotInitialized);
329        }
330    }
331
332    let mut command = Command::new(&venv_python_path);
333    command.args([
334        "-m",
335        "pip",
336        "--quiet",
337        "install",
338        "pyarrow",
339        "sqlalchemy",
340        "Pillow",
341        "soundfile",
342        "datasets",
343    ]);
344
345    // Spawn the pip install process and wait for it to complete.
346    let mut handle = command.spawn().unwrap();
347    handle
348        .wait()
349        .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")))?;
350
351    Ok(venv_python_path)
352}