1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
use std::fs::{self, create_dir_all};
use std::path::{Path, PathBuf};
use std::process::Command;

use crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage};

use sanitize_filename::sanitize;
use serde::de::DeserializeOwned;
use thiserror::Error;

const PYTHON_SOURCE: &str = include_str!("importer.py");
#[cfg(not(target_os = "windows"))]
const VENV_BIN_PYTHON: &str = "bin/python3";
#[cfg(target_os = "windows")]
const VENV_BIN_PYTHON: &str = "Scripts\\python";

/// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader).
#[derive(Error, Debug)]
pub enum ImporterError {
    /// Unknown error.
    #[error("unknown: `{0}`")]
    Unknown(String),

    /// Fail to download python dependencies.
    #[error("fail to download python dependencies: `{0}`")]
    FailToDownloadPythonDependencies(String),

    /// Fail to create sqlite dataset.
    #[error("sqlite dataset: `{0}`")]
    SqliteDataset(#[from] SqliteDatasetError),

    /// python3 is not installed.
    #[error("python3 is not installed")]
    PythonNotInstalled,

    /// venv environment is not initialized.
    #[error("venv environment is not initialized")]
    VenvNotInitialized,
}

/// Load a dataset from [huggingface datasets](https://huggingface.co/datasets).
///
/// The dataset with all splits is stored in a single sqlite database (see [SqliteDataset](SqliteDataset)).
///
/// # Example
/// ```no_run
///  use burn_dataset::HuggingfaceDatasetLoader;
///  use burn_dataset::SqliteDataset;
///  use serde::{Deserialize, Serialize};
///
/// #[derive(Deserialize, Debug, Clone)]
/// struct MNISTItemRaw {
///     pub image_bytes: Vec<u8>,
///     pub label: usize,
/// }
///
///  let train_ds:SqliteDataset<MNISTItemRaw> = HuggingfaceDatasetLoader::new("mnist")
///       .dataset("train")
///       .unwrap();
pub struct HuggingfaceDatasetLoader {
    name: String,
    subset: Option<String>,
    base_dir: Option<PathBuf>,
    huggingface_token: Option<String>,
    huggingface_cache_dir: Option<String>,
}

impl HuggingfaceDatasetLoader {
    /// Create a huggingface dataset loader.
    pub fn new(name: &str) -> Self {
        Self {
            name: name.to_string(),
            subset: None,
            base_dir: None,
            huggingface_token: None,
            huggingface_cache_dir: None,
        }
    }

    /// Create a huggingface dataset loader for a subset of the dataset.
    ///
    /// The subset name must be one of the subsets listed in the dataset page.
    ///
    /// If no subset names are listed, then do not use this method.
    pub fn with_subset(mut self, subset: &str) -> Self {
        self.subset = Some(subset.to_string());
        self
    }

    /// Specify a base directory to store the dataset.
    ///
    /// If not specified, the dataset will be stored in `~/.cache/burn-dataset`.
    pub fn with_base_dir(mut self, base_dir: &str) -> Self {
        self.base_dir = Some(base_dir.into());
        self
    }

    /// Specify a huggingface token to download datasets behind authentication.
    ///
    /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens)
    pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
        self.huggingface_token = Some(huggingface_token.to_string());
        self
    }

    /// Specify a huggingface cache directory to store the downloaded datasets.
    ///
    /// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`.
    pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
        self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
        self
    }

    /// Load the dataset.
    pub fn dataset<I: DeserializeOwned + Clone>(
        self,
        split: &str,
    ) -> Result<SqliteDataset<I>, ImporterError> {
        let db_file = self.db_file()?;
        let dataset = SqliteDataset::from_db_file(db_file, split)?;
        Ok(dataset)
    }

    /// Get the path to the sqlite database file.
    ///
    /// If the database file does not exist, it will be downloaded and imported.
    pub fn db_file(self) -> Result<PathBuf, ImporterError> {
        // determine (and create if needed) the base directory
        let base_dir = SqliteDatasetStorage::base_dir(self.base_dir);

        if !base_dir.exists() {
            create_dir_all(&base_dir).expect("Failed to create base directory");
        }

        //sanitize the name and subset
        let name = sanitize(self.name.as_str());

        // create the db file path
        let db_file_name = if let Some(subset) = self.subset.clone() {
            format!("{}-{}.db", name, sanitize(subset.as_str()))
        } else {
            format!("{}.db", name)
        };

        let db_file = base_dir.join(db_file_name);

        // import the dataset if needed
        if !Path::new(&db_file).exists() {
            import(
                self.name,
                self.subset,
                db_file.clone(),
                base_dir,
                self.huggingface_token,
                self.huggingface_cache_dir,
            )?;
        }

        Ok(db_file)
    }
}

