use std::path::{Path, PathBuf};
use crate::{
backend::{HttpBackend, StorageBackend},
dataset::ArrowDataset,
error::{Error, Result},
};
const HF_HUB_URL: &str = "https://huggingface.co";
#[derive(Debug, Clone)]
pub struct HfDataset {
repo_id: String,
revision: String,
subset: Option<String>,
split: Option<String>,
cache_dir: PathBuf,
}
impl HfDataset {
pub fn builder(repo_id: impl Into<String>) -> HfDatasetBuilder {
HfDatasetBuilder::new(repo_id)
}
pub fn repo_id(&self) -> &str {
&self.repo_id
}
pub fn revision(&self) -> &str {
&self.revision
}
pub fn subset(&self) -> Option<&str> {
self.subset.as_deref()
}
pub fn split(&self) -> Option<&str> {
self.split.as_deref()
}
pub fn cache_dir(&self) -> &Path {
&self.cache_dir
}
pub fn download(&self) -> Result<ArrowDataset> {
let parquet_path = self.build_parquet_path();
let cache_file = self.cache_path_for(&parquet_path);
if cache_file.exists() {
return ArrowDataset::from_parquet(&cache_file);
}
let url = self.build_download_url(&parquet_path);
let http = HttpBackend::with_timeout(&url, 300)?;
let data = http.get("")?;
if let Some(parent) = cache_file.parent() {
std::fs::create_dir_all(parent).map_err(|e| Error::io(e, parent))?;
}
std::fs::write(&cache_file, &data).map_err(|e| Error::io(e, &cache_file))?;
ArrowDataset::from_parquet(&cache_file)
}
pub fn download_to(&self, output: impl AsRef<Path>) -> Result<ArrowDataset> {
let output = output.as_ref();
let parquet_path = self.build_parquet_path();
let url = self.build_download_url(&parquet_path);
let http = HttpBackend::with_timeout(&url, 300)?;
let data = http.get("")?;
if let Some(parent) = output.parent() {
std::fs::create_dir_all(parent).map_err(|e| Error::io(e, parent))?;
}
std::fs::write(output, &data).map_err(|e| Error::io(e, output))?;
ArrowDataset::from_parquet(output)
}
pub(crate) fn build_parquet_path(&self) -> String {
let mut path_parts = Vec::new();
if let Some(ref subset) = self.subset {
path_parts.push(subset.clone());
} else {
path_parts.push("default".to_string());
}
let split = self.split.as_deref().unwrap_or("train");
path_parts.push(format!("{split}.parquet"));
path_parts.join("/")
}
pub(crate) fn build_download_url(&self, parquet_path: &str) -> String {
format!(
"{}/datasets/{}/resolve/{}/data/{}",
HF_HUB_URL, self.repo_id, self.revision, parquet_path
)
}
pub(crate) fn cache_path_for(&self, parquet_path: &str) -> PathBuf {
self.cache_dir
.join("huggingface")
.join("datasets")
.join(&self.repo_id)
.join(&self.revision)
.join(parquet_path)
}
pub fn clear_cache(&self) -> Result<()> {
let cache_path = self
.cache_dir
.join("huggingface")
.join("datasets")
.join(&self.repo_id);
if cache_path.exists() {
std::fs::remove_dir_all(&cache_path).map_err(|e| Error::io(e, &cache_path))?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct HfDatasetBuilder {
repo_id: String,
revision: String,
subset: Option<String>,
split: Option<String>,
cache_dir: Option<PathBuf>,
}
impl HfDatasetBuilder {
pub fn new(repo_id: impl Into<String>) -> Self {
Self {
repo_id: repo_id.into(),
revision: "main".to_string(),
subset: None,
split: None,
cache_dir: None,
}
}
#[must_use]
pub fn revision(mut self, revision: impl Into<String>) -> Self {
self.revision = revision.into();
self
}
#[must_use]
pub fn subset(mut self, subset: impl Into<String>) -> Self {
self.subset = Some(subset.into());
self
}
#[must_use]
pub fn split(mut self, split: impl Into<String>) -> Self {
self.split = Some(split.into());
self
}
#[must_use]
pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.cache_dir = Some(path.into());
self
}
pub fn build(self) -> Result<HfDataset> {
if self.repo_id.is_empty() {
return Err(Error::invalid_config("Repository ID cannot be empty"));
}
let cache_dir = self.cache_dir.unwrap_or_else(default_cache_dir);
Ok(HfDataset {
repo_id: self.repo_id,
revision: self.revision,
subset: self.subset,
split: self.split,
cache_dir,
})
}
}
pub(crate) fn default_cache_dir() -> PathBuf {
#[cfg(target_os = "windows")]
{
if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") {
return PathBuf::from(local_app_data)
.join("alimentar")
.join("cache");
}
}
#[cfg(not(target_os = "windows"))]
{
if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
return PathBuf::from(xdg_cache).join("alimentar");
}
if let Ok(home) = std::env::var("HOME") {
return PathBuf::from(home).join(".cache").join("alimentar");
}
}
std::env::temp_dir().join("alimentar").join("cache")
}
pub fn list_dataset_files(repo_id: &str, revision: Option<&str>) -> Result<Vec<String>> {
let revision = revision.unwrap_or("main");
let url = format!("{}/api/datasets/{}/tree/{}", HF_HUB_URL, repo_id, revision);
let http = HttpBackend::with_timeout(&url, 30)?;
let data = http.get("")?;
let json: serde_json::Value = serde_json::from_slice(&data)
.map_err(|e| Error::storage(format!("Failed to parse HF Hub response: {e}")))?;
let mut parquet_files = Vec::new();
if let Some(items) = json.as_array() {
for item in items {
if let Some(path) = item.get("path").and_then(|p| p.as_str()) {
if path.ends_with(".parquet") {
parquet_files.push(path.to_string());
}
}
}
}
Ok(parquet_files)
}
#[derive(Debug, Clone)]
pub struct DatasetInfo {
pub repo_id: String,
pub splits: Vec<String>,
pub subsets: Vec<String>,
pub download_size: Option<u64>,
pub description: Option<String>,
}