use std::path::Path;
use crate::error::{Error, Result};
pub(crate) const HF_API_URL: &str = "https://huggingface.co/api";
#[derive(Debug, Clone)]
pub struct HfPublisher {
repo_id: String,
token: Option<String>,
private: bool,
commit_message: String,
}
impl HfPublisher {
pub fn new(repo_id: impl Into<String>) -> Self {
Self {
repo_id: repo_id.into(),
token: std::env::var("HF_TOKEN").ok(),
private: false,
commit_message: "Upload via alimentar".to_string(),
}
}
#[must_use]
pub fn with_token(mut self, token: impl Into<String>) -> Self {
self.token = Some(token.into());
self
}
#[must_use]
pub fn with_private(mut self, private: bool) -> Self {
self.private = private;
self
}
#[must_use]
pub fn with_commit_message(mut self, message: impl Into<String>) -> Self {
self.commit_message = message.into();
self
}
pub fn repo_id(&self) -> &str {
&self.repo_id
}
#[cfg(feature = "http")]
pub async fn create_repo(&self) -> Result<()> {
let token = self.token.as_ref().ok_or_else(|| {
Error::io_no_path(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"HF_TOKEN required for upload",
))
})?;
let (org, name) = if let Some(slash_pos) = self.repo_id.find('/') {
let org = &self.repo_id[..slash_pos];
let name = &self.repo_id[slash_pos + 1..];
(Some(org), name)
} else {
(None, self.repo_id.as_str())
};
let client = reqwest::Client::new();
let url = format!("{}/repos/create", HF_API_URL);
let mut body = serde_json::json!({
"type": "dataset",
"name": name,
"private": self.private
});
if let Some(org_name) = org {
body["organization"] = serde_json::json!(org_name);
}
let response = client
.post(&url)
.header("Authorization", format!("Bearer {}", token))
.json(&body)
.send()
.await
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
if response.status().is_success() || response.status().as_u16() == 409 {
Ok(())
} else {
let status = response.status();
let body = response.text().await.unwrap_or_default();
Err(Error::io_no_path(std::io::Error::other(format!(
"Failed to create repo: {} - {}",
status, body
))))
}
}
#[cfg(feature = "hf-hub")]
pub async fn upload_file(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
if is_binary_file(path_in_repo) {
self.upload_file_lfs(path_in_repo, data).await
} else {
self.upload_file_direct(path_in_repo, data).await
}
}
#[cfg(feature = "hf-hub")]
async fn upload_file_direct(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
let token = self.token.as_ref().ok_or_else(|| {
Error::io_no_path(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"HF_TOKEN required for upload",
))
})?;
let client = reqwest::Client::new();
let url = format!("{}/datasets/{}/commit/main", HF_API_URL, self.repo_id);
let ndjson_payload = build_ndjson_upload_payload(&self.commit_message, path_in_repo, data);
let response = client
.post(&url)
.header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/x-ndjson")
.body(ndjson_payload)
.send()
.await
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
if response.status().is_success() {
Ok(())
} else {
let status = response.status();
let body = response.text().await.unwrap_or_default();
Err(Error::io_no_path(std::io::Error::other(format!(
"Failed to upload: {} - {}",
status, body
))))
}
}
#[cfg(feature = "hf-hub")]
async fn upload_file_lfs(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
let token = self.token.as_ref().ok_or_else(|| {
Error::io_no_path(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"HF_TOKEN required for upload",
))
})?;
let client = reqwest::Client::new();
let oid = compute_sha256(data);
let size = data.len();
let batch_url = format!(
"https://huggingface.co/datasets/{}.git/info/lfs/objects/batch",
self.repo_id
);
let batch_body = build_lfs_batch_request(&oid, size);
let batch_response = client
.post(&batch_url)
.header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/json")
.header("Accept", "application/vnd.git-lfs+json")
.body(batch_body)
.send()
.await
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
if !batch_response.status().is_success() {
let status = batch_response.status();
let body = batch_response.text().await.unwrap_or_default();
return Err(Error::io_no_path(std::io::Error::other(format!(
"LFS batch API failed: {} - {}",
status, body
))));
}
let batch_json: serde_json::Value = batch_response
.json()
.await
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
let objects = batch_json["objects"].as_array().ok_or_else(|| {
Error::io_no_path(std::io::Error::other("Invalid LFS batch response"))
})?;
let object = objects
.first()
.ok_or_else(|| Error::io_no_path(std::io::Error::other("No object in LFS response")))?;
let upload_action = object.get("actions").and_then(|a| a.get("upload"));
if let Some(upload) = upload_action {
let upload_url = upload["href"].as_str().ok_or_else(|| {
Error::io_no_path(std::io::Error::other("No upload URL in LFS response"))
})?;
let upload_response = client
.put(upload_url)
.header("Content-Type", "application/octet-stream")
.body(data.to_vec())
.send()
.await
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
if !upload_response.status().is_success() {
let status = upload_response.status();
let body = upload_response.text().await.unwrap_or_default();
return Err(Error::io_no_path(std::io::Error::other(format!(
"LFS S3 upload failed: {} - {}",
status, body
))));
}
}
let commit_url = format!("{}/datasets/{}/commit/main", HF_API_URL, self.repo_id);
let commit_payload =
build_ndjson_lfs_commit(&self.commit_message, path_in_repo, &oid, size);
let commit_response = client
.post(&commit_url)
.header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/x-ndjson")
.body(commit_payload)
.send()
.await
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
if commit_response.status().is_success() {
Ok(())
} else {
let status = commit_response.status();
let body = commit_response.text().await.unwrap_or_default();
Err(Error::io_no_path(std::io::Error::other(format!(
"LFS commit failed: {} - {}",
status, body
))))
}
}
#[cfg(feature = "hf-hub")]
pub async fn upload_batch(
&self,
path_in_repo: &str,
batch: &arrow::record_batch::RecordBatch,
) -> Result<()> {
use parquet::arrow::ArrowWriter;
let mut buffer = Vec::new();
{
let mut writer =
ArrowWriter::try_new(&mut buffer, batch.schema(), None).map_err(Error::Parquet)?;
writer.write(batch).map_err(Error::Parquet)?;
writer.close().map_err(Error::Parquet)?;
}
self.upload_file(path_in_repo, &buffer).await
}
#[cfg(feature = "hf-hub")]
pub async fn upload_parquet_file(&self, local_path: &Path, path_in_repo: &str) -> Result<()> {
let data = std::fs::read(local_path).map_err(|e| Error::io(e, local_path))?;
self.upload_file(path_in_repo, &data).await
}
#[cfg(all(feature = "http", feature = "tokio-runtime"))]
pub fn create_repo_sync(&self) -> Result<()> {
tokio::runtime::Runtime::new()
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
.block_on(self.create_repo())
}
#[cfg(all(feature = "hf-hub", feature = "tokio-runtime"))]
pub fn upload_file_sync(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
tokio::runtime::Runtime::new()
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
.block_on(self.upload_file(path_in_repo, data))
}
#[cfg(all(feature = "hf-hub", feature = "tokio-runtime"))]
pub fn upload_parquet_file_sync(&self, local_path: &Path, path_in_repo: &str) -> Result<()> {
tokio::runtime::Runtime::new()
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
.block_on(self.upload_parquet_file(local_path, path_in_repo))
}
#[cfg(feature = "hf-hub")]
pub async fn upload_readme_validated(&self, content: &str) -> Result<()> {
super::validation::DatasetCardValidator::validate_readme_strict(content)?;
self.upload_file("README.md", content.as_bytes()).await
}
#[cfg(all(feature = "hf-hub", feature = "tokio-runtime"))]
pub fn upload_readme_validated_sync(&self, content: &str) -> Result<()> {
tokio::runtime::Runtime::new()
.map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
.block_on(self.upload_readme_validated(content))
}
}
#[derive(Debug, Clone)]
pub struct HfPublisherBuilder {
repo_id: String,
token: Option<String>,
private: bool,
commit_message: String,
}
impl HfPublisherBuilder {
pub fn new(repo_id: impl Into<String>) -> Self {
Self {
repo_id: repo_id.into(),
token: None,
private: false,
commit_message: "Upload via alimentar".to_string(),
}
}
#[must_use]
pub fn token(mut self, token: impl Into<String>) -> Self {
self.token = Some(token.into());
self
}
#[must_use]
pub fn private(mut self, private: bool) -> Self {
self.private = private;
self
}
#[must_use]
pub fn commit_message(mut self, message: impl Into<String>) -> Self {
self.commit_message = message.into();
self
}
pub fn build(self) -> HfPublisher {
HfPublisher {
repo_id: self.repo_id,
token: self.token.or_else(|| std::env::var("HF_TOKEN").ok()),
private: self.private,
commit_message: self.commit_message,
}
}
}
#[cfg(feature = "hf-hub")]
pub fn build_ndjson_upload_payload(
commit_message: &str,
path_in_repo: &str,
data: &[u8],
) -> String {
use base64::{engine::general_purpose::STANDARD, Engine};
let header = serde_json::json!({
"key": "header",
"value": {
"summary": commit_message,
"description": ""
}
});
let file_op = serde_json::json!({
"key": "file",
"value": {
"content": STANDARD.encode(data),
"path": path_in_repo,
"encoding": "base64"
}
});
format!("{}\n{}", header, file_op)
}
const BINARY_EXTENSIONS: &[&str] = &[
"parquet",
"arrow",
"bin",
"safetensors",
"pt",
"pth",
"onnx",
"png",
"jpg",
"jpeg",
"gif",
"webp",
"bmp",
"tiff",
"mp3",
"wav",
"flac",
"ogg",
"mp4",
"webm",
"avi",
"mkv",
"zip",
"tar",
"gz",
"bz2",
"xz",
"7z",
"rar",
"pdf",
"doc",
"docx",
"xls",
"xlsx",
"npy",
"npz",
"h5",
"hdf5",
"pkl",
"pickle",
];
pub fn is_binary_file(path: &str) -> bool {
path.rsplit('.')
.next()
.map(|ext| BINARY_EXTENSIONS.contains(&ext.to_lowercase().as_str()))
.unwrap_or(false)
}
#[cfg(feature = "hf-hub")]
pub fn compute_sha256(data: &[u8]) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(data);
let result = hasher.finalize();
hex::encode(result)
}
#[cfg(feature = "hf-hub")]
pub fn build_lfs_preupload_request(path: &str, data: &[u8]) -> String {
use base64::{engine::general_purpose::STANDARD, Engine};
let sample_size = std::cmp::min(512, data.len());
let sample = STANDARD.encode(&data[..sample_size]);
let request = serde_json::json!({
"files": [{
"path": path,
"size": data.len(),
"sample": sample
}]
});
request.to_string()
}
#[cfg(feature = "hf-hub")]
pub fn build_lfs_batch_request(oid: &str, size: usize) -> String {
let request = serde_json::json!({
"operation": "upload",
"transfers": ["basic"],
"objects": [{
"oid": oid,
"size": size
}]
});
request.to_string()
}
#[cfg(feature = "hf-hub")]
pub fn build_ndjson_lfs_commit(
commit_message: &str,
path_in_repo: &str,
oid: &str,
size: usize,
) -> String {
let header = serde_json::json!({
"key": "header",
"value": {
"summary": commit_message,
"description": ""
}
});
let file_op = serde_json::json!({
"key": "lfsFile",
"value": {
"path": path_in_repo,
"algo": "sha256",
"oid": oid,
"size": size
}
});
format!("{}\n{}", header, file_op)
}