/// Import a dataset from huggingface. The transformed dataset is stored as sqlite database.
fn import(
    name: String,
    subset: Option<String>,
    base_file: PathBuf,
    base_dir: PathBuf,
    huggingface_token: Option<String>,
    huggingface_cache_dir: Option<String>,
) -> Result<(), ImporterError> {
    let venv_python_path = install_python_deps(&base_dir)?;

    let mut command = Command::new(venv_python_path);

    command.arg(importer_script_path(&base_dir));

    command.arg("--name");
    command.arg(name);

    command.arg("--file");
    command.arg(base_file);

    if let Some(subset) = subset {
        command.arg("--subset");
        command.arg(subset);
    }

    if let Some(huggingface_token) = huggingface_token {
        command.arg("--token");
        command.arg(huggingface_token);
    }

    if let Some(huggingface_cache_dir) = huggingface_cache_dir {
        command.arg("--cache_dir");
        command.arg(huggingface_cache_dir);
    }

    let mut handle = command.spawn().unwrap();
    handle
        .wait()
        .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;

    Ok(())
}

/// check python --version output is `Python 3.x.x`
fn check_python_version_is_3(python: &str) -> bool {
    let output = Command::new(python).arg("--version").output();
    match output {
        Ok(output) => {
            if output.status.success() {
                let version_string = String::from_utf8_lossy(&output.stdout);
                if let Some(index) = version_string.find(' ') {
                    let version = &version_string[index + 1..];
                    version.starts_with("3.")
                } else {
                    false
                }
            } else {
                false
            }
        }
        Err(_error) => false,
    }
}

/// get python3 name `python` `python3` or `py`
fn get_python_name() -> Result<&'static str, ImporterError> {
    let python_name_list = ["python3", "python", "py"];
    for python_name in python_name_list.iter() {
        if check_python_version_is_3(python_name) {
            return Ok(python_name);
        }
    }
    Err(ImporterError::PythonNotInstalled)
}

fn importer_script_path(base_dir: &Path) -> PathBuf {
    let path_file = base_dir.join("importer.py");

    fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader");
    path_file
}

fn install_python_deps(base_dir: &Path) -> Result<PathBuf, ImporterError> {
    let venv_dir = base_dir.join("venv");
    let venv_python_path = venv_dir.join(VENV_BIN_PYTHON);
    // If the venv environment is already initialized, skip the initialization.
    if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
        let python_name = get_python_name()?;
        let mut command = Command::new(python_name);
        command.args([
            "-m",
            "venv",
            venv_dir
                .as_os_str()
                .to_str()
                .expect("Path utf8 conversion should not fail"),
        ]);

        // Spawn the venv creation process and wait for it to complete.
        let mut handle = command.spawn().unwrap();

        handle.wait().map_err(|err| {
            ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err))
        })?;
        // Check if the venv environment can be used successfully."
        if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
            return Err(ImporterError::VenvNotInitialized);
        }
    }

    let mut command = Command::new(&venv_python_path);
    command.args([
        "-m",
        "pip",
        "--quiet",
        "install",
        "pyarrow",
        "sqlalchemy",
        "Pillow",
        "soundfile",
        "datasets",
    ]);

    // Spawn the pip install process and wait for it to complete.
    let mut handle = command.spawn().unwrap();
    handle.wait().map_err(|err| {
        ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err))
    })?;

    Ok(venv_python_path)
}