Skip to main content

alimentar/hf_hub/
upload.rs

1//! HuggingFace Hub dataset upload functionality.
2
3use std::path::Path;
4
5use crate::error::{Error, Result};
6
7/// HuggingFace Hub API URL for uploads
8pub(crate) const HF_API_URL: &str = "https://huggingface.co/api";
9
10/// Publisher for uploading datasets to HuggingFace Hub.
11///
12/// # Example
13///
14/// ```no_run
15/// use alimentar::hf_hub::HfPublisher;
16/// use arrow::record_batch::RecordBatch;
17///
18/// let publisher = HfPublisher::new("paiml/my-dataset")
19///     .with_token(std::env::var("HF_TOKEN").unwrap())
20///     .with_private(false);
21///
22/// // publisher.upload_parquet("train.parquet", &batch).unwrap();
23/// ```
24#[derive(Debug, Clone)]
25pub struct HfPublisher {
26    /// Repository ID (e.g., "paiml/depyler-citl")
27    repo_id: String,
28    /// HuggingFace API token
29    token: Option<String>,
30    /// Whether the dataset should be private
31    private: bool,
32    /// Commit message for uploads
33    commit_message: String,
34}
35
36impl HfPublisher {
37    /// Creates a new publisher for a HuggingFace dataset repository.
38    pub fn new(repo_id: impl Into<String>) -> Self {
39        Self {
40            repo_id: repo_id.into(),
41            token: std::env::var("HF_TOKEN").ok(),
42            private: false,
43            commit_message: "Upload via alimentar".to_string(),
44        }
45    }
46
47    /// Sets the HuggingFace API token.
48    #[must_use]
49    pub fn with_token(mut self, token: impl Into<String>) -> Self {
50        self.token = Some(token.into());
51        self
52    }
53
54    /// Sets whether the dataset should be private.
55    #[must_use]
56    pub fn with_private(mut self, private: bool) -> Self {
57        self.private = private;
58        self
59    }
60
61    /// Sets the commit message for uploads.
62    #[must_use]
63    pub fn with_commit_message(mut self, message: impl Into<String>) -> Self {
64        self.commit_message = message.into();
65        self
66    }
67
68    /// Returns the repository ID.
69    pub fn repo_id(&self) -> &str {
70        &self.repo_id
71    }
72
73    /// Creates the dataset repository on HuggingFace Hub if it doesn't exist.
74    #[cfg(feature = "http")]
75    pub async fn create_repo(&self) -> Result<()> {
76        let token = self.token.as_ref().ok_or_else(|| {
77            Error::io_no_path(std::io::Error::new(
78                std::io::ErrorKind::PermissionDenied,
79                "HF_TOKEN required for upload",
80            ))
81        })?;
82
83        // Split repo_id into org/name components
84        let (org, name) = if let Some(slash_pos) = self.repo_id.find('/') {
85            let org = &self.repo_id[..slash_pos];
86            let name = &self.repo_id[slash_pos + 1..];
87            (Some(org), name)
88        } else {
89            (None, self.repo_id.as_str())
90        };
91
92        let client = reqwest::Client::new();
93        let url = format!("{}/repos/create", HF_API_URL);
94
95        let mut body = serde_json::json!({
96            "type": "dataset",
97            "name": name,
98            "private": self.private
99        });
100
101        // Add organization if present
102        if let Some(org_name) = org {
103            body["organization"] = serde_json::json!(org_name);
104        }
105
106        let response = client
107            .post(&url)
108            .header("Authorization", format!("Bearer {}", token))
109            .json(&body)
110            .send()
111            .await
112            .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
113
114        // 409 Conflict means repo already exists, which is fine
115        if response.status().is_success() || response.status().as_u16() == 409 {
116            Ok(())
117        } else {
118            let status = response.status();
119            let body = response.text().await.unwrap_or_default();
120            Err(Error::io_no_path(std::io::Error::other(format!(
121                "Failed to create repo: {} - {}",
122                status, body
123            ))))
124        }
125    }
126
127    /// Uploads a file to the repository.
128    ///
129    /// This method automatically selects the appropriate upload method:
130    /// - **Binary files** (parquet, images, etc.): Uses LFS preupload API
131    /// - **Text files** (README.md, JSON, etc.): Uses direct NDJSON commit API
132    ///
133    /// The official `hf-hub` crate only supports downloads, making this upload
134    /// capability a key differentiator for alimentar.
135    #[cfg(feature = "hf-hub")]
136    pub async fn upload_file(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
137        if is_binary_file(path_in_repo) {
138            self.upload_file_lfs(path_in_repo, data).await
139        } else {
140            self.upload_file_direct(path_in_repo, data).await
141        }
142    }
143
144    /// Uploads a text file directly using the NDJSON commit API.
145    #[cfg(feature = "hf-hub")]
146    async fn upload_file_direct(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
147        let token = self.token.as_ref().ok_or_else(|| {
148            Error::io_no_path(std::io::Error::new(
149                std::io::ErrorKind::PermissionDenied,
150                "HF_TOKEN required for upload",
151            ))
152        })?;
153
154        let client = reqwest::Client::new();
155        let url = format!("{}/datasets/{}/commit/main", HF_API_URL, self.repo_id);
156
157        let ndjson_payload = build_ndjson_upload_payload(&self.commit_message, path_in_repo, data);
158
159        let response = client
160            .post(&url)
161            .header("Authorization", format!("Bearer {}", token))
162            .header("Content-Type", "application/x-ndjson")
163            .body(ndjson_payload)
164            .send()
165            .await
166            .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
167
168        if response.status().is_success() {
169            Ok(())
170        } else {
171            let status = response.status();
172            let body = response.text().await.unwrap_or_default();
173            Err(Error::io_no_path(std::io::Error::other(format!(
174                "Failed to upload: {} - {}",
175                status, body
176            ))))
177        }
178    }
179
180    /// Uploads a binary file using LFS batch API.
181    ///
182    /// Flow:
183    /// 1. Compute SHA256 hash of the file content (OID)
184    /// 2. POST to LFS batch API to get presigned S3 upload URL
185    /// 3. PUT binary content to the S3 URL
186    /// 4. POST commit with lfsFile reference
187    #[cfg(feature = "hf-hub")]
188    async fn upload_file_lfs(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
189        let token = self.token.as_ref().ok_or_else(|| {
190            Error::io_no_path(std::io::Error::new(
191                std::io::ErrorKind::PermissionDenied,
192                "HF_TOKEN required for upload",
193            ))
194        })?;
195
196        let client = reqwest::Client::new();
197
198        // Step 1: Compute SHA256 hash (LFS Object ID)
199        let oid = compute_sha256(data);
200        let size = data.len();
201
202        // Step 2: Call LFS batch API to get presigned S3 upload URL
203        let batch_url = format!(
204            "https://huggingface.co/datasets/{}.git/info/lfs/objects/batch",
205            self.repo_id
206        );
207        let batch_body = build_lfs_batch_request(&oid, size);
208
209        let batch_response = client
210            .post(&batch_url)
211            .header("Authorization", format!("Bearer {}", token))
212            .header("Content-Type", "application/json")
213            .header("Accept", "application/vnd.git-lfs+json")
214            .body(batch_body)
215            .send()
216            .await
217            .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
218
219        if !batch_response.status().is_success() {
220            let status = batch_response.status();
221            let body = batch_response.text().await.unwrap_or_default();
222            return Err(Error::io_no_path(std::io::Error::other(format!(
223                "LFS batch API failed: {} - {}",
224                status, body
225            ))));
226        }
227
228        let batch_json: serde_json::Value = batch_response
229            .json()
230            .await
231            .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
232
233        // Extract S3 upload URL from response: objects[0].actions.upload.href
234        let objects = batch_json["objects"].as_array().ok_or_else(|| {
235            Error::io_no_path(std::io::Error::other("Invalid LFS batch response"))
236        })?;
237
238        let object = objects
239            .first()
240            .ok_or_else(|| Error::io_no_path(std::io::Error::other("No object in LFS response")))?;
241
242        // Check if upload is needed (object might already exist)
243        let upload_action = object.get("actions").and_then(|a| a.get("upload"));
244
245        if let Some(upload) = upload_action {
246            let upload_url = upload["href"].as_str().ok_or_else(|| {
247                Error::io_no_path(std::io::Error::other("No upload URL in LFS response"))
248            })?;
249
250            // Step 3: Upload binary content to S3 (presigned URL, no auth header needed)
251            let upload_response = client
252                .put(upload_url)
253                .header("Content-Type", "application/octet-stream")
254                .body(data.to_vec())
255                .send()
256                .await
257                .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
258
259            if !upload_response.status().is_success() {
260                let status = upload_response.status();
261                let body = upload_response.text().await.unwrap_or_default();
262                return Err(Error::io_no_path(std::io::Error::other(format!(
263                    "LFS S3 upload failed: {} - {}",
264                    status, body
265                ))));
266            }
267        }
268        // If no upload action, object already exists in LFS - proceed to commit
269
270        // Step 4: Commit with LFS file reference
271        let commit_url = format!("{}/datasets/{}/commit/main", HF_API_URL, self.repo_id);
272        let commit_payload =
273            build_ndjson_lfs_commit(&self.commit_message, path_in_repo, &oid, size);
274
275        let commit_response = client
276            .post(&commit_url)
277            .header("Authorization", format!("Bearer {}", token))
278            .header("Content-Type", "application/x-ndjson")
279            .body(commit_payload)
280            .send()
281            .await
282            .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?;
283
284        if commit_response.status().is_success() {
285            Ok(())
286        } else {
287            let status = commit_response.status();
288            let body = commit_response.text().await.unwrap_or_default();
289            Err(Error::io_no_path(std::io::Error::other(format!(
290                "LFS commit failed: {} - {}",
291                status, body
292            ))))
293        }
294    }
295
296    /// Uploads a RecordBatch as a parquet file.
297    #[cfg(feature = "hf-hub")]
298    pub async fn upload_batch(
299        &self,
300        path_in_repo: &str,
301        batch: &arrow::record_batch::RecordBatch,
302    ) -> Result<()> {
303        use parquet::arrow::ArrowWriter;
304
305        // Write batch to parquet in memory
306        let mut buffer = Vec::new();
307        {
308            let mut writer =
309                ArrowWriter::try_new(&mut buffer, batch.schema(), None).map_err(Error::Parquet)?;
310            writer.write(batch).map_err(Error::Parquet)?;
311            writer.close().map_err(Error::Parquet)?;
312        }
313
314        self.upload_file(path_in_repo, &buffer).await
315    }
316
317    /// Uploads a local parquet file to the repository.
318    #[cfg(feature = "hf-hub")]
319    pub async fn upload_parquet_file(&self, local_path: &Path, path_in_repo: &str) -> Result<()> {
320        let data = std::fs::read(local_path).map_err(|e| Error::io(e, local_path))?;
321        self.upload_file(path_in_repo, &data).await
322    }
323
324    /// Synchronous wrapper for creating repo (for CLI use).
325    #[cfg(all(feature = "http", feature = "tokio-runtime"))]
326    pub fn create_repo_sync(&self) -> Result<()> {
327        tokio::runtime::Runtime::new()
328            .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
329            .block_on(self.create_repo())
330    }
331
332    /// Synchronous wrapper for uploading file (for CLI use).
333    #[cfg(all(feature = "hf-hub", feature = "tokio-runtime"))]
334    pub fn upload_file_sync(&self, path_in_repo: &str, data: &[u8]) -> Result<()> {
335        tokio::runtime::Runtime::new()
336            .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
337            .block_on(self.upload_file(path_in_repo, data))
338    }
339
340    /// Synchronous wrapper for uploading parquet file (for CLI use).
341    #[cfg(all(feature = "hf-hub", feature = "tokio-runtime"))]
342    pub fn upload_parquet_file_sync(&self, local_path: &Path, path_in_repo: &str) -> Result<()> {
343        tokio::runtime::Runtime::new()
344            .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
345            .block_on(self.upload_parquet_file(local_path, path_in_repo))
346    }
347
348    /// Uploads a README.md with validation.
349    ///
350    /// Validates the dataset card metadata before upload to catch issues like
351    /// invalid `task_categories` before they cause HuggingFace warnings.
352    ///
353    /// # Errors
354    ///
355    /// Returns an error if validation fails or upload fails.
356    #[cfg(feature = "hf-hub")]
357    pub async fn upload_readme_validated(&self, content: &str) -> Result<()> {
358        super::validation::DatasetCardValidator::validate_readme_strict(content)?;
359        self.upload_file("README.md", content.as_bytes()).await
360    }
361
362    /// Synchronous wrapper for validated README upload.
363    #[cfg(all(feature = "hf-hub", feature = "tokio-runtime"))]
364    pub fn upload_readme_validated_sync(&self, content: &str) -> Result<()> {
365        tokio::runtime::Runtime::new()
366            .map_err(|e| Error::io_no_path(std::io::Error::other(e)))?
367            .block_on(self.upload_readme_validated(content))
368    }
369}
370
371/// Builder for HfPublisher with fluent interface.
372#[derive(Debug, Clone)]
373pub struct HfPublisherBuilder {
374    repo_id: String,
375    token: Option<String>,
376    private: bool,
377    commit_message: String,
378}
379
380impl HfPublisherBuilder {
381    /// Creates a new builder.
382    pub fn new(repo_id: impl Into<String>) -> Self {
383        Self {
384            repo_id: repo_id.into(),
385            token: None,
386            private: false,
387            commit_message: "Upload via alimentar".to_string(),
388        }
389    }
390
391    /// Sets the token.
392    #[must_use]
393    pub fn token(mut self, token: impl Into<String>) -> Self {
394        self.token = Some(token.into());
395        self
396    }
397
398    /// Sets private flag.
399    #[must_use]
400    pub fn private(mut self, private: bool) -> Self {
401        self.private = private;
402        self
403    }
404
405    /// Sets commit message.
406    #[must_use]
407    pub fn commit_message(mut self, message: impl Into<String>) -> Self {
408        self.commit_message = message.into();
409        self
410    }
411
412    /// Builds the publisher.
413    pub fn build(self) -> HfPublisher {
414        HfPublisher {
415            repo_id: self.repo_id,
416            token: self.token.or_else(|| std::env::var("HF_TOKEN").ok()),
417            private: self.private,
418            commit_message: self.commit_message,
419        }
420    }
421}
422
423// ============================================================================
424// NDJSON Upload Payload Builder
425// ============================================================================
426
427/// Builds an NDJSON payload for the HuggingFace Hub commit API.
428///
429/// The HuggingFace commit API uses NDJSON (Newline-Delimited JSON) format:
430/// - Line 1: Header with commit message
431/// - Line 2+: File operations with base64-encoded content
432///
433/// # Arguments
434///
435/// * `commit_message` - The commit summary message
436/// * `path_in_repo` - The file path within the repository
437/// * `data` - The raw file content to upload
438///
439/// # Returns
440///
441/// A string containing the NDJSON payload ready for upload.
442///
443/// # Example
444///
445/// ```ignore
446/// let payload = build_ndjson_upload_payload(
447///     "Upload training data",
448///     "train.parquet",
449///     &parquet_bytes
450/// );
451/// ```
452#[cfg(feature = "hf-hub")]
453pub fn build_ndjson_upload_payload(
454    commit_message: &str,
455    path_in_repo: &str,
456    data: &[u8],
457) -> String {
458    use base64::{engine::general_purpose::STANDARD, Engine};
459
460    // Line 1: Header with commit message
461    let header = serde_json::json!({
462        "key": "header",
463        "value": {
464            "summary": commit_message,
465            "description": ""
466        }
467    });
468
469    // Line 2: File operation with base64 content
470    let file_op = serde_json::json!({
471        "key": "file",
472        "value": {
473            "content": STANDARD.encode(data),
474            "path": path_in_repo,
475            "encoding": "base64"
476        }
477    });
478
479    format!("{}\n{}", header, file_op)
480}
481
482// ============================================================================
483// LFS Upload Support for Binary Files
484// ============================================================================
485
486/// Binary file extensions that require LFS upload.
487const BINARY_EXTENSIONS: &[&str] = &[
488    "parquet",
489    "arrow",
490    "bin",
491    "safetensors",
492    "pt",
493    "pth",
494    "onnx",
495    "png",
496    "jpg",
497    "jpeg",
498    "gif",
499    "webp",
500    "bmp",
501    "tiff",
502    "mp3",
503    "wav",
504    "flac",
505    "ogg",
506    "mp4",
507    "webm",
508    "avi",
509    "mkv",
510    "zip",
511    "tar",
512    "gz",
513    "bz2",
514    "xz",
515    "7z",
516    "rar",
517    "pdf",
518    "doc",
519    "docx",
520    "xls",
521    "xlsx",
522    "npy",
523    "npz",
524    "h5",
525    "hdf5",
526    "pkl",
527    "pickle",
528];
529
530/// Checks if a file path is a binary file that requires LFS upload.
531///
532/// HuggingFace Hub requires binary files to be uploaded via LFS/XET storage.
533/// This function detects common binary file extensions.
534pub fn is_binary_file(path: &str) -> bool {
535    path.rsplit('.')
536        .next()
537        .map(|ext| BINARY_EXTENSIONS.contains(&ext.to_lowercase().as_str()))
538        .unwrap_or(false)
539}
540
541/// Computes SHA256 hash of data for LFS.
542///
543/// LFS uses SHA256 hashes as object identifiers (OIDs).
544#[cfg(feature = "hf-hub")]
545pub fn compute_sha256(data: &[u8]) -> String {
546    use sha2::{Digest, Sha256};
547    let mut hasher = Sha256::new();
548    hasher.update(data);
549    let result = hasher.finalize();
550    hex::encode(result)
551}
552
553/// Builds a preupload request for the HuggingFace LFS API.
554///
555/// # Arguments
556///
557/// * `path` - The file path in the repository
558/// * `data` - The binary file content
559///
560/// # Returns
561///
562/// JSON string for the preupload API request.
563#[cfg(feature = "hf-hub")]
564pub fn build_lfs_preupload_request(path: &str, data: &[u8]) -> String {
565    use base64::{engine::general_purpose::STANDARD, Engine};
566
567    // Sample is first 512 bytes, base64 encoded
568    let sample_size = std::cmp::min(512, data.len());
569    let sample = STANDARD.encode(&data[..sample_size]);
570
571    let request = serde_json::json!({
572        "files": [{
573            "path": path,
574            "size": data.len(),
575            "sample": sample
576        }]
577    });
578
579    request.to_string()
580}
581
582/// Builds a request for the LFS batch API.
583///
584/// The LFS batch API is the Git LFS standard endpoint for uploading large
585/// files. It returns presigned S3 URLs for actual binary upload.
586///
587/// # Arguments
588///
589/// * `oid` - The SHA256 hash of the file content
590/// * `size` - The file size in bytes
591///
592/// # Returns
593///
594/// JSON string for the LFS batch API request.
595#[cfg(feature = "hf-hub")]
596pub fn build_lfs_batch_request(oid: &str, size: usize) -> String {
597    let request = serde_json::json!({
598        "operation": "upload",
599        "transfers": ["basic"],
600        "objects": [{
601            "oid": oid,
602            "size": size
603        }]
604    });
605
606    request.to_string()
607}
608
609/// Builds an NDJSON commit payload for LFS files.
610///
611/// Unlike regular files (which use base64 content), LFS files use
612/// the `lfsFile` key with SHA256 OID and size.
613///
614/// # Arguments
615///
616/// * `commit_message` - The commit summary message
617/// * `path_in_repo` - The file path within the repository
618/// * `oid` - The SHA256 hash of the file content
619/// * `size` - The file size in bytes
620#[cfg(feature = "hf-hub")]
621pub fn build_ndjson_lfs_commit(
622    commit_message: &str,
623    path_in_repo: &str,
624    oid: &str,
625    size: usize,
626) -> String {
627    // Line 1: Header with commit message (same as regular commits)
628    let header = serde_json::json!({
629        "key": "header",
630        "value": {
631            "summary": commit_message,
632            "description": ""
633        }
634    });
635
636    // Line 2: LFS file operation with OID instead of content
637    let file_op = serde_json::json!({
638        "key": "lfsFile",
639        "value": {
640            "path": path_in_repo,
641            "algo": "sha256",
642            "oid": oid,
643            "size": size
644        }
645    });
646
647    format!("{}\n{}", header, file_op)
648}