use std::fs::{self, create_dir_all};
use std::path::{Path, PathBuf};
use std::process::Command;
use crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage};
use sanitize_filename::sanitize;
use serde::de::DeserializeOwned;
use thiserror::Error;
const PYTHON_SOURCE: &str = include_str!("importer.py");
#[cfg(not(target_os = "windows"))]
const VENV_BIN_PYTHON: &str = "bin/python3";
#[cfg(target_os = "windows")]
const VENV_BIN_PYTHON: &str = "Scripts\\python";
#[derive(Error, Debug)]
pub enum ImporterError {
#[error("unknown: `{0}`")]
Unknown(String),
#[error("fail to download python dependencies: `{0}`")]
FailToDownloadPythonDependencies(String),
#[error("sqlite dataset: `{0}`")]
SqliteDataset(#[from] SqliteDatasetError),
#[error("python3 is not installed")]
PythonNotInstalled,
#[error("venv environment is not initialized")]
VenvNotInitialized,
}
pub struct HuggingfaceDatasetLoader {
name: String,
subset: Option<String>,
base_dir: Option<PathBuf>,
huggingface_token: Option<String>,
huggingface_cache_dir: Option<String>,
}
impl HuggingfaceDatasetLoader {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
subset: None,
base_dir: None,
huggingface_token: None,
huggingface_cache_dir: None,
}
}
pub fn with_subset(mut self, subset: &str) -> Self {
self.subset = Some(subset.to_string());
self
}
pub fn with_base_dir(mut self, base_dir: &str) -> Self {
self.base_dir = Some(base_dir.into());
self
}
pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
self.huggingface_token = Some(huggingface_token.to_string());
self
}
pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
self
}
pub fn dataset<I: DeserializeOwned + Clone>(
self,
split: &str,
) -> Result<SqliteDataset<I>, ImporterError> {
let db_file = self.db_file()?;
let dataset = SqliteDataset::from_db_file(db_file, split)?;
Ok(dataset)
}
pub fn db_file(self) -> Result<PathBuf, ImporterError> {
let base_dir = SqliteDatasetStorage::base_dir(self.base_dir);
if !base_dir.exists() {
create_dir_all(&base_dir).expect("Failed to create base directory");
}
let name = sanitize(self.name.as_str());
let db_file_name = if let Some(subset) = self.subset.clone() {
format!("{}-{}.db", name, sanitize(subset.as_str()))
} else {
format!("{}.db", name)
};
let db_file = base_dir.join(db_file_name);
if !Path::new(&db_file).exists() {
import(
self.name,
self.subset,
db_file.clone(),
base_dir,
self.huggingface_token,
self.huggingface_cache_dir,
)?;
}
Ok(db_file)
}
}
fn import(
name: String,
subset: Option<String>,
base_file: PathBuf,
base_dir: PathBuf,
huggingface_token: Option<String>,
huggingface_cache_dir: Option<String>,
) -> Result<(), ImporterError> {
let venv_python_path = install_python_deps(&base_dir)?;
let mut command = Command::new(venv_python_path);
command.arg(importer_script_path(&base_dir));
command.arg("--name");
command.arg(name);
command.arg("--file");
command.arg(base_file);
if let Some(subset) = subset {
command.arg("--subset");
command.arg(subset);
}
if let Some(huggingface_token) = huggingface_token {
command.arg("--token");
command.arg(huggingface_token);
}
if let Some(huggingface_cache_dir) = huggingface_cache_dir {
command.arg("--cache_dir");
command.arg(huggingface_cache_dir);
}
let mut handle = command.spawn().unwrap();
handle
.wait()
.map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;
Ok(())
}
fn check_python_version_is_3(python: &str) -> bool {
let output = Command::new(python).arg("--version").output();
match output {
Ok(output) => {
if output.status.success() {
let version_string = String::from_utf8_lossy(&output.stdout);
if let Some(index) = version_string.find(' ') {
let version = &version_string[index + 1..];
version.starts_with("3.")
} else {
false
}
} else {
false
}
}
Err(_error) => false,
}
}
fn get_python_name() -> Result<&'static str, ImporterError> {
let python_name_list = ["python3", "python", "py"];
for python_name in python_name_list.iter() {
if check_python_version_is_3(python_name) {
return Ok(python_name);
}
}
Err(ImporterError::PythonNotInstalled)
}
fn importer_script_path(base_dir: &Path) -> PathBuf {
let path_file = base_dir.join("importer.py");
fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader");
path_file
}
fn install_python_deps(base_dir: &Path) -> Result<PathBuf, ImporterError> {
let venv_dir = base_dir.join("venv");
let venv_python_path = venv_dir.join(VENV_BIN_PYTHON);
if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
let python_name = get_python_name()?;
let mut command = Command::new(python_name);
command.args([
"-m",
"venv",
venv_dir
.as_os_str()
.to_str()
.expect("Path utf8 conversion should not fail"),
]);
let mut handle = command.spawn().unwrap();
handle.wait().map_err(|err| {
ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err))
})?;
if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
return Err(ImporterError::VenvNotInitialized);
}
}
let mut command = Command::new(&venv_python_path);
command.args([
"-m",
"pip",
"--quiet",
"install",
"pyarrow",
"sqlalchemy",
"Pillow",
"soundfile",
"datasets",
]);
let mut handle = command.spawn().unwrap();
handle.wait().map_err(|err| {
ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err))
})?;
Ok(venv_python_path)
}