1use std::path::Path;
4
5use crate::error::{Error, Result};
6
7pub(crate) const HF_API_URL: &str = "https://huggingface.co/api";
9
10#[derive(Debug, Clone)]
25pub struct HfPublisher {
26 repo_id: String,
28 token: Option<String>,
30 private: bool,
32 commit_message: String,
34}
35
36impl HfPublisher {
37 pub fn new(repo_id: impl Into<String>) -> Self {
39 Self {
40 repo_id: repo_id.into(),
41 token: std::env::var("HF_TOKEN").ok(),
42 private: false,
43 commit_message: "Upload via alimentar".to_string(),
44 }
45 }
46
47 #[must_use]
49 pub fn with_token(mut self, token: impl Into<String>) -> Self {
50 self.token = Some(token.into());
51 self
52 }
53
54 #[must_use]
56 pub fn with_private(mut self, private: bool) -> Self {
57 self.private = private;
58 self
59 }
60
61 #[must_use]
63 pub fn with_commit_message(mut self, message: impl Into<String>) -> Self {
64 self.commit_message = message.into();
65 self
66 }
67
68 pub fn repo_id(&self) -> &str {
70 &self.repo_id
71 }
72
73 #[cfg(feature = "http")]
75 pub async fn create_repo(&self) -> Result<()> {
76 let token = self.token.as_ref().ok_or_else(|| {
77 Error::io_no_path(std::io::Error::new(
78 std::io::ErrorKind::PermissionDenied,
79 "HF_TOKEN required for upload",
80 ))
81 })?;
82
83 let (org, name) = if let Some(slash_pos) = self.repo_id.find('/') {
85 let org = &self.repo_id[..slash_pos];
86 let name = &self.repo_id[slash_pos + 1..];
87 (Some(org), name)
88 } else {
89 (None, self.repo_id.as_str())
90 };
91
92 let client = reqwest::Client::new();
93 let url = format!("{}/repos/create", HF_API_URL);
94
95 let mut body = serde_json::json!({
96 "type": "dataset",
97 "name": name,
98 "private": self.private
99 });
100
101 if let Some(org_name) = org {
103 body["organization"] = serde_json::json!(org_name);
104 }
105
106 let response = client
107 .post(&url)
108 .header("Authorization", format!("Bearer {}", token))
109 .json(&body)
110 .send()
111 .await
112 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
113
114 if response.status().is_success() || response.status().as_u16() == 409 {
116 Ok(())
117 } else {
118 let status = response.status();
119 let body = response.text().await.unwrap_or_default();
120 Err(Error::io_no_path(std::io::Error::other(format!(
121 "Failed to create repo: {} - {}",
122 status, body
123 ))))
124 }
125 }
126
127 #[cfg(feature = "hf-hub")]
136 pub async fn upload_file(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
137 if is_binary_file(path_in_repo) {
138 self.upload_file_lfs(path_in_repo, data).await
139 } else {
140 self.upload_file_direct(path_in_repo, data).await
141 }
142 }
143
144 #[cfg(feature = "hf-hub")]
146 async fn upload_file_direct(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
147 let token = self.token.as_ref().ok_or_else(|| {
148 Error::io_no_path(std::io::Error::new(
149 std::io::ErrorKind::PermissionDenied,
150 "HF_TOKEN required for upload",
151 ))
152 })?;
153
154 let client = reqwest::Client::new();
155 let url = format!("{}/datasets/{}/commit/main", HF_API_URL, self.repo_id);
156
157 let ndjson_payload = build_ndjson_upload_payload(&self.commit_message, path_in_repo, data);
158
159 let response = client
160 .post(&url)
161 .header("Authorization", format!("Bearer {}", token))
162 .header("Content-Type", "application/x-ndjson")
163 .body(ndjson_payload)
164 .send()
165 .await
166 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
167
168 if response.status().is_success() {
169 Ok(())
170 } else {
171 let status = response.status();
172 let body = response.text().await.unwrap_or_default();
173 Err(Error::io_no_path(std::io::Error::other(format!(
174 "Failed to upload: {} - {}",
175 status, body
176 ))))
177 }
178 }
179
180 #[cfg(feature = "hf-hub")]
188 async fn upload_file_lfs(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
189 let token = self.token.as_ref().ok_or_else(|| {
190 Error::io_no_path(std::io::Error::new(
191 std::io::ErrorKind::PermissionDenied,
192 "HF_TOKEN required for upload",
193 ))
194 })?;
195
196 let client = reqwest::Client::new();
197
198 let oid = compute_sha256(data);
200 let size = data.len();
201
202 let batch_url = format!(
204 "https://huggingface.co/datasets/{}.git/info/lfs/objects/batch",
205 self.repo_id
206 );
207 let batch_body = build_lfs_batch_request(&oid, size);
208
209 let batch_response = client
210 .post(&batch_url)
211 .header("Authorization", format!("Bearer {}", token))
212 .header("Content-Type", "application/json")
213 .header("Accept", "application/vnd.git-lfs+json")
214 .body(batch_body)
215 .send()
216 .await
217 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
218
219 if !batch_response.status().is_success() {
220 let status = batch_response.status();
221 let body = batch_response.text().await.unwrap_or_default();
222 return Err(Error::io_no_path(std::io::Error::other(format!(
223 "LFS batch API failed: {} - {}",
224 status, body
225 ))));
226 }
227
228 let batch_json: serde_json::Value = batch_response
229 .json()
230 .await
231 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
232
233 let objects = batch_json["objects"].as_array().ok_or_else(|| {
235 Error::io_no_path(std::io::Error::other("Invalid LFS batch response"))
236 })?;
237
238 let object = objects
239 .first()
240 .ok_or_else(|| Error::io_no_path(std::io::Error::other("No object in LFS response")))?;
241
242 let upload_action = object.get("actions").and_then(|a| a.get("upload"));
244
245 if let Some(upload) = upload_action {
246 let upload_url = upload["href"].as_str().ok_or_else(|| {
247 Error::io_no_path(std::io::Error::other("No upload URL in LFS response"))
248 })?;
249
250 let upload_response = client
252 .put(upload_url)
253 .header("Content-Type", "application/octet-stream")
254 .body(data.to_vec())
255 .send()
256 .await
257 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
258
259 if !upload_response.status().is_success() {
260 let status = upload_response.status();
261 let body = upload_response.text().await.unwrap_or_default();
262 return Err(Error::io_no_path(std::io::Error::other(format!(
263 "LFS S3 upload failed: {} - {}",
264 status, body
265 ))));
266 }
267 }
268 let commit_url = format!("{}/datasets/{}/commit/main", HF_API_URL, self.repo_id);
272 let commit_payload =
273 build_ndjson_lfs_commit(&self.commit_message, path_in_repo, &oid, size);
274
275 let commit_response = client
276 .post(&commit_url)
277 .header("Authorization", format!("Bearer {}", token))
278 .header("Content-Type", "application/x-ndjson")
279 .body(commit_payload)
280 .send()
281 .await
282 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
283
284 if commit_response.status().is_success() {
285 Ok(())
286 } else {
287 let status = commit_response.status();
288 let body = commit_response.text().await.unwrap_or_default();
289 Err(Error::io_no_path(std::io::Error::other(format!(
290 "LFS commit failed: {} - {}",
291 status, body
292 ))))
293 }
294 }
295
296 #[cfg(feature = "hf-hub")]
298 pub async fn upload_batch(
299 &self,
300 path_in_repo: &str,
301 batch: &arrow::record_batch::RecordBatch,
302 ) -> Result<()> {
303 use parquet::arrow::ArrowWriter;
304
305 let mut buffer = Vec::new();
307 {
308 let mut writer =
309 ArrowWriter::try_new(&mut buffer, batch.schema(), None).map_err(Error::Parquet)?;
310 writer.write(batch).map_err(Error::Parquet)?;
311 writer.close().map_err(Error::Parquet)?;
312 }
313
314 self.upload_file(path_in_repo, &buffer).await
315 }
316
317 #[cfg(feature = "hf-hub")]
319 pub async fn upload_parquet_file(&self, local_path: &Path, path_in_repo: &str) -> Result<()> {
320 let data = std::fs::read(local_path).map_err(|e| Error::io(e, local_path))?;
321 self.upload_file(path_in_repo, &data).await
322 }
323
324 #[cfg(all(feature = "http", feature = "tokio-runtime"))]
326 pub fn create_repo_sync(&self) -> Result<()> {
327 tokio::runtime::Runtime::new()
328 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
329 .block_on(self.create_repo())
330 }
331
332 #[cfg(all(feature = "hf-hub", feature = "tokio-runtime"))]
334 pub fn upload_file_sync(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
335 tokio::runtime::Runtime::new()
336 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
337 .block_on(self.upload_file(path_in_repo, data))
338 }
339
340 #[cfg(all(feature = "hf-hub", feature = "tokio-runtime"))]
342 pub fn upload_parquet_file_sync(&self, local_path: &Path, path_in_repo: &str) -> Result<()> {
343 tokio::runtime::Runtime::new()
344 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
345 .block_on(self.upload_parquet_file(local_path, path_in_repo))
346 }
347
348 #[cfg(feature = "hf-hub")]
357 pub async fn upload_readme_validated(&self, content: &str) -> Result<()> {
358 super::validation::DatasetCardValidator::validate_readme_strict(content)?;
359 self.upload_file("README.md", content.as_bytes()).await
360 }
361
362 #[cfg(all(feature = "hf-hub", feature = "tokio-runtime"))]
364 pub fn upload_readme_validated_sync(&self, content: &str) -> Result<()> {
365 tokio::runtime::Runtime::new()
366 .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
367 .block_on(self.upload_readme_validated(content))
368 }
369}
370
371#[derive(Debug, Clone)]
373pub struct HfPublisherBuilder {
374 repo_id: String,
375 token: Option<String>,
376 private: bool,
377 commit_message: String,
378}
379
380impl HfPublisherBuilder {
381 pub fn new(repo_id: impl Into<String>) -> Self {
383 Self {
384 repo_id: repo_id.into(),
385 token: None,
386 private: false,
387 commit_message: "Upload via alimentar".to_string(),
388 }
389 }
390
391 #[must_use]
393 pub fn token(mut self, token: impl Into<String>) -> Self {
394 self.token = Some(token.into());
395 self
396 }
397
398 #[must_use]
400 pub fn private(mut self, private: bool) -> Self {
401 self.private = private;
402 self
403 }
404
405 #[must_use]
407 pub fn commit_message(mut self, message: impl Into<String>) -> Self {
408 self.commit_message = message.into();
409 self
410 }
411
412 pub fn build(self) -> HfPublisher {
414 HfPublisher {
415 repo_id: self.repo_id,
416 token: self.token.or_else(|| std::env::var("HF_TOKEN").ok()),
417 private: self.private,
418 commit_message: self.commit_message,
419 }
420 }
421}
422
423#[cfg(feature = "hf-hub")]
453pub fn build_ndjson_upload_payload(
454 commit_message: &str,
455 path_in_repo: &str,
456 data: &[u8],
457) -> String {
458 use base64::{engine::general_purpose::STANDARD, Engine};
459
460 let header = serde_json::json!({
462 "key": "header",
463 "value": {
464 "summary": commit_message,
465 "description": ""
466 }
467 });
468
469 let file_op = serde_json::json!({
471 "key": "file",
472 "value": {
473 "content": STANDARD.encode(data),
474 "path": path_in_repo,
475 "encoding": "base64"
476 }
477 });
478
479 format!("{}\n{}", header, file_op)
480}
481
482const BINARY_EXTENSIONS: &[&str] = &[
488 "parquet",
489 "arrow",
490 "bin",
491 "safetensors",
492 "pt",
493 "pth",
494 "onnx",
495 "png",
496 "jpg",
497 "jpeg",
498 "gif",
499 "webp",
500 "bmp",
501 "tiff",
502 "mp3",
503 "wav",
504 "flac",
505 "ogg",
506 "mp4",
507 "webm",
508 "avi",
509 "mkv",
510 "zip",
511 "tar",
512 "gz",
513 "bz2",
514 "xz",
515 "7z",
516 "rar",
517 "pdf",
518 "doc",
519 "docx",
520 "xls",
521 "xlsx",
522 "npy",
523 "npz",
524 "h5",
525 "hdf5",
526 "pkl",
527 "pickle",
528];
529
530pub fn is_binary_file(path: &str) -> bool {
535 path.rsplit('.')
536 .next()
537 .map(|ext| BINARY_EXTENSIONS.contains(&ext.to_lowercase().as_str()))
538 .unwrap_or(false)
539}
540
541#[cfg(feature = "hf-hub")]
545pub fn compute_sha256(data: &[u8]) -> String {
546 use sha2::{Digest, Sha256};
547 let mut hasher = Sha256::new();
548 hasher.update(data);
549 let result = hasher.finalize();
550 hex::encode(result)
551}
552
553#[cfg(feature = "hf-hub")]
564pub fn build_lfs_preupload_request(path: &str, data: &[u8]) -> String {
565 use base64::{engine::general_purpose::STANDARD, Engine};
566
567 let sample_size = std::cmp::min(512, data.len());
569 let sample = STANDARD.encode(&data[..sample_size]);
570
571 let request = serde_json::json!({
572 "files": [{
573 "path": path,
574 "size": data.len(),
575 "sample": sample
576 }]
577 });
578
579 request.to_string()
580}
581
582#[cfg(feature = "hf-hub")]
596pub fn build_lfs_batch_request(oid: &str, size: usize) -> String {
597 let request = serde_json::json!({
598 "operation": "upload",
599 "transfers": ["basic"],
600 "objects": [{
601 "oid": oid,
602 "size": size
603 }]
604 });
605
606 request.to_string()
607}
608
609#[cfg(feature = "hf-hub")]
621pub fn build_ndjson_lfs_commit(
622 commit_message: &str,
623 path_in_repo: &str,
624 oid: &str,
625 size: usize,
626) -> String {
627 let header = serde_json::json!({
629 "key": "header",
630 "value": {
631 "summary": commit_message,
632 "description": ""
633 }
634 });
635
636 let file_op = serde_json::json!({
638 "key": "lfsFile",
639 "value": {
640 "path": path_in_repo,
641 "algo": "sha256",
642 "oid": oid,
643 "size": size
644 }
645 });
646
647 format!("{}\n{}", header, file_op)
648}